summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2024-12-10 12:26:40 +0100
committerGitHub <noreply@github.com>2024-12-10 12:26:40 +0100
commit361174ee6aee8792b4fbb227b9bc328bf9bd6eb9 (patch)
tree887c0cb982664604ef977cdac27657d7c61c88a4
parent3ec193b4856df8e5827b83a8c7686e8498c5e5b8 (diff)
Q6_K_R4 (#130)
* Adding q6_k_r4 * q6_k_r4: 1st functional AVX2 version * q6_k_r4: AVX2 and simple Zen4 "Simple" as in processing 4 instead of 8 rows at once. On Zen4 we get PP-512(LLaMA-3.1-8B) = 238.3 t/s vs 195.2 t/s for Q6_K. TG-128 @ 1 thread is 7.94 t/s vs 5.38 t/s for Q6_K. * q6_k_r4: 1st NEON version PP-512(LLaMA-3.1-8B) = 78 t/s vs 57.6 t/s for q6_K. TG-128 is slightly lower rthan q6_K for low number of threads, becomes very slightly better at 8 threads. * q6_k_r4: slightly faster NEON PP-512(LLaMA-3.1-8B) = 83.25 t/s * q6_k_r4: slightly faster Zen4 238.3 t/s -> 243.2 t/s --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--examples/quantize/quantize.cpp1
-rw-r--r--ggml/include/ggml.h2
-rw-r--r--ggml/src/ggml-common.h8
-rw-r--r--ggml/src/ggml-quants.c1
-rw-r--r--ggml/src/ggml.c22
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp208
-rw-r--r--ggml/src/iqk/iqk_quantize.cpp119
-rw-r--r--ggml/src/iqk/iqk_quantize.h6
-rw-r--r--include/llama.h1
-rw-r--r--src/llama.cpp15
10 files changed, 381 insertions, 2 deletions
diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp
index 731ea3ff..00cd3cf0 100644
--- a/examples/quantize/quantize.cpp
+++ b/examples/quantize/quantize.cpp
@@ -65,6 +65,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
{ "Q5_K_S", LLAMA_FTYPE_MOSTLY_Q5_K_S, " 4.33G, +0.0400 ppl @ LLaMA-v1-7B", },
{ "Q5_K_M", LLAMA_FTYPE_MOSTLY_Q5_K_M, " 4.45G, +0.0122 ppl @ LLaMA-v1-7B", },
{ "Q6_K", LLAMA_FTYPE_MOSTLY_Q6_K, " 5.15G, +0.0008 ppl @ LLaMA-v1-7B", },
+ { "Q6_K_R4", LLAMA_FTYPE_MOSTLY_Q6_K_R4, "Q6_K repacked", },
{ "Q8_0", LLAMA_FTYPE_MOSTLY_Q8_0, " 6.70G, +0.0004 ppl @ LLaMA-v1-7B", },
{ "Q4_0_4_4", LLAMA_FTYPE_MOSTLY_Q4_0_4_4, " 4.34G, +0.4685 ppl @ Llama-3-8B", },
{ "Q4_0_4_8", LLAMA_FTYPE_MOSTLY_Q4_0_4_8, " 4.34G, +0.4685 ppl @ Llama-3-8B", },
diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
index 38e01dd5..6486407f 100644
--- a/ggml/include/ggml.h
+++ b/ggml/include/ggml.h
@@ -413,6 +413,7 @@ extern "C" {
GGML_TYPE_Q5_0_R4 = 206,
GGML_TYPE_Q8_0_R4 = 208,
GGML_TYPE_Q4_K_R4 = 212,
+ GGML_TYPE_Q6_K_R4 = 214,
GGML_TYPE_IQ4_NL_R4 = 220,
GGML_TYPE_IQ4_XS_R4 = 223,
GGML_TYPE_Q6_0_R4 = 233,
@@ -480,6 +481,7 @@ extern "C" {
GGML_FTYPE_MOSTLY_Q8_0_R4 = 207, // except 1d tensors
GGML_FTYPE_MOSTLY_Q5_0_R4 = 208, // except 1d tensors
GGML_FTYPE_MOSTLY_Q4_K_R4 = 212, // except 1d tensors
+ GGML_FTYPE_MOSTLY_Q6_K_R4 = 214, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ4_NL_R4 = 219, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ4_XS_R4 = 222, // except 1d tensors
GGML_FTYPE_MOSTLY_Q6_0_R4 = 227, // except 1d tensors
diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h
index 784125fc..8ace3d6f 100644
--- a/ggml/src/ggml-common.h
+++ b/ggml/src/ggml-common.h
@@ -331,6 +331,14 @@ typedef struct {
} block_q6_K;
static_assert(sizeof(block_q6_K) == sizeof(ggml_half) + QK_K / 16 + 3*QK_K/4, "wrong q6_K block size/padding");
+typedef struct {
+ ggml_half d[4]; // super-block scale
+ int8_t scales[QK_K/4]; // scales, quantized with 8 bits
+ uint8_t qh[QK_K]; // quants, upper 2 bits
+ uint8_t ql[QK_K*2]; // quants, lower 4 bits
+} block_q6_k_r4;
+static_assert(sizeof(block_q6_k_r4) == 4*sizeof(ggml_half) + QK_K/4 + 3*QK_K, "wrong q6_k_r4 block size/padding");
+
// This is only used for intermediate quantization and dot products
typedef struct {
float d; // delta
diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c
index a4b234c5..67f54da7 100644
--- a/ggml/src/ggml-quants.c
+++ b/ggml/src/ggml-quants.c
@@ -15203,6 +15203,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
case GGML_TYPE_Q6_0_R4: break;
case GGML_TYPE_Q8_0_R4: break;
case GGML_TYPE_Q4_K_R4: break;
+ case GGML_TYPE_Q6_K_R4: break;
case GGML_TYPE_Q4_0_4_4:
case GGML_TYPE_Q4_0_4_8:
{
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index fadda3e3..b92c2352 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -927,6 +927,19 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.nrows = 1,
.row_meta_size = 0,
},
+ [GGML_TYPE_Q6_K_R4] = {
+ .type_name = "q6_k_r4",
+ .blck_size = QK_K,
+ .type_size = sizeof(block_q6_K),
+ .is_quantized = true,
+ .to_float = (ggml_to_float_t) dequantize_row_q6_k_r4,
+ .from_float = quantize_row_q6_k_r4,
+ .from_float_ref = (ggml_from_float_t) quantize_row_q6_k_r4_ref,
+ .vec_dot = vec_dot_q6_k_r4_q8_k,
+ .vec_dot_type = GGML_TYPE_Q8_K,
+ .nrows = 1,
+ .row_meta_size = 0,
+ },
[GGML_TYPE_IQ2_XXS] = {
.type_name = "iq2_xxs",
.blck_size = QK_K,
@@ -4036,6 +4049,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
case GGML_FTYPE_MOSTLY_Q4_K_R4: wtype = GGML_TYPE_Q4_K_R4; break;
case GGML_FTYPE_MOSTLY_Q5_K: wtype = GGML_TYPE_Q5_K; break;
case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break;
+ case GGML_FTYPE_MOSTLY_Q6_K_R4: wtype = GGML_TYPE_Q6_K_R4; break;
case GGML_FTYPE_MOSTLY_IQ2_XXS: wtype = GGML_TYPE_IQ2_XXS; break;
case GGML_FTYPE_MOSTLY_IQ2_XS: wtype = GGML_TYPE_IQ2_XS; break;
case GGML_FTYPE_MOSTLY_IQ3_XXS: wtype = GGML_TYPE_IQ3_XXS; break;
@@ -10567,6 +10581,7 @@ static void ggml_compute_forward_add(
case GGML_TYPE_Q4_K_R4:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
+ case GGML_TYPE_Q6_K_R4:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ3_XXS:
@@ -11017,6 +11032,7 @@ static void ggml_compute_forward_add1(
case GGML_TYPE_Q4_K_R4:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
+ case GGML_TYPE_Q6_K_R4:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ3_XXS:
@@ -11164,6 +11180,7 @@ static void ggml_compute_forward_acc(
case GGML_TYPE_Q4_K_R4:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
+ case GGML_TYPE_Q6_K_R4:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ3_XXS:
@@ -14357,6 +14374,7 @@ static void ggml_compute_forward_out_prod(
case GGML_TYPE_Q4_K_R4:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
+ case GGML_TYPE_Q6_K_R4:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ3_XXS:
@@ -14744,6 +14762,7 @@ static void ggml_compute_forward_set(
case GGML_TYPE_Q4_K_R4:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
+ case GGML_TYPE_Q6_K_R4:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ3_XXS:
@@ -15025,6 +15044,7 @@ static void ggml_compute_forward_get_rows(
case GGML_TYPE_Q4_K_R4:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
+ case GGML_TYPE_Q6_K_R4:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ3_XXS:
@@ -15633,6 +15653,7 @@ static void ggml_compute_forward_clamp(
case GGML_TYPE_Q4_K_R4:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
+ case GGML_TYPE_Q6_K_R4:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ3_XXS:
@@ -22469,6 +22490,7 @@ size_t ggml_quantize_chunk(
case GGML_TYPE_Q4_K_R4: result = quantize_q4_k_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q5_K: result = quantize_q5_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q6_K: result = quantize_q6_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_Q6_K_R4: result = quantize_q6_k_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ2_XXS: result = quantize_iq2_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ2_XS: result = quantize_iq2_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ3_XXS: result = quantize_iq3_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index a53e08f3..69f1ff0e 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -163,11 +163,14 @@ struct MulMat {
static bool prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny);
static inline int num_rows(ggml_type type) {
switch (type) {
+ case GGML_TYPE_Q4_K_R4:
+ case GGML_TYPE_Q6_K_R4:
case GGML_TYPE_Q4_0_R4:
case GGML_TYPE_Q5_0_R4:
case GGML_TYPE_Q6_0_R4:
case GGML_TYPE_Q8_0_R4:
case GGML_TYPE_IQ4_NL_R4:
+ case GGML_TYPE_IQ4_XS_R4:
case GGML_TYPE_IQ2_BN_R4: return 4;
default: return 1;
}
@@ -3248,6 +3251,118 @@ static void mul_mat_q4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataIn
}
#endif
+template <int nrc_y>
+static void mul_mat_q6_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ auto m4 = _mm256_set1_epi8(0xf);
+ auto m3 = _mm256_set1_epi8(0x30);
+ static const uint8_t k_shuff[32] = {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15, 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15};
+ auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff);
+#ifdef HAVE_FANCY_SIMD
+ __m256 d4s[nrc_y];
+#else
+ auto m1 = _mm256_set1_epi16(1);
+#endif
+ int nbl = n / QK_K;
+ __m256 acc[nrc_y] = {};
+ __m256i qx[4];
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ const block_q6_k_r4 * iq6 = (const block_q6_k_r4 *)((const char *)vx + (ix+0)*bx);
+ for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
+ auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq6[ibl].d));
+ auto d4 = _mm256_set_m128(dl, dl);
+#ifdef HAVE_FANCY_SIMD
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ d4s[iy] = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl)));
+ }
+#else
+ if constexpr (nrc_y == 1) {
+ d4 = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(0, ibl)));
+ }
+#endif
+ {
+#ifndef HAVE_FANCY_SIMD
+ auto min = _mm256_mul_ps(d4, _mm256_set1_ps(-32.f));
+#endif
+ auto t1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)iq6[ibl].scales+0)), shuff); // blocks 0, 1, 2, 3 for each row
+ auto t2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)iq6[ibl].scales+1)), shuff); // blocks 4, 5, 6, 7 for each row
+ auto t3 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)iq6[ibl].scales+2)), shuff); // blocks 8, 9, 10, 11 for each row
+ auto t4 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)iq6[ibl].scales+3)), shuff); // blocks 12, 13, 14, 15 for each row
+ auto s1 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 0), _mm256_extracti128_si256(t1, 0)); // blocks 0, 1, 8, 9
+ auto s2 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 1), _mm256_extracti128_si256(t1, 1)); // blocks 2, 3, 10, 11
+ auto s3 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 0), _mm256_extracti128_si256(t2, 0)); // blocks 4, 5, 12, 13
+ auto s4 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 1), _mm256_extracti128_si256(t2, 1)); // blocks 6, 7, 14, 15
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto bsums = q8.load_bsums(iy, ibl);
+ auto sumi = _mm256_setzero_si256();
+#ifdef HAVE_FANCY_SIMD
+ sumi = _mm256_dpwssd_epi32(sumi, s1, _mm256_shuffle_epi32(bsums, 0x00));
+ sumi = _mm256_dpwssd_epi32(sumi, s2, _mm256_shuffle_epi32(bsums, 0x55));
+ sumi = _mm256_dpwssd_epi32(sumi, s3, _mm256_shuffle_epi32(bsums, 0xaa));
+ sumi = _mm256_dpwssd_epi32(sumi, s4, _mm256_shuffle_epi32(bsums, 0xff));
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4s[iy], _mm256_set1_ps(-32.f)), _mm256_cvtepi32_ps(sumi), acc[iy]);
+#else
+ sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s1, _mm256_shuffle_epi32(bsums, 0x00)));
+ sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s2, _mm256_shuffle_epi32(bsums, 0x55)));
+ sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s3, _mm256_shuffle_epi32(bsums, 0xaa)));
+ sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s4, _mm256_shuffle_epi32(bsums, 0xff)));
+ if constexpr (nrc_y == 1) {
+ acc[iy] = _mm256_fmadd_ps(min, _mm256_cvtepi32_ps(sumi), acc[iy]);
+ } else {
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(min, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
+ }
+#endif
+ }
+ }
+ const uint32_t * scales = (const uint32_t *)iq6[ibl].scales;
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ auto iscales = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(scales + 2*ib)));
+#ifdef HAVE_FANCY_SIMD
+ auto scales = _mm256_cvtepi32_ps(iscales);
+#else
+ auto scales = _mm256_mul_ps(d4, _mm256_cvtepi32_ps(iscales));
+#endif
+ auto lbits1 = _mm256_loadu_si256((const __m256i *)iq6[ibl].ql+2*ib+0);
+ auto lbits2 = _mm256_loadu_si256((const __m256i *)iq6[ibl].ql+2*ib+1);
+ auto hbits = _mm256_loadu_si256((const __m256i *)iq6[ibl].qh+ib);
+ qx[0] = _mm256_or_si256(_mm256_and_si256(lbits1, m4), _mm256_and_si256(m3, _mm256_slli_epi16(hbits, 4)));
+ qx[1] = _mm256_or_si256(_mm256_and_si256(lbits2, m4), _mm256_and_si256(m3, hbits));
+ qx[2] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits1, 4), m4), _mm256_and_si256(m3, _mm256_slli_epi16(hbits, 2)));
+ qx[3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits2, 4), m4), _mm256_and_si256(m3, _mm256_srli_epi16(hbits, 2)));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
+#ifdef HAVE_FANCY_SIMD
+ auto sumi = _mm256_setzero_si256();
+ sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales, d4s[iy]), _mm256_cvtepi32_ps(sumi), acc[iy]);
+#else
+ auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
+ _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
+ auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
+ _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
+ // Quants are in 0...63, so we can add at most 4 as int16_t to be sure of no int16_t overflow
+ auto sumi = _mm256_add_epi32(_mm256_madd_epi16(m1, sumi1), _mm256_madd_epi16(m1, sumi2));
+ if constexpr (nrc_y == 1) {
+ acc[iy] = _mm256_fmadd_ps(scales, _mm256_cvtepi32_ps(sumi), acc[iy]);
+ } else {
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
+ }
+#endif
+ }
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
+ acc[iy] = _mm256_setzero_ps();
+ info.store(ix+0, iy, sum);
+ }
+ }
+}
+
template <typename Bits>
inline void multiply_add_1(int j, const Bits& bits, const __m256i * scales, const __m256i * q8, __m256i * sumi) {
if (j == 0) {
@@ -5255,6 +5370,18 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
mm.funcs[7] = mul_mat_q4_k_r4_q8_k<8>;
expected_typeB = GGML_TYPE_Q8_K32;
break;
+ case GGML_TYPE_Q6_K_R4:
+ assert (ne00 % QK_K == 0);
+ mm.funcs[0] = mul_mat_q6_k_r4_q8_k<1>;
+ mm.funcs[1] = mul_mat_q6_k_r4_q8_k<2>;
+ mm.funcs[2] = mul_mat_q6_k_r4_q8_k<3>;
+ mm.funcs[3] = mul_mat_q6_k_r4_q8_k<4>;
+ mm.funcs[4] = mul_mat_q6_k_r4_q8_k<5>;
+ mm.funcs[5] = mul_mat_q6_k_r4_q8_k<6>;
+ mm.funcs[6] = mul_mat_q6_k_r4_q8_k<7>;
+ mm.funcs[7] = mul_mat_q6_k_r4_q8_k<8>;
+ expected_typeB = GGML_TYPE_Q8_K;
+ break;
case GGML_TYPE_Q4_0_R4:
assert (ne00 % QK4_NL == 0);
mm.funcs[0] = mul_mat_q4_0_r4_q8_1<1>;
@@ -7847,6 +7974,28 @@ IQK_ALWAYS_INLINE int32x4_t interleaved_dotq(const int8x16_t * qx, const int8x16
return sumi;
}
+IQK_ALWAYS_INLINE int32x4x2_t interleaved_dotq_b16(const int8x16_t * qx, const int8x16x2_t& y) {
+ int32x4x2_t sumi = { vdupq_n_s32(0), vdupq_n_s32(0) };
+ sumi.val[0] = vdotq_laneq_s32(sumi.val[0], qx[0], y.val[0], 0);
+ sumi.val[1] = vdotq_laneq_s32(sumi.val[1], qx[1], y.val[1], 0);
+ sumi.val[0] = vdotq_laneq_s32(sumi.val[0], qx[2], y.val[0], 1);
+ sumi.val[1] = vdotq_laneq_s32(sumi.val[1], qx[3], y.val[1], 1);
+ sumi.val[0] = vdotq_laneq_s32(sumi.val[0], qx[4], y.val[0], 2);
+ sumi.val[1] = vdotq_laneq_s32(sumi.val[1], qx[5], y.val[1], 2);
+ sumi.val[0] = vdotq_laneq_s32(sumi.val[0], qx[6], y.val[0], 3);
+ sumi.val[1] = vdotq_laneq_s32(sumi.val[1], qx[7], y.val[1], 3);
+ return sumi;
+}
+
+IQK_ALWAYS_INLINE int32x4_t interleaved_dotq(const int8x16_t * qx, const int8x16_t& y) {
+ auto sumi = vdupq_n_s32(0);
+ sumi = vdotq_laneq_s32(sumi, qx[0], y, 0);
+ sumi = vdotq_laneq_s32(sumi, qx[1], y, 1);
+ sumi = vdotq_laneq_s32(sumi, qx[2], y, 2);
+ sumi = vdotq_laneq_s32(sumi, qx[3], y, 3);
+ return sumi;
+}
+
IQK_ALWAYS_INLINE void prepare_iq4_nl_quants(const int8x16_t& values, const uint8x16_t& m4, const uint8x16x4_t& bits, int8x16_t * qx) {
qx[0] = vqtbl1q_s8(values, vandq_u8(bits.val[0], m4)); // 0...3 from the 4 rows
qx[1] = vqtbl1q_s8(values, vandq_u8(bits.val[1], m4)); // 16..19
@@ -7993,6 +8142,61 @@ void mul_mat_q4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& inf
}
}
+template <int nrc_y>
+void mul_mat_q6_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ auto mf = vdupq_n_u8(0x0f);
+ auto m3 = vdupq_n_u8(0x30);
+ auto m32 = vdupq_n_s8(-32);
+ int nbl = n / QK_K;
+ int8x16_t qx[4];
+ float32x4x2_t scales;
+ float32x4_t acc[nrc_y] = {};
+ float32x4_t d4[nrc_y] = {};
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ const block_q6_k_r4 * iq6 = (const block_q6_k_r4 *)((const char *)vx + ix*bx);
+ for (int ibl = 0; ibl < nbl; ++ibl) {
+ auto dtmp = vcvt_f32_f16(vld1_f16((const float16_t *)iq6[ibl].d));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ d4[iy] = vmulq_f32(dtmp, vdupq_n_f32(q8.scale(iy, ibl)));
+ }
+ for (int is = 0; is < 2; ++is) {
+ for (int ib = 0; ib < 4; ++ib) {
+ auto lbits = vld1q_u8_x4(iq6[ibl].ql + 256*is + 64*ib);
+ auto hbits = vld1q_u8(iq6[ibl].qh + 128*is + 32*ib);
+ auto iscales = vmovl_s8(vld1_s8(iq6[ibl].scales + 32*is + 8*ib));
+ scales.val[0] = vcvtq_f32_s32(vmovl_s16(vget_low_s16(iscales)));
+ scales.val[1] = vcvtq_f32_s32(vmovl_s16(vget_high_s16(iscales)));
+ qx[0] = vaddq_s8(m32, vorrq_u8(vandq_u8 (lbits.val[0], mf), vandq_u8(m3, vshlq_n_u8(hbits, 4))));
+ qx[1] = vaddq_s8(m32, vorrq_u8(vandq_u8 (lbits.val[2], mf), vandq_u8(m3, hbits)));
+ qx[2] = vaddq_s8(m32, vorrq_u8(vshrq_n_u8(lbits.val[0], 4), vandq_u8(m3, vshlq_n_u8(hbits, 2))));
+ qx[3] = vaddq_s8(m32, vorrq_u8(vshrq_n_u8(lbits.val[2], 4), vandq_u8(m3, vshrq_n_u8(hbits, 2))));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib);
+ auto sumi = interleaved_dotq(qx, y);
+ acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(scales.val[0], d4[iy]), vcvtq_f32_s32(sumi));
+ }
+ hbits = vld1q_u8(iq6[ibl].qh + 128*is + 32*ib + 16);
+ qx[0] = vaddq_s8(m32, vorrq_u8(vandq_u8 (lbits.val[1], mf), vandq_u8(m3, vshlq_n_u8(hbits, 4))));
+ qx[1] = vaddq_s8(m32, vorrq_u8(vandq_u8 (lbits.val[3], mf), vandq_u8(m3, hbits)));
+ qx[2] = vaddq_s8(m32, vorrq_u8(vshrq_n_u8(lbits.val[1], 4), vandq_u8(m3, vshlq_n_u8(hbits, 2))));
+ qx[3] = vaddq_s8(m32, vorrq_u8(vshrq_n_u8(lbits.val[3], 4), vandq_u8(m3, vshrq_n_u8(hbits, 2))));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib+16);
+ auto sumi = interleaved_dotq(qx, y);
+ acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(scales.val[1], d4[iy]), vcvtq_f32_s32(sumi));
+ }
+ }
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, acc[iy]);
+ acc[iy] = vdupq_n_f32(0.f);
+ }
+ }
+}
+
void mul_mat_iq4_nl_r4_q8_0_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<1, block_q8_0_x4> q8(info);
@@ -8394,6 +8598,10 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
SET_MUL_MAT_FUNCTIONS(m, mul_mat_q4_k_r4_q8_k);
expected_Btype = GGML_TYPE_Q8_K32;
break;
+ case GGML_TYPE_Q6_K_R4:
+ SET_MUL_MAT_FUNCTIONS(m, mul_mat_q6_k_r4_q8_k);
+ expected_Btype = GGML_TYPE_Q8_K;
+ break;
case GGML_TYPE_Q4_0_R4:
SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r4_q8_0, Q4_0_R4_Dequantizer);
expected_Btype = GGML_TYPE_Q8_0;
diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp
index 93a3a0ea..71578bf8 100644
--- a/ggml/src/iqk/iqk_quantize.cpp
+++ b/ggml/src/iqk/iqk_quantize.cpp
@@ -4063,3 +4063,122 @@ void vec_dot_q4_k_r4_q8_k(int n, float * s, size_t bs, const void * vx, size_t b
GGML_UNUSED(by);
}
+//
+// ========================================= q6_k_r4
+//
+
+void quantize_row_q6_k_r4_ref(const float * x, block_q6_k_r4 * y, int64_t k) {
+ quantize_q6_k_r4(x, (void *)y, 4, k/4, nullptr);
+}
+
+void quantize_row_q6_k_r4(const float * x, void * y, int64_t k) {
+ quantize_q6_k_r4(x, y, 4, k/4, nullptr);
+}
+
+namespace {
+inline void convert_q6_k(const block_q6_K& x, uint8_t * L) {
+ const uint8_t * ql = x.ql;
+ const uint8_t * qh = x.qh;
+
+ for (int n = 0; n < QK_K; n += 128) {
+ for (int l = 0; l < 32; ++l) {
+ L[n + l + 0] = (ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4);
+ L[n + l + 32] = (ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4);
+ L[n + l + 64] = (ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4);
+ L[n + l + 96] = (ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4);
+ }
+ ql += 64;
+ qh += 32;
+ }
+}
+}
+
+static void repack_q6_k(int nrows, int n_per_row, const block_q6_K * x, block_q6_k_r4 * y) {
+ GGML_ASSERT(nrows%4 == 0);
+ GGML_ASSERT(n_per_row%QK_K == 0);
+ int nblock = n_per_row/QK_K;
+ const block_q6_K * x4[4];
+ uint8_t L[QK_K];
+ for (int row = 0; row < nrows; row += 4) {
+ for (int k = 0; k < 4; ++k) x4[k] = x + nblock*k;
+ for (int ibl = 0; ibl < nblock; ++ibl) {
+ for (int k = 0; k < 4; ++k) {
+ y[ibl].d[k] = x4[k][ibl].d;
+ convert_q6_k(x4[k][ibl], L);
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ y[ibl].scales[8*ib+k+0] = x4[k][ibl].scales[2*ib+0];
+ y[ibl].scales[8*ib+k+4] = x4[k][ibl].scales[2*ib+1];
+ for (int i = 0; i < 4; ++i) {
+ y[ibl].ql[64*ib+4*k+i+ 0] = (L[32*ib+i+ 0] & 0xf) | ((L[32*ib+i+ 8] & 0xf) << 4);
+ y[ibl].ql[64*ib+4*k+i+16] = (L[32*ib+i+16] & 0xf) | ((L[32*ib+i+24] & 0xf) << 4);
+ y[ibl].ql[64*ib+4*k+i+32] = (L[32*ib+i+ 4] & 0xf) | ((L[32*ib+i+12] & 0xf) << 4);
+ y[ibl].ql[64*ib+4*k+i+48] = (L[32*ib+i+20] & 0xf) | ((L[32*ib+i+28] & 0xf) << 4);
+ y[ibl].qh[32*ib+4*k+i+ 0] = (L[32*ib+i+ 0] >> 4) | ((L[32*ib+i+ 8] >> 4) << 2) | ((L[32*ib+i+ 4] >> 4) << 4) | ((L[32*ib+i+12] >> 4) << 6);
+ y[ibl].qh[32*ib+4*k+i+16] = (L[32*ib+i+16] >> 4) | ((L[32*ib+i+24] >> 4) << 2) | ((L[32*ib+i+20] >> 4) << 4) | ((L[32*ib+i+28] >> 4) << 6);
+ }
+ }
+ }
+ }
+ x += 4*nblock;
+ y += nblock;
+ }
+}
+
+size_t quantize_q6_k_r4(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) {
+ GGML_ASSERT(nrows%4 == 0);
+ GGML_ASSERT(n_per_row%QK_K == 0);
+ char * qcur = (char *)dst;
+ auto row_size = ggml_row_size(GGML_TYPE_Q6_K, n_per_row);
+ std::vector<char> qtmp(4*row_size);
+ for (int row = 0; row < nrows; row += 4) {
+ quantize_q6_K(src, (void *)qtmp.data(), 4, n_per_row, imatrix);
+ repack_q6_k(4, n_per_row, (const block_q6_K *)qtmp.data(), (block_q6_k_r4 *)qcur);
+ qcur += 4*row_size;
+ src += 4*n_per_row;
+ }
+ return nrows*row_size;
+}
+
+void dequantize_row_q6_k_r4(const block_q6_k_r4 * x, float * y, int64_t k) {
+ auto n_per_row = k/4;
+ float * y4[4] = {y, y + n_per_row, y + 2*n_per_row, y + 3*n_per_row};
+ int nblock = n_per_row/QK_K;
+ for (int ibl = 0; ibl < nblock; ++ibl) {
+ for (int k = 0; k < 4; ++k) {
+ const float d = GGML_FP16_TO_FP32(x[ibl].d[k]);
+ auto ql = x[ibl].ql;
+ auto qh = x[ibl].qh;
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ float dl1 = d * x[ibl].scales[8*ib+k+0];
+ float dl2 = d * x[ibl].scales[8*ib+k+4];
+ for (int i = 0; i < 4; ++i) {
+ y4[k][QK_K*ibl+32*ib+i+ 0] = dl1 * (((ql[4*k+i+ 0] & 0xf) | ((qh[4*k+i+ 0] << 4) & 0x30)) - 32);
+ y4[k][QK_K*ibl+32*ib+i+ 8] = dl1 * (((ql[4*k+i+ 0] >> 4) | ((qh[4*k+i+ 0] << 2) & 0x30)) - 32);
+ y4[k][QK_K*ibl+32*ib+i+16] = dl2 * (((ql[4*k+i+16] & 0xf) | ((qh[4*k+i+16] << 4) & 0x30)) - 32);
+ y4[k][QK_K*ibl+32*ib+i+24] = dl2 * (((ql[4*k+i+16] >> 4) | ((qh[4*k+i+16] << 2) & 0x30)) - 32);
+ y4[k][QK_K*ibl+32*ib+i+ 4] = dl1 * (((ql[4*k+i+32] & 0xf) | ((qh[4*k+i+ 0] >> 0) & 0x30)) - 32);
+ y4[k][QK_K*ibl+32*ib+i+12] = dl1 * (((ql[4*k+i+32] >> 4) | ((qh[4*k+i+ 0] >> 2) & 0x30)) - 32);
+ y4[k][QK_K*ibl+32*ib+i+20] = dl2 * (((ql[4*k+i+48] & 0xf) | ((qh[4*k+i+16] >> 0) & 0x30)) - 32);
+ y4[k][QK_K*ibl+32*ib+i+28] = dl2 * (((ql[4*k+i+48] >> 4) | ((qh[4*k+i+16] >> 2) & 0x30)) - 32);
+ }
+ ql += 64;
+ qh += 32;
+ }
+ }
+ }
+}
+
+void vec_dot_q6_k_r4_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
+#if GGML_USE_IQK_MULMAT
+ if (iqk_mul_mat(1, 1, n, GGML_TYPE_Q6_K_R4, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) {
+ return;
+ }
+#endif
+ GGML_ASSERT(n%QK4_NL == 0);
+ GGML_ASSERT(nrc == 1);
+ GGML_UNUSED(bs);
+ GGML_UNUSED(bx);
+ GGML_UNUSED(by);
+}
+
+
diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h
index 5b4a3e44..8819620d 100644
--- a/ggml/src/iqk/iqk_quantize.h
+++ b/ggml/src/iqk/iqk_quantize.h
@@ -115,6 +115,12 @@ size_t quantize_q4_k_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT ds
void dequantize_row_q4_k_r4(const block_q4_k_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void vec_dot_q4_k_r4_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void quantize_row_q6_k_r4_ref(const float * GGML_RESTRICT x, block_q6_k_r4 * GGML_RESTRICT y, int64_t k);
+void quantize_row_q6_k_r4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+size_t quantize_q6_k_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+void dequantize_row_q6_k_r4(const block_q6_k_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void vec_dot_q6_k_r4_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
void iqk_quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
void quantize_row_q8_K64_ref(const float * GGML_RESTRICT x, block_q8_K64 * GGML_RESTRICT y, int64_t k);
void quantize_row_q8_K64(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
diff --git a/include/llama.h b/include/llama.h
index 2fa78879..92234d6c 100644
--- a/include/llama.h
+++ b/include/llama.h
@@ -184,6 +184,7 @@ extern "C" {
LLAMA_FTYPE_MOSTLY_Q8_0_R4 = 207, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_0_R4 = 208, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_K_R4 = 214, // except 1d tensors
+ LLAMA_FTYPE_MOSTLY_Q6_K_R4 = 218, // except 1d tensors
LLAMA_FTYPE_MOSTLY_IQ4_NL_R4 = 225, // except 1d tensors
LLAMA_FTYPE_MOSTLY_IQ4_XS_R4 = 230, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q6_0_R4 = 235, // except 1d tensors
diff --git a/src/llama.cpp b/src/llama.cpp
index 18c6e111..3b617b06 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -3840,6 +3840,7 @@ struct llama_model_loader {
case GGML_TYPE_Q4_K_R4: ftype = LLAMA_FTYPE_MOSTLY_Q4_K_R4; break;
case GGML_TYPE_Q5_K: ftype = LLAMA_FTYPE_MOSTLY_Q5_K_M; break;
case GGML_TYPE_Q6_K: ftype = LLAMA_FTYPE_MOSTLY_Q6_K; break;
+ case GGML_TYPE_Q6_K_R4: ftype = LLAMA_FTYPE_MOSTLY_Q6_K_R4; break;
case GGML_TYPE_IQ2_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XXS; break;
case GGML_TYPE_IQ2_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XS; break;
case GGML_TYPE_IQ2_KS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_KS; break;
@@ -4552,6 +4553,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
case LLAMA_FTYPE_MOSTLY_Q5_K_S: return "Q5_K - Small";
case LLAMA_FTYPE_MOSTLY_Q5_K_M: return "Q5_K - Medium";
case LLAMA_FTYPE_MOSTLY_Q6_K: return "Q6_K";
+ case LLAMA_FTYPE_MOSTLY_Q6_K_R4: return "Q6_K_R4";
case LLAMA_FTYPE_MOSTLY_IQ2_XXS: return "IQ2_XXS - 2.0625 bpw";
case LLAMA_FTYPE_MOSTLY_IQ2_XS: return "IQ2_XS - 2.3125 bpw";
case LLAMA_FTYPE_MOSTLY_IQ2_KS: return "IQ2_KS - 2.1875 bpw";
@@ -15757,7 +15759,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
ftype == LLAMA_FTYPE_MOSTLY_IQ4_KS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_KSS) && !qs.has_output) {
new_type = GGML_TYPE_IQ5_K;
}
- else if (new_type != GGML_TYPE_Q8_0 && new_type != GGML_TYPE_Q8_0_R4 && new_type != GGML_TYPE_IQ6_K) {
+ else if (new_type != GGML_TYPE_Q8_0 && new_type != GGML_TYPE_Q8_0_R4 && new_type != GGML_TYPE_IQ6_K && new_type != GGML_TYPE_Q6_K_R4) {
new_type = GGML_TYPE_Q6_K;
}
}
@@ -15791,6 +15793,9 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
else if (new_type == GGML_TYPE_Q4_K_R4) {
new_type = GGML_TYPE_Q4_K;
}
+ else if (new_type == GGML_TYPE_Q6_K_R4) {
+ new_type = GGML_TYPE_Q6_K;
+ }
else if (new_type == GGML_TYPE_Q4_0_R4) {
new_type = GGML_TYPE_Q4_0;
}
@@ -16062,7 +16067,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
new_type == GGML_TYPE_IQ1_M || new_type == GGML_TYPE_IQ4_K || new_type == GGML_TYPE_IQ2_K ||
new_type == GGML_TYPE_IQ5_K || new_type == GGML_TYPE_IQ3_K || new_type == GGML_TYPE_Q4_K_R4 ||
new_type == GGML_TYPE_IQ6_K || new_type == GGML_TYPE_IQ4_KS || new_type == GGML_TYPE_IQ4_XS_R4 ||
- new_type == GGML_TYPE_IQ2_KS || new_type == GGML_TYPE_IQ4_KSS) {
+ new_type == GGML_TYPE_IQ2_KS || new_type == GGML_TYPE_IQ4_KSS || new_type == GGML_TYPE_Q6_K_R4) {
int nx = tensor->ne[0];
int ny = tensor->ne[1];
if (nx % QK_K != 0) {
@@ -16102,6 +16107,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
case GGML_TYPE_IQ5_K:
case GGML_TYPE_Q5_K: new_type = GGML_TYPE_Q6_0; break;
case GGML_TYPE_IQ6_K:
+ case GGML_TYPE_Q6_K_R4:
case GGML_TYPE_Q6_K: new_type = GGML_TYPE_Q8_0; break;
default: throw std::runtime_error("\nUnsupported tensor size encountered\n");
}
@@ -16194,6 +16200,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
case LLAMA_FTYPE_MOSTLY_Q5_K_S:
case LLAMA_FTYPE_MOSTLY_Q5_K_M: default_type = GGML_TYPE_Q5_K; break;
case LLAMA_FTYPE_MOSTLY_Q6_K: default_type = GGML_TYPE_Q6_K; break;
+ case LLAMA_FTYPE_MOSTLY_Q6_K_R4: default_type = GGML_TYPE_Q6_K_R4; break;
case LLAMA_FTYPE_MOSTLY_IQ2_XXS: default_type = GGML_TYPE_IQ2_XXS; break;
case LLAMA_FTYPE_MOSTLY_IQ2_XS: default_type = GGML_TYPE_IQ2_XS; break;
case LLAMA_FTYPE_MOSTLY_IQ2_KS: default_type = GGML_TYPE_IQ2_KS; break;
@@ -16597,6 +16604,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q4_K;
else chunk_size_multiplier = 4;
}
+ else if (new_type == GGML_TYPE_Q6_K_R4) {
+ if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q6_K;
+ else chunk_size_multiplier = 4;
+ }
else if (new_type == GGML_TYPE_IQ2_BN_R4) {
if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ2_BN;
else chunk_size_multiplier = 4;