diff options
Diffstat (limited to 'iqk_mul_mat.cpp')
-rw-r--r-- | iqk_mul_mat.cpp | 150 |
1 files changed, 39 insertions, 111 deletions
diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp index d9aa074e..4d34f17b 100644 --- a/iqk_mul_mat.cpp +++ b/iqk_mul_mat.cpp @@ -1342,44 +1342,31 @@ template <int nrc> struct Q8_K64 { struct DequantizerIQ1BN { const __m256i m1_8 = _mm256_set1_epi8(1); -#ifdef HAVE_FANCY_SIMD - const __m128i shifthh = _mm_set_epi16(5, 6, 7, 8, 9, 10, 11, 12); -#else - const __m128i mulhh = _mm_set_epi16(32, 64, 128, 256, 512, 1024, 2048, 4096); -#endif - const __m128i maskhh = _mm_set1_epi16(4096); - const __m256i shuffles[4] = { - _mm256_set_epi64x(0x0302030203020302, 0x0302030203020302, 0x0100010001000100, 0x0100010001000100), - _mm256_set_epi64x(0x0706070607060706, 0x0706070607060706, 0x0504050405040504, 0x0504050405040504), - _mm256_set_epi64x(0x0b0a0b0a0b0a0b0a, 0x0b0a0b0a0b0a0b0a, 0x0908090809080908, 0x0908090809080908), - _mm256_set_epi64x(0x0f0e0f0e0f0e0f0e, 0x0f0e0f0e0f0e0f0e, 0x0d0c0d0c0d0c0d0c, 0x0d0c0d0c0d0c0d0c), + static __m128i load_shuffle(int i) { + static const uint8_t data[64] = {0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 12, + 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 12, + 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 12, + 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 12}; + return _mm_loadu_si128((const __m128i*)data + i); + } + const __m128i shuff[4] = { load_shuffle(0), load_shuffle(1), load_shuffle(2), load_shuffle(3) }; + const __m256i mult[4] = { + _mm256_set_epi64x(0x5100010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100), + _mm256_set_epi64x(0x1b00010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100), + _mm256_set_epi64x(0x0900010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100), + _mm256_set_epi64x(0x0300010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100), }; - const __m256i mult = _mm256_set_epi16(8, 24, 72, 216, 648, 1944, 5832, 17496, 8, 24, 72, 216, 648, 1944, 5832, 17496); const __m256i m3 = _mm256_set1_epi16(3); - const __m128i shuff_l = _mm_set_epi8(-128, 8, -128, 7, -128, 6, -128, 5, -128, 4, -128, 3, -128, 2, -128, 1); - const __m128i shuff_h = _mm_set_epi8(12, -128, 11, -128, 10, -128, 9, -128, 12, -128, 11, -128, 10, -128, 9, -128); - const __m128i shift_h = _mm_set_epi32(4, 4, 0, 0); - const __m128i mask_h = _mm_set1_epi16(0x0f00); - const __m128i shuff_hh = _mm_set_epi8(-128, 0, -128, 0, -128, 0, -128, 0, -128, 0, -128, 0, -128, 0, -128, 0); #ifdef HAVE_FANCY_SIMD const __m256i bmask = _mm256_set_epi8(62, 60, 58, 56, 54, 52, 50, 48, 46, 44, 42, 40, 38, 36, 34, 32, 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0); #endif - IQK_ALWAYS_INLINE void prepare_iq1bn_quants(const block_iq1_bn * x, __m256i& v1, __m256i& v2) { + IQK_ALWAYS_INLINE void prepare_iq1bn_quants(const block_iq1_bn * x, __m256i& v1, __m256i& v2) const { auto data = _mm_loadu_si128((const __m128i *)x); // Note: we load 16 instead of 13 bytes! - auto aux1 = _mm_shuffle_epi8(data, shuff_l); - auto aux2 = _mm_and_si128(_mm_srlv_epi32(_mm_shuffle_epi8(data, shuff_h), shift_h), mask_h); -#ifdef HAVE_FANCY_SIMD - auto aux3 = _mm_and_si128(_mm_sllv_epi16(_mm_shuffle_epi8(data, shuff_hh), shifthh), maskhh); -#else - auto aux3 = _mm_and_si128(_mm_mullo_epi16(_mm_shuffle_epi8(data, shuff_hh), mulhh), maskhh); -#endif - auto all128 = _mm_or_si128(_mm_or_si128(aux1, aux2), aux3); - auto all = MM256_SET_M128I(all128, all128); - auto val1 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[0]), mult), m3); - auto val2 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[1]), mult), m3); - auto val3 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[2]), mult), m3); - auto val4 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[3]), mult), m3); + auto val1 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_cvtepu8_epi16(_mm_shuffle_epi8(data, shuff[0])), mult[0]), m3); + auto val2 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_cvtepu8_epi16(_mm_shuffle_epi8(data, shuff[1])), mult[1]), m3); + auto val3 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_cvtepu8_epi16(_mm_shuffle_epi8(data, shuff[2])), mult[2]), m3); + auto val4 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_cvtepu8_epi16(_mm_shuffle_epi8(data, shuff[3])), mult[3]), m3); #ifdef HAVE_FANCY_SIMD v1 = _mm256_sub_epi8(_mm256_permutex2var_epi8(val1, bmask, val2), m1_8); v2 = _mm256_sub_epi8(_mm256_permutex2var_epi8(val3, bmask, val4), m1_8); @@ -1389,21 +1376,6 @@ struct DequantizerIQ1BN { #endif } - //IQK_ALWAYS_INLINE void prepare_iq1bn_quants(uint8_t extra, const uint8_t * ql, const uint8_t * qh, __m256i& v1, __m256i& v2) { - - // auto aux1 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)ql)); - // uint32_t aux32; std::memcpy(&aux32, qh, 4); - // auto aux2 = _mm_cvtepu8_epi16(_mm_and_si128(_mm_set_epi32(aux32, aux32, aux32, aux32 << 4), mask1)); - // auto aux3 = _mm_and_si128(_mm_mullo_epi16(_mm_set1_epi16(extra), mulhh), maskhh); - // auto all128 = _mm_or_si128(_mm_slli_epi16(aux2, 4), _mm_or_si128(aux1, aux3)); - // auto all = MM256_SET_M128I(all128, all128); - // auto val1 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[0]), mult), m3); - // auto val2 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[1]), mult), m3); - // auto val3 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[2]), mult), m3); - // auto val4 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[3]), mult), m3); - // v1 = _mm256_sub_epi8(_mm256_permute4x64_epi64(_mm256_packs_epi16(val1, val2), 216), m1_8); - // v2 = _mm256_sub_epi8(_mm256_permute4x64_epi64(_mm256_packs_epi16(val3, val4), 216), m1_8); - //} }; template <int nrc_y> @@ -1466,9 +1438,9 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const accd[iy], deq.m1_8, dot1), deq.m1_8, dot2), deq.m1_8, dot3), deq.m1_8, dot4); #else auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 0), val[0])), - _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 1), val[1]))); + _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 1), val[1]))); auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 2), val[2])), - _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 3), val[3]))); + _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 3), val[3]))); dot1 = _mm256_madd_epi16(m1_16, _mm256_add_epi16(dot1, dot2)); accd[iy] = _mm256_add_epi32(dot1, accd[iy]); #endif @@ -4376,73 +4348,29 @@ static const uint64_t kall_signs[257] = { struct DequantizerIQ1BN { const uint8x16_t m1 = vdupq_n_u8(1); - static inline uint8x16_t load_shuffle_l() { - static const uint8_t data[16] = {1, 255, 2, 255, 3, 255, 4, 255, 5, 255, 6, 255, 7, 255, 8, 255}; - return vld1q_u8(data); - } - static inline uint8x16_t load_shuffle_h() { - static const uint8_t data[16] = {9, 255, 10, 255, 11, 255, 12, 255, 9, 255, 10, 255, 11, 255, 12, 255}; - return vld1q_u8(data); - } - static inline uint8x16_t load_shuffle_hh() { - static const uint8_t data[16] = {0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255}; - return vld1q_u8(data); - } - static inline int16x8_t load_shift_hh() { - static const int16_t data[8] = {12, 11, 10, 9, 8, 7, 6, 5}; - return vld1q_s16(data); - } - static inline uint16x8_t load_mult() { - //static const uint16_t data[8] = {2187, 729, 243, 81, 27, 9, 3, 1}; - static const uint16_t data[8] = {2187*8, 729*8, 243*8, 81*8, 27*8, 9*8, 3*8, 1*8}; - return vld1q_u16(data); - } - //static inline uint8x16x4_t load_shuffles(uint16_t s0) { - // uint8x16x4_t r; - // auto step = vdupq_n_u8(4); - // r.val[0] = vreinterpretq_u8_u16(vdupq_n_u16(s0)); - // r.val[1] = vaddq_u8(r.val[0], step); - // r.val[2] = vaddq_u8(r.val[1], step); - // r.val[3] = vaddq_u8(r.val[2], step); - // return r; - //} - - const uint8x16_t shuff_l = load_shuffle_l(); - const uint8x16_t shuff_h = load_shuffle_h(); - const int32x4_t shift_h = {8, 8, 4, 4}; - const uint16x8_t mask_h = vdupq_n_u16(0x0f00); - const uint8x16_t shuff_hh = load_shuffle_hh(); - const uint16x8_t mask_hh = vdupq_n_u16(4096); - const int16x8_t shift_hh = load_shift_hh(); - const uint16x8_t mult = load_mult(); - const uint8x16_t step = vdupq_n_u8(2); - const uint8x16_t shuff0 = vreinterpretq_u8_u16(vdupq_n_u16(0x0100)); - //const uint8x16x4_t shuff1 = load_shuffles(0x0100); - //const uint8x16x4_t shuff2 = load_shuffles(0x0302); - //const uint16x8_t mask = vdupq_n_u16(0x1fff); - //const uint16x8_t m3 = vdupq_n_u16(3); + static inline uint8x16x4_t load_shuffles() { + static const uint8_t data[64] = {0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 12, + 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 12, + 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 12, + 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 12}; + return vld1q_u8_x4(data); + } + static inline uint8x16x4_t load_mult() { + static const uint8_t data[64] = {81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, + 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 27, + 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 9, + 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 3}; + return vld1q_u8_x4(data); + } + const uint8x16x4_t shuff = load_shuffles(); + const uint8x16x4_t mult = load_mult(); IQK_ALWAYS_INLINE void prepare_iq1bn_quants(const block_iq1_bn * x, int8x16x4_t& v) const { auto data = vld1q_u8((const uint8_t *)x); - auto aux1 = vqtbl1q_u8(data, shuff_l); - auto aux2 = vandq_u16(vshlq_u32(vqtbl1q_u8(data, shuff_h), shift_h), mask_h); - auto aux3 = vandq_u16(vshlq_u16(vqtbl1q_u8(data, shuff_hh), shift_hh), mask_hh); - auto all = vorrq_u16(vorrq_u16(aux1, aux2), aux3); - auto shuffle = shuff0; - //auto shuffle = vreinterpretq_u8_u16(vdupq_n_u16(0x0100)); - //auto step = vdupq_n_u8(2); for (int k = 0; k < 4; ++k) { - auto v1 = vreinterpretq_u16_u8(vqtbl1q_u8(all, shuffle)); shuffle = vaddq_u8(shuffle, step); - auto v2 = vreinterpretq_u16_u8(vqtbl1q_u8(all, shuffle)); shuffle = vaddq_u8(shuffle, step); - //auto v1 = vreinterpretq_u16_u8(vqtbl1q_u8(all, shuff1.val[k])); - //auto v2 = vreinterpretq_u16_u8(vqtbl1q_u8(all, shuff2.val[k])); - v1 = vmulq_u16(v1, mult); - v2 = vmulq_u16(v2, mult); - v1 = vshrq_n_u16(vhaddq_u16(v1, vshrq_n_u16(v1, 1)), 14); - v2 = vshrq_n_u16(vhaddq_u16(v2, vshrq_n_u16(v2, 1)), 14); - //v1 = vshrq_n_u16(vmulq_u16(vandq_u16(vmulq_u16(v1, mult), mask), m3), 13); - //v2 = vshrq_n_u16(vmulq_u16(vandq_u16(vmulq_u16(v2, mult), mask), m3), 13); - v.val[k] = vsubq_s8(vreinterpretq_s8_u8(vcombine_u8(vmovn_u16(v1), vmovn_u16(v2))), m1); + auto val = vmulq_u8(vqtbl1q_u8(data, shuff.val[k]), mult.val[k]); + val = vshrq_n_u8(vhaddq_u8(val, vshrq_n_u8(val, 1)), 6); + v.val[k] = vsubq_s8(vreinterpretq_s8_u8(val), m1); } } }; |