summaryrefslogtreecommitdiff
path: root/iqk_mul_mat.cpp
diff options
context:
space:
mode:
authorIwan Kawrakow <iwan.kawrakow@gmail.com>2024-06-18 12:00:16 +0300
committerIwan Kawrakow <iwan.kawrakow@gmail.com>2024-06-22 12:02:52 +0300
commit8c6276f6a1c6d9d82b5f0114d838fcc4f277234a (patch)
tree6f80e1f3b1017d3d3e88082db0399c6884bc9725 /iqk_mul_mat.cpp
parent1de6476d751a02978b035feb38066462c4382877 (diff)
Bitnet: 2.25 bpw version
Just scaler and AVX2 for now. PP-512 is even faster (325 t/s on the Ryzn-7950X, 404 t/s on Ryzen-5975WX). We lose ~6-7% for TG due to being memory bound and the model being 10% larger.
Diffstat (limited to 'iqk_mul_mat.cpp')
-rw-r--r--iqk_mul_mat.cpp46
1 files changed, 13 insertions, 33 deletions
diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp
index 83d3a472..9eb063e9 100644
--- a/iqk_mul_mat.cpp
+++ b/iqk_mul_mat.cpp
@@ -1412,14 +1412,8 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const
const int nb = n / QK_IQ1BN;
Q8_K64<nrc_y> q8(info);
__m256 accd[nrc_y];
- __m256i signs[2];
const auto m1_8 = _mm256_set1_epi8(1);
- const auto shuff1 = _mm256_set_epi64x(0x0808080808080808, 0x0000000000000000, 0x0808080808080808, 0x0000000000000000);
- const auto shuff2 = _mm256_add_epi8(shuff1, m1_8);
- const auto shuff3 = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000);
- const auto shuff4 = _mm256_set_epi64x(0x0707070707070707, 0x0606060606060606, 0x0505050505050505, 0x0404040404040404);
- const auto mask1 = _mm256_set1_epi64x(0x8040201008040201);
const auto mask2 = _mm256_set1_epi8(3);
#if !(defined __AVX512VNNI__ && defined __AVX512VL__)
const auto m1_16 = _mm256_set1_epi16(1);
@@ -1428,37 +1422,23 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const
for (int ix = 0; ix < nrc_x; ++ix) {
const block_iq2_bn * x = (const block_iq2_bn *)((const char *)vx + ix*bx);
- float d = GGML_FP16_TO_FP32(*(const ggml_half *)x);
- auto extra_ptr = (const uint16_t *)x;
+ float d = GGML_FP16_TO_FP32(x[0].d);
- auto all_signs = _mm256_set1_epi8(extra_ptr[1]);
- all_signs = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(all_signs, mask1), mask1), m1_8);
- signs[0] = _mm256_shuffle_epi8(all_signs, shuff3);
- signs[1] = _mm256_shuffle_epi8(all_signs, shuff4);
-
- auto ql = (const uint8_t *)(extra_ptr + 2);
- auto qh = ql + QK_IQ1BN/8;
- auto aux1 = _mm256_set_epi64x(iq1bn_grid_xxx[ql[3] | ((qh[1] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[2] | ((qh[1] << 8) & 0x0f00)],
- iq1bn_grid_xxx[ql[1] | ((qh[0] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[0] | ((qh[0] << 8) & 0x0f00)]);
- auto aux2 = _mm256_set_epi64x(iq1bn_grid_xxx[ql[7] | ((qh[3] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[6] | ((qh[3] << 8) & 0x0f00)],
- iq1bn_grid_xxx[ql[5] | ((qh[2] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[4] | ((qh[2] << 8) & 0x0f00)]);
-
- auto v1 = _mm256_sub_epi8(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux1, shuff2), mask1), mask1),
- _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux1, shuff1), mask1), mask1));
- auto v2 = _mm256_sub_epi8(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux2, shuff2), mask1), mask1),
- _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux2, shuff1), mask1), mask1));
-
- v1 = _mm256_sign_epi8(v1, signs[0]);
- v2 = _mm256_sign_epi8(v2, signs[1]);
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, 0, 0), v1);
- auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, 0, 1), v2);
+ {
+ auto q2bits = _mm_loadu_si128((const __m128i *)x[0].qs);
+ auto q2 = MM256_SET_M128I(_mm_srli_epi16(q2bits, 2), q2bits);
+ auto v1 = _mm256_sub_epi8(_mm256_and_si256(q2, mask2), m1_8);
+ auto v2 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(q2, 4), mask2), m1_8);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, 0, 0), v1);
+ auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, 0, 1), v2);
#if defined __AVX512VNNI__ && defined __AVX512VL__
- auto dot = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_setzero_si256(), m1_8, dot1), m1_8, dot2);
+ auto dot = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_setzero_si256(), m1_8, dot1), m1_8, dot2);
#else
- auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot1), _mm256_maddubs_epi16(m1_8, dot2)));
+ auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot1), _mm256_maddubs_epi16(m1_8, dot2)));
#endif
- accd[iy] = _mm256_mul_ps(_mm256_set1_ps(q8.scale(iy, 0)), _mm256_cvtepi32_ps(dot));
+ accd[iy] = _mm256_mul_ps(_mm256_set1_ps(q8.scale(iy, 0)), _mm256_cvtepi32_ps(dot));
+ }
}
for (int i = 1; i < nb; ++i) {