diff options
Diffstat (limited to 'ggml/src/iqk/iqk_mul_mat.cpp')
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 421 |
1 files changed, 417 insertions, 4 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index faa4cab7..b6ff7ab7 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -161,6 +161,17 @@ struct MulMat { } } static bool prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny); + static inline int num_rows(ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0_R4: + case GGML_TYPE_Q5_0_R4: + case GGML_TYPE_Q6_0_R4: + case GGML_TYPE_Q8_0_R4: + case GGML_TYPE_IQ4_NL_X4: + case GGML_TYPE_IQ2_BN_R4: return 4; + default: return 1; + } + } private: template <typename Dequantizer> static void set_functions(MulMat& m); }; @@ -181,13 +192,15 @@ bool iqk_mul_mat(long Nx, long Ny, long ne00, size_t row_size_qy = strideB; //*ggml_type_size(ggml_type(typeB)); //if (ith == 0) printf("%s: ne00 = %d, row_size_qx = %d, strideA = %d\n", __func__, int(ne00), int(row_size_qx), int(strideA)); - auto nrc_x = (Nx + nth - 1)/nth; + auto num_rows = MulMat::num_rows(ggml_type(typeA)); + GGML_ASSERT(Nx%num_rows == 0); + auto nrc_x = (Nx/num_rows + nth - 1)/nth; auto first_x = ith*nrc_x; - if (first_x + nrc_x > Nx) nrc_x = Nx - first_x; + if (first_x + nrc_x > Nx/num_rows) nrc_x = Nx/num_rows - first_x; - DataInfo info{C + first_x, (const char *)B, (size_t)stride_C, row_size_qy, 0, 1, nullptr, 0}; + DataInfo info{C + first_x*num_rows, (const char *)B, (size_t)stride_C, row_size_qy, 0, 1, nullptr, 0}; - mm.mul_mat_NxM(ne00, (const char *)A + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny); + mm.mul_mat_NxM(ne00, (const char *)A + row_size_qx*first_x*num_rows, row_size_qx, info, nrc_x*num_rows, Ny); return true; } @@ -319,6 +332,30 @@ template <int nrc, typename block_q8 = block_q8_K> struct Q8 { const block_q8 * y[nrc_y]; }; +template <int nrc> struct Q8_16 { + + constexpr static int nrc_y = nrc; + + Q8_16(const DataInfo& info) { + for (int iy = 0; iy < nrc_y; ++iy) { + auto ptr = (const float *)info.src1_row(iy); + std::memcpy(d + 5*iy, ptr, 5*sizeof(float)); + y[iy] = (const int8_t *)(ptr + 5); + } + } + +#ifdef HAVE_FANCY_SIMD + inline __m512i load_quants64(int iy, int i) const { return _mm512_loadu_si512((const __m512i*)y[iy] + i); } +#endif + inline __m256i load_quants(int iy, int i) const { return _mm256_loadu_si256((const __m256i*)y[iy] + i); } + inline float scale(int iy, int k) const { return d[5*iy+k]; } + inline float sum_row(int iy) const { return d[5*iy + 4]; } + inline __m128 scale(int iy) const { return _mm_loadu_ps(d + 5*iy); } + + float d[5*nrc_y]; + const int8_t * y[nrc_y]; +}; + struct Scales8KBase { template <typename Q8> inline void accum_mins(const __m128i& mins128, const Q8& q8, int i, float c, __m256 * accd) const { @@ -2079,6 +2116,228 @@ static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInf #endif // Zen4 or vanilla AVX2 +template <int nrc_y> +static void mul_mat_iq2_bn_r4_q8_k16_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + if (nrc_x%4) { + printf("%s: %d is not a multiple of 4\n", __func__, nrc_x); + GGML_ABORT("fatal error"); + } + Q8_16<nrc_y> q8(info); + auto m3 = _mm256_set1_epi8(0x3); + auto m1 = _mm256_set1_epi16(1); + int nb = n / QK_IQ1BN; + __m256i qx[4]; + if constexpr (nrc_y > 4) { + __m256i acc[nrc_y] = {}; + __m128 sum4[nrc_y]; + for (int ix = 0; ix < nrc_x; ix += 4) { + const float * dptr = (const float *)((const char *)vx + ix*bx); + auto dl = _mm_loadu_ps(dptr); + const uint8_t * iq2l = (const uint8_t *)(dptr + 4); + for (int ib = 0; ib < nb; ++ib) { + auto bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+0); + qx[0] = _mm256_and_si256(bits, m3); + qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3); + qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3); + qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = q8.load_quants(iy, 2*ib+0); + auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)), + _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55))); + auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)), + _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff))); + acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2))); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto dy = q8.scale(iy); + auto sumf1 = _mm256_cvtepi32_ps(acc[iy]); + auto s4 = _mm_mul_ps(_mm256_extractf128_ps(sumf1, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x00))); + s4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x55)), s4); + sum4[iy] = _mm_fmadd_ps(dl, _mm_set1_ps(-q8.sum_row(iy)), s4); + acc[iy] = _mm256_setzero_si256(); + } + for (int ib = 0; ib < nb; ++ib) { + auto bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+1); + qx[0] = _mm256_and_si256(bits, m3); + qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3); + qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3); + qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = q8.load_quants(iy, 2*ib+1); + auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)), + _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55))); + auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)), + _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff))); + acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2))); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto dy = q8.scale(iy); + auto sumf1 = _mm256_cvtepi32_ps(acc[iy]); + auto s4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xaa)), sum4[iy]); + s4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xff)), s4); + info.store(ix, iy, s4); + acc[iy] = _mm256_setzero_si256(); + } + } + } else { + __m256i acc[2*nrc_y] = {}; + for (int ix = 0; ix < nrc_x; ix += 4) { + const float * dptr = (const float *)((const char *)vx + ix*bx); + auto dl = _mm_loadu_ps(dptr); + const uint8_t * iq2l = (const uint8_t *)(dptr + 4); + for (int ib = 0; ib < nb; ++ib) { + auto bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+0); + qx[0] = _mm256_and_si256(bits, m3); + qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3); + qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3); + qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = q8.load_quants(iy, 2*ib+0); + auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)), + _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55))); + auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)), + _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff))); + acc[2*iy+0] = _mm256_add_epi32(acc[2*iy+0], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2))); + } + bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+1); + qx[0] = _mm256_and_si256(bits, m3); + qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3); + qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3); + qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = q8.load_quants(iy, 2*ib+1); + auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)), + _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55))); + auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)), + _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff))); + acc[2*iy+1] = _mm256_add_epi32(acc[2*iy+1], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2))); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto dy = q8.scale(iy); + auto sumf1 = _mm256_cvtepi32_ps(acc[2*iy+0]); + auto sumf2 = _mm256_cvtepi32_ps(acc[2*iy+1]); + auto sum4 = _mm_mul_ps(_mm256_extractf128_ps(sumf1, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x00))); + sum4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x55)), sum4); + sum4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf2, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xaa)), sum4); + sum4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf2, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xff)), sum4); + sum4 = _mm_fmadd_ps(dl, _mm_set1_ps(-q8.sum_row(iy)), sum4); + info.store(ix, iy, sum4); + acc[2*iy+0] = acc[2*iy+1] = _mm256_setzero_si256(); + } + } + } +} + +#ifdef HAVE_FANCY_SIMD +template <int nrc_y> +static void mul_mat_iq2_bn_r4_q8_k16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + if (nrc_x%4) { + printf("%s: %d is not a multiple of 4\n", __func__, nrc_x); + GGML_ABORT("fatal error"); + } + if constexpr (nrc_y == 1) { + mul_mat_iq2_bn_r4_q8_k16_avx2<1>(n, vx, bx, info, nrc_x); + } else { + Q8_16<nrc_y> q8(info); + auto m3 = _mm512_set1_epi8(0x3); + int nb = n / QK_IQ1BN; + __m512i acc[2*nrc_y] = {}; + __m512i qx[8]; + for (int ix = 0; ix < nrc_x/8; ++ix) { + const float * dptr1 = (const float *)((const char *)vx + (8*ix+0)*bx); + const float * dptr2 = (const float *)((const char *)vx + (8*ix+4)*bx); + auto dl = _mm_loadu_ps(dptr1); + auto dh = _mm_loadu_ps(dptr2); + const uint8_t * iq2l = (const uint8_t *)(dptr1 + 4); + const uint8_t * iq2h = (const uint8_t *)(dptr2 + 4); + for (int ib = 0; ib < nb; ++ib) { + auto bits_l = _mm512_loadu_si512((const __m512i *)iq2l + ib); + auto bits_h = _mm512_loadu_si512((const __m512i *)iq2h + ib); + qx[0] = _mm512_and_si512(bits_l, m3); + qx[1] = _mm512_and_si512(bits_h, m3); + qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 2), m3); + qx[3] = _mm512_and_si512(_mm512_srli_epi16(bits_h, 2), m3); + qx[4] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 4), m3); + qx[5] = _mm512_and_si512(_mm512_srli_epi16(bits_h, 4), m3); + qx[6] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 6), m3); + qx[7] = _mm512_and_si512(_mm512_srli_epi16(bits_h, 6), m3); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = q8.load_quants64(iy, ib); + auto sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)); + acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[0], sy); + acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[1], sy); + sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)); + acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[2], sy); + acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[3], sy); + sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)); + acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[4], sy); + acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[5], sy); + sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)); + acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[6], sy); + acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[7], sy); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto dy = q8.scale(iy); + __m128 sum4; + for (int k = 0; k < 2; ++k) { + const auto& dx = k == 0 ? dl : dh; + auto sumf = _mm512_cvtepi32_ps(acc[2*iy+k]); + sum4 = _mm_mul_ps (_mm512_extractf32x4_ps(sumf, 0), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0x00))); + sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 1), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0x55)), sum4); + sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 2), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0xaa)), sum4); + sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 3), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0xff)), sum4); + sum4 = _mm_fmadd_ps(dx, _mm_set1_ps(-q8.sum_row(iy)), sum4); + info.store(8*ix + 4*k, iy, sum4); + } + acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_si512(); + } + } + if (int ix = 8*(nrc_x/8); ix < nrc_x) { + const float * dptr = (const float *)((const char *)vx + ix*bx); + auto dl = _mm_loadu_ps(dptr); + const uint8_t * iq2l = (const uint8_t *)(dptr + 4); + for (int ib = 0; ib < nb; ++ib) { + auto bits_l = _mm512_loadu_si512((const __m512i *)iq2l + ib); + qx[0] = _mm512_and_si512(bits_l, m3); + qx[1] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 2), m3); + qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 4), m3); + qx[3] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 6), m3); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = q8.load_quants64(iy, ib); + acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); + acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); + acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); + acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto dy = q8.scale(iy); + auto sumf = _mm512_cvtepi32_ps(acc[iy]); + auto sum4 = _mm_mul_ps(_mm512_extractf32x4_ps(sumf, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x00))); + sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x55)), sum4); + sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 2), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xaa)), sum4); + sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 3), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xff)), sum4); + sum4 = _mm_fmadd_ps(dl, _mm_set1_ps(-q8.sum_row(iy)), sum4); + info.store(ix, iy, sum4); + } + } + } +} +#else +template <int nrc_y> +static void mul_mat_iq2_bn_r4_q8_k16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + if (nrc_x%4) { + printf("%s: %d is not a multiple of 4\n", __func__, nrc_x); + GGML_ABORT("fatal error"); + } + mul_mat_iq2_bn_r4_q8_k16_avx2<nrc_y>(n, vx, bx, info, nrc_x); +} +#endif + #ifdef HAVE_FANCY_SIMD template <int nrc_y> static void mul_mat_iq4_nl_x4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { @@ -4744,6 +5003,20 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { mm.funcs[7] = mul_mat_iq2bn_q8_K64<8>; expected_typeB = GGML_TYPE_Q8_K64; break; + case GGML_TYPE_IQ2_BN_R4: + assert (ne00 % QK_IQ1BN == 0); + mm.funcs[0] = mul_mat_iq2_bn_r4_q8_k16<1>; + mm.funcs[1] = mul_mat_iq2_bn_r4_q8_k16<2>; + mm.funcs[2] = mul_mat_iq2_bn_r4_q8_k16<3>; + mm.funcs[3] = mul_mat_iq2_bn_r4_q8_k16<4>; + mm.funcs[4] = mul_mat_iq2_bn_r4_q8_k16<5>; + mm.funcs[5] = mul_mat_iq2_bn_r4_q8_k16<6>; +//#ifdef HAVE_FANCY_SIMD + mm.funcs[6] = mul_mat_iq2_bn_r4_q8_k16<7>; + mm.funcs[7] = mul_mat_iq2_bn_r4_q8_k16<8>; +//#endif + expected_typeB = GGML_TYPE_Q8_K16; + break; case GGML_TYPE_Q4_0: assert (ne00 % QK4_0 == 0); MulMat::set_functions<Q4_0_1_Unpacker>(mm); @@ -7171,6 +7444,135 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn } } +template <int nrc> struct Q8_16 { + + constexpr static int nrc_y = nrc; + + Q8_16(const DataInfo& info) { + for (int iy = 0; iy < nrc_y; ++iy) { + auto ptr = (const float *)info.src1_row(iy); + std::memcpy(d + 5*iy, ptr, 5*sizeof(float)); + y[iy] = (const int8_t *)(ptr + 5); + } + } + + inline int8x16x4_t load_quants(int iy, int i) const { return vld1q_s8_x4(y[iy] + 64*i); } + inline int8x16x2_t load_quants_32(int iy, int i) const { return vld1q_s8_x2(y[iy] + 32*i); } + inline float scale(int iy, int k) const { return d[5*iy+k]; } + inline float sum_row(int iy) const { return d[5*iy + 4]; } + inline float32x4_t scale(int iy) const { return vld1q_f32(d + 5*iy); } + + float d[5*nrc_y]; + const int8_t * y[nrc_y]; +}; + +template <int nrc_y> +static IQK_NOINLINE void mul_mat_iq2_bn_r4_q8_k16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + if (nrc_x%4) { + printf("%s: %d is not a multiple of 4\n", __func__, nrc_x); + GGML_ABORT("fatal error"); + } + Q8_16<nrc_y> q8(info); + auto m3 = vdupq_n_u8(0x3); + int nb = n / QK_IQ1BN; + if constexpr (nrc_y == 1) { + auto mc = vdupq_n_u8(0xc); + int32x4_t acc[8]; + for (int ix = 0; ix < nrc_x; ix += 4) { + for (int k = 0; k < 8; ++k) acc[k] = vdupq_n_s32(0); + const float * dptr = (const float *)((const char *)vx + ix*bx); + auto dl = vld1q_f32(dptr); + const uint8_t * iq2 = (const uint8_t *)(dptr + 4); + for (int ib = 0; ib < nb; ++ib) { + auto y = q8.load_quants(0, ib); + for (int j = 0; j < 4; ++j) { + auto bits1 = vld1q_u8(iq2 + 64*ib + 16*j); + auto bits2 = vshrq_n_u8(bits1, 4); + acc[2*j+0] = vdotq_laneq_s32(acc[2*j+0], vandq_u8(bits1, m3), y.val[j], 0); + acc[2*j+1] = vdotq_laneq_s32(acc[2*j+1], vandq_u8(bits1, mc), y.val[j], 1); + acc[2*j+0] = vdotq_laneq_s32(acc[2*j+0], vandq_u8(bits2, m3), y.val[j], 2); + acc[2*j+1] = vdotq_laneq_s32(acc[2*j+1], vandq_u8(bits2, mc), y.val[j], 3); + } + } + auto dy = vmulq_f32(dl, vdupq_n_f32(q8.scale(0, 0))); + auto sumf1 = vmulq_f32( vcvtq_f32_s32(acc[0]), dy); + auto sumf2 = vmulq_f32( vcvtq_f32_s32(acc[1]), dy); + dy = vmulq_f32(dl, vdupq_n_f32(q8.scale(0, 1))); + sumf1 = vfmaq_f32(sumf1, vcvtq_f32_s32(acc[2]), dy); + sumf2 = vfmaq_f32(sumf2, vcvtq_f32_s32(acc[3]), dy); + dy = vmulq_f32(dl, vdupq_n_f32(q8.scale(0, 2))); + sumf1 = vfmaq_f32(sumf1, vcvtq_f32_s32(acc[4]), dy); + sumf2 = vfmaq_f32(sumf2, vcvtq_f32_s32(acc[5]), dy); + dy = vmulq_f32(dl, vdupq_n_f32(q8.scale(0, 3))); + sumf1 = vfmaq_f32(sumf1, vcvtq_f32_s32(acc[6]), dy); + sumf2 = vfmaq_f32(sumf2, vcvtq_f32_s32(acc[7]), dy); + auto sumf = vfmaq_f32(sumf1, vdupq_n_f32(0.25f), sumf2); + sumf = vfmaq_f32(sumf, dl, vdupq_n_f32(-q8.sum_row(0))); + info.store(ix, 0, sumf); + } + } else { + int32x4_t acc[4*nrc_y] = {}; + uint8x16_t qx[8]; + for (int ix = 0; ix < nrc_x; ix += 4) { + const float * dptr = (const float *)((const char *)vx + ix*bx); + auto dl = vld1q_f32(dptr); + const uint8_t * iq2 = (const uint8_t *)(dptr + 4); + for (int ib = 0; ib < nb; ++ib) { + auto bits = vld1q_u8_x2(iq2 + 64*ib); + qx[0] = vandq_u8(bits.val[0], m3); + qx[1] = vandq_u8(vshrq_n_u8(bits.val[0], 2), m3); + qx[2] = vandq_u8(vshrq_n_u8(bits.val[0], 4), m3); + qx[3] = vshrq_n_u8(bits.val[0], 6); + qx[4] = vandq_u8(bits.val[1], m3); + qx[5] = vandq_u8(vshrq_n_u8(bits.val[1], 2), m3); + qx[6] = vandq_u8(vshrq_n_u8(bits.val[1], 4), m3); + qx[7] = vshrq_n_u8(bits.val[1], 6); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = q8.load_quants_32(iy, 2*ib+0); + acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[0], y.val[0], 0); + acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[1], y.val[0], 1); + acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[2], y.val[0], 2); + acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[3], y.val[0], 3); + acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[4], y.val[1], 0); + acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[5], y.val[1], 1); + acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[6], y.val[1], 2); + acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[7], y.val[1], 3); + } + bits = vld1q_u8_x2(iq2 + 64*ib + 32); + qx[0] = vandq_u8(bits.val[0], m3); + qx[1] = vandq_u8(vshrq_n_u8(bits.val[0], 2), m3); + qx[2] = vandq_u8(vshrq_n_u8(bits.val[0], 4), m3); + qx[3] = vshrq_n_u8(bits.val[0], 6); + qx[4] = vandq_u8(bits.val[1], m3); + qx[5] = vandq_u8(vshrq_n_u8(bits.val[1], 2), m3); + qx[6] = vandq_u8(vshrq_n_u8(bits.val[1], 4), m3); + qx[7] = vshrq_n_u8(bits.val[1], 6); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = q8.load_quants_32(iy, 2*ib+1); + acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[0], y.val[0], 0); + acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[1], y.val[0], 1); + acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[2], y.val[0], 2); + acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[3], y.val[0], 3); + acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[4], y.val[1], 0); + acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[5], y.val[1], 1); + acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[6], y.val[1], 2); + acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[7], y.val[1], 3); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto dy = q8.scale(iy); + float32x4_t sumf = vmulq_f32(vcvtq_f32_s32(acc[4*iy+0]), vmulq_laneq_f32(dl, dy, 0)); + sumf = vfmaq_f32(sumf, vcvtq_f32_s32(acc[4*iy+1]), vmulq_laneq_f32(dl, dy, 1)); + sumf = vfmaq_f32(sumf, vcvtq_f32_s32(acc[4*iy+2]), vmulq_laneq_f32(dl, dy, 2)); + sumf = vfmaq_f32(sumf, vcvtq_f32_s32(acc[4*iy+3]), vmulq_laneq_f32(dl, dy, 3)); + sumf = vfmaq_f32(sumf, dl, vdupq_n_f32(-q8.sum_row(iy))); + info.store(ix, iy, sumf); + acc[4*iy+0] = acc[4*iy+1] = acc[4*iy+2] = acc[4*iy+3] = vdupq_n_s32(0); + } + } + } +} + template <int nrc_y> static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { const int nb = n / QK_IQ1BN; @@ -7716,6 +8118,17 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { m.funcs[7] = mul_mat_iq2bn_q8_K64<8>; expected_Btype = GGML_TYPE_Q8_K64; break; + case GGML_TYPE_IQ2_BN_R4: + m.funcs[0] = mul_mat_iq2_bn_r4_q8_k16<1>; + m.funcs[1] = mul_mat_iq2_bn_r4_q8_k16<2>; + m.funcs[2] = mul_mat_iq2_bn_r4_q8_k16<3>; + m.funcs[3] = mul_mat_iq2_bn_r4_q8_k16<4>; + m.funcs[4] = mul_mat_iq2_bn_r4_q8_k16<5>; + //m.funcs[5] = mul_mat_iq2_bn_r4_q8_k16<6>; + //m.funcs[6] = mul_mat_iq2_bn_r4_q8_k16<7>; + //m.funcs[7] = mul_mat_iq2_bn_r4_q8_k16<8>; + expected_Btype = GGML_TYPE_Q8_K16; + break; case GGML_TYPE_Q4_0: MulMat::set_functions<DequantizerQ40>(m); expected_Btype = GGML_TYPE_Q8_0; |