summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/quantize/quantize.cpp1
-rw-r--r--ggml/include/ggml.h4
-rw-r--r--ggml/src/ggml-common.h8
-rw-r--r--ggml/src/ggml-quants.c1
-rw-r--r--ggml/src/ggml.c22
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp271
-rw-r--r--ggml/src/iqk/iqk_quantize.cpp108
-rw-r--r--ggml/src/iqk/iqk_quantize.h6
-rw-r--r--include/llama.h1
-rw-r--r--src/llama.cpp20
10 files changed, 421 insertions, 21 deletions
diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp
index a8b4a44e..f8ce3edd 100644
--- a/examples/quantize/quantize.cpp
+++ b/examples/quantize/quantize.cpp
@@ -41,6 +41,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
{ "Q3_K_L", LLAMA_FTYPE_MOSTLY_Q3_K_L, " 3.35G, +0.1764 ppl @ LLaMA-v1-7B", },
{ "IQ4_NL", LLAMA_FTYPE_MOSTLY_IQ4_NL, " 4.50 bpw non-linear quantization", },
{ "IQ4_NL_X4",LLAMA_FTYPE_MOSTLY_IQ4_NL_X4," 4.50 bpw non-linear quantization", },
+ { "IQ4_XS_R4",LLAMA_FTYPE_MOSTLY_IQ4_XS_R4," 4.25 bpw non-linear quantization", },
{ "Q4_0_R4", LLAMA_FTYPE_MOSTLY_Q4_0_R4, " 4.50 bpw quantization", },
{ "Q5_0_R4", LLAMA_FTYPE_MOSTLY_Q5_0_R4, " 5.50 bpw quantization", },
{ "Q6_0_R4", LLAMA_FTYPE_MOSTLY_Q6_0_R4, " 6.50 bpw quantization", },
diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
index 99c39b9c..09f92eb9 100644
--- a/ggml/include/ggml.h
+++ b/ggml/include/ggml.h
@@ -410,7 +410,8 @@ extern "C" {
GGML_TYPE_Q4_0_R4 = 202,
GGML_TYPE_Q5_0_R4 = 206,
GGML_TYPE_Q8_0_R4 = 208,
- GGML_TYPE_IQ4_NL_X4 = 220,
+ GGML_TYPE_IQ4_NL_X4 = 220, // TODO: rename GGML_TYPE_IQ4_NL_X4 to GGML_TYPE_IQ4_NL_R4
+ GGML_TYPE_IQ4_XS_R4 = 223,
GGML_TYPE_Q6_0_R4 = 233,
GGML_TYPE_COUNT,
};
@@ -475,6 +476,7 @@ extern "C" {
GGML_FTYPE_MOSTLY_Q8_0_R4 = 207, // except 1d tensors
GGML_FTYPE_MOSTLY_Q5_0_R4 = 208, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ4_NL_X4 = 219, // except 1d tensors
+ GGML_FTYPE_MOSTLY_IQ4_XS_R4 = 222, // except 1d tensors
GGML_FTYPE_MOSTLY_Q6_0_R4 = 227, // except 1d tensors
};
diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h
index fb87a602..aa41bf55 100644
--- a/ggml/src/ggml-common.h
+++ b/ggml/src/ggml-common.h
@@ -448,6 +448,14 @@ typedef struct {
static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding");
typedef struct {
+ ggml_half d[4];
+ uint8_t scales_h[QK_K/32];
+ uint8_t scales_l[QK_K/16];
+ uint8_t qs[QK_K*2];
+} block_iq4_xs_r4;
+static_assert(sizeof(block_iq4_xs_r4) == 4*sizeof(ggml_half) + QK_K/32 + QK_K/16 + QK_K*2, "wrong iq4_xs_rs block size/padding");
+
+typedef struct {
uint8_t scales[QK_K/32];
uint8_t qs[QK_K/2];
} block_iq4_ks;
diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c
index 94950a36..4fdd2c36 100644
--- a/ggml/src/ggml-quants.c
+++ b/ggml/src/ggml-quants.c
@@ -15197,6 +15197,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
case GGML_TYPE_IQ4_KS: break;
case GGML_TYPE_IQ4_KSS: break;
case GGML_TYPE_IQ4_NL_X4: break;
+ case GGML_TYPE_IQ4_XS_R4: break;
case GGML_TYPE_Q4_0_R4: break;
case GGML_TYPE_Q5_0_R4: break;
case GGML_TYPE_Q6_0_R4: break;
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 203b1b57..f4320e99 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -1262,6 +1262,19 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.nrows = 1,
.row_meta_size = 0,
},
+ [GGML_TYPE_IQ4_XS_R4] = {
+ .type_name = "iq4_xs_r4",
+ .blck_size = QK_K,
+ .type_size = sizeof(block_iq4_xs),
+ .is_quantized = true,
+ .to_float = (ggml_to_float_t) dequantize_row_iq4_xs_r4,
+ .from_float = quantize_row_iq4_xs_r4,
+ .from_float_ref = (ggml_from_float_t)quantize_row_iq4_xs_r4_ref,
+ .vec_dot = vec_dot_iq4_xs_r4_q8_k,
+ .vec_dot_type = GGML_TYPE_Q8_K,
+ .nrows = 1,
+ .row_meta_size = 0,
+ },
[GGML_TYPE_Q4_0_R4] = {
.type_name = "q4_0_r4",
.blck_size = QK4_NL,
@@ -3989,6 +4002,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
case GGML_FTYPE_MOSTLY_IQ2_BN: wtype = GGML_TYPE_IQ2_BN; break;
case GGML_FTYPE_MOSTLY_IQ4_NL: wtype = GGML_TYPE_IQ4_NL; break;
case GGML_FTYPE_MOSTLY_IQ4_NL_X4: wtype = GGML_TYPE_IQ4_NL_X4;break;
+ case GGML_FTYPE_MOSTLY_IQ4_XS_R4: wtype = GGML_TYPE_IQ4_XS_R4;break;
case GGML_FTYPE_MOSTLY_Q4_0_R4: wtype = GGML_TYPE_Q4_0_R4; break;
case GGML_FTYPE_MOSTLY_Q5_0_R4: wtype = GGML_TYPE_Q5_0_R4; break;
case GGML_FTYPE_MOSTLY_Q6_0_R4: wtype = GGML_TYPE_Q6_0_R4; break;
@@ -10517,6 +10531,7 @@ static void ggml_compute_forward_add(
case GGML_TYPE_IQ2_BN:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_NL_X4:
+ case GGML_TYPE_IQ4_XS_R4:
case GGML_TYPE_Q4_0_R4:
case GGML_TYPE_Q5_0_R4:
case GGML_TYPE_Q6_0_R4:
@@ -10964,6 +10979,7 @@ static void ggml_compute_forward_add1(
case GGML_TYPE_IQ2_BN:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_NL_X4:
+ case GGML_TYPE_IQ4_XS_R4:
case GGML_TYPE_Q4_0_R4:
case GGML_TYPE_Q5_0_R4:
case GGML_TYPE_Q6_0_R4:
@@ -11108,6 +11124,7 @@ static void ggml_compute_forward_acc(
case GGML_TYPE_IQ2_BN:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_NL_X4:
+ case GGML_TYPE_IQ4_XS_R4:
case GGML_TYPE_Q4_0_R4:
case GGML_TYPE_Q5_0_R4:
case GGML_TYPE_Q6_0_R4:
@@ -14298,6 +14315,7 @@ static void ggml_compute_forward_out_prod(
case GGML_TYPE_IQ2_BN:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_NL_X4:
+ case GGML_TYPE_IQ4_XS_R4:
case GGML_TYPE_Q4_0_R4:
case GGML_TYPE_Q5_0_R4:
case GGML_TYPE_Q6_0_R4:
@@ -14682,6 +14700,7 @@ static void ggml_compute_forward_set(
case GGML_TYPE_IQ2_BN:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_NL_X4:
+ case GGML_TYPE_IQ4_XS_R4:
case GGML_TYPE_Q4_0_R4:
case GGML_TYPE_Q5_0_R4:
case GGML_TYPE_Q6_0_R4:
@@ -14960,6 +14979,7 @@ static void ggml_compute_forward_get_rows(
case GGML_TYPE_IQ2_BN:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_NL_X4:
+ case GGML_TYPE_IQ4_XS_R4:
case GGML_TYPE_Q4_0_R4:
case GGML_TYPE_Q5_0_R4:
case GGML_TYPE_Q6_0_R4:
@@ -15565,6 +15585,7 @@ static void ggml_compute_forward_clamp(
case GGML_TYPE_IQ2_BN:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_NL_X4:
+ case GGML_TYPE_IQ4_XS_R4:
case GGML_TYPE_Q4_0_R4:
case GGML_TYPE_Q5_0_R4:
case GGML_TYPE_Q6_0_R4:
@@ -22396,6 +22417,7 @@ size_t ggml_quantize_chunk(
case GGML_TYPE_IQ2_BN: result = quantize_iq2_bn (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ4_NL_X4: result = quantize_iq4_nl_x4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_IQ4_XS_R4: result = quantize_iq4_xs_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q4_0_R4: result = quantize_q4_0_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q5_0_R4: result = quantize_q5_0_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q6_0_R4: result = quantize_q6_0_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index f827e460..faa4cab7 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -2656,6 +2656,172 @@ static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn
}
#endif
+template <int nrc_y>
+static void mul_mat_iq4_xs_r4_q8_k_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%8 == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ auto m4 = _mm256_set1_epi8(0xf);
+#ifndef HAVE_FANCY_SIMD
+ auto m1 = _mm256_set1_epi16(1);
+#endif
+ auto values128 = _mm_loadu_si128((const __m128i *)iq4k_values);
+ auto values = MM256_SET_M128I(values128, values128);
+ //auto values = load_iq4nl_values_256();
+ int nbl = n / QK_K;
+ using helper_t = union { __m256i vec; uint32_t val[8]; };
+ helper_t h;
+ __m256 acc[nrc_y] = {};
+ __m256i qx[4];
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ const block_iq4_xs_r4 * iq4 = (const block_iq4_xs_r4 *)((const char *)vx + (ix+0)*bx);
+ for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
+ auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4[ibl].d));
+ auto d4 = _mm256_set_m128(dl, dl);
+ auto slbits = _mm_loadu_si128((const __m128i *)iq4[ibl].scales_l);
+ auto sl = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(slbits, 4), slbits), _mm256_set1_epi8(0xf));
+ auto aux64 = (const uint64_t *)iq4[ibl].scales_h;
+ auto shbits = _mm_set_epi64x(aux64[0] >> 2, aux64[0]);
+ auto sh = _mm256_and_si256(MM256_SET_M128I(shbits, _mm_slli_epi16(shbits, 4)), _mm256_set1_epi8(0x30));
+ h.vec = _mm256_sub_epi8(_mm256_or_si256(sl, sh), _mm256_set1_epi8(32));
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ auto iscales = _mm256_cvtepi8_epi32(_mm_set1_epi32(h.val[ib]));
+ auto scales = _mm256_mul_ps(d4, _mm256_cvtepi32_ps(iscales));
+#ifdef HAVE_FANCY_SIMD
+ auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-64.f));
+#endif
+ auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+0);
+ auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+1);
+ qx[0] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits1, m4));
+ qx[1] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits2, m4));
+ qx[2] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4));
+ qx[3] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4));
+#ifndef HAVE_FANCY_SIMD
+ auto s1 = _mm256_sign_epi8(qx[0], qx[0]);
+ auto s2 = _mm256_sign_epi8(qx[1], qx[1]);
+ auto s3 = _mm256_sign_epi8(qx[2], qx[2]);
+ auto s4 = _mm256_sign_epi8(qx[3], qx[3]);
+#endif
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
+#ifdef HAVE_FANCY_SIMD
+ auto sumi = _mm256_setzero_si256();
+ sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
+ sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
+ float d8 = q8.scale(iy, ibl);
+ float m8 = d8 * (q8.y[iy][ibl].bsums[2*ib+0] + q8.y[iy][ibl].bsums[2*ib+1]);
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales, _mm256_set1_ps(d8)), _mm256_cvtepi32_ps(sumi), acc[iy]);
+ acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(m8), acc[iy]);
+#else
+ auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0])),
+ _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1])));
+ auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2])),
+ _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3])));
+ auto sumi = _mm256_add_epi32(_mm256_madd_epi16(m1, sumi1), _mm256_madd_epi16(m1, sumi2));
+ acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
+ //auto sumi1 = _mm256_add_epi32(_mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00))),
+ // _mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55))));
+ //auto sumi2 = _mm256_add_epi32(_mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa))),
+ // _mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff))));
+ //auto sumi = _mm256_add_epi32(sumi1, sumi2);
+ //float d8 = q8.scale(iy, ibl);
+ //float m8 = d8 * (q8.y[iy][ibl].bsums[2*ib+0] + q8.y[iy][ibl].bsums[2*ib+1]);
+ //acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales, _mm256_set1_ps(d8)), _mm256_cvtepi32_ps(sumi), acc[iy]);
+ //acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(m8), acc[iy]);
+#endif
+ }
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
+ acc[iy] = _mm256_setzero_ps();
+ info.store(ix+0, iy, sum);
+ }
+ }
+}
+
+#ifdef HAVE_FANCY_SIMD
+template <int nrc_y>
+static void mul_mat_iq4_xs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ if constexpr (nrc_y == 1){
+ mul_mat_iq4_xs_r4_q8_k_avx2<1>(n, vx, bx, info, nrc_x);
+ } else {
+ GGML_ASSERT(nrc_x%8 == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ auto m4 = _mm512_set1_epi8(0xf);
+ auto values = load_iq4nl_values_512();
+ int nbl = n / QK_K;
+ using helper_t = union { __m256i vec; uint32_t val[8]; };
+ helper_t hl, hh;
+ __m512 acc[2*nrc_y] = {};
+ __m512i qx[4];
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ const block_iq4_xs_r4 * iq4l = (const block_iq4_xs_r4 *)((const char *)vx + (ix+0)*bx);
+ const block_iq4_xs_r4 * iq4h = (const block_iq4_xs_r4 *)((const char *)vx + (ix+4)*bx);
+ for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
+ auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4l[ibl].d));
+ auto dh = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4h[ibl].d));
+ auto d4 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_set_m128(dl, dl)), _mm256_set_m128(dh, dh), 1);
+ auto slbits_l = _mm_loadu_si128((const __m128i *)iq4l[ibl].scales_l);
+ auto shbits_l = _mm_loadu_si128((const __m128i *)iq4h[ibl].scales_l);
+ auto sl_l = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(slbits_l, 4), slbits_l), _mm256_set1_epi8(0xf));
+ auto sh_l = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(shbits_l, 4), shbits_l), _mm256_set1_epi8(0xf));
+ auto aux64 = (const uint64_t *)iq4l[ibl].scales_h;
+ auto slbits_h = _mm_set_epi64x(aux64[0] >> 2, aux64[0]);
+ aux64 = (const uint64_t *)iq4h[ibl].scales_h;
+ auto shbits_h = _mm_set_epi64x(aux64[0] >> 2, aux64[0]);
+ auto sl_h = _mm256_and_si256(MM256_SET_M128I(slbits_h, _mm_slli_epi16(slbits_h, 4)), _mm256_set1_epi8(0x30));
+ auto sh_h = _mm256_and_si256(MM256_SET_M128I(shbits_h, _mm_slli_epi16(shbits_h, 4)), _mm256_set1_epi8(0x30));
+ hl.vec = _mm256_sub_epi8(_mm256_or_si256(sl_l, sl_h), _mm256_set1_epi8(32));
+ hh.vec = _mm256_sub_epi8(_mm256_or_si256(sh_l, sh_h), _mm256_set1_epi8(32));
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ auto scales1 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hl.val[ib]));
+ auto scales2 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hh.val[ib]));
+ auto iscales = _mm512_inserti32x8(_mm512_castsi256_si512(scales1), scales2, 1);
+ auto scales = _mm512_mul_ps(d4, _mm512_cvtepi32_ps(iscales));
+ auto scales_m = _mm512_mul_ps(scales, _mm512_set1_ps(-64.f));
+ auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[ibl].qs+2*ib+0)),
+ _mm256_loadu_si256((const __m256i *)iq4h[ibl].qs+2*ib+0), 1);
+ auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[ibl].qs+2*ib+1)),
+ _mm256_loadu_si256((const __m256i *)iq4h[ibl].qs+2*ib+1), 1);
+ qx[0] = _mm512_shuffle_epi8(values, _mm512_and_si512(bits1, m4));
+ qx[1] = _mm512_shuffle_epi8(values, _mm512_and_si512(bits2, m4));
+ qx[2] = _mm512_shuffle_epi8(values, _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4));
+ qx[3] = _mm512_shuffle_epi8(values, _mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4));
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
+ auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1);
+ auto sumi = _mm512_setzero_si512();
+ sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
+ sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
+ float d8 = q8.scale(iy, ibl);
+ float m8 = d8 * (q8.y[iy][ibl].bsums[2*ib+0] + q8.y[iy][ibl].bsums[2*ib+1]);
+ acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, _mm512_set1_ps(d8)), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
+ acc[2*iy+1] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(m8), acc[2*iy+1]);
+ }
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sum512 = _mm512_add_ps(acc[2*iy+0], acc[2*iy+1]);
+ acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps();
+ auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 0), _mm512_extractf32x4_ps(sum512, 1));
+ auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 2), _mm512_extractf32x4_ps(sum512, 3));
+ info.store(ix+0, iy, sum1);
+ info.store(ix+4, iy, sum2);
+ }
+ }
+ }
+}
+#else
+template <int nrc_y>
+static void mul_mat_iq4_xs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ mul_mat_iq4_xs_r4_q8_k_avx2<nrc_y>(n, vx, bx, info, nrc_x);
+}
+#endif
+
template <typename Bits>
inline void multiply_add_1(int j, const Bits& bits, const __m256i * scales, const __m256i * q8, __m256i * sumi) {
if (j == 0) {
@@ -4625,6 +4791,18 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
mm.funcs[7] = mul_mat_iq4_nl_x4_q8_1<8>;
expected_typeB = GGML_TYPE_Q8_1;
break;
+ case GGML_TYPE_IQ4_XS_R4:
+ assert (ne00 % QK_K == 0);
+ mm.funcs[0] = mul_mat_iq4_xs_r4_q8_k<1>;
+ mm.funcs[1] = mul_mat_iq4_xs_r4_q8_k<2>;
+ mm.funcs[2] = mul_mat_iq4_xs_r4_q8_k<3>;
+ mm.funcs[3] = mul_mat_iq4_xs_r4_q8_k<4>;
+ mm.funcs[4] = mul_mat_iq4_xs_r4_q8_k<5>;
+ mm.funcs[5] = mul_mat_iq4_xs_r4_q8_k<6>;
+ mm.funcs[6] = mul_mat_iq4_xs_r4_q8_k<7>;
+ mm.funcs[7] = mul_mat_iq4_xs_r4_q8_k<8>;
+ expected_typeB = GGML_TYPE_Q8_K;
+ break;
case GGML_TYPE_Q4_0_R4:
assert (ne00 % QK4_NL == 0);
mm.funcs[0] = mul_mat_q4_0_r4_q8_1<1>;
@@ -7075,6 +7253,30 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
}
}
+IQK_ALWAYS_INLINE int32x4_t interleaved_dotq(const int8x16_t * qx, const int8x16x2_t& y) {
+ auto sumi = vdupq_n_s32(0);
+ sumi = vdotq_laneq_s32(sumi, qx[0], y.val[0], 0);
+ sumi = vdotq_laneq_s32(sumi, qx[1], y.val[1], 0);
+ sumi = vdotq_laneq_s32(sumi, qx[2], y.val[0], 1);
+ sumi = vdotq_laneq_s32(sumi, qx[3], y.val[1], 1);
+ sumi = vdotq_laneq_s32(sumi, qx[4], y.val[0], 2);
+ sumi = vdotq_laneq_s32(sumi, qx[5], y.val[1], 2);
+ sumi = vdotq_laneq_s32(sumi, qx[6], y.val[0], 3);
+ sumi = vdotq_laneq_s32(sumi, qx[7], y.val[1], 3);
+ return sumi;
+}
+
+IQK_ALWAYS_INLINE void prepare_iq4_nl_quants(const int8x16_t& values, const uint8x16_t& m4, const uint8x16x4_t& bits, int8x16_t * qx) {
+ qx[0] = vqtbl1q_s8(values, vandq_u8(bits.val[0], m4)); // 0...3 from the 4 rows
+ qx[1] = vqtbl1q_s8(values, vandq_u8(bits.val[1], m4)); // 16..19
+ qx[2] = vqtbl1q_s8(values, vandq_u8(bits.val[2], m4)); // 4...7
+ qx[3] = vqtbl1q_s8(values, vandq_u8(bits.val[3], m4)); // 20..23
+ qx[4] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4)); // 8..11
+ qx[5] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4)); // 24..27
+ qx[6] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[2], 4)); // 12..15
+ qx[7] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[3], 4)); // 28..31
+}
+
template <int nrc_y>
void mul_mat_iq4_nl_x4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
@@ -7091,25 +7293,10 @@ void mul_mat_iq4_nl_x4_q8_0(int n, const void * vx, size_t bx, const DataInfo& i
for (int k = 0; k < 4; ++k) {
auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[4*ib4+k].d));
auto bits = vld1q_u8_x4(iq4[4*ib4+k].qs);
- qx[0] = vqtbl1q_s8(values, vandq_u8(bits.val[0], m4)); // 0...3 from the 4 rows
- qx[1] = vqtbl1q_s8(values, vandq_u8(bits.val[1], m4)); // 16..19
- qx[2] = vqtbl1q_s8(values, vandq_u8(bits.val[2], m4)); // 4...7
- qx[3] = vqtbl1q_s8(values, vandq_u8(bits.val[3], m4)); // 20..23
- qx[4] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4)); // 8..11
- qx[5] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4)); // 24..27
- qx[6] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[2], 4)); // 12..15
- qx[7] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[3], 4)); // 28..31
+ prepare_iq4_nl_quants(values, m4, bits, qx);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8_x2(q8.y[iy][ib4].qs+32*k);
- auto sumi = vdupq_n_s32(0);
- sumi = vdotq_laneq_s32(sumi, qx[0], y.val[0], 0);
- sumi = vdotq_laneq_s32(sumi, qx[1], y.val[1], 0);
- sumi = vdotq_laneq_s32(sumi, qx[2], y.val[0], 1);
- sumi = vdotq_laneq_s32(sumi, qx[3], y.val[1], 1);
- sumi = vdotq_laneq_s32(sumi, qx[4], y.val[0], 2);
- sumi = vdotq_laneq_s32(sumi, qx[5], y.val[1], 2);
- sumi = vdotq_laneq_s32(sumi, qx[6], y.val[0], 3);
- sumi = vdotq_laneq_s32(sumi, qx[7], y.val[1], 3);
+ auto sumi = interleaved_dotq(qx, y);
auto d4d8 = vmulq_f32(scales, vdupq_n_f32(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k])));
acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi));
}
@@ -7122,6 +7309,45 @@ void mul_mat_iq4_nl_x4_q8_0(int n, const void * vx, size_t bx, const DataInfo& i
}
}
+template <int nrc_y>
+void mul_mat_iq4_xs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ Q8<nrc_y, block_q8_K> q8(info);
+ auto m4 = vdupq_n_u8(0xf);
+ auto values = vld1q_s8(iq4k_values);
+ int nbl = n / QK_K;
+ int8x16_t qx[8];
+ float32x4_t acc[nrc_y] = {};
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ const block_iq4_xs_r4 * iq4 = (const block_iq4_xs_r4 *)((const char *)vx + ix*bx);
+ for (int ibl = 0; ibl < nbl; ++ibl) {
+ const uint32_t * scales_l = (const uint32_t *)iq4[ibl].scales_l;
+ const uint32_t * scales_h = (const uint32_t *)iq4[ibl].scales_h;
+ auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[ibl].d));
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ auto ul = (scales_l[ib%4] >> 4*(ib/4)) & 0x0f0f0f0f;
+ auto uh = (scales_h[ib%2] >> 2*(ib/2)) & 0x03030303;
+ auto sl8 = vsub_s8(vreinterpret_s8_s32(vdup_n_s32(ul | (uh << 4))), vdup_n_s8(32));
+ auto sl16 = vmovl_s8(sl8);
+ auto sl32 = vmovl_s16(vget_low_s16(sl16));
+ auto scales = vmulq_f32(d4, vcvtq_f32_s32(sl32));
+ auto bits = vld1q_u8_x4(iq4[ibl].qs + 64*ib);
+ prepare_iq4_nl_quants(values, m4, bits, qx);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = vld1q_s8_x2(q8.y[iy][ibl].qs+32*ib);
+ auto sumi = interleaved_dotq(qx, y);
+ auto d4d8 = vmulq_f32(scales, vdupq_n_f32(q8.scale(iy, ibl)));
+ acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi));
+ }
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, acc[iy]);
+ acc[iy] = vdupq_n_f32(0.f);
+ }
+ }
+}
+
void mul_mat_iq4_nl_x4_q8_0_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<1, block_q8_0_x4> q8(info);
@@ -7529,6 +7755,17 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
m.funcs[7] = mul_mat_iq4_nl_x4_q8_0<8>;
expected_Btype = GGML_TYPE_Q8_0;
break;
+ case GGML_TYPE_IQ4_XS_R4:
+ m.funcs[0] = mul_mat_iq4_nl_x4_q8_0_1;
+ m.funcs[1] = mul_mat_iq4_xs_r4_q8_k<2>;
+ m.funcs[2] = mul_mat_iq4_xs_r4_q8_k<3>;
+ m.funcs[3] = mul_mat_iq4_xs_r4_q8_k<4>;
+ m.funcs[4] = mul_mat_iq4_xs_r4_q8_k<5>;
+ m.funcs[5] = mul_mat_iq4_xs_r4_q8_k<6>;
+ m.funcs[6] = mul_mat_iq4_xs_r4_q8_k<7>;
+ m.funcs[7] = mul_mat_iq4_xs_r4_q8_k<8>;
+ expected_Btype = GGML_TYPE_Q8_K;
+ break;
case GGML_TYPE_Q4_0_R4:
m.funcs[0] = mul_mat_q4_0_r4_q8_0<1>;
m.funcs[1] = mul_mat_q4_0_r4_q8_0<2>;
diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp
index f2e6a45e..acef04db 100644
--- a/ggml/src/iqk/iqk_quantize.cpp
+++ b/ggml/src/iqk/iqk_quantize.cpp
@@ -3572,3 +3572,111 @@ void vec_dot_q6_0_r4_q8_0(int n, float * s, size_t bs, const void * vx, size_t b
GGML_UNUSED(bx);
GGML_UNUSED(by);
}
+
+//
+// ========================================= iq4_xs_r4
+//
+
+void quantize_row_iq4_xs_r4_ref(const float * x, block_iq4_xs_r4 * y, int64_t k) {
+ quantize_iq4_xs_r4(x, (void *)y, 4, k/4, nullptr);
+}
+
+void quantize_row_iq4_xs_r4(const float * x, void * y, int64_t k) {
+ quantize_iq4_xs_r4(x, y, 4, k/4, nullptr);
+}
+
+static void repack_iq4_xs(int nrows, int n_per_row, const block_iq4_xs * x, block_iq4_xs_r4 * y) {
+ GGML_ASSERT(nrows%4 == 0);
+ GGML_ASSERT(n_per_row%QK_K == 0);
+ int nblock = n_per_row/QK_K;
+ const block_iq4_xs * x4[4];
+ for (int row = 0; row < nrows; row += 4) {
+ for (int k = 0; k < 4; ++k) x4[k] = x + nblock*k;
+ for (int ibl = 0; ibl < nblock; ++ibl) {
+ std::memset(y[ibl].scales_l, 0, QK_K/16);
+ std::memset(y[ibl].scales_h, 0, QK_K/32);
+ for (int k = 0; k < 4; ++k) {
+ y[ibl].d[k] = x4[k][ibl].d;
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ uint8_t sl = (x4[k][ibl].scales_l[ib/2] >> 4*(ib%2)) & 0xf;
+ uint8_t sh = (x4[k][ibl].scales_h >> 2*ib) & 3;
+ int i = 4*ib + k;
+ y[ibl].scales_l[i%16] |= (sl << 4*(i/16));
+ y[ibl].scales_h[i%8 ] |= (sh << 2*(i/8));
+ }
+ }
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ for (int k = 0; k < 4; ++k) for (int i = 0; i < 4; ++i) {
+ y[ibl].qs[64*ib+4*k+i+ 0] = (x4[k][ibl].qs[16*ib+i+0] & 0xf) | ((x4[k][ibl].qs[16*ib+i+ 8] & 0x0f) << 4); // 0....3 + 8...11 from each row
+ y[ibl].qs[64*ib+4*k+i+16] = (x4[k][ibl].qs[16*ib+i+0] >> 4) | ((x4[k][ibl].qs[16*ib+i+ 8] & 0xf0)); // 16...19 + 24...27 from each row
+ y[ibl].qs[64*ib+4*k+i+32] = (x4[k][ibl].qs[16*ib+i+4] & 0xf) | ((x4[k][ibl].qs[16*ib+i+12] & 0x0f) << 4); // 4....7 + 12...15 from each row
+ y[ibl].qs[64*ib+4*k+i+48] = (x4[k][ibl].qs[16*ib+i+4] >> 4) | ((x4[k][ibl].qs[16*ib+i+12] & 0xf0)); // 20...23 + 28...31 from each row
+ }
+ }
+ }
+ x += 4*nblock;
+ y += nblock;
+ }
+}
+
+size_t quantize_iq4_xs_r4(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) {
+ GGML_ASSERT(nrows%4 == 0);
+ GGML_ASSERT(n_per_row%QK_K == 0);
+ char * qcur = (char *)dst;
+ auto row_size = ggml_row_size(GGML_TYPE_IQ4_XS, n_per_row);
+ std::vector<char> qtmp(4*row_size);
+ for (int row = 0; row < nrows; row += 4) {
+ quantize_iq4_xs(src, (void *)qtmp.data(), 4, n_per_row, imatrix);
+ repack_iq4_xs(4, n_per_row, (const block_iq4_xs *)qtmp.data(), (block_iq4_xs_r4 *)qcur);
+ qcur += 4*row_size;
+ src += 4*n_per_row;
+ }
+ return nrows*row_size;
+}
+
+void dequantize_row_iq4_xs_r4(const block_iq4_xs_r4 * x, float * y, int64_t k) {
+ auto n_per_row = k/4;
+ float * y4[4] = {y, y + n_per_row, y + 2*n_per_row, y + 3*n_per_row};
+ int nblock = n_per_row/QK_K;
+ for (int ibl = 0; ibl < nblock; ++ibl) {
+ for (int k = 0; k < 4; ++k) {
+ const float d = GGML_FP16_TO_FP32(x[ibl].d[k]);
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ int is = 4*ib + k;
+ float dl = d * ((((x[ibl].scales_l[is%16] >> 4*(is/16)) & 0xf) | (((x[ibl].scales_h[is%8] >> 2*(is/8)) & 3) << 4)) - 32);
+ for (int i = 0; i < 4; ++i) {
+ y4[k][QK_K*ibl+32*ib+i+ 0] = dl * iq4k_values[x[ibl].qs[64*ib+4*k+i+ 0] & 0xf];
+ y4[k][QK_K*ibl+32*ib+i+ 8] = dl * iq4k_values[x[ibl].qs[64*ib+4*k+i+ 0] >> 4];
+ y4[k][QK_K*ibl+32*ib+i+16] = dl * iq4k_values[x[ibl].qs[64*ib+4*k+i+16] & 0xf];
+ y4[k][QK_K*ibl+32*ib+i+24] = dl * iq4k_values[x[ibl].qs[64*ib+4*k+i+16] >> 4];
+ y4[k][QK_K*ibl+32*ib+i+ 4] = dl * iq4k_values[x[ibl].qs[64*ib+4*k+i+32] & 0xf];
+ y4[k][QK_K*ibl+32*ib+i+12] = dl * iq4k_values[x[ibl].qs[64*ib+4*k+i+32] >> 4];
+ y4[k][QK_K*ibl+32*ib+i+20] = dl * iq4k_values[x[ibl].qs[64*ib+4*k+i+48] & 0xf];
+ y4[k][QK_K*ibl+32*ib+i+28] = dl * iq4k_values[x[ibl].qs[64*ib+4*k+i+48] >> 4];
+ }
+ }
+ }
+ //dequantize_row_iq4_xs(x + ib, ytmp, QK_K);
+ //for (int k = 0; k < 4; ++k) {
+ // for (int l = 0; l < 16; ++l) {
+ // for (int i = 0; i < 4; ++i) {
+ // //y4[k][ib*kBlockSize + i + 16*(l%4) + 4*(l/4)] = ytmp[16*l + 4*k + i];
+ // y4[k][ib*kBlockSize + i + 8*(l%8) + 4*(l/8)] = ytmp[16*l + 4*k + i];
+ // }
+ // }
+ //}
+ }
+}
+
+void vec_dot_iq4_xs_r4_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
+#if GGML_USE_IQK_MULMAT
+ if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ4_XS_R4, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) {
+ return;
+ }
+#endif
+ GGML_ASSERT(n%QK4_NL == 0);
+ GGML_ASSERT(nrc == 1);
+ GGML_UNUSED(bs);
+ GGML_UNUSED(bx);
+ GGML_UNUSED(by);
+}
diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h
index 3349c675..ad2294c5 100644
--- a/ggml/src/iqk/iqk_quantize.h
+++ b/ggml/src/iqk/iqk_quantize.h
@@ -93,6 +93,12 @@ size_t quantize_q6_0_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT ds
void dequantize_row_q6_0_r4(const block_q6_0_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void vec_dot_q6_0_r4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void quantize_row_iq4_xs_r4_ref(const float * GGML_RESTRICT x, block_iq4_xs_r4 * GGML_RESTRICT y, int64_t k);
+void quantize_row_iq4_xs_r4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+size_t quantize_iq4_xs_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+void dequantize_row_iq4_xs_r4(const block_iq4_xs_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void vec_dot_iq4_xs_r4_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
#ifdef __cplusplus
}
#endif
diff --git a/include/llama.h b/include/llama.h
index bf843ad2..77c988a5 100644
--- a/include/llama.h
+++ b/include/llama.h
@@ -184,6 +184,7 @@ extern "C" {
LLAMA_FTYPE_MOSTLY_Q8_0_R4 = 207, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_0_R4 = 208, // except 1d tensors
LLAMA_FTYPE_MOSTLY_IQ4_NL_X4 = 225, // except 1d tensors
+ LLAMA_FTYPE_MOSTLY_IQ4_XS_R4 = 230, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q6_0_R4 = 235, // except 1d tensors
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
diff --git a/src/llama.cpp b/src/llama.cpp
index f307fd89..e2abc235 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -3850,6 +3850,7 @@ struct llama_model_loader {
case GGML_TYPE_IQ2_BN: ftype = LLAMA_FTYPE_MOSTLY_IQ2_BN; break;
case GGML_TYPE_IQ4_NL: ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL; break;
case GGML_TYPE_IQ4_NL_X4:ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL_X4;break;
+ case GGML_TYPE_IQ4_XS_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS_R4;break;
case GGML_TYPE_Q4_0_R4: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_R4; break;
case GGML_TYPE_Q5_0_R4: ftype = LLAMA_FTYPE_MOSTLY_Q5_0_R4; break;
case GGML_TYPE_Q6_0_R4: ftype = LLAMA_FTYPE_MOSTLY_Q6_0_R4; break;
@@ -4559,6 +4560,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
case LLAMA_FTYPE_MOSTLY_IQ1_M: return "IQ1_M - 1.75 bpw";
case LLAMA_FTYPE_MOSTLY_IQ4_NL: return "IQ4_NL - 4.5 bpw";
case LLAMA_FTYPE_MOSTLY_IQ4_NL_X4:return "IQ4_NL_X4 - 4.5 bpw";
+ case LLAMA_FTYPE_MOSTLY_IQ4_XS_R4:return "IQ4_XS_R4 - 4.25 bpw";
case LLAMA_FTYPE_MOSTLY_Q4_0_R4: return "Q4_0_R4 - 4.5 bpw";
case LLAMA_FTYPE_MOSTLY_Q5_0_R4: return "Q5_0_R4 - 5.5 bpw";
case LLAMA_FTYPE_MOSTLY_Q6_0_R4: return "Q6_0_R4 - 6.5 bpw";
@@ -15779,6 +15781,9 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
else if (new_type == GGML_TYPE_IQ4_NL_X4) {
new_type = GGML_TYPE_IQ4_NL;
}
+ else if (new_type == GGML_TYPE_IQ4_XS_R4) {
+ new_type = GGML_TYPE_IQ4_XS;
+ }
else if (new_type == GGML_TYPE_Q4_0_R4) {
new_type = GGML_TYPE_Q4_0;
}
@@ -15852,7 +15857,8 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
new_type = qs.i_attention_wv < 2 ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K;
}
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K;
- else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL_X4 ||
+ else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS ||
+ ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL_X4 || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS_R4 ||
ftype == LLAMA_FTYPE_MOSTLY_IQ4_KS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_KSS) && qs.model.hparams.n_gqa() >= 2) {
new_type = GGML_TYPE_IQ5_K;
}
@@ -15883,6 +15889,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
else if (new_type == GGML_TYPE_Q4_K || new_type == GGML_TYPE_IQ4_XS) new_type = GGML_TYPE_Q5_K;
else if (new_type == GGML_TYPE_IQ4_NL) new_type = GGML_TYPE_Q5_K;
else if (new_type == GGML_TYPE_IQ4_NL_X4) new_type = GGML_TYPE_Q5_K;
+ else if (new_type == GGML_TYPE_IQ4_XS_R4) new_type = GGML_TYPE_Q5_K;
else if (new_type == GGML_TYPE_Q5_K) new_type = GGML_TYPE_Q6_K;
}
++qs.i_attention_wv;
@@ -15947,7 +15954,8 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
}
else if (i_layer < n_layer/8 && !qs.has_imatrix &&
(ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS ||
- ftype == LLAMA_FTYPE_MOSTLY_IQ4_KS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_KSS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL_X4)) {
+ ftype == LLAMA_FTYPE_MOSTLY_IQ4_KS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_KSS ||
+ ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL_X4 || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS_R4)) {
new_type = GGML_TYPE_Q5_K;
}
else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M && use_more_bits(i_layer, n_layer)) new_type = GGML_TYPE_Q6_K;
@@ -15973,7 +15981,8 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL ||
ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S ||
ftype == LLAMA_FTYPE_MOSTLY_IQ3_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_K ||
- ftype == LLAMA_FTYPE_MOSTLY_IQ2_K || ftype == LLAMA_FTYPE_MOSTLY_IQ3_K || ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL_X4) {
+ ftype == LLAMA_FTYPE_MOSTLY_IQ2_K || ftype == LLAMA_FTYPE_MOSTLY_IQ3_K ||
+ ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL_X4 || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS_R4) {
new_type = GGML_TYPE_Q5_K;
}
} else {
@@ -16183,6 +16192,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
case LLAMA_FTYPE_MOSTLY_IQ2_BN: default_type = GGML_TYPE_IQ2_BN; break;
case LLAMA_FTYPE_MOSTLY_IQ4_NL: default_type = GGML_TYPE_IQ4_NL; break;
case LLAMA_FTYPE_MOSTLY_IQ4_NL_X4:default_type = GGML_TYPE_IQ4_NL_X4;break;
+ case LLAMA_FTYPE_MOSTLY_IQ4_XS_R4:default_type = GGML_TYPE_IQ4_XS_R4;break;
case LLAMA_FTYPE_MOSTLY_Q4_0_R4: default_type = GGML_TYPE_Q4_0_R4; break;
case LLAMA_FTYPE_MOSTLY_Q5_0_R4: default_type = GGML_TYPE_Q5_0_R4; break;
case LLAMA_FTYPE_MOSTLY_Q6_0_R4: default_type = GGML_TYPE_Q6_0_R4; break;
@@ -16548,6 +16558,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ4_NL;
else chunk_size_multiplier = 4;
}
+ else if (new_type == GGML_TYPE_IQ4_XS_R4) {
+ if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ4_XS;
+ else chunk_size_multiplier = 4;
+ }
else if (new_type == GGML_TYPE_Q4_0_R4) {
if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q4_0;
else chunk_size_multiplier = 4;