diff options
-rw-r--r-- | iqk_mul_mat.cpp | 61 |
1 files changed, 48 insertions, 13 deletions
diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp index 44d0703a..0f53e02c 100644 --- a/iqk_mul_mat.cpp +++ b/iqk_mul_mat.cpp @@ -1342,8 +1342,11 @@ template <int nrc> struct Q8_K64 { struct DequantizerIQ1BN { const __m256i m1_8 = _mm256_set1_epi8(1); - const __m128i mask1 = _mm_set1_epi8(0xf0); +#ifdef HAVE_FANCY_SIMD + const __m128i shifthh = _mm_set_epi16(5, 6, 7, 8, 9, 10, 11, 12); +#else const __m128i mulhh = _mm_set_epi16(32, 64, 128, 256, 512, 1024, 2048, 4096); +#endif const __m128i maskhh = _mm_set1_epi16(4096); const __m256i shuffles[4] = { _mm256_set_epi64x(0x0302030203020302, 0x0302030203020302, 0x0100010001000100, 0x0100010001000100), @@ -1353,22 +1356,54 @@ struct DequantizerIQ1BN { }; const __m256i mult = _mm256_set_epi16(8, 24, 72, 216, 648, 1944, 5832, 17496, 8, 24, 72, 216, 648, 1944, 5832, 17496); const __m256i m3 = _mm256_set1_epi16(3); + const __m128i shuff_l = _mm_set_epi8(-128, 8, -128, 7, -128, 6, -128, 5, -128, 4, -128, 3, -128, 2, -128, 1); + const __m128i shuff_h = _mm_set_epi8(12, -128, 11, -128, 10, -128, 9, -128, 12, -128, 11, -128, 10, -128, 9, -128); + const __m128i shift_h = _mm_set_epi32(4, 4, 0, 0); + const __m128i mask_h = _mm_set1_epi16(0x0f00); + const __m128i shuff_hh = _mm_set_epi8(-128, 0, -128, 0, -128, 0, -128, 0, -128, 0, -128, 0, -128, 0, -128, 0); +#ifdef HAVE_FANCY_SIMD + const __m256i bmask = _mm256_set_epi8(62, 60, 58, 56, 54, 52, 50, 48, 46, 44, 42, 40, 38, 36, 34, 32, 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0); +#endif - IQK_ALWAYS_INLINE void prepare_iq1bn_quants(uint8_t extra, const uint8_t * ql, const uint8_t * qh, __m256i& v1, __m256i& v2) { - - auto aux1 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)ql)); - uint32_t aux32; std::memcpy(&aux32, qh, 4); - auto aux2 = _mm_cvtepu8_epi16(_mm_and_si128(_mm_set_epi32(aux32, aux32, aux32, aux32 << 4), mask1)); - auto aux3 = _mm_and_si128(_mm_mullo_epi16(_mm_set1_epi16(extra), mulhh), maskhh); - auto all128 = _mm_or_si128(_mm_slli_epi16(aux2, 4), _mm_or_si128(aux1, aux3)); + IQK_ALWAYS_INLINE void prepare_iq1bn_quants(const block_iq1_bn * x, __m256i& v1, __m256i& v2) { + auto data = _mm_loadu_si128((const __m128i *)x); // Note: we load 16 instead of 13 bytes! + auto aux1 = _mm_shuffle_epi8(data, shuff_l); + auto aux2 = _mm_and_si128(_mm_srlv_epi32(_mm_shuffle_epi8(data, shuff_h), shift_h), mask_h); +#ifdef HAVE_FANCY_SIMD + auto aux3 = _mm_and_si128(_mm_sllv_epi16(_mm_shuffle_epi8(data, shuff_hh), shifthh), maskhh); +#else + auto aux3 = _mm_and_si128(_mm_mullo_epi16(_mm_shuffle_epi8(data, shuff_hh), mulhh), maskhh); +#endif + auto all128 = _mm_or_si128(_mm_or_si128(aux1, aux2), aux3); auto all = MM256_SET_M128I(all128, all128); auto val1 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[0]), mult), m3); auto val2 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[1]), mult), m3); auto val3 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[2]), mult), m3); auto val4 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[3]), mult), m3); +#ifdef HAVE_FANCY_SIMD + v1 = _mm256_sub_epi8(_mm256_permutex2var_epi8(val1, bmask, val2), m1_8); + v2 = _mm256_sub_epi8(_mm256_permutex2var_epi8(val3, bmask, val4), m1_8); +#else v1 = _mm256_sub_epi8(_mm256_permute4x64_epi64(_mm256_packs_epi16(val1, val2), 216), m1_8); v2 = _mm256_sub_epi8(_mm256_permute4x64_epi64(_mm256_packs_epi16(val3, val4), 216), m1_8); +#endif } + + //IQK_ALWAYS_INLINE void prepare_iq1bn_quants(uint8_t extra, const uint8_t * ql, const uint8_t * qh, __m256i& v1, __m256i& v2) { + + // auto aux1 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)ql)); + // uint32_t aux32; std::memcpy(&aux32, qh, 4); + // auto aux2 = _mm_cvtepu8_epi16(_mm_and_si128(_mm_set_epi32(aux32, aux32, aux32, aux32 << 4), mask1)); + // auto aux3 = _mm_and_si128(_mm_mullo_epi16(_mm_set1_epi16(extra), mulhh), maskhh); + // auto all128 = _mm_or_si128(_mm_slli_epi16(aux2, 4), _mm_or_si128(aux1, aux3)); + // auto all = MM256_SET_M128I(all128, all128); + // auto val1 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[0]), mult), m3); + // auto val2 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[1]), mult), m3); + // auto val3 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[2]), mult), m3); + // auto val4 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[3]), mult), m3); + // v1 = _mm256_sub_epi8(_mm256_permute4x64_epi64(_mm256_packs_epi16(val1, val2), 216), m1_8); + // v2 = _mm256_sub_epi8(_mm256_permute4x64_epi64(_mm256_packs_epi16(val3, val4), 216), m1_8); + //} }; template <int nrc_y> @@ -1392,8 +1427,8 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const if constexpr (nrc_y == 1) { __m256i acc1 = _mm256_setzero_si256(), acc2 = _mm256_setzero_si256(); for (int i = 0; i < nb/2; ++i) { - deq.prepare_iq1bn_quants(x[2*i+0].extra, x[2*i+0].ql, x[2*i+0].qh, val[0], val[1]); - deq.prepare_iq1bn_quants(x[2*i+1].extra, x[2*i+1].ql, x[2*i+1].qh, val[2], val[3]); + deq.prepare_iq1bn_quants(x + 2*i + 0, val[0], val[1]); + deq.prepare_iq1bn_quants(x + 2*i + 1, val[2], val[3]); #if defined __AVX512VNNI__ && defined __AVX512VL__ auto dot1 = _mm256_sign_epi8(q8.load_quants(0, i, 0), val[0]); auto dot2 = _mm256_sign_epi8(q8.load_quants(0, i, 1), val[1]); @@ -1418,8 +1453,8 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const for (int i = 0; i < nb/2; ++i) { - deq.prepare_iq1bn_quants(x[2*i+0].extra, x[2*i+0].ql, x[2*i+0].qh, val[0], val[1]); - deq.prepare_iq1bn_quants(x[2*i+1].extra, x[2*i+1].ql, x[2*i+1].qh, val[2], val[3]); + deq.prepare_iq1bn_quants(x + 2*i + 0, val[0], val[1]); + deq.prepare_iq1bn_quants(x + 2*i + 1, val[2], val[3]); for (int iy = 0; iy < nrc_y; ++iy) { #if defined __AVX512VNNI__ && defined __AVX512VL__ @@ -1442,7 +1477,7 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const } int i = 2*(nb/2); if (i < nb) { - deq.prepare_iq1bn_quants(x[i].extra, x[i].ql, x[i].qh, val[0], val[1]); + deq.prepare_iq1bn_quants(x + i, val[0], val[1]); for (int iy = 0; iy < nrc_y; ++iy) { auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 0), val[0]); auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 1), val[1]); |