summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-01-15 18:19:22 +0200
committerGitHub <noreply@github.com>2025-01-15 18:19:22 +0200
commit0b74397d596bbcdfba27299393406d2b6330b133 (patch)
tree2101d059f79b6b268086c71878aa2da1c328c73d
parent49b27069fd267d3dac8de5d13141b4274e4be16b (diff)
CPU Flash Attention improvements (#172)
* Slightly faster FA for bf16 KV cache ~2-3% sort of thing. Sadly, when we go beyond 8k tokens, the advantage kind of goes away. * Slightly faster FA for Q8_0 KV cache * FA: allow bf16 for V-cache with any supported K-cache E.g., -ctk q8_0 -ctv bf16 is slightly faster than -ctk q8_0 -ctv q8_0 on Zen4 for not too long context lengths (say, <= 4096). * FA: much better bf16 kv-cache speed for large contexts We now hit 122 t/s for LLaMA-3.1-8B (quantized as iq4_xs and run-time-repacked) with a context of 32768. IIRC, the previous best for such large context was ~90 t/s. Non-negligible improvement at 16384 and 8192 as well: 173.4 and 214 t/s. * FA: slightly better quantized kv-cache speed for large contexts E.g., for q8_0 and context of 32768, we are now at 113 t/s for LLaMA-3.1-8B. Also simplified the quantized K*Q multiplication. * Fix q8_0 KV cache when not using FA - WIP (AVX2) 1. We add new types GGML_TYPE_Q8_0_X4 and GGML_TYPE_Q8_1_X4, and use those to quantize activations for quants that use Q8_0 or Q8_1 as their vec_dot type. 2. We revert the changes to quantize_row_q8_0 and quantize_row_q8_1 3. We use GGML_TYPE_Q8_0_X4 and GGML_TYPE_Q8_1_X4 as the vec_dot type 4. We change the FA implementation to use GGML_TYPE_Q8_0 rather than GGML_TYPE_Q8_0_X4 as the K and V types 5. We change the expected type to GGML_TYPE_Q8_0_X4/GGML_TYPE_Q8_1_X4 in iqk_mul_mat Also added an optimization in ggml_compute_forward_mul_mat when ne12*ne13 > 1 (K*Q and V*softmax(K*Q)) to process n12*ne13/GCD(n12*ne13, nthread) threads simultaneously using nthread/GCD(n12*ne13, nthread) threads per head. This results in a non-negligible performance gain for large contexts. Question: why is it not allowed to use quantized V-cache when not using FA? * Fix q8_0 KV cache when not using FA - NEON * Fix AVX2 Again the issue with _mm256_maddubs_epi16 overflowing that I keep forgetting. * FA: don't use large Q steps on AVX2 for fp16 K-cache * On Zen4 it is also better to not use large Q steps for fp16 K-cache --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--ggml/include/ggml.h2
-rw-r--r--ggml/src/ggml-quants.c105
-rw-r--r--ggml/src/ggml.c167
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp817
-rw-r--r--ggml/src/iqk/iqk_quantize.cpp251
-rw-r--r--ggml/src/iqk/iqk_quantize.h2
6 files changed, 738 insertions, 606 deletions
diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
index 118ce969..5eea7dcd 100644
--- a/ggml/include/ggml.h
+++ b/ggml/include/ggml.h
@@ -396,6 +396,8 @@ extern "C" {
//
GGML_TYPE_I2_S = 36,
//
+ GGML_TYPE_Q8_0_X4 = 98,
+ GGML_TYPE_Q8_1_X4 = 99,
GGML_TYPE_Q6_0 = 133,
GGML_TYPE_IQ1_BN = 134,
GGML_TYPE_IQ2_BN = 135,
diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c
index d460b84a..23ac9915 100644
--- a/ggml/src/ggml-quants.c
+++ b/ggml/src/ggml-quants.c
@@ -934,13 +934,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
block_q8_0 * restrict y = vy;
-#if GGML_USE_IQK_MULMAT
- const int nb4 = 4*(nb/4);
-#else
- const int nb4 = -1;
-#endif
#if defined(__ARM_NEON)
- block_q8_0_x4 * y4 = (block_q8_0_x4 *)vy;
for (int i = 0; i < nb; i++) {
int i4 = i/4, ir = i%4;
float32x4_t srcv [8];
@@ -959,27 +953,16 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
const float d = amax / ((1 << 7) - 1);
const float id = d ? 1.0f/d : 0.0f;
- if (i < nb4) {
- y4[i4].d[ir] = GGML_FP32_TO_FP16(d);
- } else {
- y[i].d = GGML_FP32_TO_FP16(d);
- }
+ y[i].d = GGML_FP32_TO_FP16(d);
for (int j = 0; j < 8; j++) {
const float32x4_t v = vmulq_n_f32(srcv[j], id);
const int32x4_t vi = vcvtnq_s32_f32(v);
- if (i < nb4) {
- y4[i4].qs[32*ir + 4*j + 0] = vgetq_lane_s32(vi, 0);
- y4[i4].qs[32*ir + 4*j + 1] = vgetq_lane_s32(vi, 1);
- y4[i4].qs[32*ir + 4*j + 2] = vgetq_lane_s32(vi, 2);
- y4[i4].qs[32*ir + 4*j + 3] = vgetq_lane_s32(vi, 3);
- } else {
- y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
- y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
- y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
- y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
- }
+ y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
+ y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
+ y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
+ y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
}
}
#elif defined(__wasm_simd128__)
@@ -1016,14 +999,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
}
}
#elif defined(__AVX2__) || defined(__AVX__)
- block_q8_0_x4 * y4 = (block_q8_0_x4 *)vy;
-#ifdef __AVX2__
- const bool pack = true;
-#else
- const bool pack = false;
-#endif
for (int i = 0; i < nb; i++) {
- int i4 = i/4, ir = i%4;
// Load elements into 4 AVX vectors
__m256 v0 = _mm256_loadu_ps( x );
__m256 v1 = _mm256_loadu_ps( x + 8 );
@@ -1045,11 +1021,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
// Quantize these floats
const float d = maxScalar / 127.f;
- if (pack && i < nb4) {
- y4[i4].d[ir] = GGML_FP32_TO_FP16(d);
- } else {
- y[i].d = GGML_FP32_TO_FP16(d);
- }
+ y[i].d = GGML_FP32_TO_FP16(d);
const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
const __m256 mul = _mm256_set1_ps( id );
@@ -1084,11 +1056,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
i0 = _mm256_permutevar8x32_epi32( i0, perm );
- if (i < nb4) {
- _mm256_storeu_si256((__m256i *)y4[i4].qs + ir, i0);
- } else {
- _mm256_storeu_si256((__m256i *)y[i].qs, i0);
- }
+ _mm256_storeu_si256((__m256i *)y[i].qs, i0);
#else
// Since we don't have in AVX some necessary functions,
// we split the registers in half and call AVX2 analogs from SSE
@@ -1287,15 +1255,8 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
block_q8_1 * restrict y = vy;
-#if GGML_USE_IQK_MULMAT
- const int nb4 = 4*(nb/4);
-#else
- const int nb4 = -1;
-#endif
#if defined(__ARM_NEON)
- block_q8_1_x4 * restrict y4 = vy;
for (int i = 0; i < nb; i++) {
- int i4 = i/4, ir = i%4;
float32x4_t srcv [8];
float32x4_t asrcv[8];
float32x4_t amaxv[8];
@@ -1312,11 +1273,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
const float d = amax / ((1 << 7) - 1);
const float id = d ? 1.0f/d : 0.0f;
- if (i < nb4) {
- y4[i4].d[ir] = GGML_FP32_TO_FP16(d);
- } else {
- y[i].d = GGML_FP32_TO_FP16(d);
- }
+ y[i].d = GGML_FP32_TO_FP16(d);
int32x4_t accv = vdupq_n_s32(0);
@@ -1324,26 +1281,15 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
const float32x4_t v = vmulq_n_f32(srcv[j], id);
const int32x4_t vi = vcvtnq_s32_f32(v);
- if (i < nb4) {
- y4[i4].qs[QK8_1*ir + 4*j + 0] = vgetq_lane_s32(vi, 0);
- y4[i4].qs[QK8_1*ir + 4*j + 1] = vgetq_lane_s32(vi, 1);
- y4[i4].qs[QK8_1*ir + 4*j + 2] = vgetq_lane_s32(vi, 2);
- y4[i4].qs[QK8_1*ir + 4*j + 3] = vgetq_lane_s32(vi, 3);
- } else {
- y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
- y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
- y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
- y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
- }
+ y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
+ y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
+ y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
+ y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
accv = vaddq_s32(accv, vi);
}
- if (i < nb4) {
- y4[i4].d[ir+4] = GGML_FP32_TO_FP16(d * vaddvq_s32(accv));
- } else {
- y[i].s = GGML_FP32_TO_FP16(d * vaddvq_s32(accv));
- }
+ y[i].s = GGML_FP32_TO_FP16(d * vaddvq_s32(accv));
}
#elif defined(__wasm_simd128__)
for (int i = 0; i < nb; i++) {
@@ -1389,14 +1335,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
wasm_i32x4_extract_lane(accv, 3)));
}
#elif defined(__AVX2__) || defined(__AVX__)
- block_q8_1_x4 * restrict y4 = vy;
-#ifdef __AVX2__
- const bool pack = true;
-#else
- const bool pack = false;
-#endif
for (int i = 0; i < nb; i++) {
- int i4 = i/4, ir = i%4;
// Load elements into 4 AVX vectors
__m256 v0 = _mm256_loadu_ps( x );
__m256 v1 = _mm256_loadu_ps( x + 8 );
@@ -1418,11 +1357,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
// Quantize these floats
const float d = max_scalar / 127.f;
- if (pack && i < nb4) {
- y4[i4].d[ir] = GGML_FP32_TO_FP16(d);
- } else {
- y[i].d = GGML_FP32_TO_FP16(d);
- }
+ y[i].d = GGML_FP32_TO_FP16(d);
const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f;
const __m256 mul = _mm256_set1_ps( id );
@@ -1446,11 +1381,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
#if defined(__AVX2__)
// Compute the sum of the quants and set y[i].s
- if (i < nb4) {
- y4[i4].d[ir+4] = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))));
- } else {
- y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))));
- }
+ y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))));
// Convert int32 to int16
i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
@@ -1464,11 +1395,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
i0 = _mm256_permutevar8x32_epi32( i0, perm );
- if (i < nb4) {
- _mm256_storeu_si256((__m256i *)y4[i4].qs + ir, i0);
- } else {
- _mm256_storeu_si256((__m256i *)y[i].qs, i0);
- }
+ _mm256_storeu_si256((__m256i *)y[i].qs, i0);
#else
// Since we don't have in AVX some necessary functions,
// we split the registers in half and call AVX2 analogs from SSE
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index b9f9b3d8..bcb8bf41 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -714,8 +714,12 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_q4_0,
.from_float_ref = (ggml_from_float_t) quantize_row_q4_0_ref,
.vec_dot = ggml_vec_dot_q4_0_q8_0,
-#if GGML_USE_IQK_MULMAT && defined __AVX2__
- .vec_dot_type = GGML_TYPE_Q8_1,
+#if GGML_USE_IQK_MULMAT
+#if defined __AVX2__
+ .vec_dot_type = GGML_TYPE_Q8_1_X4,
+#else
+ .vec_dot_type = GGML_TYPE_Q8_0_X4,
+#endif
#else
.vec_dot_type = GGML_TYPE_Q8_0,
#endif
@@ -735,7 +739,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_q4_1,
.from_float_ref = (ggml_from_float_t) quantize_row_q4_1_ref,
.vec_dot = ggml_vec_dot_q4_1_q8_1,
+#if GGML_USE_IQK_MULMAT
+ .vec_dot_type = GGML_TYPE_Q8_1_X4,
+#else
.vec_dot_type = GGML_TYPE_Q8_1,
+#endif
#if defined (__ARM_FEATURE_MATMUL_INT8)
.nrows = 2,
#else
@@ -778,8 +786,12 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_q5_0,
.from_float_ref = (ggml_from_float_t) quantize_row_q5_0_ref,
.vec_dot = ggml_vec_dot_q5_0_q8_0,
-#if GGML_USE_IQK_MULMAT && defined __AVX2__
- .vec_dot_type = GGML_TYPE_Q8_1,
+#if GGML_USE_IQK_MULMAT
+#if defined __AVX2__
+ .vec_dot_type = GGML_TYPE_Q8_1_X4,
+#else
+ .vec_dot_type = GGML_TYPE_Q8_0_X4,
+#endif
#else
.vec_dot_type = GGML_TYPE_Q8_0,
#endif
@@ -795,7 +807,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_q5_1,
.from_float_ref = (ggml_from_float_t) quantize_row_q5_1_ref,
.vec_dot = ggml_vec_dot_q5_1_q8_1,
+#if GGML_USE_IQK_MULMAT
+ .vec_dot_type = GGML_TYPE_Q8_1_X4,
+#else
.vec_dot_type = GGML_TYPE_Q8_1,
+#endif
.nrows = 1,
.row_meta_size = 0,
},
@@ -808,8 +824,12 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_q6_0,
.from_float_ref = (ggml_from_float_t) quantize_row_q6_0_ref,
.vec_dot = ggml_vec_dot_q6_0_q8_0,
-#if GGML_USE_IQK_MULMAT && defined __AVX2__
- .vec_dot_type = GGML_TYPE_Q8_1,
+#if GGML_USE_IQK_MULMAT
+#if defined __AVX2__
+ .vec_dot_type = GGML_TYPE_Q8_1_X4,
+#else
+ .vec_dot_type = GGML_TYPE_Q8_0_X4,
+#endif
#else
.vec_dot_type = GGML_TYPE_Q8_0,
#endif
@@ -826,8 +846,16 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float_ref = (ggml_from_float_t) quantize_row_q8_0_ref,
.from_float_to_mat = quantize_mat_q8_0,
.vec_dot = ggml_vec_dot_q8_0_q8_0,
-#if GGML_USE_IQK_MULMAT && defined __AVX2__
- .vec_dot_type = GGML_TYPE_Q8_1,
+#if GGML_USE_IQK_MULMAT
+#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__)
+ // Remember: we cannot add 128 to the Q8 quants and use iblock sum in Q8_1 to subtract as we do on Zen4 for pure AVX2
+ // because there the result of the _mm256_maddubs_epi16() instruction may overflow the int16_t range
+ // (and it gets satured if it does), leading to wrong results.
+ // TODO: expose HAVE_FANCY_SIMD from iqk_mul_mat.cpp and use #ifdef HAVE_FANCY_SIMD instead of the above.
+ .vec_dot_type = GGML_TYPE_Q8_1_X4,
+#else
+ .vec_dot_type = GGML_TYPE_Q8_0_X4,
+#endif
#else
.vec_dot_type = GGML_TYPE_Q8_0,
#endif
@@ -849,6 +877,26 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.nrows = 1,
.row_meta_size = 0,
},
+ [GGML_TYPE_Q8_0_X4] = {
+ .type_name = "q8_0_x4",
+ .blck_size = QK8_0,
+ .type_size = sizeof(block_q8_0),
+ .is_quantized = true,
+ .from_float = quantize_row_q8_0_x4,
+ .from_float_ref = quantize_row_q8_0_x4,
+ .nrows = 1,
+ .row_meta_size = 0,
+ },
+ [GGML_TYPE_Q8_1_X4] = {
+ .type_name = "q8_1_x4",
+ .blck_size = QK8_1,
+ .type_size = sizeof(block_q8_1),
+ .is_quantized = true,
+ .from_float = quantize_row_q8_1_x4,
+ .from_float_ref = quantize_row_q8_1_x4,
+ .nrows = 1,
+ .row_meta_size = 0,
+ },
[GGML_TYPE_Q2_K] = {
.type_name = "q2_K",
.blck_size = QK_K,
@@ -1196,8 +1244,12 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_iq4_nl,
.from_float_ref = (ggml_from_float_t)quantize_row_iq4_nl_ref,
.vec_dot = ggml_vec_dot_iq4_nl_q8_0,
-#if GGML_USE_IQK_MULMAT && defined __AVX2__
- .vec_dot_type = GGML_TYPE_Q8_1,
+#if GGML_USE_IQK_MULMAT
+#if defined __AVX2__
+ .vec_dot_type = GGML_TYPE_Q8_1_X4,
+#else
+ .vec_dot_type = GGML_TYPE_Q8_0_X4,
+#endif
#else
.vec_dot_type = GGML_TYPE_Q8_0,
#endif
@@ -1516,8 +1568,12 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_iq4_nl_r4,
.from_float_ref = (ggml_from_float_t)quantize_row_iq4_nl_r4_ref,
.vec_dot = vec_dot_iq4_nl_r4_q8_0,
-#if GGML_USE_IQK_MULMAT && defined __AVX2__
- .vec_dot_type = GGML_TYPE_Q8_1,
+#if GGML_USE_IQK_MULMAT
+#if defined __AVX2__
+ .vec_dot_type = GGML_TYPE_Q8_1_X4,
+#else
+ .vec_dot_type = GGML_TYPE_Q8_0_X4,
+#endif
#else
.vec_dot_type = GGML_TYPE_Q8_0,
#endif
@@ -1546,8 +1602,12 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_q4_0_r4,
.from_float_ref = (ggml_from_float_t)quantize_row_q4_0_r4_ref,
.vec_dot = vec_dot_q4_0_r4_q8_0,
-#if GGML_USE_IQK_MULMAT && defined __AVX2__
- .vec_dot_type = GGML_TYPE_Q8_1,
+#if GGML_USE_IQK_MULMAT
+#if defined __AVX2__
+ .vec_dot_type = GGML_TYPE_Q8_1_X4,
+#else
+ .vec_dot_type = GGML_TYPE_Q8_0_X4,
+#endif
#else
.vec_dot_type = GGML_TYPE_Q8_0,
#endif
@@ -1563,8 +1623,12 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_q8_0_r4,
.from_float_ref = (ggml_from_float_t)quantize_row_q8_0_r4_ref,
.vec_dot = vec_dot_q8_0_r4_q8_0,
-#if GGML_USE_IQK_MULMAT && defined __AVX2__
- .vec_dot_type = GGML_TYPE_Q8_1,
+#if GGML_USE_IQK_MULMAT
+#if defined __AVX2__
+ .vec_dot_type = GGML_TYPE_Q8_1_X4,
+#else
+ .vec_dot_type = GGML_TYPE_Q8_0_X4,
+#endif
#else
.vec_dot_type = GGML_TYPE_Q8_0,
#endif
@@ -1580,8 +1644,12 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_q5_0_r4,
.from_float_ref = (ggml_from_float_t)quantize_row_q5_0_r4_ref,
.vec_dot = vec_dot_q5_0_r4_q8_0,
-#if GGML_USE_IQK_MULMAT && defined __AVX2__
- .vec_dot_type = GGML_TYPE_Q8_1,
+#if GGML_USE_IQK_MULMAT
+#if defined __AVX2__
+ .vec_dot_type = GGML_TYPE_Q8_1_X4,
+#else
+ .vec_dot_type = GGML_TYPE_Q8_0_X4,
+#endif
#else
.vec_dot_type = GGML_TYPE_Q8_0,
#endif
@@ -1597,8 +1665,12 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_q6_0_r4,
.from_float_ref = (ggml_from_float_t)quantize_row_q6_0_r4_ref,
.vec_dot = vec_dot_q6_0_r4_q8_0,
-#if GGML_USE_IQK_MULMAT && defined __AVX2__
- .vec_dot_type = GGML_TYPE_Q8_1,
+#if GGML_USE_IQK_MULMAT
+#if defined __AVX2__
+ .vec_dot_type = GGML_TYPE_Q8_1_X4,
+#else
+ .vec_dot_type = GGML_TYPE_Q8_0_X4,
+#endif
#else
.vec_dot_type = GGML_TYPE_Q8_0,
#endif
@@ -11280,6 +11352,8 @@ static void ggml_compute_forward_add1(
case GGML_TYPE_Q6_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
+ case GGML_TYPE_Q8_0_X4:
+ case GGML_TYPE_Q8_1_X4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q2_K_R4:
case GGML_TYPE_Q3_K:
@@ -11443,6 +11517,8 @@ static void ggml_compute_forward_acc(
case GGML_TYPE_Q6_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
+ case GGML_TYPE_Q8_0_X4:
+ case GGML_TYPE_Q8_1_X4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q2_K_R4:
case GGML_TYPE_Q3_K:
@@ -13889,6 +13965,14 @@ static void ggml_compute_forward_mul_mat_one_chunk(
}
}
+static inline uint32_t simple_gcd(uint32_t a, uint32_t b) {
+ while (a != b) {
+ if (a > b) a -= b;
+ else b -= a;
+ }
+ return a;
+}
+
static void ggml_compute_forward_mul_mat(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
@@ -13905,10 +13989,12 @@ static void ggml_compute_forward_mul_mat(
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
ggml_from_float_t const from_float = type_traits[vec_dot_type].from_float;
- ggml_from_float_to_mat_t const from_float_to_mat = type_traits[vec_dot_type].from_float_to_mat;
int64_t const vec_dot_num_rows = type_traits[type].nrows;
int64_t const matmul_num_cols = type_traits[type].ncols;
+#if !GGML_USE_IQK_MULMAT
+ ggml_from_float_to_mat_t const from_float_to_mat = type_traits[vec_dot_type].from_float_to_mat;
int64_t const blck_size_interleave = type_traits[type].blck_size_interleave;
+#endif
ggml_gemv_t const gemv = type_traits[type].gemv;
ggml_gemm_t const gemm = type_traits[type].gemm;
@@ -14011,6 +14097,7 @@ UseGgmlGemm1:;
for (int64_t i13 = 0; i13 < ne13; ++i13) {
for (int64_t i12 = 0; i12 < ne12; ++i12) {
int64_t i11_processed = 0;
+#if !GGML_USE_IQK_MULMAT
if ((ggml_n_dims(src1) == 2) && from_float_to_mat && gemm) {
for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
from_float_to_mat((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
@@ -14019,6 +14106,7 @@ UseGgmlGemm1:;
}
i11_processed = ne11 - ne11 % 4;
}
+#endif
for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
@@ -14049,14 +14137,31 @@ AlreadyQuantized:;
#if GGML_USE_IQK_MULMAT
if (src1->type != vec_dot_type && dst->type == GGML_TYPE_F32) {
+ // When K*Q and V*softmax(K*Q) (so ne12*ne13 > 1), it is better (faster) to have fewer threads processing
+ // one matrix multiplication, but work on several heads at once.
+ // Hence, we find the GCD(n12*ne13, nth) and have nth/GCD(n12*ne13, nth) threads per head.
+ // Leaving the previous version commented out for now just in case.
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
- for (int64_t i13 = 0; i13 < ne13; i13++)
- for (int64_t i12 = 0; i12 < ne12; i12++)
- if (!iqk_mul_mat(ne01, ne11, ne00,
- src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01, ///ggml_type_size(src0->type),
- vec_dot_type, (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size, row_size, ///ggml_type_size(vec_dot_type),
- (float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type),
- ith, nth)) goto IQK_MulMat_Not_Available2;
+ int ntg = simple_gcd(ne12*ne13, nth);
+ int counter = 0;
+ for (int64_t i13 = 0; i13 < ne13; i13++) {
+ for (int64_t i12 = 0; i12 < ne12; i12++) {
+ if (counter++ % ntg == ith%ntg) {
+ if (!iqk_mul_mat(ne01, ne11, ne00,
+ src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01, ///ggml_type_size(src0->type),
+ vec_dot_type, (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size, row_size, ///ggml_type_size(vec_dot_type),
+ (float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type),
+ ith/ntg, nth/ntg)) goto IQK_MulMat_Not_Available2;
+ }
+ }
+ }
+ //for (int64_t i13 = 0; i13 < ne13; i13++)
+ // for (int64_t i12 = 0; i12 < ne12; i12++)
+ // if (!iqk_mul_mat(ne01, ne11, ne00,
+ // src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01, ///ggml_type_size(src0->type),
+ // vec_dot_type, (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size, row_size, ///ggml_type_size(vec_dot_type),
+ // (float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type),
+ // ith, nth)) goto IQK_MulMat_Not_Available2;
return;
}
IQK_MulMat_Not_Available2:;
@@ -15055,6 +15160,8 @@ static void ggml_compute_forward_set(
case GGML_TYPE_Q6_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
+ case GGML_TYPE_Q8_0_X4:
+ case GGML_TYPE_Q8_1_X4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q2_K_R4:
case GGML_TYPE_Q3_K:
@@ -15352,6 +15459,8 @@ static void ggml_compute_forward_get_rows(
case GGML_TYPE_Q6_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
+ case GGML_TYPE_Q8_0_X4:
+ case GGML_TYPE_Q8_1_X4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q2_K_R4:
case GGML_TYPE_Q3_K:
@@ -15977,6 +16086,8 @@ static void ggml_compute_forward_clamp(
case GGML_TYPE_Q6_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
+ case GGML_TYPE_Q8_0_X4:
+ case GGML_TYPE_Q8_1_X4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q2_K_R4:
case GGML_TYPE_Q3_K:
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index cfca477d..5577ea99 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -6755,10 +6755,10 @@ struct Q_Unpacker {
}
};
-struct Q8_0_x4_Unpacker {
+struct Q8_0_x4_Unpacker_256 {
using Sum4T = Sum4TypeQ80;
inline static int block_size() { return QK8_0; }
- Q8_0_x4_Unpacker(const void * vx, size_t bx) : cx_0((const char *)vx), x((const block_q8_0_x4 *)cx_0), bx(bx) {}
+ Q8_0_x4_Unpacker_256(const void * vx, size_t bx) : cx_0((const char *)vx), x((const block_q8_0_x4 *)cx_0), bx(bx) {}
const char * cx_0;
const block_q8_0_x4 * x;
@@ -6784,6 +6784,44 @@ struct Q8_0_x4_Unpacker {
}
};
+#ifdef HAVE_FANCY_SIMD
+struct Q8_0_x4_Unpacker_512 {
+ using Sum4T = Sum4TypeQ81;
+ inline static int block_size() { return QK8_0; }
+ Q8_0_x4_Unpacker_512(const void * vx, size_t bx) : cx_0((const char *)vx), x((const block_q8_0_x4 *)cx_0), bx(bx) {}
+
+ const char * cx_0;
+ const block_q8_0_x4 * x;
+ size_t bx;
+ const __m128 min = _mm_set1_ps(-128.f);
+
+ __m256i qx[4];
+
+ inline const __m256i* quants() const { return qx; }
+
+ inline void set_row(int ix) { x = (const block_q8_0_x4 *)(cx_0 + ix*bx); }
+
+ inline auto set_block_4(int i) {
+ auto scales = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)x[i].d));
+ for (int j = 0; j < 4; ++j) {
+ qx[j] = _mm256_loadu_si256((const __m256i *)x[i].qs + j);
+ qx[j] = _mm256_xor_si256(qx[j], _mm256_set1_epi8(0x80));
+ }
+ return _mm256_set_m128(_mm_mul_ps(scales, min), scales);
+ }
+ inline auto set_block(int i) {
+ auto q8 = (const block_q8_0 *)(x + i);
+ qx[0] = _mm256_loadu_si256((const __m256i *)q8->qs);
+ qx[0] = _mm256_xor_si256(qx[0], _mm256_set1_epi8(0x80));
+ float d = GGML_FP16_TO_FP32(q8->d);
+ return std::make_pair(d, -128.f*d);
+ }
+};
+using Q8_0_x4_Unpacker = Q8_0_x4_Unpacker_512;
+#else
+using Q8_0_x4_Unpacker = Q8_0_x4_Unpacker_256;
+#endif
+
struct Q8_0_Unpacker final : public Q_Unpacker<block_q8_0, ScaleHelperQ_0, Q8_0_Dequantizer> {
Q8_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
using Sum4T = Sum4TypeQ80;
@@ -7414,37 +7452,42 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
case GGML_TYPE_Q4_0:
assert (ne00 % QK4_0 == 0);
MulMat::set_functions<Q4_0_1_Unpacker>(mm);
- expected_typeB = GGML_TYPE_Q8_1;
+ expected_typeB = GGML_TYPE_Q8_1_X4;
break;
case GGML_TYPE_Q4_1:
assert (ne00 % QK4_1 == 0);
MulMat::set_functions<Q4_1_Unpacker>(mm);
- expected_typeB = GGML_TYPE_Q8_1;
+ expected_typeB = GGML_TYPE_Q8_1_X4;
break;
case GGML_TYPE_Q5_0:
assert (ne00 % QK5_0 == 0);
MulMat::set_functions<Q5_0_1_Unpacker>(mm);
- expected_typeB = GGML_TYPE_Q8_1;
+ expected_typeB = GGML_TYPE_Q8_1_X4;
break;
case GGML_TYPE_Q5_1:
assert (ne00 % QK5_1 == 0);
MulMat::set_functions<Q5_1_Unpacker>(mm);
- expected_typeB = GGML_TYPE_Q8_1;
+ expected_typeB = GGML_TYPE_Q8_1_X4;
break;
case GGML_TYPE_Q6_0:
assert (ne00 % QK6_0 == 0);
MulMat::set_functions<Q6_0_1_Unpacker>(mm);
- expected_typeB = GGML_TYPE_Q8_1;
+ expected_typeB = GGML_TYPE_Q8_1_X4;
break;
case GGML_TYPE_Q8_0:
assert (ne00 % QK8_0 == 0);
+#ifdef HAVE_FANCY_SIMD
MulMat::set_functions<Q8_0_1_Unpacker>(mm);
- expected_typeB = GGML_TYPE_Q8_1;
+ expected_typeB = GGML_TYPE_Q8_1_X4;
+#else
+ MulMat::set_functions<Q8_0_Unpacker>(mm);
+ expected_typeB = GGML_TYPE_Q8_0_X4;
+#endif
break;
case GGML_TYPE_IQ4_NL:
assert (ne00 % QK4_NL == 0);
MulMat::set_functions<IQ4_NL_Unpacker>(mm);
- expected_typeB = GGML_TYPE_Q8_1;
+ expected_typeB = GGML_TYPE_Q8_1_X4;
break;
case GGML_TYPE_IQ4_NL_R4:
assert (ne00 % QK4_NL == 0);
@@ -7456,7 +7499,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
mm.funcs[5] = mul_mat_iq4_nl_r4_q8_1<6>;
mm.funcs[6] = mul_mat_iq4_nl_r4_q8_1<7>;
mm.funcs[7] = mul_mat_iq4_nl_r4_q8_1<8>;
- expected_typeB = GGML_TYPE_Q8_1;
+ expected_typeB = GGML_TYPE_Q8_1_X4;
break;
case GGML_TYPE_IQ4_XS_R4:
assert (ne00 % QK_K == 0);
@@ -7689,7 +7732,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
mm.funcs[5] = mul_mat_q4_0_r4_q8_1<6>;
mm.funcs[6] = mul_mat_q4_0_r4_q8_1<7>;
mm.funcs[7] = mul_mat_q4_0_r4_q8_1<8>;
- expected_typeB = GGML_TYPE_Q8_1;
+ expected_typeB = GGML_TYPE_Q8_1_X4;
break;
case GGML_TYPE_Q5_0_R4:
assert (ne00 % QK4_NL == 0);
@@ -7701,7 +7744,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
mm.funcs[5] = mul_mat_q5_0_r4_q8_1<6>;
mm.funcs[6] = mul_mat_q5_0_r4_q8_1<7>;
mm.funcs[7] = mul_mat_q5_0_r4_q8_1<8>;
- expected_typeB = GGML_TYPE_Q8_1;
+ expected_typeB = GGML_TYPE_Q8_1_X4;
break;
case GGML_TYPE_Q6_0_R4:
assert (ne00 % QK4_NL == 0);
@@ -7713,7 +7756,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
mm.funcs[5] = mul_mat_q6_0_r4_q8_1<6>;
mm.funcs[6] = mul_mat_q6_0_r4_q8_1<7>;
mm.funcs[7] = mul_mat_q6_0_r4_q8_1<8>;
- expected_typeB = GGML_TYPE_Q8_1;
+ expected_typeB = GGML_TYPE_Q8_1_X4;
break;
case GGML_TYPE_Q8_0_R4:
assert (ne00 % QK4_NL == 0);
@@ -7725,7 +7768,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
mm.funcs[5] = mul_mat_q8_0_r4_q8_1<6>;
mm.funcs[6] = mul_mat_q8_0_r4_q8_1<7>;
mm.funcs[7] = mul_mat_q8_0_r4_q8_1<8>;
- expected_typeB = GGML_TYPE_Q8_1;
+ expected_typeB = GGML_TYPE_Q8_1_X4;
break;
default:
@@ -11998,35 +12041,35 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
break;
case GGML_TYPE_Q4_0:
MulMat::set_functions<DequantizerQ40>(m);
- expected_Btype = GGML_TYPE_Q8_0;
+ expected_Btype = GGML_TYPE_Q8_0_X4;
break;
case GGML_TYPE_Q4_1:
MulMat::set_functions<DequantizerQ41>(m);
- expected_Btype = GGML_TYPE_Q8_1;
+ expected_Btype = GGML_TYPE_Q8_1_X4;
break;
case GGML_TYPE_Q5_0:
MulMat::set_functions<DequantizerQ50>(m);
- expected_Btype = GGML_TYPE_Q8_0;
+ expected_Btype = GGML_TYPE_Q8_0_X4;
break;
case GGML_TYPE_Q5_1:
MulMat::set_functions<DequantizerQ51>(m);
- expected_Btype = GGML_TYPE_Q8_1;
+ expected_Btype = GGML_TYPE_Q8_1_X4;
break;
case GGML_TYPE_Q6_0:
MulMat::set_functions<DequantizerQ60>(m);
- expected_Btype = GGML_TYPE_Q8_0;
+ expected_Btype = GGML_TYPE_Q8_0_X4;
break;
case GGML_TYPE_Q8_0:
MulMat::set_functions<DequantizerQ80>(m);
- expected_Btype = GGML_TYPE_Q8_0;
+ expected_Btype = GGML_TYPE_Q8_0_X4;
break;
case GGML_TYPE_IQ4_NL:
MulMat::set_functions<DequantizerIQ4NL>(m);
- expected_Btype = GGML_TYPE_Q8_0;
+ expected_Btype = GGML_TYPE_Q8_0_X4;
break;
case GGML_TYPE_IQ4_NL_R4:
SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r4_q8_0, IQ4_NL_R4_Dequantizer);
- expected_Btype = GGML_TYPE_Q8_0;
+ expected_Btype = GGML_TYPE_Q8_0_X4;
break;
case GGML_TYPE_IQ4_XS_R4:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq4_xs_r4_q8_k);
@@ -12103,19 +12146,19 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
break;
case GGML_TYPE_Q4_0_R4:
SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r4_q8_0, Q4_0_R4_Dequantizer);
- expected_Btype = GGML_TYPE_Q8_0;
+ expected_Btype = GGML_TYPE_Q8_0_X4;
break;
case GGML_TYPE_Q5_0_R4:
SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r4_q8_0, Q5_0_R4_Dequantizer);
- expected_Btype = GGML_TYPE_Q8_0;
+ expected_Btype = GGML_TYPE_Q8_0_X4;
break;
case GGML_TYPE_Q6_0_R4:
SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r4_q8_0, Q6_0_R4_Dequantizer);
- expected_Btype = GGML_TYPE_Q8_0;
+ expected_Btype = GGML_TYPE_Q8_0_X4;
break;
case GGML_TYPE_Q8_0_R4:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_q8_0_r4_q8_0);
- expected_Btype = GGML_TYPE_Q8_0;
+ expected_Btype = GGML_TYPE_Q8_0_X4;
break;
default:
return false;
@@ -12407,281 +12450,36 @@ struct HelperF16 final : public BaseHelper<step> {
}
};
-void quantize_row_q8_0(const float * x, block_q8_0 * y, int k) {
- const int nb = k / QK8_0;
- const int nb4 = 4*(nb/4);
-
-#if defined(__aarch64__)
- block_q8_0_x4 * y4 = (block_q8_0_x4 *)y;
- for (int i = 0; i < nb; i++) {
- int i4 = i/4, ir = i%4;
- float32x4_t srcv [8];
- float32x4_t asrcv[8];
- float32x4_t amaxv[8];
-
- for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j);
- for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]);
-
- for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]);
- for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]);
- for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]);
-
- const float amax = vmaxvq_f32(amaxv[0]);
-
- const float d = amax / ((1 << 7) - 1);
- const float id = d ? 1.0f/d : 0.0f;
-
- if (i < nb4) {
- y4[i4].d[ir] = GGML_FP32_TO_FP16(d);
- } else {
- y[i].d = GGML_FP32_TO_FP16(d);
- }
-
- for (int j = 0; j < 8; j++) {
- const float32x4_t v = vmulq_n_f32(srcv[j], id);
- const int32x4_t vi = vcvtnq_s32_f32(v);
-
- if (i < nb4) {
- y4[i4].qs[32*ir + 4*j + 0] = vgetq_lane_s32(vi, 0);
- y4[i4].qs[32*ir + 4*j + 1] = vgetq_lane_s32(vi, 1);
- y4[i4].qs[32*ir + 4*j + 2] = vgetq_lane_s32(vi, 2);
- y4[i4].qs[32*ir + 4*j + 3] = vgetq_lane_s32(vi, 3);
- } else {
- y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
- y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
- y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
- y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
- }
- }
- }
-#else
- block_q8_0_x4 * y4 = (block_q8_0_x4 *)y;
- for (int i = 0; i < nb; i++) {
- int i4 = i/4, ir = i%4;
- // Load elements into 4 AVX vectors
- __m256 v0 = _mm256_loadu_ps( x );
- __m256 v1 = _mm256_loadu_ps( x + 8 );
- __m256 v2 = _mm256_loadu_ps( x + 16 );
- __m256 v3 = _mm256_loadu_ps( x + 24 );
- x += 32;
-
- const __m256 signBit = _mm256_set1_ps( -0.0f );
- __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
-
- __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
- max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
- max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
- const float maxScalar = _mm_cvtss_f32( max4 );
-
- const float d = maxScalar / 127.f;
- if (i < nb4) {
- y4[i4].d[ir] = GGML_FP32_TO_FP16(d);
- } else {
- y[i].d = GGML_FP32_TO_FP16(d);
- }
- const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
- const __m256 mul = _mm256_set1_ps( id );
-
- v0 = _mm256_mul_ps( v0, mul );
- v1 = _mm256_mul_ps( v1, mul );
- v2 = _mm256_mul_ps( v2, mul );
- v3 = _mm256_mul_ps( v3, mul );
-
- v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
- v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
- v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
- v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
-
- __m256i i0 = _mm256_cvtps_epi32( v0 );
- __m256i i1 = _mm256_cvtps_epi32( v1 );
- __m256i i2 = _mm256_cvtps_epi32( v2 );
- __m256i i3 = _mm256_cvtps_epi32( v3 );
-
- // Convert int32 to int16
- i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
- i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
- // Convert int16 to int8
- i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
-
- // We got our precious signed bytes, but the order is now wrong
- // These AVX2 pack instructions process 16-byte pieces independently
- // The following instruction is fixing the order
- const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
- i0 = _mm256_permutevar8x32_epi32( i0, perm );
-
- if (i < nb4) {
- _mm256_storeu_si256((__m256i *)y4[i4].qs + ir, i0);
- } else {
- _mm256_storeu_si256((__m256i *)y[i].qs, i0);
- }
- }
-#endif
-}
-
-void quantize_row_q8_1(const float * x, block_q8_1 * y, int k) {
- assert(k % QK8_1 == 0);
- const int nb = k / QK8_1;
-
- const int nb4 = 4*(nb/4);
- block_q8_1_x4 * y4 = (block_q8_1_x4 *)y;
-#if defined(__aarch64__)
- for (int i = 0; i < nb; i++) {
- int i4 = i/4, ir = i%4;
- float32x4_t srcv [8];
- float32x4_t asrcv[8];
- float32x4_t amaxv[8];
-
- for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j);
- for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]);
-
- for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]);
- for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]);
- for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]);
-
- const float amax = vmaxvq_f32(amaxv[0]);
-
- const float d = amax / ((1 << 7) - 1);
- const float id = d ? 1.0f/d : 0.0f;
-
- if (i < nb4) {
- y4[i4].d[ir] = GGML_FP32_TO_FP16(d);
- } else {
- y[i].d = GGML_FP32_TO_FP16(d);
- }
-
- int32x4_t accv = vdupq_n_s32(0);
-
- for (int j = 0; j < 8; j++) {
- const float32x4_t v = vmulq_n_f32(srcv[j], id);
- const int32x4_t vi = vcvtnq_s32_f32(v);
-
- if (i < nb4) {
- y4[i4].qs[QK8_1*ir + 4*j + 0] = vgetq_lane_s32(vi, 0);
- y4[i4].qs[QK8_1*ir + 4*j + 1] = vgetq_lane_s32(vi, 1);
- y4[i4].qs[QK8_1*ir + 4*j + 2] = vgetq_lane_s32(vi, 2);
- y4[i4].qs[QK8_1*ir + 4*j + 3] = vgetq_lane_s32(vi, 3);
- } else {
- y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
- y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
- y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
- y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
- }
-
- accv = vaddq_s32(accv, vi);
- }
-
- if (i < nb4) {
- y4[i4].d[ir+4] = GGML_FP32_TO_FP16(d * vaddvq_s32(accv));
- } else {
- y[i].s = GGML_FP32_TO_FP16(d * vaddvq_s32(accv));
- }
- }
-#else
- for (int i = 0; i < nb; i++) {
- int i4 = i/4, ir = i%4;
- // Load elements into 4 AVX vectors
- __m256 v0 = _mm256_loadu_ps( x );
- __m256 v1 = _mm256_loadu_ps( x + 8 );
- __m256 v2 = _mm256_loadu_ps( x + 16 );
- __m256 v3 = _mm256_loadu_ps( x + 24 );
- x += 32;
-
- // Compute max(abs(e)) for the block
- const __m256 signBit = _mm256_set1_ps( -0.0f );
- __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
-
- __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
- max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
- max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
- const float max_scalar = _mm_cvtss_f32( max4 );
-
- // Quantize these floats
- const float d = max_scalar / 127.f;
- if (i < nb4) {
- y4[i4].d[ir] = GGML_FP32_TO_FP16(d);
- } else {
- y[i].d = GGML_FP32_TO_FP16(d);
- }
- const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f;
- const __m256 mul = _mm256_set1_ps( id );
-
- // Apply the multiplier
- v0 = _mm256_mul_ps( v0, mul );
- v1 = _mm256_mul_ps( v1, mul );
- v2 = _mm256_mul_ps( v2, mul );
- v3 = _mm256_mul_ps( v3, mul );
-
- // Round to nearest integer
- v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
- v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
- v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
- v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
-
- // Convert floats to integers
- __m256i i0 = _mm256_cvtps_epi32( v0 );
- __m256i i1 = _mm256_cvtps_epi32( v1 );
- __m256i i2 = _mm256_cvtps_epi32( v2 );
- __m256i i3 = _mm256_cvtps_epi32( v3 );
-
- // Compute the sum of the quants and set y[i].s
- if (i < nb4) {
- y4[i4].d[ir+4] = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))));
- } else {
- y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))));
- }
-
- // Convert int32 to int16
- i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
- i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
- // Convert int16 to int8
- i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
-
- // We got our precious signed bytes, but the order is now wrong
- // These AVX2 pack instructions process 16-byte pieces independently
- // The following instruction is fixing the order
- const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
- i0 = _mm256_permutevar8x32_epi32( i0, perm );
-
- if (i < nb4) {
- _mm256_storeu_si256((__m256i *)y4[i4].qs + ir, i0);
- } else {
- _mm256_storeu_si256((__m256i *)y[i].qs, i0);
- }
- }
-#endif
-}
-
template <int D, int step>
struct HelperQ80 final : public BaseHelper<step> {
using Base = BaseHelper<step>;
+#ifdef HAVE_FANCY_SIMD
+ //using block_q8 = block_q8_1;
+ using block_q8 = block_q8_1;
+#else
using block_q8 = block_q8_0;
+#endif
HelperQ80(const char * data, int stride) : Base(data, stride) {}
// Needed for v * softmax(k * q)
inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const {
int j = F16::block_size*i;
- auto dl = (const block_q8_0_x4 *)Base::lblock(l1) + j/(4*QK8_0);
- int ii = (j/QK8_0)%4;
+ auto dl = (const block_q8_0 *)Base::lblock(l1) + j/QK8_0;
#ifdef __aarch64__
- const float16_t * d = (const float16_t *)dl->d;
- auto vd = F16::set1(d[ii]);
- auto qs = vld1_s8_x2(dl->qs + 32*ii + j%32);
+ auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d));
+ int ii = j%QK8_0;
+ auto qs = vld1_s8_x2(dl->qs + ii);
v1 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[0])));
v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[1])));
#else
- auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d[ii]));
+ auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d));
#ifdef HAVE_FANCY_SIMD
- v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl->qs+2*ii+0))));
- v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl->qs+2*ii+1))));
+ v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl->qs+0))));
+ v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl->qs+1))));
#else
- v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(dl->qs+32*ii+j%32)))));
- v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(dl->qs+32*ii+j%32+8)))));
+ int ii = j%QK8_0;
+ v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(dl->qs+ii)+0))));
+ v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(dl->qs+ii)+1))));
#endif
#endif
}
@@ -12689,7 +12487,7 @@ struct HelperQ80 final : public BaseHelper<step> {
static inline void convert(int nq, int stride_q, const float * q, block_q8_0 * y) {
GGML_ASSERT(nq <= step);
for (int i = 0; i < nq; ++i) {
- quantize_row_q8_0(q, y, D);
+ quantize_row_q8_0_x4(q, y, D);
q += stride_q;
y += D/QK8_0;
}
@@ -12698,7 +12496,7 @@ struct HelperQ80 final : public BaseHelper<step> {
static inline void convert(int nq, int stride_q, const float * q, block_q8_1 * y) {
GGML_ASSERT(nq <= step);
for (int i = 0; i < nq; ++i) {
- quantize_row_q8_1(q, y, D);
+ quantize_row_q8_1_x4(q, y, D);
q += stride_q;
y += D/QK8_1;
}
@@ -13144,13 +12942,15 @@ struct FlashQKV {
}
}
}
- F16::Data v1, v2;
- for (int l1 = 0; l1 < k_step; ++l1) {
- vh.load(l1, i, v1, v2);
+ F16::Data v1, v2, v3, v4;
+ for (int l1 = 0; l1 < k_step; l1 += 2) {
+ vh.load(l1+0, i, v1, v2);
+ vh.load(l1+1, i, v3, v4);
for (int j = 0; j < q_step; ++j) {
- auto vs = F16::set1(fms.cache[k_step*j + l1]);
- vk[2*j+0] = F16::fmadd(vk[2*j+0], v1, vs);
- vk[2*j+1] = F16::fmadd(vk[2*j+1], v2, vs);
+ auto vs1 = F16::set1(fms.cache[k_step*j + l1+0]);
+ auto vs2 = F16::set1(fms.cache[k_step*j + l1+1]);
+ vk[2*j+0] = F16::fmadd(F16::fmadd(vk[2*j+0], v1, vs1), v3, vs2);
+ vk[2*j+1] = F16::fmadd(F16::fmadd(vk[2*j+1], v2, vs1), v4, vs2);
}
}
for (int j = 0; j < q_step; ++j) {
@@ -13178,13 +12978,15 @@ struct FlashQKV {
}
}
}
- F16::Data v1, v2;
- for (int l1 = 0; l1 < k_step; ++l1) {
- vh.load(l1, i, v1, v2);
+ F16::Data v1, v2, v3, v4;
+ for (int l1 = 0; l1 < k_step; l1 += 2) {
+ vh.load(l1+0, i, v1, v2);
+ vh.load(l1+1, i, v3, v4);
for (int j = 0; j < nq1; ++j) {
- auto vs = F16::set1(fms.cache[k_step*j + l1]);
- vk[2*j+0] = F16::fmadd(vk[2*j+0], v1, vs);
- vk[2*j+1] = F16::fmadd(vk[2*j+1], v2, vs);
+ auto vs1 = F16::set1(fms.cache[k_step*j + l1+0]);
+ auto vs2 = F16::set1(fms.cache[k_step*j + l1+1]);
+ vk[2*j+0] = F16::fmadd(F16::fmadd(vk[2*j+0], v1, vs1), v3, vs2);
+ vk[2*j+1] = F16::fmadd(F16::fmadd(vk[2*j+1], v2, vs1), v4, vs2);
}
}
for (int j = 0; j < nq1; ++j) {
@@ -13388,57 +13190,83 @@ struct FlashQKfp32 {
}
#endif
- template <typename KHelper, typename block_q8>
- static inline void mul_mask_kq(const KHelper& kh, int stride_m,
- const block_q8 * q, const char * mask, FlashMS<q_step, k_step>& fms) {
- static_assert(q_step <= 8);
+ template <typename KHelper>
+ static inline std::pair<mul_mat_t, int> mul_mat_kernel(int nq) {
+ constexpr int kMaxQ = 8;
+#define MAKE_FUNCS(mul_mat, n) \
+ if (n >= kMaxQ) return std::make_pair(mul_mat, kMaxQ>, kMaxQ);\
+ else {\
+ switch (n) {\
+ case 1: return std::make_pair(mul_mat, 1>, 1);\
+ case 2: return std::make_pair(mul_mat, 2>, 2);\
+ case 3: return std::make_pair(mul_mat, 3>, 3);\
+ case 4: return std::make_pair(mul_mat, 4>, 4);\
+ case 5: return std::make_pair(mul_mat, 5>, 5);\
+ case 6: return std::make_pair(mul_mat, 6>, 6);\
+ case 7: return std::make_pair(mul_mat, 7>, 7);\
+ }\
+ }
if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>>) {
- DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr};
#ifdef __aarch64__
- mul_mat_qX_0_q8_0<DequantizerQ40, q_step>(D, kh.block, kh.stride, info, k_step);
+ MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ40, nq);
#else
- mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step);
+ MAKE_FUNCS(mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, nq);
#endif
}
else if constexpr (std::is_same_v<KHelper, HelperQ80<D, k_step>>) {
- DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr};
#ifdef __aarch64__
- mul_mat_qX_0_q8_0<DequantizerQ80_x4, q_step>(D, kh.block, kh.stride, info, k_step);
+ MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ80, nq);
#else
if constexpr (D >= 128) {
- mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step);
+#ifdef HAVE_FANCY_SIMD
+ MAKE_FUNCS(mul_mat_qX_1_q8_1_T<Q8_0_1_Unpacker, nq);
+#else
+ MAKE_FUNCS(mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, nq);
+#endif
} else {
- mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step);
+ // This does not actually work until we fix K-cache to be quantized to Q8_0_x4 only if D%128 == 0
+ MAKE_FUNCS(mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, nq);
}
#endif
}
else if constexpr (std::is_same_v<KHelper, HelperQ41<D, k_step>>) {
- DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr};
#ifdef __aarch64__
- mul_mat_qX_1_q8_1<DequantizerQ41, q_step>(D, kh.block, kh.stride, info, k_step);
+ MAKE_FUNCS(mul_mat_qX_1_q8_1<DequantizerQ41, nq);
#else
- mul_mat_qX_1_q8_1_T<Q4_1_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step);
+ MAKE_FUNCS(mul_mat_qX_1_q8_1_T<Q4_1_Unpacker, nq);
#endif
}
else if constexpr (std::is_same_v<KHelper, HelperIQ4nl<D, k_step>>) {
- DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr};
#ifdef __aarch64__
- mul_mat_qX_0_q8_0<DequantizerIQ4NL, q_step>(D, kh.block, kh.stride, info, k_step);
+ MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerIQ4NL, nq);
#else
- mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step);
+ MAKE_FUNCS(mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, nq);
#endif
}
else if constexpr (std::is_same_v<KHelper, HelperQ60<D, k_step>>) {
- DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr};
#ifdef __aarch64__
- mul_mat_qX_0_q8_0<DequantizerQ60, q_step>(D, kh.block, kh.stride, info, k_step);
+ MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ60, nq);
#else
- mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step);
+ MAKE_FUNCS(mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, nq);
#endif
}
else {
GGML_ASSERT(false);
}
+ return std::make_pair<mul_mat_t, int>(nullptr, 0);
+ }
+
+ template <typename KHelper, typename block_q8>
+ static inline void mul_mask_kq(const KHelper& kh, int stride_m,
+ const block_q8 * q, const char * mask, FlashMS<q_step, k_step>& fms) {
+ constexpr int kMaxQ = 8;
+ static_assert(q_step < kMaxQ || q_step%kMaxQ == 0);
+ auto [mul_mat, nrc_q] = mul_mat_kernel<KHelper>(q_step);
+ DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr};
+ for (int iq = 0; iq < q_step/nrc_q; ++iq) {
+ mul_mat(D, kh.block, kh.stride, info, k_step);
+ info.cur_y += nrc_q;
+ }
#ifdef __aarch64__
float32x4_t vk[k_step/4];
for (int j = 0; j < q_step; ++j) {
@@ -13451,136 +13279,21 @@ struct FlashQKfp32 {
}
#endif
}
+
template <typename KHelper, typename block_q8>
static inline void mul_mask_kq(int nq, const KHelper& kh, int stride_m,
const block_q8 * q, const char * mask, FlashMS<q_step, k_step>& fms) {
- GGML_ASSERT(nq < 8);
- if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>>) {
- DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr};
- switch (nq) {
-#ifdef __aarch64__
- case 1: mul_mat_qX_0_q8_0<DequantizerQ40, 1>(D, kh.block, kh.stride, info, k_step); break;
- case 2: mul_mat_qX_0_q8_0<DequantizerQ40, 2>(D, kh.block, kh.stride, info, k_step); break;
- case 3: mul_mat_qX_0_q8_0<DequantizerQ40, 3>(D, kh.block, kh.stride, info, k_step); break;
- case 4: mul_mat_qX_0_q8_0<DequantizerQ40, 4>(D, kh.block, kh.stride, info, k_step); break;
- case 5: mul_mat_qX_0_q8_0<DequantizerQ40, 5>(D, kh.block, kh.stride, info, k_step); break;
- case 6: mul_mat_qX_0_q8_0<DequantizerQ40, 6>(D, kh.block, kh.stride, info, k_step); break;
- case 7: mul_mat_qX_0_q8_0<DequantizerQ40, 7>(D, kh.block, kh.stride, info, k_step); break;
-#else
- case 1: mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, 1>(D, kh.block, kh.stride, info, k_step); break;
- case 2: mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, 2>(D, kh.block, kh.stride, info, k_step); break;
- case 3: mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, 3>(D, kh.block, kh.stride, info, k_step); break;
- case 4: mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, 4>(D, kh.block, kh.stride, info, k_step); break;
- case 5: mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, 5>(D, kh.block, kh.stride, info, k_step); break;
- case 6: mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, 6>(D, kh.block, kh.stride, info, k_step); break;
- case 7: mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, 7>(D, kh.block, kh.stride, info, k_step); break;
-#endif
- }
- }
- else if constexpr (std::is_same_v<KHelper, HelperQ80<D, k_step>>) {
- DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr};
-#ifdef __aarch64__
- switch (nq) {
- case 1: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 1>(D, kh.block, kh.stride, info, k_step); break;
- case 2: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 2>(D, kh.block, kh.stride, info, k_step); break;
- case 3: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 3>(D, kh.block, kh.stride, info, k_step); break;
- case 4: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 4>(D, kh.block, kh.stride, info, k_step); break;
- case 5: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 5>(D, kh.block, kh.stride, info, k_step); break;
- case 6: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 6>(D, kh.block, kh.stride, info, k_step); break;
- case 7: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 7>(D, kh.block, kh.stride, info, k_step); break;
- }
-#else
- if constexpr (D >= 128) {
- switch (nq) {
- case 1: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 1>(D, kh.block, kh.stride, info, k_step); break;
- case 2: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 2>(D, kh.block, kh.stride, info, k_step); break;
- case 3: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 3>(D, kh.block, kh.stride, info, k_step); break;
- case 4: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 4>(D, kh.block, kh.stride, info, k_step); break;
- case 5: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 5>(D, kh.block, kh.stride, info, k_step); break;
- case 6: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 6>(D, kh.block, kh.stride, info, k_step); break;
- case 7: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 7>(D, kh.block, kh.stride, info, k_step); break;
- }
- } else {
- switch (nq) {
- case 1: mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, 1>(D, kh.block, kh.stride, info, k_step); break;
- case 2: mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, 2>(D, kh.block, kh.stride, info, k_step); break;
- case 3: mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, 3>(D, kh.block, kh.stride, info, k_step); break;
- case 4: mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, 4>(D, kh.block, kh.stride, info, k_step); break;
- case 5: mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, 5>(D, kh.block, kh.stride, info, k_step); break;
- case 6: mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, 6>(D, kh.block, kh.stride, info, k_step); break;
- case 7: mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, 7>(D, kh.block, kh.stride, info, k_step); break;
- }
- }
-#endif
- }
- else if constexpr (std::is_same_v<KHelper, HelperQ41<D, k_step>>) {
- DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr};
- switch (nq) {
-#ifdef __aarch64__
- case 1: mul_mat_qX_1_q8_1<DequantizerQ41, 1>(D, kh.block, kh.stride, info, k_step); break;
- case 2: mul_mat_qX_1_q8_1<DequantizerQ41, 2>(D, kh.block, kh.stride, info, k_step); break;
- case 3: mul_mat_qX_1_q8_1<DequantizerQ41, 3>(D, kh.block, kh.stride, info, k_step); break;
- case 4: mul_mat_qX_1_q8_1<DequantizerQ41, 4>(D, kh.block, kh.stride, info, k_step); break;
- case 5: mul_mat_qX_1_q8_1<DequantizerQ41, 5>(D, kh.block, kh.stride, info, k_step); break;
- case 6: mul_mat_qX_1_q8_1<DequantizerQ41, 6>(D, kh.block, kh.stride, info, k_step); break;
- case 7: mul_mat_qX_1_q8_1<DequantizerQ41, 7>(D, kh.block, kh.stride, info, k_step); break;
-#else
- case 1: mul_mat_qX_1_q8_1_T<Q4_1_Unpacker, 1>(D, kh.block, kh.stride, info, k_step); break;
- case 2: mul_mat_qX_1_q8_1_T<Q4_1_Unpacker, 2>(D, kh.block, kh.stride, info, k_step); break;
- case 3: mul_mat_qX_1_q8_1_T<Q4_1_Unpacker, 3>(D, kh.block, kh.stride, info, k_step); break;
- case 4: mul_mat_qX_1_q8_1_T<Q4_1_Unpacker, 4>(D, kh.block, kh.stride, info, k_step); break;
- case 5: mul_mat_qX_1_q8_1_T<Q4_1_Unpacker, 5>(D, kh.block, kh.stride, info, k_step); break;
- case 6: mul_mat_qX_1_q8_1_T<Q4_1_Unpacker, 6>(D, kh.block, kh.stride, info, k_step); break;
- case 7: mul_mat_qX_1_q8_1_T<Q4_1_Unpacker, 7>(D, kh.block, kh.stride, info, k_step); break;
-#endif
- }
- }
- else if constexpr (std::is_same_v<KHelper, HelperIQ4nl<D, k_step>>) {
- DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr};
- switch (nq) {
-#ifdef __aarch64__
- case 1: mul_mat_qX_0_q8_0<DequantizerIQ4NL, 1>(D, kh.block, kh.stride, info, k_step); break;
- case 2: mul_mat_qX_0_q8_0<DequantizerIQ4NL, 2>(D, kh.block, kh.stride, info, k_step); break;
- case 3: mul_mat_qX_0_q8_0<DequantizerIQ4NL, 3>(D, kh.block, kh.stride, info, k_step); break;
- case 4: mul_mat_qX_0_q8_0<DequantizerIQ4NL, 4>(D, kh.block, kh.stride, info, k_step); break;
- case 5: mul_mat_qX_0_q8_0<DequantizerIQ4NL, 5>(D, kh.block, kh.stride, info, k_step); break;
- case 6: mul_mat_qX_0_q8_0<DequantizerIQ4NL, 6>(D, kh.block, kh.stride, info, k_step); break;
- case 7: mul_mat_qX_0_q8_0<DequantizerIQ4NL, 7>(D, kh.block, kh.stride, info, k_step); break;
-#else
- case 1: mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, 1>(D, kh.block, kh.stride, info, k_step); break;
- case 2: mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, 2>(D, kh.block, kh.stride, info, k_step); break;
- case 3: mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, 3>(D, kh.block, kh.stride, info, k_step); break;
- case 4: mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, 4>(D, kh.block, kh.stride, info, k_step); break;
- case 5: mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, 5>(D, kh.block, kh.stride, info, k_step); break;
- case 6: mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, 6>(D, kh.block, kh.stride, info, k_step); break;
- case 7: mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, 7>(D, kh.block, kh.stride, info, k_step); break;
-#endif
- }
- }
- else if constexpr (std::is_same_v<KHelper, HelperQ60<D, k_step>>) {
- DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr};
- switch (nq) {
-#ifdef __aarch64__
- case 1: mul_mat_qX_0_q8_0<DequantizerQ60, 1>(D, kh.block, kh.stride, info, k_step); break;
- case 2: mul_mat_qX_0_q8_0<DequantizerQ60, 2>(D, kh.block, kh.stride, info, k_step); break;
- case 3: mul_mat_qX_0_q8_0<DequantizerQ60, 3>(D, kh.block, kh.stride, info, k_step); break;
- case 4: mul_mat_qX_0_q8_0<DequantizerQ60, 4>(D, kh.block, kh.stride, info, k_step); break;
- case 5: mul_mat_qX_0_q8_0<DequantizerQ60, 5>(D, kh.block, kh.stride, info, k_step); break;
- case 6: mul_mat_qX_0_q8_0<DequantizerQ60, 6>(D, kh.block, kh.stride, info, k_step); break;
- case 7: mul_mat_qX_0_q8_0<DequantizerQ60, 7>(D, kh.block, kh.stride, info, k_step); break;
-#else
- case 1: mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, 1>(D, kh.block, kh.stride, info, k_step); break;
- case 2: mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, 2>(D, kh.block, kh.stride, info, k_step); break;
- case 3: mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, 3>(D, kh.block, kh.stride, info, k_step); break;
- case 4: mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, 4>(D, kh.block, kh.stride, info, k_step); break;
- case 5: mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, 5>(D, kh.block, kh.stride, info, k_step); break;
- case 6: mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, 6>(D, kh.block, kh.stride, info, k_step); break;
- case 7: mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, 7>(D, kh.block, kh.stride, info, k_step); break;
-#endif
- }
- }
- else {
- GGML_ASSERT(false);
+ auto [mul_mat, nrc_q] = mul_mat_kernel<KHelper>(nq);
+ DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr};
+ for (int iq = 0; iq < nq/nrc_q; ++iq) {
+ mul_mat(D, kh.block, kh.stride, info, k_step);
+ info.cur_y += nrc_q;
+ }
+ int iq = nrc_q*(nq/nrc_q);
+ if (iq < nq) {
+ auto [mul_mat1, nrc_q1] = mul_mat_kernel<KHelper>(nq - iq);
+ GGML_ASSERT(nrc_q1 == nq - iq);
+ mul_mat1(D, kh.block, kh.stride, info, k_step);
}
#ifdef __aarch64__
float32x4_t vk[k_step/4];
@@ -13864,6 +13577,11 @@ struct FlashQKbf16 {
}
}
+ static inline __m128 hsum_float_4x4(__m128 * a) {
+ for (int i = 0; i < 2; ++i) a[i] = _mm_add_ps(_mm_unpacklo_ps(a[i], a[i+2]), _mm_unpackhi_ps(a[i], a[i+2]));
+ return _mm_add_ps(_mm_unpacklo_ps(a[0], a[1]), _mm_unpackhi_ps(a[0], a[1]));
+ }
+
template <typename KHelper>
static inline void multiply_mask_kq(const KHelper& kh, int stride_q, int stride_m, const float * q,
const char * mask, FlashMS<q_step, k_step>& fms) {
@@ -13893,6 +13611,34 @@ struct FlashQKbf16 {
}
}
+ static inline void mult_mask_kq_4(int l1, int m1, const ggml_bf16_t * q,
+ __m512bh * qv, const __m512bh * vkh, FlashMS<q_step, k_step>& fms) {
+ auto qr = q + m1*D;
+ for (int i = 0; i < D/32; ++i) qv[i] = __m512bh(_mm512_loadu_si512((const __m512i *)qr + i));
+ __m128 sum[4];
+ for (int k = 0; k < 4; ++k) {
+ auto vsum = _mm512_setzero_ps();
+ for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i+k*(D/32)], qv[i]);
+ auto aux = _mm256_add_ps(_mm512_castps512_ps256(vsum), _mm512_extractf32x8_ps(vsum, 1));
+ sum[k] = _mm_add_ps(_mm256_castps256_ps128(aux), _mm256_extractf128_ps(aux, 1));
+ }
+ //auto sum4 = _mm_mask_blend_ps(m8, hsum_float_4x4(sum), _mm_set1_ps(-INFINITY));
+ //_mm_storeu_ps(fms.cache + k_step*m1 + l1, sum4);
+ _mm_storeu_ps(fms.cache + k_step*m1 + l1, hsum_float_4x4(sum));
+ }
+
+ static inline void mult_mask_kq_one(int l1, int m1, const ggml_bf16_t * q,
+ __m512bh * qv, const __m512bh * vkh, FlashMS<q_step, k_step>& fms) {
+ auto qr = q + m1*D;
+ for (int i = 0; i < D/32; ++i) qv[i] = __m512bh(_mm512_loadu_si512((const __m512i*)qr + i));
+ auto vsum = _mm512_setzero_ps();
+ for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i], qv[i]);
+ fms.cache[k_step*m1 + l1 + 0] = _mm512_reduce_add_ps(vsum);
+ vsum = _mm512_setzero_ps();
+ for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i+D/32], qv[i]);
+ fms.cache[k_step*m1 + l1 + 1] = _mm512_reduce_add_ps(vsum);
+ }
+
template <typename KHelper>
static inline void multiply_mask_kq(const KHelper& kh, int stride_m, const ggml_bf16_t * q,
const char * mask, FlashMS<q_step, k_step>& fms) {
@@ -13902,23 +13648,44 @@ struct FlashQKbf16 {
__m512bh vkh[D/8];
for (int l1 = 0; l1 < k_step; l1 += 4) {
kh.load_4(l1, vkh);
- for (int j = 0; j < q_step; ++j) {
- mult_mask_kq_4(l1, j, stride_m, q, mask, qv, vkh, fms);
- }
+ for (int j = 0; j < q_step; ++j) mult_mask_kq_4(l1, j, q, qv, vkh, fms);
}
} else {
__m512bh vkh[D/16];
for (int l1 = 0; l1 < k_step; l1 += 2) {
kh.load_2(l1, vkh);
- for (int j = 0; j < q_step; ++j) {
- mult_mask_kq_one(l1, j, stride_m, q, mask, qv, vkh, fms);
- }
+ for (int j = 0; j < q_step; ++j) mult_mask_kq_one(l1, j, q, qv, vkh, fms);
}
}
}
- __m512 vk[k_step/16];
+ F16::Data vk[k_step/16];
for (int j = 0; j < q_step; ++j) {
- fms.update_M_S(j, vk);
+ fms.update_M_S(j, vk, mask + stride_m*j);
+ }
+ }
+
+ template <typename KHelper>
+ static inline void multiply_mask_kq(int nq, const KHelper& kh, int stride_m, const ggml_bf16_t * q,
+ const char * mask, FlashMS<q_step, k_step>& fms) {
+ {
+ __m512bh qv[D/32];
+ if constexpr (D <= 128) {
+ __m512bh vkh[D/8];
+ for (int l1 = 0; l1 < k_step; l1 += 4) {
+ kh.load_4(l1, vkh);
+ for (int j = 0; j < nq; ++j) mult_mask_kq_4(l1, j, q, qv, vkh, fms);
+ }
+ } else {
+ __m512bh vkh[D/16];
+ for (int l1 = 0; l1 < k_step; l1 += 2) {
+ kh.load_2(l1, vkh);
+ for (int j = 0; j < nq; ++j) mult_mask_kq_one(l1, j, q, qv, vkh, fms);
+ }
+ }
+ }
+ F16::Data vk[k_step/16];
+ for (int j = 0; j < nq; ++j) {
+ fms.update_M_S(j, vk, mask + stride_m*j);
}
}
@@ -13953,6 +13720,19 @@ struct FlashQKbf16 {
bf16 += D;
}
}
+
+ static inline void convert(int nq, int stride_q, const float * q, ggml_bf16_t * bf16) {
+ auto qr = q;
+ for (int j = 0; j < nq; ++j) {
+ for (int i = 0; i < D/32; ++i) {
+ auto val1 = _mm512_loadu_ps(qr + 32*i);
+ auto val2 = _mm512_loadu_ps(qr + 32*i + 16);
+ _mm512_storeu_si512((__m512i *)bf16 + i, (__m512i)_mm512_cvtne2ps_pbh(val2, val1));
+ }
+ qr += stride_q;
+ bf16 += D;
+ }
+ }
};
template <int D, int q_step, int k_step>
@@ -13991,9 +13771,10 @@ struct FlashAttnBF16 {
fms.init_qstep();
kh.reset_block();
vh.reset_block();
+ FlashQKbf16<D, q_step, k_step>::convert(n_left, stride_q, q, q_bf16);
auto mr = mask;
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
- FlashQKbf16<D, q_step, k_step>::multiply_mask_kq(n_left, kh, stride_q, stride_m, q, mr, fms);
+ FlashQKbf16<D, q_step, k_step>::multiply_mask_kq(n_left, kh, stride_m, q_bf16, mr, fms);
fqkv.accumulate_qkv(n_left, vh, fms);
kh.next_block();
vh.next_block();
@@ -14008,28 +13789,59 @@ struct FlashAttnBF16 {
};
#endif
-template <int D, int q_step, int k_step, typename KHelper, typename VHelper>
+template <int D, int k_step, typename KHelper, typename VHelper>
inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
const float * q, const char * mask, float scale, float softcap, float * qkv) {
- if (nq1 >= q_step) {
- FlashAttn<D, q_step, k_step> fa(scale, softcap);
+#if defined __AVX2__
+ constexpr bool kUseLargeStepsQ = !std::is_same_v<KHelper, HelperF16<D, k_step>>;
+#else
+ constexpr bool kUseLargeStepsQ = true;
+#endif
+ if constexpr (kUseLargeStepsQ) {
+ if (nk1 >= 4096) {
+ if (nq1 >= 32) {
+ FlashAttn<D, 32, k_step> fa(scale, softcap);
+ fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
+ return;
+ }
+ else if (nq1 >= 8) {
+ FlashAttn<D, 8, k_step> fa(scale, softcap);
+ fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
+ return;
+ }
+ }
+ }
+ if (nq1 >= 8) {
+ FlashAttn<D, 8, k_step> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
- } else {
+ }
+ else {
FlashAttn<D, 1, k_step> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
}
}
#ifdef __AVX512BF16__
-template <int D, int q_step, int k_step>
+template <int D, int k_step>
inline void iqk_flash_helper_T(int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,
const float * q, const char * k, const char * v, const char * mask,
float scale, float softcap, float * qkv) {
HelperBF16<D, k_step> kh(k, stride_k);
HelperBF16<D, k_step> vh(v, stride_v);
- if (nq1 >= q_step) {
- FlashAttnBF16<D, q_step, k_step> fa(scale, softcap);
+ if (nk1 >= 4096) {
+ if (nq1 >= 64) {
+ FlashAttnBF16<D, 64, k_step> fa(scale, softcap);
+ fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
+ }
+ else if (nq1 >= 16) {
+ FlashAttnBF16<D, 16, k_step> fa(scale, softcap);
+ fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
+ }
+ return;
+ }
+ if (nq1 >= 8) {
+ FlashAttnBF16<D, 8, k_step> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
} else {
FlashAttnBF16<D, 1, k_step> fa(scale, softcap);
@@ -14038,7 +13850,7 @@ inline void iqk_flash_helper_T(int nq1, int nk1, int stride_q, int stride_k, int
}
#endif
-template <int D, int q_step, int k_step, typename KHelper>
+template <int D, int k_step, typename KHelper>
inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v,
int nq1, int nk1, int stride_q, int stride_v, int stride_m, int stride_qkv,
const float * q, const char * v, const char * mask,
@@ -14047,33 +13859,39 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v,
switch (type_v) {
case GGML_TYPE_F16: {
HelperF16<D, k_step> vh(v, stride_v);
- iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
+ iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
} break;
+#ifdef HAVE_FANCY_SIMD
+ case GGML_TYPE_BF16: {
+ HelperBF16<D, k_step> vh(v, stride_v);
+ iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
+ } break;
+#endif
case GGML_TYPE_Q8_0: {
HelperQ80<D, k_step> vh(v, stride_v);
- iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
+ iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
} break;
case GGML_TYPE_Q4_0: {
HelperQ40<D, k_step> vh(v, stride_v);
- iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
+ iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
} break;
case GGML_TYPE_Q4_1: {
HelperQ41<D, k_step> vh(v, stride_v);
- iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
+ iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
} break;
case GGML_TYPE_IQ4_NL: {
HelperIQ4nl<D, k_step> vh(v, stride_v);
- iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
+ iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
} break;
case GGML_TYPE_Q6_0: {
HelperQ60<D, k_step> vh(v, stride_v);
- iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
+ iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
} break;
default: break;
}
}
-template <int D, int q_step, int k_step>
+template <int D, int k_step>
inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,
const float * q, const char * k, const char * v, const char * mask,
@@ -14082,27 +13900,27 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
switch (type_k) {
case GGML_TYPE_F16: {
HelperF16<D, k_step> kh(k, stride_k);
- iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
+ iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
} break;
case GGML_TYPE_Q8_0: {
HelperQ80<D, k_step> kh(k, stride_k);
- iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
+ iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
} break;
case GGML_TYPE_Q4_0: {
HelperQ40<D, k_step> kh(k, stride_k);
- iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
+ iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
} break;
case GGML_TYPE_Q4_1: {
HelperQ41<D, k_step> kh(k, stride_k);
- iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
+ iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
} break;
case GGML_TYPE_IQ4_NL: {
HelperIQ4nl<D, k_step> kh(k, stride_k);
- iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
+ iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
} break;
case GGML_TYPE_Q6_0: {
HelperQ60<D, k_step> kh(k, stride_k);
- iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
+ iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
} break;
default: break;
}
@@ -14149,17 +13967,17 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k
stride_q /= sizeof(float); // q stride as float
#ifdef __AVX512BF16__
- if (type_k == GGML_TYPE_BF16 || type_v == GGML_TYPE_BF16) {
- if (type_k != GGML_TYPE_BF16 || type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 with other types
+ if (type_k == GGML_TYPE_BF16) {
+ if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
switch (D) {
case 64:
- iqk_flash_helper_T< 64, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T< 64, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
case 96:
- iqk_flash_helper_T< 96, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T< 96, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
case 128:
- iqk_flash_helper_T<128, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T<128, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
case 256:
- iqk_flash_helper_T<256, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T<256, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
default:
return false;
}
@@ -14168,21 +13986,42 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k
}
#endif
+ if (nk1%64 == 0) {
+ switch (D) {
+ case 64:
+ iqk_flash_helper_T< 64, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ // Disable until we fix accumulate_qkv for odd D/16
+ //case 80:
+ // iqk_flash_helper_T< 80, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ case 96:
+ iqk_flash_helper_T< 96, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ // Disable until we fix accumulate_qkv for odd D/16
+ //case 112:
+ // iqk_flash_helper_T<112, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ case 128:
+ iqk_flash_helper_T<128, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ case 256:
+ iqk_flash_helper_T<256, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ default:
+ return false;
+ }
+ return true;
+ }
switch (D) {
case 64:
- iqk_flash_helper_T< 64, F16::q_step, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T< 64, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
// Disable until we fix accumulate_qkv for odd D/16
//case 80:
- // iqk_flash_helper_T< 80, 4, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ // iqk_flash_helper_T< 80, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
case 96:
- iqk_flash_helper_T< 96, F16::q_step, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T< 96, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
// Disable until we fix accumulate_qkv for odd D/16
//case 112:
- // iqk_flash_helper_T<112, 4, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ // iqk_flash_helper_T<112, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
case 128:
- iqk_flash_helper_T<128, F16::q_step, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T<128, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
case 256:
- iqk_flash_helper_T<256, F16::q_step, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
+ iqk_flash_helper_T<256, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
default:
return false;
}
diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp
index 2404246d..a2ade6a7 100644
--- a/ggml/src/iqk/iqk_quantize.cpp
+++ b/ggml/src/iqk/iqk_quantize.cpp
@@ -654,6 +654,257 @@ void quantize_row_q8_K16(const float * x, void * vy, int64_t nk) {
#endif
}
+void quantize_row_q8_0_x4(const float * x, void * vy, int64_t k) {
+ const int nb = k / QK8_0;
+ const int nb4 = 4*(nb/4);
+
+ block_q8_0 * y = (block_q8_0 *)vy;
+ block_q8_0_x4 * y4 = (block_q8_0_x4 *)vy;
+#if defined(__aarch64__)
+ for (int i = 0; i < nb; i++) {
+ int i4 = i/4, ir = i%4;
+ float32x4_t srcv [8];
+ float32x4_t asrcv[8];
+ float32x4_t amaxv[8];
+
+ for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j);
+ for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]);
+
+ for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]);
+ for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]);
+ for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]);
+
+ const float amax = vmaxvq_f32(amaxv[0]);
+
+ const float d = amax / ((1 << 7) - 1);
+ const float id = d ? 1.0f/d : 0.0f;
+
+ if (i < nb4) {
+ y4[i4].d[ir] = GGML_FP32_TO_FP16(d);
+ } else {
+ y[i].d = GGML_FP32_TO_FP16(d);
+ }
+
+ for (int j = 0; j < 8; j++) {
+ const float32x4_t v = vmulq_n_f32(srcv[j], id);
+ const int32x4_t vi = vcvtnq_s32_f32(v);
+
+ if (i < nb4) {
+ y4[i4].qs[32*ir + 4*j + 0] = vgetq_lane_s32(vi, 0);
+ y4[i4].qs[32*ir + 4*j + 1] = vgetq_lane_s32(vi, 1);
+ y4[i4].qs[32*ir + 4*j + 2] = vgetq_lane_s32(vi, 2);
+ y4[i4].qs[32*ir + 4*j + 3] = vgetq_lane_s32(vi, 3);
+ } else {
+ y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
+ y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
+ y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
+ y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
+ }
+ }
+ }
+#else
+ for (int i = 0; i < nb; i++) {
+ int i4 = i/4, ir = i%4;
+ // Load elements into 4 AVX vectors
+ __m256 v0 = _mm256_loadu_ps( x );
+ __m256 v1 = _mm256_loadu_ps( x + 8 );
+ __m256 v2 = _mm256_loadu_ps( x + 16 );
+ __m256 v3 = _mm256_loadu_ps( x + 24 );
+ x += 32;
+
+ const __m256 signBit = _mm256_set1_ps( -0.0f );
+ __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
+
+ __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
+ max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
+ max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
+ const float maxScalar = _mm_cvtss_f32( max4 );
+
+ const float d = maxScalar / 127.f;
+ if (i < nb4) {
+ y4[i4].d[ir] = GGML_FP32_TO_FP16(d);
+ } else {
+ y[i].d = GGML_FP32_TO_FP16(d);
+ }
+ const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
+ const __m256 mul = _mm256_set1_ps( id );
+
+ v0 = _mm256_mul_ps( v0, mul );
+ v1 = _mm256_mul_ps( v1, mul );
+ v2 = _mm256_mul_ps( v2, mul );
+ v3 = _mm256_mul_ps( v3, mul );
+
+ v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
+ v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
+ v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
+ v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
+
+ __m256i i0 = _mm256_cvtps_epi32( v0 );
+ __m256i i1 = _mm256_cvtps_epi32( v1 );
+ __m256i i2 = _mm256_cvtps_epi32( v2 );
+ __m256i i3 = _mm256_cvtps_epi32( v3 );
+
+ // Convert int32 to int16
+ i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
+ i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
+ // Convert int16 to int8
+ i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
+
+ // We got our precious signed bytes, but the order is now wrong
+ // These AVX2 pack instructions process 16-byte pieces independently
+ // The following instruction is fixing the order
+ const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
+ i0 = _mm256_permutevar8x32_epi32( i0, perm );
+
+ if (i < nb4) {
+ _mm256_storeu_si256((__m256i *)y4[i4].qs + ir, i0);
+ } else {
+ _mm256_storeu_si256((__m256i *)y[i].qs, i0);
+ }
+ }
+#endif
+}
+
+void quantize_row_q8_1_x4(const float * x, void * vy, int64_t k) {
+ assert(k % QK8_1 == 0);
+ const int nb = k / QK8_1;
+
+ const int nb4 = 4*(nb/4);
+ block_q8_1 * y = (block_q8_1 *)vy;
+ block_q8_1_x4 * y4 = (block_q8_1_x4 *)vy;
+#if defined(__aarch64__)
+ for (int i = 0; i < nb; i++) {
+ int i4 = i/4, ir = i%4;
+ float32x4_t srcv [8];
+ float32x4_t asrcv[8];
+ float32x4_t amaxv[8];
+
+ for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j);
+ for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]);
+
+ for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]);
+ for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]);
+ for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]);
+
+ const float amax = vmaxvq_f32(amaxv[0]);
+
+ const float d = amax / ((1 << 7) - 1);
+ const float id = d ? 1.0f/d : 0.0f;
+
+ if (i < nb4) {
+ y4[i4].d[ir] = GGML_FP32_TO_FP16(d);
+ } else {
+ y[i].d = GGML_FP32_TO_FP16(d);
+ }
+
+ int32x4_t accv = vdupq_n_s32(0);
+
+ for (int j = 0; j < 8; j++) {
+ const float32x4_t v = vmulq_n_f32(srcv[j], id);
+ const int32x4_t vi = vcvtnq_s32_f32(v);
+
+ if (i < nb4) {
+ y4[i4].qs[QK8_1*ir + 4*j + 0] = vgetq_lane_s32(vi, 0);
+ y4[i4].qs[QK8_1*ir + 4*j + 1] = vgetq_lane_s32(vi, 1);
+ y4[i4].qs[QK8_1*ir + 4*j + 2] = vgetq_lane_s32(vi, 2);
+ y4[i4].qs[QK8_1*ir + 4*j + 3] = vgetq_lane_s32(vi, 3);
+ } else {
+ y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
+ y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
+ y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
+ y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
+ }
+
+ accv = vaddq_s32(accv, vi);
+ }
+
+ if (i < nb4) {
+ y4[i4].d[ir+4] = GGML_FP32_TO_FP16(d * vaddvq_s32(accv));
+ } else {
+ y[i].s = GGML_FP32_TO_FP16(d * vaddvq_s32(accv));
+ }
+ }
+#else
+ for (int i = 0; i < nb; i++) {
+ int i4 = i/4, ir = i%4;
+ // Load elements into 4 AVX vectors
+ __m256 v0 = _mm256_loadu_ps( x );
+ __m256 v1 = _mm256_loadu_ps( x + 8 );
+ __m256 v2 = _mm256_loadu_ps( x + 16 );
+ __m256 v3 = _mm256_loadu_ps( x + 24 );
+ x += 32;
+
+ // Compute max(abs(e)) for the block
+ const __m256 signBit = _mm256_set1_ps( -0.0f );
+ __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
+
+ __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
+ max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
+ max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
+ const float max_scalar = _mm_cvtss_f32( max4 );
+
+ // Quantize these floats
+ const float d = max_scalar / 127.f;
+ if (i < nb4) {
+ y4[i4].d[ir] = GGML_FP32_TO_FP16(d);
+ } else {
+ y[i].d = GGML_FP32_TO_FP16(d);
+ }
+ const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f;
+ const __m256 mul = _mm256_set1_ps( id );
+
+ // Apply the multiplier
+ v0 = _mm256_mul_ps( v0, mul );
+ v1 = _mm256_mul_ps( v1, mul );
+ v2 = _mm256_mul_ps( v2, mul );
+ v3 = _mm256_mul_ps( v3, mul );
+
+ // Round to nearest integer
+ v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
+ v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
+ v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
+ v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
+
+ // Convert floats to integers
+ __m256i i0 = _mm256_cvtps_epi32( v0 );
+ __m256i i1 = _mm256_cvtps_epi32( v1 );
+ __m256i i2 = _mm256_cvtps_epi32( v2 );
+ __m256i i3 = _mm256_cvtps_epi32( v3 );
+
+ // Compute the sum of the quants and set y[i].s
+ if (i < nb4) {
+ y4[i4].d[ir+4] = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))));
+ } else {
+ y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))));
+ }
+
+ // Convert int32 to int16
+ i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
+ i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
+ // Convert int16 to int8
+ i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
+
+ // We got our precious signed bytes, but the order is now wrong
+ // These AVX2 pack instructions process 16-byte pieces independently
+ // The following instruction is fixing the order
+ const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
+ i0 = _mm256_permutevar8x32_epi32( i0, perm );
+
+ if (i < nb4) {
+ _mm256_storeu_si256((__m256i *)y4[i4].qs + ir, i0);
+ } else {
+ _mm256_storeu_si256((__m256i *)y[i].qs, i0);
+ }
+ }
+#endif
+}
+
//
// ============================================== iq2_K
//
diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h
index 68a38811..729b0ec0 100644
--- a/ggml/src/iqk/iqk_quantize.h
+++ b/ggml/src/iqk/iqk_quantize.h
@@ -211,6 +211,8 @@ void quantize_row_q8_K64(const float * GGML_RESTRICT x, void * GGML_RESTRICT y,
void quantize_row_q8_K16(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q8_K32(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q8_KR8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q8_0_x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q8_1_x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void repack_f32_bf16_r16 (const void * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row);
void repack_bf16_bf16_r16(const void * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row);