summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/quantize/quantize.cpp1
-rw-r--r--ggml/include/ggml.h2
-rw-r--r--ggml/src/ggml-common.h9
-rw-r--r--ggml/src/ggml-quants.c1
-rw-r--r--ggml/src/ggml.c22
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp280
-rw-r--r--ggml/src/iqk/iqk_quantize.cpp119
-rw-r--r--ggml/src/iqk/iqk_quantize.h6
-rw-r--r--include/llama.h1
-rw-r--r--src/llama.cpp14
10 files changed, 454 insertions, 1 deletions
diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp
index 00cd3cf0..db0fc0d4 100644
--- a/examples/quantize/quantize.cpp
+++ b/examples/quantize/quantize.cpp
@@ -62,6 +62,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
{ "Q4_K_S", LLAMA_FTYPE_MOSTLY_Q4_K_S, " 3.59G, +0.0992 ppl @ LLaMA-v1-7B", },
{ "Q4_K_M", LLAMA_FTYPE_MOSTLY_Q4_K_M, " 3.80G, +0.0532 ppl @ LLaMA-v1-7B", },
{ "Q5_K", LLAMA_FTYPE_MOSTLY_Q5_K_M, "alias for Q5_K_M", },
+ { "Q5_K_R4", LLAMA_FTYPE_MOSTLY_Q5_K_R4, "Q5_K_S repacked", },
{ "Q5_K_S", LLAMA_FTYPE_MOSTLY_Q5_K_S, " 4.33G, +0.0400 ppl @ LLaMA-v1-7B", },
{ "Q5_K_M", LLAMA_FTYPE_MOSTLY_Q5_K_M, " 4.45G, +0.0122 ppl @ LLaMA-v1-7B", },
{ "Q6_K", LLAMA_FTYPE_MOSTLY_Q6_K, " 5.15G, +0.0008 ppl @ LLaMA-v1-7B", },
diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
index 6486407f..7f766497 100644
--- a/ggml/include/ggml.h
+++ b/ggml/include/ggml.h
@@ -413,6 +413,7 @@ extern "C" {
GGML_TYPE_Q5_0_R4 = 206,
GGML_TYPE_Q8_0_R4 = 208,
GGML_TYPE_Q4_K_R4 = 212,
+ GGML_TYPE_Q5_K_R4 = 213,
GGML_TYPE_Q6_K_R4 = 214,
GGML_TYPE_IQ4_NL_R4 = 220,
GGML_TYPE_IQ4_XS_R4 = 223,
@@ -481,6 +482,7 @@ extern "C" {
GGML_FTYPE_MOSTLY_Q8_0_R4 = 207, // except 1d tensors
GGML_FTYPE_MOSTLY_Q5_0_R4 = 208, // except 1d tensors
GGML_FTYPE_MOSTLY_Q4_K_R4 = 212, // except 1d tensors
+ GGML_FTYPE_MOSTLY_Q5_K_R4 = 215, // except 1d tensors
GGML_FTYPE_MOSTLY_Q6_K_R4 = 214, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ4_NL_R4 = 219, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ4_XS_R4 = 222, // except 1d tensors
diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h
index 8ace3d6f..2d73d3f8 100644
--- a/ggml/src/ggml-common.h
+++ b/ggml/src/ggml-common.h
@@ -319,6 +319,15 @@ typedef struct {
} block_q5_K;
static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_half) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
+typedef struct {
+ ggml_half d[8];
+ uint8_t scales_h[QK_K/16];// scales and mins, quantized with 6 bits
+ uint8_t scales_l[QK_K/8]; // scales and mins, quantized with 6 bits
+ uint8_t qh[QK_K/2]; // quants, high bit
+ uint8_t qs[QK_K*2]; // quants, low 4 bits
+} block_q5_k_r4;
+static_assert(sizeof(block_q5_k_r4) == 8*sizeof(ggml_half) + QK_K/16 + QK_K/8 + QK_K/2 + QK_K*2, "wrong q5_k_r4 block size/padding");
+
// 6-bit quantization
// weight is represented as x = a * q
// 16 blocks of 16 elements each
diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c
index 67f54da7..f4f375c9 100644
--- a/ggml/src/ggml-quants.c
+++ b/ggml/src/ggml-quants.c
@@ -15203,6 +15203,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
case GGML_TYPE_Q6_0_R4: break;
case GGML_TYPE_Q8_0_R4: break;
case GGML_TYPE_Q4_K_R4: break;
+ case GGML_TYPE_Q5_K_R4: break;
case GGML_TYPE_Q6_K_R4: break;
case GGML_TYPE_Q4_0_4_4:
case GGML_TYPE_Q4_0_4_8:
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index b92c2352..53c51ba6 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -914,6 +914,19 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.nrows = 1,
.row_meta_size = 0,
},
+ [GGML_TYPE_Q5_K_R4] = {
+ .type_name = "q5_k_r4",
+ .blck_size = QK_K,
+ .type_size = sizeof(block_q5_K),
+ .is_quantized = true,
+ .to_float = (ggml_to_float_t) dequantize_row_q5_k_r4,
+ .from_float = quantize_row_q5_k_r4,
+ .from_float_ref = (ggml_from_float_t) quantize_row_q5_k_r4_ref,
+ .vec_dot = vec_dot_q5_k_r4_q8_k,
+ .vec_dot_type = GGML_TYPE_Q8_K32,
+ .nrows = 1,
+ .row_meta_size = 0,
+ },
[GGML_TYPE_Q6_K] = {
.type_name = "q6_K",
.blck_size = QK_K,
@@ -4048,6 +4061,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
case GGML_FTYPE_MOSTLY_Q4_K: wtype = GGML_TYPE_Q4_K; break;
case GGML_FTYPE_MOSTLY_Q4_K_R4: wtype = GGML_TYPE_Q4_K_R4; break;
case GGML_FTYPE_MOSTLY_Q5_K: wtype = GGML_TYPE_Q5_K; break;
+ case GGML_FTYPE_MOSTLY_Q5_K_R4: wtype = GGML_TYPE_Q5_K_R4; break;
case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break;
case GGML_FTYPE_MOSTLY_Q6_K_R4: wtype = GGML_TYPE_Q6_K_R4; break;
case GGML_FTYPE_MOSTLY_IQ2_XXS: wtype = GGML_TYPE_IQ2_XXS; break;
@@ -10580,6 +10594,7 @@ static void ggml_compute_forward_add(
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q4_K_R4:
case GGML_TYPE_Q5_K:
+ case GGML_TYPE_Q5_K_R4:
case GGML_TYPE_Q6_K:
case GGML_TYPE_Q6_K_R4:
case GGML_TYPE_IQ2_XXS:
@@ -11031,6 +11046,7 @@ static void ggml_compute_forward_add1(
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q4_K_R4:
case GGML_TYPE_Q5_K:
+ case GGML_TYPE_Q5_K_R4:
case GGML_TYPE_Q6_K:
case GGML_TYPE_Q6_K_R4:
case GGML_TYPE_IQ2_XXS:
@@ -11179,6 +11195,7 @@ static void ggml_compute_forward_acc(
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q4_K_R4:
case GGML_TYPE_Q5_K:
+ case GGML_TYPE_Q5_K_R4:
case GGML_TYPE_Q6_K:
case GGML_TYPE_Q6_K_R4:
case GGML_TYPE_IQ2_XXS:
@@ -14373,6 +14390,7 @@ static void ggml_compute_forward_out_prod(
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q4_K_R4:
case GGML_TYPE_Q5_K:
+ case GGML_TYPE_Q5_K_R4:
case GGML_TYPE_Q6_K:
case GGML_TYPE_Q6_K_R4:
case GGML_TYPE_IQ2_XXS:
@@ -14761,6 +14779,7 @@ static void ggml_compute_forward_set(
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q4_K_R4:
case GGML_TYPE_Q5_K:
+ case GGML_TYPE_Q5_K_R4:
case GGML_TYPE_Q6_K:
case GGML_TYPE_Q6_K_R4:
case GGML_TYPE_IQ2_XXS:
@@ -15043,6 +15062,7 @@ static void ggml_compute_forward_get_rows(
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q4_K_R4:
case GGML_TYPE_Q5_K:
+ case GGML_TYPE_Q5_K_R4:
case GGML_TYPE_Q6_K:
case GGML_TYPE_Q6_K_R4:
case GGML_TYPE_IQ2_XXS:
@@ -15652,6 +15672,7 @@ static void ggml_compute_forward_clamp(
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q4_K_R4:
case GGML_TYPE_Q5_K:
+ case GGML_TYPE_Q5_K_R4:
case GGML_TYPE_Q6_K:
case GGML_TYPE_Q6_K_R4:
case GGML_TYPE_IQ2_XXS:
@@ -22489,6 +22510,7 @@ size_t ggml_quantize_chunk(
case GGML_TYPE_Q4_K: result = quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q4_K_R4: result = quantize_q4_k_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q5_K: result = quantize_q5_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_Q5_K_R4: result = quantize_q5_k_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q6_K: result = quantize_q6_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q6_K_R4: result = quantize_q6_k_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ2_XXS: result = quantize_iq2_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 5cf9013d..a1600cbc 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -3256,6 +3256,190 @@ static void mul_mat_q4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataIn
#endif
template <int nrc_y>
+static void mul_mat_q5_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ auto mf = _mm256_set1_epi8(0xf);
+ auto m10 = _mm256_set1_epi8(0x10);
+ auto m30 = _mm256_set1_epi8(0x30);
+#ifndef HAVE_FANCY_SIMD
+ auto m1 = _mm256_set1_epi16(1);
+#endif
+ int nbl = n / QK_K;
+ union { __m256i vec; uint32_t val[8]; } hd;
+ __m256 acc[nrc_y] = {};
+ __m256i qx[4];
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ const block_q5_k_r4 * iq5 = (const block_q5_k_r4 *)((const char *)vx + (ix+0)*bx);
+ for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
+ auto dl = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq5[ibl].d));
+ auto d4 = _mm256_set_m128(_mm256_castps256_ps128(dl), _mm256_castps256_ps128(dl));
+ auto m4 = _mm256_mul_ps(_mm256_set1_ps(-1.0f), _mm256_set_m128(_mm256_extractf128_ps(dl, 1), _mm256_extractf128_ps(dl, 1)));
+ if constexpr (nrc_y == 1) {
+ d4 = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(0, ibl)));
+ }
+ auto lbits = _mm256_loadu_si256((const __m256i *)iq5[ibl].scales_l);
+ auto hbits128 = _mm_loadu_si128((const __m128i *)iq5[ibl].scales_h);
+ auto hbits = MM256_SET_M128I(hbits128, _mm_slli_epi16(hbits128, 4));
+ hd.vec = _mm256_or_si256(_mm256_and_si256(lbits, mf), _mm256_and_si256(hbits, m30));
+ auto mins = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits, 4), mf), _mm256_and_si256(_mm256_srli_epi16(hbits, 2), m30));
+ auto shuffle = _mm256_set1_epi64x(0x0000000400000000);
+ auto c1 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle)))));
+ shuffle = _mm256_add_epi32(shuffle, _mm256_set1_epi32(1));
+ auto c2 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle)))));
+ shuffle = _mm256_add_epi32(shuffle, _mm256_set1_epi32(1));
+ auto c3 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle)))));
+ shuffle = _mm256_add_epi32(shuffle, _mm256_set1_epi32(1));
+ auto c4 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle)))));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto bs = _mm256_loadu_ps((const float *)q8.y[iy][ibl].bsums);
+ acc[iy] = _mm256_fmadd_ps(c1, _mm256_shuffle_ps(bs, bs, 0x00), acc[iy]);
+ acc[iy] = _mm256_fmadd_ps(c2, _mm256_shuffle_ps(bs, bs, 0x55), acc[iy]);
+ acc[iy] = _mm256_fmadd_ps(c3, _mm256_shuffle_ps(bs, bs, 0xaa), acc[iy]);
+ acc[iy] = _mm256_fmadd_ps(c4, _mm256_shuffle_ps(bs, bs, 0xff), acc[iy]);
+ }
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ auto scales_d = _mm256_mul_ps(d4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib]))));
+ auto lbits1 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+0);
+ auto lbits2 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+1);
+ auto hbits128 = _mm_loadu_si128((const __m128i*)iq5[ibl].qh + ib);
+ auto hbits = MM256_SET_M128I(hbits128, _mm_slli_epi16(hbits128, 4));
+ qx[0] = _mm256_or_si256(_mm256_and_si256(lbits1, mf), _mm256_and_si256(m10, hbits));
+ qx[1] = _mm256_or_si256(_mm256_and_si256(lbits2, mf), _mm256_and_si256(m10, _mm256_srli_epi16(hbits, 2)));
+ qx[2] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits1, 4), mf), _mm256_and_si256(m10, _mm256_srli_epi16(hbits, 1)));
+ qx[3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits2, 4), mf), _mm256_and_si256(m10, _mm256_srli_epi16(hbits, 3)));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
+#ifdef HAVE_FANCY_SIMD
+ auto sumi = _mm256_setzero_si256();
+ sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
+#else
+ auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
+ _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
+ auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
+ _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
+ auto sumi = _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2));
+#endif
+ if constexpr (nrc_y == 1) {
+ acc[iy] = _mm256_fmadd_ps(scales_d, _mm256_cvtepi32_ps(sumi), acc[iy]);
+ } else {
+ float d8 = q8.scale(iy, ibl);
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales_d, _mm256_set1_ps(d8)), _mm256_cvtepi32_ps(sumi), acc[iy]);
+ }
+ }
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
+ acc[iy] = _mm256_setzero_ps();
+ info.store(ix+0, iy, sum);
+ }
+ }
+}
+
+#ifdef HAVE_FANCY_SIMD
+template <int nrc_y>
+static void mul_mat_q5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ if constexpr (nrc_y == 1){
+ mul_mat_q4_k_r4_q8_k_avx2<1>(n, vx, bx, info, nrc_x);
+ } else {
+ GGML_ASSERT(nrc_x%8 == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ auto mf = _mm512_set1_epi8(0xf);
+ auto m10 = _mm512_set1_epi8(0x10);
+ int nbl = n / QK_K;
+ using helper_t = union { __m512i vec; uint32_t val[16]; };
+ helper_t hd, hm;
+ __m512 acc[nrc_y] = {};
+ __m512 d4s[nrc_y];
+ __m512i qx[4];
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ const block_q5_k_r4 * iq5l = (const block_q5_k_r4 *)((const char *)vx + (ix+0)*bx);
+ const block_q5_k_r4 * iq5h = (const block_q5_k_r4 *)((const char *)vx + (ix+4)*bx);
+ for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
+ auto d1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq5l[ibl].d));
+ auto d2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq5h[ibl].d));
+ auto dl = _mm256_castps256_ps128(d1);
+ auto ml = _mm256_extractf128_ps(d1, 1);
+ auto dh = _mm256_castps256_ps128(d2);
+ auto mh = _mm256_extractf128_ps(d2, 1);
+ auto d4 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_set_m128(dl, dl)), _mm256_set_m128(dh, dh), 1);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ d4s[iy] = _mm512_mul_ps(d4, _mm512_set1_ps(q8.scale(iy, ibl)));
+ }
+ auto m4 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_set_m128(ml, ml)), _mm256_set_m128(mh, mh), 1);
+ m4 = _mm512_mul_ps(m4, _mm512_set1_ps(-0.5f));
+ auto slbits_l = _mm256_loadu_si256((const __m256i *)iq5l[ibl].scales_l);
+ auto shbits_l = _mm256_loadu_si256((const __m256i *)iq5h[ibl].scales_l);
+ auto slb = _mm512_inserti32x8(_mm512_castsi256_si512(slbits_l), shbits_l, 1);
+ auto sld = _mm512_and_si512(slb, mf);
+ auto slm = _mm512_and_si512(_mm512_srli_epi16(slb, 4), mf);
+ auto slbits_h = _mm_loadu_si128((const __m128i *)iq5l[ibl].scales_h);
+ auto shbits_h = _mm_loadu_si128((const __m128i *)iq5h[ibl].scales_h);
+ auto slbits_h2 = MM256_SET_M128I(_mm_srli_epi16(slbits_h, 4), slbits_h);
+ auto shbits_h2 = MM256_SET_M128I(_mm_srli_epi16(shbits_h, 4), shbits_h);
+ auto shb = _mm512_inserti32x8(_mm512_castsi256_si512(slbits_h2), shbits_h2, 1);
+ auto shd = _mm512_and_si512(_mm512_slli_epi16(shb, 4), _mm512_set1_epi8(0x30));
+ auto shm = _mm512_and_si512(_mm512_slli_epi16(shb, 2), _mm512_set1_epi8(0x30));
+ hd.vec = _mm512_or_si512(sld, shd);
+ hm.vec = _mm512_or_si512(slm, shm);
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ auto scales1 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib+0]));
+ auto scales2 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib+8]));
+ auto iscales = _mm512_inserti32x8(_mm512_castsi256_si512(scales1), scales2, 1);
+ auto scales = _mm512_cvtepi32_ps(iscales);
+ scales1 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hm.val[ib+0]));
+ scales2 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hm.val[ib+8]));
+ iscales = _mm512_inserti32x8(_mm512_castsi256_si512(scales1), scales2, 1);
+ auto scales_m = _mm512_mul_ps(m4, _mm512_cvtepi32_ps(iscales));
+ auto lbits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq5l[ibl].qs+2*ib+0)),
+ _mm256_loadu_si256((const __m256i *)iq5h[ibl].qs+2*ib+0), 1);
+ auto lbits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq5l[ibl].qs+2*ib+1)),
+ _mm256_loadu_si256((const __m256i *)iq5h[ibl].qs+2*ib+1), 1);
+ auto hbits1 = _mm_loadu_si128((const __m128i*)iq5l[ibl].qh+ib);
+ auto hbits2 = _mm_loadu_si128((const __m128i*)iq5h[ibl].qh+ib);
+ auto hbl = MM256_SET_M128I(hbits1, _mm_slli_epi16(hbits1, 4));
+ auto hbh = MM256_SET_M128I(hbits2, _mm_slli_epi16(hbits2, 4));
+ auto hbits = _mm512_inserti32x8(_mm512_castsi256_si512(hbl), hbh, 1);
+ qx[0] = _mm512_or_si512(_mm512_and_si512(lbits1, mf), _mm512_and_si512(m10, hbits));
+ qx[1] = _mm512_or_si512(_mm512_and_si512(lbits2, mf), _mm512_and_si512(m10, _mm512_srli_epi16(hbits, 2)));
+ qx[2] = _mm512_or_si512(_mm512_and_si512(_mm512_srli_epi16(lbits1, 4), mf), _mm512_and_si512(m10, _mm512_srli_epi16(hbits, 1)));
+ qx[3] = _mm512_or_si512(_mm512_and_si512(_mm512_srli_epi16(lbits2, 4), mf), _mm512_and_si512(m10, _mm512_srli_epi16(hbits, 3)));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
+ auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1);
+ auto sumi = _mm512_setzero_si512();
+ sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
+ acc[iy] = _mm512_fmadd_ps(_mm512_mul_ps(scales, d4s[iy]), _mm512_cvtepi32_ps(sumi), acc[iy]);
+ float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib];
+ acc[iy] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(m8), acc[iy]);
+ }
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(acc[iy], 0), _mm512_extractf32x4_ps(acc[iy], 1));
+ auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(acc[iy], 2), _mm512_extractf32x4_ps(acc[iy], 3));
+ info.store(ix+0, iy, sum1);
+ info.store(ix+4, iy, sum2);
+ acc[iy] = _mm512_setzero_ps();
+ }
+ }
+ }
+}
+#else
+template <int nrc_y>
+static void mul_mat_q5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ mul_mat_q5_k_r4_q8_k_avx2<nrc_y>(n, vx, bx, info, nrc_x);
+}
+#endif
+
+template <int nrc_y>
static void mul_mat_q6_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
@@ -5374,6 +5558,18 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
mm.funcs[7] = mul_mat_q4_k_r4_q8_k<8>;
expected_typeB = GGML_TYPE_Q8_K32;
break;
+ case GGML_TYPE_Q5_K_R4:
+ assert (ne00 % QK_K == 0);
+ mm.funcs[0] = mul_mat_q5_k_r4_q8_k<1>;
+ mm.funcs[1] = mul_mat_q5_k_r4_q8_k<2>;
+ mm.funcs[2] = mul_mat_q5_k_r4_q8_k<3>;
+ mm.funcs[3] = mul_mat_q5_k_r4_q8_k<4>;
+ mm.funcs[4] = mul_mat_q5_k_r4_q8_k<5>;
+ mm.funcs[5] = mul_mat_q5_k_r4_q8_k<6>;
+ mm.funcs[6] = mul_mat_q5_k_r4_q8_k<7>;
+ mm.funcs[7] = mul_mat_q5_k_r4_q8_k<8>;
+ expected_typeB = GGML_TYPE_Q8_K32;
+ break;
case GGML_TYPE_Q6_K_R4:
assert (ne00 % QK_K == 0);
mm.funcs[0] = mul_mat_q6_k_r4_q8_k<1>;
@@ -8147,6 +8343,86 @@ void mul_mat_q4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& inf
}
template <int nrc_y>
+void mul_mat_q5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ auto mf = vdupq_n_u8(0xf);
+ auto m30 = vdupq_n_u8(0x30);
+ auto m10 = vdupq_n_u8(0x10);
+ int nbl = n / QK_K;
+ int8x16_t qx[8];
+ int8x16x4_t iscales;
+ float32x4x4_t scales;
+ float32x4_t acc[nrc_y] = {};
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ const block_q5_k_r4 * iq5 = (const block_q5_k_r4 *)((const char *)vx + ix*bx);
+ for (int ibl = 0; ibl < nbl; ++ibl) {
+ auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq5[ibl].d));
+ auto m4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq5[ibl].d+4));
+ m4 = vmulq_f32(m4, vdupq_n_f32(-1.f));
+ if constexpr (nrc_y == 1) {
+ d4 = vmulq_f32(d4, vdupq_n_f32(q8.scale(0, ibl)));
+ }
+ auto sl = vld1q_u8_x2(iq5[ibl].scales_l);
+ auto sh = vld1q_u8(iq5[ibl].scales_h);
+ iscales.val[0] = vorrq_u8(vandq_u8(sl.val[0], mf), vandq_u8(vshlq_n_u8(sh, 4), m30));
+ iscales.val[1] = vorrq_u8(vandq_u8(sl.val[1], mf), vandq_u8(sh, m30));
+ iscales.val[2] = vorrq_u8(vshrq_n_u8(sl.val[0], 4), vandq_u8(vshlq_n_u8(sh, 2), m30));
+ iscales.val[3] = vorrq_u8(vshrq_n_u8(sl.val[1], 4), vandq_u8(vshrq_n_u8(sh, 2), m30));
+ for (int is = 0; is < 2; ++is) {
+ auto iscales16_1 = vmovl_s8(vget_low_s8(iscales.val[is+2]));
+ auto iscales16_2 = vmovl_s8(vget_high_s8(iscales.val[is+2]));
+ scales.val[0] = vmulq_f32(m4, vcvtq_f32_s32(vmovl_s16(vget_low_s16(iscales16_1))));
+ scales.val[1] = vmulq_f32(m4, vcvtq_f32_s32(vmovl_s16(vget_high_s16(iscales16_1))));
+ scales.val[2] = vmulq_f32(m4, vcvtq_f32_s32(vmovl_s16(vget_low_s16(iscales16_2))));
+ scales.val[3] = vmulq_f32(m4, vcvtq_f32_s32(vmovl_s16(vget_high_s16(iscales16_2))));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto m8 = vld1q_f32((const float *)q8.y[iy][ibl].bsums + 4*is);
+ acc[iy] = vmlaq_laneq_f32(acc[iy], scales.val[0], m8, 0);
+ acc[iy] = vmlaq_laneq_f32(acc[iy], scales.val[1], m8, 1);
+ acc[iy] = vmlaq_laneq_f32(acc[iy], scales.val[2], m8, 2);
+ acc[iy] = vmlaq_laneq_f32(acc[iy], scales.val[3], m8, 3);
+ }
+ iscales16_1 = vmovl_s8(vget_low_s8(iscales.val[is]));
+ iscales16_2 = vmovl_s8(vget_high_s8(iscales.val[is]));
+ scales.val[0] = vmulq_f32(d4, vcvtq_f32_s32(vmovl_s16(vget_low_s16(iscales16_1))));
+ scales.val[1] = vmulq_f32(d4, vcvtq_f32_s32(vmovl_s16(vget_high_s16(iscales16_1))));
+ scales.val[2] = vmulq_f32(d4, vcvtq_f32_s32(vmovl_s16(vget_low_s16(iscales16_2))));
+ scales.val[3] = vmulq_f32(d4, vcvtq_f32_s32(vmovl_s16(vget_high_s16(iscales16_2))));
+ for (int ib = 0; ib < 4; ++ib) {
+ auto lbits = vld1q_u8_x4(iq5[ibl].qs + 256*is + 64*ib);
+ auto hbits2 = vld1q_u8(iq5[ibl].qh + 64*is + 16*ib);
+ auto hbits1 = vshlq_n_u8(hbits2, 4);
+ prepare_q4_k_quants(mf, lbits, qx);
+ qx[0] = vorrq_u8(qx[0], vandq_u8(m10, hbits1));
+ qx[1] = vorrq_u8(qx[1], vandq_u8(m10, hbits2));
+ qx[2] = vorrq_u8(qx[2], vandq_u8(m10, vshrq_n_u8(hbits1, 2)));
+ qx[3] = vorrq_u8(qx[3], vandq_u8(m10, vshrq_n_u8(hbits2, 2)));
+ qx[4] = vorrq_u8(qx[4], vandq_u8(m10, vshrq_n_u8(hbits1, 1)));
+ qx[5] = vorrq_u8(qx[5], vandq_u8(m10, vshrq_n_u8(hbits2, 1)));
+ qx[6] = vorrq_u8(qx[6], vandq_u8(m10, vshrq_n_u8(hbits1, 3)));
+ qx[7] = vorrq_u8(qx[7], vandq_u8(m10, vshrq_n_u8(hbits2, 3)));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = vld1q_s8_x2(q8.y[iy][ibl].qs+128*is+32*ib);
+ auto sumi = interleaved_dotq(qx, y);
+ if constexpr (nrc_y == 1) {
+ acc[iy] = vfmaq_f32(acc[iy], scales.val[ib], vcvtq_f32_s32(sumi));
+ } else {
+ auto d4d8 = vmulq_f32(scales.val[ib], vdupq_n_f32(q8.scale(iy, ibl)));
+ acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi));
+ }
+ }
+ }
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, acc[iy]);
+ acc[iy] = vdupq_n_f32(0.f);
+ }
+ }
+}
+
+template <int nrc_y>
void mul_mat_q6_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
@@ -8602,6 +8878,10 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
SET_MUL_MAT_FUNCTIONS(m, mul_mat_q4_k_r4_q8_k);
expected_Btype = GGML_TYPE_Q8_K32;
break;
+ case GGML_TYPE_Q5_K_R4:
+ SET_MUL_MAT_FUNCTIONS(m, mul_mat_q5_k_r4_q8_k);
+ expected_Btype = GGML_TYPE_Q8_K32;
+ break;
case GGML_TYPE_Q6_K_R4:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_q6_k_r4_q8_k);
expected_Btype = GGML_TYPE_Q8_K;
diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp
index 71578bf8..8ca18060 100644
--- a/ggml/src/iqk/iqk_quantize.cpp
+++ b/ggml/src/iqk/iqk_quantize.cpp
@@ -4182,3 +4182,122 @@ void vec_dot_q6_k_r4_q8_k(int n, float * s, size_t bs, const void * vx, size_t b
}
+//
+// ========================================= q5_k_r4
+//
+
+void quantize_row_q5_k_r4_ref(const float * x, block_q5_k_r4 * y, int64_t k) {
+ quantize_q5_k_r4(x, (void *)y, 4, k/4, nullptr);
+}
+
+void quantize_row_q5_k_r4(const float * x, void * y, int64_t k) {
+ quantize_q5_k_r4(x, y, 4, k/4, nullptr);
+}
+
+namespace {
+inline void convert_q5_k(const block_q5_K& x, uint8_t * L, uint8_t * Ld, uint8_t * Lm) {
+ for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {
+ get_scale_min_k4(2*ib64+0, x.scales, Ld[2*ib64+0], Lm[2*ib64+0]);
+ get_scale_min_k4(2*ib64+1, x.scales, Ld[2*ib64+1], Lm[2*ib64+1]);
+ for (int j = 0; j < 32; ++j) {
+ L[64*ib64+j+ 0] = (x.qs[32*ib64+j] & 0xf) | (((x.qh[j] >> (2*ib64+0)) & 1) << 4);
+ L[64*ib64+j+32] = (x.qs[32*ib64+j] >> 4) | (((x.qh[j] >> (2*ib64+1)) & 1) << 4);
+ }
+ }
+}
+}
+
+static void repack_q5_k(int nrows, int n_per_row, const block_q5_K * x, block_q5_k_r4 * y) {
+ GGML_ASSERT(nrows%4 == 0);
+ GGML_ASSERT(n_per_row%QK_K == 0);
+ int nblock = n_per_row/QK_K;
+ const block_q5_K * x4[4];
+ uint8_t L[QK_K], Ld[QK_K/32], Lm[QK_K/32];
+ for (int row = 0; row < nrows; row += 4) {
+ for (int k = 0; k < 4; ++k) x4[k] = x + nblock*k;
+ for (int ibl = 0; ibl < nblock; ++ibl) {
+ std::memset(y[ibl].scales_l, 0, QK_K/8);
+ std::memset(y[ibl].scales_h, 0, QK_K/16);
+ for (int k = 0; k < 4; ++k) {
+ y[ibl].d[k+0] = x4[k][ibl].d;
+ y[ibl].d[k+4] = x4[k][ibl].dmin;
+ convert_q5_k(x4[k][ibl], L, Ld, Lm);
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ y[ibl].scales_l[4*ib+k] = (Ld[ib] & 0xf) | ((Lm[ib] & 0xf) << 4);
+ uint8_t h = (Ld[ib] >> 4) | ((Lm[ib] >> 4) << 2);
+ y[ibl].scales_h[(4*ib+k)%16] |= (h << 4*((4*ib+k)/16));
+ for (int i = 0; i < 4; ++i) {
+ y[ibl].qs[64*ib+4*k+i+ 0] = (L[32*ib+i+ 0] & 0xf) | ((L[32*ib+i+ 8] & 0xf) << 4);
+ y[ibl].qs[64*ib+4*k+i+16] = (L[32*ib+i+16] & 0xf) | ((L[32*ib+i+24] & 0xf) << 4);
+ y[ibl].qs[64*ib+4*k+i+32] = (L[32*ib+i+ 4] & 0xf) | ((L[32*ib+i+12] & 0xf) << 4);
+ y[ibl].qs[64*ib+4*k+i+48] = (L[32*ib+i+20] & 0xf) | ((L[32*ib+i+28] & 0xf) << 4);
+ y[ibl].qh[16*ib+4*k+i+ 0] = ((L[32*ib+i+ 0] >> 4) << 0) | ((L[32*ib+i+ 8] >> 4) << 1) | ((L[32*ib+i+ 4] >> 4) << 2) | ((L[32*ib+i+12] >> 4) << 3) |
+ ((L[32*ib+i+16] >> 4) << 4) | ((L[32*ib+i+24] >> 4) << 5) | ((L[32*ib+i+20] >> 4) << 6) | ((L[32*ib+i+28] >> 4) << 7);
+ }
+ }
+ }
+ }
+ x += 4*nblock;
+ y += nblock;
+ }
+}
+
+size_t quantize_q5_k_r4(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) {
+ GGML_ASSERT(nrows%4 == 0);
+ GGML_ASSERT(n_per_row%QK_K == 0);
+ char * qcur = (char *)dst;
+ auto row_size = ggml_row_size(GGML_TYPE_Q5_K, n_per_row);
+ std::vector<char> qtmp(4*row_size);
+ for (int row = 0; row < nrows; row += 4) {
+ quantize_q5_K(src, (void *)qtmp.data(), 4, n_per_row, imatrix);
+ repack_q5_k(4, n_per_row, (const block_q5_K *)qtmp.data(), (block_q5_k_r4 *)qcur);
+ qcur += 4*row_size;
+ src += 4*n_per_row;
+ }
+ return nrows*row_size;
+}
+
+void dequantize_row_q5_k_r4(const block_q5_k_r4 * x, float * y, int64_t k) {
+ auto n_per_row = k/4;
+ float * y4[4] = {y, y + n_per_row, y + 2*n_per_row, y + 3*n_per_row};
+ int nblock = n_per_row/QK_K;
+ for (int ibl = 0; ibl < nblock; ++ibl) {
+ for (int k = 0; k < 4; ++k) {
+ const float d = GGML_FP16_TO_FP32(x[ibl].d[k+0]);
+ const float m = GGML_FP16_TO_FP32(x[ibl].d[k+4]);
+ auto ql = x[ibl].qs;
+ auto qh = x[ibl].qh;
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ int is = 4*ib + k;
+ float dl = d * ((x[ibl].scales_l[is] & 0xf) | (((x[ibl].scales_h[is%16] >> 4*(is/16)) & 0x03) << 4));
+ float ml = m * ((x[ibl].scales_l[is] >> 4) | (((x[ibl].scales_h[is%16] >> 4*(is/16)) & 0x0c) << 2));
+ for (int i = 0; i < 4; ++i) {
+ y4[k][QK_K*ibl+32*ib+i+ 0] = dl * ((ql[4*k+i+ 0] & 0xf) | ((qh[4*k+i] << 4) & 0x10)) - ml;
+ y4[k][QK_K*ibl+32*ib+i+ 8] = dl * ((ql[4*k+i+ 0] >> 4) | ((qh[4*k+i] << 3) & 0x10)) - ml;
+ y4[k][QK_K*ibl+32*ib+i+16] = dl * ((ql[4*k+i+16] & 0xf) | ((qh[4*k+i] >> 0) & 0x10)) - ml;
+ y4[k][QK_K*ibl+32*ib+i+24] = dl * ((ql[4*k+i+16] >> 4) | ((qh[4*k+i] >> 1) & 0x10)) - ml;
+ y4[k][QK_K*ibl+32*ib+i+ 4] = dl * ((ql[4*k+i+32] & 0xf) | ((qh[4*k+i] << 2) & 0x10)) - ml;
+ y4[k][QK_K*ibl+32*ib+i+12] = dl * ((ql[4*k+i+32] >> 4) | ((qh[4*k+i] << 1) & 0x10)) - ml;
+ y4[k][QK_K*ibl+32*ib+i+20] = dl * ((ql[4*k+i+48] & 0xf) | ((qh[4*k+i] >> 2) & 0x10)) - ml;
+ y4[k][QK_K*ibl+32*ib+i+28] = dl * ((ql[4*k+i+48] >> 4) | ((qh[4*k+i] >> 3) & 0x10)) - ml;
+ }
+ ql += 64;
+ qh += 16;
+ }
+ }
+ }
+}
+
+void vec_dot_q5_k_r4_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
+#if GGML_USE_IQK_MULMAT
+ if (iqk_mul_mat(1, 1, n, GGML_TYPE_Q5_K_R4, vx, 0, GGML_TYPE_Q8_K32, vy, 0, s, 0, 0, 1)) {
+ return;
+ }
+#endif
+ GGML_ASSERT(n%QK4_NL == 0);
+ GGML_ASSERT(nrc == 1);
+ GGML_UNUSED(bs);
+ GGML_UNUSED(bx);
+ GGML_UNUSED(by);
+}
+
diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h
index 8819620d..77c34fea 100644
--- a/ggml/src/iqk/iqk_quantize.h
+++ b/ggml/src/iqk/iqk_quantize.h
@@ -115,6 +115,12 @@ size_t quantize_q4_k_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT ds
void dequantize_row_q4_k_r4(const block_q4_k_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void vec_dot_q4_k_r4_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void quantize_row_q5_k_r4_ref(const float * GGML_RESTRICT x, block_q5_k_r4 * GGML_RESTRICT y, int64_t k);
+void quantize_row_q5_k_r4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+size_t quantize_q5_k_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+void dequantize_row_q5_k_r4(const block_q5_k_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void vec_dot_q5_k_r4_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
void quantize_row_q6_k_r4_ref(const float * GGML_RESTRICT x, block_q6_k_r4 * GGML_RESTRICT y, int64_t k);
void quantize_row_q6_k_r4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
size_t quantize_q6_k_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
diff --git a/include/llama.h b/include/llama.h
index 92234d6c..7290f18f 100644
--- a/include/llama.h
+++ b/include/llama.h
@@ -184,6 +184,7 @@ extern "C" {
LLAMA_FTYPE_MOSTLY_Q8_0_R4 = 207, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_0_R4 = 208, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_K_R4 = 214, // except 1d tensors
+ LLAMA_FTYPE_MOSTLY_Q5_K_R4 = 216, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q6_K_R4 = 218, // except 1d tensors
LLAMA_FTYPE_MOSTLY_IQ4_NL_R4 = 225, // except 1d tensors
LLAMA_FTYPE_MOSTLY_IQ4_XS_R4 = 230, // except 1d tensors
diff --git a/src/llama.cpp b/src/llama.cpp
index 3b617b06..dc6d307d 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -3839,6 +3839,7 @@ struct llama_model_loader {
case GGML_TYPE_Q4_K: ftype = LLAMA_FTYPE_MOSTLY_Q4_K_M; break;
case GGML_TYPE_Q4_K_R4: ftype = LLAMA_FTYPE_MOSTLY_Q4_K_R4; break;
case GGML_TYPE_Q5_K: ftype = LLAMA_FTYPE_MOSTLY_Q5_K_M; break;
+ case GGML_TYPE_Q5_K_R4: ftype = LLAMA_FTYPE_MOSTLY_Q5_K_R4; break;
case GGML_TYPE_Q6_K: ftype = LLAMA_FTYPE_MOSTLY_Q6_K; break;
case GGML_TYPE_Q6_K_R4: ftype = LLAMA_FTYPE_MOSTLY_Q6_K_R4; break;
case GGML_TYPE_IQ2_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XXS; break;
@@ -4551,6 +4552,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
case LLAMA_FTYPE_MOSTLY_Q4_K_R4: return "Q4_K_R4";
case LLAMA_FTYPE_MOSTLY_Q4_K_M: return "Q4_K - Medium";
case LLAMA_FTYPE_MOSTLY_Q5_K_S: return "Q5_K - Small";
+ case LLAMA_FTYPE_MOSTLY_Q5_K_R4: return "Q5_K_R4";
case LLAMA_FTYPE_MOSTLY_Q5_K_M: return "Q5_K - Medium";
case LLAMA_FTYPE_MOSTLY_Q6_K: return "Q6_K";
case LLAMA_FTYPE_MOSTLY_Q6_K_R4: return "Q6_K_R4";
@@ -15793,6 +15795,9 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
else if (new_type == GGML_TYPE_Q4_K_R4) {
new_type = GGML_TYPE_Q4_K;
}
+ else if (new_type == GGML_TYPE_Q5_K_R4) {
+ new_type = GGML_TYPE_Q5_K;
+ }
else if (new_type == GGML_TYPE_Q6_K_R4) {
new_type = GGML_TYPE_Q6_K;
}
@@ -16067,7 +16072,8 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
new_type == GGML_TYPE_IQ1_M || new_type == GGML_TYPE_IQ4_K || new_type == GGML_TYPE_IQ2_K ||
new_type == GGML_TYPE_IQ5_K || new_type == GGML_TYPE_IQ3_K || new_type == GGML_TYPE_Q4_K_R4 ||
new_type == GGML_TYPE_IQ6_K || new_type == GGML_TYPE_IQ4_KS || new_type == GGML_TYPE_IQ4_XS_R4 ||
- new_type == GGML_TYPE_IQ2_KS || new_type == GGML_TYPE_IQ4_KSS || new_type == GGML_TYPE_Q6_K_R4) {
+ new_type == GGML_TYPE_IQ2_KS || new_type == GGML_TYPE_IQ4_KSS || new_type == GGML_TYPE_Q6_K_R4 ||
+ new_type == GGML_TYPE_Q5_K_R4) {
int nx = tensor->ne[0];
int ny = tensor->ne[1];
if (nx % QK_K != 0) {
@@ -16105,6 +16111,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
case GGML_TYPE_Q4_K_R4:
case GGML_TYPE_Q4_K: new_type = GGML_TYPE_Q5_0; break;
case GGML_TYPE_IQ5_K:
+ case GGML_TYPE_Q5_K_R4:
case GGML_TYPE_Q5_K: new_type = GGML_TYPE_Q6_0; break;
case GGML_TYPE_IQ6_K:
case GGML_TYPE_Q6_K_R4:
@@ -16199,6 +16206,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
case LLAMA_FTYPE_MOSTLY_Q4_K_R4: default_type = GGML_TYPE_Q4_K_R4; break;
case LLAMA_FTYPE_MOSTLY_Q5_K_S:
case LLAMA_FTYPE_MOSTLY_Q5_K_M: default_type = GGML_TYPE_Q5_K; break;
+ case LLAMA_FTYPE_MOSTLY_Q5_K_R4: default_type = GGML_TYPE_Q5_K_R4; break;
case LLAMA_FTYPE_MOSTLY_Q6_K: default_type = GGML_TYPE_Q6_K; break;
case LLAMA_FTYPE_MOSTLY_Q6_K_R4: default_type = GGML_TYPE_Q6_K_R4; break;
case LLAMA_FTYPE_MOSTLY_IQ2_XXS: default_type = GGML_TYPE_IQ2_XXS; break;
@@ -16604,6 +16612,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q4_K;
else chunk_size_multiplier = 4;
}
+ else if (new_type == GGML_TYPE_Q5_K_R4) {
+ if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q5_K;
+ else chunk_size_multiplier = 4;
+ }
else if (new_type == GGML_TYPE_Q6_K_R4) {
if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q6_K;
else chunk_size_multiplier = 4;