summaryrefslogtreecommitdiff
path: root/ggml/src
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-02-19 11:47:07 +0200
committerGitHub <noreply@github.com>2025-02-19 11:47:07 +0200
commita0ebfdd661a2ccb2700b0e36cfc10ca1a2b4de98 (patch)
treed5bb2c8f07625c617d1113348b4b67d79b8f64f4 /ggml/src
parent047ba895bb3d94f055756c1ec7767b3342cb9c90 (diff)
Q8_KV: 8-bit quantization type targeting the KV cache (#208)
* Adding q8_KV - Basics + AVX2 gemm/gemv * q8_KV: Better AVX2 gemm * q8_KV: Better Zen4 gemm We get 225.7 t/s for L3-8B. In comparison q8_0 without run-tinme-repacking is at 169 t/s. * q8_KV: AVX2 gemm/gemv We get 254 t/s for L3-8B vs 194 t/s for q8_0 without rtr. * q8_KV: be able to use it for K cache This required quite a few fixes in ggml and llama.cpp: * ggml: do not calculate row size as n/block_size*type_size. I had removed most of it when implementing the quants with per row scale, bit it was stull lurking in ggml_copy. Not sure if these were the last remnants of ggmil-style row sizes, or if there are still places left * llama.cpp: get rid of the the 1d K cache assumption. Create and manage the K-cache as a 2D tensor so we can have per row meta data as needed by q8_KV. Using q8_KV for K-cache results in non-negligible performance gains. More details to follow, but for DeepSeek-Lite with MLA, we get 18% speedup for PP-8192 compared to q8_0 K-cache. * q8_KV: be able to use it for K cache in FA * q8_KV: repack it for K*Q in FA * q8_KV: slightly faster gemv on Zen4 * q8_KV: slightly faster gemv on Zen4 * q8_KV: ARM_NEON We get PP-512 = 167 t/s for L3-8B without interleaving! We do the interleaving on the fly, so I wonder if this could be done for other quants as well. * q8_KV: use it in FA on NEON * q8_KV_r8 - repacked q8_KV On Zen4 it is slower than q8_k_r8 (292 vs 370 t/s) This makes no sense whatsoever as the q8_KV_r8 GEMM is basically the q8_k_r8 GEMM with the unnecessary block stuff removed (so, one would think that it would be faster). * q8_KV_r8: don't use nrc_y = 16 on Zen4 This is faster - 350 t/s. Why? Much better than the 290 t/s we had before, but still slower than the 370 t/s for q8_k_r8. * q8_KV: nrc_y = 16 also doesn't pay off in FA * Minor --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml/src')
-rw-r--r--ggml/src/ggml-quants.c6
-rw-r--r--ggml/src/ggml.c48
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp598
-rw-r--r--ggml/src/iqk/iqk_quantize.cpp290
-rw-r--r--ggml/src/iqk/iqk_quantize.h12
5 files changed, 929 insertions, 25 deletions
diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c
index fe7de167..e8218e76 100644
--- a/ggml/src/ggml-quants.c
+++ b/ggml/src/ggml-quants.c
@@ -15214,8 +15214,10 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
case GGML_TYPE_IQ3_K_R4: break;
case GGML_TYPE_IQ4_K_R4: break;
case GGML_TYPE_IQ5_K_R4: break;
- case GGML_TYPE_IQ4_KS_R4: break;
- case GGML_TYPE_Q8_K_R8: break;
+ case GGML_TYPE_IQ4_KS_R4:break;
+ case GGML_TYPE_Q8_KV_R8: break;
+ case GGML_TYPE_Q8_K_R8: break;
+ case GGML_TYPE_Q8_KV: break;
case GGML_TYPE_BF16_R16: break;
case GGML_TYPE_Q4_0_4_4:
case GGML_TYPE_Q4_0_4_8:
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 8ab6b0a9..0aee8dd4 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -1362,6 +1362,30 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_q8_K128,
.row_meta_size = 0,
},
+ [GGML_TYPE_Q8_KV] = {
+ .type_name = "q8_KV",
+ .blck_size = 32,
+ .type_size = 32,
+ .is_quantized = true,
+ .to_float = (ggml_to_float_t) dequantize_row_q8_KV,
+ .from_float = quantize_row_q8_KV,
+ .from_float_ref = (ggml_from_float_t)quantize_row_q8_KV_ref,
+ .vec_dot = vec_dot_q8_KV_q8_KV,
+ .vec_dot_type = GGML_TYPE_Q8_KV,
+ .row_meta_size = 8,
+ },
+ [GGML_TYPE_Q8_KV_R8] = {
+ .type_name = "q8_KV_r8",
+ .blck_size = 32,
+ .type_size = 32,
+ .is_quantized = true,
+ .to_float = (ggml_to_float_t) dequantize_row_q8_KV_r8,
+ .from_float = quantize_row_q8_KV_r8,
+ .from_float_ref = (ggml_from_float_t)quantize_row_q8_KV_r8_ref,
+ .vec_dot = vec_dot_q8_KV_r8_q8_KV,
+ .vec_dot_type = GGML_TYPE_Q8_KV,
+ .row_meta_size = 4,
+ },
[GGML_TYPE_Q8_K16] = {
.type_name = "q8_K16",
.blck_size = 64,
@@ -4373,6 +4397,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
case GGML_FTYPE_MOSTLY_Q5_1: wtype = GGML_TYPE_Q5_1; break;
case GGML_FTYPE_MOSTLY_Q6_0: wtype = GGML_TYPE_Q6_0; break;
case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break;
+ case GGML_FTYPE_MOSTLY_Q8_KV: wtype = GGML_TYPE_Q8_KV; break;
case GGML_FTYPE_MOSTLY_Q2_K: wtype = GGML_TYPE_Q2_K; break;
case GGML_FTYPE_MOSTLY_Q2_K_R4: wtype = GGML_TYPE_Q2_K_R4; break;
case GGML_FTYPE_MOSTLY_Q3_K: wtype = GGML_TYPE_Q3_K; break;
@@ -4384,6 +4409,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break;
case GGML_FTYPE_MOSTLY_Q6_K_R4: wtype = GGML_TYPE_Q6_K_R4; break;
case GGML_FTYPE_MOSTLY_Q8_K_R8: wtype = GGML_TYPE_Q8_K_R8; break;
+ case GGML_FTYPE_MOSTLY_Q8_KV_R8: wtype = GGML_TYPE_Q8_KV_R8; break;
case GGML_FTYPE_MOSTLY_IQ2_XXS: wtype = GGML_TYPE_IQ2_XXS; break;
case GGML_FTYPE_MOSTLY_IQ2_XXS_R4: wtype = GGML_TYPE_IQ2_XXS_R4;break;
case GGML_FTYPE_MOSTLY_IQ2_XS: wtype = GGML_TYPE_IQ2_XS; break;
@@ -9436,7 +9462,7 @@ static void ggml_compute_forward_dup_f16(
float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
size_t id = 0;
- size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
+ size_t rs = ggml_row_size(dst->type, ne00); //nb0 * (ne00 / ggml_blck_size(dst->type));
char * dst_ptr = (char *) dst->data;
for (int i03 = 0; i03 < ne03; i03++) {
@@ -9722,7 +9748,7 @@ static void ggml_compute_forward_dup_bf16(
float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
size_t id = 0;
- size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
+ size_t rs = ggml_row_size(dst->type, ne00); //nb0 * (ne00 / ggml_blck_size(dst->type));
char * dst_ptr = (char *) dst->data;
for (int i03 = 0; i03 < ne03; i03++) {
@@ -10042,7 +10068,7 @@ static void ggml_compute_forward_dup_f32(
ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float;
size_t id = 0;
- size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
+ size_t rs = ggml_row_size(dst->type, ne00); //nb0 * (ne00 / ggml_blck_size(dst->type));
char * dst_ptr = (char *) dst->data;
for (int i03 = 0; i03 < ne03; i03++) {
@@ -10936,6 +10962,7 @@ static void ggml_compute_forward_add(
case GGML_TYPE_Q6_K:
case GGML_TYPE_Q6_K_R4:
case GGML_TYPE_Q8_K_R8:
+ case GGML_TYPE_Q8_KV_R8:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XXS_R4:
case GGML_TYPE_IQ2_XS:
@@ -11406,6 +11433,7 @@ static void ggml_compute_forward_add1(
case GGML_TYPE_Q6_K:
case GGML_TYPE_Q6_K_R4:
case GGML_TYPE_Q8_K_R8:
+ case GGML_TYPE_Q8_KV_R8:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XXS_R4:
case GGML_TYPE_IQ2_XS:
@@ -11573,6 +11601,7 @@ static void ggml_compute_forward_acc(
case GGML_TYPE_Q6_K:
case GGML_TYPE_Q6_K_R4:
case GGML_TYPE_Q8_K_R8:
+ case GGML_TYPE_Q8_KV_R8:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XXS_R4:
case GGML_TYPE_IQ2_XS:
@@ -14061,7 +14090,7 @@ static void ggml_compute_forward_mul_mat(
// nb01 >= nb00 - src0 is not transposed
// compute by src0 rows
-#if GGML_USE_IQK_MULMAT || GGML_USE_LLAMAFILE
+#if GGML_USE_LLAMAFILE
// broadcast factors
const int64_t r2 = ne12 / ne02;
const int64_t r3 = ne13 / ne03;
@@ -14344,7 +14373,7 @@ static void ggml_compute_forward_mul_mat_id(
char * wdata_src1_end = (src1->type == vec_dot_type) ?
(char *) params->wdata :
- (char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t));
+ (char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, src1->ne[0])*ggml_nrows(src1), sizeof(int64_t));
struct mmid_row_mapping {
int32_t i1;
@@ -14768,6 +14797,7 @@ static void ggml_compute_forward_out_prod(
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q6_0:
case GGML_TYPE_Q8_0:
+ case GGML_TYPE_Q8_KV:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q2_K_R4:
case GGML_TYPE_Q3_K:
@@ -14779,6 +14809,7 @@ static void ggml_compute_forward_out_prod(
case GGML_TYPE_Q6_K:
case GGML_TYPE_Q6_K_R4:
case GGML_TYPE_Q8_K_R8:
+ case GGML_TYPE_Q8_KV_R8:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XXS_R4:
case GGML_TYPE_IQ2_XS:
@@ -15186,6 +15217,7 @@ static void ggml_compute_forward_set(
case GGML_TYPE_Q6_K:
case GGML_TYPE_Q6_K_R4:
case GGML_TYPE_Q8_K_R8:
+ case GGML_TYPE_Q8_KV_R8:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XXS_R4:
case GGML_TYPE_IQ2_XS:
@@ -15473,6 +15505,7 @@ static void ggml_compute_forward_get_rows(
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q6_0:
case GGML_TYPE_Q8_0:
+ case GGML_TYPE_Q8_KV:
case GGML_TYPE_Q8_1:
case GGML_TYPE_Q8_0_X4:
case GGML_TYPE_Q8_1_X4:
@@ -15487,6 +15520,7 @@ static void ggml_compute_forward_get_rows(
case GGML_TYPE_Q6_K:
case GGML_TYPE_Q6_K_R4:
case GGML_TYPE_Q8_K_R8:
+ case GGML_TYPE_Q8_KV_R8:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XXS_R4:
case GGML_TYPE_IQ2_XS:
@@ -16116,6 +16150,7 @@ static void ggml_compute_forward_clamp(
case GGML_TYPE_Q6_K:
case GGML_TYPE_Q6_K_R4:
case GGML_TYPE_Q8_K_R8:
+ case GGML_TYPE_Q8_KV_R8:
case GGML_TYPE_Q8_KR8:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XXS_R4:
@@ -16159,6 +16194,7 @@ static void ggml_compute_forward_clamp(
case GGML_TYPE_Q8_K:
case GGML_TYPE_Q8_K64:
case GGML_TYPE_Q8_K128:
+ case GGML_TYPE_Q8_KV:
case GGML_TYPE_Q8_K16:
case GGML_TYPE_Q8_K32:
case GGML_TYPE_Q4_0_4_4:
@@ -22970,6 +23006,7 @@ size_t ggml_quantize_chunk(
case GGML_TYPE_Q5_1: result = quantize_q5_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q6_0: result = quantize_q6_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q8_0: result = quantize_q8_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_Q8_KV: result = quantize_q8_KV(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q2_K: result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q2_K_R4: result = quantize_q2_k_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q3_K: result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
@@ -22981,6 +23018,7 @@ size_t ggml_quantize_chunk(
case GGML_TYPE_Q6_K: result = quantize_q6_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q6_K_R4: result = quantize_q6_k_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q8_K_R8: result = quantize_q8_k_r8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_Q8_KV_R8:result = quantize_q8_KV_r8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ2_XXS: result = quantize_iq2_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ2_XXS_R4:result = quantize_iq2_xxs_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ2_XS: result = quantize_iq2_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 8d6b45da..3bfded73 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -269,6 +269,8 @@ struct MulMat {
case GGML_TYPE_IQ4_XS_R8:
case GGML_TYPE_Q4_K_R4:
case GGML_TYPE_Q5_K_R4:
+ case GGML_TYPE_Q8_KV:
+ case GGML_TYPE_Q8_KV_R8:
case GGML_TYPE_Q8_K_R8: return 8;
case GGML_TYPE_Q4_0_R8:
case GGML_TYPE_Q8_0_R8:
@@ -301,6 +303,8 @@ struct MulMat {
case GGML_TYPE_IQ4_XS_R8:
case GGML_TYPE_Q4_0_R8:
case GGML_TYPE_Q8_0_R8:
+ case GGML_TYPE_Q8_KV:
+ case GGML_TYPE_Q8_KV_R8:
case GGML_TYPE_Q8_K_R8: return 8;
case GGML_TYPE_BF16_R16: return 16;
default: return 1;
@@ -6107,7 +6111,7 @@ static void mul_mat_q6_k_r4_q8_k(int n, const void * vx, size_t bx, const DataIn
// The HAVE_FANCY_SIMD should only be #if defined(__AVX512_VNNI__ && defined(__AVX512VL__)
template <int nrc_y>
static void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
- GGML_ASSERT(nrc_x%4 == 0);
+ GGML_ASSERT(nrc_x%8 == 0);
Q8<nrc_y, block_q8_K> q8(info);
#ifndef HAVE_FANCY_SIMD
auto m1 = _mm256_set1_epi16(1);
@@ -6169,6 +6173,230 @@ static void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataIn
}
}
+// The HAVE_FANCY_SIMD should only be #if defined(__AVX512_VNNI__ && defined(__AVX512VL__)
+template <int nrc_y>
+static void mul_mat_q8_KV_r8_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%8 == 0);
+ GGML_ASSERT(n%32 == 0);
+#ifndef HAVE_FANCY_SIMD
+ auto m1 = _mm256_set1_epi16(1);
+#endif
+ int nb = n / 16;
+ __m256i acc[nrc_y] = {};
+ __m256i qx[4];
+ float dy[nrc_y];
+#ifdef HAVE_FANCY_SIMD
+ float sy[nrc_y];
+#endif
+ const int8_t * q8y[nrc_y];
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto dptr = (const float *)info.src1_row(iy);
+ dy[iy] = dptr[0];
+#ifdef HAVE_FANCY_SIMD
+ auto iptr = (const int32_t *)(dptr + 1);
+ sy[iy] = -127*iptr[0];
+#endif
+ q8y[iy] = (const int8_t *)(dptr + 2);
+ }
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ auto dptr = (const float *)((const char *)vx + ix*bx);
+ auto dx = _mm256_loadu_ps(dptr);
+ auto q8x = (const int8_t *)(dptr + 8);
+ for (int ib = 0; ib < nb; ++ib) { // Blocks of 16 for 8 interleaved rows
+ qx[0] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+0);
+ qx[1] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+1);
+ qx[2] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+2);
+ qx[3] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+3);
+#ifndef HAVE_FANCY_SIMD
+ auto s0 = _mm256_sign_epi8(qx[0], qx[0]);
+ auto s1 = _mm256_sign_epi8(qx[1], qx[1]);
+ auto s2 = _mm256_sign_epi8(qx[2], qx[2]);
+ auto s3 = _mm256_sign_epi8(qx[3], qx[3]);
+#endif
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y128 = _mm_loadu_si128((const __m128i*)q8y[iy]+ib);
+ auto y = MM256_SET_M128I(y128, y128);
+#ifdef HAVE_FANCY_SIMD
+ acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[0], _mm256_shuffle_epi32(y, 0x00));
+ acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[1], _mm256_shuffle_epi32(y, 0x55));
+ acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[2], _mm256_shuffle_epi32(y, 0xaa));
+ acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[3], _mm256_shuffle_epi32(y, 0xff));
+#else
+ auto sumi1 = _mm256_maddubs_epi16(s0, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
+ auto sumi2 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
+ auto sumi3 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
+ auto sumi4 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
+ auto sumi12 = _mm256_add_epi32(_mm256_madd_epi16(m1, sumi1), _mm256_madd_epi16(m1, sumi2));
+ auto sumi34 = _mm256_add_epi32(_mm256_madd_epi16(m1, sumi3), _mm256_madd_epi16(m1, sumi4));
+ acc[iy] = _mm256_add_epi32(acc[iy], _mm256_add_epi32(sumi12, sumi34));
+#endif
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto scale = _mm256_mul_ps(dx, _mm256_set1_ps(dy[iy]));
+#ifdef HAVE_FANCY_SIMD
+ acc[iy] = _mm256_add_epi32(acc[iy], _mm256_set1_epi32(sy[iy]));
+#endif
+ info.store(ix, iy, _mm256_mul_ps(scale, _mm256_cvtepi32_ps(acc[iy])));
+ acc[iy] = _mm256_setzero_si256();
+ }
+ }
+}
+
+template <int nrc_y>
+static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%8 == 0);
+ GGML_ASSERT(n%32 == 0);
+ __m256i qx[2];
+ __m256i acc[2*nrc_y] = {};
+ float dy[nrc_y];
+#ifdef HAVE_FANCY_SIMD
+ int32_t sy[nrc_y];
+#else
+ __m256i sx[2];
+ auto m1 = _mm256_set1_epi16(1);
+#endif
+ const int8_t * q8y[nrc_y];
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto dptr = (const float *)info.src1_row(iy);
+ dy[iy] = dptr[0];
+#ifdef HAVE_FANCY_SIMD
+ auto iptr = (const int32_t *)(dptr+1);
+ sy[iy] = -127*iptr[0];
+#endif
+ q8y[iy] = (const int8_t *)(dptr + 2);
+ }
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ auto dx = (const float *)((const char *)vx + ix*bx);
+ auto q8x = (const int8_t *)(dx + 2);
+ for (int i = 0; i < n/64; ++i) {
+ for (int j = 0; j < 2; ++j) {
+#ifdef HAVE_FANCY_SIMD
+ qx[j] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)q8x + 2*i + j), _mm256_set1_epi8(127));
+#else
+ qx[j] = _mm256_loadu_si256((const __m256i *)q8x + 2*i + j);
+ sx[j] = _mm256_sign_epi8(qx[j], qx[j]);
+#endif
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ for (int j = 0; j < 2; ++j) {
+#ifdef HAVE_FANCY_SIMD
+ acc[2*iy+j] = _mm256_dpbusd_epi32(acc[2*iy+j], qx[j], _mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j));
+#else
+ auto dot = _mm256_maddubs_epi16(sx[j], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j), qx[j]));
+ acc[2*iy+j] = _mm256_add_epi32(acc[2*iy+j], _mm256_madd_epi16(m1, dot));
+#endif
+ }
+ }
+ }
+ if (int i = 2*(n/64); i < n/32) {
+#ifdef HAVE_FANCY_SIMD
+ qx[0] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)q8x + i), _mm256_set1_epi8(127));
+#else
+ qx[0] = _mm256_loadu_si256((const __m256i *)q8x + i);
+ sx[0] = _mm256_sign_epi8(qx[0], qx[0]);
+#endif
+ for (int iy = 0; iy < nrc_y; ++iy) {
+#ifdef HAVE_FANCY_SIMD
+ acc[2*iy] = _mm256_dpbusd_epi32(acc[2*iy], qx[0], _mm256_loadu_si256((const __m256i *)q8y[iy] + i));
+#else
+ auto dot = _mm256_maddubs_epi16(sx[0], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + i), qx[0]));
+ acc[2*iy] = _mm256_add_epi32(acc[2*iy], _mm256_madd_epi16(m1, dot));
+#endif
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sumi = hsum_i32_8(_mm256_add_epi32(acc[2*iy], acc[2*iy+1]));
+#ifdef HAVE_FANCY_SIMD
+ info.store(ix, iy, dx[0]*dy[iy]*(sumi+sy[iy]));
+#else
+ info.store(ix, iy, dx[0]*dy[iy]*sumi);
+#endif
+ acc[2*iy] = acc[2*iy+1] = _mm256_setzero_si256();
+ }
+ }
+}
+
+template <int nrc_y>
+static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%8 == 0);
+ GGML_ASSERT(n%32 == 0);
+ __m256i qx[4];
+#ifndef HAVE_FANCY_SIMD
+ __m256i sx[4];
+ auto m1 = _mm256_set1_epi16(1);
+#endif
+ __m256i acc[nrc_y] = {};
+ float dy[nrc_y];
+#ifdef HAVE_FANCY_SIMD
+ int32_t sy[nrc_y];
+#endif
+ const int8_t * q8y[nrc_y];
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto dptr = (const float *)info.src1_row(iy);
+ dy[iy] = dptr[0];
+#ifdef HAVE_FANCY_SIMD
+ auto iptr = (const int32_t *)(dptr + 1);
+ sy[iy] = -127*iptr[0];
+#endif
+ q8y[iy] = (const int8_t *)(dptr + 2);
+ }
+ const int8_t * q8x[4];
+ float dx[4];
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ for (int kx = 0; kx < 4; ++kx) {
+ auto dptr = (const float *)((const char *)vx + (ix+kx)*bx);
+ dx[kx] = dptr[0];
+ q8x[kx] = (const int8_t *)(dptr + 2);
+ }
+ for (int i = 0; i < n/32; ++i) {
+ for (int kx = 0; kx < 4; ++kx) qx[kx] = _mm256_loadu_si256((const __m256i *)q8x[kx] + i);
+ auto t0 = _mm256_unpacklo_epi32(qx[0], qx[1]);
+ auto t1 = _mm256_unpacklo_epi32(qx[2], qx[3]);
+ auto t2 = _mm256_unpackhi_epi32(qx[0], qx[1]);
+ auto t3 = _mm256_unpackhi_epi32(qx[2], qx[3]);
+#ifdef HAVE_FANCY_SIMD
+ qx[0] = _mm256_add_epi8(_mm256_unpacklo_epi64(t0, t1), _mm256_set1_epi8(127));
+ qx[1] = _mm256_add_epi8(_mm256_unpackhi_epi64(t0, t1), _mm256_set1_epi8(127));
+ qx[2] = _mm256_add_epi8(_mm256_unpacklo_epi64(t2, t3), _mm256_set1_epi8(127));
+ qx[3] = _mm256_add_epi8(_mm256_unpackhi_epi64(t2, t3), _mm256_set1_epi8(127));
+#else
+ qx[0] = _mm256_unpacklo_epi64(t0, t1); sx[0] = _mm256_sign_epi8(qx[0], qx[0]);
+ qx[1] = _mm256_unpackhi_epi64(t0, t1); sx[1] = _mm256_sign_epi8(qx[1], qx[1]);
+ qx[2] = _mm256_unpacklo_epi64(t2, t3); sx[2] = _mm256_sign_epi8(qx[2], qx[2]);
+ qx[3] = _mm256_unpackhi_epi64(t2, t3); sx[3] = _mm256_sign_epi8(qx[3], qx[3]);
+#endif
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = _mm256_loadu_si256((const __m256i *)q8y[iy] + i);
+#ifdef HAVE_FANCY_SIMD
+ acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[0], _mm256_shuffle_epi32(y, 0x00));
+ acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[1], _mm256_shuffle_epi32(y, 0x55));
+ acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[2], _mm256_shuffle_epi32(y, 0xaa));
+ acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[3], _mm256_shuffle_epi32(y, 0xff));
+#else
+ auto dot1 = _mm256_maddubs_epi16(sx[0], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
+ auto dot2 = _mm256_maddubs_epi16(sx[1], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
+ auto dot3 = _mm256_maddubs_epi16(sx[2], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
+ auto dot4 = _mm256_maddubs_epi16(sx[3], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
+ auto dot12 = _mm256_add_epi32(_mm256_madd_epi16(m1, dot1), _mm256_madd_epi16(m1, dot2));
+ auto dot34 = _mm256_add_epi32(_mm256_madd_epi16(m1, dot3), _mm256_madd_epi16(m1, dot4));
+ acc[iy] = _mm256_add_epi32(acc[iy], _mm256_add_epi32(dot12, dot34));
+#endif
+ }
+ }
+ auto scales_x = _mm_loadu_ps(dx);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sumi = _mm_add_epi32(_mm256_castsi256_si128(acc[iy]), _mm256_extracti128_si256(acc[iy], 1));
+#ifdef HAVE_FANCY_SIMD
+ sumi = _mm_add_epi32(sumi, _mm_set1_epi32(sy[iy]));
+#endif
+ auto scale = _mm_mul_ps(scales_x, _mm_set1_ps(dy[iy]));
+ info.store(ix, iy, _mm_mul_ps(scale, _mm_cvtepi32_ps(sumi)));
+ acc[iy] = _mm256_setzero_si256();
+ }
+ }
+}
+
#ifdef __AVX512BF16__
template <int nrc_y>
static void mul_mat_bf16_r16_bf16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
@@ -9114,6 +9342,33 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
#endif
expected_typeB = GGML_TYPE_Q8_KR8;
break;
+ case GGML_TYPE_Q8_KV:
+ assert (ne00 % 32 == 0);
+ mm.funcs[0] = mul_mat_q8_KV_q8_KV_1<1>;
+ mm.funcs[1] = mul_mat_q8_KV_q8_KV<2>;
+ mm.funcs[2] = mul_mat_q8_KV_q8_KV<3>;
+ mm.funcs[3] = mul_mat_q8_KV_q8_KV<4>;
+ mm.funcs[4] = mul_mat_q8_KV_q8_KV<5>;
+ mm.funcs[5] = mul_mat_q8_KV_q8_KV<6>;
+ mm.funcs[6] = mul_mat_q8_KV_q8_KV<7>;
+ mm.funcs[7] = mul_mat_q8_KV_q8_KV<8>;
+#ifdef HAVE_FANCY_SIMD
+ mm.func16 = mul_mat_q8_KV_q8_KV<16>;
+#endif
+ expected_typeB = GGML_TYPE_Q8_KV;
+ break;
+ case GGML_TYPE_Q8_KV_R8:
+ assert (ne00 % 32 == 0);
+ mm.funcs[0] = mul_mat_q8_KV_r8_q8_KV<1>;
+ mm.funcs[1] = mul_mat_q8_KV_r8_q8_KV<2>;
+ mm.funcs[2] = mul_mat_q8_KV_r8_q8_KV<3>;
+ mm.funcs[3] = mul_mat_q8_KV_r8_q8_KV<4>;
+ mm.funcs[4] = mul_mat_q8_KV_r8_q8_KV<5>;
+ mm.funcs[5] = mul_mat_q8_KV_r8_q8_KV<6>;
+ mm.funcs[6] = mul_mat_q8_KV_r8_q8_KV<7>;
+ mm.funcs[7] = mul_mat_q8_KV_r8_q8_KV<8>;
+ expected_typeB = GGML_TYPE_Q8_KV;
+ break;
case GGML_TYPE_IQ4_K_R4:
assert (ne00 % QK_K == 0);
mm.funcs[0] = mul_mat_iq4_k_r4_q8_k<1>;
@@ -13424,6 +13679,123 @@ void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& inf
}
}
+static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(n%32 == 0);
+ int32x4_t acc[4] = {};
+ auto dptr = (const float *)info.src1_row(0);
+ const float dy = dptr[0];
+ auto q8y = (const int8_t *)(dptr + 2);
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ auto dx = (const float *)((const char *)vx + ix*bx);
+ auto q8x = (const int8_t *)(dx + 2);
+ for (int i = 0; i < n/64; ++i) {
+ auto qx = vld1q_s8_x4(q8x + 64*i);
+ for (int j = 0; j < 4; ++j) {
+ acc[j] = ggml_vdotq_s32(acc[j], qx.val[j], vld1q_s8(q8y + 64*i + 16*j));
+ }
+ }
+ if (int i = 2*(n/64); i < n/32) {
+ auto qx = vld1q_s8_x2(q8x + 32*i);
+ for (int j = 0; j < 2; ++j) {
+ acc[j] = ggml_vdotq_s32(acc[j], qx.val[j], vld1q_s8(q8y + 32*i + 16*j));
+ }
+ }
+ acc[0] = vaddq_s32(acc[0], acc[1]);
+ acc[2] = vaddq_s32(acc[2], acc[3]);
+ acc[0] = vaddq_s32(acc[0], acc[2]);
+ info.store(ix, 0, dx[0]*dy*vaddvq_s32(acc[0]));
+ acc[0] = acc[1] = acc[2] = acc[3] = vdupq_n_s32(0);
+ }
+}
+
+template <int nrc_y>
+static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%4 == 0);
+ GGML_ASSERT(n%16 == 0);
+ int8x16_t qx[4];
+ int32x4_t acc[nrc_y] = {};
+ float dy[nrc_y];
+ const int8_t * q8y[nrc_y];
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto dptr = (const float *)info.src1_row(iy);
+ dy[iy] = dptr[0];
+ q8y[iy] = (const int8_t *)(dptr + 2);
+ }
+ const int8_t * q8x[4];
+ float dx[4];
+ for (int ix = 0; ix < nrc_x; ix += 4) {
+ for (int kx = 0; kx < 4; ++kx) {
+ auto dptr = (const float *)((const char *)vx + (ix+kx)*bx);
+ dx[kx] = dptr[0];
+ q8x[kx] = (const int8_t *)(dptr + 2);
+ }
+ for (int i = 0; i < n/16; ++i) {
+ for (int kx = 0; kx < 4; ++kx) qx[kx] = vld1q_s8(q8x[kx] + 16*i);
+ auto row01 = vtrnq_s32(qx[0], qx[1]);
+ auto row23 = vtrnq_s32(qx[2], qx[3]);
+ qx[0] = vtrn1q_s64(row01.val[0], row23.val[0]);
+ qx[1] = vtrn1q_s64(row01.val[1], row23.val[1]);
+ qx[2] = vtrn2q_s64(row01.val[0], row23.val[0]);
+ qx[3] = vtrn2q_s64(row01.val[1], row23.val[1]);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = vld1q_s8(q8y[iy] + 16*i);
+ acc[iy] = vdotq_laneq_s32(acc[iy], qx[0], y, 0);
+ acc[iy] = vdotq_laneq_s32(acc[iy], qx[1], y, 1);
+ acc[iy] = vdotq_laneq_s32(acc[iy], qx[2], y, 2);
+ acc[iy] = vdotq_laneq_s32(acc[iy], qx[3], y, 3);
+ }
+ }
+ auto scales_x = vld1q_f32(dx);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto scale = vmulq_f32(scales_x, vdupq_n_f32(dy[iy]));
+ info.store(ix, iy, vmulq_f32(scale, vcvtq_f32_s32(acc[iy])));
+ acc[iy] = vdupq_n_s32(0);
+ }
+ }
+}
+
+template <int nrc_y>
+void mul_mat_q8_KV_r8_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(nrc_x%8 == 0);
+ int32x4_t acc[2*nrc_y] = {};
+ float dy[nrc_y];
+ const int8_t * q8y[nrc_y];
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto dptr = (const float *)info.src1_row(iy);
+ dy[iy] = dptr[0];
+ q8y[iy] = (const int8_t *)(dptr + 2);
+ }
+ for (int ix = 0; ix < nrc_x; ix += 8) {
+ const float * dptr = (const float *)((const char *)vx + ix*bx);
+ auto q8x = (const int8_t *)(dptr + 8);
+ for (int ib = 0; ib < n/16; ++ib) {
+ auto q1 = vld1q_s8_x4(q8x + 128*ib + 0);
+ auto q2 = vld1q_s8_x4(q8x + 128*ib + 64);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto y = vld1q_s8(q8y[iy]+16*ib);
+ acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q1.val[0], y, 0);
+ acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q1.val[1], y, 0);
+ acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q1.val[2], y, 1);
+ acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q1.val[3], y, 1);
+ acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q2.val[0], y, 2);
+ acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q2.val[1], y, 2);
+ acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q2.val[2], y, 3);
+ acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q2.val[3], y, 3);
+ }
+ }
+ auto scale1_x = vld1q_f32(dptr+0);
+ auto scale2_x = vld1q_f32(dptr+4);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto scale_y = vdupq_n_f32(dy[iy]);
+ auto scale1 = vmulq_f32(scale1_x, scale_y);
+ auto scale2 = vmulq_f32(scale2_x, scale_y);
+ info.store(ix+0, iy, vmulq_f32(scale1, vcvtq_f32_s32(acc[2*iy+0])));
+ info.store(ix+4, iy, vmulq_f32(scale2, vcvtq_f32_s32(acc[2*iy+1])));
+ acc[2*iy+0] = acc[2*iy+1] = vdupq_n_f32(0.f);
+ }
+ }
+}
+
void mul_mat_iq4_nl_r4_q8_0_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<1, block_q8_0_x4> q8(info);
@@ -14000,6 +14372,16 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
SET_MUL_MAT_FUNCTIONS(m, mul_mat_q8_k_r8_q8_k);
expected_Btype = GGML_TYPE_Q8_KR8;
break;
+ case GGML_TYPE_Q8_KV:
+ SET_MUL_MAT_FUNCTIONS(m, mul_mat_q8_KV_q8_KV);
+ m.funcs[0] = mul_mat_q8_KV_q8_KV_1;
+ m.func16 = mul_mat_q8_KV_q8_KV<16>;
+ expected_Btype = GGML_TYPE_Q8_KV;
+ break;
+ case GGML_TYPE_Q8_KV_R8:
+ SET_MUL_MAT_FUNCTIONS(m, mul_mat_q8_KV_r8_q8_KV);
+ expected_Btype = GGML_TYPE_Q8_KV;
+ break;
case GGML_TYPE_IQ2_K_R4:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq2_k_r4_q8_k);
expected_Btype = GGML_TYPE_Q8_K;
@@ -14347,13 +14729,49 @@ struct HelperF16 final : public BaseHelper<step> {
}
};
+template <int D> struct block_q8_KV {
+ float d;
+ int s;
+ int8_t qs[D];
+};
+
+template <int D, int step>
+struct HelperQ8KV final : public BaseHelper<step> {
+ using Base = BaseHelper<step>;
+ using block_q8 = block_q8_KV<D>;
+ constexpr static int block_size_q = D;
+ HelperQ8KV(const char * data, int stride) : Base(data, stride) {}
+
+ // Needed for v * softmax(k * q)
+ inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const {
+ auto q8 = (const block_q8_KV<D> *)Base::lblock(l1);
+#ifdef __aarch64__
+ auto vd = F16::set1(q8->d);
+ auto qs = vld1_s8_x2(q8->qs + 8*i);
+ v1 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[0])));
+ v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[1])));
+#else
+ auto vd = F16::set1(q8->d);
+#ifdef HAVE_FANCY_SIMD
+ v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)q8->qs+i+0))));
+ v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)q8->qs+i+1))));
+#else
+ v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(q8->qs+8*i+0)))));
+ v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(q8->qs+8*i+8)))));
+#endif
+#endif
+ }
+};
+
template <int D, int step>
struct HelperQ80 final : public BaseHelper<step> {
using Base = BaseHelper<step>;
#ifdef HAVE_FANCY_SIMD
using block_q8 = block_q8_1;
+ constexpr static int block_size_q = QK8_1;
#else
using block_q8 = block_q8_0;
+ constexpr static int block_size_q = QK8_0;
#endif
HelperQ80(const char * data, int stride) : Base(data, stride) {}
@@ -14397,23 +14815,33 @@ struct HelperQ80 final : public BaseHelper<step> {
y += D/QK8_1;
}
}
+
+ static inline void convert(int nq, int stride_q, const float * q, block_q8_KV<D> * y) {
+ for (int i = 0; i < nq; ++i) {
+ quantize_row_q8_KV(q, y, D);
+ q += stride_q;
+ ++y;
+ }
+ }
};
template <int D, int step>
-struct HelperQ80R4 : public BaseHelper<step> {
+struct HelperQ80R8 : public BaseHelper<step> {
using Base = BaseHelper<step>;
#ifdef __AVX2__
+ constexpr static int block_size_q = QK8_1;
using block_q8 = block_q8_1;
#else
+ constexpr static int block_size_q = QK8_0;
using block_q8 = block_q8_0;
#endif
- HelperQ80R4(int nk, const HelperQ80<D, step>& q8) : Base(q8.data, q8.stride) {
+ HelperQ80R8(int nk, const HelperQ80<D, step>& q8) : Base(q8.data, q8.stride) {
r4 = repack(nk, q8);
Base::data = (const char *)r4.data();
Base::stride = (D/QK8_0)*sizeof(block_q8_0);
}
- static std::vector<block_q8_0_r8> repack(int nk, const HelperQ80<D, step> q8) {
+ static std::vector<block_q8_0_r8> repack(int nk, const HelperQ80<D, step>& q8) {
static_assert(D%QK8_0 == 0);
GGML_ASSERT(nk%8 == 0);
constexpr int nblock = D/QK8_0;
@@ -14512,10 +14940,107 @@ struct HelperQ80R4 : public BaseHelper<step> {
std::vector<block_q8_0_r8> r4;
};
+// TODO: unite this with the above
+template <int D, int step>
+struct HelperQ8KVR8 : public BaseHelper<step> {
+ using Base = BaseHelper<step>;
+ constexpr static int block_size_q = D;
+ using block_q8 = block_q8_KV<D>;
+
+ struct block_q8_KV_r8 {
+ float d[8];
+ int8_t qs[8*D];
+ };
+
+ HelperQ8KVR8(int nk, const HelperQ8KV<D, step>& q8) : Base(q8.data, q8.stride) {
+ r4 = repack(nk, q8);
+ Base::data = (const char *)r4.data();
+ Base::stride = sizeof(block_q8_KV_r8)/8;
+ }
+
+ static std::vector<block_q8_KV_r8> repack(int nk, const HelperQ8KV<D, step>& q8) {
+ static_assert(D%32 == 0);
+ GGML_ASSERT(nk%8 == 0);
+ std::vector<block_q8_KV_r8> result(nk/8);
+ auto y = result.data();
+#ifdef __ARM_NEON
+ int8x16x2_t m0, m1, m2, m3;
+#endif
+ const int8_t * x8[8];
+ for (int ix = 0; ix < nk/8; ++ix) {
+ for (int k = 0; k < 8; ++k) {
+ auto dptr = (const float *)(q8.data + (8*ix + k)*q8.stride);
+ y[ix].d[k] = dptr[0];
+ x8[k] = (const int8_t *)(dptr + 2);
+ }
+ for (int ib = 0; ib < D/16; ++ib) {
+#ifdef __AVX2__
+ auto m0 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[4]+ib), _mm_loadu_si128((const __m128i *)x8[0]+ib));
+ auto m1 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[5]+ib), _mm_loadu_si128((const __m128i *)x8[1]+ib));
+ auto m2 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[6]+ib), _mm_loadu_si128((const __m128i *)x8[2]+ib));
+ auto m3 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[7]+ib), _mm_loadu_si128((const __m128i *)x8[3]+ib));
+ auto t0 = _mm256_unpacklo_epi32(m0, m1);
+ auto t1 = _mm256_unpacklo_epi32(m2, m3);
+ auto t2 = _mm256_unpackhi_epi32(m0, m1);
+ auto t3 = _mm256_unpackhi_epi32(m2, m3);
+ m0 = _mm256_unpacklo_epi64(t0, t1);
+ m1 = _mm256_unpackhi_epi64(t0, t1);
+ m2 = _mm256_unpacklo_epi64(t2, t3);
+ m3 = _mm256_unpackhi_epi64(t2, t3);
+#ifdef HAVE_FANCY_SIMD
+ m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127));
+ m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127));
+ m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127));
+ m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127));
+#endif
+ _mm256_storeu_si256((__m256i *)y[ix].qs + 4*ib+0, m0);
+ _mm256_storeu_si256((__m256i *)y[ix].qs + 4*ib+1, m1);
+ _mm256_storeu_si256((__m256i *)y[ix].qs + 4*ib+2, m2);
+ _mm256_storeu_si256((__m256i *)y[ix].qs + 4*ib+3, m3);
+#elif defined __ARM_NEON
+ // TODO
+ m0.val[0] = vld1q_s8(x8[0]+16*ib); m0.val[1] = vld1q_s8(x8[4]+16*ib);
+ m1.val[0] = vld1q_s8(x8[1]+16*ib); m1.val[1] = vld1q_s8(x8[5]+16*ib);
+ m2.val[0] = vld1q_s8(x8[2]+16*ib); m2.val[1] = vld1q_s8(x8[6]+16*ib);
+ m3.val[0] = vld1q_s8(x8[3]+16*ib); m3.val[1] = vld1q_s8(x8[7]+16*ib);
+ auto row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[0]), vreinterpretq_s32_s8(m1.val[0]));
+ auto row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[0]), vreinterpretq_s32_s8(m3.val[0]));
+ m0.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
+ m1.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
+ m2.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
+ m3.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
+ row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[1]), vreinterpretq_s32_s8(m1.val[1]));
+ row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[1]), vreinterpretq_s32_s8(m3.val[1]));
+ m0.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
+ m1.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
+ m2.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
+ m3.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
+ vst1q_s8_x2(y[ix].qs + 0 + 128*ib, m0);
+ vst1q_s8_x2(y[ix].qs + 32 + 128*ib, m1);
+ vst1q_s8_x2(y[ix].qs + 64 + 128*ib, m2);
+ vst1q_s8_x2(y[ix].qs + 96 + 128*ib, m3);
+#else
+ // TODO
+ for (int l = 0; l < 4; ++l) {
+ for (int k = 0; k < 8; ++k) for (int i = 0; i < 4; ++i) {
+ y[ib].qs[32*l+4*k+i+ 0] = x8[k][ib].qs[i+4*l+ 0];
+ y[ib].qs[32*l+4*k+i+128] = x8[k][ib].qs[i+4*l+16];
+ }
+ }
+#endif
+ }
+ }
+ return result;
+ }
+
+ std::vector<block_q8_KV_r8> r4;
+};
+
template <int D, int step>
struct HelperQ40 final : public BaseHelper<step> {
using Base = BaseHelper<step>;
using block_q8 = block_q8_0;
+ constexpr static int block_size_q = QK8_0;
HelperQ40(const char * data, int stride) : Base(data, stride) {}
// Needed for v * softmax(k * q)
@@ -14559,6 +15084,7 @@ template <int D, int step>
struct HelperQ41 final : public BaseHelper<step> {
using Base = BaseHelper<step>;
using block_q8 = block_q8_1;
+ constexpr static int block_size_q = QK8_1;
HelperQ41(const char * data, int stride) : Base(data, stride) {}
// Needed for v * softmax(k * q)
@@ -14649,8 +15175,10 @@ template <int D, int step>
struct HelperQ60 final : public BaseHelper<step> {
#ifdef __aarch64__
using block_q8 = block_q8_0;
+ constexpr static int block_size_q = QK8_0;
#else
using block_q8 = block_q8_1;
+ constexpr static int block_size_q = QK8_1;
#endif
using Base = BaseHelper<step>;
HelperQ60(const char * data, int stride) : Base(data, stride) {}
@@ -15071,9 +15599,9 @@ struct FlashQKV {
}
inline void normalize_and_store(const FlashMS<q_step, k_step>& fms, int j, const qkv_cache_t * R, float * qkv) const {
- GGML_ASSERT(fms.S[j] > 0);
- auto norm = F16::set1(1/fms.S[j]);
- //auto norm = F16::set1(fms.S[j] > 0 ? 1/fms.S[j] : 0.f);
+ //GGML_ASSERT(fms.S[j] > 0);
+ //auto norm = F16::set1(1/fms.S[j]);
+ auto norm = F16::set1(fms.S[j] > 0 ? 1/fms.S[j] : 0.f);
for (int i = 0; i < D/F16::block_size; ++i) {
auto r = F16::load(R + F16::block_size*i);
F16::store(qkv + F16::block_size*i, F16::mul(norm, r));
@@ -15357,13 +15885,29 @@ struct FlashQKfp32 {
#endif
#endif
}
- else if constexpr (std::is_same_v<KHelper, HelperQ80R4<D, k_step>>) {
+ else if constexpr (std::is_same_v<KHelper, HelperQ8KV<D, k_step>>) {
+#ifdef __aarch64__
+ if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV<16>, 16);
+ if (nq == 1) return std::make_pair(mul_mat_q8_KV_q8_KV_1, 1);
+ MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV, nq);
+#else
+#ifdef HAVE_FANCY_SIMD
+ if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV<16>, 16);
+#endif
+ if (nq == 1) return std::make_pair(mul_mat_q8_KV_q8_KV_1<1>, 1);
+ MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV, nq);
+#endif
+ }
+ else if constexpr (std::is_same_v<KHelper, HelperQ80R8<D, k_step>>) {
#ifdef __aarch64__
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r8_q8_0, nq);
#else
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r8_q8_1, nq);
#endif
}
+ else if constexpr (std::is_same_v<KHelper, HelperQ8KVR8<D, k_step>>) {
+ MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_r8_q8_KV, nq);
+ }
else if constexpr (std::is_same_v<KHelper, HelperQ60<D, k_step>>) {
#ifdef __aarch64__
MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ60, nq);
@@ -15406,7 +15950,7 @@ struct FlashQKfp32 {
constexpr int kMaxQ = 8;
static_assert(q_step < kMaxQ || q_step%kMaxQ == 0);
auto [mul_mat, nrc_q] = mul_mat_kernel<KHelper>(q_step);
- DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr};
+ DataInfo info{fms.cache, (const char *)q, k_step, (D/KHelper::block_size_q)*sizeof(block_q8), 0, 1, nullptr};
for (int iq = 0; iq < q_step/nrc_q; ++iq) {
mul_mat(D, kh.block, kh.stride, info, k_step);
info.cur_y += nrc_q;
@@ -15428,7 +15972,7 @@ struct FlashQKfp32 {
static inline void mul_mask_kq(int nq, const KHelper& kh, int stride_m,
const block_q8 * q, const char * mask, FlashMS<q_step, k_step>& fms) {
auto [mul_mat, nrc_q] = mul_mat_kernel<KHelper>(nq);
- DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr};
+ DataInfo info{fms.cache, (const char *)q, k_step, (D/KHelper::block_size_q)*sizeof(block_q8), 0, 1, nullptr};
for (int iq = 0; iq < nq/nrc_q; ++iq) {
mul_mat(D, kh.block, kh.stride, info, k_step);
info.cur_y += nrc_q;
@@ -15516,7 +16060,7 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
FlashMS<q_step, k_step>& fms,
FlashQKV<Dv, q_step, k_step>& fqkv,
const float * q, const char * mask, float * qkv) {
- typename KHelper::block_q8 q8[q_step*(Dk/QK8_0)];
+ typename KHelper::block_q8 q8[q_step*(Dk/KHelper::block_size_q)];
#if FA_TIMING
Perf perf(false);
#endif
@@ -15613,12 +16157,28 @@ struct FlashAttn {
if (nq1 >= 8) {
#if FA_TIMING
auto t1 = Perf::cur_time();
- HelperQ80R4<Dk, k_step> khr4(nk1, kh);
+ HelperQ80R8<Dk, k_step> khr4(nk1, kh);
+ Perf::instance().accum(4, t1);
+#else
+ HelperQ80R8<Dk, k_step> khr4(nk1, kh);
+#endif
+ compute_helper_q<Dk, Dv, q_step, k_step, HelperQ80R8<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
+ khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
+ } else{
+ compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
+ kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
+ }
+ }
+ else if constexpr (std::is_same_v<KHelper, HelperQ8KV<Dk, k_step>>) {
+ if (nq1 >= 8) {
+#if FA_TIMING
+ auto t1 = Perf::cur_time();
+ HelperQ8KVR8<Dk, k_step> khr4(nk1, kh);
Perf::instance().accum(4, t1);
#else
- HelperQ80R4<Dk, k_step> khr4(nk1, kh);
+ HelperQ8KVR8<Dk, k_step> khr4(nk1, kh);
#endif
- compute_helper_q<Dk, Dv, q_step, k_step, HelperQ80R4<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
+ compute_helper_q<Dk, Dv, q_step, k_step, HelperQ8KVR8<Dk, k_step>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
} else{
compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
@@ -16142,6 +16702,10 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v,
HelperQ80<Dv, k_step> vh(v, stride_v);
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
} break;
+ case GGML_TYPE_Q8_KV: {
+ HelperQ8KV<Dv, k_step> vh(v, stride_v);
+ iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
+ } break;
case GGML_TYPE_Q6_0: {
HelperQ60<Dv, k_step> vh(v, stride_v);
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
@@ -16179,6 +16743,10 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
HelperQ80<Dk, k_step> kh(k, stride_k);
iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
} break;
+ case GGML_TYPE_Q8_KV: {
+ HelperQ8KV<Dk, k_step> kh(k, stride_k);
+ iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
+ } break;
case GGML_TYPE_Q6_0: {
HelperQ60<Dk, k_step> kh(k, stride_k);
iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
@@ -16210,7 +16778,7 @@ inline bool flash_attn_is_supported(ggml_type type) {
if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 ||
type == GGML_TYPE_Q6_0 || type == GGML_TYPE_IQ4_NL) return true;
#else
- if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q6_0) return true;
+ if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q6_0 || type == GGML_TYPE_Q8_KV) return true;
#endif
return false;
}
diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp
index 24b49d89..7b777a1f 100644
--- a/ggml/src/iqk/iqk_quantize.cpp
+++ b/ggml/src/iqk/iqk_quantize.cpp
@@ -2967,6 +2967,103 @@ void iqk_quantize_row_q8_K128(const float * x, void * vy, int64_t k) {
}
#endif
}
+// TODO: merge this with the above template
+void iqk_quantize_row_q8_KV(const float * x, void * vy, int64_t k) {
+ assert(k % 32 == 0);
+ auto dptr = (float *)vy;
+ auto q8 = (int8_t *)(dptr + 2);
+#ifdef __AVX2__
+ const __m256 signBit = _mm256_set1_ps(-0.0f);
+ const __m256i perm = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7);
+ __m256 maxAbs = _mm256_setzero_ps();
+ for (int ib = 0; ib < k/8; ++ib) {
+ const __m256 v = _mm256_loadu_ps(x + 8*ib);
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps(signBit, v));
+ }
+ const float maxScalar = hmax_f32_8(maxAbs);
+ if (!maxScalar) {
+ dptr[0] = dptr[1] = 0;
+ std::memset(q8, 0, k*sizeof(int8_t));
+ return;
+ }
+ dptr[0] = maxScalar / 127.f;
+ auto mul = _mm256_set1_ps(1/dptr[0]);
+ auto isum = _mm256_setzero_si256();
+ for (int i = 0; i < k/32; i++) {
+ __m256 v0 = _mm256_mul_ps(mul, _mm256_loadu_ps(x + 32*i + 0));
+ __m256 v1 = _mm256_mul_ps(mul, _mm256_loadu_ps(x + 32*i + 8));
+ __m256 v2 = _mm256_mul_ps(mul, _mm256_loadu_ps(x + 32*i + 16));
+ __m256 v3 = _mm256_mul_ps(mul, _mm256_loadu_ps(x + 32*i + 24));
+ v0 = _mm256_round_ps(v0, _MM_ROUND_NEAREST);
+ v1 = _mm256_round_ps(v1, _MM_ROUND_NEAREST);
+ v2 = _mm256_round_ps(v2, _MM_ROUND_NEAREST);
+ v3 = _mm256_round_ps(v3, _MM_ROUND_NEAREST);
+ __m256i i0 = _mm256_cvtps_epi32(v0);
+ __m256i i1 = _mm256_cvtps_epi32(v1);
+ __m256i i2 = _mm256_cvtps_epi32(v2);
+ __m256i i3 = _mm256_cvtps_epi32(v3);
+ isum = _mm256_add_epi32(isum, _mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
+ i0 = _mm256_packs_epi32( i0, i1 );
+ i2 = _mm256_packs_epi32( i2, i3 );
+ i0 = _mm256_packs_epi16( i0, i2 );
+ i0 = _mm256_permutevar8x32_epi32( i0, perm );
+ _mm256_storeu_si256((__m256i *)q8, i0);
+ q8 += 32;
+ }
+ auto iptr = (int32_t *)(dptr + 1);
+ iptr[0] = hsum_i32_8(isum);
+#elif defined __ARM_NEON
+ int32x4_t ival[8];
+ auto vmax = vdupq_n_f32(0.f);
+ for (int j = 0; j < k; j += 4) {
+ vmax = vmaxq_f32(vmax, vabsq_f32(vld1q_f32(x + j)));
+ }
+ auto smax = vmaxvq_f32(vmax);
+ if (!smax) {
+ dptr[0] = dptr[1] = 0;
+ std::memset(q8, 0, k*sizeof(int8_t));
+ return;
+ }
+ dptr[0] = smax/127;
+ auto vid = vdupq_n_f32(1/dptr[0]);
+ auto isum = vdupq_n_s32(0);
+ for (int ib = 0; ib < k/32; ++ib) {
+ auto xb = x + 32*ib;
+ for (int k = 0; k < 8; ++k) {
+ auto val = vld1q_f32(xb + 4*k);
+ ival[k] = vcvtnq_s32_f32(vmulq_f32(val, vid));
+ isum = vaddq_s32(isum, ival[k]);
+ }
+ for (int k = 0; k < 4; ++k) {
+ auto i16 = vcombine_s16(vmovn_s32(ival[2*k+0]), vmovn_s32(ival[2*k+1]));
+ vst1_s8(q8, vmovn_s16(i16));
+ q8 += 8;
+ }
+ }
+ auto iptr = (int32_t *)(dptr + 1);
+ iptr[0] = vaddvq_s32(isum);
+#else
+ float amax = 0;
+ for (int j = 0; j < k; ++j) {
+ float ax = std::abs(x[j]);
+ amax = std::max(amax, ax);
+ }
+ if (!amax) {
+ dptr[0] = dptr[1] = 0;
+ std::memset(q8, 0, k*sizeof(int8_t));
+ return;
+ }
+ dptr[0] = amax/127;
+ float id = 1/dptr[0];
+ int isum = 0;
+ for (int i = 0; i < k; i++) {
+ q8[i] = nearest_int(id*x[i]);
+ isum += q8[i];
+ }
+ auto iptr = (int32_t *)(dptr + 1);
+ iptr[0] = isum;
+#endif
+}
}
void quantize_row_q8_K128(const float * x, void * vy, int64_t k) {
@@ -3886,7 +3983,7 @@ static void repack_q8_0(int nrows, int n_per_row, const block_q8_0 * x, block_q8
#ifdef HAVE_FANCY_SIMD
static void modify_q8_0_r8(int64_t k, char * cy) {
- auto y = (block_iq4_nl_r8 *)cy;
+ auto y = (block_q8_0_r8 *)cy;
int nb = k/(32*8);
for (int ib = 0; ib < nb; ++ib) {
for (int l = 0; l < 4; ++l) {
@@ -5413,6 +5510,150 @@ void vec_dot_q8_k_r8_q8_k(int n, float * s, size_t bs, const void * vx, size_t b
}
//
+// ========================================= q8_KV_r8
+//
+
+void quantize_row_q8_KV_r8_ref(const float * x, void * y, int64_t k) {
+ quantize_q8_KV_r8(x, y, 8, k/8, nullptr);
+}
+
+void quantize_row_q8_KV_r8(const float * x, void * y, int64_t k) {
+ quantize_q8_KV_r8(x, y, 8, k/8, nullptr);
+}
+
+static void repack_q8_KV(int nrows, int n_per_row, const char * cx, char * cy, [[maybe_unused]] bool online) {
+ GGML_ASSERT(nrows%8 == 0);
+ GGML_ASSERT(n_per_row%16 == 0);
+ auto row_size_x = ggml_row_size(GGML_TYPE_Q8_KV, n_per_row);
+ auto row_size_y = ggml_row_size(GGML_TYPE_Q8_KV_R8, n_per_row);
+ const int8_t * x8[8];
+#ifdef __ARM_NEON
+ int8x16x2_t m0, m1, m2, m3;
+#endif
+ for (int row = 0; row < nrows; row += 8) {
+ auto dy = (float *)cy;
+ auto qy = (int8_t *)(dy + 8);
+ for (int k = 0; k < 8; ++k) {
+ auto dx = (const float *)(cx + k*row_size_x);
+ dy[k] = dx[0];
+ x8[k] = (const int8_t *)(dx + 2);
+ }
+ for (int ib = 0; ib < n_per_row/16; ++ib) {
+#ifdef __AVX2__
+#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
+ auto m0 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[4]+ib), _mm_loadu_si128((const __m128i *)x8[0]+ib));
+ auto m1 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[5]+ib), _mm_loadu_si128((const __m128i *)x8[1]+ib));
+ auto m2 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[6]+ib), _mm_loadu_si128((const __m128i *)x8[2]+ib));
+ auto m3 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[7]+ib), _mm_loadu_si128((const __m128i *)x8[3]+ib));
+ auto t0 = _mm256_unpacklo_epi32(m0, m1);
+ auto t1 = _mm256_unpacklo_epi32(m2, m3);
+ auto t2 = _mm256_unpackhi_epi32(m0, m1);
+ auto t3 = _mm256_unpackhi_epi32(m2, m3);
+ m0 = _mm256_unpacklo_epi64(t0, t1);
+ m1 = _mm256_unpackhi_epi64(t0, t1);
+ m2 = _mm256_unpacklo_epi64(t2, t3);
+ m3 = _mm256_unpackhi_epi64(t2, t3);
+#ifdef HAVE_FANCY_SIMD
+ if (online) {
+ m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127));
+ m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127));
+ m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127));
+ m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127));
+ }
+#endif
+ _mm256_storeu_si256((__m256i *)qy + 4*ib+0, m0);
+ _mm256_storeu_si256((__m256i *)qy + 4*ib+1, m1);
+ _mm256_storeu_si256((__m256i *)qy + 4*ib+2, m2);
+ _mm256_storeu_si256((__m256i *)qy + 4*ib+3, m3);
+#elif defined __ARM_NEON
+ m0.val[0] = vld1q_s8(x8[0]+16*ib); m0.val[1] = vld1q_s8(x8[4]+16*ib);
+ m1.val[0] = vld1q_s8(x8[1]+16*ib); m1.val[1] = vld1q_s8(x8[5]+16*ib);
+ m2.val[0] = vld1q_s8(x8[2]+16*ib); m2.val[1] = vld1q_s8(x8[6]+16*ib);
+ m3.val[0] = vld1q_s8(x8[3]+16*ib); m3.val[1] = vld1q_s8(x8[7]+16*ib);
+ auto row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[0]), vreinterpretq_s32_s8(m1.val[0]));
+ auto row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[0]), vreinterpretq_s32_s8(m3.val[0]));
+ m0.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
+ m1.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
+ m2.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
+ m3.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
+ row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[1]), vreinterpretq_s32_s8(m1.val[1]));
+ row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[1]), vreinterpretq_s32_s8(m3.val[1]));
+ m0.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
+ m1.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
+ m2.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
+ m3.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
+ vst1q_s8_x2(qy + 0 + 128*ib, m0);
+ vst1q_s8_x2(qy + 32 + 128*ib, m1);
+ vst1q_s8_x2(qy + 64 + 128*ib, m2);
+ vst1q_s8_x2(qy + 96 + 128*ib, m3);
+#else
+ // TODO
+ for (int l = 0; l < 4; ++l) {
+ for (int k = 0; k < 8; ++k) for (int i = 0; i < 4; ++i) {
+ y[ib].qs[32*l+4*k+i+ 0] = x8[k][ib].qs[i+4*l+ 0];
+ y[ib].qs[32*l+4*k+i+128] = x8[k][ib].qs[i+4*l+16];
+ }
+ }
+#endif
+
+ }
+ cx += 8*row_size_x;
+ cy += online ? 8*row_size_x : 8*row_size_y;
+ //So, if we are run-time-repacking (online = true) we don't want to change the stride, so we just leave some unused space at the end of each row
+ }
+}
+#ifdef HAVE_FANCY_SIMD
+static void modify_q8_KV_r8(int64_t k, char * cy) {
+ int8_t * q8 = (int8_t *)(cy + 8*sizeof(float));
+ for (int j = 0; j < k; ++j) q8[j] += 127;
+}
+#endif
+
+size_t quantize_q8_KV_r8(const float * src, void * dst, int64_t nrows, int64_t n_per_row, [[maybe_unused]] const float * imatrix) {
+ GGML_ASSERT(nrows%8 == 0);
+ GGML_ASSERT(n_per_row%16 == 0);
+ char * qcur = (char *)dst;
+ auto row_size_0 = ggml_row_size(GGML_TYPE_Q8_KV, n_per_row);
+ auto row_size_1 = ggml_row_size(GGML_TYPE_Q8_KV_R8, n_per_row);
+ std::vector<char> qtmp(8*row_size_0);
+ for (int row = 0; row < nrows; row += 8) {
+ quantize_q8_KV(src, (void *)qtmp.data(), 8, n_per_row, imatrix);
+ repack_q8_KV(8, n_per_row, qtmp.data(), qcur, false);
+ qcur += 8*row_size_1;
+ src += 8*n_per_row;
+ }
+ return nrows*row_size_1;
+}
+
+void dequantize_row_q8_KV_r8(const void * vx, float * y, int64_t k) {
+ auto n_per_row = k/8;
+ float * y8[8];
+ for (int k = 0; k < 8; ++k) y8[k] = y + n_per_row*k;
+ auto dptr = (const float *)vx;
+ auto q8 = (const int8_t *)(dptr + 8);
+ for (int ib = 0; ib < n_per_row/16; ++ib) {
+ for (int k = 0; k < 8; ++k) {
+ for (int l = 0; l < 4; ++l) {
+ for (int i = 0; i < 4; ++i) y8[k][16*ib + 4*l + i] = dptr[k] * q8[128*ib + 32*l + 4*k + i];
+ }
+ }
+ }
+}
+
+void vec_dot_q8_KV_r8_q8_KV(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
+#if GGML_USE_IQK_MULMAT
+ if (iqk_mul_mat(1, 1, n, GGML_TYPE_Q8_KV_R8, vx, 0, GGML_TYPE_Q8_KV, vy, 0, s, 0, 0, 1)) {
+ return;
+ }
+#endif
+ GGML_ASSERT(n%QK4_NL == 0);
+ GGML_ASSERT(nrc == 1);
+ GGML_UNUSED(bs);
+ GGML_UNUSED(bx);
+ GGML_UNUSED(by);
+}
+
+//
// ========================================= bf16_r4
//
namespace {
@@ -6450,6 +6691,47 @@ void vec_dot_iq1_m_r4_q8_k(int n, float * s, size_t bs, const void * vx, size_t
GGML_UNUSED(by);
}
+void quantize_row_q8_KV(const float * x, void * vy, int64_t k) {
+ iqk_quantize_row_q8_KV(x, vy, k);
+}
+
+void quantize_row_q8_KV_ref(const float * x, void * y, int64_t k) {
+ quantize_row_q8_KV(x, y, k);
+}
+
+size_t quantize_q8_KV(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) {
+ (void)imatrix;
+ auto row_size = ggml_row_size(GGML_TYPE_Q8_KV, n_per_row);
+ auto q = (char *)dst;
+ for (int row = 0; row < nrows; ++row) {
+ quantize_row_q8_KV(src, q, n_per_row);
+ src += n_per_row;
+ q += row_size;
+ }
+ return row_size*nrows;
+}
+
+void dequantize_row_q8_KV(const void * x, float * y, int64_t k) {
+ auto dptr = (const float *)x;
+ float d = dptr[0];
+ auto q8 = (const int8_t *)(dptr + 2);
+ for (int j = 0; j < k; ++j) y[j] = d * q8[j];
+}
+
+void vec_dot_q8_KV_q8_KV(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
+#if GGML_USE_IQK_MULMAT
+ if (iqk_mul_mat(1, 1, n, GGML_TYPE_Q8_KV, vx, 0, GGML_TYPE_Q8_KV, vy, 0, s, 0, 0, 1)) {
+ return;
+ }
+#endif
+ GGML_ASSERT(n%QK4_NL == 0);
+ GGML_ASSERT(nrc == 1);
+ GGML_UNUSED(bs);
+ GGML_UNUSED(bx);
+ GGML_UNUSED(by);
+}
+
+
//================================================
namespace {
@@ -6472,8 +6754,9 @@ bool iqk_modify_tensor(struct ggml_tensor * tensor) {
{ GGML_TYPE_Q4_0_R8, {modify_q4_0_r8, 8} },
#endif
#ifdef HAVE_FANCY_SIMD
- { GGML_TYPE_Q8_0_R8, {modify_q8_0_r8, 8} },
- { GGML_TYPE_Q8_K_R8, {modify_q8_k_r8, 8} },
+ { GGML_TYPE_Q8_0_R8, {modify_q8_0_r8, 8} },
+ { GGML_TYPE_Q8_K_R8, {modify_q8_k_r8, 8} },
+ { GGML_TYPE_Q8_KV_R8, {modify_q8_KV_r8, 8} },
#endif
};
auto it = k_mod_map.find(tensor->type);
@@ -6532,6 +6815,7 @@ void iqk_repack_tensor(struct ggml_tensor * tensor) {
{ GGML_TYPE_Q6_0, { GGML_TYPE_Q6_0_R4, 4, (Repack::repack_func)repack_q6_0} },
{ GGML_TYPE_Q8_0, { GGML_TYPE_Q8_0_R8, 8, (Repack::repack_func)repack_q8_0} },
{ GGML_TYPE_Q8_K, { GGML_TYPE_Q8_K_R8, 8, (Repack::repack_func)repack_q8_k} },
+ { GGML_TYPE_Q8_KV, { GGML_TYPE_Q8_KV_R8, 8, (Repack::repack_func)repack_q8_KV} },
#ifdef __AVX512BF16__
{ GGML_TYPE_BF16, { GGML_TYPE_BF16_R16, 16, (Repack::repack_func)repack_bf16<ggml_bf16_t>}},
{ GGML_TYPE_F16, { GGML_TYPE_BF16_R16, 16, (Repack::repack_func)repack_bf16<ggml_half>} },
diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h
index 97719361..76fbac3b 100644
--- a/ggml/src/iqk/iqk_quantize.h
+++ b/ggml/src/iqk/iqk_quantize.h
@@ -217,6 +217,18 @@ size_t quantize_q8_k_r8(const float * GGML_RESTRICT src, void * GGML_RESTRICT ds
void dequantize_row_q8_k_r8(const block_q8_k_r8 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void vec_dot_q8_k_r8_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void quantize_row_q8_KV_ref(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q8_KV(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+size_t quantize_q8_KV(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+void dequantize_row_q8_KV(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void vec_dot_q8_KV_q8_KV(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
+void quantize_row_q8_KV_r8_ref(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q8_KV_r8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+size_t quantize_q8_KV_r8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+void dequantize_row_q8_KV_r8(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void vec_dot_q8_KV_r8_q8_KV(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
void iqk_quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
void quantize_row_q8_K64_ref(const float * GGML_RESTRICT x, block_q8_K64 * GGML_RESTRICT y, int64_t k);
void quantize_row_q8_K64(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);