summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/quantize/quantize.cpp1
-rw-r--r--ggml/include/ggml.h2
-rw-r--r--ggml/src/ggml-common.h7
-rw-r--r--ggml/src/ggml-quants.c1
-rw-r--r--ggml/src/ggml.c24
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp255
-rw-r--r--ggml/src/iqk/iqk_quantize.cpp115
-rw-r--r--ggml/src/iqk/iqk_quantize.h6
-rw-r--r--include/llama.h1
-rw-r--r--src/llama.cpp21
10 files changed, 430 insertions, 3 deletions
diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp
index 2c8b33c2..4c650b87 100644
--- a/examples/quantize/quantize.cpp
+++ b/examples/quantize/quantize.cpp
@@ -31,6 +31,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
{ "IQ2_BN", LLAMA_FTYPE_MOSTLY_IQ2_BN, " 2.00 bpw quantization (Bitnet)", },
{ "IQ2_BN_R4",LLAMA_FTYPE_MOSTLY_IQ2_BN_R4," 2.00 bpw quantization (Bitnet)", },
{ "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", },
+ { "Q2_K_R4", LLAMA_FTYPE_MOSTLY_Q2_K_R4, "Q2_K_S repacked", },
{ "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", },
{ "IQ3_XXS", LLAMA_FTYPE_MOSTLY_IQ3_XXS, " 3.06 bpw quantization", },
{ "IQ3_S", LLAMA_FTYPE_MOSTLY_IQ3_S, " 3.44 bpw quantization", },
diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
index 0ab34f27..2ed0fb1f 100644
--- a/ggml/include/ggml.h
+++ b/ggml/include/ggml.h
@@ -412,6 +412,7 @@ extern "C" {
GGML_TYPE_Q4_0_R4 = 202,
GGML_TYPE_Q5_0_R4 = 206,
GGML_TYPE_Q8_0_R4 = 208,
+ GGML_TYPE_Q2_K_R4 = 210,
GGML_TYPE_Q3_K_R4 = 211,
GGML_TYPE_Q4_K_R4 = 212,
GGML_TYPE_Q5_K_R4 = 213,
@@ -482,6 +483,7 @@ extern "C" {
GGML_FTYPE_MOSTLY_Q4_0_R4 = 202, // except 1d tensors
GGML_FTYPE_MOSTLY_Q8_0_R4 = 207, // except 1d tensors
GGML_FTYPE_MOSTLY_Q5_0_R4 = 208, // except 1d tensors
+ GGML_FTYPE_MOSTLY_Q2_K_R4 = 210, // except 1d tensors
GGML_FTYPE_MOSTLY_Q3_K_R4 = 211, // except 1d tensors
GGML_FTYPE_MOSTLY_Q4_K_R4 = 212, // except 1d tensors
GGML_FTYPE_MOSTLY_Q5_K_R4 = 215, // except 1d tensors
diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h
index bc34718e..61e8dfd3 100644
--- a/ggml/src/ggml-common.h
+++ b/ggml/src/ggml-common.h
@@ -276,6 +276,13 @@ typedef struct {
} block_q2_K;
static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_half) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
+typedef struct {
+ ggml_half d[8];
+ uint8_t scales[QK_K/4]; // scales and mins, quantized with 4 bits
+ uint8_t qs[QK_K]; // quants
+} block_q2_k_r4;
+static_assert(sizeof(block_q2_k_r4) == 8*sizeof(ggml_half) + QK_K/4 + QK_K, "wrong q2_k_r4 block size/padding");
+
// 3-bit quantization
// weight is represented as x = a * q
// 16 blocks of 16 elements each
diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c
index c2fdf6fa..ff857087 100644
--- a/ggml/src/ggml-quants.c
+++ b/ggml/src/ggml-quants.c
@@ -15202,6 +15202,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
case GGML_TYPE_Q5_0_R4: break;
case GGML_TYPE_Q6_0_R4: break;
case GGML_TYPE_Q8_0_R4: break;
+ case GGML_TYPE_Q2_K_R4: break;
case GGML_TYPE_Q3_K_R4: break;
case GGML_TYPE_Q4_K_R4: break;
case GGML_TYPE_Q5_K_R4: break;
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 0bb59d2b..6c574933 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -862,6 +862,19 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.nrows = 1,
.row_meta_size = 0,
},
+ [GGML_TYPE_Q2_K_R4] = {
+ .type_name = "q2_k_r4",
+ .blck_size = QK_K,
+ .type_size = sizeof(block_q2_K),
+ .is_quantized = true,
+ .to_float = (ggml_to_float_t) dequantize_row_q2_k_r4,
+ .from_float = quantize_row_q2_k_r4,
+ .from_float_ref = (ggml_from_float_t) quantize_row_q2_k_r4_ref,
+ .vec_dot = vec_dot_q2_k_r4_q8_k,
+ .vec_dot_type = GGML_TYPE_Q8_K,
+ .nrows = 1,
+ .row_meta_size = 0,
+ },
[GGML_TYPE_Q3_K] = {
.type_name = "q3_K",
.blck_size = QK_K,
@@ -4070,7 +4083,8 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
case GGML_FTYPE_MOSTLY_Q6_0: wtype = GGML_TYPE_Q6_0; break;
case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break;
case GGML_FTYPE_MOSTLY_Q2_K: wtype = GGML_TYPE_Q2_K; break;
- case GGML_FTYPE_MOSTLY_Q3_K: wtype = GGML_TYPE_Q3_K; break;
+ case GGML_FTYPE_MOSTLY_Q2_K_R4: wtype = GGML_TYPE_Q2_K_R4; break;
+ case GGML_FTYPE_MOSTLY_Q3_K: wtype = GGML_TYPE_Q3_K; break;
case GGML_FTYPE_MOSTLY_Q3_K_R4: wtype = GGML_TYPE_Q3_K_R4; break;
case GGML_FTYPE_MOSTLY_Q4_K: wtype = GGML_TYPE_Q4_K; break;
case GGML_FTYPE_MOSTLY_Q4_K_R4: wtype = GGML_TYPE_Q4_K_R4; break;
@@ -10604,6 +10618,7 @@ static void ggml_compute_forward_add(
case GGML_TYPE_Q6_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q2_K:
+ case GGML_TYPE_Q2_K_R4:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q3_K_R4:
case GGML_TYPE_Q4_K:
@@ -11057,6 +11072,7 @@ static void ggml_compute_forward_add1(
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_Q2_K:
+ case GGML_TYPE_Q2_K_R4:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q3_K_R4:
case GGML_TYPE_Q4_K:
@@ -11207,6 +11223,7 @@ static void ggml_compute_forward_acc(
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_Q2_K:
+ case GGML_TYPE_Q2_K_R4:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q3_K_R4:
case GGML_TYPE_Q4_K:
@@ -14403,6 +14420,7 @@ static void ggml_compute_forward_out_prod(
case GGML_TYPE_Q6_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q2_K:
+ case GGML_TYPE_Q2_K_R4:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q3_K_R4:
case GGML_TYPE_Q4_K:
@@ -14793,6 +14811,7 @@ static void ggml_compute_forward_set(
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_Q2_K:
+ case GGML_TYPE_Q2_K_R4:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q3_K_R4:
case GGML_TYPE_Q4_K:
@@ -15077,6 +15096,7 @@ static void ggml_compute_forward_get_rows(
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_Q2_K:
+ case GGML_TYPE_Q2_K_R4:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q3_K_R4:
case GGML_TYPE_Q4_K:
@@ -15688,6 +15708,7 @@ static void ggml_compute_forward_clamp(
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_Q2_K:
+ case GGML_TYPE_Q2_K_R4:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q3_K_R4:
case GGML_TYPE_Q4_K:
@@ -22527,6 +22548,7 @@ size_t ggml_quantize_chunk(
case GGML_TYPE_Q6_0: result = quantize_q6_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q8_0: result = quantize_q8_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q2_K: result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_Q2_K_R4: result = quantize_q2_k_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q3_K: result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q3_K_R4: result = quantize_q3_k_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q4_K: result = quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 8ab8b2bd..4316373a 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -163,7 +163,10 @@ struct MulMat {
static bool prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny);
static inline int num_rows(ggml_type type) {
switch (type) {
+ case GGML_TYPE_Q2_K_R4:
+ case GGML_TYPE_Q3_K_R4:
case GGML_TYPE_Q4_K_R4:
+ case GGML_TYPE_Q5_K_R4:
case GGML_TYPE_Q6_K_R4:
case GGML_TYPE_Q4_0_R4:
case GGML_TYPE_Q5_0_R4:
@@ -3440,6 +3443,116 @@ static void mul_mat_q5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataIn
#endif
template <int nrc_y>
+static void mul_mat_q2_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ auto mxf = _mm256_set1_epi8(0xf);
+ auto m03 = _mm256_set1_epi8(0x03);
+ static const uint8_t k_shuff[32] = {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15, 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15};
+ auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff);
+#ifdef HAVE_FANCY_SIMD
+ __m256 d4s[nrc_y];
+#else
+ auto m1 = _mm256_set1_epi16(1);
+#endif
+ int nbl = n / QK_K;
+ __m256 acc[nrc_y] = {};
+ __m256i qx[4];
+ int8_t scales[64];
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ const block_q2_k_r4 * iq2 = (const block_q2_k_r4 *)((const char *)vx + (ix+0)*bx);
+ for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
+ auto dm = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq2[ibl].d));
+ auto d4 = _mm256_set_m128(_mm256_castps256_ps128(dm), _mm256_castps256_ps128(dm));
+ auto m4 = _mm256_set_m128(_mm256_extractf128_ps(dm, 1), _mm256_extractf128_ps(dm, 1));
+ m4 = _mm256_mul_ps(m4, _mm256_set1_ps(-1.f));
+ auto all_scales1 = _mm256_loadu_si256((const __m256i *)iq2[ibl].scales+0);
+ auto all_scales2 = _mm256_loadu_si256((const __m256i *)iq2[ibl].scales+1);
+ auto scales1 = _mm256_and_si256(_mm256_srli_epi16(all_scales1, 4), mxf);
+ auto scales2 = _mm256_and_si256(_mm256_srli_epi16(all_scales2, 4), mxf);
+ {
+ auto t1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales1, 0)), shuff); // blocks 0, 1, 2, 3 for each row
+ auto t2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales1, 1)), shuff); // blocks 4, 5, 6, 7 for each row
+ auto t3 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales2, 0)), shuff); // blocks 8, 9, 10, 11 for each row
+ auto t4 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales2, 1)), shuff); // blocks 12, 13, 14, 15 for each row
+ auto s1 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 0), _mm256_extracti128_si256(t1, 0)); // blocks 0, 1, 8, 9
+ auto s2 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 1), _mm256_extracti128_si256(t1, 1)); // blocks 2, 3, 10, 11
+ auto s3 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 0), _mm256_extracti128_si256(t2, 0)); // blocks 4, 5, 12, 13
+ auto s4 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 1), _mm256_extracti128_si256(t2, 1)); // blocks 6, 7, 14, 15
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto bsums = q8.load_bsums(iy, ibl);
+ auto sumi = _mm256_setzero_si256();
+#ifdef HAVE_FANCY_SIMD
+ sumi = _mm256_dpwssd_epi32(sumi, s1, _mm256_shuffle_epi32(bsums, 0x00));
+ sumi = _mm256_dpwssd_epi32(sumi, s2, _mm256_shuffle_epi32(bsums, 0x55));
+ sumi = _mm256_dpwssd_epi32(sumi, s3, _mm256_shuffle_epi32(bsums, 0xaa));
+ sumi = _mm256_dpwssd_epi32(sumi, s4, _mm256_shuffle_epi32(bsums, 0xff));
+ auto d8 = _mm256_set1_ps(q8.scale(iy, ibl));
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(m4, d8), _mm256_cvtepi32_ps(sumi), acc[iy]);
+ d4s[iy] = _mm256_mul_ps(d4, d8);
+#else
+ sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s1, _mm256_shuffle_epi32(bsums, 0x00)));
+ sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s2, _mm256_shuffle_epi32(bsums, 0x55)));
+ sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s3, _mm256_shuffle_epi32(bsums, 0xaa)));
+ sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s4, _mm256_shuffle_epi32(bsums, 0xff)));
+ auto d8 = _mm256_set1_ps(q8.scale(iy, ibl));
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(m4, d8), _mm256_cvtepi32_ps(sumi), acc[iy]);
+ if constexpr (nrc_y == 1) {
+ d4 = _mm256_mul_ps(d4, d8);
+ }
+#endif
+ }
+ }
+ all_scales1 = _mm256_and_si256(all_scales1, mxf);
+ all_scales2 = _mm256_and_si256(all_scales2, mxf);
+ _mm256_storeu_si256((__m256i *)scales+0, all_scales1);
+ _mm256_storeu_si256((__m256i *)scales+1, all_scales2);
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ auto iscales = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(scales + 8*ib)));
+#ifdef HAVE_FANCY_SIMD
+ auto scales = _mm256_cvtepi32_ps(iscales);
+#else
+ auto scales = _mm256_mul_ps(d4, _mm256_cvtepi32_ps(iscales));
+#endif
+ auto lb = _mm256_loadu_si256((const __m256i *)iq2[ibl].qs+ib);
+ qx[0] = _mm256_and_si256(lb, m03);
+ qx[1] = _mm256_and_si256(_mm256_srli_epi16(lb, 2), m03);
+ qx[2] = _mm256_and_si256(_mm256_srli_epi16(lb, 4), m03);
+ qx[3] = _mm256_and_si256(_mm256_srli_epi16(lb, 6), m03);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
+#ifdef HAVE_FANCY_SIMD
+ auto sumi = _mm256_setzero_si256();
+ sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales, d4s[iy]), _mm256_cvtepi32_ps(sumi), acc[iy]);
+#else
+ auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
+ _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
+ auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
+ _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
+ // Quants are in 0...3, so we can add add up all of them as int16_t without overflowing
+ auto sumi = _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2));
+ if constexpr (nrc_y == 1) {
+ acc[iy] = _mm256_fmadd_ps(scales, _mm256_cvtepi32_ps(sumi), acc[iy]);
+ } else {
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
+ }
+#endif
+ }
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
+ acc[iy] = _mm256_setzero_ps();
+ info.store(ix+0, iy, sum);
+ }
+ }
+}
+
+template <int nrc_y>
static void mul_mat_q3_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
@@ -3450,7 +3563,11 @@ static void mul_mat_q3_k_r4_q8_k(int n, const void * vx, size_t bx, const DataIn
auto m04 = _mm256_set1_epi8(0x04);
static const uint8_t k_shuff[32] = {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15, 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15};
auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff);
+#ifdef HAVE_FANCY_SIMD
__m256 d4s[nrc_y];
+#else
+ auto m1 = _mm256_set1_epi16(1);
+#endif
int nbl = n / QK_K;
__m256 acc[nrc_y] = {};
__m256i qx[4];
@@ -3460,9 +3577,15 @@ static void mul_mat_q3_k_r4_q8_k(int n, const void * vx, size_t bx, const DataIn
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq3[ibl].d));
auto d4 = _mm256_set_m128(dl, dl);
+#ifdef HAVE_FANCY_SIMD
for (int iy = 0; iy < nrc_y; ++iy) {
d4s[iy] = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl)));
}
+#else
+ if constexpr (nrc_y == 1) {
+ d4 = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(0, ibl)));
+ }
+#endif
auto slb = _mm256_loadu_si256((const __m256i *)iq3[ibl].scales_l);
auto shbits = _mm_loadu_si128((const __m128i *)iq3[ibl].scales_h);
auto shb = MM256_SET_M128I(_mm_srli_epi16(shbits, 2), shbits);
@@ -3471,6 +3594,9 @@ static void mul_mat_q3_k_r4_q8_k(int n, const void * vx, size_t bx, const DataIn
_mm256_storeu_si256((__m256i *)scales+0, scales1);
_mm256_storeu_si256((__m256i *)scales+1, scales2);
{
+#ifndef HAVE_FANCY_SIMD
+ auto min = _mm256_mul_ps(d4, _mm256_set1_ps(-4.f));
+#endif
auto t1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales1, 0)), shuff); // blocks 0, 1, 2, 3 for each row
auto t2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales1, 1)), shuff); // blocks 4, 5, 6, 7 for each row
auto t3 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales2, 0)), shuff); // blocks 8, 9, 10, 11 for each row
@@ -3482,16 +3608,32 @@ static void mul_mat_q3_k_r4_q8_k(int n, const void * vx, size_t bx, const DataIn
for (int iy = 0; iy < nrc_y; ++iy) {
auto bsums = q8.load_bsums(iy, ibl);
auto sumi = _mm256_setzero_si256();
+#ifdef HAVE_FANCY_SIMD
sumi = _mm256_dpwssd_epi32(sumi, s1, _mm256_shuffle_epi32(bsums, 0x00));
sumi = _mm256_dpwssd_epi32(sumi, s2, _mm256_shuffle_epi32(bsums, 0x55));
sumi = _mm256_dpwssd_epi32(sumi, s3, _mm256_shuffle_epi32(bsums, 0xaa));
sumi = _mm256_dpwssd_epi32(sumi, s4, _mm256_shuffle_epi32(bsums, 0xff));
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4s[iy], _mm256_set1_ps(-4.f)), _mm256_cvtepi32_ps(sumi), acc[iy]);
+#else
+ sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s1, _mm256_shuffle_epi32(bsums, 0x00)));
+ sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s2, _mm256_shuffle_epi32(bsums, 0x55)));
+ sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s3, _mm256_shuffle_epi32(bsums, 0xaa)));
+ sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s4, _mm256_shuffle_epi32(bsums, 0xff)));
+ if constexpr (nrc_y == 1) {
+ acc[iy] = _mm256_fmadd_ps(min, _mm256_cvtepi32_ps(sumi), acc[iy]);
+ } else {
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(min, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
+ }
+#endif
}
}
for (int ib = 0; ib < QK_K/32; ++ib) {
auto iscales = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(scales + 8*ib)));
+#ifdef HAVE_FANCY_SIMD
auto scales = _mm256_cvtepi32_ps(iscales);
+#else
+ auto scales = _mm256_mul_ps(d4, _mm256_cvtepi32_ps(iscales));
+#endif
auto lb = _mm256_loadu_si256((const __m256i *)iq3[ibl].qs+ib);
auto hbits = _mm_loadu_si128((const __m128i *)iq3[ibl].qh+ib);
auto hb = MM256_SET_M128I(hbits, _mm_slli_epi16(hbits, 4));
@@ -3501,12 +3643,27 @@ static void mul_mat_q3_k_r4_q8_k(int n, const void * vx, size_t bx, const DataIn
qx[3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lb, 6), m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 5)));
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
+#ifdef HAVE_FANCY_SIMD
auto sumi = _mm256_setzero_si256();
sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales, d4s[iy]), _mm256_cvtepi32_ps(sumi), acc[iy]);
+#else
+ auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
+ _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
+ auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
+ _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
+ // Quants are in 0...8, so we can add add up all of them as int16_t without overflowing
+ auto sumi = _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2));
+ if constexpr (nrc_y == 1) {
+ acc[iy] = _mm256_fmadd_ps(scales, _mm256_cvtepi32_ps(sumi), acc[iy]);
+ } else {
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
+ }
+#endif
+
}
}
}
@@ -5625,6 +5782,18 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
mm.funcs[7] = mul_mat_iq4_xs_r4_q8_k<8>;
expected_typeB = GGML_TYPE_Q8_K32;
break;
+ case GGML_TYPE_Q2_K_R4:
+ assert (ne00 % QK_K == 0);
+ mm.funcs[0] = mul_mat_q2_k_r4_q8_k<1>;
+ mm.funcs[1] = mul_mat_q2_k_r4_q8_k<2>;
+ mm.funcs[2] = mul_mat_q2_k_r4_q8_k<3>;
+ mm.funcs[3] = mul_mat_q2_k_r4_q8_k<4>;
+ mm.funcs[4] = mul_mat_q2_k_r4_q8_k<5>;
+ mm.funcs[5] = mul_mat_q2_k_r4_q8_k<6>;
+ mm.funcs[6] = mul_mat_q2_k_r4_q8_k<7>;
+ mm.funcs[7] = mul_mat_q2_k_r4_q8_k<8>;
+ expected_typeB = GGML_TYPE_Q8_K;
+ break;
case GGML_TYPE_Q3_K_R4:
assert (ne00 % QK_K == 0);
mm.funcs[0] = mul_mat_q3_k_r4_q8_k<1>;
@@ -8361,6 +8530,88 @@ IQK_ALWAYS_INLINE void prepare_q4_k_quants(const uint8x16_t& m4, const uint8x16x
}
template <int nrc_y>
+void mul_mat_q2_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ auto mf = vdupq_n_u8(0x0f);
+ auto m03 = vdupq_n_u8(0x03);
+ int nbl = n / QK_K;
+ int8x16_t qx[4];
+ float32x4_t acc[nrc_y] = {};
+ int16x8x4_t i16scales;
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ const block_q2_k_r4 * iq2 = (const block_q2_k_r4 *)((const char *)vx + ix*bx);
+ for (int ibl = 0; ibl < nbl; ++ibl) {
+ int32x4_t isum[nrc_y] = {};
+ auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq2[ibl].d));
+ auto m4 = vmulq_f32(vdupq_n_f32(-1.f), vcvt_f32_f16(vld1_f16((const float16_t *)iq2[ibl].d+4)));
+ for (int is = 0; is < 2; ++is) {
+ auto sl = vld1q_u8_x2(iq2[ibl].scales + 32*is);
+ auto m = vshrq_n_u8(sl.val[0], 4);
+ i16scales.val[0] = vmovl_u8(vget_low_u8 (m));
+ i16scales.val[1] = vmovl_u8(vget_high_u8(m));
+ m = vshrq_n_u8(sl.val[1], 4);
+ i16scales.val[2] = vmovl_u8(vget_low_u8 (m));
+ i16scales.val[3] = vmovl_u8(vget_high_u8(m));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sumi = vdupq_n_s32(0);
+ auto bsums = vld1q_s16(q8.y[iy][ibl].bsums + 8*is);
+ auto b8 = vget_low_s16(bsums);
+ //auto bsums = q8.load_bsums(iy, ibl);
+ //auto b8 = vget_low_s16(bsums.val[0]);
+ sumi = vmlal_lane_s16(sumi, vget_low_s16 (i16scales.val[0]), b8, 0);
+ sumi = vmlal_lane_s16(sumi, vget_high_s16(i16scales.val[0]), b8, 1);
+ sumi = vmlal_lane_s16(sumi, vget_low_s16 (i16scales.val[1]), b8, 2);
+ sumi = vmlal_lane_s16(sumi, vget_high_s16(i16scales.val[1]), b8, 3);
+ b8 = vget_high_s16(bsums);
+ sumi = vmlal_lane_s16(sumi, vget_low_s16 (i16scales.val[2]), b8, 0);
+ sumi = vmlal_lane_s16(sumi, vget_high_s16(i16scales.val[2]), b8, 1);
+ sumi = vmlal_lane_s16(sumi, vget_low_s16 (i16scales.val[3]), b8, 2);
+ sumi = vmlal_lane_s16(sumi, vget_high_s16(i16scales.val[3]), b8, 3);
+ acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(m4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(sumi));
+ }
+ m = vandq_u8(sl.val[0], mf);
+ i16scales.val[0] = vmovl_u8(vget_low_u8 (m));
+ i16scales.val[1] = vmovl_u8(vget_high_u8(m));
+ m = vandq_u8(sl.val[1], mf);
+ i16scales.val[2] = vmovl_u8(vget_low_u8 (m));
+ i16scales.val[3] = vmovl_u8(vget_high_u8(m));
+ for (int ib = 0; ib < 4; ++ib) {
+ auto bits = vld1q_u8_x2(iq2[ibl].qs + 128*is + 32*ib);
+ auto scales = vmovl_s16(vget_low_s16 (i16scales.val[ib]));
+ qx[0] = vreinterpretq_s8_u8(vandq_u8( bits.val[0], m03));
+ qx[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(bits.val[0], 2), m03));
+ qx[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(bits.val[0], 4), m03));
+ qx[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(bits.val[0], 6), m03));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib);
+ auto sumi = interleaved_dotq(qx, y);
+ isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
+ }
+ scales = vmovl_s16(vget_high_s16(i16scales.val[ib]));
+ qx[0] = vreinterpretq_s8_u8(vandq_u8( bits.val[1], m03));
+ qx[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(bits.val[1], 2), m03));
+ qx[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(bits.val[1], 4), m03));
+ qx[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(bits.val[1], 6), m03));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib+16);
+ auto sumi = interleaved_dotq(qx, y);
+ isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
+ }
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, acc[iy]);
+ acc[iy] = vdupq_n_f32(0.f);
+ }
+ }
+}
+
+template <int nrc_y>
void mul_mat_q3_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
@@ -9025,6 +9276,10 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq4_xs_r4_q8_k);
expected_Btype = GGML_TYPE_Q8_K32;
break;
+ case GGML_TYPE_Q2_K_R4:
+ SET_MUL_MAT_FUNCTIONS(m, mul_mat_q2_k_r4_q8_k);
+ expected_Btype = GGML_TYPE_Q8_K;
+ break;
case GGML_TYPE_Q3_K_R4:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_q3_k_r4_q8_k);
expected_Btype = GGML_TYPE_Q8_K;
diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp
index 2e59fefe..49e2cf8e 100644
--- a/ggml/src/iqk/iqk_quantize.cpp
+++ b/ggml/src/iqk/iqk_quantize.cpp
@@ -4437,3 +4437,118 @@ void vec_dot_q3_k_r4_q8_k(int n, float * s, size_t bs, const void * vx, size_t b
GGML_UNUSED(by);
}
+//
+// ========================================= q2_k_r4
+//
+
+void quantize_row_q2_k_r4_ref(const float * x, block_q2_k_r4 * y, int64_t k) {
+ quantize_q3_k_r4(x, (void *)y, 4, k/4, nullptr);
+}
+
+void quantize_row_q2_k_r4(const float * x, void * y, int64_t k) {
+ quantize_q2_k_r4(x, y, 4, k/4, nullptr);
+}
+
+namespace {
+inline void convert_q2_k(const block_q2_K& x, uint8_t * L) {
+
+ const uint8_t * qs = x.qs;
+ for (int n = 0; n < QK_K; n += 128) {
+ for (int j = 0; j < 32; ++j) {
+ L[n + j + 0] = (qs[j] >> 0) & 0x3;
+ L[n + j + 32] = (qs[j] >> 2) & 0x3;
+ L[n + j + 64] = (qs[j] >> 4) & 0x3;
+ L[n + j + 96] = (qs[j] >> 6) & 0x3;
+ }
+ qs += 32;
+ }
+}
+}
+
+static void repack_q2_k(int nrows, int n_per_row, const block_q2_K * x, block_q2_k_r4 * y) {
+ GGML_ASSERT(nrows%4 == 0);
+ GGML_ASSERT(n_per_row%QK_K == 0);
+ int nblock = n_per_row/QK_K;
+ const block_q2_K * x4[4];
+ uint8_t L[QK_K];
+ for (int row = 0; row < nrows; row += 4) {
+ for (int k = 0; k < 4; ++k) x4[k] = x + nblock*k;
+ for (int ibl = 0; ibl < nblock; ++ibl) {
+ for (int k = 0; k < 4; ++k) {
+ y[ibl].d[k+0] = x4[k][ibl].d;
+ y[ibl].d[k+4] = x4[k][ibl].dmin;
+ for (int ib = 0; ib < QK_K/16; ++ib) {
+ y[ibl].scales[4*ib+k] = x4[k][ibl].scales[ib];
+ }
+ convert_q2_k(x4[k][ibl], L);
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ for (int i = 0; i < 4; ++i) {
+ y[ibl].qs[32*ib+4*k+i+ 0] = ((L[32*ib+i+ 0] & 0x3) << 0) | ((L[32*ib+i+ 4] & 0x3) << 2) | ((L[32*ib+i+ 8] & 0x3) << 4) | ((L[32*ib+i+12] & 0x3) << 6);
+ y[ibl].qs[32*ib+4*k+i+16] = ((L[32*ib+i+16] & 0x3) << 0) | ((L[32*ib+i+20] & 0x3) << 2) | ((L[32*ib+i+24] & 0x3) << 4) | ((L[32*ib+i+28] & 0x3) << 6);
+ }
+ }
+ }
+ }
+ x += 4*nblock;
+ y += nblock;
+ }
+}
+
+size_t quantize_q2_k_r4(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) {
+ GGML_ASSERT(nrows%4 == 0);
+ GGML_ASSERT(n_per_row%QK_K == 0);
+ char * qcur = (char *)dst;
+ auto row_size = ggml_row_size(GGML_TYPE_Q2_K, n_per_row);
+ std::vector<char> qtmp(4*row_size);
+ for (int row = 0; row < nrows; row += 4) {
+ quantize_q2_K(src, (void *)qtmp.data(), 4, n_per_row, imatrix);
+ repack_q2_k(4, n_per_row, (const block_q2_K *)qtmp.data(), (block_q2_k_r4 *)qcur);
+ qcur += 4*row_size;
+ src += 4*n_per_row;
+ }
+ return nrows*row_size;
+}
+
+void dequantize_row_q2_k_r4(const block_q2_k_r4 * x, float * y, int64_t k) {
+ auto n_per_row = k/4;
+ float * y4[4] = {y, y + n_per_row, y + 2*n_per_row, y + 3*n_per_row};
+ int nblock = n_per_row/QK_K;
+ for (int ibl = 0; ibl < nblock; ++ibl) {
+ for (int k = 0; k < 4; ++k) {
+ const float d = GGML_FP16_TO_FP32(x[ibl].d[k+0]);
+ const float m = GGML_FP16_TO_FP32(x[ibl].d[k+4]);
+ auto ql = x[ibl].qs;
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ float dl1 = d * (x[ibl].scales[8*ib + k + 0] & 0xf);
+ float ml1 = m * (x[ibl].scales[8*ib + k + 0] >> 4);
+ float dl2 = d * (x[ibl].scales[8*ib + k + 4] & 0xf);
+ float ml2 = m * (x[ibl].scales[8*ib + k + 4] >> 4);
+ for (int i = 0; i < 4; ++i) {
+ y4[k][QK_K*ibl+32*ib+i+ 0] = dl1 * ((ql[4*k+i+ 0] >> 0) & 3) - ml1;
+ y4[k][QK_K*ibl+32*ib+i+ 4] = dl1 * ((ql[4*k+i+ 0] >> 2) & 3) - ml1;
+ y4[k][QK_K*ibl+32*ib+i+ 8] = dl1 * ((ql[4*k+i+ 0] >> 4) & 3) - ml1;
+ y4[k][QK_K*ibl+32*ib+i+12] = dl1 * ((ql[4*k+i+ 0] >> 6) & 3) - ml1;
+ y4[k][QK_K*ibl+32*ib+i+16] = dl2 * ((ql[4*k+i+16] >> 0) & 3) - ml2;
+ y4[k][QK_K*ibl+32*ib+i+20] = dl2 * ((ql[4*k+i+16] >> 2) & 3) - ml2;
+ y4[k][QK_K*ibl+32*ib+i+24] = dl2 * ((ql[4*k+i+16] >> 4) & 3) - ml2;
+ y4[k][QK_K*ibl+32*ib+i+28] = dl2 * ((ql[4*k+i+16] >> 6) & 3) - ml2;
+ }
+ ql += 32;
+ }
+ }
+ }
+}
+
+void vec_dot_q2_k_r4_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
+#if GGML_USE_IQK_MULMAT
+ if (iqk_mul_mat(1, 1, n, GGML_TYPE_Q2_K_R4, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) {
+ return;
+ }
+#endif
+ GGML_ASSERT(n%QK4_NL == 0);
+ GGML_ASSERT(nrc == 1);
+ GGML_UNUSED(bs);
+ GGML_UNUSED(bx);
+ GGML_UNUSED(by);
+}
+
diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h
index f3a4d8e2..4a1c31f8 100644
--- a/ggml/src/iqk/iqk_quantize.h
+++ b/ggml/src/iqk/iqk_quantize.h
@@ -115,6 +115,12 @@ size_t quantize_q3_k_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT ds
void dequantize_row_q3_k_r4(const block_q3_k_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void vec_dot_q3_k_r4_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void quantize_row_q2_k_r4_ref(const float * GGML_RESTRICT x, block_q2_k_r4 * GGML_RESTRICT y, int64_t k);
+void quantize_row_q2_k_r4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+size_t quantize_q2_k_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+void dequantize_row_q2_k_r4(const block_q2_k_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void vec_dot_q2_k_r4_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
void quantize_row_q4_k_r4_ref(const float * GGML_RESTRICT x, block_q4_k_r4 * GGML_RESTRICT y, int64_t k);
void quantize_row_q4_k_r4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
size_t quantize_q4_k_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
diff --git a/include/llama.h b/include/llama.h
index f87d13ff..0992b10a 100644
--- a/include/llama.h
+++ b/include/llama.h
@@ -183,6 +183,7 @@ extern "C" {
LLAMA_FTYPE_MOSTLY_Q4_0_R4 = 202, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q8_0_R4 = 207, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_0_R4 = 208, // except 1d tensors
+ LLAMA_FTYPE_MOSTLY_Q2_K_R4 = 210, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q3_K_R4 = 211, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_K_R4 = 214, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_K_R4 = 216, // except 1d tensors
diff --git a/src/llama.cpp b/src/llama.cpp
index 9f41724f..6ecf0452 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -4545,6 +4545,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
case LLAMA_FTYPE_MOSTLY_Q6_0: return "Q6_0";
case LLAMA_FTYPE_MOSTLY_Q8_0: return "Q8_0";
case LLAMA_FTYPE_MOSTLY_Q2_K: return "Q2_K - Medium";
+ case LLAMA_FTYPE_MOSTLY_Q2_K_R4: return "Q2_K_R4";
case LLAMA_FTYPE_MOSTLY_Q2_K_S: return "Q2_K - Small";
case LLAMA_FTYPE_MOSTLY_Q3_K_S: return "Q3_K - Small";
case LLAMA_FTYPE_MOSTLY_Q3_K_M: return "Q3_K - Medium";
@@ -15794,6 +15795,9 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
else if (new_type == GGML_TYPE_IQ4_XS_R4) {
new_type = GGML_TYPE_IQ4_XS;
}
+ else if (new_type == GGML_TYPE_Q2_K_R4) {
+ new_type = GGML_TYPE_Q2_K;
+ }
else if (new_type == GGML_TYPE_Q3_K_R4) {
new_type = GGML_TYPE_Q3_K;
}
@@ -15859,6 +15863,9 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && qs.model.hparams.n_gqa() >= 4) {
new_type = GGML_TYPE_Q4_K;
}
+ else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_R4 && qs.model.hparams.n_gqa() >= 4) {
+ new_type = GGML_TYPE_Q4_K_R4;
+ }
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K : qs.model.hparams.n_gqa() >= 2 ? GGML_TYPE_IQ3_K
: !qs.has_imatrix ? GGML_TYPE_IQ3_S : GGML_TYPE_IQ3_XXS;
@@ -15950,6 +15957,9 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S) {
if (i_layer < n_layer/8) new_type = GGML_TYPE_Q4_K;
}
+ else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_R4) {
+ if (i_layer < n_layer/8) new_type = GGML_TYPE_Q4_K_R4;
+ }
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS && !qs.has_imatrix) {
new_type = i_layer < n_layer/8 ? GGML_TYPE_Q4_K : GGML_TYPE_Q3_K;
}
@@ -16009,7 +16019,8 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S ||
ftype == LLAMA_FTYPE_MOSTLY_IQ3_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_K ||
ftype == LLAMA_FTYPE_MOSTLY_IQ2_K || ftype == LLAMA_FTYPE_MOSTLY_IQ3_K || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_R4 ||
- ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS_R4 || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_R4) {
+ ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS_R4 || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_R4 ||
+ ftype == LLAMA_FTYPE_MOSTLY_Q2_K_R4) {
new_type = GGML_TYPE_Q5_K;
}
} else {
@@ -16079,7 +16090,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
new_type == GGML_TYPE_IQ5_K || new_type == GGML_TYPE_IQ3_K || new_type == GGML_TYPE_Q4_K_R4 ||
new_type == GGML_TYPE_IQ6_K || new_type == GGML_TYPE_IQ4_KS || new_type == GGML_TYPE_IQ4_XS_R4 ||
new_type == GGML_TYPE_IQ2_KS || new_type == GGML_TYPE_IQ4_KSS || new_type == GGML_TYPE_Q6_K_R4 ||
- new_type == GGML_TYPE_Q5_K_R4 || new_type == GGML_TYPE_Q3_K_R4) {
+ new_type == GGML_TYPE_Q5_K_R4 || new_type == GGML_TYPE_Q3_K_R4 || new_type == GGML_TYPE_Q2_K_R4) {
int nx = tensor->ne[0];
int ny = tensor->ne[1];
if (nx % QK_K != 0) {
@@ -16106,6 +16117,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ1_M:
case GGML_TYPE_Q2_K:
+ case GGML_TYPE_Q2_K_R4:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q3_K_R4:
case GGML_TYPE_IQ2_K:
@@ -16204,6 +16216,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
// K-quants
case LLAMA_FTYPE_MOSTLY_Q2_K_S:
case LLAMA_FTYPE_MOSTLY_Q2_K: default_type = GGML_TYPE_Q2_K; break;
+ case LLAMA_FTYPE_MOSTLY_Q2_K_R4: default_type = GGML_TYPE_Q2_K_R4; break;
case LLAMA_FTYPE_MOSTLY_IQ3_XS: default_type = GGML_TYPE_IQ3_S; break;
case LLAMA_FTYPE_MOSTLY_Q3_K_S:
case LLAMA_FTYPE_MOSTLY_Q3_K_M:
@@ -16616,6 +16629,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q8_0;
else chunk_size_multiplier = 4;
}
+ else if (new_type == GGML_TYPE_Q2_K_R4) {
+ if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q2_K;
+ else chunk_size_multiplier = 4;
+ }
else if (new_type == GGML_TYPE_Q3_K_R4) {
if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q3_K;
else chunk_size_multiplier = 4;