diff options
author | Iwan Kawrakow <iwan.kawrakow@gmail.com> | 2024-06-20 18:39:31 +0300 |
---|---|---|
committer | Iwan Kawrakow <iwan.kawrakow@gmail.com> | 2024-06-22 12:02:52 +0300 |
commit | f0325c5826c55bb9796485d49bc971a17735e96a (patch) | |
tree | e70069ee59e64f3882468cc65f09831ae266d744 | |
parent | e05cca9ef652eee7b42927485a3821b14e3c565f (diff) |
bitnet(scale in a separate tensor): more CPU improvements
It seems it is enough to have 4 scales per row for Q8.
I get PPL = 8.5470 with this, which is slightly higher than
the 8.5430 we get with 1 scale per 128 activations, but still
OK, I think.
With this, we get the following performance:
Systema | quant | PP-512 | TG-128a | quant | PP-512 | TG-12s |
M2 Max | iq2bn 229.02 ± 0.37 78.75 ± 0.61 | iq1bn | 146.67 ± 2.85 33.12 ± 0.03
Ryzen7950| iq2bn 379.36 ± 1.03 49.08 ± 0.18 | iq1bn | 247.12 ± 1.53 32.80 ± 0.02
Ryzen5975| iq2bn 465.28 ± 0.57 39.17 ± 0.02 | iq1bn | 325.86 ± 0.46 26.60 ± 0.10
-rw-r--r-- | iqk-quantize.cpp | 60 | ||||
-rw-r--r-- | iqk_mul_mat.cpp | 270 |
2 files changed, 159 insertions, 171 deletions
diff --git a/iqk-quantize.cpp b/iqk-quantize.cpp index 1a672803..40eff93f 100644 --- a/iqk-quantize.cpp +++ b/iqk-quantize.cpp @@ -355,8 +355,8 @@ void ggml_vec_dot_iq2_bn_q8_K64(int n, float * s, size_t bs, const void * vx, si } void quantize_row_q8_K64_reference(const float * x, block_q8_K64 * y, int64_t k) { - assert(k % 64 == 0); - const int64_t nb = k / 64; + //assert(k % 64 == 0); + //const int64_t nb = k / 64; // Check if a row-wise scale works. It almost does, PPL is only ~0.02 higher //float amax = 0; @@ -374,50 +374,24 @@ void quantize_row_q8_K64_reference(const float * x, block_q8_K64 * y, int64_t k) // x += 64; //} - block_q8_K128 * yp = (block_q8_K128 *)y; - for (int i = 0; i < nb/2; i++) { - float max = 0; - float amax = 0; - for (int j = 0; j < 128; ++j) { - float ax = fabsf(x[j]); - if (ax > amax) { - amax = ax; max = x[j]; + float aux[4] = {0.f, 0.f, 0.f, 0.f}; + for (int j = 0; j < k; j += 16) { + for (int i = 0; i < 4; ++i) { + for (int l = 0; l < 4; ++l) { + float ax = fabsf(x[j+4*i+l]); + aux[i] = std::max(aux[i], ax); } } - if (!amax) { - yp[i].d = 0; - memset(yp[i].qs, 0, 128); - x += 128; - continue; - } - const float iscale = -127.f/max; - for (int j = 0; j < 128; ++j) { - int v = nearest_int(iscale*x[j]); - yp[i].qs[j] = MIN(127, v); - } - yp[i].d = 1/iscale; - x += 128; } - int i = 2*(nb/2); - if (i < nb) { - float max = 0; - float amax = 0; - for (int j = 0; j < 64; ++j) { - float ax = fabsf(x[j]); - if (ax > amax) { - amax = ax; max = x[j]; - } - } - if (!amax) { - yp[i/2].d = 0; - memset(yp[i/2].qs, 0, 64); - } else { - const float iscale = -127.f/max; - for (int j = 0; j < 64; ++j) { - int v = nearest_int(iscale*x[j]); - yp[i/2].qs[j] = MIN(127, v); - } - yp[i/2].d = 1/iscale; + float * dptr = (float *)y; + for (int i = 0; i < 4; ++i) { + dptr[i] = aux[i]/127; + aux[i] = dptr[i] > 0 ? 1/dptr[i] : 0.f; + } + auto qs = (int8_t *)(dptr + 4); + for (int j = 0; j < k; j += 16) { + for (int i = 0; i < 4; ++i) { + for (int l = 0; l < 4; ++l) qs[j+4*i+l] = nearest_int(aux[i]*x[j+4*i+l]); } } } diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp index 5148d184..1d382a41 100644 --- a/iqk_mul_mat.cpp +++ b/iqk_mul_mat.cpp @@ -256,6 +256,13 @@ inline float hsum_float_4(__m128 x) { inline float hsum_float_8(__m256 x) { return hsum_float_4(_mm_add_ps(_mm256_castps256_ps128(x), _mm256_extractf128_ps(x, 1))); } +inline int hsum_i32_8(const __m256i a) { + const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1)); + const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128); + const __m128i sum64 = _mm_add_epi32(hi64, sum128); + const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); + return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); +} #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) @@ -1318,12 +1325,19 @@ template <int nrc> struct Q8_K64 { constexpr static int nrc_y = nrc; - Q8_K64(const DataInfo& info) { for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8_K128 *)info.src1_row(iy); } + Q8_K64(const DataInfo& info) { + for (int iy = 0; iy < nrc_y; ++iy) { + const float * dptr = (const float *)info.src1_row(iy); + std::memcpy(d + 4*iy, dptr, 4*sizeof(float)); + y[iy] = (const int8_t *)(dptr + 4); + } + } - inline __m256i load_quants(int iy, int i, int j) const { return _mm256_loadu_si256((const __m256i*)y[iy][i].qs + j); } - inline float scale(int iy, int i) const { return y[iy][i].d; } + inline __m256i load_quants(int iy, int i, int j) const { return _mm256_loadu_si256((const __m256i*)y[iy] + 4*i + j); } + inline __m128 scale(int iy) const { return _mm_loadu_ps(d + 4*iy); } - const block_q8_K128 * y[nrc_y]; + float d[4*nrc_y]; + const int8_t * y[nrc_y]; }; struct DequantizerIQ1BN { @@ -1333,13 +1347,8 @@ struct DequantizerIQ1BN { const __m256i shuff3 = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000); const __m256i shuff4 = _mm256_set_epi64x(0x0707070707070707, 0x0606060606060606, 0x0505050505050505, 0x0404040404040404); const __m256i mask1 = _mm256_set1_epi64x(0x8040201008040201); - //__m256i signs[2]; IQK_ALWAYS_INLINE void prepare_iq1bn_quants(uint8_t extra, const uint8_t * ql, const uint8_t * qh, __m256i& v1, __m256i& v2) { - //auto all_signs = _mm256_set1_epi8(extra); - //all_signs = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(all_signs, mask1), mask1), m1_8); - //signs[0] = _mm256_shuffle_epi8(all_signs, shuff3); - //signs[1] = _mm256_shuffle_epi8(all_signs, shuff4); auto aux1 = _mm256_set_epi64x(iq1bn_grid_xxx[ql[3] | ((qh[1] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[2] | ((qh[1] << 8) & 0x0f00)], iq1bn_grid_xxx[ql[1] | ((qh[0] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[0] | ((qh[0] << 8) & 0x0f00)]); @@ -1350,8 +1359,6 @@ struct DequantizerIQ1BN { _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux1, shuff1), mask1), mask1)); v2 = _mm256_sub_epi8(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux2, shuff2), mask1), mask1), _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux2, shuff1), mask1), mask1)); - //v1 = _mm256_sign_epi8(v1, signs[0]); - //v2 = _mm256_sign_epi8(v2, signs[1]); auto all_signs = _mm256_set1_epi8(extra); all_signs = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(all_signs, mask1), mask1), m1_8); @@ -1365,7 +1372,7 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const const int nb = n / QK_IQ1BN; Q8_K64<nrc_y> q8(info); DequantizerIQ1BN deq; - __m256 accd[nrc_y]; + __m256i accd[nrc_y]; __m256i val[4]; #if !(defined __AVX512VNNI__ && defined __AVX512VL__) @@ -1378,30 +1385,55 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const x = (const block_iq1_bn *)((const char *)vx + ix*bx); - for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps(); + if constexpr (nrc_y == 1) { + __m256i acc1 = _mm256_setzero_si256(), acc2 = _mm256_setzero_si256(); + for (int i = 0; i < nb/2; ++i) { + deq.prepare_iq1bn_quants(x[2*i+0].extra, x[2*i+0].ql, x[2*i+0].qh, val[0], val[1]); + deq.prepare_iq1bn_quants(x[2*i+1].extra, x[2*i+1].ql, x[2*i+1].qh, val[2], val[3]); +#if defined __AVX512VNNI__ && defined __AVX512VL__ + auto dot1 = _mm256_sign_epi8(q8.load_quants(0, i, 0), val[0]); + auto dot2 = _mm256_sign_epi8(q8.load_quants(0, i, 1), val[1]); + auto dot3 = _mm256_sign_epi8(q8.load_quants(0, i, 2), val[2]); + auto dot4 = _mm256_sign_epi8(q8.load_quants(0, i, 3), val[3]); + acc1 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc1, deq.m1_8, dot1), deq.m1_8, dot2); + acc2 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc2, deq.m1_8, dot3), deq.m1_8, dot4); +#else + auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 0), val[0])), + _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 1), val[1]))); + auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 2), val[2])), + _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 3), val[3]))); + acc1 = _mm256_add_epi32(acc1, _mm256_madd_epi16(m1_16, dot1)); + acc2 = _mm256_add_epi32(acc2, _mm256_madd_epi16(m1_16, dot2)); +#endif + } + accd[0] = _mm256_add_epi32(acc1, acc2); + } + else { - for (int i = 0; i < nb/2; ++i) { + for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_si256(); - deq.prepare_iq1bn_quants(x[2*i+0].extra, x[2*i+0].ql, x[2*i+0].qh, val[0], val[1]); - deq.prepare_iq1bn_quants(x[2*i+1].extra, x[2*i+1].ql, x[2*i+1].qh, val[2], val[3]); + for (int i = 0; i < nb/2; ++i) { - for (int iy = 0; iy < nrc_y; ++iy) { + deq.prepare_iq1bn_quants(x[2*i+0].extra, x[2*i+0].ql, x[2*i+0].qh, val[0], val[1]); + deq.prepare_iq1bn_quants(x[2*i+1].extra, x[2*i+1].ql, x[2*i+1].qh, val[2], val[3]); + + for (int iy = 0; iy < nrc_y; ++iy) { #if defined __AVX512VNNI__ && defined __AVX512VL__ - auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i, 0), val[0]); - auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i, 1), val[1]); - auto dot3 = _mm256_sign_epi8(q8.load_quants(iy, i, 2), val[2]); - auto dot4 = _mm256_sign_epi8(q8.load_quants(iy, i, 3), val[3]); - auto dot = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32( - _mm256_setzero_si256(), deq.m1_8, dot1), deq.m1_8, dot2), deq.m1_8, dot3), deq.m1_8, dot4); - accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.scale(iy, i)), _mm256_cvtepi32_ps(dot), accd[iy]); + auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i, 0), val[0]); + auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i, 1), val[1]); + auto dot3 = _mm256_sign_epi8(q8.load_quants(iy, i, 2), val[2]); + auto dot4 = _mm256_sign_epi8(q8.load_quants(iy, i, 3), val[3]); + accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32( + accd[iy], deq.m1_8, dot1), deq.m1_8, dot2), deq.m1_8, dot3), deq.m1_8, dot4); #else - auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 0), val[0])), - _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 1), val[1]))); - auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 2), val[2])), - _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 3), val[3]))); - dot1 = _mm256_madd_epi16(m1_16, _mm256_add_epi16(dot1, dot2)); - accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.scale(iy, i)), _mm256_cvtepi32_ps(dot1), accd[iy]); + auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 0), val[0])), + _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 1), val[1]))); + auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 2), val[2])), + _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 3), val[3]))); + dot1 = _mm256_madd_epi16(m1_16, _mm256_add_epi16(dot1, dot2)); + accd[iy] = _mm256_add_epi32(dot1, accd[iy]); #endif + } } } int i = 2*(nb/2); @@ -1411,17 +1443,20 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 0), val[0]); auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 1), val[1]); #if defined __AVX512VNNI__ && defined __AVX512VL__ - auto dot = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_setzero_si256(), deq.m1_8, dot1), deq.m1_8, dot2); + accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], deq.m1_8, dot1), deq.m1_8, dot2); #else auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, dot1), _mm256_maddubs_epi16(deq.m1_8, dot2))); + accd[iy] = _mm256_add_epi32(dot, accd[iy]); #endif - accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.scale(iy, i/2)), _mm256_cvtepi32_ps(dot), accd[iy]); } } for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix, iy, hsum_float_8(accd[iy])); + auto vd = q8.scale(iy); + auto sumi = _mm_add_epi32(_mm256_castsi256_si128(accd[iy]), _mm256_extractf128_si256(accd[iy], 1)); + auto sumf = _mm_mul_ps(vd, _mm_cvtepi32_ps(sumi)); + info.store(ix, iy, hsum_float_4(sumf)); } } @@ -1431,7 +1466,7 @@ template <int nrc_y> IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { const int nb = n / QK_IQ1BN; Q8_K64<nrc_y> q8(info); - __m256 accd[nrc_y]; + __m256i accd[nrc_y]; const auto m1_8 = _mm256_set1_epi8(1); const auto mask2 = _mm256_set1_epi8(3); @@ -1458,14 +1493,13 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const auto dot3 = _mm256_sign_epi8(q8.load_quants(iy, 0, 2), v3); auto dot4 = _mm256_sign_epi8(q8.load_quants(iy, 0, 3), v4); #if defined __AVX512VNNI__ && defined __AVX512VL__ - auto dot = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32( + accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32( _mm256_setzero_si256(), m1_8, dot1), m1_8, dot2), m1_8, dot3), m1_8, dot4); #else - auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16( + accd[iy] = _mm256_madd_epi16(m1_16, _mm256_add_epi16( _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot1), _mm256_maddubs_epi16(m1_8, dot2)), _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot3), _mm256_maddubs_epi16(m1_8, dot4)))); #endif - accd[iy] = _mm256_mul_ps(_mm256_set1_ps(q8.scale(iy, 0)), _mm256_cvtepi32_ps(dot)); } } @@ -1484,14 +1518,14 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const auto dot3 = _mm256_sign_epi8(q8.load_quants(iy, i, 2), v3); auto dot4 = _mm256_sign_epi8(q8.load_quants(iy, i, 3), v4); #if defined __AVX512VNNI__ && defined __AVX512VL__ - auto dot = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32( - _mm256_setzero_si256(), m1_8, dot1), m1_8, dot2), m1_8, dot3), m1_8, dot4); + accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32( + accd[iy], m1_8, dot1), m1_8, dot2), m1_8, dot3), m1_8, dot4); #else auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16( _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot1), _mm256_maddubs_epi16(m1_8, dot2)), _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot3), _mm256_maddubs_epi16(m1_8, dot4)))); + accd[iy] = _mm256_add_epi32(dot, accd[iy]); #endif - accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.scale(iy, i)), _mm256_cvtepi32_ps(dot), accd[iy]); } } int i = 2*(nb/2); @@ -1504,18 +1538,20 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 0), v1); auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 1), v2); #if defined __AVX512VNNI__ && defined __AVX512VL__ - dot1 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_setzero_si256(), m1_8, dot1), m1_8, dot2); + accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], m1_8, dot1), m1_8, dot2); #else dot1 = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot1), _mm256_maddubs_epi16(m1_8, dot2))); + accd[iy] = _mm256_add_epi32(dot1, accd[iy]); #endif - accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.scale(iy, i/2)), _mm256_cvtepi32_ps(dot1), accd[iy]); } } for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix, iy, hsum_float_8(accd[iy])); + auto vd = q8.scale(iy); + auto sumi = _mm_add_epi32(_mm256_castsi256_si128(accd[iy]), _mm256_extractf128_si256(accd[iy], 1)); + auto sumf = _mm_mul_ps(vd, _mm_cvtepi32_ps(sumi)); + info.store(ix, iy, hsum_float_4(sumf)); } - } } @@ -4149,13 +4185,20 @@ template <int nrc> struct Q8_K64 { constexpr static int nrc_y = nrc; - Q8_K64(const DataInfo& info) { for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8_K128 *)info.src1_row(iy); } + Q8_K64(const DataInfo& info) { + for (int iy = 0; iy < nrc_y; ++iy) { + auto dptr = (const float *)info.src1_row(iy); + std::memcpy(d + 4*iy, dptr, 4*sizeof(float)); + y[iy] = (const int8_t *)(dptr + 4); + } + } - inline int8x16x4_t load_quants64(int iy, int i, int j) const { return vld1q_s8_x4(y[iy][i].qs + 64*j); } - inline int8x16x2_t load_quants(int iy, int i, int j) const { return vld1q_s8_x2(y[iy][i].qs + 32*j); } - inline float scale(int iy, int i) const { return y[iy][i].d; } + inline int8x16x4_t load_quants64(int iy, int i, int j) const { return vld1q_s8_x4(y[iy] + 128*i + 64*j); } + inline int8x16x2_t load_quants(int iy, int i, int j) const { return vld1q_s8_x2(y[iy] + 128*i + 32*j); } + inline float32x4_t scale(int iy) const { return vld1q_f32(d + 4*iy); } - const block_q8_K128 * y[nrc_y]; + float d[4*nrc_y]; + const int8_t * y[nrc_y]; }; struct DequantizerIQ1BN { @@ -4204,8 +4247,8 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn Q8_K64<nrc_y> q8(info); DequantizerIQ1BN deq; - float32x4_t accd[nrc_y]; - int8x16x4_t v1, v2; + int32x4_t accd[nrc_y]; + int8x16x4_t v1, v2; const block_iq1_bn * x = (const block_iq1_bn *)((const char *)vx); @@ -4213,35 +4256,37 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn x = (const block_iq1_bn *)((const char *)vx + ix*bx); - for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = vdupq_n_f32(0.f); - for (int i = 0; i < nb/2; ++i) { + if constexpr (nrc_y == 1) { + int32x4_t acc[4] = {}; + for (int i = 0; i < nb/2; ++i) { + deq.prepare_iq1bn_quants(x[2*i+0].extra, x[2*i+0].ql, x[2*i+0].qh, v1); + auto q = q8.load_quants64(0, i, 0); + for (int j = 0; j < 4; ++j) acc[j] = ggml_vdotq_s32(acc[j], q.val[j], v1.val[j]); + deq.prepare_iq1bn_quants(x[2*i+1].extra, x[2*i+1].ql, x[2*i+1].qh, v1); + q = q8.load_quants64(0, i, 1); + for (int j = 0; j < 4; ++j) acc[j] = ggml_vdotq_s32(acc[j], q.val[j], v1.val[j]); + } + accd[0] = vaddq_s32(vaddq_s32(acc[0], acc[1]), vaddq_s32(acc[2], acc[3])); + } + else { - deq.prepare_iq1bn_quants(x[2*i+0].extra, x[2*i+0].ql, x[2*i+0].qh, v1); - deq.prepare_iq1bn_quants(x[2*i+1].extra, x[2*i+1].ql, x[2*i+1].qh, v2); + for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = vdupq_n_s32(0); + + for (int i = 0; i < nb/2; ++i) { + + deq.prepare_iq1bn_quants(x[2*i+0].extra, x[2*i+0].ql, x[2*i+0].qh, v1); + deq.prepare_iq1bn_quants(x[2*i+1].extra, x[2*i+1].ql, x[2*i+1].qh, v2); - int32x4_t sumi1 = vdupq_n_s32(0); - int32x4_t sumi2 = vdupq_n_s32(0); - if constexpr (nrc_y == 1) { - auto q1 = q8.load_quants64(0, i, 0); - auto q2 = q8.load_quants64(0, i, 1); - for (int j = 0; j < 4; ++j) { - sumi1 = ggml_vdotq_s32(sumi1, q1.val[j], v1.val[j]); - sumi2 = ggml_vdotq_s32(sumi2, q2.val[j], v2.val[j]); - } - accd[0] = vfmaq_f32(accd[0], vdupq_n_f32(q8.scale(0, i)), vcvtq_f32_s32(vaddq_s32(sumi1, sumi2))); - } else { for (int iy = 0; iy < nrc_y; ++iy) { - int32x4_t sumi = vdupq_n_s32(0); auto q = q8.load_quants(iy, i, 0); - sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v1.val[0]), q.val[1], v1.val[1]); + accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[0]), q.val[1], v1.val[1]); q = q8.load_quants(iy, i, 1); - sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v1.val[2]), q.val[1], v1.val[3]); + accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[2]), q.val[1], v1.val[3]); q = q8.load_quants(iy, i, 2); - sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v2.val[0]), q.val[1], v2.val[1]); + accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v2.val[0]), q.val[1], v2.val[1]); q = q8.load_quants(iy, i, 3); - sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v2.val[2]), q.val[1], v2.val[3]); - accd[iy] = vfmaq_f32(accd[iy], vdupq_n_f32(q8.scale(iy, i)), vcvtq_f32_s32(sumi)); + accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v2.val[2]), q.val[1], v2.val[3]); } } } @@ -4250,25 +4295,21 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn deq.prepare_iq1bn_quants(x[i].extra, x[i].ql, x[i].qh, v1); if constexpr (nrc_y == 1) { auto q = q8.load_quants(0, i/2, 0); - int32x4_t sumi = vdupq_n_s32(0); for (int j = 0; j < 4; ++j) { - sumi = ggml_vdotq_s32(sumi, q.val[j], v1.val[j]); + accd[0] = ggml_vdotq_s32(accd[0], q.val[j], v1.val[j]); } - accd[0] = vfmaq_f32(accd[0], vdupq_n_f32(q8.scale(0, i/2)), vcvtq_f32_s32(sumi)); } else { for (int iy = 0; iy < nrc_y; ++iy) { - int32x4_t sumi = vdupq_n_s32(0); auto q = q8.load_quants(iy, i/2, 0); - sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v1.val[0]), q.val[1], v1.val[1]); + accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[0]), q.val[1], v1.val[1]); q = q8.load_quants(iy, i/2, 1); - sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v1.val[2]), q.val[1], v1.val[3]); - accd[iy] = vfmaq_f32(accd[iy], vdupq_n_f32(q8.scale(iy, i/2)), vcvtq_f32_s32(sumi)); + accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[2]), q.val[1], v1.val[3]); } } } for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix, iy, vaddvq_f32(accd[iy])); + info.store(ix, iy, vaddvq_f32(vmulq_f32(q8.scale(iy), vcvtq_f32_s32(accd[iy])))); } } @@ -4280,8 +4321,7 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn Q8_K64<nrc_y> q8(info); - float32x4_t accd[nrc_y]; - int8x16x4_t v1, v2; + int32x4_t accd[nrc_y]; const auto m1 = vdupq_n_u8(1); const auto mask2 = vdupq_n_s8(3); @@ -4290,36 +4330,10 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn const block_iq2_bn * x = (const block_iq2_bn *)((const char *)vx + ix*bx); - { - auto q2bits = vld1q_u8(x[0].qs); - v1.val[0] = vsubq_s8(vandq_s8(q2bits, mask2), m1); - v1.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1); - v1.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1); - v1.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1); - q2bits = vld1q_u8(x[1].qs); - v2.val[0] = vsubq_s8(vandq_s8(q2bits, mask2), m1); - v2.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1); - v2.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1); - v2.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1); - for (int iy = 0; iy < nrc_y; ++iy) { - int32x4_t sumi = vdupq_n_s32(0); - auto q = q8.load_quants(iy, 0, 0); - sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v1.val[0]), q.val[1], v1.val[1]); - q = q8.load_quants(iy, 0, 1); - sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v1.val[2]), q.val[1], v1.val[3]); - q = q8.load_quants(iy, 0, 2); - sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v2.val[0]), q.val[1], v2.val[1]); - q = q8.load_quants(iy, 0, 3); - sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v2.val[2]), q.val[1], v2.val[3]); - accd[iy] = vmulq_f32(vdupq_n_f32(q8.scale(iy, 0)), vcvtq_f32_s32(sumi)); - } - - } - if constexpr (nrc_y == 1) { - for (int i = 1; i < nb/2; ++i) { - auto sumi1 = vdupq_n_s32(0); - auto sumi2 = vdupq_n_s32(0); + int8x16x4_t v1; + int32x4_t acc[4] = {}; + for (int i = 0; i < nb/2; ++i) { for (int j = 0; j < 2; ++j) { auto q = q8.load_quants64(0, i, j); auto q2bits = vld1q_u8(x[2*i+j].qs); @@ -4327,13 +4341,17 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn v1.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1); v1.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1); v1.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1); - sumi1 = ggml_vdotq_s32(ggml_vdotq_s32(sumi1, q.val[0], v1.val[0]), q.val[1], v1.val[1]); - sumi2 = ggml_vdotq_s32(ggml_vdotq_s32(sumi2, q.val[2], v1.val[2]), q.val[3], v1.val[3]); + acc[0] = ggml_vdotq_s32(acc[0], q.val[0], v1.val[0]); + acc[1] = ggml_vdotq_s32(acc[1], q.val[1], v1.val[1]); + acc[2] = ggml_vdotq_s32(acc[2], q.val[2], v1.val[2]); + acc[3] = ggml_vdotq_s32(acc[3], q.val[3], v1.val[3]); } - accd[0] = vfmaq_f32(accd[0], vdupq_n_f32(q8.scale(0, i)), vcvtq_f32_s32(vaddq_s32(sumi1, sumi2))); } + accd[0] = vaddq_s32(vaddq_s32(acc[0], acc[1]), vaddq_s32(acc[2], acc[3])); } else { - for (int i = 1; i < nb/2; ++i) { + int8x16x4_t v1, v2; + for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = vdupq_n_s32(0); + for (int i = 0; i < nb/2; ++i) { auto q2bits = vld1q_u8(x[2*i+0].qs); v1.val[0] = vsubq_s8(vandq_s8(q2bits, mask2), m1); v1.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1); @@ -4345,40 +4363,36 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn v2.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1); v2.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1); for (int iy = 0; iy < nrc_y; ++iy) { - int32x4_t sumi = vdupq_n_s32(0); auto q = q8.load_quants(iy, i, 0); - sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v1.val[0]), q.val[1], v1.val[1]); + accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[0]), q.val[1], v1.val[1]); q = q8.load_quants(iy, i, 1); - sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v1.val[2]), q.val[1], v1.val[3]); + accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[2]), q.val[1], v1.val[3]); q = q8.load_quants(iy, i, 2); - sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v2.val[0]), q.val[1], v2.val[1]); + accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v2.val[0]), q.val[1], v2.val[1]); q = q8.load_quants(iy, i, 3); - sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v2.val[2]), q.val[1], v2.val[3]); - accd[iy] = vfmaq_f32(accd[iy], vdupq_n_f32(q8.scale(iy, i)), vcvtq_f32_s32(sumi)); + accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v2.val[2]), q.val[1], v2.val[3]); } } } int i = 2*(nb/2); if (i < nb) { - auto q2bits = vld1q_u8(x[2*i+0].qs); + auto q2bits = vld1q_u8(x[i].qs); + int8x16x4_t v1; v1.val[0] = vsubq_s8(vandq_s8(q2bits, mask2), m1); v1.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1); v1.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1); v1.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1); for (int iy = 0; iy < nrc_y; ++iy) { - int32x4_t sumi = vdupq_n_s32(0); auto q = q8.load_quants(iy, i/2, 0); - sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v1.val[0]), q.val[1], v1.val[1]); + accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[0]), q.val[1], v1.val[1]); q = q8.load_quants(iy, i/2, 1); - sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v1.val[2]), q.val[1], v1.val[3]); - accd[iy] = vfmaq_f32(accd[iy], vdupq_n_f32(q8.scale(iy, i/2)), vcvtq_f32_s32(sumi)); + accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[2]), q.val[1], v1.val[3]); } } for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix, iy, vaddvq_f32(accd[iy])); + info.store(ix, iy, vaddvq_f32(vmulq_f32(q8.scale(iy), vcvtq_f32_s32(accd[iy])))); } - } } |