diff options
-rw-r--r-- | ggml-common.h | 5 | ||||
-rw-r--r-- | iqk-quantize.cpp | 42 | ||||
-rw-r--r-- | iqk_mul_mat.cpp | 387 |
3 files changed, 303 insertions, 131 deletions
diff --git a/ggml-common.h b/ggml-common.h index d3945975..4de80794 100644 --- a/ggml-common.h +++ b/ggml-common.h @@ -306,6 +306,11 @@ typedef struct { int8_t qs[64]; // quants } block_q8_K64; static_assert(sizeof(block_q8_K64) == sizeof(float) + 64, "wrong q8_K64 block size/padding"); +typedef struct { + float d; // delta + int8_t qs[128]; // quants +} block_q8_K128; +static_assert(sizeof(block_q8_K128) == sizeof(float) + 128, "wrong q8_K128 block size/padding"); // (Almost) "true" 2-bit quantization. // Due to the need to use blocks as per ggml design, it ends up using diff --git a/iqk-quantize.cpp b/iqk-quantize.cpp index 6622d5ba..1a672803 100644 --- a/iqk-quantize.cpp +++ b/iqk-quantize.cpp @@ -374,29 +374,51 @@ void quantize_row_q8_K64_reference(const float * x, block_q8_K64 * y, int64_t k) // x += 64; //} - for (int i = 0; i < nb; i++) { - + block_q8_K128 * yp = (block_q8_K128 *)y; + for (int i = 0; i < nb/2; i++) { float max = 0; float amax = 0; - for (int j = 0; j < 64; ++j) { + for (int j = 0; j < 128; ++j) { float ax = fabsf(x[j]); if (ax > amax) { amax = ax; max = x[j]; } } if (!amax) { - y[i].d = 0; - memset(y[i].qs, 0, 64); - x += 64; + yp[i].d = 0; + memset(yp[i].qs, 0, 128); + x += 128; continue; } const float iscale = -127.f/max; - for (int j = 0; j < 64; ++j) { + for (int j = 0; j < 128; ++j) { int v = nearest_int(iscale*x[j]); - y[i].qs[j] = MIN(127, v); + yp[i].qs[j] = MIN(127, v); + } + yp[i].d = 1/iscale; + x += 128; + } + int i = 2*(nb/2); + if (i < nb) { + float max = 0; + float amax = 0; + for (int j = 0; j < 64; ++j) { + float ax = fabsf(x[j]); + if (ax > amax) { + amax = ax; max = x[j]; + } + } + if (!amax) { + yp[i/2].d = 0; + memset(yp[i/2].qs, 0, 64); + } else { + const float iscale = -127.f/max; + for (int j = 0; j < 64; ++j) { + int v = nearest_int(iscale*x[j]); + yp[i/2].qs[j] = MIN(127, v); + } + yp[i/2].d = 1/iscale; } - y[i].d = 1/iscale; - x += 64; } } diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp index 0cee0ff4..5148d184 100644 --- a/iqk_mul_mat.cpp +++ b/iqk_mul_mat.cpp @@ -1318,27 +1318,56 @@ template <int nrc> struct Q8_K64 { constexpr static int nrc_y = nrc; - Q8_K64(const DataInfo& info) { for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8_K64 *)info.src1_row(iy); } + Q8_K64(const DataInfo& info) { for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8_K128 *)info.src1_row(iy); } inline __m256i load_quants(int iy, int i, int j) const { return _mm256_loadu_si256((const __m256i*)y[iy][i].qs + j); } inline float scale(int iy, int i) const { return y[iy][i].d; } - const block_q8_K64 * y[nrc_y]; + const block_q8_K128 * y[nrc_y]; +}; + +struct DequantizerIQ1BN { + const __m256i m1_8 = _mm256_set1_epi8(1); + const __m256i shuff1 = _mm256_set_epi64x(0x0808080808080808, 0x0000000000000000, 0x0808080808080808, 0x0000000000000000); + const __m256i shuff2 = _mm256_add_epi8(shuff1, m1_8); + const __m256i shuff3 = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000); + const __m256i shuff4 = _mm256_set_epi64x(0x0707070707070707, 0x0606060606060606, 0x0505050505050505, 0x0404040404040404); + const __m256i mask1 = _mm256_set1_epi64x(0x8040201008040201); + //__m256i signs[2]; + + IQK_ALWAYS_INLINE void prepare_iq1bn_quants(uint8_t extra, const uint8_t * ql, const uint8_t * qh, __m256i& v1, __m256i& v2) { + //auto all_signs = _mm256_set1_epi8(extra); + //all_signs = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(all_signs, mask1), mask1), m1_8); + //signs[0] = _mm256_shuffle_epi8(all_signs, shuff3); + //signs[1] = _mm256_shuffle_epi8(all_signs, shuff4); + + auto aux1 = _mm256_set_epi64x(iq1bn_grid_xxx[ql[3] | ((qh[1] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[2] | ((qh[1] << 8) & 0x0f00)], + iq1bn_grid_xxx[ql[1] | ((qh[0] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[0] | ((qh[0] << 8) & 0x0f00)]); + auto aux2 = _mm256_set_epi64x(iq1bn_grid_xxx[ql[7] | ((qh[3] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[6] | ((qh[3] << 8) & 0x0f00)], + iq1bn_grid_xxx[ql[5] | ((qh[2] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[4] | ((qh[2] << 8) & 0x0f00)]); + + v1 = _mm256_sub_epi8(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux1, shuff2), mask1), mask1), + _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux1, shuff1), mask1), mask1)); + v2 = _mm256_sub_epi8(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux2, shuff2), mask1), mask1), + _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux2, shuff1), mask1), mask1)); + //v1 = _mm256_sign_epi8(v1, signs[0]); + //v2 = _mm256_sign_epi8(v2, signs[1]); + + auto all_signs = _mm256_set1_epi8(extra); + all_signs = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(all_signs, mask1), mask1), m1_8); + v1 = _mm256_sign_epi8(v1, _mm256_shuffle_epi8(all_signs, shuff3)); + v2 = _mm256_sign_epi8(v2, _mm256_shuffle_epi8(all_signs, shuff4)); + } }; template <int nrc_y> IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { const int nb = n / QK_IQ1BN; Q8_K64<nrc_y> q8(info); + DequantizerIQ1BN deq; __m256 accd[nrc_y]; - __m256i signs[2]; + __m256i val[4]; - const auto m1_8 = _mm256_set1_epi8(1); - const auto shuff1 = _mm256_set_epi64x(0x0808080808080808, 0x0000000000000000, 0x0808080808080808, 0x0000000000000000); - const auto shuff2 = _mm256_add_epi8(shuff1, m1_8); - const auto shuff3 = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000); - const auto shuff4 = _mm256_set_epi64x(0x0707070707070707, 0x0606060606060606, 0x0505050505050505, 0x0404040404040404); - const auto mask1 = _mm256_set1_epi64x(0x8040201008040201); #if !(defined __AVX512VNNI__ && defined __AVX512VL__) const auto m1_16 = _mm256_set1_epi16(1); #endif @@ -1351,47 +1380,43 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps(); - for (int i = 0; i < nb; ++i) { + for (int i = 0; i < nb/2; ++i) { - auto all_signs = _mm256_set1_epi8(x[i].extra); - all_signs = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(all_signs, mask1), mask1), m1_8); - signs[0] = _mm256_shuffle_epi8(all_signs, shuff3); - signs[1] = _mm256_shuffle_epi8(all_signs, shuff4); + deq.prepare_iq1bn_quants(x[2*i+0].extra, x[2*i+0].ql, x[2*i+0].qh, val[0], val[1]); + deq.prepare_iq1bn_quants(x[2*i+1].extra, x[2*i+1].ql, x[2*i+1].qh, val[2], val[3]); - auto ql = x[i].ql; - auto qh = x[i].qh; - auto aux1 = _mm256_set_epi64x(iq1bn_grid_xxx[ql[3] | ((qh[1] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[2] | ((qh[1] << 8) & 0x0f00)], - iq1bn_grid_xxx[ql[1] | ((qh[0] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[0] | ((qh[0] << 8) & 0x0f00)]); - auto aux2 = _mm256_set_epi64x(iq1bn_grid_xxx[ql[7] | ((qh[3] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[6] | ((qh[3] << 8) & 0x0f00)], - iq1bn_grid_xxx[ql[5] | ((qh[2] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[4] | ((qh[2] << 8) & 0x0f00)]); - - auto v1 = _mm256_sub_epi8(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux1, shuff2), mask1), mask1), - _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux1, shuff1), mask1), mask1)); - auto v2 = _mm256_sub_epi8(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux2, shuff2), mask1), mask1), - _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux2, shuff1), mask1), mask1)); - - if constexpr (nrc_y == 1) { - auto dot1 = _mm256_sign_epi8(_mm256_sign_epi8(q8.load_quants(0, i, 0), signs[0]), v1); - auto dot2 = _mm256_sign_epi8(_mm256_sign_epi8(q8.load_quants(0, i, 1), signs[1]), v2); + for (int iy = 0; iy < nrc_y; ++iy) { #if defined __AVX512VNNI__ && defined __AVX512VL__ - auto dot = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_setzero_si256(), m1_8, dot1), m1_8, dot2); + auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i, 0), val[0]); + auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i, 1), val[1]); + auto dot3 = _mm256_sign_epi8(q8.load_quants(iy, i, 2), val[2]); + auto dot4 = _mm256_sign_epi8(q8.load_quants(iy, i, 3), val[3]); + auto dot = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32( + _mm256_setzero_si256(), deq.m1_8, dot1), deq.m1_8, dot2), deq.m1_8, dot3), deq.m1_8, dot4); + accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.scale(iy, i)), _mm256_cvtepi32_ps(dot), accd[iy]); #else - auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot1), _mm256_maddubs_epi16(m1_8, dot2))); + auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 0), val[0])), + _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 1), val[1]))); + auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 2), val[2])), + _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 3), val[3]))); + dot1 = _mm256_madd_epi16(m1_16, _mm256_add_epi16(dot1, dot2)); + accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.scale(iy, i)), _mm256_cvtepi32_ps(dot1), accd[iy]); #endif - accd[0] = _mm256_fmadd_ps(_mm256_set1_ps(q8.scale(0, i)), _mm256_cvtepi32_ps(dot), accd[0]); - } else { - v1 = _mm256_sign_epi8(v1, signs[0]); - v2 = _mm256_sign_epi8(v2, signs[1]); - for (int iy = 0; iy < nrc_y; ++iy) { - auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i, 0), v1); - auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i, 1), v2); + } + } + int i = 2*(nb/2); + if (i < nb) { + deq.prepare_iq1bn_quants(x[i].extra, x[i].ql, x[i].qh, val[0], val[1]); + for (int iy = 0; iy < nrc_y; ++iy) { + auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 0), val[0]); + auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 1), val[1]); #if defined __AVX512VNNI__ && defined __AVX512VL__ - auto dot = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_setzero_si256(), m1_8, dot1), m1_8, dot2); + auto dot = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_setzero_si256(), deq.m1_8, dot1), deq.m1_8, dot2); #else - auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot1), _mm256_maddubs_epi16(m1_8, dot2))); + auto dot = _mm256_madd_epi16(m1_16, + _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, dot1), _mm256_maddubs_epi16(deq.m1_8, dot2))); #endif - accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.scale(iy, i)), _mm256_cvtepi32_ps(dot), accd[iy]); - } + accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.scale(iy, i/2)), _mm256_cvtepi32_ps(dot), accd[iy]); } } @@ -1419,38 +1444,73 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const const block_iq2_bn * x = (const block_iq2_bn *)((const char *)vx + ix*bx); { - auto q2bits = _mm_loadu_si128((const __m128i *)x[0].qs); - auto q2 = MM256_SET_M128I(_mm_srli_epi16(q2bits, 2), q2bits); - auto v1 = _mm256_sub_epi8(_mm256_and_si256(q2, mask2), m1_8); - auto v2 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(q2, 4), mask2), m1_8); + auto q2bits_1 = _mm_loadu_si128((const __m128i *)x[0].qs); + auto q2bits_2 = _mm_loadu_si128((const __m128i *)x[1].qs); + auto q2_1 = MM256_SET_M128I(_mm_srli_epi16(q2bits_1, 2), q2bits_1); + auto q2_2 = MM256_SET_M128I(_mm_srli_epi16(q2bits_2, 2), q2bits_2); + auto v1 = _mm256_sub_epi8(_mm256_and_si256(q2_1, mask2), m1_8); + auto v3 = _mm256_sub_epi8(_mm256_and_si256(q2_2, mask2), m1_8); + auto v2 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(q2_1, 4), mask2), m1_8); + auto v4 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(q2_2, 4), mask2), m1_8); for (int iy = 0; iy < nrc_y; ++iy) { auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, 0, 0), v1); auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, 0, 1), v2); + auto dot3 = _mm256_sign_epi8(q8.load_quants(iy, 0, 2), v3); + auto dot4 = _mm256_sign_epi8(q8.load_quants(iy, 0, 3), v4); #if defined __AVX512VNNI__ && defined __AVX512VL__ - auto dot = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_setzero_si256(), m1_8, dot1), m1_8, dot2); + auto dot = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32( + _mm256_setzero_si256(), m1_8, dot1), m1_8, dot2), m1_8, dot3), m1_8, dot4); #else - auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot1), _mm256_maddubs_epi16(m1_8, dot2))); + auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16( + _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot1), _mm256_maddubs_epi16(m1_8, dot2)), + _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot3), _mm256_maddubs_epi16(m1_8, dot4)))); #endif accd[iy] = _mm256_mul_ps(_mm256_set1_ps(q8.scale(iy, 0)), _mm256_cvtepi32_ps(dot)); } } - for (int i = 1; i < nb; ++i) { - auto q2bits = _mm_loadu_si128((const __m128i *)x[i].qs); - auto q2 = MM256_SET_M128I(_mm_srli_epi16(q2bits, 2), q2bits); - auto v1 = _mm256_sub_epi8(_mm256_and_si256(q2, mask2), m1_8); - auto v2 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(q2, 4), mask2), m1_8); + for (int i = 1; i < nb/2; ++i) { + auto q2bits_1 = _mm_loadu_si128((const __m128i *)x[2*i+0].qs); + auto q2bits_2 = _mm_loadu_si128((const __m128i *)x[2*i+1].qs); + auto q2_1 = MM256_SET_M128I(_mm_srli_epi16(q2bits_1, 2), q2bits_1); + auto q2_2 = MM256_SET_M128I(_mm_srli_epi16(q2bits_2, 2), q2bits_2); + auto v1 = _mm256_sub_epi8(_mm256_and_si256(q2_1, mask2), m1_8); + auto v3 = _mm256_sub_epi8(_mm256_and_si256(q2_2, mask2), m1_8); + auto v2 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(q2_1, 4), mask2), m1_8); + auto v4 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(q2_2, 4), mask2), m1_8); for (int iy = 0; iy < nrc_y; ++iy) { auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i, 0), v1); auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i, 1), v2); + auto dot3 = _mm256_sign_epi8(q8.load_quants(iy, i, 2), v3); + auto dot4 = _mm256_sign_epi8(q8.load_quants(iy, i, 3), v4); #if defined __AVX512VNNI__ && defined __AVX512VL__ - auto dot = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_setzero_si256(), m1_8, dot1), m1_8, dot2); + auto dot = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32( + _mm256_setzero_si256(), m1_8, dot1), m1_8, dot2), m1_8, dot3), m1_8, dot4); #else - auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot1), _mm256_maddubs_epi16(m1_8, dot2))); + auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16( + _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot1), _mm256_maddubs_epi16(m1_8, dot2)), + _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot3), _mm256_maddubs_epi16(m1_8, dot4)))); #endif accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.scale(iy, i)), _mm256_cvtepi32_ps(dot), accd[iy]); } } + int i = 2*(nb/2); + if (i < nb) { + auto q2bits = _mm_loadu_si128((const __m128i *)x[i].qs); + auto q2_1 = MM256_SET_M128I(_mm_srli_epi16(q2bits, 2), q2bits); + auto v1 = _mm256_sub_epi8(_mm256_and_si256(q2_1, mask2), m1_8); + auto v2 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(q2_1, 4), mask2), m1_8); + for (int iy = 0; iy < nrc_y; ++iy) { + auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 0), v1); + auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 1), v2); +#if defined __AVX512VNNI__ && defined __AVX512VL__ + dot1 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_setzero_si256(), m1_8, dot1), m1_8, dot2); +#else + dot1 = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot1), _mm256_maddubs_epi16(m1_8, dot2))); +#endif + accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.scale(iy, i/2)), _mm256_cvtepi32_ps(dot1), accd[iy]); + } + } for (int iy = 0; iy < nrc_y; ++iy) { info.store(ix, iy, hsum_float_8(accd[iy])); @@ -4089,13 +4149,52 @@ template <int nrc> struct Q8_K64 { constexpr static int nrc_y = nrc; - Q8_K64(const DataInfo& info) { for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8_K64 *)info.src1_row(iy); } + Q8_K64(const DataInfo& info) { for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8_K128 *)info.src1_row(iy); } - inline int8x16x4_t load_quants(int iy, int i) const { return vld1q_s8_x4(y[iy][i].qs); } + inline int8x16x4_t load_quants64(int iy, int i, int j) const { return vld1q_s8_x4(y[iy][i].qs + 64*j); } inline int8x16x2_t load_quants(int iy, int i, int j) const { return vld1q_s8_x2(y[iy][i].qs + 32*j); } inline float scale(int iy, int i) const { return y[iy][i].d; } - const block_q8_K64 * y[nrc_y]; + const block_q8_K128 * y[nrc_y]; +}; + +struct DequantizerIQ1BN { + const uint8x16_t m1 = vdupq_n_u8(1); + const uint8x16x4_t sign_shuffles = { + vreinterpretq_u8_u64(uint64x2_t{0x0000000000000000, 0x0101010101010101}), + vreinterpretq_u8_u64(uint64x2_t{0x0202020202020202, 0x0303030303030303}), + vreinterpretq_u8_u64(uint64x2_t{0x0404040404040404, 0x0505050505050505}), + vreinterpretq_u8_u64(uint64x2_t{0x0606060606060606, 0x0707070707070707}), + }; + const int8x16_t shift = vreinterpretq_s8_u32(vdupq_n_u32(0xfafcfe00)); + const uint8x16_t qmask = vdupq_n_u8(3); + const uint8x16_t shuff1 = vreinterpretq_u8_u64(uint64x2_t{0x0101010100000000, 0x0909090908080808}); + const uint8x16_t mask1 = vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201)); + int8x16x4_t signs; + uint64x2x4_t a; + inline void prepare_iq1bn_quants(uint8_t extra, const uint8_t * ql, const uint8_t * qh, int8x16x4_t& v) { + auto all_signs = vdupq_n_u8(extra); + all_signs = vorrq_u8(vceqq_u8(vandq_u8(all_signs, mask1), mask1), m1); + signs.val[0] = vqtbl1q_u8(all_signs, sign_shuffles.val[0]); + signs.val[1] = vqtbl1q_u8(all_signs, sign_shuffles.val[1]); + signs.val[2] = vqtbl1q_u8(all_signs, sign_shuffles.val[2]); + signs.val[3] = vqtbl1q_u8(all_signs, sign_shuffles.val[3]); + + a.val[0] = uint64x2_t{iq1bn_grid_u16[ql[0] | ((qh[0] << 8) & 0x0f00)], iq1bn_grid_u16[ql[1] | ((qh[0] << 4) & 0x0f00)]}; + a.val[1] = uint64x2_t{iq1bn_grid_u16[ql[2] | ((qh[1] << 8) & 0x0f00)], iq1bn_grid_u16[ql[3] | ((qh[1] << 4) & 0x0f00)]}; + a.val[2] = uint64x2_t{iq1bn_grid_u16[ql[4] | ((qh[2] << 8) & 0x0f00)], iq1bn_grid_u16[ql[5] | ((qh[2] << 4) & 0x0f00)]}; + a.val[3] = uint64x2_t{iq1bn_grid_u16[ql[6] | ((qh[3] << 8) & 0x0f00)], iq1bn_grid_u16[ql[7] | ((qh[3] << 4) & 0x0f00)]}; + + v.val[0] = vsubq_s8(vandq_u8(vshlq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[0]), shuff1), shift), qmask), m1); + v.val[1] = vsubq_s8(vandq_u8(vshlq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[1]), shuff1), shift), qmask), m1); + v.val[2] = vsubq_s8(vandq_u8(vshlq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[2]), shuff1), shift), qmask), m1); + v.val[3] = vsubq_s8(vandq_u8(vshlq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[3]), shuff1), shift), qmask), m1); + + v.val[0] = vmulq_s8(v.val[0], signs.val[0]); + v.val[1] = vmulq_s8(v.val[1], signs.val[1]); + v.val[2] = vmulq_s8(v.val[2], signs.val[2]); + v.val[3] = vmulq_s8(v.val[3], signs.val[3]); + } }; template <int nrc_y> @@ -4103,23 +4202,10 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn const int nb = n / QK_IQ1BN; Q8_K64<nrc_y> q8(info); + DequantizerIQ1BN deq; float32x4_t accd[nrc_y]; - int8x16x4_t signs; - uint64x2x4_t a; - int8x16x4_t v; - - const auto m1 = vdupq_n_u8(1); - const uint8x16x4_t sign_shuffles = { - vreinterpretq_u8_u64(uint64x2_t{0x0000000000000000, 0x0101010101010101}), - vreinterpretq_u8_u64(uint64x2_t{0x0202020202020202, 0x0303030303030303}), - vreinterpretq_u8_u64(uint64x2_t{0x0404040404040404, 0x0505050505050505}), - vreinterpretq_u8_u64(uint64x2_t{0x0606060606060606, 0x0707070707070707}), - }; - const auto shift = vreinterpretq_s8_u32(vdupq_n_u32(0xfafcfe00)); - const auto qmask = vdupq_n_u8(3); - const auto shuff1 = vreinterpretq_u8_u64(uint64x2_t{0x0101010100000000, 0x0909090908080808}); - const auto mask1 = vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201)); + int8x16x4_t v1, v2; const block_iq1_bn * x = (const block_iq1_bn *)((const char *)vx); @@ -4129,50 +4215,57 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = vdupq_n_f32(0.f); - for (int i = 0; i < nb; ++i) { + for (int i = 0; i < nb/2; ++i) { - auto all_signs = vdupq_n_u8(x[i].extra); - all_signs = vorrq_u8(vceqq_u8(vandq_u8(all_signs, mask1), mask1), m1); - signs.val[0] = vqtbl1q_u8(all_signs, sign_shuffles.val[0]); - signs.val[1] = vqtbl1q_u8(all_signs, sign_shuffles.val[1]); - signs.val[2] = vqtbl1q_u8(all_signs, sign_shuffles.val[2]); - signs.val[3] = vqtbl1q_u8(all_signs, sign_shuffles.val[3]); - - auto ql = x[i].ql; - auto qh = x[i].qh; - a.val[0] = uint64x2_t{iq1bn_grid_u16[ql[0] | ((qh[0] << 8) & 0x0f00)], iq1bn_grid_u16[ql[1] | ((qh[0] << 4) & 0x0f00)]}; - a.val[1] = uint64x2_t{iq1bn_grid_u16[ql[2] | ((qh[1] << 8) & 0x0f00)], iq1bn_grid_u16[ql[3] | ((qh[1] << 4) & 0x0f00)]}; - a.val[2] = uint64x2_t{iq1bn_grid_u16[ql[4] | ((qh[2] << 8) & 0x0f00)], iq1bn_grid_u16[ql[5] | ((qh[2] << 4) & 0x0f00)]}; - a.val[3] = uint64x2_t{iq1bn_grid_u16[ql[6] | ((qh[3] << 8) & 0x0f00)], iq1bn_grid_u16[ql[7] | ((qh[3] << 4) & 0x0f00)]}; - - v.val[0] = vsubq_s8(vandq_u8(vshlq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[0]), shuff1), shift), qmask), m1); - v.val[1] = vsubq_s8(vandq_u8(vshlq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[1]), shuff1), shift), qmask), m1); - v.val[2] = vsubq_s8(vandq_u8(vshlq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[2]), shuff1), shift), qmask), m1); - v.val[3] = vsubq_s8(vandq_u8(vshlq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[3]), shuff1), shift), qmask), m1); - - v.val[0] = vmulq_s8(v.val[0], signs.val[0]); - v.val[1] = vmulq_s8(v.val[1], signs.val[1]); - v.val[2] = vmulq_s8(v.val[2], signs.val[2]); - v.val[3] = vmulq_s8(v.val[3], signs.val[3]); + deq.prepare_iq1bn_quants(x[2*i+0].extra, x[2*i+0].ql, x[2*i+0].qh, v1); + deq.prepare_iq1bn_quants(x[2*i+1].extra, x[2*i+1].ql, x[2*i+1].qh, v2); + int32x4_t sumi1 = vdupq_n_s32(0); + int32x4_t sumi2 = vdupq_n_s32(0); if constexpr (nrc_y == 1) { - auto q = q8.load_quants(0, i); - int32x4_t sumi = vdupq_n_s32(0); + auto q1 = q8.load_quants64(0, i, 0); + auto q2 = q8.load_quants64(0, i, 1); for (int j = 0; j < 4; ++j) { - sumi = ggml_vdotq_s32(sumi, q.val[j], v.val[j]); + sumi1 = ggml_vdotq_s32(sumi1, q1.val[j], v1.val[j]); + sumi2 = ggml_vdotq_s32(sumi2, q2.val[j], v2.val[j]); } - accd[0] = vfmaq_f32(accd[0], vdupq_n_f32(q8.scale(0, i)), vcvtq_f32_s32(sumi)); + accd[0] = vfmaq_f32(accd[0], vdupq_n_f32(q8.scale(0, i)), vcvtq_f32_s32(vaddq_s32(sumi1, sumi2))); } else { for (int iy = 0; iy < nrc_y; ++iy) { int32x4_t sumi = vdupq_n_s32(0); auto q = q8.load_quants(iy, i, 0); - sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v.val[0]), q.val[1], v.val[1]); + sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v1.val[0]), q.val[1], v1.val[1]); q = q8.load_quants(iy, i, 1); - sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v.val[2]), q.val[1], v.val[3]); + sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v1.val[2]), q.val[1], v1.val[3]); + q = q8.load_quants(iy, i, 2); + sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v2.val[0]), q.val[1], v2.val[1]); + q = q8.load_quants(iy, i, 3); + sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v2.val[2]), q.val[1], v2.val[3]); accd[iy] = vfmaq_f32(accd[iy], vdupq_n_f32(q8.scale(iy, i)), vcvtq_f32_s32(sumi)); } } } + int i = 2*(nb/2); + if (i < nb) { + deq.prepare_iq1bn_quants(x[i].extra, x[i].ql, x[i].qh, v1); + if constexpr (nrc_y == 1) { + auto q = q8.load_quants(0, i/2, 0); + int32x4_t sumi = vdupq_n_s32(0); + for (int j = 0; j < 4; ++j) { + sumi = ggml_vdotq_s32(sumi, q.val[j], v1.val[j]); + } + accd[0] = vfmaq_f32(accd[0], vdupq_n_f32(q8.scale(0, i/2)), vcvtq_f32_s32(sumi)); + } else { + for (int iy = 0; iy < nrc_y; ++iy) { + int32x4_t sumi = vdupq_n_s32(0); + auto q = q8.load_quants(iy, i/2, 0); + sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v1.val[0]), q.val[1], v1.val[1]); + q = q8.load_quants(iy, i/2, 1); + sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v1.val[2]), q.val[1], v1.val[3]); + accd[iy] = vfmaq_f32(accd[iy], vdupq_n_f32(q8.scale(iy, i/2)), vcvtq_f32_s32(sumi)); + } + } + } for (int iy = 0; iy < nrc_y; ++iy) { info.store(ix, iy, vaddvq_f32(accd[iy])); @@ -4188,7 +4281,7 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn Q8_K64<nrc_y> q8(info); float32x4_t accd[nrc_y]; - int8x16x4_t v; + int8x16x4_t v1, v2; const auto m1 = vdupq_n_u8(1); const auto mask2 = vdupq_n_s8(3); @@ -4199,34 +4292,86 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn { auto q2bits = vld1q_u8(x[0].qs); - v.val[0] = vsubq_s8(vandq_s8(q2bits, mask2), m1); - v.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1); - v.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1); - v.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1); + v1.val[0] = vsubq_s8(vandq_s8(q2bits, mask2), m1); + v1.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1); + v1.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1); + v1.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1); + q2bits = vld1q_u8(x[1].qs); + v2.val[0] = vsubq_s8(vandq_s8(q2bits, mask2), m1); + v2.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1); + v2.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1); + v2.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1); for (int iy = 0; iy < nrc_y; ++iy) { int32x4_t sumi = vdupq_n_s32(0); auto q = q8.load_quants(iy, 0, 0); - sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v.val[0]), q.val[1], v.val[1]); + sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v1.val[0]), q.val[1], v1.val[1]); q = q8.load_quants(iy, 0, 1); - sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v.val[2]), q.val[1], v.val[3]); + sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v1.val[2]), q.val[1], v1.val[3]); + q = q8.load_quants(iy, 0, 2); + sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v2.val[0]), q.val[1], v2.val[1]); + q = q8.load_quants(iy, 0, 3); + sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v2.val[2]), q.val[1], v2.val[3]); accd[iy] = vmulq_f32(vdupq_n_f32(q8.scale(iy, 0)), vcvtq_f32_s32(sumi)); } } - for (int i = 1; i < nb; ++i) { - auto q2bits = vld1q_u8(x[i].qs); - v.val[0] = vsubq_s8(vandq_s8(q2bits, mask2), m1); - v.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1); - v.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1); - v.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1); + if constexpr (nrc_y == 1) { + for (int i = 1; i < nb/2; ++i) { + auto sumi1 = vdupq_n_s32(0); + auto sumi2 = vdupq_n_s32(0); + for (int j = 0; j < 2; ++j) { + auto q = q8.load_quants64(0, i, j); + auto q2bits = vld1q_u8(x[2*i+j].qs); + v1.val[0] = vsubq_s8(vandq_s8(q2bits, mask2), m1); + v1.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1); + v1.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1); + v1.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1); + sumi1 = ggml_vdotq_s32(ggml_vdotq_s32(sumi1, q.val[0], v1.val[0]), q.val[1], v1.val[1]); + sumi2 = ggml_vdotq_s32(ggml_vdotq_s32(sumi2, q.val[2], v1.val[2]), q.val[3], v1.val[3]); + } + accd[0] = vfmaq_f32(accd[0], vdupq_n_f32(q8.scale(0, i)), vcvtq_f32_s32(vaddq_s32(sumi1, sumi2))); + } + } else { + for (int i = 1; i < nb/2; ++i) { + auto q2bits = vld1q_u8(x[2*i+0].qs); + v1.val[0] = vsubq_s8(vandq_s8(q2bits, mask2), m1); + v1.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1); + v1.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1); + v1.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1); + q2bits = vld1q_u8(x[2*i+1].qs); + v2.val[0] = vsubq_s8(vandq_s8(q2bits, mask2), m1); + v2.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1); + v2.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1); + v2.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1); + for (int iy = 0; iy < nrc_y; ++iy) { + int32x4_t sumi = vdupq_n_s32(0); + auto q = q8.load_quants(iy, i, 0); + sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v1.val[0]), q.val[1], v1.val[1]); + q = q8.load_quants(iy, i, 1); + sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v1.val[2]), q.val[1], v1.val[3]); + q = q8.load_quants(iy, i, 2); + sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v2.val[0]), q.val[1], v2.val[1]); + q = q8.load_quants(iy, i, 3); + sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v2.val[2]), q.val[1], v2.val[3]); + accd[iy] = vfmaq_f32(accd[iy], vdupq_n_f32(q8.scale(iy, i)), vcvtq_f32_s32(sumi)); + } + } + } + int i = 2*(nb/2); + if (i < nb) { + auto q2bits = vld1q_u8(x[2*i+0].qs); + v1.val[0] = vsubq_s8(vandq_s8(q2bits, mask2), m1); + v1.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1); + v1.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1); + v1.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1); for (int iy = 0; iy < nrc_y; ++iy) { int32x4_t sumi = vdupq_n_s32(0); - auto q = q8.load_quants(iy, i, 0); - sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v.val[0]), q.val[1], v.val[1]); - q = q8.load_quants(iy, i, 1); - sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v.val[2]), q.val[1], v.val[3]); - accd[iy] = vfmaq_f32(accd[iy], vdupq_n_f32(q8.scale(iy, i)), vcvtq_f32_s32(sumi)); + auto q = q8.load_quants(iy, i/2, 0); + sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v1.val[0]), q.val[1], v1.val[1]); + q = q8.load_quants(iy, i/2, 1); + sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v1.val[2]), q.val[1], v1.val[3]); + accd[iy] = vfmaq_f32(accd[iy], vdupq_n_f32(q8.scale(iy, i/2)), vcvtq_f32_s32(sumi)); } } |