diff options
-rw-r--r-- | iqk_mul_mat.cpp | 90 |
1 files changed, 48 insertions, 42 deletions
diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp index 7f57593b..f8ce62a9 100644 --- a/iqk_mul_mat.cpp +++ b/iqk_mul_mat.cpp @@ -1462,43 +1462,59 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const } } +struct DequantizeIQ2BN final : public BaseDequantizer<block_iq2_bn> { + DequantizeIQ2BN(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} + + inline void prepare4(int i, __m256i * val) const { + auto q2bits_1 = _mm256_loadu_si256((const __m256i *)x[2*i].qs); + auto q2bits_2 = _mm256_srli_epi16(q2bits_1, 2); + make2(_mm256_permute2x128_si256(q2bits_1, q2bits_2, 0x20), val+0); + make2(_mm256_permute2x128_si256(q2bits_1, q2bits_2, 0x31), val+2); + } + inline void make2(__m256i q2_1, __m256i * val) const { + val[0] = _mm256_sub_epi8(_mm256_and_si256(q2_1, mask2), m1_8); + val[1] = _mm256_sub_epi8(_mm256_and_si256(q2_1, mask3), mf_8); + } + inline void prepare2(int i, __m256i * val) const { + auto q2bits_1 = _mm_loadu_si128((const __m128i *)x[i].qs); + make2(MM256_SET_M128I(_mm_srli_epi16(q2bits_1, 2), q2bits_1), val); + } + const __m256i m1_8 = _mm256_set1_epi8(1); + const __m256i mf_8 = _mm256_set1_epi8(16); + const __m256i mask2 = _mm256_set1_epi8(0x03); + const __m256i mask3 = _mm256_set1_epi8(0x30); +}; + template <int nrc_y> IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { const int nb = n / QK_IQ1BN; Q8_K64<nrc_y> q8(info); + DequantizeIQ2BN deq(vx, bx); __m256i accd[nrc_y]; + __m256i val[4]; - const auto m1_8 = _mm256_set1_epi8(1); - const auto mask2 = _mm256_set1_epi8(3); #if !(defined __AVX512VNNI__ && defined __AVX512VL__) const auto m1_16 = _mm256_set1_epi16(1); #endif for (int ix = 0; ix < nrc_x; ++ix) { - const block_iq2_bn * x = (const block_iq2_bn *)((const char *)vx + ix*bx); + deq.new_row(ix); if constexpr (nrc_y == 1) { __m256i acc[2] = {}; for (int i = 0; i < nb/2; ++i) { - auto q2bits_1 = _mm_loadu_si128((const __m128i *)x[2*i+0].qs); - auto q2bits_2 = _mm_loadu_si128((const __m128i *)x[2*i+1].qs); - auto q2_1 = MM256_SET_M128I(_mm_srli_epi16(q2bits_1, 2), q2bits_1); - auto q2_2 = MM256_SET_M128I(_mm_srli_epi16(q2bits_2, 2), q2bits_2); - auto v1 = _mm256_sub_epi8(_mm256_and_si256(q2_1, mask2), m1_8); - auto v3 = _mm256_sub_epi8(_mm256_and_si256(q2_2, mask2), m1_8); - auto v2 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(q2_1, 4), mask2), m1_8); - auto v4 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(q2_2, 4), mask2), m1_8); + deq.prepare4(i, val); #if defined __AVX512VNNI__ && defined __AVX512VL__ - acc[0] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[0], m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 0), v1)), - m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 1), v2)); - acc[1] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[1], m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 2), v3)), - m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 3), v4)); + acc[0] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[0], deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 0), val[0])), + deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 1), val[1])); + acc[1] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[1], deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 2), val[2])), + deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 3), val[3])); #else - auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 0), v1)), - _mm256_maddubs_epi16(m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 1), v2))); - auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 2), v3)), - _mm256_maddubs_epi16(m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 3), v4))); + auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 0), val[0])), + _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 1), val[1]))); + auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 2), val[2])), + _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 3), val[3]))); acc[0] = _mm256_add_epi32(acc[0], _mm256_madd_epi16(m1_16, dot1)); acc[1] = _mm256_add_epi32(acc[1], _mm256_madd_epi16(m1_16, dot2)); #endif @@ -1510,26 +1526,19 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_si256(); for (int i = 0; i < nb/2; ++i) { - auto q2bits_1 = _mm_loadu_si128((const __m128i *)x[2*i+0].qs); - auto q2bits_2 = _mm_loadu_si128((const __m128i *)x[2*i+1].qs); - auto q2_1 = MM256_SET_M128I(_mm_srli_epi16(q2bits_1, 2), q2bits_1); - auto q2_2 = MM256_SET_M128I(_mm_srli_epi16(q2bits_2, 2), q2bits_2); - auto v1 = _mm256_sub_epi8(_mm256_and_si256(q2_1, mask2), m1_8); - auto v3 = _mm256_sub_epi8(_mm256_and_si256(q2_2, mask2), m1_8); - auto v2 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(q2_1, 4), mask2), m1_8); - auto v4 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(q2_2, 4), mask2), m1_8); + deq.prepare4(i, val); for (int iy = 0; iy < nrc_y; ++iy) { - auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i, 0), v1); - auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i, 1), v2); - auto dot3 = _mm256_sign_epi8(q8.load_quants(iy, i, 2), v3); - auto dot4 = _mm256_sign_epi8(q8.load_quants(iy, i, 3), v4); + auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i, 0), val[0]); + auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i, 1), val[1]); + auto dot3 = _mm256_sign_epi8(q8.load_quants(iy, i, 2), val[2]); + auto dot4 = _mm256_sign_epi8(q8.load_quants(iy, i, 3), val[3]); #if defined __AVX512VNNI__ && defined __AVX512VL__ accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32( - accd[iy], m1_8, dot1), m1_8, dot2), m1_8, dot3), m1_8, dot4); + accd[iy], deq.m1_8, dot1), deq.m1_8, dot2), deq.m1_8, dot3), deq.m1_8, dot4); #else auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16( - _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot1), _mm256_maddubs_epi16(m1_8, dot2)), - _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot3), _mm256_maddubs_epi16(m1_8, dot4)))); + _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, dot1), _mm256_maddubs_epi16(deq.m1_8, dot2)), + _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, dot3), _mm256_maddubs_epi16(deq.m1_8, dot4)))); accd[iy] = _mm256_add_epi32(dot, accd[iy]); #endif } @@ -1537,17 +1546,14 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const } int i = 2*(nb/2); if (i < nb) { - auto q2bits = _mm_loadu_si128((const __m128i *)x[i].qs); - auto q2_1 = MM256_SET_M128I(_mm_srli_epi16(q2bits, 2), q2bits); - auto v1 = _mm256_sub_epi8(_mm256_and_si256(q2_1, mask2), m1_8); - auto v2 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(q2_1, 4), mask2), m1_8); + deq.prepare2(i, val); for (int iy = 0; iy < nrc_y; ++iy) { - auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 0), v1); - auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 1), v2); + auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 0), val[0]); + auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 1), val[1]); #if defined __AVX512VNNI__ && defined __AVX512VL__ - accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], m1_8, dot1), m1_8, dot2); + accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], deq.m1_8, dot1), deq.m1_8, dot2); #else - dot1 = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot1), _mm256_maddubs_epi16(m1_8, dot2))); + dot1 = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, dot1), _mm256_maddubs_epi16(deq.m1_8, dot2))); accd[iy] = _mm256_add_epi32(dot1, accd[iy]); #endif } |