summaryrefslogtreecommitdiff
path: root/iqk_mul_mat.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'iqk_mul_mat.cpp')
-rw-r--r--iqk_mul_mat.cpp150
1 files changed, 39 insertions, 111 deletions
diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp
index d9aa074e..4d34f17b 100644
--- a/iqk_mul_mat.cpp
+++ b/iqk_mul_mat.cpp
@@ -1342,44 +1342,31 @@ template <int nrc> struct Q8_K64 {
struct DequantizerIQ1BN {
const __m256i m1_8 = _mm256_set1_epi8(1);
-#ifdef HAVE_FANCY_SIMD
- const __m128i shifthh = _mm_set_epi16(5, 6, 7, 8, 9, 10, 11, 12);
-#else
- const __m128i mulhh = _mm_set_epi16(32, 64, 128, 256, 512, 1024, 2048, 4096);
-#endif
- const __m128i maskhh = _mm_set1_epi16(4096);
- const __m256i shuffles[4] = {
- _mm256_set_epi64x(0x0302030203020302, 0x0302030203020302, 0x0100010001000100, 0x0100010001000100),
- _mm256_set_epi64x(0x0706070607060706, 0x0706070607060706, 0x0504050405040504, 0x0504050405040504),
- _mm256_set_epi64x(0x0b0a0b0a0b0a0b0a, 0x0b0a0b0a0b0a0b0a, 0x0908090809080908, 0x0908090809080908),
- _mm256_set_epi64x(0x0f0e0f0e0f0e0f0e, 0x0f0e0f0e0f0e0f0e, 0x0d0c0d0c0d0c0d0c, 0x0d0c0d0c0d0c0d0c),
+ static __m128i load_shuffle(int i) {
+ static const uint8_t data[64] = {0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 12,
+ 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 12,
+ 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 12,
+ 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 12};
+ return _mm_loadu_si128((const __m128i*)data + i);
+ }
+ const __m128i shuff[4] = { load_shuffle(0), load_shuffle(1), load_shuffle(2), load_shuffle(3) };
+ const __m256i mult[4] = {
+ _mm256_set_epi64x(0x5100010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100),
+ _mm256_set_epi64x(0x1b00010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100),
+ _mm256_set_epi64x(0x0900010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100),
+ _mm256_set_epi64x(0x0300010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100),
};
- const __m256i mult = _mm256_set_epi16(8, 24, 72, 216, 648, 1944, 5832, 17496, 8, 24, 72, 216, 648, 1944, 5832, 17496);
const __m256i m3 = _mm256_set1_epi16(3);
- const __m128i shuff_l = _mm_set_epi8(-128, 8, -128, 7, -128, 6, -128, 5, -128, 4, -128, 3, -128, 2, -128, 1);
- const __m128i shuff_h = _mm_set_epi8(12, -128, 11, -128, 10, -128, 9, -128, 12, -128, 11, -128, 10, -128, 9, -128);
- const __m128i shift_h = _mm_set_epi32(4, 4, 0, 0);
- const __m128i mask_h = _mm_set1_epi16(0x0f00);
- const __m128i shuff_hh = _mm_set_epi8(-128, 0, -128, 0, -128, 0, -128, 0, -128, 0, -128, 0, -128, 0, -128, 0);
#ifdef HAVE_FANCY_SIMD
const __m256i bmask = _mm256_set_epi8(62, 60, 58, 56, 54, 52, 50, 48, 46, 44, 42, 40, 38, 36, 34, 32, 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
#endif
- IQK_ALWAYS_INLINE void prepare_iq1bn_quants(const block_iq1_bn * x, __m256i& v1, __m256i& v2) {
+ IQK_ALWAYS_INLINE void prepare_iq1bn_quants(const block_iq1_bn * x, __m256i& v1, __m256i& v2) const {
auto data = _mm_loadu_si128((const __m128i *)x); // Note: we load 16 instead of 13 bytes!
- auto aux1 = _mm_shuffle_epi8(data, shuff_l);
- auto aux2 = _mm_and_si128(_mm_srlv_epi32(_mm_shuffle_epi8(data, shuff_h), shift_h), mask_h);
-#ifdef HAVE_FANCY_SIMD
- auto aux3 = _mm_and_si128(_mm_sllv_epi16(_mm_shuffle_epi8(data, shuff_hh), shifthh), maskhh);
-#else
- auto aux3 = _mm_and_si128(_mm_mullo_epi16(_mm_shuffle_epi8(data, shuff_hh), mulhh), maskhh);
-#endif
- auto all128 = _mm_or_si128(_mm_or_si128(aux1, aux2), aux3);
- auto all = MM256_SET_M128I(all128, all128);
- auto val1 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[0]), mult), m3);
- auto val2 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[1]), mult), m3);
- auto val3 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[2]), mult), m3);
- auto val4 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[3]), mult), m3);
+ auto val1 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_cvtepu8_epi16(_mm_shuffle_epi8(data, shuff[0])), mult[0]), m3);
+ auto val2 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_cvtepu8_epi16(_mm_shuffle_epi8(data, shuff[1])), mult[1]), m3);
+ auto val3 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_cvtepu8_epi16(_mm_shuffle_epi8(data, shuff[2])), mult[2]), m3);
+ auto val4 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_cvtepu8_epi16(_mm_shuffle_epi8(data, shuff[3])), mult[3]), m3);
#ifdef HAVE_FANCY_SIMD
v1 = _mm256_sub_epi8(_mm256_permutex2var_epi8(val1, bmask, val2), m1_8);
v2 = _mm256_sub_epi8(_mm256_permutex2var_epi8(val3, bmask, val4), m1_8);
@@ -1389,21 +1376,6 @@ struct DequantizerIQ1BN {
#endif
}
- //IQK_ALWAYS_INLINE void prepare_iq1bn_quants(uint8_t extra, const uint8_t * ql, const uint8_t * qh, __m256i& v1, __m256i& v2) {
-
- // auto aux1 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)ql));
- // uint32_t aux32; std::memcpy(&aux32, qh, 4);
- // auto aux2 = _mm_cvtepu8_epi16(_mm_and_si128(_mm_set_epi32(aux32, aux32, aux32, aux32 << 4), mask1));
- // auto aux3 = _mm_and_si128(_mm_mullo_epi16(_mm_set1_epi16(extra), mulhh), maskhh);
- // auto all128 = _mm_or_si128(_mm_slli_epi16(aux2, 4), _mm_or_si128(aux1, aux3));
- // auto all = MM256_SET_M128I(all128, all128);
- // auto val1 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[0]), mult), m3);
- // auto val2 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[1]), mult), m3);
- // auto val3 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[2]), mult), m3);
- // auto val4 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[3]), mult), m3);
- // v1 = _mm256_sub_epi8(_mm256_permute4x64_epi64(_mm256_packs_epi16(val1, val2), 216), m1_8);
- // v2 = _mm256_sub_epi8(_mm256_permute4x64_epi64(_mm256_packs_epi16(val3, val4), 216), m1_8);
- //}
};
template <int nrc_y>
@@ -1466,9 +1438,9 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const
accd[iy], deq.m1_8, dot1), deq.m1_8, dot2), deq.m1_8, dot3), deq.m1_8, dot4);
#else
auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 0), val[0])),
- _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 1), val[1])));
+ _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 1), val[1])));
auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 2), val[2])),
- _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 3), val[3])));
+ _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 3), val[3])));
dot1 = _mm256_madd_epi16(m1_16, _mm256_add_epi16(dot1, dot2));
accd[iy] = _mm256_add_epi32(dot1, accd[iy]);
#endif
@@ -4376,73 +4348,29 @@ static const uint64_t kall_signs[257] = {
struct DequantizerIQ1BN {
const uint8x16_t m1 = vdupq_n_u8(1);
- static inline uint8x16_t load_shuffle_l() {
- static const uint8_t data[16] = {1, 255, 2, 255, 3, 255, 4, 255, 5, 255, 6, 255, 7, 255, 8, 255};
- return vld1q_u8(data);
- }
- static inline uint8x16_t load_shuffle_h() {
- static const uint8_t data[16] = {9, 255, 10, 255, 11, 255, 12, 255, 9, 255, 10, 255, 11, 255, 12, 255};
- return vld1q_u8(data);
- }
- static inline uint8x16_t load_shuffle_hh() {
- static const uint8_t data[16] = {0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255};
- return vld1q_u8(data);
- }
- static inline int16x8_t load_shift_hh() {
- static const int16_t data[8] = {12, 11, 10, 9, 8, 7, 6, 5};
- return vld1q_s16(data);
- }
- static inline uint16x8_t load_mult() {
- //static const uint16_t data[8] = {2187, 729, 243, 81, 27, 9, 3, 1};
- static const uint16_t data[8] = {2187*8, 729*8, 243*8, 81*8, 27*8, 9*8, 3*8, 1*8};
- return vld1q_u16(data);
- }
- //static inline uint8x16x4_t load_shuffles(uint16_t s0) {
- // uint8x16x4_t r;
- // auto step = vdupq_n_u8(4);
- // r.val[0] = vreinterpretq_u8_u16(vdupq_n_u16(s0));
- // r.val[1] = vaddq_u8(r.val[0], step);
- // r.val[2] = vaddq_u8(r.val[1], step);
- // r.val[3] = vaddq_u8(r.val[2], step);
- // return r;
- //}
-
- const uint8x16_t shuff_l = load_shuffle_l();
- const uint8x16_t shuff_h = load_shuffle_h();
- const int32x4_t shift_h = {8, 8, 4, 4};
- const uint16x8_t mask_h = vdupq_n_u16(0x0f00);
- const uint8x16_t shuff_hh = load_shuffle_hh();
- const uint16x8_t mask_hh = vdupq_n_u16(4096);
- const int16x8_t shift_hh = load_shift_hh();
- const uint16x8_t mult = load_mult();
- const uint8x16_t step = vdupq_n_u8(2);
- const uint8x16_t shuff0 = vreinterpretq_u8_u16(vdupq_n_u16(0x0100));
- //const uint8x16x4_t shuff1 = load_shuffles(0x0100);
- //const uint8x16x4_t shuff2 = load_shuffles(0x0302);
- //const uint16x8_t mask = vdupq_n_u16(0x1fff);
- //const uint16x8_t m3 = vdupq_n_u16(3);
+ static inline uint8x16x4_t load_shuffles() {
+ static const uint8_t data[64] = {0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 12,
+ 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 12,
+ 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 12,
+ 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 12};
+ return vld1q_u8_x4(data);
+ }
+ static inline uint8x16x4_t load_mult() {
+ static const uint8_t data[64] = {81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81,
+ 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 27,
+ 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 9,
+ 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 3};
+ return vld1q_u8_x4(data);
+ }
+ const uint8x16x4_t shuff = load_shuffles();
+ const uint8x16x4_t mult = load_mult();
IQK_ALWAYS_INLINE void prepare_iq1bn_quants(const block_iq1_bn * x, int8x16x4_t& v) const {
auto data = vld1q_u8((const uint8_t *)x);
- auto aux1 = vqtbl1q_u8(data, shuff_l);
- auto aux2 = vandq_u16(vshlq_u32(vqtbl1q_u8(data, shuff_h), shift_h), mask_h);
- auto aux3 = vandq_u16(vshlq_u16(vqtbl1q_u8(data, shuff_hh), shift_hh), mask_hh);
- auto all = vorrq_u16(vorrq_u16(aux1, aux2), aux3);
- auto shuffle = shuff0;
- //auto shuffle = vreinterpretq_u8_u16(vdupq_n_u16(0x0100));
- //auto step = vdupq_n_u8(2);
for (int k = 0; k < 4; ++k) {
- auto v1 = vreinterpretq_u16_u8(vqtbl1q_u8(all, shuffle)); shuffle = vaddq_u8(shuffle, step);
- auto v2 = vreinterpretq_u16_u8(vqtbl1q_u8(all, shuffle)); shuffle = vaddq_u8(shuffle, step);
- //auto v1 = vreinterpretq_u16_u8(vqtbl1q_u8(all, shuff1.val[k]));
- //auto v2 = vreinterpretq_u16_u8(vqtbl1q_u8(all, shuff2.val[k]));
- v1 = vmulq_u16(v1, mult);
- v2 = vmulq_u16(v2, mult);
- v1 = vshrq_n_u16(vhaddq_u16(v1, vshrq_n_u16(v1, 1)), 14);
- v2 = vshrq_n_u16(vhaddq_u16(v2, vshrq_n_u16(v2, 1)), 14);
- //v1 = vshrq_n_u16(vmulq_u16(vandq_u16(vmulq_u16(v1, mult), mask), m3), 13);
- //v2 = vshrq_n_u16(vmulq_u16(vandq_u16(vmulq_u16(v2, mult), mask), m3), 13);
- v.val[k] = vsubq_s8(vreinterpretq_s8_u8(vcombine_u8(vmovn_u16(v1), vmovn_u16(v2))), m1);
+ auto val = vmulq_u8(vqtbl1q_u8(data, shuff.val[k]), mult.val[k]);
+ val = vshrq_n_u8(vhaddq_u8(val, vshrq_n_u8(val, 1)), 6);
+ v.val[k] = vsubq_s8(vreinterpretq_s8_u8(val), m1);
}
}
};