summaryrefslogtreecommitdiff
path: root/ggml/src/iqk/iqk_mul_mat.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/iqk/iqk_mul_mat.cpp')
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp421
1 files changed, 417 insertions, 4 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index faa4cab7..b6ff7ab7 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -161,6 +161,17 @@ struct MulMat {
}
}
static bool prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny);
+ static inline int num_rows(ggml_type type) {
+ switch (type) {
+ case GGML_TYPE_Q4_0_R4:
+ case GGML_TYPE_Q5_0_R4:
+ case GGML_TYPE_Q6_0_R4:
+ case GGML_TYPE_Q8_0_R4:
+ case GGML_TYPE_IQ4_NL_X4:
+ case GGML_TYPE_IQ2_BN_R4: return 4;
+ default: return 1;
+ }
+ }
private:
template <typename Dequantizer> static void set_functions(MulMat& m);
};
@@ -181,13 +192,15 @@ bool iqk_mul_mat(long Nx, long Ny, long ne00,
size_t row_size_qy = strideB; //*ggml_type_size(ggml_type(typeB));
//if (ith == 0) printf("%s: ne00 = %d, row_size_qx = %d, strideA = %d\n", __func__, int(ne00), int(row_size_qx), int(strideA));
- auto nrc_x = (Nx + nth - 1)/nth;
+ auto num_rows = MulMat::num_rows(ggml_type(typeA));
+ GGML_ASSERT(Nx%num_rows == 0);
+ auto nrc_x = (Nx/num_rows + nth - 1)/nth;
auto first_x = ith*nrc_x;
- if (first_x + nrc_x > Nx) nrc_x = Nx - first_x;
+ if (first_x + nrc_x > Nx/num_rows) nrc_x = Nx/num_rows - first_x;
- DataInfo info{C + first_x, (const char *)B, (size_t)stride_C, row_size_qy, 0, 1, nullptr, 0};
+ DataInfo info{C + first_x*num_rows, (const char *)B, (size_t)stride_C, row_size_qy, 0, 1, nullptr, 0};
- mm.mul_mat_NxM(ne00, (const char *)A + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny);
+ mm.mul_mat_NxM(ne00, (const char *)A + row_size_qx*first_x*num_rows, row_size_qx, info, nrc_x*num_rows, Ny);
return true;
}
@@ -319,6 +332,30 @@ template <int nrc, typename block_q8 = block_q8_K> struct Q8 {
const block_q8 * y[nrc_y];
};
+template <int nrc> struct Q8_16 {
+
+ constexpr static int nrc_y = nrc;
+
+ Q8_16(const DataInfo& info) {
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto ptr = (const float *)info.src1_row(iy);
+ std::memcpy(d + 5*iy, ptr, 5*sizeof(float));
+ y[iy] = (const int8_t *)(ptr + 5);
+ }
+ }
+
+#ifdef HAVE_FANCY_SIMD
+ inline __m512i load_quants64(int iy, int i) const { return _mm512_loadu_si512((const __m512i*)y[iy] + i); }
+#endif
+ inline __m256i load_quants(int iy, int i) const { return _mm256_loadu_si256((const __m256i*)y[iy] + i); }
+ inline float scale(int iy, int k) const { return d[5*iy+k]; }
+ inline float sum_row(int iy) const { return d[5*iy + 4]; }
+ inline __m128 scale(int iy) const { return _mm_loadu_ps(d + 5*iy); }
+
+ float d[5*nrc_y];
+ const int8_t * y[nrc_y];
+};
+
struct Scales8KBase {
template <typename Q8>
inline void accum_mins(const __m128i& mins128, const Q8& q8, int i, float c, __m256 * accd) const {
@@ -2079,6 +2116,228 @@ static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInf
#endif // Zen4 or vanilla AVX2
+template <int nrc_y>
+static void mul_mat_iq2_bn_r4_q8_k16_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ if (nrc_x%4) {
+ printf("%s: %d is not a multiple of 4\n", __func__, nrc_x);
+ GGML_ABORT("fatal error");
+ }
+ Q8_16<nrc_y> q8(info);
+ auto m3 = _mm256_set1_epi8(0x3);
+ auto m1 = _mm256_set1_epi16(1);
+ int nb = n / QK_IQ1BN;
+ __m256i qx[4];
+ if constexpr (nrc_y > 4) {
+ __m256i acc[nrc_y] = {};
+ __m128 sum4[nrc_y];
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ const float * dptr = (const float *)((const char *)vx + ix*bx);
+ auto dl = _mm_loadu_ps(dptr);
+ const uint8_t * iq2l = (const uint8_t *)(dptr + 4);
+ for (int ib = 0; ib < nb; ++ib) {
+ auto bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+0);
+ qx[0] = _mm256_and_si256(bits, m3);
+ qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3);
+ qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3);
+ qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = q8.load_quants(iy, 2*ib+0);
+ auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
+ _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
+ auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
+ _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
+ acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)));
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto dy = q8.scale(iy);
+ auto sumf1 = _mm256_cvtepi32_ps(acc[iy]);
+ auto s4 = _mm_mul_ps(_mm256_extractf128_ps(sumf1, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x00)));
+ s4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x55)), s4);
+ sum4[iy] = _mm_fmadd_ps(dl, _mm_set1_ps(-q8.sum_row(iy)), s4);
+ acc[iy] = _mm256_setzero_si256();
+ }
+ for (int ib = 0; ib < nb; ++ib) {
+ auto bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+1);
+ qx[0] = _mm256_and_si256(bits, m3);
+ qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3);
+ qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3);
+ qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = q8.load_quants(iy, 2*ib+1);
+ auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
+ _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
+ auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
+ _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
+ acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)));
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto dy = q8.scale(iy);
+ auto sumf1 = _mm256_cvtepi32_ps(acc[iy]);
+ auto s4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xaa)), sum4[iy]);
+ s4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xff)), s4);
+ info.store(ix, iy, s4);
+ acc[iy] = _mm256_setzero_si256();
+ }
+ }
+ } else {
+ __m256i acc[2*nrc_y] = {};
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ const float * dptr = (const float *)((const char *)vx + ix*bx);
+ auto dl = _mm_loadu_ps(dptr);
+ const uint8_t * iq2l = (const uint8_t *)(dptr + 4);
+ for (int ib = 0; ib < nb; ++ib) {
+ auto bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+0);
+ qx[0] = _mm256_and_si256(bits, m3);
+ qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3);
+ qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3);
+ qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = q8.load_quants(iy, 2*ib+0);
+ auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
+ _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
+ auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
+ _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
+ acc[2*iy+0] = _mm256_add_epi32(acc[2*iy+0], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)));
+ }
+ bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+1);
+ qx[0] = _mm256_and_si256(bits, m3);
+ qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3);
+ qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3);
+ qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = q8.load_quants(iy, 2*ib+1);
+ auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
+ _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
+ auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
+ _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
+ acc[2*iy+1] = _mm256_add_epi32(acc[2*iy+1], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)));
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto dy = q8.scale(iy);
+ auto sumf1 = _mm256_cvtepi32_ps(acc[2*iy+0]);
+ auto sumf2 = _mm256_cvtepi32_ps(acc[2*iy+1]);
+ auto sum4 = _mm_mul_ps(_mm256_extractf128_ps(sumf1, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x00)));
+ sum4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x55)), sum4);
+ sum4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf2, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xaa)), sum4);
+ sum4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf2, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xff)), sum4);
+ sum4 = _mm_fmadd_ps(dl, _mm_set1_ps(-q8.sum_row(iy)), sum4);
+ info.store(ix, iy, sum4);
+ acc[2*iy+0] = acc[2*iy+1] = _mm256_setzero_si256();
+ }
+ }
+ }
+}
+
+#ifdef HAVE_FANCY_SIMD
+template <int nrc_y>
+static void mul_mat_iq2_bn_r4_q8_k16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ if (nrc_x%4) {
+ printf("%s: %d is not a multiple of 4\n", __func__, nrc_x);
+ GGML_ABORT("fatal error");
+ }
+ if constexpr (nrc_y == 1) {
+ mul_mat_iq2_bn_r4_q8_k16_avx2<1>(n, vx, bx, info, nrc_x);
+ } else {
+ Q8_16<nrc_y> q8(info);
+ auto m3 = _mm512_set1_epi8(0x3);
+ int nb = n / QK_IQ1BN;
+ __m512i acc[2*nrc_y] = {};
+ __m512i qx[8];
+ for (int ix = 0; ix < nrc_x/8; ++ix) {
+ const float * dptr1 = (const float *)((const char *)vx + (8*ix+0)*bx);
+ const float * dptr2 = (const float *)((const char *)vx + (8*ix+4)*bx);
+ auto dl = _mm_loadu_ps(dptr1);
+ auto dh = _mm_loadu_ps(dptr2);
+ const uint8_t * iq2l = (const uint8_t *)(dptr1 + 4);
+ const uint8_t * iq2h = (const uint8_t *)(dptr2 + 4);
+ for (int ib = 0; ib < nb; ++ib) {
+ auto bits_l = _mm512_loadu_si512((const __m512i *)iq2l + ib);
+ auto bits_h = _mm512_loadu_si512((const __m512i *)iq2h + ib);
+ qx[0] = _mm512_and_si512(bits_l, m3);
+ qx[1] = _mm512_and_si512(bits_h, m3);
+ qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 2), m3);
+ qx[3] = _mm512_and_si512(_mm512_srli_epi16(bits_h, 2), m3);
+ qx[4] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 4), m3);
+ qx[5] = _mm512_and_si512(_mm512_srli_epi16(bits_h, 4), m3);
+ qx[6] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 6), m3);
+ qx[7] = _mm512_and_si512(_mm512_srli_epi16(bits_h, 6), m3);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = q8.load_quants64(iy, ib);
+ auto sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00));
+ acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[0], sy);
+ acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[1], sy);
+ sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55));
+ acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[2], sy);
+ acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[3], sy);
+ sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa));
+ acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[4], sy);
+ acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[5], sy);
+ sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff));
+ acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[6], sy);
+ acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[7], sy);
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto dy = q8.scale(iy);
+ __m128 sum4;
+ for (int k = 0; k < 2; ++k) {
+ const auto& dx = k == 0 ? dl : dh;
+ auto sumf = _mm512_cvtepi32_ps(acc[2*iy+k]);
+ sum4 = _mm_mul_ps (_mm512_extractf32x4_ps(sumf, 0), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0x00)));
+ sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 1), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0x55)), sum4);
+ sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 2), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0xaa)), sum4);
+ sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 3), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0xff)), sum4);
+ sum4 = _mm_fmadd_ps(dx, _mm_set1_ps(-q8.sum_row(iy)), sum4);
+ info.store(8*ix + 4*k, iy, sum4);
+ }
+ acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_si512();
+ }
+ }
+ if (int ix = 8*(nrc_x/8); ix < nrc_x) {
+ const float * dptr = (const float *)((const char *)vx + ix*bx);
+ auto dl = _mm_loadu_ps(dptr);
+ const uint8_t * iq2l = (const uint8_t *)(dptr + 4);
+ for (int ib = 0; ib < nb; ++ib) {
+ auto bits_l = _mm512_loadu_si512((const __m512i *)iq2l + ib);
+ qx[0] = _mm512_and_si512(bits_l, m3);
+ qx[1] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 2), m3);
+ qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 4), m3);
+ qx[3] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 6), m3);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = q8.load_quants64(iy, ib);
+ acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
+ acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
+ acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
+ acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto dy = q8.scale(iy);
+ auto sumf = _mm512_cvtepi32_ps(acc[iy]);
+ auto sum4 = _mm_mul_ps(_mm512_extractf32x4_ps(sumf, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x00)));
+ sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x55)), sum4);
+ sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 2), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xaa)), sum4);
+ sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 3), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xff)), sum4);
+ sum4 = _mm_fmadd_ps(dl, _mm_set1_ps(-q8.sum_row(iy)), sum4);
+ info.store(ix, iy, sum4);
+ }
+ }
+ }
+}
+#else
+template <int nrc_y>
+static void mul_mat_iq2_bn_r4_q8_k16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ if (nrc_x%4) {
+ printf("%s: %d is not a multiple of 4\n", __func__, nrc_x);
+ GGML_ABORT("fatal error");
+ }
+ mul_mat_iq2_bn_r4_q8_k16_avx2<nrc_y>(n, vx, bx, info, nrc_x);
+}
+#endif
+
#ifdef HAVE_FANCY_SIMD
template <int nrc_y>
static void mul_mat_iq4_nl_x4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
@@ -4744,6 +5003,20 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
mm.funcs[7] = mul_mat_iq2bn_q8_K64<8>;
expected_typeB = GGML_TYPE_Q8_K64;
break;
+ case GGML_TYPE_IQ2_BN_R4:
+ assert (ne00 % QK_IQ1BN == 0);
+ mm.funcs[0] = mul_mat_iq2_bn_r4_q8_k16<1>;
+ mm.funcs[1] = mul_mat_iq2_bn_r4_q8_k16<2>;
+ mm.funcs[2] = mul_mat_iq2_bn_r4_q8_k16<3>;
+ mm.funcs[3] = mul_mat_iq2_bn_r4_q8_k16<4>;
+ mm.funcs[4] = mul_mat_iq2_bn_r4_q8_k16<5>;
+ mm.funcs[5] = mul_mat_iq2_bn_r4_q8_k16<6>;
+//#ifdef HAVE_FANCY_SIMD
+ mm.funcs[6] = mul_mat_iq2_bn_r4_q8_k16<7>;
+ mm.funcs[7] = mul_mat_iq2_bn_r4_q8_k16<8>;
+//#endif
+ expected_typeB = GGML_TYPE_Q8_K16;
+ break;
case GGML_TYPE_Q4_0:
assert (ne00 % QK4_0 == 0);
MulMat::set_functions<Q4_0_1_Unpacker>(mm);
@@ -7171,6 +7444,135 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
}
}
+template <int nrc> struct Q8_16 {
+
+ constexpr static int nrc_y = nrc;
+
+ Q8_16(const DataInfo& info) {
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto ptr = (const float *)info.src1_row(iy);
+ std::memcpy(d + 5*iy, ptr, 5*sizeof(float));
+ y[iy] = (const int8_t *)(ptr + 5);
+ }
+ }
+
+ inline int8x16x4_t load_quants(int iy, int i) const { return vld1q_s8_x4(y[iy] + 64*i); }
+ inline int8x16x2_t load_quants_32(int iy, int i) const { return vld1q_s8_x2(y[iy] + 32*i); }
+ inline float scale(int iy, int k) const { return d[5*iy+k]; }
+ inline float sum_row(int iy) const { return d[5*iy + 4]; }
+ inline float32x4_t scale(int iy) const { return vld1q_f32(d + 5*iy); }
+
+ float d[5*nrc_y];
+ const int8_t * y[nrc_y];
+};
+
+template <int nrc_y>
+static IQK_NOINLINE void mul_mat_iq2_bn_r4_q8_k16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ if (nrc_x%4) {
+ printf("%s: %d is not a multiple of 4\n", __func__, nrc_x);
+ GGML_ABORT("fatal error");
+ }
+ Q8_16<nrc_y> q8(info);
+ auto m3 = vdupq_n_u8(0x3);
+ int nb = n / QK_IQ1BN;
+ if constexpr (nrc_y == 1) {
+ auto mc = vdupq_n_u8(0xc);
+ int32x4_t acc[8];
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ for (int k = 0; k < 8; ++k) acc[k] = vdupq_n_s32(0);
+ const float * dptr = (const float *)((const char *)vx + ix*bx);
+ auto dl = vld1q_f32(dptr);
+ const uint8_t * iq2 = (const uint8_t *)(dptr + 4);
+ for (int ib = 0; ib < nb; ++ib) {
+ auto y = q8.load_quants(0, ib);
+ for (int j = 0; j < 4; ++j) {
+ auto bits1 = vld1q_u8(iq2 + 64*ib + 16*j);
+ auto bits2 = vshrq_n_u8(bits1, 4);
+ acc[2*j+0] = vdotq_laneq_s32(acc[2*j+0], vandq_u8(bits1, m3), y.val[j], 0);
+ acc[2*j+1] = vdotq_laneq_s32(acc[2*j+1], vandq_u8(bits1, mc), y.val[j], 1);
+ acc[2*j+0] = vdotq_laneq_s32(acc[2*j+0], vandq_u8(bits2, m3), y.val[j], 2);
+ acc[2*j+1] = vdotq_laneq_s32(acc[2*j+1], vandq_u8(bits2, mc), y.val[j], 3);
+ }
+ }
+ auto dy = vmulq_f32(dl, vdupq_n_f32(q8.scale(0, 0)));
+ auto sumf1 = vmulq_f32( vcvtq_f32_s32(acc[0]), dy);
+ auto sumf2 = vmulq_f32( vcvtq_f32_s32(acc[1]), dy);
+ dy = vmulq_f32(dl, vdupq_n_f32(q8.scale(0, 1)));
+ sumf1 = vfmaq_f32(sumf1, vcvtq_f32_s32(acc[2]), dy);
+ sumf2 = vfmaq_f32(sumf2, vcvtq_f32_s32(acc[3]), dy);
+ dy = vmulq_f32(dl, vdupq_n_f32(q8.scale(0, 2)));
+ sumf1 = vfmaq_f32(sumf1, vcvtq_f32_s32(acc[4]), dy);
+ sumf2 = vfmaq_f32(sumf2, vcvtq_f32_s32(acc[5]), dy);
+ dy = vmulq_f32(dl, vdupq_n_f32(q8.scale(0, 3)));
+ sumf1 = vfmaq_f32(sumf1, vcvtq_f32_s32(acc[6]), dy);
+ sumf2 = vfmaq_f32(sumf2, vcvtq_f32_s32(acc[7]), dy);
+ auto sumf = vfmaq_f32(sumf1, vdupq_n_f32(0.25f), sumf2);
+ sumf = vfmaq_f32(sumf, dl, vdupq_n_f32(-q8.sum_row(0)));
+ info.store(ix, 0, sumf);
+ }
+ } else {
+ int32x4_t acc[4*nrc_y] = {};
+ uint8x16_t qx[8];
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ const float * dptr = (const float *)((const char *)vx + ix*bx);
+ auto dl = vld1q_f32(dptr);
+ const uint8_t * iq2 = (const uint8_t *)(dptr + 4);
+ for (int ib = 0; ib < nb; ++ib) {
+ auto bits = vld1q_u8_x2(iq2 + 64*ib);
+ qx[0] = vandq_u8(bits.val[0], m3);
+ qx[1] = vandq_u8(vshrq_n_u8(bits.val[0], 2), m3);
+ qx[2] = vandq_u8(vshrq_n_u8(bits.val[0], 4), m3);
+ qx[3] = vshrq_n_u8(bits.val[0], 6);
+ qx[4] = vandq_u8(bits.val[1], m3);
+ qx[5] = vandq_u8(vshrq_n_u8(bits.val[1], 2), m3);
+ qx[6] = vandq_u8(vshrq_n_u8(bits.val[1], 4), m3);
+ qx[7] = vshrq_n_u8(bits.val[1], 6);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = q8.load_quants_32(iy, 2*ib+0);
+ acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[0], y.val[0], 0);
+ acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[1], y.val[0], 1);
+ acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[2], y.val[0], 2);
+ acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[3], y.val[0], 3);
+ acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[4], y.val[1], 0);
+ acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[5], y.val[1], 1);
+ acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[6], y.val[1], 2);
+ acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[7], y.val[1], 3);
+ }
+ bits = vld1q_u8_x2(iq2 + 64*ib + 32);
+ qx[0] = vandq_u8(bits.val[0], m3);
+ qx[1] = vandq_u8(vshrq_n_u8(bits.val[0], 2), m3);
+ qx[2] = vandq_u8(vshrq_n_u8(bits.val[0], 4), m3);
+ qx[3] = vshrq_n_u8(bits.val[0], 6);
+ qx[4] = vandq_u8(bits.val[1], m3);
+ qx[5] = vandq_u8(vshrq_n_u8(bits.val[1], 2), m3);
+ qx[6] = vandq_u8(vshrq_n_u8(bits.val[1], 4), m3);
+ qx[7] = vshrq_n_u8(bits.val[1], 6);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = q8.load_quants_32(iy, 2*ib+1);
+ acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[0], y.val[0], 0);
+ acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[1], y.val[0], 1);
+ acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[2], y.val[0], 2);
+ acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[3], y.val[0], 3);
+ acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[4], y.val[1], 0);
+ acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[5], y.val[1], 1);
+ acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[6], y.val[1], 2);
+ acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[7], y.val[1], 3);
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto dy = q8.scale(iy);
+ float32x4_t sumf = vmulq_f32(vcvtq_f32_s32(acc[4*iy+0]), vmulq_laneq_f32(dl, dy, 0));
+ sumf = vfmaq_f32(sumf, vcvtq_f32_s32(acc[4*iy+1]), vmulq_laneq_f32(dl, dy, 1));
+ sumf = vfmaq_f32(sumf, vcvtq_f32_s32(acc[4*iy+2]), vmulq_laneq_f32(dl, dy, 2));
+ sumf = vfmaq_f32(sumf, vcvtq_f32_s32(acc[4*iy+3]), vmulq_laneq_f32(dl, dy, 3));
+ sumf = vfmaq_f32(sumf, dl, vdupq_n_f32(-q8.sum_row(iy)));
+ info.store(ix, iy, sumf);
+ acc[4*iy+0] = acc[4*iy+1] = acc[4*iy+2] = acc[4*iy+3] = vdupq_n_s32(0);
+ }
+ }
+ }
+}
+
template <int nrc_y>
static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
const int nb = n / QK_IQ1BN;
@@ -7716,6 +8118,17 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
m.funcs[7] = mul_mat_iq2bn_q8_K64<8>;
expected_Btype = GGML_TYPE_Q8_K64;
break;
+ case GGML_TYPE_IQ2_BN_R4:
+ m.funcs[0] = mul_mat_iq2_bn_r4_q8_k16<1>;
+ m.funcs[1] = mul_mat_iq2_bn_r4_q8_k16<2>;
+ m.funcs[2] = mul_mat_iq2_bn_r4_q8_k16<3>;
+ m.funcs[3] = mul_mat_iq2_bn_r4_q8_k16<4>;
+ m.funcs[4] = mul_mat_iq2_bn_r4_q8_k16<5>;
+ //m.funcs[5] = mul_mat_iq2_bn_r4_q8_k16<6>;
+ //m.funcs[6] = mul_mat_iq2_bn_r4_q8_k16<7>;
+ //m.funcs[7] = mul_mat_iq2_bn_r4_q8_k16<8>;
+ expected_Btype = GGML_TYPE_Q8_K16;
+ break;
case GGML_TYPE_Q4_0:
MulMat::set_functions<DequantizerQ40>(m);
expected_Btype = GGML_TYPE_Q8_0;