summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--common/common.cpp3
-rw-r--r--examples/llama-bench/llama-bench.cpp3
-rw-r--r--examples/quantize/quantize.cpp2
-rw-r--r--ggml/include/ggml.h4
-rw-r--r--ggml/src/ggml-quants.c6
-rw-r--r--ggml/src/ggml.c48
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp598
-rw-r--r--ggml/src/iqk/iqk_quantize.cpp290
-rw-r--r--ggml/src/iqk/iqk_quantize.h12
-rw-r--r--include/llama.h2
-rw-r--r--src/llama.cpp49
11 files changed, 983 insertions, 34 deletions
diff --git a/common/common.cpp b/common/common.cpp
index 44678d7a..f7a6f76f 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -2259,6 +2259,9 @@ static ggml_type kv_cache_type_from_str(const std::string & s) {
if (s == "q6_0") {
return GGML_TYPE_Q6_0;
}
+ if (s == "q8_KV") {
+ return GGML_TYPE_Q8_KV;
+ }
throw std::runtime_error("Invalid cache type: " + s);
}
diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp
index 95df06dc..0222c213 100644
--- a/examples/llama-bench/llama-bench.cpp
+++ b/examples/llama-bench/llama-bench.cpp
@@ -339,6 +339,9 @@ static ggml_type ggml_type_from_name(const std::string & s) {
if (s == "q6_0") {
return GGML_TYPE_Q6_0;
}
+ if (s == "q8_KV") {
+ return GGML_TYPE_Q8_KV;
+ }
return GGML_TYPE_COUNT;
}
diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp
index 7ceee208..916f57ec 100644
--- a/examples/quantize/quantize.cpp
+++ b/examples/quantize/quantize.cpp
@@ -56,6 +56,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
{ "Q5_0_R4", LLAMA_FTYPE_MOSTLY_Q5_0_R4, " 5.50 bpw quantization", },
{ "Q6_0_R4", LLAMA_FTYPE_MOSTLY_Q6_0_R4, " 6.50 bpw quantization", },
{ "Q8_0_R8", LLAMA_FTYPE_MOSTLY_Q8_0_R8, " 8.50 bpw quantization", },
+ { "Q8_KV", LLAMA_FTYPE_MOSTLY_Q8_KV, " 8.00 bpw quantization", },
{ "IQ4_XS", LLAMA_FTYPE_MOSTLY_IQ4_XS, " 4.25 bpw non-linear quantization", },
{ "IQ4_KS", LLAMA_FTYPE_MOSTLY_IQ4_KS, " 4.25 bpw non-linear quantization", },
{ "IQ4_KS_R4",LLAMA_FTYPE_MOSTLY_IQ4_KS_R4,"IQ4_KS repacked", },
@@ -82,6 +83,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
{ "Q6_K", LLAMA_FTYPE_MOSTLY_Q6_K, " 5.15G, +0.0008 ppl @ LLaMA-v1-7B", },
{ "Q6_K_R4", LLAMA_FTYPE_MOSTLY_Q6_K_R4, "Q6_K repacked", },
{ "Q8_K_R8", LLAMA_FTYPE_MOSTLY_Q8_K_R8, "Q8_K repacked", },
+ { "Q8_KV_R8", LLAMA_FTYPE_MOSTLY_Q8_KV_R8, "Q8_KV repacked", },
{ "Q8_0", LLAMA_FTYPE_MOSTLY_Q8_0, " 6.70G, +0.0004 ppl @ LLaMA-v1-7B", },
{ "Q4_0_4_4", LLAMA_FTYPE_MOSTLY_Q4_0_4_4, " 4.34G, +0.4685 ppl @ Llama-3-8B", },
{ "Q4_0_4_8", LLAMA_FTYPE_MOSTLY_Q4_0_4_8, " 4.34G, +0.4685 ppl @ Llama-3-8B", },
diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
index 66bcb25a..d2131a15 100644
--- a/ggml/include/ggml.h
+++ b/ggml/include/ggml.h
@@ -416,6 +416,7 @@ extern "C" {
GGML_TYPE_Q8_K32 = 148,
GGML_TYPE_Q8_KR8 = 149,
GGML_TYPE_Q8_K128 = 150,
+ GGML_TYPE_Q8_KV = 151,
GGML_TYPE_Q4_0_R8 = 202,
GGML_TYPE_Q5_0_R4 = 206,
@@ -442,6 +443,7 @@ extern "C" {
GGML_TYPE_IQ4_K_R4 = 339,
GGML_TYPE_IQ5_K_R4 = 340,
GGML_TYPE_IQ4_KS_R4 = 344,
+ GGML_TYPE_Q8_KV_R8 = 398,
GGML_TYPE_Q8_K_R8 = 399,
GGML_TYPE_COUNT,
};
@@ -501,6 +503,7 @@ extern "C" {
GGML_FTYPE_MOSTLY_IQ4_KS = 137, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ2_KS = 138, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ4_KSS = 139, // except 1d tensors
+ GGML_FTYPE_MOSTLY_Q8_KV = 140, // except 1d tensors
//
GGML_FTYPE_MOSTLY_Q4_0_R8 = 202, // except 1d tensors
GGML_FTYPE_MOSTLY_Q8_0_R8 = 207, // except 1d tensors
@@ -527,6 +530,7 @@ extern "C" {
GGML_FTYPE_MOSTLY_IQ4_K_R4 = 332, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ5_K_R4 = 333, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ4_KS_R4 = 337, // except 1d tensors
+ GGML_FTYPE_MOSTLY_Q8_KV_R8 = 398, // except 1d tensors
GGML_FTYPE_MOSTLY_Q8_K_R8 = 399, // except 1d tensors
};
diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c
index fe7de167..e8218e76 100644
--- a/ggml/src/ggml-quants.c
+++ b/ggml/src/ggml-quants.c
@@ -15214,8 +15214,10 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
case GGML_TYPE_IQ3_K_R4: break;
case GGML_TYPE_IQ4_K_R4: break;
case GGML_TYPE_IQ5_K_R4: break;
- case GGML_TYPE_IQ4_KS_R4: break;
- case GGML_TYPE_Q8_K_R8: break;
+ case GGML_TYPE_IQ4_KS_R4:break;
+ case GGML_TYPE_Q8_KV_R8: break;
+ case GGML_TYPE_Q8_K_R8: break;
+ case GGML_TYPE_Q8_KV: break;
case GGML_TYPE_BF16_R16: break;
case GGML_TYPE_Q4_0_4_4:
case GGML_TYPE_Q4_0_4_8:
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 8ab6b0a9..0aee8dd4 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -1362,6 +1362,30 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_q8_K128,
.row_meta_size = 0,
},
+ [GGML_TYPE_Q8_KV] = {
+ .type_name = "q8_KV",
+ .blck_size = 32,
+ .type_size = 32,
+ .is_quantized = true,
+ .to_float = (ggml_to_float_t) dequantize_row_q8_KV,
+ .from_float = quantize_row_q8_KV,
+ .from_float_ref = (ggml_from_float_t)quantize_row_q8_KV_ref,
+ .vec_dot = vec_dot_q8_KV_q8_KV,
+ .vec_dot_type = GGML_TYPE_Q8_KV,
+ .row_meta_size = 8,
+ },
+ [GGML_TYPE_Q8_KV_R8] = {
+ .type_name = "q8_KV_r8",
+ .blck_size = 32,
+ .type_size = 32,
+ .is_quantized = true,
+ .to_float = (ggml_to_float_t) dequantize_row_q8_KV_r8,
+ .from_float = quantize_row_q8_KV_r8,
+ .from_float_ref = (ggml_from_float_t)quantize_row_q8_KV_r8_ref,
+ .vec_dot = vec_dot_q8_KV_r8_q8_KV,
+ .vec_dot_type = GGML_TYPE_Q8_KV,
+ .row_meta_size = 4,
+ },
[GGML_TYPE_Q8_K16] = {
.type_name = "q8_K16",
.blck_size = 64,
@@ -4373,6 +4397,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
case GGML_FTYPE_MOSTLY_Q5_1: wtype = GGML_TYPE_Q5_1; break;
case GGML_FTYPE_MOSTLY_Q6_0: wtype = GGML_TYPE_Q6_0; break;
case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break;
+ case GGML_FTYPE_MOSTLY_Q8_KV: wtype = GGML_TYPE_Q8_KV; break;
case GGML_FTYPE_MOSTLY_Q2_K: wtype = GGML_TYPE_Q2_K; break;
case GGML_FTYPE_MOSTLY_Q2_K_R4: wtype = GGML_TYPE_Q2_K_R4; break;
case GGML_FTYPE_MOSTLY_Q3_K: wtype = GGML_TYPE_Q3_K; break;
@@ -4384,6 +4409,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break;
case GGML_FTYPE_MOSTLY_Q6_K_R4: wtype = GGML_TYPE_Q6_K_R4; break;
case GGML_FTYPE_MOSTLY_Q8_K_R8: wtype = GGML_TYPE_Q8_K_R8; break;
+ case GGML_FTYPE_MOSTLY_Q8_KV_R8: wtype = GGML_TYPE_Q8_KV_R8; break;
case GGML_FTYPE_MOSTLY_IQ2_XXS: wtype = GGML_TYPE_IQ2_XXS; break;
case GGML_FTYPE_MOSTLY_IQ2_XXS_R4: wtype = GGML_TYPE_IQ2_XXS_R4;break;
case GGML_FTYPE_MOSTLY_IQ2_XS: wtype = GGML_TYPE_IQ2_XS; break;
@@ -9436,7 +9462,7 @@ static void ggml_compute_forward_dup_f16(
float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
size_t id = 0;
- size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
+ size_t rs = ggml_row_size(dst->type, ne00); //nb0 * (ne00 / ggml_blck_size(dst->type));
char * dst_ptr = (char *) dst->data;
for (int i03 = 0; i03 < ne03; i03++) {
@@ -9722,7 +9748,7 @@ static void ggml_compute_forward_dup_bf16(
float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
size_t id = 0;
- size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
+ size_t rs = ggml_row_size(dst->type, ne00); //nb0 * (ne00 / ggml_blck_size(dst->type));
char * dst_ptr = (char *) dst->data;
for (int i03 = 0; i03 < ne03; i03++) {
@@ -10042,7 +10068,7 @@ static void ggml_compute_forward_dup_f32(
ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float;
size_t id = 0;
- size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
+ size_t rs = ggml_row_size(dst->type, ne00); //nb0 * (ne00 / ggml_blck_size(dst->type));
char * dst_ptr = (char *) dst->data;
for (int i03 = 0; i03 < ne03; i03++) {
@@ -10936,6 +10962,7 @@ static void ggml_compute_forward_add(
case GGML_TYPE_Q6_K:
case GGML_TYPE_Q6_K_R4:
case GGML_TYPE_Q8_K_R8:
+ case GGML_TYPE_Q8_KV_R8:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XXS_R4:
case GGML_TYPE_IQ2_XS:
@@ -11406,6 +11433,7 @@ static void ggml_compute_forward_add1(
case GGML_TYPE_Q6_K:
case GGML_TYPE_Q6_K_R4:
case GGML_TYPE_Q8_K_R8:
+ case GGML_TYPE_Q8_KV_R8:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XXS_R4:
case GGML_TYPE_IQ2_XS:
@@ -11573,6 +11601,7 @@ static void ggml_compute_forward_acc(
case GGML_TYPE_Q6_K:
case GGML_TYPE_Q6_K_R4:
case GGML_TYPE_Q8_K_R8:
+ case GGML_TYPE_Q8_KV_R8:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XXS_R4:
case GGML_TYPE_IQ2_XS:
@@ -14061,7 +14090,7 @@ static void ggml_compute_forward_mul_mat(
// nb01 >= nb00 - src0 is not transposed
// compute by src0 rows
-#if GGML_USE_IQK_MULMAT || GGML_USE_LLAMAFILE
+#if GGML_USE_LLAMAFILE
// broadcast factors
const int64_t r2 = ne12 / ne02;
const int64_t r3 = ne13 / ne03;
@@ -14344,7 +14373,7 @@ static void ggml_compute_forward_mul_mat_id(
char * wdata_src1_end = (src1->type == vec_dot_type) ?
(char *) params->wdata :
- (char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t));
+ (char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, src1->ne[0])*ggml_nrows(src1), sizeof(int64_t));
struct mmid_row_mapping {
int32_t i1;
@@ -14768,6 +14797,7 @@ static void ggml_compute_forward_out_prod(
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q6_0:
case GGML_TYPE_Q8_0:
+ case GGML_TYPE_Q8_KV:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q2_K_R4:
case GGML_TYPE_Q3_K:
@@ -14779,6 +14809,7 @@ static void ggml_compute_forward_out_prod(
case GGML_TYPE_Q6_K:
case GGML_TYPE_Q6_K_R4:
case GGML_TYPE_Q8_K_R8:
+ case GGML_TYPE_Q8_KV_R8:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XXS_R4:
case GGML_TYPE_IQ2_XS:
@@ -15186,6 +15217,7 @@ static void ggml_compute_forward_set(
case GGML_TYPE_Q6_K:
case GGML_TYPE_Q6_K_R4:
case GGML_TYPE_Q8_K_R8:
+ case GGML_TYPE_Q8_KV_R8:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XXS_R4:
case GGML_TYPE_IQ2_XS:
@@ -15473,6 +15505,7 @@ static void ggml_compute_forward_get_rows(
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q6_0:
case GGML_TYPE_Q8_0:
+ case GGML_TYPE_Q8_KV:
case GGML_TYPE_Q8_1:
case GGML_TYPE_Q8_0_X4:
case GGML_TYPE_Q8_1_X4:
@@ -15487,6 +15520,7 @@ static void ggml_compute_forward_get_rows(
case GGML_TYPE_Q6_K:
case GGML_TYPE_Q6_K_R4:
case GGML_TYPE_Q8_K_R8:
+ case GGML_TYPE_Q8_KV_R8:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XXS_R4:
case GGML_TYPE_IQ2_XS:
@@ -16116,6 +16150,7 @@ static void ggml_compute_forward_clamp(
case GGML_TYPE_Q6_K:
case GGML_TYPE_Q6_K_R4:
case GGML_TYPE_Q8_K_R8:
+ case GGML_TYPE_Q8_KV_R8:
case GGML_TYPE_Q8_KR8:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XXS_R4:
@@ -16159,6 +16194,7 @@ static void ggml_compute_forward_clamp(
case GGML_TYPE_Q8_K:
case GGML_TYPE_Q8_K64:
case GGML_TYPE_Q8_K128:
+ case GGML_TYPE_Q8_KV:
case GGML_TYPE_Q8_K16:
case GGML_TYPE_Q8_K32:
case GGML_TYPE_Q4_0_4_4:
@@ -22970,6 +23006,7 @@ size_t ggml_quantize_chunk(
case GGML_TYPE_Q5_1: result = quantize_q5_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q6_0: result = quantize_q6_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q8_0: result = quantize_q8_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_Q8_KV: result = quantize_q8_KV(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q2_K: result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q2_K_R4: result = quantize_q2_k_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q3_K: result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
@@ -22981,6 +23018,7 @@ size_t ggml_quantize_chunk(
case GGML_TYPE_Q6_K: result = quantize_q6_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q6_K_R4: result = quantize_q6_k_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q8_K_R8: result = quantize_q8_k_r8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_Q8_KV_R8:result = quantize_q8_KV_r8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ2_XXS: result = quantize_iq2_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ2_XXS_R4:result = quantize_iq2_xxs_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ2_XS: result = quantize_iq2_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 8d6b45da..3bfded73 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -269,6 +269,8 @@ struct MulMat {
case GGML_TYPE_IQ4_XS_R8:
case GGML_TYPE_Q4_K_R4:
case GGML_TYPE_Q5_K_R4:
+ case GGML_TYPE_Q8_KV:
+ case GGML_TYPE_Q8_KV_R8:
case GGML_TYPE_Q8_K_R8: return 8;
case GGML_TYPE_Q4_0_R8:
case GGML_TYPE_Q8_0_R8:
@@ -301,6 +303,8 @@ struct MulMat {
case GGML_TYPE_IQ4_XS_R8:
case GGML_TYPE_Q4_0_R8:
case GGML_TYPE_Q8_0_R8:
+ case GGML_TYPE_Q8_KV:
+ case GGML_TYPE_Q8_KV_R8:
case GGML_TYPE_Q8_K_R8: return 8;
case GGML_TYPE_BF16_R16: return 16;
default: return 1;
@@ -6107,7 +6111,7 @@ static void mul_mat_q6_k_r4_q8_k(int n, const void * vx, size_t bx, const DataIn
// The HAVE_FANCY_SIMD should only be #if defined(__AVX512_VNNI__ && defined(__AVX512VL__)
template <int nrc_y>
static void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
+ GGML_ASSERT(nrc_x%8 == 0);
Q8<nrc_y, block_q8_K> q8(info);
#ifndef HAVE_FANCY_SIMD
auto m1 = _mm256_set1_epi16(1);
@@ -6169,6 +6173,230 @@ static void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataIn
}
}
+// The HAVE_FANCY_SIMD should only be #if defined(__AVX512_VNNI__ && defined(__AVX512VL__)
+template <int nrc_y>
+static void mul_mat_q8_KV_r8_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%8 == 0);
+ GGML_ASSERT(n%32 == 0);
+#ifndef HAVE_FANCY_SIMD
+ auto m1 = _mm256_set1_epi16(1);
+#endif
+ int nb = n / 16;
+ __m256i acc[nrc_y] = {};
+ __m256i qx[4];
+ float dy[nrc_y];
+#ifdef HAVE_FANCY_SIMD
+ float sy[nrc_y];
+#endif
+ const int8_t * q8y[nrc_y];
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto dptr = (const float *)info.src1_row(iy);
+ dy[iy] = dptr[0];
+#ifdef HAVE_FANCY_SIMD
+ auto iptr = (const int32_t *)(dptr + 1);
+ sy[iy] = -127*iptr[0];
+#endif
+ q8y[iy] = (const int8_t *)(dptr + 2);
+ }
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ auto dptr = (const float *)((const char *)vx + ix*bx);
+ auto dx = _mm256_loadu_ps(dptr);
+ auto q8x = (const int8_t *)(dptr + 8);
+ for (int ib = 0; ib < nb; ++ib) { // Blocks of 16 for 8 interleaved rows
+ qx[0] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+0);
+ qx[1] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+1);
+ qx[2] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+2);
+ qx[3] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+3);
+#ifndef HAVE_FANCY_SIMD
+ auto s0 = _mm256_sign_epi8(qx[0], qx[0]);
+ auto s1 = _mm256_sign_epi8(qx[1], qx[1]);
+ auto s2 = _mm256_sign_epi8(qx[2], qx[2]);
+ auto s3 = _mm256_sign_epi8(qx[3], qx[3]);
+#endif
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y128 = _mm_loadu_si128((const __m128i*)q8y[iy]+ib);
+ auto y = MM256_SET_M128I(y128, y128);
+#ifdef HAVE_FANCY_SIMD
+ acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[0], _mm256_shuffle_epi32(y, 0x00));
+ acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[1], _mm256_shuffle_epi32(y, 0x55));
+ acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[2], _mm256_shuffle_epi32(y, 0xaa));
+ acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[3], _mm256_shuffle_epi32(y, 0xff));
+#else
+ auto sumi1 = _mm256_maddubs_epi16(s0, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
+ auto sumi2 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
+ auto sumi3 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
+ auto sumi4 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
+ auto sumi12 = _mm256_add_epi32(_mm256_madd_epi16(m1, sumi1), _mm256_madd_epi16(m1, sumi2));
+ auto sumi34 = _mm256_add_epi32(_mm256_madd_epi16(m1, sumi3), _mm256_madd_epi16(m1, sumi4));
+ acc[iy] = _mm256_add_epi32(acc[iy], _mm256_add_epi32(sumi12, sumi34));
+#endif
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto scale = _mm256_mul_ps(dx, _mm256_set1_ps(dy[iy]));
+#ifdef HAVE_FANCY_SIMD
+ acc[iy] = _mm256_add_epi32(acc[iy], _mm256_set1_epi32(sy[iy]));
+#endif
+ info.store(ix, iy, _mm256_mul_ps(scale, _mm256_cvtepi32_ps(acc[iy])));
+ acc[iy] = _mm256_setzero_si256();
+ }
+ }
+}
+
+template <int nrc_y>
+static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%8 == 0);
+ GGML_ASSERT(n%32 == 0);
+ __m256i qx[2];
+ __m256i acc[2*nrc_y] = {};
+ float dy[nrc_y];
+#ifdef HAVE_FANCY_SIMD
+ int32_t sy[nrc_y];
+#else
+ __m256i sx[2];
+ auto m1 = _mm256_set1_epi16(1);
+#endif
+ const int8_t * q8y[nrc_y];
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto dptr = (const float *)info.src1_row(iy);
+ dy[iy] = dptr[0];
+#ifdef HAVE_FANCY_SIMD
+ auto iptr = (const int32_t *)(dptr+1);
+ sy[iy] = -127*iptr[0];
+#endif
+ q8y[iy] = (const int8_t *)(dptr + 2);
+ }
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ auto dx = (const float *)((const char *)vx + ix*bx);
+ auto q8x = (const int8_t *)(dx + 2);
+ for (int i = 0; i < n/64; ++i) {
+ for (int j = 0; j < 2; ++j) {
+#ifdef HAVE_FANCY_SIMD
+ qx[j] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)q8x + 2*i + j), _mm256_set1_epi8(127));
+#else
+ qx[j] = _mm256_loadu_si256((const __m256i *)q8x + 2*i + j);
+ sx[j] = _mm256_sign_epi8(qx[j], qx[j]);
+#endif
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ for (int j = 0; j < 2; ++j) {
+#ifdef HAVE_FANCY_SIMD
+ acc[2*iy+j] = _mm256_dpbusd_epi32(acc[2*iy+j], qx[j], _mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j));
+#else
+ auto dot = _mm256_maddubs_epi16(sx[j], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j), qx[j]));
+ acc[2*iy+j] = _mm256_add_epi32(acc[2*iy+j], _mm256_madd_epi16(m1, dot));
+#endif
+ }
+ }
+ }
+ if (int i = 2*(n/64); i < n/32) {
+#ifdef HAVE_FANCY_SIMD
+ qx[0] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)q8x + i), _mm256_set1_epi8(127));
+#else
+ qx[0] = _mm256_loadu_si256((const __m256i *)q8x + i);
+ sx[0] = _mm256_sign_epi8(qx[0], qx[0]);
+#endif
+ for (int iy = 0; iy < nrc_y; ++iy) {
+#ifdef HAVE_FANCY_SIMD
+ acc[2*iy] = _mm256_dpbusd_epi32(acc[2*iy], qx[0], _mm256_loadu_si256((const __m256i *)q8y[iy] + i));
+#else
+ auto dot = _mm256_maddubs_epi16(sx[0], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + i), qx[0]));
+ acc[2*iy] = _mm256_add_epi32(acc[2*iy], _mm256_madd_epi16(m1, dot));
+#endif
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sumi = hsum_i32_8(_mm256_add_epi32(acc[2*iy], acc[2*iy+1]));
+#ifdef HAVE_FANCY_SIMD
+ info.store(ix, iy, dx[0]*dy[iy]*(sumi+sy[iy]));
+#else
+ info.store(ix, iy, dx[0]*dy[iy]*sumi);
+#endif
+ acc[2*iy] = acc[2*iy+1] = _mm256_setzero_si256();
+ }
+ }
+}
+
+template <int nrc_y>
+static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%8 == 0);
+ GGML_ASSERT(n%32 == 0);
+ __m256i qx[4];
+#ifndef HAVE_FANCY_SIMD
+ __m256i sx[4];
+ auto m1 = _mm256_set1_epi16(1);
+#endif
+ __m256i acc[nrc_y] = {};
+ float dy[nrc_y];
+#ifdef HAVE_FANCY_SIMD
+ int32_t sy[nrc_y];
+#endif
+ const int8_t * q8y[nrc_y];
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto dptr = (const float *)info.src1_row(iy);
+ dy[iy] = dptr[0];
+#ifdef HAVE_FANCY_SIMD
+ auto iptr = (const int32_t *)(dptr + 1);
+ sy[iy] = -127*iptr[0];
+#endif
+ q8y[iy] = (const int8_t *)(dptr + 2);
+ }
+ const int8_t * q8x[4];
+ float dx[4];
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ for (int kx = 0; kx < 4; ++kx) {
+ auto dptr = (const float *)((const char *)vx + (ix+kx)*bx);
+ dx[kx] = dptr[0];
+ q8x[kx] = (const int8_t *)(dptr + 2);
+ }
+ for (int i = 0; i < n/32; ++i) {
+ for (int kx = 0; kx < 4; ++kx) qx[kx] = _mm256_loadu_si256((const __m256i *)q8x[kx] + i);
+ auto t0 = _mm256_unpacklo_epi32(qx[0], qx[1]);
+ auto t1 = _mm256_unpacklo_epi32(qx[2], qx[3]);
+ auto t2 = _mm256_unpackhi_epi32(qx[0], qx[1]);
+ auto t3 = _mm256_unpackhi_epi32(qx[2], qx[3]);
+#ifdef HAVE_FANCY_SIMD
+ qx[0] = _mm256_add_epi8(_mm256_unpacklo_epi64(t0, t1), _mm256_set1_epi8(127));
+ qx[1] = _mm256_add_epi8(_mm256_unpackhi_epi64(t0, t1), _mm256_set1_epi8(127));
+ qx[2] = _mm256_add_epi8(_mm256_unpacklo_epi64(t2, t3), _mm256_set1_epi8(127));
+ qx[3] = _mm256_add_epi8(_mm256_unpackhi_epi64(t2, t3), _mm256_set1_epi8(127));
+#else
+ qx[0] = _mm256_unpacklo_epi64(t0, t1); sx[0] = _mm256_sign_epi8(qx[0], qx[0]);
+ qx[1] = _mm256_unpackhi_epi64(t0, t1); sx[1] = _mm256_sign_epi8(qx[1], qx[1]);
+ qx[2] = _mm256_unpacklo_epi64(t2, t3); sx[2] = _mm256_sign_epi8(qx[2], qx[2]);
+ qx[3] = _mm256_unpackhi_epi64(t2, t3); sx[3] = _mm256_sign_epi8(qx[3], qx[3]);
+#endif
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = _mm256_loadu_si256((const __m256i *)q8y[iy] + i);
+#ifdef HAVE_FANCY_SIMD
+ acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[0], _mm256_shuffle_epi32(y, 0x00));
+ acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[1], _mm256_shuffle_epi32(y, 0x55));
+ acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[2], _mm256_shuffle_epi32(y, 0xaa));
+ acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[3], _mm256_shuffle_epi32(y, 0xff));
+#else
+ auto dot1 = _mm256_maddubs_epi16(sx[0], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
+ auto dot2 = _mm256_maddubs_epi16(sx[1], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
+ auto dot3 = _mm256_maddubs_epi16(sx[2], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
+ auto dot4 = _mm256_maddubs_epi16(sx[3], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
+ auto dot12 = _mm256_add_epi32(_mm256_madd_epi16(m1, dot1), _mm256_madd_epi16(m1, dot2));
+ auto dot34 = _mm256_add_epi32(_mm256_madd_epi16(m1, dot3), _mm256_madd_epi16(m1, dot4));
+ acc[iy] = _mm256_add_epi32(acc[iy], _mm256_add_epi32(dot12, dot34));
+#endif
+ }
+ }
+ auto scales_x = _mm_loadu_ps(dx);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sumi = _mm_add_epi32(_mm256_castsi256_si128(acc[iy]), _mm256_extracti128_si256(acc[iy], 1));
+#ifdef HAVE_FANCY_SIMD
+ sumi = _mm_add_epi32(sumi, _mm_set1_epi32(sy[iy]));
+#endif
+ auto scale = _mm_mul_ps(scales_x, _mm_set1_ps(dy[iy]));
+ info.store(ix, iy, _mm_mul_ps(scale, _mm_cvtepi32_ps(sumi)));
+ acc[iy] = _mm256_setzero_si256();
+ }
+ }
+}
+
#ifdef __AVX512BF16__
template <int nrc_y>
static void mul_mat_bf16_r16_bf16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
@@ -9114,6 +9342,33 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
#endif
expected_typeB = GGML_TYPE_Q8_KR8;
break;
+ case GGML_TYPE_Q8_KV:
+ assert (ne00 % 32 == 0);
+ mm.funcs[0] = mul_mat_q8_KV_q8_KV_1<1>;
+ mm.funcs[1] = mul_mat_q8_KV_q8_KV<2>;
+ mm.funcs[2] = mul_mat_q8_KV_q8_KV<3>;
+ mm.funcs[3] = mul_mat_q8_KV_q8_KV<4>;
+ mm.funcs[4] = mul_mat_q8_KV_q8_KV<5>;
+ mm.funcs[5] = mul_mat_q8_KV_q8_KV<6>;
+ mm.funcs[6] = mul_mat_q8_KV_q8_KV<7>;
+ mm.funcs[7] = mul_mat_q8_KV_q8_KV<8>;
+#ifdef HAVE_FANCY_SIMD
+ mm.func16 = mul_mat_q8_KV_q8_KV<16>;
+#endif
+ expected_typeB = GGML_TYPE_Q8_KV;
+ break;
+ case GGML_TYPE_Q8_KV_R8:
+ assert (ne00 % 32 == 0);
+ mm.funcs[0] = mul_mat_q8_KV_r8_q8_KV<1>;
+ mm.funcs[1] = mul_mat_q8_KV_r8_q8_KV<2>;
+ mm.funcs[2] = mul_mat_q8_KV_r8_q8_KV<3>;
+ mm.funcs[3] = mul_mat_q8_KV_r8_q8_KV<4>;
+ mm.funcs[4] = mul_mat_q8_KV_r8_q8_KV<5>;
+ mm.funcs[5] = mul_mat_q8_KV_r8_q8_KV<6>;
+ mm.funcs[6] = mul_mat_q8_KV_r8_q8_KV<7>;
+ mm.funcs[7] = mul_mat_q8_KV_r8_q8_KV<8>;
+ expected_typeB = GGML_TYPE_Q8_KV;
+ break;
case GGML_TYPE_IQ4_K_R4:
assert (ne00 % QK_K == 0);
mm.funcs[0] = mul_mat_iq4_k_r4_q8_k<1>;
@@ -13424,6 +13679,123 @@ void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& inf
}
}
+static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(n%32 == 0);
+ int32x4_t acc[4] = {};
+ auto dptr = (const float *)info.src1_row(0);
+ const float dy = dptr[0];
+ auto q8y = (const int8_t *)(dptr + 2);
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ auto dx = (const float *)((const char *)vx + ix*bx);
+ auto q8x = (const int8_t *)(dx + 2);
+ for (int i = 0; i < n/64; ++i) {
+ auto qx = vld1q_s8_x4(q8x + 64*i);
+ for (int j = 0; j < 4; ++j) {
+ acc[j] = ggml_vdotq_s32(acc[j], qx.val[j], vld1q_s8(q8y + 64*i + 16*j));
+ }
+ }
+ if (int i = 2*(n/64); i < n/32) {
+ auto qx = vld1q_s8_x2(q8x + 32*i);
+ for (int j = 0; j < 2; ++j) {
+ acc[j] = ggml_vdotq_s32(acc[j], qx.val[j], vld1q_s8(q8y + 32*i + 16*j));
+ }
+ }
+ acc[0] = vaddq_s32(acc[0], acc[1]);
+ acc[2] = vaddq_s32(acc[2], acc[3]);
+ acc[0] = vaddq_s32(acc[0], acc[2]);
+ info.store(ix, 0, dx[0]*dy*vaddvq_s32(acc[0]));
+ acc[0] = acc[1] = acc[2] = acc[3] = vdupq_n_s32(0);
+ }
+}
+
+template <int nrc_y>
+static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ GGML_ASSERT(n%16 == 0);
+ int8x16_t qx[4];
+ int32x4_t acc[nrc_y] = {};
+ float dy[nrc_y];
+ const int8_t * q8y[nrc_y];
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto dptr = (const float *)info.src1_row(iy);
+ dy[iy] = dptr[0];
+ q8y[iy] = (const int8_t *)(dptr + 2);
+ }
+ const int8_t * q8x[4];
+ float dx[4];
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ for (int kx = 0; kx < 4; ++kx) {
+ auto dptr = (const float *)((const char *)vx + (ix+kx)*bx);
+ dx[kx] = dptr[0];
+ q8x[kx] = (const int8_t *)(dptr + 2);
+ }
+ for (int i = 0; i < n/16; ++i) {
+ for (int kx = 0; kx < 4; ++kx) qx[kx] = vld1q_s8(q8x[kx] + 16*i);
+ auto row01 = vtrnq_s32(qx[0], qx[1]);
+ auto row23 = vtrnq_s32(qx[2], qx[3]);
+ qx[0] = vtrn1q_s64(row01.val[0], row23.val[0]);
+ qx[1] = vtrn1q_s64(row01.val[1], row23.val[1]);
+ qx[2] = vtrn2q_s64(row01.val[0], row23.val[0]);
+ qx[3] = vtrn2q_s64(row01.val[1], row23.val[1]);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = vld1q_s8(q8y[iy] + 16*i);
+ acc[iy] = vdotq_laneq_s32(acc[iy], qx[0], y, 0);
+ acc[iy] = vdotq_laneq_s32(acc[iy], qx[1], y, 1);
+ acc[iy] = vdotq_laneq_s32(acc[iy], qx[2], y, 2);
+ acc[iy] = vdotq_laneq_s32(acc[iy], qx[3], y, 3);
+ }
+ }
+ auto scales_x = vld1q_f32(dx);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto scale = vmulq_f32(scales_x, vdupq_n_f32(dy[iy]));
+ info.store(ix, iy, vmulq_f32(scale, vcvtq_f32_s32(acc[iy])));
+ acc[iy] = vdupq_n_s32(0);
+ }
+ }
+}
+
+template <int nrc_y>
+void mul_mat_q8_KV_r8_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%8 == 0);
+ int32x4_t acc[2*nrc_y] = {};
+ float dy[nrc_y];
+ const int8_t * q8y[nrc_y];
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto dptr = (const float *)info.src1_row(iy);
+ dy[iy] = dptr[0];
+ q8y[iy] = (const int8_t *)(dptr + 2);
+ }
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ const float * dptr = (const float *)((const char *)vx + ix*bx);
+ auto q8x = (const int8_t *)(dptr + 8);
+ for (int ib = 0; ib < n/16; ++ib) {
+ auto q1 = vld1q_s8_x4(q8x + 128*ib + 0);
+ auto q2 = vld1q_s8_x4(q8x + 128*ib + 64);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = vld1q_s8(q8y[iy]+16*ib);
+ acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q1.val[0], y, 0);
+ acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q1.val[1], y, 0);
+ acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q1.val[2], y, 1);
+ acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q1.val[3], y, 1);
+ acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q2.val[0], y, 2);
+ acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q2.val[1], y, 2);
+ acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q2.val[2], y, 3);
+ acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q2.val[3], y, 3);
+ }
+ }
+ auto scale1_x = vld1q_f32(dptr+0);
+ auto scale2_x = vld1q_f32(dptr+4);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto scale_y = vdupq_n_f32(dy[iy]);
+ auto scale1 = vmulq_f32(scale1_x, scale_y);
+ auto scale2 = vmulq_f32(scale2_x, scale_y);
+ info.store(ix+0, iy, vmulq_f32(scale1, vcvtq_f32_s32(acc[2*iy+0])));
+ info.store(ix+4, iy, vmulq_f32(scale2, vcvtq_f32_s32(acc[2*iy+1])));
+ acc[2*iy+0] = acc[2*iy+1] = vdupq_n_f32(0.f);
+ }
+ }
+}
+
void mul_mat_iq4_nl_r4_q8_0_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<1, block_q8_0_x4> q8(info);
@@ -14000,6 +14372,16 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
SET_MUL_MAT_FUNCTIONS(m, mul_mat_q8_k_r8_q8_k);
expected_Btype = GGML_TYPE_Q8_KR8;
break;
+ case GGML_TYPE_Q8_KV:
+ SET_MUL_MAT_FUNCTIONS(m, mul_mat_q8_KV_q8_KV);
+ m.funcs[0] = mul_mat_q8_KV_q8_KV_1;
+ m.func16 = mul_mat_q8_KV_q8_KV<16>;
+ expected_Btype = GGML_TYPE_Q8_KV;
+ break;
+ case GGML_TYPE_Q8_KV_R8:
+ SET_MUL_MAT_FUNCTIONS(m, mul_mat_q8_KV_r8_q8_KV);
+ expected_Btype = GGML_TYPE_Q8_KV;
+ break;
case GGML_TYPE_IQ2_K_R4:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq2_k_r4_q8_k);
expected_Btype = GGML_TYPE_Q8_K;
@@ -14347,13 +14729,49 @@ struct HelperF16 final : public BaseHelper<step> {
}
};
+template <int D> struct block_q8_KV {
+ float d;
+ int s;
+ int8_t qs[D];
+};
+
+template <int D, int step>
+struct HelperQ8KV final : public BaseHelper<step> {
+ using Base = BaseHelper<step>;
+ using block_q8 = block_q8_KV<D>;
+ constexpr static int block_size_q = D;
+ HelperQ8KV(const char * data, int stride) : Base(data, stride) {}
+
+ // Needed for v * softmax(k * q)
+ inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const {
+ auto q8 = (const block_q8_KV<D> *)Base::lblock(l1);
+#ifdef __aarch64__
+ auto vd = F16::set1(q8->d);
+ auto qs = vld1_s8_x2(q8->qs + 8*i);
+ v1 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[0])));
+ v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[1])));
+#else
+ auto vd = F16::set1(q8->d);
+#ifdef HAVE_FANCY_SIMD
+ v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)q8->qs+i+0))));
+ v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)q8->qs+i+1))));
+#else
+ v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(q8->qs+8*i+0)))));
+ v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(q8->qs+8*i+8)))));
+#endif
+#endif
+ }
+};
+
template <int D, int step>
struct HelperQ80 final : public BaseHelper<step> {
using Base = BaseHelper<step>;
#ifdef HAVE_FANCY_SIMD
using block_q8 = block_q8_1;
+ constexpr static int block_size_q = QK8_1;
#else
using block_q8 = block_q8_0;
+ constexpr static int block_size_q = QK8_0;
#endif
HelperQ80(const char * data, int stride) : Base(data, stride) {}
@@ -14397,23 +14815,33 @@ struct HelperQ80 final : public BaseHelper<step> {
y += D/QK8_1;
}
}
+
+ static inline void convert(int nq, int stride_q, const float * q, block_q8_KV<D> * y) {
+ for (int i = 0; i < nq; ++i) {
+ quantize_row_q8_KV(q, y, D);
+ q += stride_q;
+ ++y;
+ }
+ }
};
template <int D, int step>
-struct HelperQ80R4 : public BaseHelper<step> {
+struct HelperQ80R8 : public BaseHelper<step> {
using Base = BaseHelper<step>;
#ifdef __AVX2__
+ constexpr static int block_size_q = QK8_1;
using block_q8 = block_q8_1;
#else
+ constexpr static int block_size_q = QK8_0;
using block_q8 = block_q8_0;
#endif
- HelperQ80R4(int nk, const HelperQ80<D, step>& q8) : Base(q8.data, q8.stride) {
+ HelperQ80R8(int nk, const HelperQ80<D, step>& q8) : Base(q8.data, q8.stride) {
r4 = repack(nk, q8);
Base::data = (const char *)r4.data();
Base::stride = (D/QK8_0)*sizeof(block_q8_0);
}
- static std::vector<block_q8_0_r8> repack(int nk, const HelperQ80<D, step> q8) {
+ static std::vector<block_q8_0_r8> repack(int nk, const HelperQ80<D, step>& q8) {
static_assert(D%QK8_0 == 0);
GGML_ASSERT(nk%8 == 0);
constexpr int nblock = D/QK8_0;
@@ -14512,10 +14940,107 @@ struct HelperQ80R4 : public BaseHelper<step> {
std::vector<block_q8_0_r8> r4;
};
+// TODO: unite this with the above
+template <int D, int step>
+struct HelperQ8KVR8 : public BaseHelper<step> {
+ using Base = BaseHelper<step>;
+ constexpr static int block_size_q = D;
+ using block_q8 = block_q8_KV<D>;
+
+ struct block_q8_KV_r8 {
+ float d[8];
+ int8_t qs[8*D];
+ };
+
+ HelperQ8KVR8(int nk, const HelperQ8KV<D, step>& q8) : Base(q8.data, q8.stride) {
+ r4 = repack(nk, q8);
+ Base::data = (const char *)r4.data();
+ Base::stride = sizeof(block_q8_KV_r8)/8;
+ }
+
+ static std::vector<block_q8_KV_r8> repack(int nk, const HelperQ8KV<D, step>& q8) {
+ static_assert(D%32 == 0);
+ GGML_ASSERT(nk%8 == 0);
+ std::vector<block_q8_KV_r8> result(nk/8);
+ auto y = result.data();
+#ifdef __ARM_NEON
+ int8x16x2_t m0, m1, m2, m3;
+#endif
+ const int8_t * x8[8];
+ for (int ix = 0; ix < nk/8; ++ix) {
+ for (int k = 0; k < 8; ++k) {
+ auto dptr = (const float *)(q8.data + (8*ix + k)*q8.stride);
+ y[ix].d[k] = dptr[0];
+ x8[k] = (const int8_t *)(dptr + 2);
+ }
+ for (int ib = 0; ib < D/16; ++ib) {
+#ifdef __AVX2__
+ auto m0 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[4]+ib), _mm_loadu_si128((const __m128i *)x8[0]+ib));
+ auto m1 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[5]+ib), _mm_loadu_si128((const __m128i *)x8[1]+ib));
+ auto m2 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[6]+ib), _mm_loadu_si128((const __m128i *)x8[2]+ib));
+ auto m3 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[7]+ib), _mm_loadu_si128((const __m128i *)x8[3]+ib));
+ auto t0 = _mm256_unpacklo_epi32(m0, m1);
+ auto t1 = _mm256_unpacklo_epi32(m2, m3);
+ auto t2 = _mm256_unpackhi_epi32(m0, m1);
+ auto t3 = _mm256_unpackhi_epi32(m2, m3);
+ m0 = _mm256_unpacklo_epi64(t0, t1);
+ m1 = _mm256_unpackhi_epi64(t0, t1);
+ m2 = _mm256_unpacklo_epi64(t2, t3);
+ m3 = _mm256_unpackhi_epi64(t2, t3);
+#ifdef HAVE_FANCY_SIMD
+ m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127));
+ m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127));
+ m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127));
+ m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127));
+#endif
+ _mm256_storeu_si256((__m256i *)y[ix].qs + 4*ib+0, m0);
+ _mm256_storeu_si256((__m256i *)y[ix].qs + 4*ib+1, m1);
+ _mm256_storeu_si256((__m256i *)y[ix].qs + 4*ib+2, m2);
+ _mm256_storeu_si256((__m256i *)y[ix].qs + 4*ib+3, m3);
+#elif defined __ARM_NEON
+ // TODO
+ m0.val[0] = vld1q_s8(x8[0]+16*ib); m0.val[1] = vld1q_s8(x8[4]+16*ib);
+ m1.val[0] = vld1q_s8(x8[1]+16*ib); m1.val[1] = vld1q_s8(x8[5]+16*ib);
+ m2.val[0] = vld1q_s8(x8[2]+16*ib); m2.val[1] = vld1q_s8(x8[6]+16*ib);
+ m3.val[0] = vld1q_s8(x8[3]+16*ib); m3.val[1] = vld1q_s8(x8[7]+16*ib);
+ auto row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[0]), vreinterpretq_s32_s8(m1.val[0]));
+ auto row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[0]), vreinterpretq_s32_s8(m3.val[0]));
+ m0.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
+ m1.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
+ m2.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
+ m3.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
+ row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[1]), vreinterpretq_s32_s8(m1.val[1]));
+ row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[1]), vreinterpretq_s32_s8(m3.val[1]));
+ m0.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
+ m1.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
+ m2.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
+ m3.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
+ vst1q_s8_x2(y[ix].qs + 0 + 128*ib, m0);
+ vst1q_s8_x2(y[ix].qs + 32 + 128*ib, m1);
+ vst1q_s8_x2(y[ix].qs + 64 + 128*ib, m2);
+ vst1q_s8_x2(y[ix].qs + 96 + 128*ib, m3);
+#else
+ // TODO
+ for (int l = 0; l < 4; ++l) {
+ for (int k = 0; k < 8; ++k) for (int i = 0; i < 4; ++i) {
+ y[ib].qs[32*l+4*k+i+ 0] = x8[k][ib].qs[i+4*l+ 0];
+ y[ib].qs[32*l+4*k+i+128] = x8[k][ib].qs[i+4*l+16];
+ }
+ }
+#endif
+ }
+ }
+ return result;
+ }
+
+ std::vector<block_q8_KV_r8> r4;
+};
+
template <int D, int step>
struct HelperQ40 final : public BaseHelper<step> {
using Base = BaseHelper<step>;
using block_q8 = block_q8_0;
+ constexpr static int block_size_q = QK8_0;
HelperQ40(const char * data, int stride) : Base(data, stride) {}
// Needed for v * softmax(k * q)
@@ -14559,6 +15084,7 @@ template <int D, int step>
struct HelperQ41 final : public BaseHelper<step> {
using Base = BaseHelper<step>;
using block_q8 = block_q8_1;
+ constexpr static int block_size_q = QK8_1;
HelperQ41(const char * data, int stride) : Base(data, stride) {}
// Needed for v * softmax(k * q)
@@ -14649,8 +15175,10 @@ template <int D, int step>
struct HelperQ60 final : public BaseHelper<step> {
#ifdef __aarch64__
using block_q8 = block_q8_0;
+ constexpr static int block_size_q = QK8_0;
#else
using block_q8 = block_q8_1;
+ constexpr static int block_size_q = QK8_1;
#endif
using Base = BaseHelper<step>;
HelperQ60(const char * data, int stride) : Base(data, stride) {}
@@ -15071,9 +15599,9 @@ struct FlashQKV {
}
inline void normalize_and_store(const FlashMS<q_step, k_step>& fms, int j, const qkv_cache_t * R, float * qkv) const {
- GGML_ASSERT(fms.S[j] > 0);
- auto norm = F16::set1(1/fms.S[j]);
- //auto norm = F16::set1(fms.S[j] > 0 ? 1/fms.S[j] : 0.f);
+ //GGML_ASSERT(fms.S[j] > 0);
+ //auto norm = F16::set1(1/fms.S[j]);
+ auto norm = F16::set1(fms.S[j] > 0 ? 1/fms.S[j] : 0.f);
for (int i = 0; i < D/F16::block_size; ++i) {
auto r = F16::load(R + F16::block_size*i);
F16::store(qkv + F16::block_size*i, F16::mul(norm, r));
@@ -15357,13 +15885,29 @@ struct FlashQKfp32 {
#endif
#endif
}
- else if constexpr (std::is_same_v<KHelper, HelperQ80R4<D, k_step>>) {
+ else if constexpr (std::is_same_v<KHelper, HelperQ8KV<D, k_step>>) {
+#ifdef __aarch64__
+ if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV<16>, 16);
+ if (nq == 1) return std::make_pair(mul_mat_q8_KV_q8_KV_1, 1);
+ MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV, nq);
+#else
+#ifdef HAVE_FANCY_SIMD
+ if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV<16>, 16);
+#endif
+ if (nq == 1) return std::make_pair(mul_mat_q8_KV_q8_KV_1<1>, 1);
+ MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV, nq);
+#endif
+ }
+ else if constexpr (std::is_same_v<KHelper, HelperQ80R8<D, k_step>>) {
#ifdef __aarch64__
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r8_q8_0, nq);
#else
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r8_q8_1, nq);
#endif
}
+ else if constexpr (std::is_same_v<KHelper, HelperQ8KVR8<D, k_step>>) {
+ MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_r8_q8_KV, nq);
+ }
else if constexpr (std::is_same_v<KHelper, HelperQ60<D, k_step>>) {
#ifdef __aarch64__
MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ60, nq);
@@ -15406,7 +15950,7 @@ struct FlashQKfp32 {
constexpr int kMaxQ = 8;
static_assert(q_step < kMaxQ || q_step%kMaxQ == 0);
auto [mul_mat, nrc_q] = mul_mat_kernel<KHelper>(q_step);
- DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr};
+ DataInfo info{fms.cache, (const char *)q, k_step, (D/KHelper::block_size_q)*sizeof(block_q8), 0, 1, nullptr};
for (int iq = 0; iq < q_step/nrc_q; ++iq) {
mul_mat(D, kh.block, kh.stride, info, k_step);
info.cur_y += nrc_q;
@@ -15428,7 +15972,7 @@ struct FlashQKfp32 {
static inline void mul_mask_kq(int nq, const KHelper& kh, int stride_m,
const block_q8 * q, const char * mask, FlashMS<q_step, k_step>& fms) {
auto [mul_mat, nrc_q] = mul_mat_kernel<KHelper>(nq);
- DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr};
+ DataInfo info{fms.cache, (const char *)q, k_step, (D/KHelper::block_size_q)*sizeof(block_q8), 0, 1, nullptr};
for (int iq = 0; iq < nq/nrc_q; ++iq) {
mul_mat(D, kh.block, kh.stride, info, k_step);
info.cur_y += nrc_q;
@@ -15516,7 +16060,7 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
FlashMS<q_step, k_step>& fms,
FlashQKV<Dv, q_step, k_step>& fqkv,
const float * q, const char * mask, float * qkv) {
- typename KHelper::block_q8 q8[q_step*(Dk/QK8_0)];
+ typename KHelper::block_q8 q8[q_step*(Dk/KHelper::block_size_q)];
#if FA_TIMING
Perf perf(false);
#endif
@@ -15613,12 +16157,28 @@ struct FlashAttn {
if (nq1 >= 8) {
#if FA_TIMING
auto t1 = Perf::cur_time();
- HelperQ80R4<Dk, k_step> khr4(nk1, kh);
+ HelperQ80R8<Dk, k_step> khr4(nk1, kh);
+ Perf::instance().accum(4, t1);
+#else
+ HelperQ80R8<Dk, k_step> khr4(nk1, kh);
+#endif
+ compute_helper_q<Dk, Dv, q_step, k_step, HelperQ80R8<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
+ khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
+ } else{
+ compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
+ kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
+ }
+ }
+ else if constexpr (std::is_same_v<KHelper, HelperQ8KV<Dk, k_step>>) {
+ if (nq1 >= 8) {
+#if FA_TIMING
+ auto t1 = Perf::cur_time();
+ HelperQ8KVR8<Dk, k_step> khr4(nk1, kh);
Perf::instance().accum(4, t1);
#else
- HelperQ80R4<Dk, k_step> khr4(nk1, kh);
+ HelperQ8KVR8<Dk, k_step> khr4(nk1, kh);
#endif
- compute_helper_q<Dk, Dv, q_step, k_step, HelperQ80R4<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
+ compute_helper_q<Dk, Dv, q_step, k_step, HelperQ8KVR8<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
} else{
compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
@@ -16142,6 +16702,10 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v,
HelperQ80<Dv, k_step> vh(v, stride_v);
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
} break;
+ case GGML_TYPE_Q8_KV: {
+ HelperQ8KV<Dv, k_step> vh(v, stride_v);
+ iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
+ } break;
case GGML_TYPE_Q6_0: {
HelperQ60<Dv, k_step> vh(v, stride_v);
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
@@ -16179,6 +16743,10 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
HelperQ80<Dk, k_step> kh(k, stride_k);
iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
} break;
+ case GGML_TYPE_Q8_KV: {
+ HelperQ8KV<Dk, k_step> kh(k, stride_k);
+ iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
+ } break;
case GGML_TYPE_Q6_0: {
HelperQ60<Dk, k_step> kh(k, stride_k);
iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
@@ -16210,7 +16778,7 @@ inline bool flash_attn_is_supported(ggml_type type) {
if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 ||
type == GGML_TYPE_Q6_0 || type == GGML_TYPE_IQ4_NL) return true;
#else
- if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q6_0) return true;
+ if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q6_0 || type == GGML_TYPE_Q8_KV) return true;
#endif
return false;
}
diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp
index 24b49d89..7b777a1f 100644
--- a/ggml/src/iqk/iqk_quantize.cpp
+++ b/ggml/src/iqk/iqk_quantize.cpp
@@ -2967,6 +2967,103 @@ void iqk_quantize_row_q8_K128(const float * x, void * vy, int64_t k) {
}
#endif
}
+// TODO: merge this with the above template
+void iqk_quantize_row_q8_KV(const float * x, void * vy, int64_t k) {
+ assert(k % 32 == 0);
+ auto dptr = (float *)vy;
+ auto q8 = (int8_t *)(dptr + 2);
+#ifdef __AVX2__
+ const __m256 signBit = _mm256_set1_ps(-0.0f);
+ const __m256i perm = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7);
+ __m256 maxAbs = _mm256_setzero_ps();
+ for (int ib = 0; ib < k/8; ++ib) {
+ const __m256 v = _mm256_loadu_ps(x + 8*ib);
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps(signBit, v));
+ }
+ const float maxScalar = hmax_f32_8(maxAbs);
+ if (!maxScalar) {
+ dptr[0] = dptr[1] = 0;
+ std::memset(q8, 0, k*sizeof(int8_t));
+ return;
+ }
+ dptr[0] = maxScalar / 127.f;
+ auto mul = _mm256_set1_ps(1/dptr[0]);
+ auto isum = _mm256_setzero_si256();
+ for (int i = 0; i < k/32; i++) {
+ __m256 v0 = _mm256_mul_ps(mul, _mm256_loadu_ps(x + 32*i + 0));
+ __m256 v1 = _mm256_mul_ps(mul, _mm256_loadu_ps(x + 32*i + 8));
+ __m256 v2 = _mm256_mul_ps(mul, _mm256_loadu_ps(x + 32*i + 16));
+ __m256 v3 = _mm256_mul_ps(mul, _mm256_loadu_ps(x + 32*i + 24));
+ v0 = _mm256_round_ps(v0, _MM_ROUND_NEAREST);
+ v1 = _mm256_round_ps(v1, _MM_ROUND_NEAREST);
+ v2 = _mm256_round_ps(v2, _MM_ROUND_NEAREST);
+ v3 = _mm256_round_ps(v3, _MM_ROUND_NEAREST);
+ __m256i i0 = _mm256_cvtps_epi32(v0);
+ __m256i i1 = _mm256_cvtps_epi32(v1);
+ __m256i i2 = _mm256_cvtps_epi32(v2);
+ __m256i i3 = _mm256_cvtps_epi32(v3);
+ isum = _mm256_add_epi32(isum, _mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
+ i0 = _mm256_packs_epi32( i0, i1 );
+ i2 = _mm256_packs_epi32( i2, i3 );
+ i0 = _mm256_packs_epi16( i0, i2 );
+ i0 = _mm256_permutevar8x32_epi32( i0, perm );
+ _mm256_storeu_si256((__m256i *)q8, i0);
+ q8 += 32;
+ }
+ auto iptr = (int32_t *)(dptr + 1);
+ iptr[0] = hsum_i32_8(isum);
+#elif defined __ARM_NEON
+ int32x4_t ival[8];
+ auto vmax = vdupq_n_f32(0.f);
+ for (int j = 0; j < k; j += 4) {
+ vmax = vmaxq_f32(vmax, vabsq_f32(vld1q_f32(x + j)));
+ }
+ auto smax = vmaxvq_f32(vmax);
+ if (!smax) {
+ dptr[0] = dptr[1] = 0;
+ std::memset(q8, 0, k*sizeof(int8_t));
+ return;
+ }
+ dptr[0] = smax/127;
+ auto vid = vdupq_n_f32(1/dptr[0]);
+ auto isum = vdupq_n_s32(0);
+ for (int ib = 0; ib < k/32; ++ib) {
+ auto xb = x + 32*ib;
+ for (int k = 0; k < 8; ++k) {
+ auto val = vld1q_f32(xb + 4*k);
+ ival[k] = vcvtnq_s32_f32(vmulq_f32(val, vid));
+ isum = vaddq_s32(isum, ival[k]);
+ }
+ for (int k = 0; k < 4; ++k) {
+ auto i16 = vcombine_s16(vmovn_s32(ival[2*k+0]), vmovn_s32(ival[2*k+1]));
+ vst1_s8(q8, vmovn_s16(i16));
+ q8 += 8;
+ }
+ }
+ auto iptr = (int32_t *)(dptr + 1);
+ iptr[0] = vaddvq_s32(isum);
+#else
+ float amax = 0;
+ for (int j = 0; j < k; ++j) {
+ float ax = std::abs(x[j]);
+ amax = std::max(amax, ax);
+ }
+ if (!amax) {
+ dptr[0] = dptr[1] = 0;
+ std::memset(q8, 0, k*sizeof(int8_t));
+ return;
+ }
+ dptr[0] = amax/127;
+ float id = 1/dptr[0];
+ int isum = 0;
+ for (int i = 0; i < k; i++) {
+ q8[i] = nearest_int(id*x[i]);
+ isum += q8[i];
+ }
+ auto iptr = (int32_t *)(dptr + 1);
+ iptr[0] = isum;
+#endif
+}
}
void quantize_row_q8_K128(const float * x, void * vy, int64_t k) {
@@ -3886,7 +3983,7 @@ static void repack_q8_0(int nrows, int n_per_row, const block_q8_0 * x, block_q8
#ifdef HAVE_FANCY_SIMD
static void modify_q8_0_r8(int64_t k, char * cy) {
- auto y = (block_iq4_nl_r8 *)cy;
+ auto y = (block_q8_0_r8 *)cy;
int nb = k/(32*8);
for (int ib = 0; ib < nb; ++ib) {
for (int l = 0; l < 4; ++l) {
@@ -5413,6 +5510,150 @@ void vec_dot_q8_k_r8_q8_k(int n, float * s, size_t bs, const void * vx, size_t b
}
//
+// ========================================= q8_KV_r8
+//
+
+void quantize_row_q8_KV_r8_ref(const float * x, void * y, int64_t k) {
+ quantize_q8_KV_r8(x, y, 8, k/8, nullptr);
+}
+
+void quantize_row_q8_KV_r8(const float * x, void * y, int64_t k) {
+ quantize_q8_KV_r8(x, y, 8, k/8, nullptr);
+}
+
+static void repack_q8_KV(int nrows, int n_per_row, const char * cx, char * cy, [[maybe_unused]] bool online) {
+ GGML_ASSERT(nrows%8 == 0);
+ GGML_ASSERT(n_per_row%16 == 0);
+ auto row_size_x = ggml_row_size(GGML_TYPE_Q8_KV, n_per_row);
+ auto row_size_y = ggml_row_size(GGML_TYPE_Q8_KV_R8, n_per_row);
+ const int8_t * x8[8];
+#ifdef __ARM_NEON
+ int8x16x2_t m0, m1, m2, m3;
+#endif
+ for (int row = 0; row < nrows; row += 8) {
+ auto dy = (float *)cy;
+ auto qy = (int8_t *)(dy + 8);
+ for (int k = 0; k < 8; ++k) {
+ auto dx = (const float *)(cx + k*row_size_x);
+ dy[k] = dx[0];
+ x8[k] = (const int8_t *)(dx + 2);
+ }
+ for (int ib = 0; ib < n_per_row/16; ++ib) {
+#ifdef __AVX2__
+#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
+ auto m0 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[4]+ib), _mm_loadu_si128((const __m128i *)x8[0]+ib));
+ auto m1 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[5]+ib), _mm_loadu_si128((const __m128i *)x8[1]+ib));
+ auto m2 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[6]+ib), _mm_loadu_si128((const __m128i *)x8[2]+ib));
+ auto m3 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[7]+ib), _mm_loadu_si128((const __m128i *)x8[3]+ib));
+ auto t0 = _mm256_unpacklo_epi32(m0, m1);
+ auto t1 = _mm256_unpacklo_epi32(m2, m3);
+ auto t2 = _mm256_unpackhi_epi32(m0, m1);
+ auto t3 = _mm256_unpackhi_epi32(m2, m3);
+ m0 = _mm256_unpacklo_epi64(t0, t1);
+ m1 = _mm256_unpackhi_epi64(t0, t1);
+ m2 = _mm256_unpacklo_epi64(t2, t3);
+ m3 = _mm256_unpackhi_epi64(t2, t3);
+#ifdef HAVE_FANCY_SIMD
+ if (online) {
+ m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127));
+ m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127));
+ m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127));
+ m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127));
+ }
+#endif
+ _mm256_storeu_si256((__m256i *)qy + 4*ib+0, m0);
+ _mm256_storeu_si256((__m256i *)qy + 4*ib+1, m1);
+ _mm256_storeu_si256((__m256i *)qy + 4*ib+2, m2);
+ _mm256_storeu_si256((__m256i *)qy + 4*ib+3, m3);
+#elif defined __ARM_NEON
+ m0.val[0] = vld1q_s8(x8[0]+16*ib); m0.val[1] = vld1q_s8(x8[4]+16*ib);
+ m1.val[0] = vld1q_s8(x8[1]+16*ib); m1.val[1] = vld1q_s8(x8[5]+16*ib);
+ m2.val[0] = vld1q_s8(x8[2]+16*ib); m2.val[1] = vld1q_s8(x8[6]+16*ib);
+ m3.val[0] = vld1q_s8(x8[3]+16*ib); m3.val[1] = vld1q_s8(x8[7]+16*ib);
+ auto row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[0]), vreinterpretq_s32_s8(m1.val[0]));
+ auto row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[0]), vreinterpretq_s32_s8(m3.val[0]));
+ m0.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
+ m1.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
+ m2.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
+ m3.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
+ row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[1]), vreinterpretq_s32_s8(m1.val[1]));
+ row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[1]), vreinterpretq_s32_s8(m3.val[1]));
+ m0.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
+ m1.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
+ m2.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
+ m3.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
+ vst1q_s8_x2(qy + 0 + 128*ib, m0);
+ vst1q_s8_x2(qy + 32 + 128*ib, m1);
+ vst1q_s8_x2(qy + 64 + 128*ib, m2);
+ vst1q_s8_x2(qy + 96 + 128*ib, m3);
+#else
+ // TODO
+ for (int l = 0; l < 4; ++l) {
+ for (int k = 0; k < 8; ++k) for (int i = 0; i < 4; ++i) {
+ y[ib].qs[32*l+4*k+i+ 0] = x8[k][ib].qs[i+4*l+ 0];
+ y[ib].qs[32*l+4*k+i+128] = x8[k][ib].qs[i+4*l+16];
+ }
+ }
+#endif
+
+ }
+ cx += 8*row_size_x;
+ cy += online ? 8*row_size_x : 8*row_size_y;
+ //So, if we are run-time-repacking (online = true) we don't want to change the stride, so we just leave some unused space at the end of each row
+ }
+}
+#ifdef HAVE_FANCY_SIMD
+static void modify_q8_KV_r8(int64_t k, char * cy) {
+ int8_t * q8 = (int8_t *)(cy + 8*sizeof(float));
+ for (int j = 0; j < k; ++j) q8[j] += 127;
+}
+#endif
+
+size_t quantize_q8_KV_r8(const float * src, void * dst, int64_t nrows, int64_t n_per_row, [[maybe_unused]] const float * imatrix) {
+ GGML_ASSERT(nrows%8 == 0);
+ GGML_ASSERT(n_per_row%16 == 0);
+ char * qcur = (char *)dst;
+ auto row_size_0 = ggml_row_size(GGML_TYPE_Q8_KV, n_per_row);
+ auto row_size_1 = ggml_row_size(GGML_TYPE_Q8_KV_R8, n_per_row);
+ std::vector<char> qtmp(8*row_size_0);
+ for (int row = 0; row < nrows; row += 8) {
+ quantize_q8_KV(src, (void *)qtmp.data(), 8, n_per_row, imatrix);
+ repack_q8_KV(8, n_per_row, qtmp.data(), qcur, false);
+ qcur += 8*row_size_1;
+ src += 8*n_per_row;
+ }
+ return nrows*row_size_1;
+}
+
+void dequantize_row_q8_KV_r8(const void * vx, float * y, int64_t k) {
+ auto n_per_row = k/8;
+ float * y8[8];
+ for (int k = 0; k < 8; ++k) y8[k] = y + n_per_row*k;
+ auto dptr = (const float *)vx;
+ auto q8 = (const int8_t *)(dptr + 8);
+ for (int ib = 0; ib < n_per_row/16; ++ib) {
+ for (int k = 0; k < 8; ++k) {
+ for (int l = 0; l < 4; ++l) {
+ for (int i = 0; i < 4; ++i) y8[k][16*ib + 4*l + i] = dptr[k] * q8[128*ib + 32*l + 4*k + i];
+ }
+ }
+ }
+}
+
+void vec_dot_q8_KV_r8_q8_KV(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
+#if GGML_USE_IQK_MULMAT
+ if (iqk_mul_mat(1, 1, n, GGML_TYPE_Q8_KV_R8, vx, 0, GGML_TYPE_Q8_KV, vy, 0, s, 0, 0, 1)) {
+ return;
+ }
+#endif
+ GGML_ASSERT(n%QK4_NL == 0);
+ GGML_ASSERT(nrc == 1);
+ GGML_UNUSED(bs);
+ GGML_UNUSED(bx);
+ GGML_UNUSED(by);
+}
+
+//
// ========================================= bf16_r4
//
namespace {
@@ -6450,6 +6691,47 @@ void vec_dot_iq1_m_r4_q8_k(int n, float * s, size_t bs, const void * vx, size_t
GGML_UNUSED(by);
}
+void quantize_row_q8_KV(const float * x, void * vy, int64_t k) {
+ iqk_quantize_row_q8_KV(x, vy, k);
+}
+
+void quantize_row_q8_KV_ref(const float * x, void * y, int64_t k) {
+ quantize_row_q8_KV(x, y, k);
+}
+
+size_t quantize_q8_KV(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) {
+ (void)imatrix;
+ auto row_size = ggml_row_size(GGML_TYPE_Q8_KV, n_per_row);
+ auto q = (char *)dst;
+ for (int row = 0; row < nrows; ++row) {
+ quantize_row_q8_KV(src, q, n_per_row);
+ src += n_per_row;
+ q += row_size;
+ }
+ return row_size*nrows;
+}
+
+void dequantize_row_q8_KV(const void * x, float * y, int64_t k) {
+ auto dptr = (const float *)x;
+ float d = dptr[0];
+ auto q8 = (const int8_t *)(dptr + 2);
+ for (int j = 0; j < k; ++j) y[j] = d * q8[j];
+}
+
+void vec_dot_q8_KV_q8_KV(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
+#if GGML_USE_IQK_MULMAT
+ if (iqk_mul_mat(1, 1, n, GGML_TYPE_Q8_KV, vx, 0, GGML_TYPE_Q8_KV, vy, 0, s, 0, 0, 1)) {
+ return;
+ }
+#endif
+ GGML_ASSERT(n%QK4_NL == 0);
+ GGML_ASSERT(nrc == 1);
+ GGML_UNUSED(bs);
+ GGML_UNUSED(bx);
+ GGML_UNUSED(by);
+}
+
+
//================================================
namespace {
@@ -6472,8 +6754,9 @@ bool iqk_modify_tensor(struct ggml_tensor * tensor) {
{ GGML_TYPE_Q4_0_R8, {modify_q4_0_r8, 8} },
#endif
#ifdef HAVE_FANCY_SIMD
- { GGML_TYPE_Q8_0_R8, {modify_q8_0_r8, 8} },
- { GGML_TYPE_Q8_K_R8, {modify_q8_k_r8, 8} },
+ { GGML_TYPE_Q8_0_R8, {modify_q8_0_r8, 8} },
+ { GGML_TYPE_Q8_K_R8, {modify_q8_k_r8, 8} },
+ { GGML_TYPE_Q8_KV_R8, {modify_q8_KV_r8, 8} },
#endif
};
auto it = k_mod_map.find(tensor->type);
@@ -6532,6 +6815,7 @@ void iqk_repack_tensor(struct ggml_tensor * tensor) {
{ GGML_TYPE_Q6_0, { GGML_TYPE_Q6_0_R4, 4, (Repack::repack_func)repack_q6_0} },
{ GGML_TYPE_Q8_0, { GGML_TYPE_Q8_0_R8, 8, (Repack::repack_func)repack_q8_0} },
{ GGML_TYPE_Q8_K, { GGML_TYPE_Q8_K_R8, 8, (Repack::repack_func)repack_q8_k} },
+ { GGML_TYPE_Q8_KV, { GGML_TYPE_Q8_KV_R8, 8, (Repack::repack_func)repack_q8_KV} },
#ifdef __AVX512BF16__
{ GGML_TYPE_BF16, { GGML_TYPE_BF16_R16, 16, (Repack::repack_func)repack_bf16<ggml_bf16_t>}},
{ GGML_TYPE_F16, { GGML_TYPE_BF16_R16, 16, (Repack::repack_func)repack_bf16<ggml_half>} },
diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h
index 97719361..76fbac3b 100644
--- a/ggml/src/iqk/iqk_quantize.h
+++ b/ggml/src/iqk/iqk_quantize.h
@@ -217,6 +217,18 @@ size_t quantize_q8_k_r8(const float * GGML_RESTRICT src, void * GGML_RESTRICT ds
void dequantize_row_q8_k_r8(const block_q8_k_r8 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void vec_dot_q8_k_r8_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void quantize_row_q8_KV_ref(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q8_KV(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+size_t quantize_q8_KV(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+void dequantize_row_q8_KV(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void vec_dot_q8_KV_q8_KV(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
+void quantize_row_q8_KV_r8_ref(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q8_KV_r8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+size_t quantize_q8_KV_r8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+void dequantize_row_q8_KV_r8(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void vec_dot_q8_KV_r8_q8_KV(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
void iqk_quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
void quantize_row_q8_K64_ref(const float * GGML_RESTRICT x, block_q8_K64 * GGML_RESTRICT y, int64_t k);
void quantize_row_q8_K64(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
diff --git a/include/llama.h b/include/llama.h
index 39251d35..b5ad65e7 100644
--- a/include/llama.h
+++ b/include/llama.h
@@ -180,6 +180,7 @@ extern "C" {
LLAMA_FTYPE_MOSTLY_IQ3_KL = 146, // except 1d tensors
LLAMA_FTYPE_MOSTLY_IQ2_KS = 147, // except 1d tensors
LLAMA_FTYPE_MOSTLY_IQ4_KSS = 148, // except 1d tensors
+ LLAMA_FTYPE_MOSTLY_Q8_KV = 149, // except 1d tensors
//
LLAMA_FTYPE_MOSTLY_Q4_0_R8 = 202, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q8_0_R8 = 207, // except 1d tensors
@@ -206,6 +207,7 @@ extern "C" {
LLAMA_FTYPE_MOSTLY_IQ4_K_R4 = 340, // except 1d tensors
LLAMA_FTYPE_MOSTLY_IQ5_K_R4 = 341, // except 1d tensors
LLAMA_FTYPE_MOSTLY_IQ4_KS_R4 = 345, // except 1d tensors
+ LLAMA_FTYPE_MOSTLY_Q8_KV_R8 = 398, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q8_K_R8 = 399, // except 1d tensors
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
diff --git a/src/llama.cpp b/src/llama.cpp
index 8c4a966d..0257a0a3 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -3180,6 +3180,10 @@ static bool llama_kv_cache_init(
for (int i = 0; i < (int) n_layer; i++) {
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
+ const uint32_t n_head = hparams.n_head(i);
+ const uint32_t n_head_kv = hparams.n_head_kv(i);
+ const uint32_t n_embd_head_k= hparams.n_embd_head_k;
+
struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
ggml_tensor * k;
@@ -3201,7 +3205,8 @@ static bool llama_kv_cache_init(
const uint32_t kv_lora_rank = hparams.n_lora_kv;
LLAMA_LOG_INFO("%s: layer %d: n_embd_head_qk_rope = %d, kv_lora_rank = %d\n", __func__, i, n_embd_head_qk_rope, kv_lora_rank);
#if MLA_USE_TRANSPOSED_CACHE
- ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_k, (kv_lora_rank + n_embd_head_qk_rope)*kv_size);
+ ggml_tensor * kv = ggml_new_tensor_2d(ctx, cache.type_k, kv_lora_rank + n_embd_head_qk_rope, kv_size);
+ //ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_k, (kv_lora_rank + n_embd_head_qk_rope)*kv_size);
#else
ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_v, (kv_lora_rank + n_embd_head_qk_rope)*kv_size);
#endif
@@ -3215,7 +3220,10 @@ static bool llama_kv_cache_init(
n_mla++;
}
else {
- k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
+ //printf("Creating cache tensors:\n");
+ //printf("n_embd_k_gqa = %d, kv_size = %d, n_head = %d, n_head_kv = %d, n_embd_head_k = %d\n", (int)n_embd_k_gqa, (int)kv_size, (int)n_head, (int)n_head_kv, (int)n_embd_head_k);
+ //k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
+ k = ggml_new_tensor_2d(ctx, type_k, n_embd_head_k, n_head_kv*kv_size);
v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
ggml_format_name(k, "cache_k_l%d", i);
ggml_format_name(v, "cache_v_l%d", i);
@@ -4002,6 +4010,7 @@ struct llama_model_loader {
case GGML_TYPE_Q5_1: ftype = LLAMA_FTYPE_MOSTLY_Q5_1; break;
case GGML_TYPE_Q6_0: ftype = LLAMA_FTYPE_MOSTLY_Q6_0; break;
case GGML_TYPE_Q8_0: ftype = LLAMA_FTYPE_MOSTLY_Q8_0; break;
+ case GGML_TYPE_Q8_KV: ftype = LLAMA_FTYPE_MOSTLY_Q8_KV; break;
case GGML_TYPE_Q2_K: ftype = LLAMA_FTYPE_MOSTLY_Q2_K; break;
case GGML_TYPE_Q3_K: ftype = LLAMA_FTYPE_MOSTLY_Q3_K_M; break;
case GGML_TYPE_Q3_K_R4: ftype = LLAMA_FTYPE_MOSTLY_Q3_K_R4; break;
@@ -4012,6 +4021,7 @@ struct llama_model_loader {
case GGML_TYPE_Q6_K: ftype = LLAMA_FTYPE_MOSTLY_Q6_K; break;
case GGML_TYPE_Q6_K_R4: ftype = LLAMA_FTYPE_MOSTLY_Q6_K_R4; break;
case GGML_TYPE_Q8_K_R8: ftype = LLAMA_FTYPE_MOSTLY_Q8_K_R8; break;
+ case GGML_TYPE_Q8_KV_R8: ftype = LLAMA_FTYPE_MOSTLY_Q8_KV_R8; break;
case GGML_TYPE_IQ2_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XXS; break;
case GGML_TYPE_IQ2_XXS_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ2_XXS_R4; break;
case GGML_TYPE_IQ2_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XS; break;
@@ -4730,6 +4740,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
case LLAMA_FTYPE_MOSTLY_Q5_1: return "Q5_1";
case LLAMA_FTYPE_MOSTLY_Q6_0: return "Q6_0";
case LLAMA_FTYPE_MOSTLY_Q8_0: return "Q8_0";
+ case LLAMA_FTYPE_MOSTLY_Q8_KV: return "Q8_KV";
case LLAMA_FTYPE_MOSTLY_Q2_K: return "Q2_K - Medium";
case LLAMA_FTYPE_MOSTLY_Q2_K_R4: return "Q2_K_R4";
case LLAMA_FTYPE_MOSTLY_Q2_K_S: return "Q2_K - Small";
@@ -4746,6 +4757,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
case LLAMA_FTYPE_MOSTLY_Q6_K: return "Q6_K";
case LLAMA_FTYPE_MOSTLY_Q6_K_R4: return "Q6_K_R4";
case LLAMA_FTYPE_MOSTLY_Q8_K_R8: return "Q8_K_R8";
+ case LLAMA_FTYPE_MOSTLY_Q8_KV_R8: return "Q8_KV_R8";
case LLAMA_FTYPE_MOSTLY_IQ2_XXS: return "IQ2_XXS - 2.0625 bpw";
case LLAMA_FTYPE_MOSTLY_IQ2_XXS_R4:return "IQ2_XXS_R4 - 2.0625 bpw";
case LLAMA_FTYPE_MOSTLY_IQ2_XS: return "IQ2_XS - 2.3125 bpw";
@@ -8283,11 +8295,20 @@ static void llm_build_kv_store(
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
+ const int64_t n_head = hparams.n_head(il);
+ const int64_t n_head_kv = hparams.n_head_kv(il);
+ const int64_t n_embd_head_k = hparams.n_embd_head_k;
+ const int64_t n_embd_head_v = hparams.n_embd_head_v;
+
GGML_ASSERT(kv.size == n_ctx);
- struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa,
- (ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa))*kv_head);
- cb(k_cache_view, "k_cache_view", il);
+ //struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa,
+ // (ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa))*kv_head);
+ //cb(k_cache_view, "k_cache_view", il);
+
+ auto k_row_size = ggml_row_size(kv.k_l[il]->type, n_embd_head_k);
+ ggml_tensor * k_cache_view = ggml_view_2d(ctx, kv.k_l[il], n_embd_head_k, n_tokens*n_head_kv,
+ k_row_size, k_row_size*n_head_kv*kv_head);
// note: storing RoPE-ed version of K in the KV cache
ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view));
@@ -8706,7 +8727,7 @@ static struct ggml_tensor * llm_build_kqv(
struct ggml_tensor * k =
ggml_view_3d(ctx, kv.k_l[il],
n_embd_head_k, n_kv, n_head_kv,
- ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa),
+ ggml_row_size(kv.k_l[il]->type, n_embd_head_k)*n_head_kv, //n_embd_k_gqa),
ggml_row_size(kv.k_l[il]->type, n_embd_head_k),
0);
cb(k, "k", il);
@@ -13507,8 +13528,9 @@ struct llm_build_context {
ggml_tensor * kvr = ggml_concat(ctx0, kv_compressed, ggml_permute(ctx0, k_rope, 0, 2, 1, 3), 0);
cb(kvr, "kvr", il);
- ggml_tensor * kv_cache_view = ggml_view_1d(ctx0, kv_self.kv_l[il], n_tokens*(kv_lora_rank + n_embd_head_qk_rope),
- ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope)*kv_head);
+ auto row_size = ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope);
+ ggml_tensor * kv_cache_view = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_self.kv_l[il]->ne[0], n_tokens,
+ row_size, row_size*kv_head);
ggml_build_forward_expand(gf, ggml_cpy(ctx0, kvr, kv_cache_view));
ggml_tensor * kv_cache = ggml_view_2d(ctx0, kv_self.kv_l[il],
kv_lora_rank + n_embd_head_qk_rope, n_kv,
@@ -16164,7 +16186,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
new_type = GGML_TYPE_IQ5_K;
}
else if (new_type != GGML_TYPE_Q8_0 && new_type != GGML_TYPE_Q8_0_R8 && new_type != GGML_TYPE_IQ6_K && new_type != GGML_TYPE_Q6_K_R4 &&
- new_type != GGML_TYPE_Q8_K_R8) {
+ new_type != GGML_TYPE_Q8_K_R8 && new_type != GGML_TYPE_Q8_KV && new_type != GGML_TYPE_Q8_KV_R8) {
new_type = GGML_TYPE_Q6_K;
}
}
@@ -16218,6 +16240,9 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
else if (new_type == GGML_TYPE_Q8_K_R8) {
new_type = GGML_TYPE_Q8_0;
}
+ else if (new_type == GGML_TYPE_Q8_KV_R8) {
+ new_type = GGML_TYPE_Q8_0;
+ }
else if (new_type == GGML_TYPE_IQ2_K_R4) {
new_type = GGML_TYPE_IQ2_K;
}
@@ -16728,6 +16753,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
case LLAMA_FTYPE_MOSTLY_Q5_1: default_type = GGML_TYPE_Q5_1; break;
case LLAMA_FTYPE_MOSTLY_Q6_0: default_type = GGML_TYPE_Q6_0; break;
case LLAMA_FTYPE_MOSTLY_Q8_0: default_type = GGML_TYPE_Q8_0; break;
+ case LLAMA_FTYPE_MOSTLY_Q8_KV:default_type = GGML_TYPE_Q8_KV;break;
case LLAMA_FTYPE_MOSTLY_F16: default_type = GGML_TYPE_F16; break;
case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break;
case LLAMA_FTYPE_MOSTLY_BF16_R16: default_type = GGML_TYPE_BF16_R16; break;
@@ -16751,6 +16777,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
case LLAMA_FTYPE_MOSTLY_Q6_K: default_type = GGML_TYPE_Q6_K; break;
case LLAMA_FTYPE_MOSTLY_Q6_K_R4: default_type = GGML_TYPE_Q6_K_R4; break;
case LLAMA_FTYPE_MOSTLY_Q8_K_R8: default_type = GGML_TYPE_Q8_K_R8; break;
+ case LLAMA_FTYPE_MOSTLY_Q8_KV_R8: default_type = GGML_TYPE_Q8_KV_R8; break;
case LLAMA_FTYPE_MOSTLY_IQ2_XXS: default_type = GGML_TYPE_IQ2_XXS; break;
case LLAMA_FTYPE_MOSTLY_IQ2_XXS_R4:default_type = GGML_TYPE_IQ2_XXS_R4; break;
case LLAMA_FTYPE_MOSTLY_IQ2_XS: default_type = GGML_TYPE_IQ2_XS; break;
@@ -17194,6 +17221,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
if (tensor->ne[1] % 8 != 0) new_type = GGML_TYPE_Q8_0;
else chunk_size_multiplier = 8;
}
+ else if (new_type == GGML_TYPE_Q8_KV_R8) {
+ if (tensor->ne[1] % 8 != 0) new_type = GGML_TYPE_Q8_0;
+ else chunk_size_multiplier = 8;
+ }
else if (new_type == GGML_TYPE_IQ2_BN_R4) {
if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ2_BN;
else chunk_size_multiplier = 4;