summaryrefslogtreecommitdiff
path: root/ggml/src
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2024-12-17 07:51:11 +0100
committerGitHub <noreply@github.com>2024-12-17 07:51:11 +0100
commitd69344f8ea72c6fe6ec16300b939586fa9633e2e (patch)
treeb8c0efb7322169372543b020360bf0e27549fba5 /ggml/src
parent1714e46f137318152370beee16af92991042d7b4 (diff)
IQ3_K_R4 (#145)
* iq3_k_r4 WIP * iq3_k_r4: Zen4 * iq3_k_r4: AVX2 * iq3_k_r4: NEON * iq3_k_r4: faster matrix x vector multiplication on NEON --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml/src')
-rw-r--r--ggml/src/ggml-common.h10
-rw-r--r--ggml/src/ggml-quants.c1
-rw-r--r--ggml/src/ggml.c22
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp322
-rw-r--r--ggml/src/iqk/iqk_quantize.cpp136
-rw-r--r--ggml/src/iqk/iqk_quantize.h6
6 files changed, 461 insertions, 36 deletions
diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h
index d77ba12c..ca56704c 100644
--- a/ggml/src/ggml-common.h
+++ b/ggml/src/ggml-common.h
@@ -539,6 +539,16 @@ typedef struct {
static_assert(sizeof(block_iq3_k) == sizeof(ggml_half) + 2*sizeof(uint16_t) + QK_K/32 + QK_K/4 + QK_K/8, "wrong iq3_k block size/padding");
typedef struct {
+ ggml_half d[4];
+ uint8_t extra[8];
+ uint8_t scales_h[QK_K/32];
+ uint8_t scales_l[QK_K/8];
+ uint8_t qs[QK_K];
+ uint8_t qh[QK_K/2];
+} block_iq3_k_r4;
+static_assert(sizeof(block_iq3_k_r4) == 4*sizeof(block_iq3_k), "wrong iq3_k_r4 block size/padding");
+
+typedef struct {
ggml_half d;
uint16_t extra;
uint8_t scales_h[QK_K/64];
diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c
index 0b157295..1d022672 100644
--- a/ggml/src/ggml-quants.c
+++ b/ggml/src/ggml-quants.c
@@ -15207,6 +15207,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
case GGML_TYPE_Q4_K_R4: break;
case GGML_TYPE_Q5_K_R4: break;
case GGML_TYPE_Q6_K_R4: break;
+ case GGML_TYPE_IQ3_K_R4: break;
case GGML_TYPE_IQ4_K_R4: break;
case GGML_TYPE_Q8_K_R8: break;
case GGML_TYPE_BF16_R16: break;
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 51ef6eb2..4194d943 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -1360,6 +1360,19 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.nrows = 1,
.row_meta_size = 0,
},
+ [GGML_TYPE_IQ3_K_R4] = {
+ .type_name = "iq3_k_r4",
+ .blck_size = QK_K,
+ .type_size = sizeof(block_iq3_k),
+ .is_quantized = true,
+ .to_float = (ggml_to_float_t) dequantize_row_iq3_k_r4,
+ .from_float = quantize_row_iq3_k_r4,
+ .from_float_ref = (ggml_from_float_t)quantize_row_iq3_k_r4_ref,
+ .vec_dot = vec_dot_iq3_k_r4_q8_k,
+ .vec_dot_type = GGML_TYPE_Q8_K,
+ .nrows = 1,
+ .row_meta_size = 0,
+ },
[GGML_TYPE_IQ5_K] = {
.type_name = "iq5_k",
.blck_size = QK_K,
@@ -4163,6 +4176,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
case GGML_FTYPE_MOSTLY_IQ2_KS: wtype = GGML_TYPE_IQ2_KS; break;
case GGML_FTYPE_MOSTLY_IQ3_K: wtype = GGML_TYPE_IQ3_K; break;
case GGML_FTYPE_MOSTLY_IQ4_K: wtype = GGML_TYPE_IQ4_K; break;
+ case GGML_FTYPE_MOSTLY_IQ3_K_R4: wtype = GGML_TYPE_IQ3_K_R4; break;
case GGML_FTYPE_MOSTLY_IQ4_K_R4: wtype = GGML_TYPE_IQ4_K_R4; break;
case GGML_FTYPE_MOSTLY_IQ5_K: wtype = GGML_TYPE_IQ5_K; break;
case GGML_FTYPE_MOSTLY_IQ6_K: wtype = GGML_TYPE_IQ6_K; break;
@@ -10700,6 +10714,7 @@ static void ggml_compute_forward_add(
case GGML_TYPE_IQ2_KS:
case GGML_TYPE_IQ3_K:
case GGML_TYPE_IQ4_K:
+ case GGML_TYPE_IQ3_K_R4:
case GGML_TYPE_IQ4_K_R4:
case GGML_TYPE_IQ5_K:
case GGML_TYPE_IQ6_K:
@@ -11156,6 +11171,7 @@ static void ggml_compute_forward_add1(
case GGML_TYPE_IQ2_KS:
case GGML_TYPE_IQ3_K:
case GGML_TYPE_IQ4_K:
+ case GGML_TYPE_IQ3_K_R4:
case GGML_TYPE_IQ4_K_R4:
case GGML_TYPE_IQ5_K:
case GGML_TYPE_IQ6_K:
@@ -11309,6 +11325,7 @@ static void ggml_compute_forward_acc(
case GGML_TYPE_IQ2_KS:
case GGML_TYPE_IQ3_K:
case GGML_TYPE_IQ4_K:
+ case GGML_TYPE_IQ3_K_R4:
case GGML_TYPE_IQ4_K_R4:
case GGML_TYPE_IQ5_K:
case GGML_TYPE_IQ6_K:
@@ -14508,6 +14525,7 @@ static void ggml_compute_forward_out_prod(
case GGML_TYPE_IQ2_KS:
case GGML_TYPE_IQ3_K:
case GGML_TYPE_IQ4_K:
+ case GGML_TYPE_IQ3_K_R4:
case GGML_TYPE_IQ4_K_R4:
case GGML_TYPE_IQ5_K:
case GGML_TYPE_IQ6_K:
@@ -14901,6 +14919,7 @@ static void ggml_compute_forward_set(
case GGML_TYPE_IQ2_KS:
case GGML_TYPE_IQ3_K:
case GGML_TYPE_IQ4_K:
+ case GGML_TYPE_IQ3_K_R4:
case GGML_TYPE_IQ4_K_R4:
case GGML_TYPE_IQ5_K:
case GGML_TYPE_IQ6_K:
@@ -15188,6 +15207,7 @@ static void ggml_compute_forward_get_rows(
case GGML_TYPE_IQ2_KS:
case GGML_TYPE_IQ3_K:
case GGML_TYPE_IQ4_K:
+ case GGML_TYPE_IQ3_K_R4:
case GGML_TYPE_IQ4_K_R4:
case GGML_TYPE_IQ5_K:
case GGML_TYPE_IQ6_K:
@@ -15804,6 +15824,7 @@ static void ggml_compute_forward_clamp(
case GGML_TYPE_IQ2_KS:
case GGML_TYPE_IQ3_K:
case GGML_TYPE_IQ4_K:
+ case GGML_TYPE_IQ3_K_R4:
case GGML_TYPE_IQ4_K_R4:
case GGML_TYPE_IQ5_K:
case GGML_TYPE_IQ6_K:
@@ -22648,6 +22669,7 @@ size_t ggml_quantize_chunk(
case GGML_TYPE_IQ2_KS: result = quantize_iq2_ks (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ3_K: result = quantize_iq3_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ4_K: result = quantize_iq4_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_IQ3_K_R4:result = quantize_iq3_k_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ4_K_R4:result = quantize_iq4_k_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ5_K: result = quantize_iq5_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ6_K: result = quantize_iq6_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index e74a15f0..bcf96c0a 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -182,6 +182,7 @@ struct MulMat {
case GGML_TYPE_Q8_0_R4:
case GGML_TYPE_IQ4_NL_R4:
case GGML_TYPE_IQ4_XS_R4:
+ case GGML_TYPE_IQ3_K_R4:
case GGML_TYPE_IQ4_K_R4:
case GGML_TYPE_IQ2_BN_R4: return 4;
case GGML_TYPE_Q8_K_R8: return 8;
@@ -3957,6 +3958,119 @@ static void mul_mat_bf16_r16_bf16(int n, const void * vx, size_t bx, const DataI
#endif
template <int nrc_y>
+static void mul_mat_iq3_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ auto m4 = _mm256_set1_epi8(0xf);
+ auto ms = _mm256_set1_epi8(8);
+ auto m03 = _mm256_set1_epi8(0x03);
+ auto m04 = _mm256_set1_epi8(0x04);
+ auto smask = _mm256_set_epi64x(0x0808080808080808, 0x0404040404040404, 0x0202020202020202, 0x0101010101010101);
+ auto shift_shuffle = _mm256_set_epi64x(0x0707070706060606, 0x0505050504040404, 0x0303030302020202, 0x0101010100000000);
+ auto values128 = _mm_loadu_si128((const __m128i *)iq3nl_values);
+ auto values = MM256_SET_M128I(values128, values128);
+ values = _mm256_add_epi8(values, _mm256_set1_epi8(64));
+ static const uint8_t k_shuff[32] = {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15, 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15};
+ auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff);
+#ifndef HAVE_FANCY_SIMD
+ auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100);
+#endif
+ int nbl = n / QK_K;
+ __m256 acc[nrc_y] = {};
+ __m256i qx[4];
+ uint64_t stored_scales[8];
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ const block_iq3_k_r4 * iq3 = (const block_iq3_k_r4 *)((const char *)vx + (ix+0)*bx);
+ for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
+ auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq3[ibl].d));
+ auto d4 = _mm256_set_m128(dl, dl);
+ auto extra = _mm256_set1_epi64x(*(const uint64_t *)iq3[ibl].extra);
+ auto slbits = _mm256_loadu_si256((const __m256i *)iq3[ibl].scales_l);
+ auto sl1 = _mm256_add_epi8(_mm256_slli_epi16(_mm256_and_si256(slbits, m4), 1), _mm256_set1_epi8(1));
+ auto sl2 = _mm256_add_epi8(_mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(slbits, 4), m4), 1), _mm256_set1_epi8(1));
+ auto sh = _mm256_set1_epi64x(((const uint64_t *)iq3[ibl].scales_h)[0]);
+ auto sh1 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(sh, smask), smask), _mm256_set1_epi8(1));
+ auto sh2 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_srli_epi16(sh, 4), smask), smask), _mm256_set1_epi8(1));
+ auto i8scales1 = _mm256_sign_epi8(sl1, sh1);
+ auto i8scales2 = _mm256_sign_epi8(sl2, sh2);
+ _mm256_storeu_si256((__m256i *)stored_scales+0, i8scales1);
+ _mm256_storeu_si256((__m256i *)stored_scales+1, i8scales2);
+ __m256i isum[nrc_y] = {};
+ {
+ auto t1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales1, 0)), shuff); // blocks 0, 1, 2, 3 for each row
+ auto t2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales1, 1)), shuff); // blocks 4, 5, 6, 7 for each row
+ auto t3 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales2, 0)), shuff); // blocks 8, 9, 10, 11 for each row
+ auto t4 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales2, 1)), shuff); // blocks 12, 13, 14, 15 for each row
+ auto s1 = _mm256_mullo_epi16(_mm256_set1_epi16(-64), MM256_SET_M128I(_mm256_extracti128_si256(t3, 0), _mm256_extracti128_si256(t1, 0))); // blocks 0, 1, 8, 9
+ auto s2 = _mm256_mullo_epi16(_mm256_set1_epi16(-64), MM256_SET_M128I(_mm256_extracti128_si256(t3, 1), _mm256_extracti128_si256(t1, 1))); // blocks 2, 3, 10, 11
+ auto s3 = _mm256_mullo_epi16(_mm256_set1_epi16(-64), MM256_SET_M128I(_mm256_extracti128_si256(t4, 0), _mm256_extracti128_si256(t2, 0))); // blocks 4, 5, 12, 13
+ auto s4 = _mm256_mullo_epi16(_mm256_set1_epi16(-64), MM256_SET_M128I(_mm256_extracti128_si256(t4, 1), _mm256_extracti128_si256(t2, 1))); // blocks 6, 7, 14, 15
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto bsums = q8.load_bsums(iy, ibl);
+#ifdef HAVE_FANCY_SIMD
+ isum[iy] = _mm256_dpwssd_epi32(isum[iy], s1, _mm256_shuffle_epi32(bsums, 0x00));
+ isum[iy] = _mm256_dpwssd_epi32(isum[iy], s2, _mm256_shuffle_epi32(bsums, 0x55));
+ isum[iy] = _mm256_dpwssd_epi32(isum[iy], s3, _mm256_shuffle_epi32(bsums, 0xaa));
+ isum[iy] = _mm256_dpwssd_epi32(isum[iy], s4, _mm256_shuffle_epi32(bsums, 0xff));
+#else
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s1, _mm256_shuffle_epi32(bsums, 0x00)));
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s2, _mm256_shuffle_epi32(bsums, 0x55)));
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s3, _mm256_shuffle_epi32(bsums, 0xaa)));
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s4, _mm256_shuffle_epi32(bsums, 0xff)));
+#endif
+ }
+ }
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+#ifdef HAVE_FANCY_SIMD
+ auto scales = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(stored_scales + ib)));
+#else
+ auto scales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi64x(stored_scales[ib])), s_shuffle);
+#endif
+ auto lb = _mm256_loadu_si256((const __m256i *)iq3[ibl].qs+ib);
+ auto hbits = _mm_loadu_si128((const __m128i *)iq3[ibl].qh+ib);
+ auto hb = MM256_SET_M128I(hbits, _mm_slli_epi16(hbits, 4));
+ auto shift = _mm256_and_si256(ms, _mm256_slli_epi16(extra, 3)); extra = _mm256_srli_epi16(extra, 1);
+ shift = _mm256_shuffle_epi8(shift, shift_shuffle);
+ qx[0] = _mm256_or_si256(_mm256_and_si256(lb, m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 2)));
+ qx[1] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lb, 2), m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 3)));
+ qx[2] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lb, 4), m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 4)));
+ qx[3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lb, 6), m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 5)));
+ qx[0] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[0], shift));
+ qx[1] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[1], shift));
+ qx[2] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[2], shift));
+ qx[3] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[3], shift));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
+#ifdef HAVE_FANCY_SIMD
+ auto sumi = _mm256_setzero_si256();
+ sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales, sumi));
+#else
+ auto sumi1 = _mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00));
+ auto sumi2 = _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55));
+ auto sumi3 = _mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa));
+ auto sumi4 = _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff));
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(scales, sumi1), _mm256_madd_epi16(scales, sumi2)));
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(scales, sumi3), _mm256_madd_epi16(scales, sumi4)));
+#endif
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
+ acc[iy] = _mm256_setzero_ps();
+ info.store(ix+0, iy, sum);
+ }
+ }
+}
+
+template <int nrc_y>
static void mul_mat_iq4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
@@ -6167,6 +6281,18 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
mm.funcs[7] = mul_mat_iq4_k_r4_q8_k<8>;
expected_typeB = GGML_TYPE_Q8_K;
break;
+ case GGML_TYPE_IQ3_K_R4:
+ assert (ne00 % QK_K == 0);
+ mm.funcs[0] = mul_mat_iq3_k_r4_q8_k<1>;
+ mm.funcs[1] = mul_mat_iq3_k_r4_q8_k<2>;
+ mm.funcs[2] = mul_mat_iq3_k_r4_q8_k<3>;
+ mm.funcs[3] = mul_mat_iq3_k_r4_q8_k<4>;
+ mm.funcs[4] = mul_mat_iq3_k_r4_q8_k<5>;
+ mm.funcs[5] = mul_mat_iq3_k_r4_q8_k<6>;
+ mm.funcs[6] = mul_mat_iq3_k_r4_q8_k<7>;
+ mm.funcs[7] = mul_mat_iq3_k_r4_q8_k<8>;
+ expected_typeB = GGML_TYPE_Q8_K;
+ break;
case GGML_TYPE_Q4_0_R4:
assert (ne00 % QK4_NL == 0);
mm.funcs[0] = mul_mat_q4_0_r4_q8_1<1>;
@@ -8844,6 +8970,161 @@ void mul_mat_iq4_xs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& i
}
template <int nrc_y>
+inline void iq3_4_add_shift(int ibl, const Q8<nrc_y, block_q8_K>& q8, const int8x16x4_t& i8scales, uint8x16_t extra,
+ uint8x16_t ms, int32x4_t * isum) {
+ auto s8_1 = vmulq_s8(i8scales.val[0], vandq_u8(ms, vshlq_n_u8(extra, 2)));
+ auto s8_2 = vmulq_s8(i8scales.val[1], vandq_u8(ms, extra));
+ auto s16_1 = vmovl_s8(vget_low_s8 (s8_1));
+ auto s16_2 = vmovl_s8(vget_high_s8(s8_1));
+ auto s16_3 = vmovl_s8(vget_low_s8 (s8_2));
+ auto s16_4 = vmovl_s8(vget_high_s8(s8_2));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto b8 = vld1_s16(q8.y[iy][ibl].bsums);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_1), b8, 0);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_1), b8, 1);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_2), b8, 2);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_2), b8, 3);
+ b8 = vld1_s16(q8.y[iy][ibl].bsums+4);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_3), b8, 0);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_3), b8, 1);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_4), b8, 2);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_4), b8, 3);
+ }
+ s8_1 = vmulq_s8(i8scales.val[2], vandq_u8(ms, vshrq_n_u8(extra, 2)));
+ s8_2 = vmulq_s8(i8scales.val[3], vandq_u8(ms, vshrq_n_u8(extra, 4)));
+ s16_1 = vmovl_s8(vget_low_s8 (s8_1));
+ s16_2 = vmovl_s8(vget_high_s8(s8_1));
+ s16_3 = vmovl_s8(vget_low_s8 (s8_2));
+ s16_4 = vmovl_s8(vget_high_s8(s8_2));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto b8 = vld1_s16(q8.y[iy][ibl].bsums+8);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_1), b8, 0);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_1), b8, 1);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_2), b8, 2);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_2), b8, 3);
+ b8 = vld1_s16(q8.y[iy][ibl].bsums+12);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_3), b8, 0);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_3), b8, 1);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_4), b8, 2);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_4), b8, 3);
+ }
+}
+
+template <int nrc_y>
+void mul_mat_iq3_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ auto m4 = vdupq_n_u8(0xf);
+ auto ms = nrc_y == 1 ? vdupq_n_u8(4) : vdupq_n_u8(8);
+ auto m03 = vdupq_n_u8(0x03);
+ auto m04 = vdupq_n_u8(0x04);
+ uint8x16x2_t shift_shuffle = {
+ vreinterpretq_u8_u64(uint64x2_t{0x0101010100000000, 0x0303030302020202}),
+ vreinterpretq_u8_u64(uint64x2_t{0x0505050504040404, 0x0707070706060606})
+ };
+ uint8x16x2_t smask = { vcombine_u8(vdup_n_u8(1), vdup_n_u8(2)), vcombine_u8(vdup_n_u8(4), vdup_n_u8(8)) };
+ auto values = vld1q_s8(iq3nl_values);
+ int nbl = n / QK_K;
+ int8x16_t qx[4];
+ int8x16x4_t i8scales;
+ int16x8x4_t i16scales;
+ float32x4_t acc[nrc_y] = {};
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ const block_iq3_k_r4 * iq3 = (const block_iq3_k_r4 *)((const char *)vx + ix*bx);
+ for (int ibl = 0; ibl < nbl; ++ibl) {
+ auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq3[ibl].d));
+ auto extra8 = vld1_u8(iq3[ibl].extra);
+ uint8x16_t extra;
+ if constexpr (nrc_y == 1) {
+ extra = vcombine_u8(extra8, vshr_n_u8(extra8,1));
+ } else {
+ extra = vcombine_u8(extra8, extra8);
+ }
+ auto sl = vld1q_u8_x2(iq3[ibl].scales_l);
+ auto sh8 = vld1_u8(iq3[ibl].scales_h);
+ auto sh = vcombine_u8(sh8, sh8);
+ i8scales.val[0] = vaddq_s8(vshlq_n_u8(vandq_u8(sl.val[0], m4), 1), vdupq_n_s8(1));
+ i8scales.val[1] = vaddq_s8(vshlq_n_u8(vandq_u8(sl.val[1], m4), 1), vdupq_n_s8(1));
+ i8scales.val[2] = vaddq_s8(vshlq_n_u8(vshrq_n_u8(sl.val[0], 4), 1), vdupq_n_s8(1));
+ i8scales.val[3] = vaddq_s8(vshlq_n_u8(vshrq_n_u8(sl.val[1], 4), 1), vdupq_n_s8(1));
+ i8scales.val[0] = vmulq_s8(i8scales.val[0], vorrq_u8(vceqq_u8(vandq_u8(sh, smask.val[0]), smask.val[0]), vdupq_n_u8(1)));
+ i8scales.val[1] = vmulq_s8(i8scales.val[1], vorrq_u8(vceqq_u8(vandq_u8(sh, smask.val[1]), smask.val[1]), vdupq_n_u8(1)));
+ sh = vshrq_n_u8(sh, 4);
+ i8scales.val[2] = vmulq_s8(i8scales.val[2], vorrq_u8(vceqq_u8(vandq_u8(sh, smask.val[0]), smask.val[0]), vdupq_n_u8(1)));
+ i8scales.val[3] = vmulq_s8(i8scales.val[3], vorrq_u8(vceqq_u8(vandq_u8(sh, smask.val[1]), smask.val[1]), vdupq_n_u8(1)));
+ int32x4_t isum[nrc_y] = {};
+ if constexpr (nrc_y == 1) {
+ iq3_4_add_shift(ibl, q8, i8scales, extra, ms, isum);
+ }
+ for (int is = 0; is < 2; ++is) {
+ i16scales.val[0] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+0]));
+ i16scales.val[1] = vmovl_s8(vget_high_s8(i8scales.val[2*is+0]));
+ i16scales.val[2] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+1]));
+ i16scales.val[3] = vmovl_s8(vget_high_s8(i8scales.val[2*is+1]));
+ for (int ib = 0; ib < 4; ++ib) {
+ auto scales = vmovl_s16(vget_low_s16 (i16scales.val[ib]));
+ auto lbits = vld1q_u8_x2(iq3[ibl].qs + 128*is + 32*ib);
+ auto hbits = vld1q_u8(iq3[ibl].qh + 64*is + 16*ib);
+ qx[0] = vorrq_u8(vandq_u8( lbits.val[0], m03), vandq_u8(m04, vshlq_n_u8(hbits, 2)));
+ qx[1] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[0], 2), m03), vandq_u8(m04, vshlq_n_u8(hbits, 1)));
+ qx[2] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[0], 4), m03), vandq_u8(m04, hbits));
+ qx[3] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[0], 6), m03), vandq_u8(m04, vshrq_n_u8(hbits, 1)));
+ uint8x16_t shifts;
+ if constexpr (nrc_y == 1) {
+ qx[0] = vqtbl1q_s8(values, qx[0]); // 0...3 from the 4 rows
+ qx[1] = vqtbl1q_s8(values, qx[1]); // 4...7
+ qx[2] = vqtbl1q_s8(values, qx[2]); // 8..11
+ qx[3] = vqtbl1q_s8(values, qx[3]); // 12..15
+ } else {
+ shifts = vandq_u8(ms, vshlq_n_u8(extra, 3));
+ auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[0]);
+ extra = vshrq_n_u8(extra, 1);
+ qx[0] = vqtbl1q_s8(values, vaddq_u8(shift, qx[0])); // 0...3 from the 4 rows
+ qx[1] = vqtbl1q_s8(values, vaddq_u8(shift, qx[1])); // 4...7
+ qx[2] = vqtbl1q_s8(values, vaddq_u8(shift, qx[2])); // 8..11
+ qx[3] = vqtbl1q_s8(values, vaddq_u8(shift, qx[3])); // 12..15
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib);
+ auto sumi = interleaved_dotq(qx, y);
+ isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
+ }
+ qx[0] = vorrq_u8(vandq_u8( lbits.val[1], m03), vandq_u8(m04, vshrq_n_u8(hbits, 2)));
+ qx[1] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[1], 2), m03), vandq_u8(m04, vshrq_n_u8(hbits, 3)));
+ qx[2] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[1], 4), m03), vandq_u8(m04, vshrq_n_u8(hbits, 4)));
+ qx[3] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[1], 6), m03), vandq_u8(m04, vshrq_n_u8(hbits, 5)));
+ if constexpr (nrc_y == 1) {
+ qx[0] = vqtbl1q_s8(values, qx[0]); // 0...3 from the 4 rows
+ qx[1] = vqtbl1q_s8(values, qx[1]); // 4...7
+ qx[2] = vqtbl1q_s8(values, qx[2]); // 8..11
+ qx[3] = vqtbl1q_s8(values, qx[3]); // 12..15
+ } else {
+ auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[1]);
+ qx[0] = vqtbl1q_s8(values, vaddq_u8(shift, qx[0])); // 0...3 from the 4 rows
+ qx[1] = vqtbl1q_s8(values, vaddq_u8(shift, qx[1])); // 4...7
+ qx[2] = vqtbl1q_s8(values, vaddq_u8(shift, qx[2])); // 8..11
+ qx[3] = vqtbl1q_s8(values, vaddq_u8(shift, qx[3])); // 12..15
+ }
+ scales = vmovl_s16(vget_high_s16(i16scales.val[ib]));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib+16);
+ auto sumi = interleaved_dotq(qx, y);
+ isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
+ }
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, acc[iy]);
+ acc[iy] = vdupq_n_f32(0.f);
+ }
+ }
+}
+
+template <int nrc_y>
void mul_mat_iq4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
@@ -8880,42 +9161,7 @@ void mul_mat_iq4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& in
i8scales.val[3] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl.val[1], 4), vandq_u8(vshrq_n_u8(sh, 2), m3)), m32);
int32x4_t isum[nrc_y] = {};
if constexpr (nrc_y == 1) {
- auto s8_1 = vmulq_s8(i8scales.val[0], vandq_u8(ms, vshlq_n_u8(extra, 2)));
- auto s8_2 = vmulq_s8(i8scales.val[1], vandq_u8(ms, extra));
- auto s16_1 = vmovl_s8(vget_low_s8 (s8_1));
- auto s16_2 = vmovl_s8(vget_high_s8(s8_1));
- auto s16_3 = vmovl_s8(vget_low_s8 (s8_2));
- auto s16_4 = vmovl_s8(vget_high_s8(s8_2));
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto b8 = vld1_s16(q8.y[iy][ibl].bsums);
- isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_1), b8, 0);
- isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_1), b8, 1);
- isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_2), b8, 2);
- isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_2), b8, 3);
- b8 = vld1_s16(q8.y[iy][ibl].bsums+4);
- isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_3), b8, 0);
- isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_3), b8, 1);
- isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_4), b8, 2);
- isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_4), b8, 3);
- }
- s8_1 = vmulq_s8(i8scales.val[2], vandq_u8(ms, vshrq_n_u8(extra, 2)));
- s8_2 = vmulq_s8(i8scales.val[3], vandq_u8(ms, vshrq_n_u8(extra, 4)));
- s16_1 = vmovl_s8(vget_low_s8 (s8_1));
- s16_2 = vmovl_s8(vget_high_s8(s8_1));
- s16_3 = vmovl_s8(vget_low_s8 (s8_2));
- s16_4 = vmovl_s8(vget_high_s8(s8_2));
- for (int iy = 0; iy < nrc_y; ++iy) {
- auto b8 = vld1_s16(q8.y[iy][ibl].bsums+8);
- isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_1), b8, 0);
- isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_1), b8, 1);
- isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_2), b8, 2);
- isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_2), b8, 3);
- b8 = vld1_s16(q8.y[iy][ibl].bsums+12);
- isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_3), b8, 0);
- isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_3), b8, 1);
- isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_4), b8, 2);
- isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_4), b8, 3);
- }
+ iq3_4_add_shift(ibl, q8, i8scales, extra, ms, isum);
}
for (int is = 0; is < 2; ++is) {
i16scales.val[0] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+0]));
@@ -9803,6 +10049,10 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
SET_MUL_MAT_FUNCTIONS(m, mul_mat_q8_k_r8_q8_k);
expected_Btype = GGML_TYPE_Q8_KR8;
break;
+ case GGML_TYPE_IQ3_K_R4:
+ SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq3_k_r4_q8_k);
+ expected_Btype = GGML_TYPE_Q8_K;
+ break;
case GGML_TYPE_IQ4_K_R4:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq4_k_r4_q8_k);
expected_Btype = GGML_TYPE_Q8_K;
diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp
index abe81858..373a15bb 100644
--- a/ggml/src/iqk/iqk_quantize.cpp
+++ b/ggml/src/iqk/iqk_quantize.cpp
@@ -4795,3 +4795,139 @@ void repack_bf16_bf16_r16(const void * GGML_RESTRICT src, void * GGML_RESTRICT d
repack_bf16(nrows, n_per_row, (const ggml_bf16_t *)src, (ggml_bf16_t *)dst);
}
+//
+// ========================================= iq3_k_r4
+//
+
+void quantize_row_iq3_k_r4_ref(const float * x, block_iq3_k_r4 * y, int64_t k) {
+ quantize_iq3_k_r4(x, (void *)y, 4, k/4, nullptr);
+}
+
+void quantize_row_iq3_k_r4(const float * x, void * y, int64_t k) {
+ quantize_iq3_k_r4(x, y, 4, k/4, nullptr);
+}
+
+namespace {
+inline void convert_iq3_k(const block_iq3_k& x, uint8_t * L) {
+ const uint8_t * qs = x.qs;
+ const uint8_t * qh = x.qh;
+ for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
+ int shift_l = 2*(ib32%4);
+ int shift_h = ib32%8;
+ for (int j = 0; j < 16; ++j) {
+ L[j+ 0] = ((qs[j+ 0] >> shift_l) & 3) | (((qh[j+ 0] >> shift_h) & 1) << 2);
+ L[j+16] = ((qs[j+16] >> shift_l) & 3) | (((qh[j+16] >> shift_h) & 1) << 2);
+ }
+ L += 32;
+ if (shift_l == 6) qs += 32;
+ }
+}
+}
+
+static void repack_iq3_k(int nrows, int n_per_row, const block_iq3_k * x, block_iq3_k_r4 * y) {
+ GGML_ASSERT(nrows%4 == 0);
+ GGML_ASSERT(n_per_row%QK_K == 0);
+ int nblock = n_per_row/QK_K;
+ const block_iq3_k * x4[4];
+ uint8_t L[QK_K];
+ for (int row = 0; row < nrows; row += 4) {
+ for (int k = 0; k < 4; ++k) x4[k] = x + nblock*k;
+ for (int ibl = 0; ibl < nblock; ++ibl) {
+ std::memset(y[ibl].extra, 0, 8);
+ std::memset(y[ibl].scales_l, 0, QK_K/8);
+ std::memset(y[ibl].scales_h, 0, QK_K/32);
+ for (int k = 0; k < 4; ++k) {
+ y[ibl].d[k] = x4[k][ibl].d;
+ auto extra = x4[k][ibl].extra;
+ uint16_t sh = x4[k][ibl].scales_h;
+ convert_iq3_k(x4[k][ibl], L);
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ if (extra & 1) y[ibl].extra[k+0] |= (1 << ib);
+ if (extra & 2) y[ibl].extra[k+4] |= (1 << ib);
+ extra >>= 2;
+ uint8_t sl1 = x4[k][ibl].scales_l[ib] & 0xf;
+ uint8_t sl2 = x4[k][ibl].scales_l[ib] >> 4;
+ uint8_t sh1 = (sh >> 0) & 1;
+ uint8_t sh2 = (sh >> 1) & 1;
+ sh >>= 2;
+ int i = 8*ib + k;
+ y[ibl].scales_l[i%32] |= (sl1 << 4*(i/32));
+ y[ibl].scales_h[i%8 ] |= (sh1 << (i/8));
+ i += 4;
+ y[ibl].scales_l[i%32] |= (sl2 << 4*(i/32));
+ y[ibl].scales_h[i%8 ] |= (sh2 << (i/8));
+ for (int i = 0; i < 4; ++i) {
+ y[ibl].qs[32*ib+4*k+i+ 0] = ((L[32*ib+i+ 0] & 0x3) << 0) | ((L[32*ib+i+ 4] & 0x3) << 2) | ((L[32*ib+i+ 8] & 0x3) << 4) | ((L[32*ib+i+12] & 0x3) << 6);
+ y[ibl].qs[32*ib+4*k+i+16] = ((L[32*ib+i+16] & 0x3) << 0) | ((L[32*ib+i+20] & 0x3) << 2) | ((L[32*ib+i+24] & 0x3) << 4) | ((L[32*ib+i+28] & 0x3) << 6);
+ y[ibl].qh[16*ib+4*k+i+ 0] = ((L[32*ib+i+ 0] >> 2) << 0) | ((L[32*ib+i+ 4] >> 2) << 1) | ((L[32*ib+i+ 8] >> 2) << 2) | ((L[32*ib+i+12] >> 2) << 3)
+ | ((L[32*ib+i+16] >> 2) << 4) | ((L[32*ib+i+20] >> 2) << 5) | ((L[32*ib+i+24] >> 2) << 6) | ((L[32*ib+i+28] >> 2) << 7);
+ }
+ }
+ }
+ }
+ x += 4*nblock;
+ y += nblock;
+ }
+}
+
+size_t quantize_iq3_k_r4(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) {
+ GGML_ASSERT(nrows%4 == 0);
+ GGML_ASSERT(n_per_row%QK_K == 0);
+ char * qcur = (char *)dst;
+ auto row_size = ggml_row_size(GGML_TYPE_IQ3_K, n_per_row);
+ std::vector<char> qtmp(4*row_size);
+ for (int row = 0; row < nrows; row += 4) {
+ quantize_iq3_k(src, (void *)qtmp.data(), 4, n_per_row, imatrix);
+ repack_iq3_k(4, n_per_row, (const block_iq3_k *)qtmp.data(), (block_iq3_k_r4 *)qcur);
+ qcur += 4*row_size;
+ src += 4*n_per_row;
+ }
+ return nrows*row_size;
+}
+
+void dequantize_row_iq3_k_r4(const block_iq3_k_r4 * x, float * y, int64_t k) {
+ auto n_per_row = k/4;
+ float * y4[4] = {y, y + n_per_row, y + 2*n_per_row, y + 3*n_per_row};
+ int nblock = n_per_row/QK_K;
+ for (int ibl = 0; ibl < nblock; ++ibl) {
+ for (int k = 0; k < 4; ++k) {
+ const float d = GGML_FP16_TO_FP32(x[ibl].d[k]);
+ auto ql = x[ibl].qs;
+ auto qh = x[ibl].qh;
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ int is = 8*ib + k;
+ float dl1 = d * (2*((x[ibl].scales_l[is%32] >> 4*(is/32)) & 0xf) + 1) * ((x[ibl].scales_h[is%8] >> (is/8)) & 1 ? -1 : 1);
+ is += 4;
+ float dl2 = d * (2*((x[ibl].scales_l[is%32] >> 4*(is/32)) & 0xf) + 1) * ((x[ibl].scales_h[is%8] >> (is/8)) & 1 ? -1 : 1);
+ auto values1 = iq3nl_values + (x[ibl].extra[k+0] & (1 << ib) ? 8 : 0);
+ auto values2 = iq3nl_values + (x[ibl].extra[k+4] & (1 << ib) ? 8 : 0);
+ for (int i = 0; i < 4; ++i) {
+ y4[k][QK_K*ibl+32*ib+i+ 0] = dl1 * values1[((ql[4*k+i+ 0] >> 0) & 3) | ((qh[4*k+i] << 2) & 4)];
+ y4[k][QK_K*ibl+32*ib+i+ 4] = dl1 * values1[((ql[4*k+i+ 0] >> 2) & 3) | ((qh[4*k+i] << 1) & 4)];
+ y4[k][QK_K*ibl+32*ib+i+ 8] = dl1 * values1[((ql[4*k+i+ 0] >> 4) & 3) | ((qh[4*k+i] << 0) & 4)];
+ y4[k][QK_K*ibl+32*ib+i+12] = dl1 * values1[((ql[4*k+i+ 0] >> 6) & 3) | ((qh[4*k+i] >> 1) & 4)];
+ y4[k][QK_K*ibl+32*ib+i+16] = dl2 * values2[((ql[4*k+i+16] >> 0) & 3) | ((qh[4*k+i] >> 2) & 4)];
+ y4[k][QK_K*ibl+32*ib+i+20] = dl2 * values2[((ql[4*k+i+16] >> 2) & 3) | ((qh[4*k+i] >> 3) & 4)];
+ y4[k][QK_K*ibl+32*ib+i+24] = dl2 * values2[((ql[4*k+i+16] >> 4) & 3) | ((qh[4*k+i] >> 4) & 4)];
+ y4[k][QK_K*ibl+32*ib+i+28] = dl2 * values2[((ql[4*k+i+16] >> 6) & 3) | ((qh[4*k+i] >> 5) & 4)];
+ }
+ ql += 32;
+ qh += 16;
+ }
+ }
+ }
+}
+
+void vec_dot_iq3_k_r4_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
+#if GGML_USE_IQK_MULMAT
+ if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ3_K_R4, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) {
+ return;
+ }
+#endif
+ GGML_ASSERT(n%QK4_NL == 0);
+ GGML_ASSERT(nrc == 1);
+ GGML_UNUSED(bs);
+ GGML_UNUSED(bx);
+ GGML_UNUSED(by);
+}
+
diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h
index e8721a5e..1ca66bd8 100644
--- a/ggml/src/iqk/iqk_quantize.h
+++ b/ggml/src/iqk/iqk_quantize.h
@@ -145,6 +145,12 @@ size_t quantize_iq4_k_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT d
void dequantize_row_iq4_k_r4(const block_iq4_k_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void vec_dot_iq4_k_r4_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void quantize_row_iq3_k_r4_ref(const float * GGML_RESTRICT x, block_iq3_k_r4 * GGML_RESTRICT y, int64_t k);
+void quantize_row_iq3_k_r4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+size_t quantize_iq3_k_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+void dequantize_row_iq3_k_r4(const block_iq3_k_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void vec_dot_iq3_k_r4_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
void quantize_row_q8_k_r8_ref(const float * GGML_RESTRICT x, block_q8_k_r8 * GGML_RESTRICT y, int64_t k);
void quantize_row_q8_k_r8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
size_t quantize_q8_k_r8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);