diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2025-02-19 11:47:07 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-02-19 11:47:07 +0200 |
commit | a0ebfdd661a2ccb2700b0e36cfc10ca1a2b4de98 (patch) | |
tree | d5bb2c8f07625c617d1113348b4b67d79b8f64f4 | |
parent | 047ba895bb3d94f055756c1ec7767b3342cb9c90 (diff) |
Q8_KV: 8-bit quantization type targeting the KV cache (#208)
* Adding q8_KV - Basics + AVX2 gemm/gemv
* q8_KV: Better AVX2 gemm
* q8_KV: Better Zen4 gemm
We get 225.7 t/s for L3-8B. In comparison q8_0 without
run-tinme-repacking is at 169 t/s.
* q8_KV: AVX2 gemm/gemv
We get 254 t/s for L3-8B vs 194 t/s for q8_0 without rtr.
* q8_KV: be able to use it for K cache
This required quite a few fixes in ggml and llama.cpp:
* ggml: do not calculate row size as n/block_size*type_size. I had
removed most of it when implementing the quants with per row scale,
bit it was stull lurking in ggml_copy. Not sure if these were the last
remnants of ggmil-style row sizes, or if there are still places left
* llama.cpp: get rid of the the 1d K cache assumption. Create and manage
the K-cache as a 2D tensor so we can have per row meta data as needed
by q8_KV.
Using q8_KV for K-cache results in non-negligible performance gains.
More details to follow, but for DeepSeek-Lite with MLA, we get
18% speedup for PP-8192 compared to q8_0 K-cache.
* q8_KV: be able to use it for K cache in FA
* q8_KV: repack it for K*Q in FA
* q8_KV: slightly faster gemv on Zen4
* q8_KV: slightly faster gemv on Zen4
* q8_KV: ARM_NEON
We get PP-512 = 167 t/s for L3-8B without interleaving!
We do the interleaving on the fly, so I wonder if this
could be done for other quants as well.
* q8_KV: use it in FA on NEON
* q8_KV_r8 - repacked q8_KV
On Zen4 it is slower than q8_k_r8 (292 vs 370 t/s)
This makes no sense whatsoever as the q8_KV_r8 GEMM is
basically the q8_k_r8 GEMM with the unnecessary block stuff
removed (so, one would think that it would be faster).
* q8_KV_r8: don't use nrc_y = 16 on Zen4
This is faster - 350 t/s. Why?
Much better than the 290 t/s we had before, but still slower
than the 370 t/s for q8_k_r8.
* q8_KV: nrc_y = 16 also doesn't pay off in FA
* Minor
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | common/common.cpp | 3 | ||||
-rw-r--r-- | examples/llama-bench/llama-bench.cpp | 3 | ||||
-rw-r--r-- | examples/quantize/quantize.cpp | 2 | ||||
-rw-r--r-- | ggml/include/ggml.h | 4 | ||||
-rw-r--r-- | ggml/src/ggml-quants.c | 6 | ||||
-rw-r--r-- | ggml/src/ggml.c | 48 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 598 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_quantize.cpp | 290 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_quantize.h | 12 | ||||
-rw-r--r-- | include/llama.h | 2 | ||||
-rw-r--r-- | src/llama.cpp | 49 |
11 files changed, 983 insertions, 34 deletions
diff --git a/common/common.cpp b/common/common.cpp index 44678d7a..f7a6f76f 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -2259,6 +2259,9 @@ static ggml_type kv_cache_type_from_str(const std::string & s) { if (s == "q6_0") { return GGML_TYPE_Q6_0; } + if (s == "q8_KV") { + return GGML_TYPE_Q8_KV; + } throw std::runtime_error("Invalid cache type: " + s); } diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 95df06dc..0222c213 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -339,6 +339,9 @@ static ggml_type ggml_type_from_name(const std::string & s) { if (s == "q6_0") { return GGML_TYPE_Q6_0; } + if (s == "q8_KV") { + return GGML_TYPE_Q8_KV; + } return GGML_TYPE_COUNT; } diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 7ceee208..916f57ec 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -56,6 +56,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = { { "Q5_0_R4", LLAMA_FTYPE_MOSTLY_Q5_0_R4, " 5.50 bpw quantization", }, { "Q6_0_R4", LLAMA_FTYPE_MOSTLY_Q6_0_R4, " 6.50 bpw quantization", }, { "Q8_0_R8", LLAMA_FTYPE_MOSTLY_Q8_0_R8, " 8.50 bpw quantization", }, + { "Q8_KV", LLAMA_FTYPE_MOSTLY_Q8_KV, " 8.00 bpw quantization", }, { "IQ4_XS", LLAMA_FTYPE_MOSTLY_IQ4_XS, " 4.25 bpw non-linear quantization", }, { "IQ4_KS", LLAMA_FTYPE_MOSTLY_IQ4_KS, " 4.25 bpw non-linear quantization", }, { "IQ4_KS_R4",LLAMA_FTYPE_MOSTLY_IQ4_KS_R4,"IQ4_KS repacked", }, @@ -82,6 +83,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = { { "Q6_K", LLAMA_FTYPE_MOSTLY_Q6_K, " 5.15G, +0.0008 ppl @ LLaMA-v1-7B", }, { "Q6_K_R4", LLAMA_FTYPE_MOSTLY_Q6_K_R4, "Q6_K repacked", }, { "Q8_K_R8", LLAMA_FTYPE_MOSTLY_Q8_K_R8, "Q8_K repacked", }, + { "Q8_KV_R8", LLAMA_FTYPE_MOSTLY_Q8_KV_R8, "Q8_KV repacked", }, { "Q8_0", LLAMA_FTYPE_MOSTLY_Q8_0, " 6.70G, +0.0004 ppl @ LLaMA-v1-7B", }, { "Q4_0_4_4", LLAMA_FTYPE_MOSTLY_Q4_0_4_4, " 4.34G, +0.4685 ppl @ Llama-3-8B", }, { "Q4_0_4_8", LLAMA_FTYPE_MOSTLY_Q4_0_4_8, " 4.34G, +0.4685 ppl @ Llama-3-8B", }, diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 66bcb25a..d2131a15 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -416,6 +416,7 @@ extern "C" { GGML_TYPE_Q8_K32 = 148, GGML_TYPE_Q8_KR8 = 149, GGML_TYPE_Q8_K128 = 150, + GGML_TYPE_Q8_KV = 151, GGML_TYPE_Q4_0_R8 = 202, GGML_TYPE_Q5_0_R4 = 206, @@ -442,6 +443,7 @@ extern "C" { GGML_TYPE_IQ4_K_R4 = 339, GGML_TYPE_IQ5_K_R4 = 340, GGML_TYPE_IQ4_KS_R4 = 344, + GGML_TYPE_Q8_KV_R8 = 398, GGML_TYPE_Q8_K_R8 = 399, GGML_TYPE_COUNT, }; @@ -501,6 +503,7 @@ extern "C" { GGML_FTYPE_MOSTLY_IQ4_KS = 137, // except 1d tensors GGML_FTYPE_MOSTLY_IQ2_KS = 138, // except 1d tensors GGML_FTYPE_MOSTLY_IQ4_KSS = 139, // except 1d tensors + GGML_FTYPE_MOSTLY_Q8_KV = 140, // except 1d tensors // GGML_FTYPE_MOSTLY_Q4_0_R8 = 202, // except 1d tensors GGML_FTYPE_MOSTLY_Q8_0_R8 = 207, // except 1d tensors @@ -527,6 +530,7 @@ extern "C" { GGML_FTYPE_MOSTLY_IQ4_K_R4 = 332, // except 1d tensors GGML_FTYPE_MOSTLY_IQ5_K_R4 = 333, // except 1d tensors GGML_FTYPE_MOSTLY_IQ4_KS_R4 = 337, // except 1d tensors + GGML_FTYPE_MOSTLY_Q8_KV_R8 = 398, // except 1d tensors GGML_FTYPE_MOSTLY_Q8_K_R8 = 399, // except 1d tensors }; diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index fe7de167..e8218e76 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -15214,8 +15214,10 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte case GGML_TYPE_IQ3_K_R4: break; case GGML_TYPE_IQ4_K_R4: break; case GGML_TYPE_IQ5_K_R4: break; - case GGML_TYPE_IQ4_KS_R4: break; - case GGML_TYPE_Q8_K_R8: break; + case GGML_TYPE_IQ4_KS_R4:break; + case GGML_TYPE_Q8_KV_R8: break; + case GGML_TYPE_Q8_K_R8: break; + case GGML_TYPE_Q8_KV: break; case GGML_TYPE_BF16_R16: break; case GGML_TYPE_Q4_0_4_4: case GGML_TYPE_Q4_0_4_8: diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 8ab6b0a9..0aee8dd4 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1362,6 +1362,30 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float = quantize_row_q8_K128, .row_meta_size = 0, }, + [GGML_TYPE_Q8_KV] = { + .type_name = "q8_KV", + .blck_size = 32, + .type_size = 32, + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q8_KV, + .from_float = quantize_row_q8_KV, + .from_float_ref = (ggml_from_float_t)quantize_row_q8_KV_ref, + .vec_dot = vec_dot_q8_KV_q8_KV, + .vec_dot_type = GGML_TYPE_Q8_KV, + .row_meta_size = 8, + }, + [GGML_TYPE_Q8_KV_R8] = { + .type_name = "q8_KV_r8", + .blck_size = 32, + .type_size = 32, + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q8_KV_r8, + .from_float = quantize_row_q8_KV_r8, + .from_float_ref = (ggml_from_float_t)quantize_row_q8_KV_r8_ref, + .vec_dot = vec_dot_q8_KV_r8_q8_KV, + .vec_dot_type = GGML_TYPE_Q8_KV, + .row_meta_size = 4, + }, [GGML_TYPE_Q8_K16] = { .type_name = "q8_K16", .blck_size = 64, @@ -4373,6 +4397,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_Q5_1: wtype = GGML_TYPE_Q5_1; break; case GGML_FTYPE_MOSTLY_Q6_0: wtype = GGML_TYPE_Q6_0; break; case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break; + case GGML_FTYPE_MOSTLY_Q8_KV: wtype = GGML_TYPE_Q8_KV; break; case GGML_FTYPE_MOSTLY_Q2_K: wtype = GGML_TYPE_Q2_K; break; case GGML_FTYPE_MOSTLY_Q2_K_R4: wtype = GGML_TYPE_Q2_K_R4; break; case GGML_FTYPE_MOSTLY_Q3_K: wtype = GGML_TYPE_Q3_K; break; @@ -4384,6 +4409,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break; case GGML_FTYPE_MOSTLY_Q6_K_R4: wtype = GGML_TYPE_Q6_K_R4; break; case GGML_FTYPE_MOSTLY_Q8_K_R8: wtype = GGML_TYPE_Q8_K_R8; break; + case GGML_FTYPE_MOSTLY_Q8_KV_R8: wtype = GGML_TYPE_Q8_KV_R8; break; case GGML_FTYPE_MOSTLY_IQ2_XXS: wtype = GGML_TYPE_IQ2_XXS; break; case GGML_FTYPE_MOSTLY_IQ2_XXS_R4: wtype = GGML_TYPE_IQ2_XXS_R4;break; case GGML_FTYPE_MOSTLY_IQ2_XS: wtype = GGML_TYPE_IQ2_XS; break; @@ -9436,7 +9462,7 @@ static void ggml_compute_forward_dup_f16( float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith; size_t id = 0; - size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type)); + size_t rs = ggml_row_size(dst->type, ne00); //nb0 * (ne00 / ggml_blck_size(dst->type)); char * dst_ptr = (char *) dst->data; for (int i03 = 0; i03 < ne03; i03++) { @@ -9722,7 +9748,7 @@ static void ggml_compute_forward_dup_bf16( float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith; size_t id = 0; - size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type)); + size_t rs = ggml_row_size(dst->type, ne00); //nb0 * (ne00 / ggml_blck_size(dst->type)); char * dst_ptr = (char *) dst->data; for (int i03 = 0; i03 < ne03; i03++) { @@ -10042,7 +10068,7 @@ static void ggml_compute_forward_dup_f32( ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float; size_t id = 0; - size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type)); + size_t rs = ggml_row_size(dst->type, ne00); //nb0 * (ne00 / ggml_blck_size(dst->type)); char * dst_ptr = (char *) dst->data; for (int i03 = 0; i03 < ne03; i03++) { @@ -10936,6 +10962,7 @@ static void ggml_compute_forward_add( case GGML_TYPE_Q6_K: case GGML_TYPE_Q6_K_R4: case GGML_TYPE_Q8_K_R8: + case GGML_TYPE_Q8_KV_R8: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XXS_R4: case GGML_TYPE_IQ2_XS: @@ -11406,6 +11433,7 @@ static void ggml_compute_forward_add1( case GGML_TYPE_Q6_K: case GGML_TYPE_Q6_K_R4: case GGML_TYPE_Q8_K_R8: + case GGML_TYPE_Q8_KV_R8: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XXS_R4: case GGML_TYPE_IQ2_XS: @@ -11573,6 +11601,7 @@ static void ggml_compute_forward_acc( case GGML_TYPE_Q6_K: case GGML_TYPE_Q6_K_R4: case GGML_TYPE_Q8_K_R8: + case GGML_TYPE_Q8_KV_R8: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XXS_R4: case GGML_TYPE_IQ2_XS: @@ -14061,7 +14090,7 @@ static void ggml_compute_forward_mul_mat( // nb01 >= nb00 - src0 is not transposed // compute by src0 rows -#if GGML_USE_IQK_MULMAT || GGML_USE_LLAMAFILE +#if GGML_USE_LLAMAFILE // broadcast factors const int64_t r2 = ne12 / ne02; const int64_t r3 = ne13 / ne03; @@ -14344,7 +14373,7 @@ static void ggml_compute_forward_mul_mat_id( char * wdata_src1_end = (src1->type == vec_dot_type) ? (char *) params->wdata : - (char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t)); + (char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, src1->ne[0])*ggml_nrows(src1), sizeof(int64_t)); struct mmid_row_mapping { int32_t i1; @@ -14768,6 +14797,7 @@ static void ggml_compute_forward_out_prod( case GGML_TYPE_Q5_1: case GGML_TYPE_Q6_0: case GGML_TYPE_Q8_0: + case GGML_TYPE_Q8_KV: case GGML_TYPE_Q2_K: case GGML_TYPE_Q2_K_R4: case GGML_TYPE_Q3_K: @@ -14779,6 +14809,7 @@ static void ggml_compute_forward_out_prod( case GGML_TYPE_Q6_K: case GGML_TYPE_Q6_K_R4: case GGML_TYPE_Q8_K_R8: + case GGML_TYPE_Q8_KV_R8: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XXS_R4: case GGML_TYPE_IQ2_XS: @@ -15186,6 +15217,7 @@ static void ggml_compute_forward_set( case GGML_TYPE_Q6_K: case GGML_TYPE_Q6_K_R4: case GGML_TYPE_Q8_K_R8: + case GGML_TYPE_Q8_KV_R8: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XXS_R4: case GGML_TYPE_IQ2_XS: @@ -15473,6 +15505,7 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_Q5_1: case GGML_TYPE_Q6_0: case GGML_TYPE_Q8_0: + case GGML_TYPE_Q8_KV: case GGML_TYPE_Q8_1: case GGML_TYPE_Q8_0_X4: case GGML_TYPE_Q8_1_X4: @@ -15487,6 +15520,7 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_Q6_K: case GGML_TYPE_Q6_K_R4: case GGML_TYPE_Q8_K_R8: + case GGML_TYPE_Q8_KV_R8: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XXS_R4: case GGML_TYPE_IQ2_XS: @@ -16116,6 +16150,7 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_Q6_K: case GGML_TYPE_Q6_K_R4: case GGML_TYPE_Q8_K_R8: + case GGML_TYPE_Q8_KV_R8: case GGML_TYPE_Q8_KR8: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XXS_R4: @@ -16159,6 +16194,7 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_Q8_K: case GGML_TYPE_Q8_K64: case GGML_TYPE_Q8_K128: + case GGML_TYPE_Q8_KV: case GGML_TYPE_Q8_K16: case GGML_TYPE_Q8_K32: case GGML_TYPE_Q4_0_4_4: @@ -22970,6 +23006,7 @@ size_t ggml_quantize_chunk( case GGML_TYPE_Q5_1: result = quantize_q5_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q6_0: result = quantize_q6_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q8_0: result = quantize_q8_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q8_KV: result = quantize_q8_KV(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q2_K: result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q2_K_R4: result = quantize_q2_k_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q3_K: result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; @@ -22981,6 +23018,7 @@ size_t ggml_quantize_chunk( case GGML_TYPE_Q6_K: result = quantize_q6_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q6_K_R4: result = quantize_q6_k_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q8_K_R8: result = quantize_q8_k_r8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q8_KV_R8:result = quantize_q8_KV_r8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_XXS: result = quantize_iq2_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_XXS_R4:result = quantize_iq2_xxs_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_XS: result = quantize_iq2_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 8d6b45da..3bfded73 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -269,6 +269,8 @@ struct MulMat { case GGML_TYPE_IQ4_XS_R8: case GGML_TYPE_Q4_K_R4: case GGML_TYPE_Q5_K_R4: + case GGML_TYPE_Q8_KV: + case GGML_TYPE_Q8_KV_R8: case GGML_TYPE_Q8_K_R8: return 8; case GGML_TYPE_Q4_0_R8: case GGML_TYPE_Q8_0_R8: @@ -301,6 +303,8 @@ struct MulMat { case GGML_TYPE_IQ4_XS_R8: case GGML_TYPE_Q4_0_R8: case GGML_TYPE_Q8_0_R8: + case GGML_TYPE_Q8_KV: + case GGML_TYPE_Q8_KV_R8: case GGML_TYPE_Q8_K_R8: return 8; case GGML_TYPE_BF16_R16: return 16; default: return 1; @@ -6107,7 +6111,7 @@ static void mul_mat_q6_k_r4_q8_k(int n, const void * vx, size_t bx, const DataIn // The HAVE_FANCY_SIMD should only be #if defined(__AVX512_VNNI__ && defined(__AVX512VL__) template <int nrc_y> static void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - GGML_ASSERT(nrc_x%4 == 0); + GGML_ASSERT(nrc_x%8 == 0); Q8<nrc_y, block_q8_K> q8(info); #ifndef HAVE_FANCY_SIMD auto m1 = _mm256_set1_epi16(1); @@ -6169,6 +6173,230 @@ static void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataIn } } +// The HAVE_FANCY_SIMD should only be #if defined(__AVX512_VNNI__ && defined(__AVX512VL__) +template <int nrc_y> +static void mul_mat_q8_KV_r8_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%8 == 0); + GGML_ASSERT(n%32 == 0); +#ifndef HAVE_FANCY_SIMD + auto m1 = _mm256_set1_epi16(1); +#endif + int nb = n / 16; + __m256i acc[nrc_y] = {}; + __m256i qx[4]; + float dy[nrc_y]; +#ifdef HAVE_FANCY_SIMD + float sy[nrc_y]; +#endif + const int8_t * q8y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) { + auto dptr = (const float *)info.src1_row(iy); + dy[iy] = dptr[0]; +#ifdef HAVE_FANCY_SIMD + auto iptr = (const int32_t *)(dptr + 1); + sy[iy] = -127*iptr[0]; +#endif + q8y[iy] = (const int8_t *)(dptr + 2); + } + for (int ix = 0; ix < nrc_x; ix += 8) { + auto dptr = (const float *)((const char *)vx + ix*bx); + auto dx = _mm256_loadu_ps(dptr); + auto q8x = (const int8_t *)(dptr + 8); + for (int ib = 0; ib < nb; ++ib) { // Blocks of 16 for 8 interleaved rows + qx[0] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+0); + qx[1] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+1); + qx[2] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+2); + qx[3] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+3); +#ifndef HAVE_FANCY_SIMD + auto s0 = _mm256_sign_epi8(qx[0], qx[0]); + auto s1 = _mm256_sign_epi8(qx[1], qx[1]); + auto s2 = _mm256_sign_epi8(qx[2], qx[2]); + auto s3 = _mm256_sign_epi8(qx[3], qx[3]); +#endif + for (int iy = 0; iy < nrc_y; ++iy) { + auto y128 = _mm_loadu_si128((const __m128i*)q8y[iy]+ib); + auto y = MM256_SET_M128I(y128, y128); +#ifdef HAVE_FANCY_SIMD + acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[0], _mm256_shuffle_epi32(y, 0x00)); + acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[1], _mm256_shuffle_epi32(y, 0x55)); + acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[2], _mm256_shuffle_epi32(y, 0xaa)); + acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[3], _mm256_shuffle_epi32(y, 0xff)); +#else + auto sumi1 = _mm256_maddubs_epi16(s0, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0])); + auto sumi2 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1])); + auto sumi3 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2])); + auto sumi4 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3])); + auto sumi12 = _mm256_add_epi32(_mm256_madd_epi16(m1, sumi1), _mm256_madd_epi16(m1, sumi2)); + auto sumi34 = _mm256_add_epi32(_mm256_madd_epi16(m1, sumi3), _mm256_madd_epi16(m1, sumi4)); + acc[iy] = _mm256_add_epi32(acc[iy], _mm256_add_epi32(sumi12, sumi34)); +#endif + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto scale = _mm256_mul_ps(dx, _mm256_set1_ps(dy[iy])); +#ifdef HAVE_FANCY_SIMD + acc[iy] = _mm256_add_epi32(acc[iy], _mm256_set1_epi32(sy[iy])); +#endif + info.store(ix, iy, _mm256_mul_ps(scale, _mm256_cvtepi32_ps(acc[iy]))); + acc[iy] = _mm256_setzero_si256(); + } + } +} + +template <int nrc_y> +static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%8 == 0); + GGML_ASSERT(n%32 == 0); + __m256i qx[2]; + __m256i acc[2*nrc_y] = {}; + float dy[nrc_y]; +#ifdef HAVE_FANCY_SIMD + int32_t sy[nrc_y]; +#else + __m256i sx[2]; + auto m1 = _mm256_set1_epi16(1); +#endif + const int8_t * q8y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) { + auto dptr = (const float *)info.src1_row(iy); + dy[iy] = dptr[0]; +#ifdef HAVE_FANCY_SIMD + auto iptr = (const int32_t *)(dptr+1); + sy[iy] = -127*iptr[0]; +#endif + q8y[iy] = (const int8_t *)(dptr + 2); + } + for (int ix = 0; ix < nrc_x; ++ix) { + auto dx = (const float *)((const char *)vx + ix*bx); + auto q8x = (const int8_t *)(dx + 2); + for (int i = 0; i < n/64; ++i) { + for (int j = 0; j < 2; ++j) { +#ifdef HAVE_FANCY_SIMD + qx[j] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)q8x + 2*i + j), _mm256_set1_epi8(127)); +#else + qx[j] = _mm256_loadu_si256((const __m256i *)q8x + 2*i + j); + sx[j] = _mm256_sign_epi8(qx[j], qx[j]); +#endif + } + for (int iy = 0; iy < nrc_y; ++iy) { + for (int j = 0; j < 2; ++j) { +#ifdef HAVE_FANCY_SIMD + acc[2*iy+j] = _mm256_dpbusd_epi32(acc[2*iy+j], qx[j], _mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j)); +#else + auto dot = _mm256_maddubs_epi16(sx[j], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j), qx[j])); + acc[2*iy+j] = _mm256_add_epi32(acc[2*iy+j], _mm256_madd_epi16(m1, dot)); +#endif + } + } + } + if (int i = 2*(n/64); i < n/32) { +#ifdef HAVE_FANCY_SIMD + qx[0] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)q8x + i), _mm256_set1_epi8(127)); +#else + qx[0] = _mm256_loadu_si256((const __m256i *)q8x + i); + sx[0] = _mm256_sign_epi8(qx[0], qx[0]); +#endif + for (int iy = 0; iy < nrc_y; ++iy) { +#ifdef HAVE_FANCY_SIMD + acc[2*iy] = _mm256_dpbusd_epi32(acc[2*iy], qx[0], _mm256_loadu_si256((const __m256i *)q8y[iy] + i)); +#else + auto dot = _mm256_maddubs_epi16(sx[0], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + i), qx[0])); + acc[2*iy] = _mm256_add_epi32(acc[2*iy], _mm256_madd_epi16(m1, dot)); +#endif + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sumi = hsum_i32_8(_mm256_add_epi32(acc[2*iy], acc[2*iy+1])); +#ifdef HAVE_FANCY_SIMD + info.store(ix, iy, dx[0]*dy[iy]*(sumi+sy[iy])); +#else + info.store(ix, iy, dx[0]*dy[iy]*sumi); +#endif + acc[2*iy] = acc[2*iy+1] = _mm256_setzero_si256(); + } + } +} + +template <int nrc_y> +static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%8 == 0); + GGML_ASSERT(n%32 == 0); + __m256i qx[4]; +#ifndef HAVE_FANCY_SIMD + __m256i sx[4]; + auto m1 = _mm256_set1_epi16(1); +#endif + __m256i acc[nrc_y] = {}; + float dy[nrc_y]; +#ifdef HAVE_FANCY_SIMD + int32_t sy[nrc_y]; +#endif + const int8_t * q8y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) { + auto dptr = (const float *)info.src1_row(iy); + dy[iy] = dptr[0]; +#ifdef HAVE_FANCY_SIMD + auto iptr = (const int32_t *)(dptr + 1); + sy[iy] = -127*iptr[0]; +#endif + q8y[iy] = (const int8_t *)(dptr + 2); + } + const int8_t * q8x[4]; + float dx[4]; + for (int ix = 0; ix < nrc_x; ix += 4) { + for (int kx = 0; kx < 4; ++kx) { + auto dptr = (const float *)((const char *)vx + (ix+kx)*bx); + dx[kx] = dptr[0]; + q8x[kx] = (const int8_t *)(dptr + 2); + } + for (int i = 0; i < n/32; ++i) { + for (int kx = 0; kx < 4; ++kx) qx[kx] = _mm256_loadu_si256((const __m256i *)q8x[kx] + i); + auto t0 = _mm256_unpacklo_epi32(qx[0], qx[1]); + auto t1 = _mm256_unpacklo_epi32(qx[2], qx[3]); + auto t2 = _mm256_unpackhi_epi32(qx[0], qx[1]); + auto t3 = _mm256_unpackhi_epi32(qx[2], qx[3]); +#ifdef HAVE_FANCY_SIMD + qx[0] = _mm256_add_epi8(_mm256_unpacklo_epi64(t0, t1), _mm256_set1_epi8(127)); + qx[1] = _mm256_add_epi8(_mm256_unpackhi_epi64(t0, t1), _mm256_set1_epi8(127)); + qx[2] = _mm256_add_epi8(_mm256_unpacklo_epi64(t2, t3), _mm256_set1_epi8(127)); + qx[3] = _mm256_add_epi8(_mm256_unpackhi_epi64(t2, t3), _mm256_set1_epi8(127)); +#else + qx[0] = _mm256_unpacklo_epi64(t0, t1); sx[0] = _mm256_sign_epi8(qx[0], qx[0]); + qx[1] = _mm256_unpackhi_epi64(t0, t1); sx[1] = _mm256_sign_epi8(qx[1], qx[1]); + qx[2] = _mm256_unpacklo_epi64(t2, t3); sx[2] = _mm256_sign_epi8(qx[2], qx[2]); + qx[3] = _mm256_unpackhi_epi64(t2, t3); sx[3] = _mm256_sign_epi8(qx[3], qx[3]); +#endif + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = _mm256_loadu_si256((const __m256i *)q8y[iy] + i); +#ifdef HAVE_FANCY_SIMD + acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[0], _mm256_shuffle_epi32(y, 0x00)); + acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[1], _mm256_shuffle_epi32(y, 0x55)); + acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[2], _mm256_shuffle_epi32(y, 0xaa)); + acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[3], _mm256_shuffle_epi32(y, 0xff)); +#else + auto dot1 = _mm256_maddubs_epi16(sx[0], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0])); + auto dot2 = _mm256_maddubs_epi16(sx[1], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1])); + auto dot3 = _mm256_maddubs_epi16(sx[2], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2])); + auto dot4 = _mm256_maddubs_epi16(sx[3], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3])); + auto dot12 = _mm256_add_epi32(_mm256_madd_epi16(m1, dot1), _mm256_madd_epi16(m1, dot2)); + auto dot34 = _mm256_add_epi32(_mm256_madd_epi16(m1, dot3), _mm256_madd_epi16(m1, dot4)); + acc[iy] = _mm256_add_epi32(acc[iy], _mm256_add_epi32(dot12, dot34)); +#endif + } + } + auto scales_x = _mm_loadu_ps(dx); + for (int iy = 0; iy < nrc_y; ++iy) { + auto sumi = _mm_add_epi32(_mm256_castsi256_si128(acc[iy]), _mm256_extracti128_si256(acc[iy], 1)); +#ifdef HAVE_FANCY_SIMD + sumi = _mm_add_epi32(sumi, _mm_set1_epi32(sy[iy])); +#endif + auto scale = _mm_mul_ps(scales_x, _mm_set1_ps(dy[iy])); + info.store(ix, iy, _mm_mul_ps(scale, _mm_cvtepi32_ps(sumi))); + acc[iy] = _mm256_setzero_si256(); + } + } +} + #ifdef __AVX512BF16__ template <int nrc_y> static void mul_mat_bf16_r16_bf16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { @@ -9114,6 +9342,33 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { #endif expected_typeB = GGML_TYPE_Q8_KR8; break; + case GGML_TYPE_Q8_KV: + assert (ne00 % 32 == 0); + mm.funcs[0] = mul_mat_q8_KV_q8_KV_1<1>; + mm.funcs[1] = mul_mat_q8_KV_q8_KV<2>; + mm.funcs[2] = mul_mat_q8_KV_q8_KV<3>; + mm.funcs[3] = mul_mat_q8_KV_q8_KV<4>; + mm.funcs[4] = mul_mat_q8_KV_q8_KV<5>; + mm.funcs[5] = mul_mat_q8_KV_q8_KV<6>; + mm.funcs[6] = mul_mat_q8_KV_q8_KV<7>; + mm.funcs[7] = mul_mat_q8_KV_q8_KV<8>; +#ifdef HAVE_FANCY_SIMD + mm.func16 = mul_mat_q8_KV_q8_KV<16>; +#endif + expected_typeB = GGML_TYPE_Q8_KV; + break; + case GGML_TYPE_Q8_KV_R8: + assert (ne00 % 32 == 0); + mm.funcs[0] = mul_mat_q8_KV_r8_q8_KV<1>; + mm.funcs[1] = mul_mat_q8_KV_r8_q8_KV<2>; + mm.funcs[2] = mul_mat_q8_KV_r8_q8_KV<3>; + mm.funcs[3] = mul_mat_q8_KV_r8_q8_KV<4>; + mm.funcs[4] = mul_mat_q8_KV_r8_q8_KV<5>; + mm.funcs[5] = mul_mat_q8_KV_r8_q8_KV<6>; + mm.funcs[6] = mul_mat_q8_KV_r8_q8_KV<7>; + mm.funcs[7] = mul_mat_q8_KV_r8_q8_KV<8>; + expected_typeB = GGML_TYPE_Q8_KV; + break; case GGML_TYPE_IQ4_K_R4: assert (ne00 % QK_K == 0); mm.funcs[0] = mul_mat_iq4_k_r4_q8_k<1>; @@ -13424,6 +13679,123 @@ void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& inf } } +static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(n%32 == 0); + int32x4_t acc[4] = {}; + auto dptr = (const float *)info.src1_row(0); + const float dy = dptr[0]; + auto q8y = (const int8_t *)(dptr + 2); + for (int ix = 0; ix < nrc_x; ++ix) { + auto dx = (const float *)((const char *)vx + ix*bx); + auto q8x = (const int8_t *)(dx + 2); + for (int i = 0; i < n/64; ++i) { + auto qx = vld1q_s8_x4(q8x + 64*i); + for (int j = 0; j < 4; ++j) { + acc[j] = ggml_vdotq_s32(acc[j], qx.val[j], vld1q_s8(q8y + 64*i + 16*j)); + } + } + if (int i = 2*(n/64); i < n/32) { + auto qx = vld1q_s8_x2(q8x + 32*i); + for (int j = 0; j < 2; ++j) { + acc[j] = ggml_vdotq_s32(acc[j], qx.val[j], vld1q_s8(q8y + 32*i + 16*j)); + } + } + acc[0] = vaddq_s32(acc[0], acc[1]); + acc[2] = vaddq_s32(acc[2], acc[3]); + acc[0] = vaddq_s32(acc[0], acc[2]); + info.store(ix, 0, dx[0]*dy*vaddvq_s32(acc[0])); + acc[0] = acc[1] = acc[2] = acc[3] = vdupq_n_s32(0); + } +} + +template <int nrc_y> +static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + GGML_ASSERT(n%16 == 0); + int8x16_t qx[4]; + int32x4_t acc[nrc_y] = {}; + float dy[nrc_y]; + const int8_t * q8y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) { + auto dptr = (const float *)info.src1_row(iy); + dy[iy] = dptr[0]; + q8y[iy] = (const int8_t *)(dptr + 2); + } + const int8_t * q8x[4]; + float dx[4]; + for (int ix = 0; ix < nrc_x; ix += 4) { + for (int kx = 0; kx < 4; ++kx) { + auto dptr = (const float *)((const char *)vx + (ix+kx)*bx); + dx[kx] = dptr[0]; + q8x[kx] = (const int8_t *)(dptr + 2); + } + for (int i = 0; i < n/16; ++i) { + for (int kx = 0; kx < 4; ++kx) qx[kx] = vld1q_s8(q8x[kx] + 16*i); + auto row01 = vtrnq_s32(qx[0], qx[1]); + auto row23 = vtrnq_s32(qx[2], qx[3]); + qx[0] = vtrn1q_s64(row01.val[0], row23.val[0]); + qx[1] = vtrn1q_s64(row01.val[1], row23.val[1]); + qx[2] = vtrn2q_s64(row01.val[0], row23.val[0]); + qx[3] = vtrn2q_s64(row01.val[1], row23.val[1]); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8(q8y[iy] + 16*i); + acc[iy] = vdotq_laneq_s32(acc[iy], qx[0], y, 0); + acc[iy] = vdotq_laneq_s32(acc[iy], qx[1], y, 1); + acc[iy] = vdotq_laneq_s32(acc[iy], qx[2], y, 2); + acc[iy] = vdotq_laneq_s32(acc[iy], qx[3], y, 3); + } + } + auto scales_x = vld1q_f32(dx); + for (int iy = 0; iy < nrc_y; ++iy) { + auto scale = vmulq_f32(scales_x, vdupq_n_f32(dy[iy])); + info.store(ix, iy, vmulq_f32(scale, vcvtq_f32_s32(acc[iy]))); + acc[iy] = vdupq_n_s32(0); + } + } +} + +template <int nrc_y> +void mul_mat_q8_KV_r8_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%8 == 0); + int32x4_t acc[2*nrc_y] = {}; + float dy[nrc_y]; + const int8_t * q8y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) { + auto dptr = (const float *)info.src1_row(iy); + dy[iy] = dptr[0]; + q8y[iy] = (const int8_t *)(dptr + 2); + } + for (int ix = 0; ix < nrc_x; ix += 8) { + const float * dptr = (const float *)((const char *)vx + ix*bx); + auto q8x = (const int8_t *)(dptr + 8); + for (int ib = 0; ib < n/16; ++ib) { + auto q1 = vld1q_s8_x4(q8x + 128*ib + 0); + auto q2 = vld1q_s8_x4(q8x + 128*ib + 64); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8(q8y[iy]+16*ib); + acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q1.val[0], y, 0); + acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q1.val[1], y, 0); + acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q1.val[2], y, 1); + acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q1.val[3], y, 1); + acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q2.val[0], y, 2); + acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q2.val[1], y, 2); + acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q2.val[2], y, 3); + acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q2.val[3], y, 3); + } + } + auto scale1_x = vld1q_f32(dptr+0); + auto scale2_x = vld1q_f32(dptr+4); + for (int iy = 0; iy < nrc_y; ++iy) { + auto scale_y = vdupq_n_f32(dy[iy]); + auto scale1 = vmulq_f32(scale1_x, scale_y); + auto scale2 = vmulq_f32(scale2_x, scale_y); + info.store(ix+0, iy, vmulq_f32(scale1, vcvtq_f32_s32(acc[2*iy+0]))); + info.store(ix+4, iy, vmulq_f32(scale2, vcvtq_f32_s32(acc[2*iy+1]))); + acc[2*iy+0] = acc[2*iy+1] = vdupq_n_f32(0.f); + } + } +} + void mul_mat_iq4_nl_r4_q8_0_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); Q8<1, block_q8_0_x4> q8(info); @@ -14000,6 +14372,16 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { SET_MUL_MAT_FUNCTIONS(m, mul_mat_q8_k_r8_q8_k); expected_Btype = GGML_TYPE_Q8_KR8; break; + case GGML_TYPE_Q8_KV: + SET_MUL_MAT_FUNCTIONS(m, mul_mat_q8_KV_q8_KV); + m.funcs[0] = mul_mat_q8_KV_q8_KV_1; + m.func16 = mul_mat_q8_KV_q8_KV<16>; + expected_Btype = GGML_TYPE_Q8_KV; + break; + case GGML_TYPE_Q8_KV_R8: + SET_MUL_MAT_FUNCTIONS(m, mul_mat_q8_KV_r8_q8_KV); + expected_Btype = GGML_TYPE_Q8_KV; + break; case GGML_TYPE_IQ2_K_R4: SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq2_k_r4_q8_k); expected_Btype = GGML_TYPE_Q8_K; @@ -14347,13 +14729,49 @@ struct HelperF16 final : public BaseHelper<step> { } }; +template <int D> struct block_q8_KV { + float d; + int s; + int8_t qs[D]; +}; + +template <int D, int step> +struct HelperQ8KV final : public BaseHelper<step> { + using Base = BaseHelper<step>; + using block_q8 = block_q8_KV<D>; + constexpr static int block_size_q = D; + HelperQ8KV(const char * data, int stride) : Base(data, stride) {} + + // Needed for v * softmax(k * q) + inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const { + auto q8 = (const block_q8_KV<D> *)Base::lblock(l1); +#ifdef __aarch64__ + auto vd = F16::set1(q8->d); + auto qs = vld1_s8_x2(q8->qs + 8*i); + v1 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[0]))); + v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[1]))); +#else + auto vd = F16::set1(q8->d); +#ifdef HAVE_FANCY_SIMD + v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)q8->qs+i+0)))); + v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)q8->qs+i+1)))); +#else + v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(q8->qs+8*i+0))))); + v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(q8->qs+8*i+8))))); +#endif +#endif + } +}; + template <int D, int step> struct HelperQ80 final : public BaseHelper<step> { using Base = BaseHelper<step>; #ifdef HAVE_FANCY_SIMD using block_q8 = block_q8_1; + constexpr static int block_size_q = QK8_1; #else using block_q8 = block_q8_0; + constexpr static int block_size_q = QK8_0; #endif HelperQ80(const char * data, int stride) : Base(data, stride) {} @@ -14397,23 +14815,33 @@ struct HelperQ80 final : public BaseHelper<step> { y += D/QK8_1; } } + + static inline void convert(int nq, int stride_q, const float * q, block_q8_KV<D> * y) { + for (int i = 0; i < nq; ++i) { + quantize_row_q8_KV(q, y, D); + q += stride_q; + ++y; + } + } }; template <int D, int step> -struct HelperQ80R4 : public BaseHelper<step> { +struct HelperQ80R8 : public BaseHelper<step> { using Base = BaseHelper<step>; #ifdef __AVX2__ + constexpr static int block_size_q = QK8_1; using block_q8 = block_q8_1; #else + constexpr static int block_size_q = QK8_0; using block_q8 = block_q8_0; #endif - HelperQ80R4(int nk, const HelperQ80<D, step>& q8) : Base(q8.data, q8.stride) { + HelperQ80R8(int nk, const HelperQ80<D, step>& q8) : Base(q8.data, q8.stride) { r4 = repack(nk, q8); Base::data = (const char *)r4.data(); Base::stride = (D/QK8_0)*sizeof(block_q8_0); } - static std::vector<block_q8_0_r8> repack(int nk, const HelperQ80<D, step> q8) { + static std::vector<block_q8_0_r8> repack(int nk, const HelperQ80<D, step>& q8) { static_assert(D%QK8_0 == 0); GGML_ASSERT(nk%8 == 0); constexpr int nblock = D/QK8_0; @@ -14512,10 +14940,107 @@ struct HelperQ80R4 : public BaseHelper<step> { std::vector<block_q8_0_r8> r4; }; +// TODO: unite this with the above +template <int D, int step> +struct HelperQ8KVR8 : public BaseHelper<step> { + using Base = BaseHelper<step>; + constexpr static int block_size_q = D; + using block_q8 = block_q8_KV<D>; + + struct block_q8_KV_r8 { + float d[8]; + int8_t qs[8*D]; + }; + + HelperQ8KVR8(int nk, const HelperQ8KV<D, step>& q8) : Base(q8.data, q8.stride) { + r4 = repack(nk, q8); + Base::data = (const char *)r4.data(); + Base::stride = sizeof(block_q8_KV_r8)/8; + } + + static std::vector<block_q8_KV_r8> repack(int nk, const HelperQ8KV<D, step>& q8) { + static_assert(D%32 == 0); + GGML_ASSERT(nk%8 == 0); + std::vector<block_q8_KV_r8> result(nk/8); + auto y = result.data(); +#ifdef __ARM_NEON + int8x16x2_t m0, m1, m2, m3; +#endif + const int8_t * x8[8]; + for (int ix = 0; ix < nk/8; ++ix) { + for (int k = 0; k < 8; ++k) { + auto dptr = (const float *)(q8.data + (8*ix + k)*q8.stride); + y[ix].d[k] = dptr[0]; + x8[k] = (const int8_t *)(dptr + 2); + } + for (int ib = 0; ib < D/16; ++ib) { +#ifdef __AVX2__ + auto m0 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[4]+ib), _mm_loadu_si128((const __m128i *)x8[0]+ib)); + auto m1 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[5]+ib), _mm_loadu_si128((const __m128i *)x8[1]+ib)); + auto m2 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[6]+ib), _mm_loadu_si128((const __m128i *)x8[2]+ib)); + auto m3 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[7]+ib), _mm_loadu_si128((const __m128i *)x8[3]+ib)); + auto t0 = _mm256_unpacklo_epi32(m0, m1); + auto t1 = _mm256_unpacklo_epi32(m2, m3); + auto t2 = _mm256_unpackhi_epi32(m0, m1); + auto t3 = _mm256_unpackhi_epi32(m2, m3); + m0 = _mm256_unpacklo_epi64(t0, t1); + m1 = _mm256_unpackhi_epi64(t0, t1); + m2 = _mm256_unpacklo_epi64(t2, t3); + m3 = _mm256_unpackhi_epi64(t2, t3); +#ifdef HAVE_FANCY_SIMD + m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127)); + m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127)); + m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127)); + m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127)); +#endif + _mm256_storeu_si256((__m256i *)y[ix].qs + 4*ib+0, m0); + _mm256_storeu_si256((__m256i *)y[ix].qs + 4*ib+1, m1); + _mm256_storeu_si256((__m256i *)y[ix].qs + 4*ib+2, m2); + _mm256_storeu_si256((__m256i *)y[ix].qs + 4*ib+3, m3); +#elif defined __ARM_NEON + // TODO + m0.val[0] = vld1q_s8(x8[0]+16*ib); m0.val[1] = vld1q_s8(x8[4]+16*ib); + m1.val[0] = vld1q_s8(x8[1]+16*ib); m1.val[1] = vld1q_s8(x8[5]+16*ib); + m2.val[0] = vld1q_s8(x8[2]+16*ib); m2.val[1] = vld1q_s8(x8[6]+16*ib); + m3.val[0] = vld1q_s8(x8[3]+16*ib); m3.val[1] = vld1q_s8(x8[7]+16*ib); + auto row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[0]), vreinterpretq_s32_s8(m1.val[0])); + auto row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[0]), vreinterpretq_s32_s8(m3.val[0])); + m0.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m1.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + m2.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m3.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[1]), vreinterpretq_s32_s8(m1.val[1])); + row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[1]), vreinterpretq_s32_s8(m3.val[1])); + m0.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m1.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + m2.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m3.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + vst1q_s8_x2(y[ix].qs + 0 + 128*ib, m0); + vst1q_s8_x2(y[ix].qs + 32 + 128*ib, m1); + vst1q_s8_x2(y[ix].qs + 64 + 128*ib, m2); + vst1q_s8_x2(y[ix].qs + 96 + 128*ib, m3); +#else + // TODO + for (int l = 0; l < 4; ++l) { + for (int k = 0; k < 8; ++k) for (int i = 0; i < 4; ++i) { + y[ib].qs[32*l+4*k+i+ 0] = x8[k][ib].qs[i+4*l+ 0]; + y[ib].qs[32*l+4*k+i+128] = x8[k][ib].qs[i+4*l+16]; + } + } +#endif + } + } + return result; + } + + std::vector<block_q8_KV_r8> r4; +}; + template <int D, int step> struct HelperQ40 final : public BaseHelper<step> { using Base = BaseHelper<step>; using block_q8 = block_q8_0; + constexpr static int block_size_q = QK8_0; HelperQ40(const char * data, int stride) : Base(data, stride) {} // Needed for v * softmax(k * q) @@ -14559,6 +15084,7 @@ template <int D, int step> struct HelperQ41 final : public BaseHelper<step> { using Base = BaseHelper<step>; using block_q8 = block_q8_1; + constexpr static int block_size_q = QK8_1; HelperQ41(const char * data, int stride) : Base(data, stride) {} // Needed for v * softmax(k * q) @@ -14649,8 +15175,10 @@ template <int D, int step> struct HelperQ60 final : public BaseHelper<step> { #ifdef __aarch64__ using block_q8 = block_q8_0; + constexpr static int block_size_q = QK8_0; #else using block_q8 = block_q8_1; + constexpr static int block_size_q = QK8_1; #endif using Base = BaseHelper<step>; HelperQ60(const char * data, int stride) : Base(data, stride) {} @@ -15071,9 +15599,9 @@ struct FlashQKV { } inline void normalize_and_store(const FlashMS<q_step, k_step>& fms, int j, const qkv_cache_t * R, float * qkv) const { - GGML_ASSERT(fms.S[j] > 0); - auto norm = F16::set1(1/fms.S[j]); - //auto norm = F16::set1(fms.S[j] > 0 ? 1/fms.S[j] : 0.f); + //GGML_ASSERT(fms.S[j] > 0); + //auto norm = F16::set1(1/fms.S[j]); + auto norm = F16::set1(fms.S[j] > 0 ? 1/fms.S[j] : 0.f); for (int i = 0; i < D/F16::block_size; ++i) { auto r = F16::load(R + F16::block_size*i); F16::store(qkv + F16::block_size*i, F16::mul(norm, r)); @@ -15357,13 +15885,29 @@ struct FlashQKfp32 { #endif #endif } - else if constexpr (std::is_same_v<KHelper, HelperQ80R4<D, k_step>>) { + else if constexpr (std::is_same_v<KHelper, HelperQ8KV<D, k_step>>) { +#ifdef __aarch64__ + if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV<16>, 16); + if (nq == 1) return std::make_pair(mul_mat_q8_KV_q8_KV_1, 1); + MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV, nq); +#else +#ifdef HAVE_FANCY_SIMD + if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV<16>, 16); +#endif + if (nq == 1) return std::make_pair(mul_mat_q8_KV_q8_KV_1<1>, 1); + MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV, nq); +#endif + } + else if constexpr (std::is_same_v<KHelper, HelperQ80R8<D, k_step>>) { #ifdef __aarch64__ MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r8_q8_0, nq); #else MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r8_q8_1, nq); #endif } + else if constexpr (std::is_same_v<KHelper, HelperQ8KVR8<D, k_step>>) { + MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_r8_q8_KV, nq); + } else if constexpr (std::is_same_v<KHelper, HelperQ60<D, k_step>>) { #ifdef __aarch64__ MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ60, nq); @@ -15406,7 +15950,7 @@ struct FlashQKfp32 { constexpr int kMaxQ = 8; static_assert(q_step < kMaxQ || q_step%kMaxQ == 0); auto [mul_mat, nrc_q] = mul_mat_kernel<KHelper>(q_step); - DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr}; + DataInfo info{fms.cache, (const char *)q, k_step, (D/KHelper::block_size_q)*sizeof(block_q8), 0, 1, nullptr}; for (int iq = 0; iq < q_step/nrc_q; ++iq) { mul_mat(D, kh.block, kh.stride, info, k_step); info.cur_y += nrc_q; @@ -15428,7 +15972,7 @@ struct FlashQKfp32 { static inline void mul_mask_kq(int nq, const KHelper& kh, int stride_m, const block_q8 * q, const char * mask, FlashMS<q_step, k_step>& fms) { auto [mul_mat, nrc_q] = mul_mat_kernel<KHelper>(nq); - DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr}; + DataInfo info{fms.cache, (const char *)q, k_step, (D/KHelper::block_size_q)*sizeof(block_q8), 0, 1, nullptr}; for (int iq = 0; iq < nq/nrc_q; ++iq) { mul_mat(D, kh.block, kh.stride, info, k_step); info.cur_y += nrc_q; @@ -15516,7 +16060,7 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, FlashMS<q_step, k_step>& fms, FlashQKV<Dv, q_step, k_step>& fqkv, const float * q, const char * mask, float * qkv) { - typename KHelper::block_q8 q8[q_step*(Dk/QK8_0)]; + typename KHelper::block_q8 q8[q_step*(Dk/KHelper::block_size_q)]; #if FA_TIMING Perf perf(false); #endif @@ -15613,12 +16157,28 @@ struct FlashAttn { if (nq1 >= 8) { #if FA_TIMING auto t1 = Perf::cur_time(); - HelperQ80R4<Dk, k_step> khr4(nk1, kh); + HelperQ80R8<Dk, k_step> khr4(nk1, kh); + Perf::instance().accum(4, t1); +#else + HelperQ80R8<Dk, k_step> khr4(nk1, kh); +#endif + compute_helper_q<Dk, Dv, q_step, k_step, HelperQ80R8<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>( + khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); + } else{ + compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>( + kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); + } + } + else if constexpr (std::is_same_v<KHelper, HelperQ8KV<Dk, k_step>>) { + if (nq1 >= 8) { +#if FA_TIMING + auto t1 = Perf::cur_time(); + HelperQ8KVR8<Dk, k_step> khr4(nk1, kh); Perf::instance().accum(4, t1); #else - HelperQ80R4<Dk, k_step> khr4(nk1, kh); + HelperQ8KVR8<Dk, k_step> khr4(nk1, kh); #endif - compute_helper_q<Dk, Dv, q_step, k_step, HelperQ80R4<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>( + compute_helper_q<Dk, Dv, q_step, k_step, HelperQ8KVR8<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>( khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); } else{ compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>( @@ -16142,6 +16702,10 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v, HelperQ80<Dv, k_step> vh(v, stride_v); iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; + case GGML_TYPE_Q8_KV: { + HelperQ8KV<Dv, k_step> vh(v, stride_v); + iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + } break; case GGML_TYPE_Q6_0: { HelperQ60<Dv, k_step> vh(v, stride_v); iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); @@ -16179,6 +16743,10 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, HelperQ80<Dk, k_step> kh(k, stride_k); iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; + case GGML_TYPE_Q8_KV: { + HelperQ8KV<Dk, k_step> kh(k, stride_k); + iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + } break; case GGML_TYPE_Q6_0: { HelperQ60<Dk, k_step> kh(k, stride_k); iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); @@ -16210,7 +16778,7 @@ inline bool flash_attn_is_supported(ggml_type type) { if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_Q6_0 || type == GGML_TYPE_IQ4_NL) return true; #else - if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q6_0) return true; + if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q6_0 || type == GGML_TYPE_Q8_KV) return true; #endif return false; } diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 24b49d89..7b777a1f 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -2967,6 +2967,103 @@ void iqk_quantize_row_q8_K128(const float * x, void * vy, int64_t k) { } #endif } +// TODO: merge this with the above template +void iqk_quantize_row_q8_KV(const float * x, void * vy, int64_t k) { + assert(k % 32 == 0); + auto dptr = (float *)vy; + auto q8 = (int8_t *)(dptr + 2); +#ifdef __AVX2__ + const __m256 signBit = _mm256_set1_ps(-0.0f); + const __m256i perm = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7); + __m256 maxAbs = _mm256_setzero_ps(); + for (int ib = 0; ib < k/8; ++ib) { + const __m256 v = _mm256_loadu_ps(x + 8*ib); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps(signBit, v)); + } + const float maxScalar = hmax_f32_8(maxAbs); + if (!maxScalar) { + dptr[0] = dptr[1] = 0; + std::memset(q8, 0, k*sizeof(int8_t)); + return; + } + dptr[0] = maxScalar / 127.f; + auto mul = _mm256_set1_ps(1/dptr[0]); + auto isum = _mm256_setzero_si256(); + for (int i = 0; i < k/32; i++) { + __m256 v0 = _mm256_mul_ps(mul, _mm256_loadu_ps(x + 32*i + 0)); + __m256 v1 = _mm256_mul_ps(mul, _mm256_loadu_ps(x + 32*i + 8)); + __m256 v2 = _mm256_mul_ps(mul, _mm256_loadu_ps(x + 32*i + 16)); + __m256 v3 = _mm256_mul_ps(mul, _mm256_loadu_ps(x + 32*i + 24)); + v0 = _mm256_round_ps(v0, _MM_ROUND_NEAREST); + v1 = _mm256_round_ps(v1, _MM_ROUND_NEAREST); + v2 = _mm256_round_ps(v2, _MM_ROUND_NEAREST); + v3 = _mm256_round_ps(v3, _MM_ROUND_NEAREST); + __m256i i0 = _mm256_cvtps_epi32(v0); + __m256i i1 = _mm256_cvtps_epi32(v1); + __m256i i2 = _mm256_cvtps_epi32(v2); + __m256i i3 = _mm256_cvtps_epi32(v3); + isum = _mm256_add_epi32(isum, _mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))); + i0 = _mm256_packs_epi32( i0, i1 ); + i2 = _mm256_packs_epi32( i2, i3 ); + i0 = _mm256_packs_epi16( i0, i2 ); + i0 = _mm256_permutevar8x32_epi32( i0, perm ); + _mm256_storeu_si256((__m256i *)q8, i0); + q8 += 32; + } + auto iptr = (int32_t *)(dptr + 1); + iptr[0] = hsum_i32_8(isum); +#elif defined __ARM_NEON + int32x4_t ival[8]; + auto vmax = vdupq_n_f32(0.f); + for (int j = 0; j < k; j += 4) { + vmax = vmaxq_f32(vmax, vabsq_f32(vld1q_f32(x + j))); + } + auto smax = vmaxvq_f32(vmax); + if (!smax) { + dptr[0] = dptr[1] = 0; + std::memset(q8, 0, k*sizeof(int8_t)); + return; + } + dptr[0] = smax/127; + auto vid = vdupq_n_f32(1/dptr[0]); + auto isum = vdupq_n_s32(0); + for (int ib = 0; ib < k/32; ++ib) { + auto xb = x + 32*ib; + for (int k = 0; k < 8; ++k) { + auto val = vld1q_f32(xb + 4*k); + ival[k] = vcvtnq_s32_f32(vmulq_f32(val, vid)); + isum = vaddq_s32(isum, ival[k]); + } + for (int k = 0; k < 4; ++k) { + auto i16 = vcombine_s16(vmovn_s32(ival[2*k+0]), vmovn_s32(ival[2*k+1])); + vst1_s8(q8, vmovn_s16(i16)); + q8 += 8; + } + } + auto iptr = (int32_t *)(dptr + 1); + iptr[0] = vaddvq_s32(isum); +#else + float amax = 0; + for (int j = 0; j < k; ++j) { + float ax = std::abs(x[j]); + amax = std::max(amax, ax); + } + if (!amax) { + dptr[0] = dptr[1] = 0; + std::memset(q8, 0, k*sizeof(int8_t)); + return; + } + dptr[0] = amax/127; + float id = 1/dptr[0]; + int isum = 0; + for (int i = 0; i < k; i++) { + q8[i] = nearest_int(id*x[i]); + isum += q8[i]; + } + auto iptr = (int32_t *)(dptr + 1); + iptr[0] = isum; +#endif +} } void quantize_row_q8_K128(const float * x, void * vy, int64_t k) { @@ -3886,7 +3983,7 @@ static void repack_q8_0(int nrows, int n_per_row, const block_q8_0 * x, block_q8 #ifdef HAVE_FANCY_SIMD static void modify_q8_0_r8(int64_t k, char * cy) { - auto y = (block_iq4_nl_r8 *)cy; + auto y = (block_q8_0_r8 *)cy; int nb = k/(32*8); for (int ib = 0; ib < nb; ++ib) { for (int l = 0; l < 4; ++l) { @@ -5413,6 +5510,150 @@ void vec_dot_q8_k_r8_q8_k(int n, float * s, size_t bs, const void * vx, size_t b } // +// ========================================= q8_KV_r8 +// + +void quantize_row_q8_KV_r8_ref(const float * x, void * y, int64_t k) { + quantize_q8_KV_r8(x, y, 8, k/8, nullptr); +} + +void quantize_row_q8_KV_r8(const float * x, void * y, int64_t k) { + quantize_q8_KV_r8(x, y, 8, k/8, nullptr); +} + +static void repack_q8_KV(int nrows, int n_per_row, const char * cx, char * cy, [[maybe_unused]] bool online) { + GGML_ASSERT(nrows%8 == 0); + GGML_ASSERT(n_per_row%16 == 0); + auto row_size_x = ggml_row_size(GGML_TYPE_Q8_KV, n_per_row); + auto row_size_y = ggml_row_size(GGML_TYPE_Q8_KV_R8, n_per_row); + const int8_t * x8[8]; +#ifdef __ARM_NEON + int8x16x2_t m0, m1, m2, m3; +#endif + for (int row = 0; row < nrows; row += 8) { + auto dy = (float *)cy; + auto qy = (int8_t *)(dy + 8); + for (int k = 0; k < 8; ++k) { + auto dx = (const float *)(cx + k*row_size_x); + dy[k] = dx[0]; + x8[k] = (const int8_t *)(dx + 2); + } + for (int ib = 0; ib < n_per_row/16; ++ib) { +#ifdef __AVX2__ +#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) + auto m0 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[4]+ib), _mm_loadu_si128((const __m128i *)x8[0]+ib)); + auto m1 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[5]+ib), _mm_loadu_si128((const __m128i *)x8[1]+ib)); + auto m2 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[6]+ib), _mm_loadu_si128((const __m128i *)x8[2]+ib)); + auto m3 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[7]+ib), _mm_loadu_si128((const __m128i *)x8[3]+ib)); + auto t0 = _mm256_unpacklo_epi32(m0, m1); + auto t1 = _mm256_unpacklo_epi32(m2, m3); + auto t2 = _mm256_unpackhi_epi32(m0, m1); + auto t3 = _mm256_unpackhi_epi32(m2, m3); + m0 = _mm256_unpacklo_epi64(t0, t1); + m1 = _mm256_unpackhi_epi64(t0, t1); + m2 = _mm256_unpacklo_epi64(t2, t3); + m3 = _mm256_unpackhi_epi64(t2, t3); +#ifdef HAVE_FANCY_SIMD + if (online) { + m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127)); + m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127)); + m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127)); + m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127)); + } +#endif + _mm256_storeu_si256((__m256i *)qy + 4*ib+0, m0); + _mm256_storeu_si256((__m256i *)qy + 4*ib+1, m1); + _mm256_storeu_si256((__m256i *)qy + 4*ib+2, m2); + _mm256_storeu_si256((__m256i *)qy + 4*ib+3, m3); +#elif defined __ARM_NEON + m0.val[0] = vld1q_s8(x8[0]+16*ib); m0.val[1] = vld1q_s8(x8[4]+16*ib); + m1.val[0] = vld1q_s8(x8[1]+16*ib); m1.val[1] = vld1q_s8(x8[5]+16*ib); + m2.val[0] = vld1q_s8(x8[2]+16*ib); m2.val[1] = vld1q_s8(x8[6]+16*ib); + m3.val[0] = vld1q_s8(x8[3]+16*ib); m3.val[1] = vld1q_s8(x8[7]+16*ib); + auto row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[0]), vreinterpretq_s32_s8(m1.val[0])); + auto row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[0]), vreinterpretq_s32_s8(m3.val[0])); + m0.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m1.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + m2.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m3.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[1]), vreinterpretq_s32_s8(m1.val[1])); + row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[1]), vreinterpretq_s32_s8(m3.val[1])); + m0.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m1.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + m2.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m3.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + vst1q_s8_x2(qy + 0 + 128*ib, m0); + vst1q_s8_x2(qy + 32 + 128*ib, m1); + vst1q_s8_x2(qy + 64 + 128*ib, m2); + vst1q_s8_x2(qy + 96 + 128*ib, m3); +#else + // TODO + for (int l = 0; l < 4; ++l) { + for (int k = 0; k < 8; ++k) for (int i = 0; i < 4; ++i) { + y[ib].qs[32*l+4*k+i+ 0] = x8[k][ib].qs[i+4*l+ 0]; + y[ib].qs[32*l+4*k+i+128] = x8[k][ib].qs[i+4*l+16]; + } + } +#endif + + } + cx += 8*row_size_x; + cy += online ? 8*row_size_x : 8*row_size_y; + //So, if we are run-time-repacking (online = true) we don't want to change the stride, so we just leave some unused space at the end of each row + } +} +#ifdef HAVE_FANCY_SIMD +static void modify_q8_KV_r8(int64_t k, char * cy) { + int8_t * q8 = (int8_t *)(cy + 8*sizeof(float)); + for (int j = 0; j < k; ++j) q8[j] += 127; +} +#endif + +size_t quantize_q8_KV_r8(const float * src, void * dst, int64_t nrows, int64_t n_per_row, [[maybe_unused]] const float * imatrix) { + GGML_ASSERT(nrows%8 == 0); + GGML_ASSERT(n_per_row%16 == 0); + char * qcur = (char *)dst; + auto row_size_0 = ggml_row_size(GGML_TYPE_Q8_KV, n_per_row); + auto row_size_1 = ggml_row_size(GGML_TYPE_Q8_KV_R8, n_per_row); + std::vector<char> qtmp(8*row_size_0); + for (int row = 0; row < nrows; row += 8) { + quantize_q8_KV(src, (void *)qtmp.data(), 8, n_per_row, imatrix); + repack_q8_KV(8, n_per_row, qtmp.data(), qcur, false); + qcur += 8*row_size_1; + src += 8*n_per_row; + } + return nrows*row_size_1; +} + +void dequantize_row_q8_KV_r8(const void * vx, float * y, int64_t k) { + auto n_per_row = k/8; + float * y8[8]; + for (int k = 0; k < 8; ++k) y8[k] = y + n_per_row*k; + auto dptr = (const float *)vx; + auto q8 = (const int8_t *)(dptr + 8); + for (int ib = 0; ib < n_per_row/16; ++ib) { + for (int k = 0; k < 8; ++k) { + for (int l = 0; l < 4; ++l) { + for (int i = 0; i < 4; ++i) y8[k][16*ib + 4*l + i] = dptr[k] * q8[128*ib + 32*l + 4*k + i]; + } + } + } +} + +void vec_dot_q8_KV_r8_q8_KV(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(1, 1, n, GGML_TYPE_Q8_KV_R8, vx, 0, GGML_TYPE_Q8_KV, vy, 0, s, 0, 0, 1)) { + return; + } +#endif + GGML_ASSERT(n%QK4_NL == 0); + GGML_ASSERT(nrc == 1); + GGML_UNUSED(bs); + GGML_UNUSED(bx); + GGML_UNUSED(by); +} + +// // ========================================= bf16_r4 // namespace { @@ -6450,6 +6691,47 @@ void vec_dot_iq1_m_r4_q8_k(int n, float * s, size_t bs, const void * vx, size_t GGML_UNUSED(by); } +void quantize_row_q8_KV(const float * x, void * vy, int64_t k) { + iqk_quantize_row_q8_KV(x, vy, k); +} + +void quantize_row_q8_KV_ref(const float * x, void * y, int64_t k) { + quantize_row_q8_KV(x, y, k); +} + +size_t quantize_q8_KV(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + (void)imatrix; + auto row_size = ggml_row_size(GGML_TYPE_Q8_KV, n_per_row); + auto q = (char *)dst; + for (int row = 0; row < nrows; ++row) { + quantize_row_q8_KV(src, q, n_per_row); + src += n_per_row; + q += row_size; + } + return row_size*nrows; +} + +void dequantize_row_q8_KV(const void * x, float * y, int64_t k) { + auto dptr = (const float *)x; + float d = dptr[0]; + auto q8 = (const int8_t *)(dptr + 2); + for (int j = 0; j < k; ++j) y[j] = d * q8[j]; +} + +void vec_dot_q8_KV_q8_KV(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(1, 1, n, GGML_TYPE_Q8_KV, vx, 0, GGML_TYPE_Q8_KV, vy, 0, s, 0, 0, 1)) { + return; + } +#endif + GGML_ASSERT(n%QK4_NL == 0); + GGML_ASSERT(nrc == 1); + GGML_UNUSED(bs); + GGML_UNUSED(bx); + GGML_UNUSED(by); +} + + //================================================ namespace { @@ -6472,8 +6754,9 @@ bool iqk_modify_tensor(struct ggml_tensor * tensor) { { GGML_TYPE_Q4_0_R8, {modify_q4_0_r8, 8} }, #endif #ifdef HAVE_FANCY_SIMD - { GGML_TYPE_Q8_0_R8, {modify_q8_0_r8, 8} }, - { GGML_TYPE_Q8_K_R8, {modify_q8_k_r8, 8} }, + { GGML_TYPE_Q8_0_R8, {modify_q8_0_r8, 8} }, + { GGML_TYPE_Q8_K_R8, {modify_q8_k_r8, 8} }, + { GGML_TYPE_Q8_KV_R8, {modify_q8_KV_r8, 8} }, #endif }; auto it = k_mod_map.find(tensor->type); @@ -6532,6 +6815,7 @@ void iqk_repack_tensor(struct ggml_tensor * tensor) { { GGML_TYPE_Q6_0, { GGML_TYPE_Q6_0_R4, 4, (Repack::repack_func)repack_q6_0} }, { GGML_TYPE_Q8_0, { GGML_TYPE_Q8_0_R8, 8, (Repack::repack_func)repack_q8_0} }, { GGML_TYPE_Q8_K, { GGML_TYPE_Q8_K_R8, 8, (Repack::repack_func)repack_q8_k} }, + { GGML_TYPE_Q8_KV, { GGML_TYPE_Q8_KV_R8, 8, (Repack::repack_func)repack_q8_KV} }, #ifdef __AVX512BF16__ { GGML_TYPE_BF16, { GGML_TYPE_BF16_R16, 16, (Repack::repack_func)repack_bf16<ggml_bf16_t>}}, { GGML_TYPE_F16, { GGML_TYPE_BF16_R16, 16, (Repack::repack_func)repack_bf16<ggml_half>} }, diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h index 97719361..76fbac3b 100644 --- a/ggml/src/iqk/iqk_quantize.h +++ b/ggml/src/iqk/iqk_quantize.h @@ -217,6 +217,18 @@ size_t quantize_q8_k_r8(const float * GGML_RESTRICT src, void * GGML_RESTRICT ds void dequantize_row_q8_k_r8(const block_q8_k_r8 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void vec_dot_q8_k_r8_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void quantize_row_q8_KV_ref(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_KV(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_q8_KV(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_q8_KV(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_q8_KV_q8_KV(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + +void quantize_row_q8_KV_r8_ref(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_KV_r8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_q8_KV_r8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_q8_KV_r8(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_q8_KV_r8_q8_KV(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + void iqk_quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void quantize_row_q8_K64_ref(const float * GGML_RESTRICT x, block_q8_K64 * GGML_RESTRICT y, int64_t k); void quantize_row_q8_K64(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); diff --git a/include/llama.h b/include/llama.h index 39251d35..b5ad65e7 100644 --- a/include/llama.h +++ b/include/llama.h @@ -180,6 +180,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_IQ3_KL = 146, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ2_KS = 147, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ4_KSS = 148, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q8_KV = 149, // except 1d tensors // LLAMA_FTYPE_MOSTLY_Q4_0_R8 = 202, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q8_0_R8 = 207, // except 1d tensors @@ -206,6 +207,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_IQ4_K_R4 = 340, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ5_K_R4 = 341, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ4_KS_R4 = 345, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q8_KV_R8 = 398, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q8_K_R8 = 399, // except 1d tensors LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file diff --git a/src/llama.cpp b/src/llama.cpp index 8c4a966d..0257a0a3 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3180,6 +3180,10 @@ static bool llama_kv_cache_init( for (int i = 0; i < (int) n_layer; i++) { const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); + const uint32_t n_head = hparams.n_head(i); + const uint32_t n_head_kv = hparams.n_head_kv(i); + const uint32_t n_embd_head_k= hparams.n_embd_head_k; + struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front(); ggml_tensor * k; @@ -3201,7 +3205,8 @@ static bool llama_kv_cache_init( const uint32_t kv_lora_rank = hparams.n_lora_kv; LLAMA_LOG_INFO("%s: layer %d: n_embd_head_qk_rope = %d, kv_lora_rank = %d\n", __func__, i, n_embd_head_qk_rope, kv_lora_rank); #if MLA_USE_TRANSPOSED_CACHE - ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_k, (kv_lora_rank + n_embd_head_qk_rope)*kv_size); + ggml_tensor * kv = ggml_new_tensor_2d(ctx, cache.type_k, kv_lora_rank + n_embd_head_qk_rope, kv_size); + //ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_k, (kv_lora_rank + n_embd_head_qk_rope)*kv_size); #else ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_v, (kv_lora_rank + n_embd_head_qk_rope)*kv_size); #endif @@ -3215,7 +3220,10 @@ static bool llama_kv_cache_init( n_mla++; } else { - k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); + //printf("Creating cache tensors:\n"); + //printf("n_embd_k_gqa = %d, kv_size = %d, n_head = %d, n_head_kv = %d, n_embd_head_k = %d\n", (int)n_embd_k_gqa, (int)kv_size, (int)n_head, (int)n_head_kv, (int)n_embd_head_k); + //k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); + k = ggml_new_tensor_2d(ctx, type_k, n_embd_head_k, n_head_kv*kv_size); v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); ggml_format_name(k, "cache_k_l%d", i); ggml_format_name(v, "cache_v_l%d", i); @@ -4002,6 +4010,7 @@ struct llama_model_loader { case GGML_TYPE_Q5_1: ftype = LLAMA_FTYPE_MOSTLY_Q5_1; break; case GGML_TYPE_Q6_0: ftype = LLAMA_FTYPE_MOSTLY_Q6_0; break; case GGML_TYPE_Q8_0: ftype = LLAMA_FTYPE_MOSTLY_Q8_0; break; + case GGML_TYPE_Q8_KV: ftype = LLAMA_FTYPE_MOSTLY_Q8_KV; break; case GGML_TYPE_Q2_K: ftype = LLAMA_FTYPE_MOSTLY_Q2_K; break; case GGML_TYPE_Q3_K: ftype = LLAMA_FTYPE_MOSTLY_Q3_K_M; break; case GGML_TYPE_Q3_K_R4: ftype = LLAMA_FTYPE_MOSTLY_Q3_K_R4; break; @@ -4012,6 +4021,7 @@ struct llama_model_loader { case GGML_TYPE_Q6_K: ftype = LLAMA_FTYPE_MOSTLY_Q6_K; break; case GGML_TYPE_Q6_K_R4: ftype = LLAMA_FTYPE_MOSTLY_Q6_K_R4; break; case GGML_TYPE_Q8_K_R8: ftype = LLAMA_FTYPE_MOSTLY_Q8_K_R8; break; + case GGML_TYPE_Q8_KV_R8: ftype = LLAMA_FTYPE_MOSTLY_Q8_KV_R8; break; case GGML_TYPE_IQ2_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XXS; break; case GGML_TYPE_IQ2_XXS_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ2_XXS_R4; break; case GGML_TYPE_IQ2_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XS; break; @@ -4730,6 +4740,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_Q5_1: return "Q5_1"; case LLAMA_FTYPE_MOSTLY_Q6_0: return "Q6_0"; case LLAMA_FTYPE_MOSTLY_Q8_0: return "Q8_0"; + case LLAMA_FTYPE_MOSTLY_Q8_KV: return "Q8_KV"; case LLAMA_FTYPE_MOSTLY_Q2_K: return "Q2_K - Medium"; case LLAMA_FTYPE_MOSTLY_Q2_K_R4: return "Q2_K_R4"; case LLAMA_FTYPE_MOSTLY_Q2_K_S: return "Q2_K - Small"; @@ -4746,6 +4757,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_Q6_K: return "Q6_K"; case LLAMA_FTYPE_MOSTLY_Q6_K_R4: return "Q6_K_R4"; case LLAMA_FTYPE_MOSTLY_Q8_K_R8: return "Q8_K_R8"; + case LLAMA_FTYPE_MOSTLY_Q8_KV_R8: return "Q8_KV_R8"; case LLAMA_FTYPE_MOSTLY_IQ2_XXS: return "IQ2_XXS - 2.0625 bpw"; case LLAMA_FTYPE_MOSTLY_IQ2_XXS_R4:return "IQ2_XXS_R4 - 2.0625 bpw"; case LLAMA_FTYPE_MOSTLY_IQ2_XS: return "IQ2_XS - 2.3125 bpw"; @@ -8283,11 +8295,20 @@ static void llm_build_kv_store( const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); + const int64_t n_head = hparams.n_head(il); + const int64_t n_head_kv = hparams.n_head_kv(il); + const int64_t n_embd_head_k = hparams.n_embd_head_k; + const int64_t n_embd_head_v = hparams.n_embd_head_v; + GGML_ASSERT(kv.size == n_ctx); - struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa, - (ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa))*kv_head); - cb(k_cache_view, "k_cache_view", il); + //struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa, + // (ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa))*kv_head); + //cb(k_cache_view, "k_cache_view", il); + + auto k_row_size = ggml_row_size(kv.k_l[il]->type, n_embd_head_k); + ggml_tensor * k_cache_view = ggml_view_2d(ctx, kv.k_l[il], n_embd_head_k, n_tokens*n_head_kv, + k_row_size, k_row_size*n_head_kv*kv_head); // note: storing RoPE-ed version of K in the KV cache ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view)); @@ -8706,7 +8727,7 @@ static struct ggml_tensor * llm_build_kqv( struct ggml_tensor * k = ggml_view_3d(ctx, kv.k_l[il], n_embd_head_k, n_kv, n_head_kv, - ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa), + ggml_row_size(kv.k_l[il]->type, n_embd_head_k)*n_head_kv, //n_embd_k_gqa), ggml_row_size(kv.k_l[il]->type, n_embd_head_k), 0); cb(k, "k", il); @@ -13507,8 +13528,9 @@ struct llm_build_context { ggml_tensor * kvr = ggml_concat(ctx0, kv_compressed, ggml_permute(ctx0, k_rope, 0, 2, 1, 3), 0); cb(kvr, "kvr", il); - ggml_tensor * kv_cache_view = ggml_view_1d(ctx0, kv_self.kv_l[il], n_tokens*(kv_lora_rank + n_embd_head_qk_rope), - ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope)*kv_head); + auto row_size = ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope); + ggml_tensor * kv_cache_view = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_self.kv_l[il]->ne[0], n_tokens, + row_size, row_size*kv_head); ggml_build_forward_expand(gf, ggml_cpy(ctx0, kvr, kv_cache_view)); ggml_tensor * kv_cache = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_lora_rank + n_embd_head_qk_rope, n_kv, @@ -16164,7 +16186,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n new_type = GGML_TYPE_IQ5_K; } else if (new_type != GGML_TYPE_Q8_0 && new_type != GGML_TYPE_Q8_0_R8 && new_type != GGML_TYPE_IQ6_K && new_type != GGML_TYPE_Q6_K_R4 && - new_type != GGML_TYPE_Q8_K_R8) { + new_type != GGML_TYPE_Q8_K_R8 && new_type != GGML_TYPE_Q8_KV && new_type != GGML_TYPE_Q8_KV_R8) { new_type = GGML_TYPE_Q6_K; } } @@ -16218,6 +16240,9 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n else if (new_type == GGML_TYPE_Q8_K_R8) { new_type = GGML_TYPE_Q8_0; } + else if (new_type == GGML_TYPE_Q8_KV_R8) { + new_type = GGML_TYPE_Q8_0; + } else if (new_type == GGML_TYPE_IQ2_K_R4) { new_type = GGML_TYPE_IQ2_K; } @@ -16728,6 +16753,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_Q5_1: default_type = GGML_TYPE_Q5_1; break; case LLAMA_FTYPE_MOSTLY_Q6_0: default_type = GGML_TYPE_Q6_0; break; case LLAMA_FTYPE_MOSTLY_Q8_0: default_type = GGML_TYPE_Q8_0; break; + case LLAMA_FTYPE_MOSTLY_Q8_KV:default_type = GGML_TYPE_Q8_KV;break; case LLAMA_FTYPE_MOSTLY_F16: default_type = GGML_TYPE_F16; break; case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break; case LLAMA_FTYPE_MOSTLY_BF16_R16: default_type = GGML_TYPE_BF16_R16; break; @@ -16751,6 +16777,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_Q6_K: default_type = GGML_TYPE_Q6_K; break; case LLAMA_FTYPE_MOSTLY_Q6_K_R4: default_type = GGML_TYPE_Q6_K_R4; break; case LLAMA_FTYPE_MOSTLY_Q8_K_R8: default_type = GGML_TYPE_Q8_K_R8; break; + case LLAMA_FTYPE_MOSTLY_Q8_KV_R8: default_type = GGML_TYPE_Q8_KV_R8; break; case LLAMA_FTYPE_MOSTLY_IQ2_XXS: default_type = GGML_TYPE_IQ2_XXS; break; case LLAMA_FTYPE_MOSTLY_IQ2_XXS_R4:default_type = GGML_TYPE_IQ2_XXS_R4; break; case LLAMA_FTYPE_MOSTLY_IQ2_XS: default_type = GGML_TYPE_IQ2_XS; break; @@ -17194,6 +17221,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s if (tensor->ne[1] % 8 != 0) new_type = GGML_TYPE_Q8_0; else chunk_size_multiplier = 8; } + else if (new_type == GGML_TYPE_Q8_KV_R8) { + if (tensor->ne[1] % 8 != 0) new_type = GGML_TYPE_Q8_0; + else chunk_size_multiplier = 8; + } else if (new_type == GGML_TYPE_IQ2_BN_R4) { if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ2_BN; else chunk_size_multiplier = 4; |