summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/quantize/quantize.cpp1
-rw-r--r--ggml/include/ggml.h2
-rw-r--r--ggml/src/ggml-common.h6
-rw-r--r--ggml/src/ggml-quants.c1
-rw-r--r--ggml/src/ggml.c26
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp201
-rw-r--r--ggml/src/iqk/iqk_quantize.cpp101
-rw-r--r--ggml/src/iqk/iqk_quantize.h6
-rw-r--r--include/llama.h1
-rw-r--r--src/llama.cpp21
10 files changed, 364 insertions, 2 deletions
diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp
index d505f493..cefe2735 100644
--- a/examples/quantize/quantize.cpp
+++ b/examples/quantize/quantize.cpp
@@ -51,6 +51,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
{ "Q8_0_R4", LLAMA_FTYPE_MOSTLY_Q8_0_R4, " 8.50 bpw quantization", },
{ "IQ4_XS", LLAMA_FTYPE_MOSTLY_IQ4_XS, " 4.25 bpw non-linear quantization", },
{ "IQ4_KS", LLAMA_FTYPE_MOSTLY_IQ4_KS, " 4.25 bpw non-linear quantization", },
+ { "IQ4_KS_R4",LLAMA_FTYPE_MOSTLY_IQ4_KS_R4,"IQ4_KS repacked", },
{ "IQ4_KSS", LLAMA_FTYPE_MOSTLY_IQ4_KSS, " 4.0 bpw non-linear quantization", },
{ "IQ2_K", LLAMA_FTYPE_MOSTLY_IQ2_K, " 2.375 bpw non-linear quantization",},
{ "IQ2_K_R4", LLAMA_FTYPE_MOSTLY_IQ2_K_R4, "IQ2_K repacked",},
diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
index 07142692..1a4e7f19 100644
--- a/ggml/include/ggml.h
+++ b/ggml/include/ggml.h
@@ -427,6 +427,7 @@ extern "C" {
GGML_TYPE_IQ3_K_R4 = 338,
GGML_TYPE_IQ4_K_R4 = 339,
GGML_TYPE_IQ5_K_R4 = 340,
+ GGML_TYPE_IQ4_KS_R4 = 344,
GGML_TYPE_Q8_K_R8 = 399,
GGML_TYPE_COUNT,
};
@@ -504,6 +505,7 @@ extern "C" {
GGML_FTYPE_MOSTLY_IQ3_K_R4 = 331, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ4_K_R4 = 332, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ5_K_R4 = 333, // except 1d tensors
+ GGML_FTYPE_MOSTLY_IQ4_KS_R4 = 337, // except 1d tensors
GGML_FTYPE_MOSTLY_Q8_K_R8 = 399, // except 1d tensors
};
diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h
index 0af461c7..6e87dbaa 100644
--- a/ggml/src/ggml-common.h
+++ b/ggml/src/ggml-common.h
@@ -509,6 +509,12 @@ typedef struct {
static_assert(sizeof(block_iq4_ks) == QK_K/32 + QK_K/2, "wrong iq4_ks block size/padding");
typedef struct {
+ uint8_t scales[QK_K/8];
+ uint8_t qs[QK_K*2];
+} block_iq4_ks_r4;
+static_assert(sizeof(block_iq4_ks_r4) == 4*sizeof(block_iq4_ks), "wrong iq4_ks_r4 block size/padding");
+
+typedef struct {
uint32_t qs[QK_K/8];
} block_iq4_kss;
static_assert(sizeof(block_iq4_kss) == QK_K/8*sizeof(uint32_t), "wrong iq4_kss block size/padding");
diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c
index d0de3d0f..2b59808b 100644
--- a/ggml/src/ggml-quants.c
+++ b/ggml/src/ggml-quants.c
@@ -15211,6 +15211,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
case GGML_TYPE_IQ3_K_R4: break;
case GGML_TYPE_IQ4_K_R4: break;
case GGML_TYPE_IQ5_K_R4: break;
+ case GGML_TYPE_IQ4_KS_R4: break;
case GGML_TYPE_Q8_K_R8: break;
case GGML_TYPE_BF16_R16: break;
case GGML_TYPE_Q4_0_4_4:
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 526b1139..758ee018 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -1165,6 +1165,23 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.nrows = 1,
.row_meta_size = 4,
},
+ [GGML_TYPE_IQ4_KS_R4] = {
+ .type_name = "iq4_ks_r4",
+ .blck_size = QK_K,
+ .type_size = sizeof(block_iq4_ks),
+ .is_quantized = true,
+ .to_float = (ggml_to_float_t) dequantize_row_iq4_ks_r4,
+ .from_float = quantize_row_iq4_ks_r4,
+ .from_float_ref = (ggml_from_float_t)quantize_row_iq4_ks_r4_ref,
+ .vec_dot = vec_dot_iq4_ks_r4_q8_k,
+#if defined __AVX2__
+ .vec_dot_type = GGML_TYPE_Q8_K32,
+#else
+ .vec_dot_type = GGML_TYPE_Q8_K,
+#endif
+ .nrows = 1,
+ .row_meta_size = 4,
+ },
[GGML_TYPE_IQ4_KSS] = {
.type_name = "iq4_kss",
.blck_size = QK_K,
@@ -4197,6 +4214,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
case GGML_FTYPE_MOSTLY_Q8_0_R4: wtype = GGML_TYPE_Q8_0_R4; break;
case GGML_FTYPE_MOSTLY_IQ4_XS: wtype = GGML_TYPE_IQ4_XS; break;
case GGML_FTYPE_MOSTLY_IQ4_KS: wtype = GGML_TYPE_IQ4_KS; break;
+ case GGML_FTYPE_MOSTLY_IQ4_KS_R4: wtype = GGML_TYPE_IQ4_KS_R4;break;
case GGML_FTYPE_MOSTLY_IQ4_KSS: wtype = GGML_TYPE_IQ4_KSS; break;
case GGML_FTYPE_MOSTLY_IQ2_K: wtype = GGML_TYPE_IQ2_K; break;
case GGML_FTYPE_MOSTLY_IQ2_K_R4: wtype = GGML_TYPE_IQ2_K_R4; break;
@@ -10737,6 +10755,7 @@ static void ggml_compute_forward_add(
case GGML_TYPE_Q8_0_R4:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_KS:
+ case GGML_TYPE_IQ4_KS_R4:
case GGML_TYPE_IQ4_KSS:
case GGML_TYPE_IQ2_K:
case GGML_TYPE_IQ2_K_R4:
@@ -11196,6 +11215,7 @@ static void ggml_compute_forward_add1(
case GGML_TYPE_Q8_0_R4:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_KS:
+ case GGML_TYPE_IQ4_KS_R4:
case GGML_TYPE_IQ4_KSS:
case GGML_TYPE_IQ2_K:
case GGML_TYPE_IQ2_K_R4:
@@ -11352,6 +11372,7 @@ static void ggml_compute_forward_acc(
case GGML_TYPE_Q8_0_R4:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_KS:
+ case GGML_TYPE_IQ4_KS_R4:
case GGML_TYPE_IQ4_KSS:
case GGML_TYPE_IQ2_K:
case GGML_TYPE_IQ2_K_R4:
@@ -14554,6 +14575,7 @@ static void ggml_compute_forward_out_prod(
case GGML_TYPE_Q8_0_R4:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_KS:
+ case GGML_TYPE_IQ4_KS_R4:
case GGML_TYPE_IQ4_KSS:
case GGML_TYPE_IQ2_K:
case GGML_TYPE_IQ2_K_R4:
@@ -14950,6 +14972,7 @@ static void ggml_compute_forward_set(
case GGML_TYPE_Q8_0_R4:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_KS:
+ case GGML_TYPE_IQ4_KS_R4:
case GGML_TYPE_IQ4_KSS:
case GGML_TYPE_IQ2_K:
case GGML_TYPE_IQ2_K_R4:
@@ -15240,6 +15263,7 @@ static void ggml_compute_forward_get_rows(
case GGML_TYPE_Q8_0_R4:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_KS:
+ case GGML_TYPE_IQ4_KS_R4:
case GGML_TYPE_IQ4_KSS:
case GGML_TYPE_IQ2_K:
case GGML_TYPE_IQ2_K_R4:
@@ -15859,6 +15883,7 @@ static void ggml_compute_forward_clamp(
case GGML_TYPE_Q8_0_R4:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_KS:
+ case GGML_TYPE_IQ4_KS_R4:
case GGML_TYPE_IQ4_KSS:
case GGML_TYPE_IQ2_K:
case GGML_TYPE_IQ2_K_R4:
@@ -22706,6 +22731,7 @@ size_t ggml_quantize_chunk(
case GGML_TYPE_Q8_0_R4: result = quantize_q8_0_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ4_KS: result = quantize_iq4_ks (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_IQ4_KS_R4:result = quantize_iq4_ks_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ4_KSS: result = quantize_iq4_kss(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ2_K: result = quantize_iq2_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ2_K_R4:result = quantize_iq2_k_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 3e8cbf06..f4576319 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -186,6 +186,7 @@ struct MulMat {
case GGML_TYPE_IQ3_K_R4:
case GGML_TYPE_IQ4_K_R4:
case GGML_TYPE_IQ5_K_R4:
+ case GGML_TYPE_IQ4_KS_R4:
case GGML_TYPE_IQ2_BN_R4: return 4;
case GGML_TYPE_Q8_K_R8: return 8;
case GGML_TYPE_BF16_R16: return 16;
@@ -3103,6 +3104,115 @@ static void mul_mat_iq4_xs_r4_q8_k(int n, const void * vx, size_t bx, const Data
#endif
template <int nrc_y>
+static void mul_mat_iq4_ks_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ auto m4 = _mm256_set1_epi8(0xf);
+#ifndef HAVE_FANCY_SIMD
+ auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100);
+ auto values128 = _mm_loadu_si128((const __m128i *)iq4k_values);
+ auto values = MM256_SET_M128I(values128, values128);
+#else
+ auto values = load_iq4nl_values_256();
+#endif
+ int nbl = n / QK_K;
+ using helper_t = union { __m256i vec; uint32_t val[8]; };
+#ifndef HAVE_FANCY_SIMD
+ helper_t h, h_shift;
+#else
+ using helper512_t = union { __m512i vec; uint64_t val[8]; };
+ helper_t h;
+ helper512_t h_shift;
+#endif
+ __m256 acc[nrc_y] = {};
+ __m256i isum[nrc_y] = {};
+ __m256i qx[4];
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ auto dptr = (const float *)((const char *)vx + (ix+0)*bx);
+ const block_iq4_ks_r4 * iq4 = (const block_iq4_ks_r4 *)(dptr + 4);
+ auto d4 = _mm_loadu_ps(dptr);
+ for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
+ auto scales = _mm256_loadu_si256((const __m256i *)iq4[ibl].scales);
+ h.vec = _mm256_sub_epi8(_mm256_and_si256(scales, _mm256_set1_epi8(-2)), _mm256_set1_epi8(127));
+#ifndef HAVE_FANCY_SIMD
+ h_shift.vec = _mm256_slli_epi16(_mm256_and_si256(scales, _mm256_set1_epi8(1)), 2);
+ {
+ __m256 v1 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[4])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[0])))),
+ _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[4])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[0])))));
+ __m256 v2 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[5])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[1])))),
+ _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[5])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[1])))));
+ __m256 v3 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[6])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[2])))),
+ _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[6])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[2])))));
+ __m256 v4 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[7])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[3])))),
+ _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[7])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[3])))));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto m8 = _mm256_loadu_ps((const float *)q8.y[iy][ibl].bsums);
+ acc[iy] = _mm256_fmadd_ps(v1, _mm256_shuffle_ps(m8, m8, 0x00), acc[iy]);
+ acc[iy] = _mm256_fmadd_ps(v2, _mm256_shuffle_ps(m8, m8, 0x55), acc[iy]);
+ acc[iy] = _mm256_fmadd_ps(v3, _mm256_shuffle_ps(m8, m8, 0xaa), acc[iy]);
+ acc[iy] = _mm256_fmadd_ps(v4, _mm256_shuffle_ps(m8, m8, 0xff), acc[iy]);
+ }
+ }
+#else
+ auto shift = _mm256_add_epi8(_mm256_set1_epi8(-64), _mm256_slli_epi16(_mm256_and_si256(scales, _mm256_set1_epi8(1)), 1));
+ h_shift.vec = _mm512_mullo_epi16(_mm512_cvtepi8_epi16(shift), _mm512_cvtepi8_epi16(h.vec));
+#endif
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+#ifdef HAVE_FANCY_SIMD
+ auto iscales = _mm256_cvtepi8_epi32(_mm_set1_epi32(h.val[ib]));
+ auto ishifts = _mm256_cvtepi16_epi32(_mm_set1_epi64x(h_shift.val[ib]));
+ auto scales_m = _mm256_cvtepi32_ps(ishifts);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib];
+ acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(m8), acc[iy]);
+ }
+#endif
+ auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+0);
+ auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+1);
+ qx[0] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits1, m4));
+ qx[1] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits2, m4));
+ qx[2] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4));
+ qx[3] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4));
+#ifndef HAVE_FANCY_SIMD
+ auto iscales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi32(h.val[ib])), s_shuffle);
+ auto s1 = _mm256_sign_epi8(qx[0], qx[0]);
+ auto s2 = _mm256_sign_epi8(qx[1], qx[1]);
+ auto s3 = _mm256_sign_epi8(qx[2], qx[2]);
+ auto s4 = _mm256_sign_epi8(qx[3], qx[3]);
+#endif
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
+#ifdef HAVE_FANCY_SIMD
+ auto sumi = _mm256_setzero_si256();
+ sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(iscales, sumi));
+#else
+ auto sumi1 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
+ auto sumi2 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
+ auto sumi3 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
+ auto sumi4 = _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(iscales, sumi1), _mm256_madd_epi16(iscales, sumi2)));
+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(iscales, sumi3), _mm256_madd_epi16(iscales, sumi4)));
+#endif
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.scale(iy, ibl)), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
+ isum[iy] = _mm256_setzero_si256();
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
+ acc[iy] = _mm256_setzero_ps();
+ info.store(ix+0, iy, _mm_mul_ps(d4, sum));
+ }
+ }
+}
+
+template <int nrc_y>
static void mul_mat_q4_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
@@ -6482,6 +6592,18 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
mm.funcs[7] = mul_mat_iq4_xs_r4_q8_k<8>;
expected_typeB = GGML_TYPE_Q8_K32;
break;
+ case GGML_TYPE_IQ4_KS_R4:
+ assert (ne00 % QK_K == 0);
+ mm.funcs[0] = mul_mat_iq4_ks_r4_q8_k<1>;
+ mm.funcs[1] = mul_mat_iq4_ks_r4_q8_k<2>;
+ mm.funcs[2] = mul_mat_iq4_ks_r4_q8_k<3>;
+ mm.funcs[3] = mul_mat_iq4_ks_r4_q8_k<4>;
+ mm.funcs[4] = mul_mat_iq4_ks_r4_q8_k<5>;
+ mm.funcs[5] = mul_mat_iq4_ks_r4_q8_k<6>;
+ mm.funcs[6] = mul_mat_iq4_ks_r4_q8_k<7>;
+ mm.funcs[7] = mul_mat_iq4_ks_r4_q8_k<8>;
+ expected_typeB = GGML_TYPE_Q8_K32;
+ break;
case GGML_TYPE_Q2_K_R4:
assert (ne00 % QK_K == 0);
mm.funcs[0] = mul_mat_q2_k_r4_q8_k<1>;
@@ -9278,6 +9400,81 @@ void mul_mat_iq4_xs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& i
}
}
+template <int nrc_y>
+void mul_mat_iq4_ks_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ auto m4 = vdupq_n_u8(0xf);
+ auto values = vld1q_s8(iq4k_values);
+ int nbl = n / QK_K;
+ int8x16_t qx[8];
+ int16x8x4_t iscales;
+ int32x4x4_t scales;
+ float32x4_t acc[nrc_y] = {};
+ int32x4_t isum[nrc_y] = {};
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ auto dptr = (const float *)((const char *)vx + ix*bx);
+ auto d4 = vld1q_f32(dptr);
+ const block_iq4_ks_r4 * iq4 = (const block_iq4_ks_r4 *)(dptr + 4);
+ for (int ibl = 0; ibl < nbl; ++ibl) {
+ auto sas = vld1q_u8_x2(iq4[ibl].scales);
+ auto scale = vandq_u8(sas.val[0], vdupq_n_u8(254));
+ iscales.val[0] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8 (scale))), vdupq_n_s16(-127));
+ iscales.val[1] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(scale))), vdupq_n_s16(-127));
+ scale = vandq_u8(sas.val[1], vdupq_n_u8(254));
+ iscales.val[2] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8 (scale))), vdupq_n_s16(-127));
+ iscales.val[3] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(scale))), vdupq_n_s16(-127));
+ // Adding the block shifts costs us ~9% in performance drop.
+ // Is there a better way?
+ sas.val[0] = vshlq_n_u8(vandq_u8(sas.val[0], vdupq_n_u8(1)), 2);
+ sas.val[1] = vshlq_n_u8(vandq_u8(sas.val[1], vdupq_n_u8(1)), 2);
+ {
+ auto s16_1 = vmulq_s16(iscales.val[0], vmovl_u8(vget_low_u8 (sas.val[0])));
+ auto s16_2 = vmulq_s16(iscales.val[1], vmovl_u8(vget_high_u8(sas.val[0])));
+ auto s16_3 = vmulq_s16(iscales.val[2], vmovl_u8(vget_low_u8 (sas.val[1])));
+ auto s16_4 = vmulq_s16(iscales.val[3], vmovl_u8(vget_high_u8(sas.val[1])));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto bsums = vld1q_s16_x2(q8.y[iy][ibl].bsums);
+ auto bs = vpaddq_s16(bsums.val[0], bsums.val[1]);
+ auto b8 = vget_low_s16(bs);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_1), b8, 0);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_1), b8, 1);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_2), b8, 2);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_2), b8, 3);
+ b8 = vget_high_s16(bs);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_3), b8, 0);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_3), b8, 1);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_4), b8, 2);
+ isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_4), b8, 3);
+ }
+ }
+ for (int is = 0; is < 2; ++is) {
+ scales.val[0] = vmovl_s16(vget_low_s16 (iscales.val[2*is+0]));
+ scales.val[1] = vmovl_s16(vget_high_s16(iscales.val[2*is+0]));
+ scales.val[2] = vmovl_s16(vget_low_s16 (iscales.val[2*is+1]));
+ scales.val[3] = vmovl_s16(vget_high_s16(iscales.val[2*is+1]));
+ for (int ib = 0; ib < 4; ++ib) {
+ auto bits = vld1q_u8_x4(iq4[ibl].qs + 256*is + 64*ib);
+ prepare_iq4_nl_quants(values, m4, bits, qx);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = vld1q_s8_x2(q8.y[iy][ibl].qs+128*is+32*ib);
+ auto sumi = interleaved_dotq(qx, y);
+ isum[iy] = vmlaq_s32(isum[iy], scales.val[ib], sumi);
+ }
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(q8.scale(iy, ibl)), vcvtq_f32_s32(isum[iy]));
+ isum[iy] = vdupq_n_s32(0);
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, vmulq_f32(d4, acc[iy]));
+ acc[iy] = vdupq_n_f32(0.f);
+ }
+ }
+}
+
template <int nrc_y, int k_shift>
inline void iq3_4_add_shift(int ibl, const Q8<nrc_y, block_q8_K>& q8, const int8x16x4_t& i8scales, uint8x16_t extra,
int32x4_t * isum) {
@@ -10571,6 +10768,10 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq4_xs_r4_q8_k);
expected_Btype = GGML_TYPE_Q8_K32;
break;
+ case GGML_TYPE_IQ4_KS_R4:
+ SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq4_ks_r4_q8_k);
+ expected_Btype = GGML_TYPE_Q8_K;
+ break;
case GGML_TYPE_Q2_K_R4:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_q2_k_r4_q8_k);
expected_Btype = GGML_TYPE_Q8_K;
diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp
index 0007dc04..874c6d55 100644
--- a/ggml/src/iqk/iqk_quantize.cpp
+++ b/ggml/src/iqk/iqk_quantize.cpp
@@ -3839,6 +3839,107 @@ void vec_dot_iq4_xs_r4_q8_k(int n, float * s, size_t bs, const void * vx, size_t
}
//
+// ========================================= iq4_ks_r4
+//
+
+void quantize_row_iq4_ks_r4_ref(const float * x, block_iq4_ks_r4 * y, int64_t k) {
+ quantize_iq4_ks_r4(x, (void *)y, 4, k/4, nullptr);
+}
+
+void quantize_row_iq4_ks_r4(const float * x, void * y, int64_t k) {
+ quantize_iq4_ks_r4(x, y, 4, k/4, nullptr);
+}
+
+static void repack_iq4_ks(int nrows, int n_per_row, const block_iq4_ks * x, block_iq4_ks_r4 * y) {
+ GGML_ASSERT(nrows%4 == 0);
+ GGML_ASSERT(n_per_row%QK_K == 0);
+ auto row_size = ggml_row_size(GGML_TYPE_IQ4_KS, n_per_row);
+ int nblock = n_per_row/QK_K;
+ char * cy = (char *)y;
+ const char * cx = (const char *)x;
+ const block_iq4_ks * x4[4];
+ for (int row = 0; row < nrows; row += 4) {
+ float * dptr = (float *)cy;
+ block_iq4_ks_r4 * y = (block_iq4_ks_r4 *)(dptr + 4);
+ for (int k = 0; k < 4; ++k) {
+ auto dk = (const float *)(cx + k*row_size);
+ dptr[k] = dk[0];
+ x4[k] = (const block_iq4_ks *)(dk + 1);
+ }
+ for (int ibl = 0; ibl < nblock; ++ibl) {
+ for (int k = 0; k < 4; ++k) {
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ y[ibl].scales[4*ib+k] = x4[k][ibl].scales[ib];
+ for (int i = 0; i < 4; ++i) {
+ y[ibl].qs[64*ib+4*k+i+ 0] = (x4[k][ibl].qs[16*ib+i+0] & 0xf) | ((x4[k][ibl].qs[16*ib+i+ 8] & 0x0f) << 4); // 0....3 + 8...11 from each row
+ y[ibl].qs[64*ib+4*k+i+16] = (x4[k][ibl].qs[16*ib+i+0] >> 4) | ((x4[k][ibl].qs[16*ib+i+ 8] & 0xf0)); // 16...19 + 24...27 from each row
+ y[ibl].qs[64*ib+4*k+i+32] = (x4[k][ibl].qs[16*ib+i+4] & 0xf) | ((x4[k][ibl].qs[16*ib+i+12] & 0x0f) << 4); // 4....7 + 12...15 from each row
+ y[ibl].qs[64*ib+4*k+i+48] = (x4[k][ibl].qs[16*ib+i+4] >> 4) | ((x4[k][ibl].qs[16*ib+i+12] & 0xf0)); // 20...23 + 28...31 from each row
+ }
+ }
+ }
+ }
+ cx += 4*row_size;
+ cy += 4*row_size;
+ }
+}
+
+size_t quantize_iq4_ks_r4(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) {
+ GGML_ASSERT(nrows%4 == 0);
+ GGML_ASSERT(n_per_row%QK_K == 0);
+ char * qcur = (char *)dst;
+ auto row_size = ggml_row_size(GGML_TYPE_IQ4_KS, n_per_row);
+ std::vector<char> qtmp(4*row_size);
+ for (int row = 0; row < nrows; row += 4) {
+ quantize_iq4_ks(src, (void *)qtmp.data(), 4, n_per_row, imatrix);
+ repack_iq4_ks(4, n_per_row, (const block_iq4_ks *)qtmp.data(), (block_iq4_ks_r4 *)qcur);
+ qcur += 4*row_size;
+ src += 4*n_per_row;
+ }
+ return nrows*row_size;
+}
+
+void dequantize_row_iq4_ks_r4(const block_iq4_ks_r4 * x, float * y, int64_t k) {
+ auto n_per_row = k/4;
+ float * y4[4] = {y, y + n_per_row, y + 2*n_per_row, y + 3*n_per_row};
+ int nblock = n_per_row/QK_K;
+ const float * dptr = (const float *)x;
+ x = (const block_iq4_ks_r4 *)(dptr + 4);
+ for (int ibl = 0; ibl < nblock; ++ibl) {
+ for (int k = 0; k < 4; ++k) {
+ const float d = dptr[k];
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ float dl = d * ((x[ibl].scales[4*ib + k] & 254) - 127);
+ auto values = iq4k_values + ((x[ibl].scales[4*ib + k] & 1) << 4);
+ for (int i = 0; i < 4; ++i) {
+ y4[k][QK_K*ibl+32*ib+i+ 0] = dl * values[x[ibl].qs[64*ib+4*k+i+ 0] & 0xf];
+ y4[k][QK_K*ibl+32*ib+i+ 8] = dl * values[x[ibl].qs[64*ib+4*k+i+ 0] >> 4];
+ y4[k][QK_K*ibl+32*ib+i+16] = dl * values[x[ibl].qs[64*ib+4*k+i+16] & 0xf];
+ y4[k][QK_K*ibl+32*ib+i+24] = dl * values[x[ibl].qs[64*ib+4*k+i+16] >> 4];
+ y4[k][QK_K*ibl+32*ib+i+ 4] = dl * values[x[ibl].qs[64*ib+4*k+i+32] & 0xf];
+ y4[k][QK_K*ibl+32*ib+i+12] = dl * values[x[ibl].qs[64*ib+4*k+i+32] >> 4];
+ y4[k][QK_K*ibl+32*ib+i+20] = dl * values[x[ibl].qs[64*ib+4*k+i+48] & 0xf];
+ y4[k][QK_K*ibl+32*ib+i+28] = dl * values[x[ibl].qs[64*ib+4*k+i+48] >> 4];
+ }
+ }
+ }
+ }
+}
+
+void vec_dot_iq4_ks_r4_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
+#if GGML_USE_IQK_MULMAT
+ if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ4_KS_R4, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) {
+ return;
+ }
+#endif
+ GGML_ASSERT(n%QK4_NL == 0);
+ GGML_ASSERT(nrc == 1);
+ GGML_UNUSED(bs);
+ GGML_UNUSED(bx);
+ GGML_UNUSED(by);
+}
+
+//
// ========================================= iq2_bn_r4
//
void quantize_row_iq2_bn_r4_ref(const float * x, block_iq2_bn * y, int64_t k) {
diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h
index b8604caa..bcec8432 100644
--- a/ggml/src/iqk/iqk_quantize.h
+++ b/ggml/src/iqk/iqk_quantize.h
@@ -163,6 +163,12 @@ size_t quantize_iq2_k_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT d
void dequantize_row_iq2_k_r4(const block_iq2_k_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void vec_dot_iq2_k_r4_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void quantize_row_iq4_ks_r4_ref(const float * GGML_RESTRICT x, block_iq4_ks_r4 * GGML_RESTRICT y, int64_t k);
+void quantize_row_iq4_ks_r4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+size_t quantize_iq4_ks_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+void dequantize_row_iq4_ks_r4(const block_iq4_ks_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void vec_dot_iq4_ks_r4_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
void quantize_row_q8_k_r8_ref(const float * GGML_RESTRICT x, block_q8_k_r8 * GGML_RESTRICT y, int64_t k);
void quantize_row_q8_k_r8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
size_t quantize_q8_k_r8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
diff --git a/include/llama.h b/include/llama.h
index 7267fbd4..5480f4e9 100644
--- a/include/llama.h
+++ b/include/llama.h
@@ -197,6 +197,7 @@ extern "C" {
LLAMA_FTYPE_MOSTLY_IQ3_K_R4 = 339, // except 1d tensors
LLAMA_FTYPE_MOSTLY_IQ4_K_R4 = 340, // except 1d tensors
LLAMA_FTYPE_MOSTLY_IQ5_K_R4 = 341, // except 1d tensors
+ LLAMA_FTYPE_MOSTLY_IQ4_KS_R4 = 345, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q8_K_R8 = 399, // except 1d tensors
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
diff --git a/src/llama.cpp b/src/llama.cpp
index f04d7f6f..3f0ceb0a 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -3868,6 +3868,7 @@ struct llama_model_loader {
case GGML_TYPE_Q8_0_R4: ftype = LLAMA_FTYPE_MOSTLY_Q8_0_R4; break;
case GGML_TYPE_IQ4_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS; break;
case GGML_TYPE_IQ4_KS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_KS; break;
+ case GGML_TYPE_IQ4_KS_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ4_KS_R4; break;
case GGML_TYPE_IQ4_KSS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_KSS; break;
case GGML_TYPE_IQ2_K: ftype = LLAMA_FTYPE_MOSTLY_IQ2_K; break;
case GGML_TYPE_IQ2_K_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ2_K_R4;break;
@@ -4593,6 +4594,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
case LLAMA_FTYPE_MOSTLY_Q8_0_R4: return "Q8_0_R4 - 8.5 bpw";
case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw";
case LLAMA_FTYPE_MOSTLY_IQ4_KS: return "IQ4_KS - 4.25 bpw";
+ case LLAMA_FTYPE_MOSTLY_IQ4_KS_R4:return "IQ4_KS_R4 - 4.25 bpw";
case LLAMA_FTYPE_MOSTLY_IQ4_KSS: return "IQ4_KSS - 4.0 bpw";
case LLAMA_FTYPE_MOSTLY_IQ2_K: return "IQ2_K - 2.375 bpw";
case LLAMA_FTYPE_MOSTLY_IQ2_K_R4: return "IQ2_K_R4 - 2.375 bpw";
@@ -15794,7 +15796,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
new_type = !qs.has_output ? GGML_TYPE_IQ4_K : GGML_TYPE_Q5_K;
}
else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ3_S || ftype == LLAMA_FTYPE_MOSTLY_IQ3_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS ||
- ftype == LLAMA_FTYPE_MOSTLY_IQ4_KS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_KSS) && !qs.has_output) {
+ ftype == LLAMA_FTYPE_MOSTLY_IQ4_KS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_KSS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_KS_R4) && !qs.has_output) {
new_type = GGML_TYPE_IQ5_K;
}
else if (new_type != GGML_TYPE_Q8_0 && new_type != GGML_TYPE_Q8_0_R4 && new_type != GGML_TYPE_IQ6_K && new_type != GGML_TYPE_Q6_K_R4 &&
@@ -15859,6 +15861,9 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
else if (new_type == GGML_TYPE_IQ5_K_R4) {
new_type = GGML_TYPE_IQ5_K;
}
+ else if (new_type == GGML_TYPE_IQ4_KS_R4) {
+ new_type = GGML_TYPE_IQ4_KS;
+ }
else if (new_type == GGML_TYPE_Q4_0_R4) {
new_type = GGML_TYPE_Q4_0;
}
@@ -15949,6 +15954,9 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
ftype == LLAMA_FTYPE_MOSTLY_IQ4_KS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_KSS) && qs.model.hparams.n_gqa() >= 2) {
new_type = GGML_TYPE_IQ5_K;
}
+ else if (ftype == LLAMA_FTYPE_MOSTLY_IQ4_KS_R4 && qs.model.hparams.n_gqa() >= 2) {
+ new_type = GGML_TYPE_IQ5_K_R4;
+ }
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ4_K && qs.model.hparams.n_gqa() >= 2) {
new_type = GGML_TYPE_IQ5_K;
}
@@ -16053,6 +16061,9 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS_R4)) {
new_type = GGML_TYPE_Q5_K;
}
+ else if (ftype == LLAMA_FTYPE_MOSTLY_IQ4_KS_R4 && i_layer < n_layer/8 && !qs.has_imatrix) {
+ new_type = GGML_TYPE_Q5_K_R4;
+ }
else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M && use_more_bits(i_layer, n_layer)) new_type = GGML_TYPE_Q6_K;
else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && arch != LLM_ARCH_FALCON && i_layer < n_layer/8) {
new_type = GGML_TYPE_Q5_K;
@@ -16155,7 +16166,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
new_type == GGML_TYPE_IQ2_KS || new_type == GGML_TYPE_IQ4_KSS || new_type == GGML_TYPE_Q6_K_R4 ||
new_type == GGML_TYPE_Q5_K_R4 || new_type == GGML_TYPE_Q3_K_R4 || new_type == GGML_TYPE_Q2_K_R4 ||
new_type == GGML_TYPE_IQ4_K_R4|| new_type == GGML_TYPE_Q8_K_R8 || new_type == GGML_TYPE_IQ3_K_R4||
- new_type == GGML_TYPE_IQ2_K_R4|| new_type == GGML_TYPE_IQ5_K_R4) {
+ new_type == GGML_TYPE_IQ2_K_R4|| new_type == GGML_TYPE_IQ5_K_R4|| new_type == GGML_TYPE_IQ4_KS_R4) {
int nx = tensor->ne[0];
int ny = tensor->ne[1];
if (nx % QK_K != 0) {
@@ -16191,6 +16202,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
case GGML_TYPE_IQ3_K_R4:
case GGML_TYPE_IQ4_KSS:
case GGML_TYPE_IQ4_KS:
+ case GGML_TYPE_IQ4_KS_R4:
case GGML_TYPE_IQ4_XS_R4:
case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_IQ4_NL; break;
case GGML_TYPE_IQ4_K:
@@ -16322,6 +16334,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
case LLAMA_FTYPE_MOSTLY_Q8_0_R4: default_type = GGML_TYPE_Q8_0_R4; break;
case LLAMA_FTYPE_MOSTLY_IQ4_XS: default_type = GGML_TYPE_IQ4_XS; break;
case LLAMA_FTYPE_MOSTLY_IQ4_KS: default_type = GGML_TYPE_IQ4_KS; break;
+ case LLAMA_FTYPE_MOSTLY_IQ4_KS_R4:default_type = GGML_TYPE_IQ4_KS_R4;break;
case LLAMA_FTYPE_MOSTLY_IQ4_KSS: default_type = GGML_TYPE_IQ4_KSS; break;
case LLAMA_FTYPE_MOSTLY_IQ2_K: default_type = GGML_TYPE_IQ2_K; break;
case LLAMA_FTYPE_MOSTLY_IQ2_K_R4:default_type = GGML_TYPE_IQ2_K_R4;break;
@@ -16752,6 +16765,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ5_K;
else chunk_size_multiplier = 4;
}
+ else if (new_type == GGML_TYPE_IQ4_KS_R4) {
+ if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ4_KS;
+ else chunk_size_multiplier = 4;
+ }
else if (new_type == GGML_TYPE_BF16_R16) {
if (tensor->ne[1] % 16 != 0) new_type = GGML_TYPE_BF16;
else chunk_size_multiplier = 16;