summaryrefslogtreecommitdiff
path: root/ggml/src
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2024-12-03 14:48:26 +0100
committerGitHub <noreply@github.com>2024-12-03 14:48:26 +0100
commitf1f4eb988fe5ee969100cd0d3782fd7460d13949 (patch)
tree97bb1a75ba7189f05e82835de6b2b65661a1ce7a /ggml/src
parentc5bf589367cd609f4c0ff73a6534bbde7902abe8 (diff)
Q6_0_R4 (#122)
* Adding q6_0_r4 We get PP-512(LLaMA-3.1-8B) = 257 t/s on a Ryzen-7950X. * q6_0_r4: NEON We get PP-512(LLaMA-3.1-8B) = 95 t/s on M2-Max. In terms of ops, q6_0_r4 is identical to q5_0_r4 except for loading the high bits being vld1q_u8_x2 instead of vld1q_u8. It is strange that this can make a 5% difference in performance, especially considering that this is amortized (re-used) over 8 columns in the right matrix. Or am I running out of vector registers? * Fix AVX2 --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml/src')
-rw-r--r--ggml/src/ggml-quants.c1
-rw-r--r--ggml/src/ggml.c26
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp194
-rw-r--r--ggml/src/iqk/iqk_quantize.cpp97
-rw-r--r--ggml/src/iqk/iqk_quantize.h6
5 files changed, 324 insertions, 0 deletions
diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c
index 1953fb7e..94950a36 100644
--- a/ggml/src/ggml-quants.c
+++ b/ggml/src/ggml-quants.c
@@ -15199,6 +15199,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
case GGML_TYPE_IQ4_NL_X4: break;
case GGML_TYPE_Q4_0_R4: break;
case GGML_TYPE_Q5_0_R4: break;
+ case GGML_TYPE_Q6_0_R4: break;
case GGML_TYPE_Q8_0_R4: break;
case GGML_TYPE_Q4_0_4_4:
case GGML_TYPE_Q4_0_4_8:
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 0eb76a07..203b1b57 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -1313,6 +1313,23 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.nrows = 1,
.row_meta_size = 0,
},
+ [GGML_TYPE_Q6_0_R4] = {
+ .type_name = "q6_0_r4",
+ .blck_size = QK6_0,
+ .type_size = sizeof(block_q6_0),
+ .is_quantized = true,
+ .to_float = (ggml_to_float_t) dequantize_row_q6_0_r4,
+ .from_float = quantize_row_q6_0_r4,
+ .from_float_ref = (ggml_from_float_t)quantize_row_q6_0_r4_ref,
+ .vec_dot = vec_dot_q6_0_r4_q8_0,
+#if GGML_USE_IQK_MULMAT && defined __AVX2__
+ .vec_dot_type = GGML_TYPE_Q8_1,
+#else
+ .vec_dot_type = GGML_TYPE_Q8_0,
+#endif
+ .nrows = 1,
+ .row_meta_size = 0,
+ },
};
// For internal test use
@@ -3974,6 +3991,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
case GGML_FTYPE_MOSTLY_IQ4_NL_X4: wtype = GGML_TYPE_IQ4_NL_X4;break;
case GGML_FTYPE_MOSTLY_Q4_0_R4: wtype = GGML_TYPE_Q4_0_R4; break;
case GGML_FTYPE_MOSTLY_Q5_0_R4: wtype = GGML_TYPE_Q5_0_R4; break;
+ case GGML_FTYPE_MOSTLY_Q6_0_R4: wtype = GGML_TYPE_Q6_0_R4; break;
case GGML_FTYPE_MOSTLY_Q8_0_R4: wtype = GGML_TYPE_Q8_0_R4; break;
case GGML_FTYPE_MOSTLY_IQ4_XS: wtype = GGML_TYPE_IQ4_XS; break;
case GGML_FTYPE_MOSTLY_IQ4_KS: wtype = GGML_TYPE_IQ4_KS; break;
@@ -10501,6 +10519,7 @@ static void ggml_compute_forward_add(
case GGML_TYPE_IQ4_NL_X4:
case GGML_TYPE_Q4_0_R4:
case GGML_TYPE_Q5_0_R4:
+ case GGML_TYPE_Q6_0_R4:
case GGML_TYPE_Q8_0_R4:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_KS:
@@ -10947,6 +10966,7 @@ static void ggml_compute_forward_add1(
case GGML_TYPE_IQ4_NL_X4:
case GGML_TYPE_Q4_0_R4:
case GGML_TYPE_Q5_0_R4:
+ case GGML_TYPE_Q6_0_R4:
case GGML_TYPE_Q8_0_R4:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_KS:
@@ -11090,6 +11110,7 @@ static void ggml_compute_forward_acc(
case GGML_TYPE_IQ4_NL_X4:
case GGML_TYPE_Q4_0_R4:
case GGML_TYPE_Q5_0_R4:
+ case GGML_TYPE_Q6_0_R4:
case GGML_TYPE_Q8_0_R4:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_KS:
@@ -14279,6 +14300,7 @@ static void ggml_compute_forward_out_prod(
case GGML_TYPE_IQ4_NL_X4:
case GGML_TYPE_Q4_0_R4:
case GGML_TYPE_Q5_0_R4:
+ case GGML_TYPE_Q6_0_R4:
case GGML_TYPE_Q8_0_R4:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_KS:
@@ -14662,6 +14684,7 @@ static void ggml_compute_forward_set(
case GGML_TYPE_IQ4_NL_X4:
case GGML_TYPE_Q4_0_R4:
case GGML_TYPE_Q5_0_R4:
+ case GGML_TYPE_Q6_0_R4:
case GGML_TYPE_Q8_0_R4:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_KS:
@@ -14939,6 +14962,7 @@ static void ggml_compute_forward_get_rows(
case GGML_TYPE_IQ4_NL_X4:
case GGML_TYPE_Q4_0_R4:
case GGML_TYPE_Q5_0_R4:
+ case GGML_TYPE_Q6_0_R4:
case GGML_TYPE_Q8_0_R4:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_KS:
@@ -15543,6 +15567,7 @@ static void ggml_compute_forward_clamp(
case GGML_TYPE_IQ4_NL_X4:
case GGML_TYPE_Q4_0_R4:
case GGML_TYPE_Q5_0_R4:
+ case GGML_TYPE_Q6_0_R4:
case GGML_TYPE_Q8_0_R4:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_KS:
@@ -22373,6 +22398,7 @@ size_t ggml_quantize_chunk(
case GGML_TYPE_IQ4_NL_X4: result = quantize_iq4_nl_x4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q4_0_R4: result = quantize_q4_0_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q5_0_R4: result = quantize_q5_0_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_Q6_0_R4: result = quantize_q6_0_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q8_0_R4: result = quantize_q8_0_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ4_KS: result = quantize_iq4_ks (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 4cdc1a08..f827e460 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -2400,6 +2400,128 @@ static void mul_mat_q5_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn
}
#endif
+template <int nrc_y>
+static void mul_mat_q6_0_r4_q8_1_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%8 == 0);
+ Q8<nrc_y, block_q8_1_x4> q8(info);
+ auto m4 = _mm256_set1_epi8(0xf);
+ auto m6 = _mm256_set1_epi8(0x30);
+#ifndef HAVE_FANCY_SIMD
+ auto m1 = _mm256_set1_epi16(1);
+#endif
+ int nb = n / QK6_0;
+ GGML_ASSERT(nb%4 == 0);
+ __m256 acc[nrc_y] = {};
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ const block_q6_0_r4 * iq6 = (const block_q6_0_r4 *)((const char *)vx + ix*bx);
+ for (int ib4 = 0; ib4 < nb/4; ++ib4) {
+ for (int k = 0; k < 4; ++k) {
+ auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq6[4*ib4+k].d));
+ auto scales = _mm256_set_m128(scales128, scales128);
+ auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-16.f));
+ auto bits1 = _mm256_loadu_si256((const __m256i *)iq6[4*ib4+k].qs+0);
+ auto bits2 = _mm256_loadu_si256((const __m256i *)iq6[4*ib4+k].qs+1);
+ auto hbits = _mm256_loadu_si256((const __m256i *)iq6[4*ib4+k].qh);
+ auto q1 = _mm256_and_si256(bits1, m4) | _mm256_and_si256(_mm256_slli_epi16(hbits, 4), m6);
+ auto q2 = _mm256_and_si256(bits2, m4) | _mm256_and_si256(_mm256_slli_epi16(hbits, 2), m6);
+ auto q3 = _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4) | _mm256_and_si256(hbits, m6);
+ auto q4 = _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4) | _mm256_and_si256(_mm256_srli_epi16(hbits, 2), m6);;
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k);
+#ifdef HAVE_FANCY_SIMD
+ auto sumi = _mm256_dpbusd_epi32(_mm256_setzero_si256(), q1, _mm256_shuffle_epi32(y, 0x00));
+ sumi = _mm256_dpbusd_epi32(sumi, q2, _mm256_shuffle_epi32(y, 0x55));
+ sumi = _mm256_dpbusd_epi32(sumi, q3, _mm256_shuffle_epi32(y, 0xaa));
+ sumi = _mm256_dpbusd_epi32(sumi, q4, _mm256_shuffle_epi32(y, 0xff));
+#else
+ auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(q1, _mm256_shuffle_epi32(y, 0x00)),
+ _mm256_maddubs_epi16(q2, _mm256_shuffle_epi32(y, 0x55)));
+ auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(q3, _mm256_shuffle_epi32(y, 0xaa)),
+ _mm256_maddubs_epi16(q4, _mm256_shuffle_epi32(y, 0xff)));
+ auto sumi = _mm256_add_epi32(_mm256_madd_epi16(m1, sumi1), _mm256_madd_epi16(m1, sumi2));
+#endif
+ auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k])));
+ acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
+ acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k+4])), acc[iy]);
+ }
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
+ info.store(ix, iy, sum);
+ acc[iy] = _mm256_setzero_ps();
+ }
+ }
+}
+
+#ifdef HAVE_FANCY_SIMD
+template <int nrc_y>
+static void mul_mat_q6_0_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ if constexpr (nrc_y == 1) {
+ mul_mat_q6_0_r4_q8_1_avx2<1>(n, vx, bx, info, nrc_x);
+ } else {
+ GGML_ASSERT(nrc_x%8 == 0);
+ Q8<nrc_y, block_q8_1_x4> q8(info);
+ auto m4 = _mm512_set1_epi8(0xf);
+ auto m6 = _mm512_set1_epi8(0x30);
+ int nb = n / QK6_0;
+ GGML_ASSERT(nb%4 == 0);
+ __m512 acc[2*nrc_y] = {};
+ __m512i qx[4];
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ const block_q6_0_r4 * iq6l = (const block_q6_0_r4 *)((const char *)vx + (ix+0)*bx);
+ const block_q6_0_r4 * iq6h = (const block_q6_0_r4 *)((const char *)vx + (ix+4)*bx);
+ for (int ib4 = 0; ib4 < nb/4; ++ib4) {
+ for (int k = 0; k < 4; ++k) {
+ auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq6l[4*ib4+k].d));
+ auto scales1 = _mm256_set_m128(scales128, scales128);
+ scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq6h[4*ib4+k].d));
+ auto scales2 = _mm256_set_m128(scales128, scales128);
+ auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1);
+ auto scales_m = _mm512_mul_ps(scales, _mm512_set1_ps(-16.f));
+ auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq6l[4*ib4+k].qs+0)),
+ _mm256_loadu_si256((const __m256i *)iq6h[4*ib4+k].qs+0), 1);
+ auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq6l[4*ib4+k].qs+1)),
+ _mm256_loadu_si256((const __m256i *)iq6h[4*ib4+k].qs+1), 1);
+ auto hbits1 = _mm256_loadu_si256((const __m256i *)iq6l[4*ib4+k].qh);
+ auto hbits2 = _mm256_loadu_si256((const __m256i *)iq6h[4*ib4+k].qh);
+ auto hb = _mm512_inserti32x8(_mm512_castsi256_si512(hbits1), hbits2, 1);
+ qx[0] = _mm512_and_si512(bits1, m4) | _mm512_and_si512(_mm512_slli_epi16(hb, 4), m6);
+ qx[1] = _mm512_and_si512(bits2, m4) | _mm512_and_si512(_mm512_slli_epi16(hb, 2), m6);;
+ qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4) | _mm512_and_si512(hb, m6);
+ qx[3] = _mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4) | _mm512_and_si512(_mm512_srli_epi16(hb, 2), m6);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k);
+ auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1);
+ auto sumi = _mm512_setzero_si512();
+ sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
+ auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k]));
+ acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
+ acc[2*iy+1] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k+4])), acc[2*iy+1]);
+ }
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sum512 = _mm512_add_ps(acc[2*iy+0], acc[2*iy+1]);
+ acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps();
+ auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 0), _mm512_extractf32x4_ps(sum512, 1));
+ auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 2), _mm512_extractf32x4_ps(sum512, 3));
+ info.store(ix+0, iy, sum1);
+ info.store(ix+4, iy, sum2);
+ }
+ }
+ }
+}
+#else
+template <int nrc_y>
+static void mul_mat_q6_0_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ mul_mat_q6_0_r4_q8_1_avx2<nrc_y>(n, vx, bx, info, nrc_x);
+}
+#endif
+
#ifdef HAVE_FANCY_SIMD
template <int nrc_y>
static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
@@ -4527,6 +4649,18 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
mm.funcs[7] = mul_mat_q5_0_r4_q8_1<8>;
expected_typeB = GGML_TYPE_Q8_1;
break;
+ case GGML_TYPE_Q6_0_R4:
+ assert (ne00 % QK4_NL == 0);
+ mm.funcs[0] = mul_mat_q6_0_r4_q8_1<1>;
+ mm.funcs[1] = mul_mat_q6_0_r4_q8_1<2>;
+ mm.funcs[2] = mul_mat_q6_0_r4_q8_1<3>;
+ mm.funcs[3] = mul_mat_q6_0_r4_q8_1<4>;
+ mm.funcs[4] = mul_mat_q6_0_r4_q8_1<5>;
+ mm.funcs[5] = mul_mat_q6_0_r4_q8_1<6>;
+ mm.funcs[6] = mul_mat_q6_0_r4_q8_1<7>;
+ mm.funcs[7] = mul_mat_q6_0_r4_q8_1<8>;
+ expected_typeB = GGML_TYPE_Q8_1;
+ break;
case GGML_TYPE_Q8_0_R4:
assert (ne00 % QK4_NL == 0);
mm.funcs[0] = mul_mat_q8_0_r4_q8_1<1>;
@@ -7130,6 +7264,55 @@ void mul_mat_q5_0_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& inf
}
template <int nrc_y>
+void mul_mat_q6_0_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_0_x4> q8(info);
+ auto m4 = vdupq_n_u8(0x0f);
+ auto m6 = vdupq_n_u8(0x30);
+ auto m32 = vdupq_n_s8(-32);
+ int nb = n / QK6_0;
+ GGML_ASSERT(nb%4 == 0);
+ int8x16_t qx[8];
+ float32x4_t acc[nrc_y] = {};
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ const block_q6_0_r4 * iq6 = (const block_q6_0_r4 *)((const char *)vx + ix*bx);
+ for (int ib4 = 0; ib4 < nb/4; ++ib4) {
+ for (int k = 0; k < 4; ++k) {
+ auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq6[4*ib4+k].d));
+ auto lbits = vld1q_u8_x4(iq6[4*ib4+k].qs);
+ auto hbits = vld1q_u8_x2(iq6[4*ib4+k].qh);
+ qx[0] = vaddq_s8(vandq_u8(lbits.val[0], m4) | vandq_u8(vshlq_n_u8(hbits.val[0], 4), m6), m32); // 0...3
+ qx[1] = vaddq_s8(vandq_u8(lbits.val[1], m4) | vandq_u8(vshlq_n_u8(hbits.val[1], 4), m6), m32); // 16..19
+ qx[2] = vaddq_s8(vandq_u8(lbits.val[2], m4) | vandq_u8(vshlq_n_u8(hbits.val[0], 2), m6), m32); // 4...7
+ qx[3] = vaddq_s8(vandq_u8(lbits.val[3], m4) | vandq_u8(vshlq_n_u8(hbits.val[1], 2), m6), m32); // 20..23
+ qx[4] = vaddq_s8(vshrq_n_u8(lbits.val[0], 4)| vandq_u8(hbits.val[0], m6), m32); // 8..11
+ qx[5] = vaddq_s8(vshrq_n_u8(lbits.val[1], 4)| vandq_u8(hbits.val[1], m6), m32); // 24..27
+ qx[6] = vaddq_s8(vshrq_n_u8(lbits.val[2], 4)| vandq_u8(vshrq_n_u8(hbits.val[0], 2), m6), m32); // 12..15
+ qx[7] = vaddq_s8(vshrq_n_u8(lbits.val[3], 4)| vandq_u8(vshrq_n_u8(hbits.val[1], 2), m6), m32); // 28..31
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = vld1q_s8_x2(q8.y[iy][ib4].qs+32*k);
+ auto sumi = vdupq_n_s32(0);
+ sumi = vdotq_laneq_s32(sumi, qx[0], y.val[0], 0);
+ sumi = vdotq_laneq_s32(sumi, qx[1], y.val[1], 0);
+ sumi = vdotq_laneq_s32(sumi, qx[2], y.val[0], 1);
+ sumi = vdotq_laneq_s32(sumi, qx[3], y.val[1], 1);
+ sumi = vdotq_laneq_s32(sumi, qx[4], y.val[0], 2);
+ sumi = vdotq_laneq_s32(sumi, qx[5], y.val[1], 2);
+ sumi = vdotq_laneq_s32(sumi, qx[6], y.val[0], 3);
+ sumi = vdotq_laneq_s32(sumi, qx[7], y.val[1], 3);
+ auto d4d8 = vmulq_f32(scales, vdupq_n_f32(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k])));
+ acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi));
+ }
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, acc[iy]);
+ acc[iy] = vdupq_n_f32(0.f);
+ }
+ }
+}
+
+template <int nrc_y>
void mul_mat_q8_0_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_0_x4> q8(info);
@@ -7368,6 +7551,17 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
m.funcs[7] = mul_mat_q5_0_r4_q8_0<8>;
expected_Btype = GGML_TYPE_Q8_0;
break;
+ case GGML_TYPE_Q6_0_R4:
+ m.funcs[0] = mul_mat_q6_0_r4_q8_0<1>;
+ m.funcs[1] = mul_mat_q6_0_r4_q8_0<2>;
+ m.funcs[2] = mul_mat_q6_0_r4_q8_0<3>;
+ m.funcs[3] = mul_mat_q6_0_r4_q8_0<4>;
+ m.funcs[4] = mul_mat_q6_0_r4_q8_0<5>;
+ m.funcs[5] = mul_mat_q6_0_r4_q8_0<6>;
+ m.funcs[6] = mul_mat_q6_0_r4_q8_0<7>;
+ m.funcs[7] = mul_mat_q6_0_r4_q8_0<8>;
+ expected_Btype = GGML_TYPE_Q8_0;
+ break;
case GGML_TYPE_Q8_0_R4:
m.funcs[0] = mul_mat_q8_0_r4_q8_0<1>;
m.funcs[1] = mul_mat_q8_0_r4_q8_0<2>;
diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp
index eafb2887..f2e6a45e 100644
--- a/ggml/src/iqk/iqk_quantize.cpp
+++ b/ggml/src/iqk/iqk_quantize.cpp
@@ -3475,3 +3475,100 @@ void vec_dot_q5_0_r4_q8_0(int n, float * s, size_t bs, const void * vx, size_t b
GGML_UNUSED(bx);
GGML_UNUSED(by);
}
+
+//
+// ========================================= q6_0_r4
+//
+void quantize_row_q6_0_r4_ref(const float * x, block_q6_0_r4 * y, int64_t k) {
+ // we assume we are called with 4 rows
+ quantize_q6_0_r4(x, (void *)y, 4, k/4, nullptr);
+}
+
+void quantize_row_q6_0_r4(const float * x, void * y, int64_t k) {
+ // we assume we are called with 4 rows
+ quantize_q6_0_r4(x, y, 4, k/4, nullptr);
+}
+
+static inline void convert_q6_0(const block_q6_0& x, uint8_t * L) {
+
+ for (int j = 0; j < QK6_0/2; ++j) {
+ const uint8_t h = x.qh[j%(QK6_0/4)] >> 4*(j/(QK6_0/4));
+ L[j ] = (x.qs[j] & 0x0F) | ((h << 4) & 0x30);
+ L[j + QK6_0/2] = (x.qs[j] >> 4) | ((h << 2) & 0x30);
+ }
+}
+
+static void repack_q6_0(int nrows, int n_per_row, const block_q6_0 * x, block_q6_0_r4 * y) {
+ GGML_ASSERT(nrows%4 == 0);
+ GGML_ASSERT(n_per_row%QK5_0 == 0);
+ int nblock = n_per_row/QK6_0;
+ const block_q6_0 * x4[4];
+ uint8_t L[QK6_0];
+ for (int row = 0; row < nrows; row += 4) {
+ for (int k = 0; k < 4; ++k) x4[k] = x + nblock*k;
+ for (int ib = 0; ib < nblock; ++ib) {
+ std::memset(y[ib].qh, 0, QK6_0);
+ for (int k = 0; k < 4; ++k) {
+ y[ib].d[k] = x4[k][ib].d;
+ convert_q6_0(x4[k][ib], L);
+ for (int l = 0; l < 4; ++l) {
+ int l1 = 4*(l/2) + 16*(l%2), l2 = l1 + 8;
+ for (int i = 0; i < 4; ++i) {
+ y[ib].qs[4*k+i+16*l] = (L[i + l1] & 0xf) | ((L[i + l2] & 0xf) << 4);
+ y[ib].qh[4*k+i+16*(l%2)] |= ((L[i + l1] >> 4) | ((L[i + l2] >> 4) << 4)) << 2*(l/2);
+ }
+ }
+ }
+ }
+ x += 4*nblock;
+ y += nblock;
+ }
+}
+
+size_t quantize_q6_0_r4(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) {
+ GGML_ASSERT(nrows%4 == 0);
+ auto row_size_0 = ggml_row_size(GGML_TYPE_Q6_0, n_per_row);
+ std::vector<char> qtmp(4*row_size_0);
+ char * qrow = (char *)dst;
+ for (int row = 0; row < nrows; row += 4) {
+ quantize_q6_0(src, qtmp.data(), 4, n_per_row, imatrix);
+ repack_q6_0(4, n_per_row, (const block_q6_0 *)qtmp.data(), (block_q6_0_r4 *)qrow);
+ src += 4*n_per_row;
+ qrow += 4*row_size_0;
+ }
+ return nrows*row_size_0;
+}
+
+void dequantize_row_q6_0_r4(const block_q6_0_r4 * x, float * y, int64_t k) {
+ // we assume we are called with 4 rows
+ int n_per_row = k/4;
+ int nb = n_per_row/QK6_0;
+ float * yk[4];
+ for (int k = 0; k < 4; ++k) yk[k] = y + k*n_per_row;
+ for (int ib = 0; ib < nb; ++ib) {
+ for (int k = 0; k < 4; ++k) {
+ float d = GGML_FP16_TO_FP32(x[ib].d[k]);
+ float m = -32*d;
+ for (int l = 0; l < 4; ++l) {
+ int ll = 16*(l%2) + 4*(l/2);
+ for (int i = 0; i < 4; ++i) {
+ yk[k][QK4_0*ib+i+ll+0] = d * ((x[ib].qs[4*k+i+16*l] & 0xf) | (((x[ib].qh[4*k+i+16*(l%2)] >> (2*(l/2)+0)) & 3) << 4)) + m;
+ yk[k][QK4_0*ib+i+ll+8] = d * ((x[ib].qs[4*k+i+16*l] >> 4) | (((x[ib].qh[4*k+i+16*(l%2)] >> (2*(l/2)+4)) & 3) << 4)) + m;
+ }
+ }
+ }
+ }
+}
+
+void vec_dot_q6_0_r4_q8_0(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
+#if GGML_USE_IQK_MULMAT
+ if (iqk_mul_mat(1, 1, n, GGML_TYPE_Q6_0_R4, vx, 0, GGML_TYPE_Q8_0, vy, 0, s, 0, 0, 1)) {
+ return;
+ }
+#endif
+ GGML_ASSERT(n%QK4_NL == 0);
+ GGML_ASSERT(nrc == 1);
+ GGML_UNUSED(bs);
+ GGML_UNUSED(bx);
+ GGML_UNUSED(by);
+}
diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h
index 24c241a2..3349c675 100644
--- a/ggml/src/iqk/iqk_quantize.h
+++ b/ggml/src/iqk/iqk_quantize.h
@@ -87,6 +87,12 @@ size_t quantize_q5_0_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT ds
void dequantize_row_q5_0_r4(const block_q5_0_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void vec_dot_q5_0_r4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void quantize_row_q6_0_r4_ref(const float * GGML_RESTRICT x, block_q6_0_r4 * GGML_RESTRICT y, int64_t k);
+void quantize_row_q6_0_r4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+size_t quantize_q6_0_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+void dequantize_row_q6_0_r4(const block_q6_0_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void vec_dot_q6_0_r4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
#ifdef __cplusplus
}
#endif