summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--iqk_mul_mat.cpp90
1 files changed, 48 insertions, 42 deletions
diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp
index 7f57593b..f8ce62a9 100644
--- a/iqk_mul_mat.cpp
+++ b/iqk_mul_mat.cpp
@@ -1462,43 +1462,59 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const
}
}
+struct DequantizeIQ2BN final : public BaseDequantizer<block_iq2_bn> {
+ DequantizeIQ2BN(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
+
+ inline void prepare4(int i, __m256i * val) const {
+ auto q2bits_1 = _mm256_loadu_si256((const __m256i *)x[2*i].qs);
+ auto q2bits_2 = _mm256_srli_epi16(q2bits_1, 2);
+ make2(_mm256_permute2x128_si256(q2bits_1, q2bits_2, 0x20), val+0);
+ make2(_mm256_permute2x128_si256(q2bits_1, q2bits_2, 0x31), val+2);
+ }
+ inline void make2(__m256i q2_1, __m256i * val) const {
+ val[0] = _mm256_sub_epi8(_mm256_and_si256(q2_1, mask2), m1_8);
+ val[1] = _mm256_sub_epi8(_mm256_and_si256(q2_1, mask3), mf_8);
+ }
+ inline void prepare2(int i, __m256i * val) const {
+ auto q2bits_1 = _mm_loadu_si128((const __m128i *)x[i].qs);
+ make2(MM256_SET_M128I(_mm_srli_epi16(q2bits_1, 2), q2bits_1), val);
+ }
+ const __m256i m1_8 = _mm256_set1_epi8(1);
+ const __m256i mf_8 = _mm256_set1_epi8(16);
+ const __m256i mask2 = _mm256_set1_epi8(0x03);
+ const __m256i mask3 = _mm256_set1_epi8(0x30);
+};
+
template <int nrc_y>
IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
const int nb = n / QK_IQ1BN;
Q8_K64<nrc_y> q8(info);
+ DequantizeIQ2BN deq(vx, bx);
__m256i accd[nrc_y];
+ __m256i val[4];
- const auto m1_8 = _mm256_set1_epi8(1);
- const auto mask2 = _mm256_set1_epi8(3);
#if !(defined __AVX512VNNI__ && defined __AVX512VL__)
const auto m1_16 = _mm256_set1_epi16(1);
#endif
for (int ix = 0; ix < nrc_x; ++ix) {
- const block_iq2_bn * x = (const block_iq2_bn *)((const char *)vx + ix*bx);
+ deq.new_row(ix);
if constexpr (nrc_y == 1) {
__m256i acc[2] = {};
for (int i = 0; i < nb/2; ++i) {
- auto q2bits_1 = _mm_loadu_si128((const __m128i *)x[2*i+0].qs);
- auto q2bits_2 = _mm_loadu_si128((const __m128i *)x[2*i+1].qs);
- auto q2_1 = MM256_SET_M128I(_mm_srli_epi16(q2bits_1, 2), q2bits_1);
- auto q2_2 = MM256_SET_M128I(_mm_srli_epi16(q2bits_2, 2), q2bits_2);
- auto v1 = _mm256_sub_epi8(_mm256_and_si256(q2_1, mask2), m1_8);
- auto v3 = _mm256_sub_epi8(_mm256_and_si256(q2_2, mask2), m1_8);
- auto v2 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(q2_1, 4), mask2), m1_8);
- auto v4 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(q2_2, 4), mask2), m1_8);
+ deq.prepare4(i, val);
#if defined __AVX512VNNI__ && defined __AVX512VL__
- acc[0] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[0], m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 0), v1)),
- m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 1), v2));
- acc[1] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[1], m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 2), v3)),
- m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 3), v4));
+ acc[0] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[0], deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 0), val[0])),
+ deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 1), val[1]));
+ acc[1] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[1], deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 2), val[2])),
+ deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 3), val[3]));
#else
- auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 0), v1)),
- _mm256_maddubs_epi16(m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 1), v2)));
- auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 2), v3)),
- _mm256_maddubs_epi16(m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 3), v4)));
+ auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 0), val[0])),
+ _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 1), val[1])));
+ auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 2), val[2])),
+ _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 3), val[3])));
acc[0] = _mm256_add_epi32(acc[0], _mm256_madd_epi16(m1_16, dot1));
acc[1] = _mm256_add_epi32(acc[1], _mm256_madd_epi16(m1_16, dot2));
#endif
@@ -1510,26 +1526,19 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_si256();
for (int i = 0; i < nb/2; ++i) {
- auto q2bits_1 = _mm_loadu_si128((const __m128i *)x[2*i+0].qs);
- auto q2bits_2 = _mm_loadu_si128((const __m128i *)x[2*i+1].qs);
- auto q2_1 = MM256_SET_M128I(_mm_srli_epi16(q2bits_1, 2), q2bits_1);
- auto q2_2 = MM256_SET_M128I(_mm_srli_epi16(q2bits_2, 2), q2bits_2);
- auto v1 = _mm256_sub_epi8(_mm256_and_si256(q2_1, mask2), m1_8);
- auto v3 = _mm256_sub_epi8(_mm256_and_si256(q2_2, mask2), m1_8);
- auto v2 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(q2_1, 4), mask2), m1_8);
- auto v4 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(q2_2, 4), mask2), m1_8);
+ deq.prepare4(i, val);
for (int iy = 0; iy < nrc_y; ++iy) {
- auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i, 0), v1);
- auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i, 1), v2);
- auto dot3 = _mm256_sign_epi8(q8.load_quants(iy, i, 2), v3);
- auto dot4 = _mm256_sign_epi8(q8.load_quants(iy, i, 3), v4);
+ auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i, 0), val[0]);
+ auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i, 1), val[1]);
+ auto dot3 = _mm256_sign_epi8(q8.load_quants(iy, i, 2), val[2]);
+ auto dot4 = _mm256_sign_epi8(q8.load_quants(iy, i, 3), val[3]);
#if defined __AVX512VNNI__ && defined __AVX512VL__
accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(
- accd[iy], m1_8, dot1), m1_8, dot2), m1_8, dot3), m1_8, dot4);
+ accd[iy], deq.m1_8, dot1), deq.m1_8, dot2), deq.m1_8, dot3), deq.m1_8, dot4);
#else
auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(
- _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot1), _mm256_maddubs_epi16(m1_8, dot2)),
- _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot3), _mm256_maddubs_epi16(m1_8, dot4))));
+ _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, dot1), _mm256_maddubs_epi16(deq.m1_8, dot2)),
+ _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, dot3), _mm256_maddubs_epi16(deq.m1_8, dot4))));
accd[iy] = _mm256_add_epi32(dot, accd[iy]);
#endif
}
@@ -1537,17 +1546,14 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const
}
int i = 2*(nb/2);
if (i < nb) {
- auto q2bits = _mm_loadu_si128((const __m128i *)x[i].qs);
- auto q2_1 = MM256_SET_M128I(_mm_srli_epi16(q2bits, 2), q2bits);
- auto v1 = _mm256_sub_epi8(_mm256_and_si256(q2_1, mask2), m1_8);
- auto v2 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(q2_1, 4), mask2), m1_8);
+ deq.prepare2(i, val);
for (int iy = 0; iy < nrc_y; ++iy) {
- auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 0), v1);
- auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 1), v2);
+ auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 0), val[0]);
+ auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 1), val[1]);
#if defined __AVX512VNNI__ && defined __AVX512VL__
- accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], m1_8, dot1), m1_8, dot2);
+ accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], deq.m1_8, dot1), deq.m1_8, dot2);
#else
- dot1 = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot1), _mm256_maddubs_epi16(m1_8, dot2)));
+ dot1 = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, dot1), _mm256_maddubs_epi16(deq.m1_8, dot2)));
accd[iy] = _mm256_add_epi32(dot1, accd[iy]);
#endif
}