diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2024-12-06 12:15:39 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-12-06 12:15:39 +0100 |
commit | 3682e4700db6b8cb2ca8e3da365578078f21ab0c (patch) | |
tree | ea1680494ca00580b0a038cdef035c596e80e58c /ggml/src/iqk/iqk_mul_mat.cpp | |
parent | f64de08203aaee95ca755336de3e1db85d990198 (diff) |
iq2_bn_r4: fastest Bitnet CPU implementation on the planet (#124)
* Adding iq2_bn_r4
This Zen4-only implementation achieves PP-512 = 826 t/s (!!!)
for Bitnet-1.58b-3B, up from 620 t/s for iq2_bn.
* Make sure rows per thread are a multiple of the number of interleaved rows
With this I can run iq2_bn_r4 with 32 threads and this increases
PP-512 to 872 t/s.
* iq2_bn_r4: 1st shot at NEON
PP-512 is already faster than iq2_bn (284 t/s vs 246 t/s
for Bitnet-1.58b-3B). TG-128 is ~5% slower.
* iq2_bn_r4: NEON
PP-512 is now 296 t/s. TG-128 is ~20% faster than iq2_bn
for 1 thread, but saturates to about the same 93 t/s at
8 threads.
* iq2_bn_r4: Experimenting on NEON
The matrix x vvector multiplication is erratic.
iq2_bn_r4 is faster at 1, 2, and 4 threads, but
saturates to a lower t/s at 8 threads compared to
iq2_bn. iq2_bn actually manages 99 t/s at 8 threads
and not 93 as I wrore in the last commit. iq2_bn_r4
performance has huge fluctuations at 4 and 8 threads.
* Some cleanup
* iq2_bn_r4: AVX2
As expected, PP is slightly slower as we just don;t have
enough vector registers (690 vs 710 t/s). TG is slightly faster
(18.2 vs 16.7 t/s at 1 thread).
* iq2_bn_r4: use AVX2 implementation on Zen4 for matrix x vector
It is faster - we get 29.6 t/s at 1 thread vs 25.9 t/s for iq2_bn.
* iq2_bn_r4: simdify q8_K16 quantization (AVX2)
PP-512 becomes 834 t/s and TG-128 now saturates to the same
performance as iq2_bn for 4 threads.
* iq2_bn_r4: simdify q8_K16 quantization (NEON)
PP-512 is now 304.7 t/s, and TG-128 @ 8 threads
very slightly outperforms iq2_bn (100.7 t/s vs 99.6 t/s)
* iq2_bn_r4: fix AVX2 after breaking it two commits ago
* iq2_bn_r4: better AVX2
As we don't have enough vector registers on AVX2, it is better
to do two passes per row needing only half of the accumulator
registers that way.
With this, we now beat iq2_bn PP also on AVX2 by a small margin.
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml/src/iqk/iqk_mul_mat.cpp')
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 421 |
1 files changed, 417 insertions, 4 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index faa4cab7..b6ff7ab7 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -161,6 +161,17 @@ struct MulMat { } } static bool prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny); + static inline int num_rows(ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0_R4: + case GGML_TYPE_Q5_0_R4: + case GGML_TYPE_Q6_0_R4: + case GGML_TYPE_Q8_0_R4: + case GGML_TYPE_IQ4_NL_X4: + case GGML_TYPE_IQ2_BN_R4: return 4; + default: return 1; + } + } private: template <typename Dequantizer> static void set_functions(MulMat& m); }; @@ -181,13 +192,15 @@ bool iqk_mul_mat(long Nx, long Ny, long ne00, size_t row_size_qy = strideB; //*ggml_type_size(ggml_type(typeB)); //if (ith == 0) printf("%s: ne00 = %d, row_size_qx = %d, strideA = %d\n", __func__, int(ne00), int(row_size_qx), int(strideA)); - auto nrc_x = (Nx + nth - 1)/nth; + auto num_rows = MulMat::num_rows(ggml_type(typeA)); + GGML_ASSERT(Nx%num_rows == 0); + auto nrc_x = (Nx/num_rows + nth - 1)/nth; auto first_x = ith*nrc_x; - if (first_x + nrc_x > Nx) nrc_x = Nx - first_x; + if (first_x + nrc_x > Nx/num_rows) nrc_x = Nx/num_rows - first_x; - DataInfo info{C + first_x, (const char *)B, (size_t)stride_C, row_size_qy, 0, 1, nullptr, 0}; + DataInfo info{C + first_x*num_rows, (const char *)B, (size_t)stride_C, row_size_qy, 0, 1, nullptr, 0}; - mm.mul_mat_NxM(ne00, (const char *)A + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny); + mm.mul_mat_NxM(ne00, (const char *)A + row_size_qx*first_x*num_rows, row_size_qx, info, nrc_x*num_rows, Ny); return true; } @@ -319,6 +332,30 @@ template <int nrc, typename block_q8 = block_q8_K> struct Q8 { const block_q8 * y[nrc_y]; }; +template <int nrc> struct Q8_16 { + + constexpr static int nrc_y = nrc; + + Q8_16(const DataInfo& info) { + for (int iy = 0; iy < nrc_y; ++iy) { + auto ptr = (const float *)info.src1_row(iy); + std::memcpy(d + 5*iy, ptr, 5*sizeof(float)); + y[iy] = (const int8_t *)(ptr + 5); + } + } + +#ifdef HAVE_FANCY_SIMD + inline __m512i load_quants64(int iy, int i) const { return _mm512_loadu_si512((const __m512i*)y[iy] + i); } +#endif + inline __m256i load_quants(int iy, int i) const { return _mm256_loadu_si256((const __m256i*)y[iy] + i); } + inline float scale(int iy, int k) const { return d[5*iy+k]; } + inline float sum_row(int iy) const { return d[5*iy + 4]; } + inline __m128 scale(int iy) const { return _mm_loadu_ps(d + 5*iy); } + + float d[5*nrc_y]; + const int8_t * y[nrc_y]; +}; + struct Scales8KBase { template <typename Q8> inline void accum_mins(const __m128i& mins128, const Q8& q8, int i, float c, __m256 * accd) const { @@ -2079,6 +2116,228 @@ static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInf #endif // Zen4 or vanilla AVX2 +template <int nrc_y> +static void mul_mat_iq2_bn_r4_q8_k16_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + if (nrc_x%4) { + printf("%s: %d is not a multiple of 4\n", __func__, nrc_x); + GGML_ABORT("fatal error"); + } + Q8_16<nrc_y> q8(info); + auto m3 = _mm256_set1_epi8(0x3); + auto m1 = _mm256_set1_epi16(1); + int nb = n / QK_IQ1BN; + __m256i qx[4]; + if constexpr (nrc_y > 4) { + __m256i acc[nrc_y] = {}; + __m128 sum4[nrc_y]; + for (int ix = 0; ix < nrc_x; ix += 4) { + const float * dptr = (const float *)((const char *)vx + ix*bx); + auto dl = _mm_loadu_ps(dptr); + const uint8_t * iq2l = (const uint8_t *)(dptr + 4); + for (int ib = 0; ib < nb; ++ib) { + auto bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+0); + qx[0] = _mm256_and_si256(bits, m3); + qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3); + qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3); + qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = q8.load_quants(iy, 2*ib+0); + auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)), + _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55))); + auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)), + _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff))); + acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2))); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto dy = q8.scale(iy); + auto sumf1 = _mm256_cvtepi32_ps(acc[iy]); + auto s4 = _mm_mul_ps(_mm256_extractf128_ps(sumf1, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x00))); + s4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x55)), s4); + sum4[iy] = _mm_fmadd_ps(dl, _mm_set1_ps(-q8.sum_row(iy)), s4); + acc[iy] = _mm256_setzero_si256(); + } + for (int ib = 0; ib < nb; ++ib) { + auto bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+1); + qx[0] = _mm256_and_si256(bits, m3); + qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3); + qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3); + qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = q8.load_quants(iy, 2*ib+1); + auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)), + _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55))); + auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)), + _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff))); + acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2))); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto dy = q8.scale(iy); + auto sumf1 = _mm256_cvtepi32_ps(acc[iy]); + auto s4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xaa)), sum4[iy]); + s4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xff)), s4); + info.store(ix, iy, s4); + acc[iy] = _mm256_setzero_si256(); + } + } + } else { + __m256i acc[2*nrc_y] = {}; + for (int ix = 0; ix < nrc_x; ix += 4) { + const float * dptr = (const float *)((const char *)vx + ix*bx); + auto dl = _mm_loadu_ps(dptr); + const uint8_t * iq2l = (const uint8_t *)(dptr + 4); + for (int ib = 0; ib < nb; ++ib) { + auto bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+0); + qx[0] = _mm256_and_si256(bits, m3); + qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3); + qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3); + qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = q8.load_quants(iy, 2*ib+0); + auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)), + _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55))); + auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)), + _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff))); + acc[2*iy+0] = _mm256_add_epi32(acc[2*iy+0], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2))); + } + bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+1); + qx[0] = _mm256_and_si256(bits, m3); + qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3); + qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3); + qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = q8.load_quants(iy, 2*ib+1); + auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)), + _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55))); + auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)), + _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff))); + acc[2*iy+1] = _mm256_add_epi32(acc[2*iy+1], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2))); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto dy = q8.scale(iy); + auto sumf1 = _mm256_cvtepi32_ps(acc[2*iy+0]); + auto sumf2 = _mm256_cvtepi32_ps(acc[2*iy+1]); + auto sum4 = _mm_mul_ps(_mm256_extractf128_ps(sumf1, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x00))); + sum4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x55)), sum4); + sum4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf2, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xaa)), sum4); + sum4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf2, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xff)), sum4); + sum4 = _mm_fmadd_ps(dl, _mm_set1_ps(-q8.sum_row(iy)), sum4); + info.store(ix, iy, sum4); + acc[2*iy+0] = acc[2*iy+1] = _mm256_setzero_si256(); + } + } + } +} + +#ifdef HAVE_FANCY_SIMD +template <int nrc_y> +static void mul_mat_iq2_bn_r4_q8_k16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + if (nrc_x%4) { + printf("%s: %d is not a multiple of 4\n", __func__, nrc_x); + GGML_ABORT("fatal error"); + } + if constexpr (nrc_y == 1) { + mul_mat_iq2_bn_r4_q8_k16_avx2<1>(n, vx, bx, info, nrc_x); + } else { + Q8_16<nrc_y> q8(info); + auto m3 = _mm512_set1_epi8(0x3); + int nb = n / QK_IQ1BN; + __m512i acc[2*nrc_y] = {}; + __m512i qx[8]; + for (int ix = 0; ix < nrc_x/8; ++ix) { + const float * dptr1 = (const float *)((const char *)vx + (8*ix+0)*bx); + const float * dptr2 = (const float *)((const char *)vx + (8*ix+4)*bx); + auto dl = _mm_loadu_ps(dptr1); + auto dh = _mm_loadu_ps(dptr2); + const uint8_t * iq2l = (const uint8_t *)(dptr1 + 4); + const uint8_t * iq2h = (const uint8_t *)(dptr2 + 4); + for (int ib = 0; ib < nb; ++ib) { + auto bits_l = _mm512_loadu_si512((const __m512i *)iq2l + ib); + auto bits_h = _mm512_loadu_si512((const __m512i *)iq2h + ib); + qx[0] = _mm512_and_si512(bits_l, m3); + qx[1] = _mm512_and_si512(bits_h, m3); + qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 2), m3); + qx[3] = _mm512_and_si512(_mm512_srli_epi16(bits_h, 2), m3); + qx[4] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 4), m3); + qx[5] = _mm512_and_si512(_mm512_srli_epi16(bits_h, 4), m3); + qx[6] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 6), m3); + qx[7] = _mm512_and_si512(_mm512_srli_epi16(bits_h, 6), m3); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = q8.load_quants64(iy, ib); + auto sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)); + acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[0], sy); + acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[1], sy); + sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)); + acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[2], sy); + acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[3], sy); + sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)); + acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[4], sy); + acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[5], sy); + sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)); + acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[6], sy); + acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[7], sy); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto dy = q8.scale(iy); + __m128 sum4; + for (int k = 0; k < 2; ++k) { + const auto& dx = k == 0 ? dl : dh; + auto sumf = _mm512_cvtepi32_ps(acc[2*iy+k]); + sum4 = _mm_mul_ps (_mm512_extractf32x4_ps(sumf, 0), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0x00))); + sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 1), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0x55)), sum4); + sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 2), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0xaa)), sum4); + sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 3), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0xff)), sum4); + sum4 = _mm_fmadd_ps(dx, _mm_set1_ps(-q8.sum_row(iy)), sum4); + info.store(8*ix + 4*k, iy, sum4); + } + acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_si512(); + } + } + if (int ix = 8*(nrc_x/8); ix < nrc_x) { + const float * dptr = (const float *)((const char *)vx + ix*bx); + auto dl = _mm_loadu_ps(dptr); + const uint8_t * iq2l = (const uint8_t *)(dptr + 4); + for (int ib = 0; ib < nb; ++ib) { + auto bits_l = _mm512_loadu_si512((const __m512i *)iq2l + ib); + qx[0] = _mm512_and_si512(bits_l, m3); + qx[1] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 2), m3); + qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 4), m3); + qx[3] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 6), m3); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = q8.load_quants64(iy, ib); + acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); + acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); + acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); + acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto dy = q8.scale(iy); + auto sumf = _mm512_cvtepi32_ps(acc[iy]); + auto sum4 = _mm_mul_ps(_mm512_extractf32x4_ps(sumf, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x00))); + sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x55)), sum4); + sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 2), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xaa)), sum4); + sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 3), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xff)), sum4); + sum4 = _mm_fmadd_ps(dl, _mm_set1_ps(-q8.sum_row(iy)), sum4); + info.store(ix, iy, sum4); + } + } + } +} +#else +template <int nrc_y> +static void mul_mat_iq2_bn_r4_q8_k16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + if (nrc_x%4) { + printf("%s: %d is not a multiple of 4\n", __func__, nrc_x); + GGML_ABORT("fatal error"); + } + mul_mat_iq2_bn_r4_q8_k16_avx2<nrc_y>(n, vx, bx, info, nrc_x); +} +#endif + #ifdef HAVE_FANCY_SIMD template <int nrc_y> static void mul_mat_iq4_nl_x4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { @@ -4744,6 +5003,20 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { mm.funcs[7] = mul_mat_iq2bn_q8_K64<8>; expected_typeB = GGML_TYPE_Q8_K64; break; + case GGML_TYPE_IQ2_BN_R4: + assert (ne00 % QK_IQ1BN == 0); + mm.funcs[0] = mul_mat_iq2_bn_r4_q8_k16<1>; + mm.funcs[1] = mul_mat_iq2_bn_r4_q8_k16<2>; + mm.funcs[2] = mul_mat_iq2_bn_r4_q8_k16<3>; + mm.funcs[3] = mul_mat_iq2_bn_r4_q8_k16<4>; + mm.funcs[4] = mul_mat_iq2_bn_r4_q8_k16<5>; + mm.funcs[5] = mul_mat_iq2_bn_r4_q8_k16<6>; +//#ifdef HAVE_FANCY_SIMD + mm.funcs[6] = mul_mat_iq2_bn_r4_q8_k16<7>; + mm.funcs[7] = mul_mat_iq2_bn_r4_q8_k16<8>; +//#endif + expected_typeB = GGML_TYPE_Q8_K16; + break; case GGML_TYPE_Q4_0: assert (ne00 % QK4_0 == 0); MulMat::set_functions<Q4_0_1_Unpacker>(mm); @@ -7171,6 +7444,135 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn } } +template <int nrc> struct Q8_16 { + + constexpr static int nrc_y = nrc; + + Q8_16(const DataInfo& info) { + for (int iy = 0; iy < nrc_y; ++iy) { + auto ptr = (const float *)info.src1_row(iy); + std::memcpy(d + 5*iy, ptr, 5*sizeof(float)); + y[iy] = (const int8_t *)(ptr + 5); + } + } + + inline int8x16x4_t load_quants(int iy, int i) const { return vld1q_s8_x4(y[iy] + 64*i); } + inline int8x16x2_t load_quants_32(int iy, int i) const { return vld1q_s8_x2(y[iy] + 32*i); } + inline float scale(int iy, int k) const { return d[5*iy+k]; } + inline float sum_row(int iy) const { return d[5*iy + 4]; } + inline float32x4_t scale(int iy) const { return vld1q_f32(d + 5*iy); } + + float d[5*nrc_y]; + const int8_t * y[nrc_y]; +}; + +template <int nrc_y> +static IQK_NOINLINE void mul_mat_iq2_bn_r4_q8_k16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + if (nrc_x%4) { + printf("%s: %d is not a multiple of 4\n", __func__, nrc_x); + GGML_ABORT("fatal error"); + } + Q8_16<nrc_y> q8(info); + auto m3 = vdupq_n_u8(0x3); + int nb = n / QK_IQ1BN; + if constexpr (nrc_y == 1) { + auto mc = vdupq_n_u8(0xc); + int32x4_t acc[8]; + for (int ix = 0; ix < nrc_x; ix += 4) { + for (int k = 0; k < 8; ++k) acc[k] = vdupq_n_s32(0); + const float * dptr = (const float *)((const char *)vx + ix*bx); + auto dl = vld1q_f32(dptr); + const uint8_t * iq2 = (const uint8_t *)(dptr + 4); + for (int ib = 0; ib < nb; ++ib) { + auto y = q8.load_quants(0, ib); + for (int j = 0; j < 4; ++j) { + auto bits1 = vld1q_u8(iq2 + 64*ib + 16*j); + auto bits2 = vshrq_n_u8(bits1, 4); + acc[2*j+0] = vdotq_laneq_s32(acc[2*j+0], vandq_u8(bits1, m3), y.val[j], 0); + acc[2*j+1] = vdotq_laneq_s32(acc[2*j+1], vandq_u8(bits1, mc), y.val[j], 1); + acc[2*j+0] = vdotq_laneq_s32(acc[2*j+0], vandq_u8(bits2, m3), y.val[j], 2); + acc[2*j+1] = vdotq_laneq_s32(acc[2*j+1], vandq_u8(bits2, mc), y.val[j], 3); + } + } + auto dy = vmulq_f32(dl, vdupq_n_f32(q8.scale(0, 0))); + auto sumf1 = vmulq_f32( vcvtq_f32_s32(acc[0]), dy); + auto sumf2 = vmulq_f32( vcvtq_f32_s32(acc[1]), dy); + dy = vmulq_f32(dl, vdupq_n_f32(q8.scale(0, 1))); + sumf1 = vfmaq_f32(sumf1, vcvtq_f32_s32(acc[2]), dy); + sumf2 = vfmaq_f32(sumf2, vcvtq_f32_s32(acc[3]), dy); + dy = vmulq_f32(dl, vdupq_n_f32(q8.scale(0, 2))); + sumf1 = vfmaq_f32(sumf1, vcvtq_f32_s32(acc[4]), dy); + sumf2 = vfmaq_f32(sumf2, vcvtq_f32_s32(acc[5]), dy); + dy = vmulq_f32(dl, vdupq_n_f32(q8.scale(0, 3))); + sumf1 = vfmaq_f32(sumf1, vcvtq_f32_s32(acc[6]), dy); + sumf2 = vfmaq_f32(sumf2, vcvtq_f32_s32(acc[7]), dy); + auto sumf = vfmaq_f32(sumf1, vdupq_n_f32(0.25f), sumf2); + sumf = vfmaq_f32(sumf, dl, vdupq_n_f32(-q8.sum_row(0))); + info.store(ix, 0, sumf); + } + } else { + int32x4_t acc[4*nrc_y] = {}; + uint8x16_t qx[8]; + for (int ix = 0; ix < nrc_x; ix += 4) { + const float * dptr = (const float *)((const char *)vx + ix*bx); + auto dl = vld1q_f32(dptr); + const uint8_t * iq2 = (const uint8_t *)(dptr + 4); + for (int ib = 0; ib < nb; ++ib) { + auto bits = vld1q_u8_x2(iq2 + 64*ib); + qx[0] = vandq_u8(bits.val[0], m3); + qx[1] = vandq_u8(vshrq_n_u8(bits.val[0], 2), m3); + qx[2] = vandq_u8(vshrq_n_u8(bits.val[0], 4), m3); + qx[3] = vshrq_n_u8(bits.val[0], 6); + qx[4] = vandq_u8(bits.val[1], m3); + qx[5] = vandq_u8(vshrq_n_u8(bits.val[1], 2), m3); + qx[6] = vandq_u8(vshrq_n_u8(bits.val[1], 4), m3); + qx[7] = vshrq_n_u8(bits.val[1], 6); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = q8.load_quants_32(iy, 2*ib+0); + acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[0], y.val[0], 0); + acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[1], y.val[0], 1); + acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[2], y.val[0], 2); + acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[3], y.val[0], 3); + acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[4], y.val[1], 0); + acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[5], y.val[1], 1); + acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[6], y.val[1], 2); + acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[7], y.val[1], 3); + } + bits = vld1q_u8_x2(iq2 + 64*ib + 32); + qx[0] = vandq_u8(bits.val[0], m3); + qx[1] = vandq_u8(vshrq_n_u8(bits.val[0], 2), m3); + qx[2] = vandq_u8(vshrq_n_u8(bits.val[0], 4), m3); + qx[3] = vshrq_n_u8(bits.val[0], 6); + qx[4] = vandq_u8(bits.val[1], m3); + qx[5] = vandq_u8(vshrq_n_u8(bits.val[1], 2), m3); + qx[6] = vandq_u8(vshrq_n_u8(bits.val[1], 4), m3); + qx[7] = vshrq_n_u8(bits.val[1], 6); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = q8.load_quants_32(iy, 2*ib+1); + acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[0], y.val[0], 0); + acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[1], y.val[0], 1); + acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[2], y.val[0], 2); + acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[3], y.val[0], 3); + acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[4], y.val[1], 0); + acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[5], y.val[1], 1); + acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[6], y.val[1], 2); + acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[7], y.val[1], 3); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto dy = q8.scale(iy); + float32x4_t sumf = vmulq_f32(vcvtq_f32_s32(acc[4*iy+0]), vmulq_laneq_f32(dl, dy, 0)); + sumf = vfmaq_f32(sumf, vcvtq_f32_s32(acc[4*iy+1]), vmulq_laneq_f32(dl, dy, 1)); + sumf = vfmaq_f32(sumf, vcvtq_f32_s32(acc[4*iy+2]), vmulq_laneq_f32(dl, dy, 2)); + sumf = vfmaq_f32(sumf, vcvtq_f32_s32(acc[4*iy+3]), vmulq_laneq_f32(dl, dy, 3)); + sumf = vfmaq_f32(sumf, dl, vdupq_n_f32(-q8.sum_row(iy))); + info.store(ix, iy, sumf); + acc[4*iy+0] = acc[4*iy+1] = acc[4*iy+2] = acc[4*iy+3] = vdupq_n_s32(0); + } + } + } +} + template <int nrc_y> static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { const int nb = n / QK_IQ1BN; @@ -7716,6 +8118,17 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { m.funcs[7] = mul_mat_iq2bn_q8_K64<8>; expected_Btype = GGML_TYPE_Q8_K64; break; + case GGML_TYPE_IQ2_BN_R4: + m.funcs[0] = mul_mat_iq2_bn_r4_q8_k16<1>; + m.funcs[1] = mul_mat_iq2_bn_r4_q8_k16<2>; + m.funcs[2] = mul_mat_iq2_bn_r4_q8_k16<3>; + m.funcs[3] = mul_mat_iq2_bn_r4_q8_k16<4>; + m.funcs[4] = mul_mat_iq2_bn_r4_q8_k16<5>; + //m.funcs[5] = mul_mat_iq2_bn_r4_q8_k16<6>; + //m.funcs[6] = mul_mat_iq2_bn_r4_q8_k16<7>; + //m.funcs[7] = mul_mat_iq2_bn_r4_q8_k16<8>; + expected_Btype = GGML_TYPE_Q8_K16; + break; case GGML_TYPE_Q4_0: MulMat::set_functions<DequantizerQ40>(m); expected_Btype = GGML_TYPE_Q8_0; |