diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2025-01-15 18:19:22 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-01-15 18:19:22 +0200 |
commit | 0b74397d596bbcdfba27299393406d2b6330b133 (patch) | |
tree | 2101d059f79b6b268086c71878aa2da1c328c73d | |
parent | 49b27069fd267d3dac8de5d13141b4274e4be16b (diff) |
CPU Flash Attention improvements (#172)
* Slightly faster FA for bf16 KV cache
~2-3% sort of thing. Sadly, when we go beyond 8k tokens, the
advantage kind of goes away.
* Slightly faster FA for Q8_0 KV cache
* FA: allow bf16 for V-cache with any supported K-cache
E.g., -ctk q8_0 -ctv bf16 is slightly faster than
-ctk q8_0 -ctv q8_0 on Zen4 for not too long context lengths
(say, <= 4096).
* FA: much better bf16 kv-cache speed for large contexts
We now hit 122 t/s for LLaMA-3.1-8B (quantized as iq4_xs and
run-time-repacked) with a context of 32768. IIRC, the previous
best for such large context was ~90 t/s.
Non-negligible improvement at 16384 and 8192 as well:
173.4 and 214 t/s.
* FA: slightly better quantized kv-cache speed for large contexts
E.g., for q8_0 and context of 32768, we are now at 113 t/s
for LLaMA-3.1-8B.
Also simplified the quantized K*Q multiplication.
* Fix q8_0 KV cache when not using FA - WIP (AVX2)
1. We add new types GGML_TYPE_Q8_0_X4 and GGML_TYPE_Q8_1_X4, and use
those to quantize activations for quants that use Q8_0 or Q8_1
as their vec_dot type.
2. We revert the changes to quantize_row_q8_0 and quantize_row_q8_1
3. We use GGML_TYPE_Q8_0_X4 and GGML_TYPE_Q8_1_X4 as the vec_dot type
4. We change the FA implementation to use GGML_TYPE_Q8_0 rather than
GGML_TYPE_Q8_0_X4 as the K and V types
5. We change the expected type to GGML_TYPE_Q8_0_X4/GGML_TYPE_Q8_1_X4
in iqk_mul_mat
Also added an optimization in ggml_compute_forward_mul_mat when
ne12*ne13 > 1 (K*Q and V*softmax(K*Q)) to process
n12*ne13/GCD(n12*ne13, nthread) threads simultaneously using
nthread/GCD(n12*ne13, nthread) threads per head. This results in
a non-negligible performance gain for large contexts.
Question: why is it not allowed to use quantized V-cache when
not using FA?
* Fix q8_0 KV cache when not using FA - NEON
* Fix AVX2
Again the issue with _mm256_maddubs_epi16 overflowing that I
keep forgetting.
* FA: don't use large Q steps on AVX2 for fp16 K-cache
* On Zen4 it is also better to not use large Q steps for fp16 K-cache
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | ggml/include/ggml.h | 2 | ||||
-rw-r--r-- | ggml/src/ggml-quants.c | 105 | ||||
-rw-r--r-- | ggml/src/ggml.c | 167 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 817 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_quantize.cpp | 251 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_quantize.h | 2 |
6 files changed, 738 insertions, 606 deletions
diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 118ce969..5eea7dcd 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -396,6 +396,8 @@ extern "C" { // GGML_TYPE_I2_S = 36, // + GGML_TYPE_Q8_0_X4 = 98, + GGML_TYPE_Q8_1_X4 = 99, GGML_TYPE_Q6_0 = 133, GGML_TYPE_IQ1_BN = 134, GGML_TYPE_IQ2_BN = 135, diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index d460b84a..23ac9915 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -934,13 +934,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k) block_q8_0 * restrict y = vy; -#if GGML_USE_IQK_MULMAT - const int nb4 = 4*(nb/4); -#else - const int nb4 = -1; -#endif #if defined(__ARM_NEON) - block_q8_0_x4 * y4 = (block_q8_0_x4 *)vy; for (int i = 0; i < nb; i++) { int i4 = i/4, ir = i%4; float32x4_t srcv [8]; @@ -959,27 +953,16 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k) const float d = amax / ((1 << 7) - 1); const float id = d ? 1.0f/d : 0.0f; - if (i < nb4) { - y4[i4].d[ir] = GGML_FP32_TO_FP16(d); - } else { - y[i].d = GGML_FP32_TO_FP16(d); - } + y[i].d = GGML_FP32_TO_FP16(d); for (int j = 0; j < 8; j++) { const float32x4_t v = vmulq_n_f32(srcv[j], id); const int32x4_t vi = vcvtnq_s32_f32(v); - if (i < nb4) { - y4[i4].qs[32*ir + 4*j + 0] = vgetq_lane_s32(vi, 0); - y4[i4].qs[32*ir + 4*j + 1] = vgetq_lane_s32(vi, 1); - y4[i4].qs[32*ir + 4*j + 2] = vgetq_lane_s32(vi, 2); - y4[i4].qs[32*ir + 4*j + 3] = vgetq_lane_s32(vi, 3); - } else { - y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); - y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); - y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); - y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); - } + y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); + y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); + y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); + y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); } } #elif defined(__wasm_simd128__) @@ -1016,14 +999,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k) } } #elif defined(__AVX2__) || defined(__AVX__) - block_q8_0_x4 * y4 = (block_q8_0_x4 *)vy; -#ifdef __AVX2__ - const bool pack = true; -#else - const bool pack = false; -#endif for (int i = 0; i < nb; i++) { - int i4 = i/4, ir = i%4; // Load elements into 4 AVX vectors __m256 v0 = _mm256_loadu_ps( x ); __m256 v1 = _mm256_loadu_ps( x + 8 ); @@ -1045,11 +1021,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k) // Quantize these floats const float d = maxScalar / 127.f; - if (pack && i < nb4) { - y4[i4].d[ir] = GGML_FP32_TO_FP16(d); - } else { - y[i].d = GGML_FP32_TO_FP16(d); - } + y[i].d = GGML_FP32_TO_FP16(d); const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; const __m256 mul = _mm256_set1_ps( id ); @@ -1084,11 +1056,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k) const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); i0 = _mm256_permutevar8x32_epi32( i0, perm ); - if (i < nb4) { - _mm256_storeu_si256((__m256i *)y4[i4].qs + ir, i0); - } else { - _mm256_storeu_si256((__m256i *)y[i].qs, i0); - } + _mm256_storeu_si256((__m256i *)y[i].qs, i0); #else // Since we don't have in AVX some necessary functions, // we split the registers in half and call AVX2 analogs from SSE @@ -1287,15 +1255,8 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) block_q8_1 * restrict y = vy; -#if GGML_USE_IQK_MULMAT - const int nb4 = 4*(nb/4); -#else - const int nb4 = -1; -#endif #if defined(__ARM_NEON) - block_q8_1_x4 * restrict y4 = vy; for (int i = 0; i < nb; i++) { - int i4 = i/4, ir = i%4; float32x4_t srcv [8]; float32x4_t asrcv[8]; float32x4_t amaxv[8]; @@ -1312,11 +1273,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) const float d = amax / ((1 << 7) - 1); const float id = d ? 1.0f/d : 0.0f; - if (i < nb4) { - y4[i4].d[ir] = GGML_FP32_TO_FP16(d); - } else { - y[i].d = GGML_FP32_TO_FP16(d); - } + y[i].d = GGML_FP32_TO_FP16(d); int32x4_t accv = vdupq_n_s32(0); @@ -1324,26 +1281,15 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) const float32x4_t v = vmulq_n_f32(srcv[j], id); const int32x4_t vi = vcvtnq_s32_f32(v); - if (i < nb4) { - y4[i4].qs[QK8_1*ir + 4*j + 0] = vgetq_lane_s32(vi, 0); - y4[i4].qs[QK8_1*ir + 4*j + 1] = vgetq_lane_s32(vi, 1); - y4[i4].qs[QK8_1*ir + 4*j + 2] = vgetq_lane_s32(vi, 2); - y4[i4].qs[QK8_1*ir + 4*j + 3] = vgetq_lane_s32(vi, 3); - } else { - y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); - y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); - y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); - y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); - } + y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); + y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); + y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); + y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); accv = vaddq_s32(accv, vi); } - if (i < nb4) { - y4[i4].d[ir+4] = GGML_FP32_TO_FP16(d * vaddvq_s32(accv)); - } else { - y[i].s = GGML_FP32_TO_FP16(d * vaddvq_s32(accv)); - } + y[i].s = GGML_FP32_TO_FP16(d * vaddvq_s32(accv)); } #elif defined(__wasm_simd128__) for (int i = 0; i < nb; i++) { @@ -1389,14 +1335,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) wasm_i32x4_extract_lane(accv, 3))); } #elif defined(__AVX2__) || defined(__AVX__) - block_q8_1_x4 * restrict y4 = vy; -#ifdef __AVX2__ - const bool pack = true; -#else - const bool pack = false; -#endif for (int i = 0; i < nb; i++) { - int i4 = i/4, ir = i%4; // Load elements into 4 AVX vectors __m256 v0 = _mm256_loadu_ps( x ); __m256 v1 = _mm256_loadu_ps( x + 8 ); @@ -1418,11 +1357,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) // Quantize these floats const float d = max_scalar / 127.f; - if (pack && i < nb4) { - y4[i4].d[ir] = GGML_FP32_TO_FP16(d); - } else { - y[i].d = GGML_FP32_TO_FP16(d); - } + y[i].d = GGML_FP32_TO_FP16(d); const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f; const __m256 mul = _mm256_set1_ps( id ); @@ -1446,11 +1381,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) #if defined(__AVX2__) // Compute the sum of the quants and set y[i].s - if (i < nb4) { - y4[i4].d[ir+4] = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)))); - } else { - y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)))); - } + y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)))); // Convert int32 to int16 i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 @@ -1464,11 +1395,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); i0 = _mm256_permutevar8x32_epi32( i0, perm ); - if (i < nb4) { - _mm256_storeu_si256((__m256i *)y4[i4].qs + ir, i0); - } else { - _mm256_storeu_si256((__m256i *)y[i].qs, i0); - } + _mm256_storeu_si256((__m256i *)y[i].qs, i0); #else // Since we don't have in AVX some necessary functions, // we split the registers in half and call AVX2 analogs from SSE diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index b9f9b3d8..bcb8bf41 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -714,8 +714,12 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float = quantize_row_q4_0, .from_float_ref = (ggml_from_float_t) quantize_row_q4_0_ref, .vec_dot = ggml_vec_dot_q4_0_q8_0, -#if GGML_USE_IQK_MULMAT && defined __AVX2__ - .vec_dot_type = GGML_TYPE_Q8_1, +#if GGML_USE_IQK_MULMAT +#if defined __AVX2__ + .vec_dot_type = GGML_TYPE_Q8_1_X4, +#else + .vec_dot_type = GGML_TYPE_Q8_0_X4, +#endif #else .vec_dot_type = GGML_TYPE_Q8_0, #endif @@ -735,7 +739,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float = quantize_row_q4_1, .from_float_ref = (ggml_from_float_t) quantize_row_q4_1_ref, .vec_dot = ggml_vec_dot_q4_1_q8_1, +#if GGML_USE_IQK_MULMAT + .vec_dot_type = GGML_TYPE_Q8_1_X4, +#else .vec_dot_type = GGML_TYPE_Q8_1, +#endif #if defined (__ARM_FEATURE_MATMUL_INT8) .nrows = 2, #else @@ -778,8 +786,12 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float = quantize_row_q5_0, .from_float_ref = (ggml_from_float_t) quantize_row_q5_0_ref, .vec_dot = ggml_vec_dot_q5_0_q8_0, -#if GGML_USE_IQK_MULMAT && defined __AVX2__ - .vec_dot_type = GGML_TYPE_Q8_1, +#if GGML_USE_IQK_MULMAT +#if defined __AVX2__ + .vec_dot_type = GGML_TYPE_Q8_1_X4, +#else + .vec_dot_type = GGML_TYPE_Q8_0_X4, +#endif #else .vec_dot_type = GGML_TYPE_Q8_0, #endif @@ -795,7 +807,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float = quantize_row_q5_1, .from_float_ref = (ggml_from_float_t) quantize_row_q5_1_ref, .vec_dot = ggml_vec_dot_q5_1_q8_1, +#if GGML_USE_IQK_MULMAT + .vec_dot_type = GGML_TYPE_Q8_1_X4, +#else .vec_dot_type = GGML_TYPE_Q8_1, +#endif .nrows = 1, .row_meta_size = 0, }, @@ -808,8 +824,12 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float = quantize_row_q6_0, .from_float_ref = (ggml_from_float_t) quantize_row_q6_0_ref, .vec_dot = ggml_vec_dot_q6_0_q8_0, -#if GGML_USE_IQK_MULMAT && defined __AVX2__ - .vec_dot_type = GGML_TYPE_Q8_1, +#if GGML_USE_IQK_MULMAT +#if defined __AVX2__ + .vec_dot_type = GGML_TYPE_Q8_1_X4, +#else + .vec_dot_type = GGML_TYPE_Q8_0_X4, +#endif #else .vec_dot_type = GGML_TYPE_Q8_0, #endif @@ -826,8 +846,16 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float_ref = (ggml_from_float_t) quantize_row_q8_0_ref, .from_float_to_mat = quantize_mat_q8_0, .vec_dot = ggml_vec_dot_q8_0_q8_0, -#if GGML_USE_IQK_MULMAT && defined __AVX2__ - .vec_dot_type = GGML_TYPE_Q8_1, +#if GGML_USE_IQK_MULMAT +#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__) + // Remember: we cannot add 128 to the Q8 quants and use iblock sum in Q8_1 to subtract as we do on Zen4 for pure AVX2 + // because there the result of the _mm256_maddubs_epi16() instruction may overflow the int16_t range + // (and it gets satured if it does), leading to wrong results. + // TODO: expose HAVE_FANCY_SIMD from iqk_mul_mat.cpp and use #ifdef HAVE_FANCY_SIMD instead of the above. + .vec_dot_type = GGML_TYPE_Q8_1_X4, +#else + .vec_dot_type = GGML_TYPE_Q8_0_X4, +#endif #else .vec_dot_type = GGML_TYPE_Q8_0, #endif @@ -849,6 +877,26 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .nrows = 1, .row_meta_size = 0, }, + [GGML_TYPE_Q8_0_X4] = { + .type_name = "q8_0_x4", + .blck_size = QK8_0, + .type_size = sizeof(block_q8_0), + .is_quantized = true, + .from_float = quantize_row_q8_0_x4, + .from_float_ref = quantize_row_q8_0_x4, + .nrows = 1, + .row_meta_size = 0, + }, + [GGML_TYPE_Q8_1_X4] = { + .type_name = "q8_1_x4", + .blck_size = QK8_1, + .type_size = sizeof(block_q8_1), + .is_quantized = true, + .from_float = quantize_row_q8_1_x4, + .from_float_ref = quantize_row_q8_1_x4, + .nrows = 1, + .row_meta_size = 0, + }, [GGML_TYPE_Q2_K] = { .type_name = "q2_K", .blck_size = QK_K, @@ -1196,8 +1244,12 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float = quantize_row_iq4_nl, .from_float_ref = (ggml_from_float_t)quantize_row_iq4_nl_ref, .vec_dot = ggml_vec_dot_iq4_nl_q8_0, -#if GGML_USE_IQK_MULMAT && defined __AVX2__ - .vec_dot_type = GGML_TYPE_Q8_1, +#if GGML_USE_IQK_MULMAT +#if defined __AVX2__ + .vec_dot_type = GGML_TYPE_Q8_1_X4, +#else + .vec_dot_type = GGML_TYPE_Q8_0_X4, +#endif #else .vec_dot_type = GGML_TYPE_Q8_0, #endif @@ -1516,8 +1568,12 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float = quantize_row_iq4_nl_r4, .from_float_ref = (ggml_from_float_t)quantize_row_iq4_nl_r4_ref, .vec_dot = vec_dot_iq4_nl_r4_q8_0, -#if GGML_USE_IQK_MULMAT && defined __AVX2__ - .vec_dot_type = GGML_TYPE_Q8_1, +#if GGML_USE_IQK_MULMAT +#if defined __AVX2__ + .vec_dot_type = GGML_TYPE_Q8_1_X4, +#else + .vec_dot_type = GGML_TYPE_Q8_0_X4, +#endif #else .vec_dot_type = GGML_TYPE_Q8_0, #endif @@ -1546,8 +1602,12 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float = quantize_row_q4_0_r4, .from_float_ref = (ggml_from_float_t)quantize_row_q4_0_r4_ref, .vec_dot = vec_dot_q4_0_r4_q8_0, -#if GGML_USE_IQK_MULMAT && defined __AVX2__ - .vec_dot_type = GGML_TYPE_Q8_1, +#if GGML_USE_IQK_MULMAT +#if defined __AVX2__ + .vec_dot_type = GGML_TYPE_Q8_1_X4, +#else + .vec_dot_type = GGML_TYPE_Q8_0_X4, +#endif #else .vec_dot_type = GGML_TYPE_Q8_0, #endif @@ -1563,8 +1623,12 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float = quantize_row_q8_0_r4, .from_float_ref = (ggml_from_float_t)quantize_row_q8_0_r4_ref, .vec_dot = vec_dot_q8_0_r4_q8_0, -#if GGML_USE_IQK_MULMAT && defined __AVX2__ - .vec_dot_type = GGML_TYPE_Q8_1, +#if GGML_USE_IQK_MULMAT +#if defined __AVX2__ + .vec_dot_type = GGML_TYPE_Q8_1_X4, +#else + .vec_dot_type = GGML_TYPE_Q8_0_X4, +#endif #else .vec_dot_type = GGML_TYPE_Q8_0, #endif @@ -1580,8 +1644,12 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float = quantize_row_q5_0_r4, .from_float_ref = (ggml_from_float_t)quantize_row_q5_0_r4_ref, .vec_dot = vec_dot_q5_0_r4_q8_0, -#if GGML_USE_IQK_MULMAT && defined __AVX2__ - .vec_dot_type = GGML_TYPE_Q8_1, +#if GGML_USE_IQK_MULMAT +#if defined __AVX2__ + .vec_dot_type = GGML_TYPE_Q8_1_X4, +#else + .vec_dot_type = GGML_TYPE_Q8_0_X4, +#endif #else .vec_dot_type = GGML_TYPE_Q8_0, #endif @@ -1597,8 +1665,12 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float = quantize_row_q6_0_r4, .from_float_ref = (ggml_from_float_t)quantize_row_q6_0_r4_ref, .vec_dot = vec_dot_q6_0_r4_q8_0, -#if GGML_USE_IQK_MULMAT && defined __AVX2__ - .vec_dot_type = GGML_TYPE_Q8_1, +#if GGML_USE_IQK_MULMAT +#if defined __AVX2__ + .vec_dot_type = GGML_TYPE_Q8_1_X4, +#else + .vec_dot_type = GGML_TYPE_Q8_0_X4, +#endif #else .vec_dot_type = GGML_TYPE_Q8_0, #endif @@ -11280,6 +11352,8 @@ static void ggml_compute_forward_add1( case GGML_TYPE_Q6_0: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: + case GGML_TYPE_Q8_0_X4: + case GGML_TYPE_Q8_1_X4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q2_K_R4: case GGML_TYPE_Q3_K: @@ -11443,6 +11517,8 @@ static void ggml_compute_forward_acc( case GGML_TYPE_Q6_0: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: + case GGML_TYPE_Q8_0_X4: + case GGML_TYPE_Q8_1_X4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q2_K_R4: case GGML_TYPE_Q3_K: @@ -13889,6 +13965,14 @@ static void ggml_compute_forward_mul_mat_one_chunk( } } +static inline uint32_t simple_gcd(uint32_t a, uint32_t b) { + while (a != b) { + if (a > b) a -= b; + else b -= a; + } + return a; +} + static void ggml_compute_forward_mul_mat( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -13905,10 +13989,12 @@ static void ggml_compute_forward_mul_mat( enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; ggml_from_float_t const from_float = type_traits[vec_dot_type].from_float; - ggml_from_float_to_mat_t const from_float_to_mat = type_traits[vec_dot_type].from_float_to_mat; int64_t const vec_dot_num_rows = type_traits[type].nrows; int64_t const matmul_num_cols = type_traits[type].ncols; +#if !GGML_USE_IQK_MULMAT + ggml_from_float_to_mat_t const from_float_to_mat = type_traits[vec_dot_type].from_float_to_mat; int64_t const blck_size_interleave = type_traits[type].blck_size_interleave; +#endif ggml_gemv_t const gemv = type_traits[type].gemv; ggml_gemm_t const gemm = type_traits[type].gemm; @@ -14011,6 +14097,7 @@ UseGgmlGemm1:; for (int64_t i13 = 0; i13 < ne13; ++i13) { for (int64_t i12 = 0; i12 < ne12; ++i12) { int64_t i11_processed = 0; +#if !GGML_USE_IQK_MULMAT if ((ggml_n_dims(src1) == 2) && from_float_to_mat && gemm) { for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) { from_float_to_mat((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), @@ -14019,6 +14106,7 @@ UseGgmlGemm1:; } i11_processed = ne11 - ne11 % 4; } +#endif for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) { from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), @@ -14049,14 +14137,31 @@ AlreadyQuantized:; #if GGML_USE_IQK_MULMAT if (src1->type != vec_dot_type && dst->type == GGML_TYPE_F32) { + // When K*Q and V*softmax(K*Q) (so ne12*ne13 > 1), it is better (faster) to have fewer threads processing + // one matrix multiplication, but work on several heads at once. + // Hence, we find the GCD(n12*ne13, nth) and have nth/GCD(n12*ne13, nth) threads per head. + // Leaving the previous version commented out for now just in case. const size_t row_size = ggml_row_size(vec_dot_type, ne10); - for (int64_t i13 = 0; i13 < ne13; i13++) - for (int64_t i12 = 0; i12 < ne12; i12++) - if (!iqk_mul_mat(ne01, ne11, ne00, - src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01, ///ggml_type_size(src0->type), - vec_dot_type, (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size, row_size, ///ggml_type_size(vec_dot_type), - (float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type), - ith, nth)) goto IQK_MulMat_Not_Available2; + int ntg = simple_gcd(ne12*ne13, nth); + int counter = 0; + for (int64_t i13 = 0; i13 < ne13; i13++) { + for (int64_t i12 = 0; i12 < ne12; i12++) { + if (counter++ % ntg == ith%ntg) { + if (!iqk_mul_mat(ne01, ne11, ne00, + src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01, ///ggml_type_size(src0->type), + vec_dot_type, (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size, row_size, ///ggml_type_size(vec_dot_type), + (float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type), + ith/ntg, nth/ntg)) goto IQK_MulMat_Not_Available2; + } + } + } + //for (int64_t i13 = 0; i13 < ne13; i13++) + // for (int64_t i12 = 0; i12 < ne12; i12++) + // if (!iqk_mul_mat(ne01, ne11, ne00, + // src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01, ///ggml_type_size(src0->type), + // vec_dot_type, (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size, row_size, ///ggml_type_size(vec_dot_type), + // (float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type), + // ith, nth)) goto IQK_MulMat_Not_Available2; return; } IQK_MulMat_Not_Available2:; @@ -15055,6 +15160,8 @@ static void ggml_compute_forward_set( case GGML_TYPE_Q6_0: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: + case GGML_TYPE_Q8_0_X4: + case GGML_TYPE_Q8_1_X4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q2_K_R4: case GGML_TYPE_Q3_K: @@ -15352,6 +15459,8 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_Q6_0: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: + case GGML_TYPE_Q8_0_X4: + case GGML_TYPE_Q8_1_X4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q2_K_R4: case GGML_TYPE_Q3_K: @@ -15977,6 +16086,8 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_Q6_0: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: + case GGML_TYPE_Q8_0_X4: + case GGML_TYPE_Q8_1_X4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q2_K_R4: case GGML_TYPE_Q3_K: diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index cfca477d..5577ea99 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -6755,10 +6755,10 @@ struct Q_Unpacker { } }; -struct Q8_0_x4_Unpacker { +struct Q8_0_x4_Unpacker_256 { using Sum4T = Sum4TypeQ80; inline static int block_size() { return QK8_0; } - Q8_0_x4_Unpacker(const void * vx, size_t bx) : cx_0((const char *)vx), x((const block_q8_0_x4 *)cx_0), bx(bx) {} + Q8_0_x4_Unpacker_256(const void * vx, size_t bx) : cx_0((const char *)vx), x((const block_q8_0_x4 *)cx_0), bx(bx) {} const char * cx_0; const block_q8_0_x4 * x; @@ -6784,6 +6784,44 @@ struct Q8_0_x4_Unpacker { } }; +#ifdef HAVE_FANCY_SIMD +struct Q8_0_x4_Unpacker_512 { + using Sum4T = Sum4TypeQ81; + inline static int block_size() { return QK8_0; } + Q8_0_x4_Unpacker_512(const void * vx, size_t bx) : cx_0((const char *)vx), x((const block_q8_0_x4 *)cx_0), bx(bx) {} + + const char * cx_0; + const block_q8_0_x4 * x; + size_t bx; + const __m128 min = _mm_set1_ps(-128.f); + + __m256i qx[4]; + + inline const __m256i* quants() const { return qx; } + + inline void set_row(int ix) { x = (const block_q8_0_x4 *)(cx_0 + ix*bx); } + + inline auto set_block_4(int i) { + auto scales = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)x[i].d)); + for (int j = 0; j < 4; ++j) { + qx[j] = _mm256_loadu_si256((const __m256i *)x[i].qs + j); + qx[j] = _mm256_xor_si256(qx[j], _mm256_set1_epi8(0x80)); + } + return _mm256_set_m128(_mm_mul_ps(scales, min), scales); + } + inline auto set_block(int i) { + auto q8 = (const block_q8_0 *)(x + i); + qx[0] = _mm256_loadu_si256((const __m256i *)q8->qs); + qx[0] = _mm256_xor_si256(qx[0], _mm256_set1_epi8(0x80)); + float d = GGML_FP16_TO_FP32(q8->d); + return std::make_pair(d, -128.f*d); + } +}; +using Q8_0_x4_Unpacker = Q8_0_x4_Unpacker_512; +#else +using Q8_0_x4_Unpacker = Q8_0_x4_Unpacker_256; +#endif + struct Q8_0_Unpacker final : public Q_Unpacker<block_q8_0, ScaleHelperQ_0, Q8_0_Dequantizer> { Q8_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} using Sum4T = Sum4TypeQ80; @@ -7414,37 +7452,42 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { case GGML_TYPE_Q4_0: assert (ne00 % QK4_0 == 0); MulMat::set_functions<Q4_0_1_Unpacker>(mm); - expected_typeB = GGML_TYPE_Q8_1; + expected_typeB = GGML_TYPE_Q8_1_X4; break; case GGML_TYPE_Q4_1: assert (ne00 % QK4_1 == 0); MulMat::set_functions<Q4_1_Unpacker>(mm); - expected_typeB = GGML_TYPE_Q8_1; + expected_typeB = GGML_TYPE_Q8_1_X4; break; case GGML_TYPE_Q5_0: assert (ne00 % QK5_0 == 0); MulMat::set_functions<Q5_0_1_Unpacker>(mm); - expected_typeB = GGML_TYPE_Q8_1; + expected_typeB = GGML_TYPE_Q8_1_X4; break; case GGML_TYPE_Q5_1: assert (ne00 % QK5_1 == 0); MulMat::set_functions<Q5_1_Unpacker>(mm); - expected_typeB = GGML_TYPE_Q8_1; + expected_typeB = GGML_TYPE_Q8_1_X4; break; case GGML_TYPE_Q6_0: assert (ne00 % QK6_0 == 0); MulMat::set_functions<Q6_0_1_Unpacker>(mm); - expected_typeB = GGML_TYPE_Q8_1; + expected_typeB = GGML_TYPE_Q8_1_X4; break; case GGML_TYPE_Q8_0: assert (ne00 % QK8_0 == 0); +#ifdef HAVE_FANCY_SIMD MulMat::set_functions<Q8_0_1_Unpacker>(mm); - expected_typeB = GGML_TYPE_Q8_1; + expected_typeB = GGML_TYPE_Q8_1_X4; +#else + MulMat::set_functions<Q8_0_Unpacker>(mm); + expected_typeB = GGML_TYPE_Q8_0_X4; +#endif break; case GGML_TYPE_IQ4_NL: assert (ne00 % QK4_NL == 0); MulMat::set_functions<IQ4_NL_Unpacker>(mm); - expected_typeB = GGML_TYPE_Q8_1; + expected_typeB = GGML_TYPE_Q8_1_X4; break; case GGML_TYPE_IQ4_NL_R4: assert (ne00 % QK4_NL == 0); @@ -7456,7 +7499,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { mm.funcs[5] = mul_mat_iq4_nl_r4_q8_1<6>; mm.funcs[6] = mul_mat_iq4_nl_r4_q8_1<7>; mm.funcs[7] = mul_mat_iq4_nl_r4_q8_1<8>; - expected_typeB = GGML_TYPE_Q8_1; + expected_typeB = GGML_TYPE_Q8_1_X4; break; case GGML_TYPE_IQ4_XS_R4: assert (ne00 % QK_K == 0); @@ -7689,7 +7732,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { mm.funcs[5] = mul_mat_q4_0_r4_q8_1<6>; mm.funcs[6] = mul_mat_q4_0_r4_q8_1<7>; mm.funcs[7] = mul_mat_q4_0_r4_q8_1<8>; - expected_typeB = GGML_TYPE_Q8_1; + expected_typeB = GGML_TYPE_Q8_1_X4; break; case GGML_TYPE_Q5_0_R4: assert (ne00 % QK4_NL == 0); @@ -7701,7 +7744,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { mm.funcs[5] = mul_mat_q5_0_r4_q8_1<6>; mm.funcs[6] = mul_mat_q5_0_r4_q8_1<7>; mm.funcs[7] = mul_mat_q5_0_r4_q8_1<8>; - expected_typeB = GGML_TYPE_Q8_1; + expected_typeB = GGML_TYPE_Q8_1_X4; break; case GGML_TYPE_Q6_0_R4: assert (ne00 % QK4_NL == 0); @@ -7713,7 +7756,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { mm.funcs[5] = mul_mat_q6_0_r4_q8_1<6>; mm.funcs[6] = mul_mat_q6_0_r4_q8_1<7>; mm.funcs[7] = mul_mat_q6_0_r4_q8_1<8>; - expected_typeB = GGML_TYPE_Q8_1; + expected_typeB = GGML_TYPE_Q8_1_X4; break; case GGML_TYPE_Q8_0_R4: assert (ne00 % QK4_NL == 0); @@ -7725,7 +7768,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { mm.funcs[5] = mul_mat_q8_0_r4_q8_1<6>; mm.funcs[6] = mul_mat_q8_0_r4_q8_1<7>; mm.funcs[7] = mul_mat_q8_0_r4_q8_1<8>; - expected_typeB = GGML_TYPE_Q8_1; + expected_typeB = GGML_TYPE_Q8_1_X4; break; default: @@ -11998,35 +12041,35 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { break; case GGML_TYPE_Q4_0: MulMat::set_functions<DequantizerQ40>(m); - expected_Btype = GGML_TYPE_Q8_0; + expected_Btype = GGML_TYPE_Q8_0_X4; break; case GGML_TYPE_Q4_1: MulMat::set_functions<DequantizerQ41>(m); - expected_Btype = GGML_TYPE_Q8_1; + expected_Btype = GGML_TYPE_Q8_1_X4; break; case GGML_TYPE_Q5_0: MulMat::set_functions<DequantizerQ50>(m); - expected_Btype = GGML_TYPE_Q8_0; + expected_Btype = GGML_TYPE_Q8_0_X4; break; case GGML_TYPE_Q5_1: MulMat::set_functions<DequantizerQ51>(m); - expected_Btype = GGML_TYPE_Q8_1; + expected_Btype = GGML_TYPE_Q8_1_X4; break; case GGML_TYPE_Q6_0: MulMat::set_functions<DequantizerQ60>(m); - expected_Btype = GGML_TYPE_Q8_0; + expected_Btype = GGML_TYPE_Q8_0_X4; break; case GGML_TYPE_Q8_0: MulMat::set_functions<DequantizerQ80>(m); - expected_Btype = GGML_TYPE_Q8_0; + expected_Btype = GGML_TYPE_Q8_0_X4; break; case GGML_TYPE_IQ4_NL: MulMat::set_functions<DequantizerIQ4NL>(m); - expected_Btype = GGML_TYPE_Q8_0; + expected_Btype = GGML_TYPE_Q8_0_X4; break; case GGML_TYPE_IQ4_NL_R4: SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r4_q8_0, IQ4_NL_R4_Dequantizer); - expected_Btype = GGML_TYPE_Q8_0; + expected_Btype = GGML_TYPE_Q8_0_X4; break; case GGML_TYPE_IQ4_XS_R4: SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq4_xs_r4_q8_k); @@ -12103,19 +12146,19 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { break; case GGML_TYPE_Q4_0_R4: SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r4_q8_0, Q4_0_R4_Dequantizer); - expected_Btype = GGML_TYPE_Q8_0; + expected_Btype = GGML_TYPE_Q8_0_X4; break; case GGML_TYPE_Q5_0_R4: SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r4_q8_0, Q5_0_R4_Dequantizer); - expected_Btype = GGML_TYPE_Q8_0; + expected_Btype = GGML_TYPE_Q8_0_X4; break; case GGML_TYPE_Q6_0_R4: SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r4_q8_0, Q6_0_R4_Dequantizer); - expected_Btype = GGML_TYPE_Q8_0; + expected_Btype = GGML_TYPE_Q8_0_X4; break; case GGML_TYPE_Q8_0_R4: SET_MUL_MAT_FUNCTIONS(m, mul_mat_q8_0_r4_q8_0); - expected_Btype = GGML_TYPE_Q8_0; + expected_Btype = GGML_TYPE_Q8_0_X4; break; default: return false; @@ -12407,281 +12450,36 @@ struct HelperF16 final : public BaseHelper<step> { } }; -void quantize_row_q8_0(const float * x, block_q8_0 * y, int k) { - const int nb = k / QK8_0; - const int nb4 = 4*(nb/4); - -#if defined(__aarch64__) - block_q8_0_x4 * y4 = (block_q8_0_x4 *)y; - for (int i = 0; i < nb; i++) { - int i4 = i/4, ir = i%4; - float32x4_t srcv [8]; - float32x4_t asrcv[8]; - float32x4_t amaxv[8]; - - for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j); - for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]); - - for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]); - for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]); - for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]); - - const float amax = vmaxvq_f32(amaxv[0]); - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - - if (i < nb4) { - y4[i4].d[ir] = GGML_FP32_TO_FP16(d); - } else { - y[i].d = GGML_FP32_TO_FP16(d); - } - - for (int j = 0; j < 8; j++) { - const float32x4_t v = vmulq_n_f32(srcv[j], id); - const int32x4_t vi = vcvtnq_s32_f32(v); - - if (i < nb4) { - y4[i4].qs[32*ir + 4*j + 0] = vgetq_lane_s32(vi, 0); - y4[i4].qs[32*ir + 4*j + 1] = vgetq_lane_s32(vi, 1); - y4[i4].qs[32*ir + 4*j + 2] = vgetq_lane_s32(vi, 2); - y4[i4].qs[32*ir + 4*j + 3] = vgetq_lane_s32(vi, 3); - } else { - y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); - y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); - y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); - y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); - } - } - } -#else - block_q8_0_x4 * y4 = (block_q8_0_x4 *)y; - for (int i = 0; i < nb; i++) { - int i4 = i/4, ir = i%4; - // Load elements into 4 AVX vectors - __m256 v0 = _mm256_loadu_ps( x ); - __m256 v1 = _mm256_loadu_ps( x + 8 ); - __m256 v2 = _mm256_loadu_ps( x + 16 ); - __m256 v3 = _mm256_loadu_ps( x + 24 ); - x += 32; - - const __m256 signBit = _mm256_set1_ps( -0.0f ); - __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); - - __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); - max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); - max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); - const float maxScalar = _mm_cvtss_f32( max4 ); - - const float d = maxScalar / 127.f; - if (i < nb4) { - y4[i4].d[ir] = GGML_FP32_TO_FP16(d); - } else { - y[i].d = GGML_FP32_TO_FP16(d); - } - const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; - const __m256 mul = _mm256_set1_ps( id ); - - v0 = _mm256_mul_ps( v0, mul ); - v1 = _mm256_mul_ps( v1, mul ); - v2 = _mm256_mul_ps( v2, mul ); - v3 = _mm256_mul_ps( v3, mul ); - - v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); - v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); - v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); - v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); - - __m256i i0 = _mm256_cvtps_epi32( v0 ); - __m256i i1 = _mm256_cvtps_epi32( v1 ); - __m256i i2 = _mm256_cvtps_epi32( v2 ); - __m256i i3 = _mm256_cvtps_epi32( v3 ); - - // Convert int32 to int16 - i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 - i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 - // Convert int16 to int8 - i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 - - // We got our precious signed bytes, but the order is now wrong - // These AVX2 pack instructions process 16-byte pieces independently - // The following instruction is fixing the order - const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); - i0 = _mm256_permutevar8x32_epi32( i0, perm ); - - if (i < nb4) { - _mm256_storeu_si256((__m256i *)y4[i4].qs + ir, i0); - } else { - _mm256_storeu_si256((__m256i *)y[i].qs, i0); - } - } -#endif -} - -void quantize_row_q8_1(const float * x, block_q8_1 * y, int k) { - assert(k % QK8_1 == 0); - const int nb = k / QK8_1; - - const int nb4 = 4*(nb/4); - block_q8_1_x4 * y4 = (block_q8_1_x4 *)y; -#if defined(__aarch64__) - for (int i = 0; i < nb; i++) { - int i4 = i/4, ir = i%4; - float32x4_t srcv [8]; - float32x4_t asrcv[8]; - float32x4_t amaxv[8]; - - for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j); - for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]); - - for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]); - for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]); - for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]); - - const float amax = vmaxvq_f32(amaxv[0]); - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - - if (i < nb4) { - y4[i4].d[ir] = GGML_FP32_TO_FP16(d); - } else { - y[i].d = GGML_FP32_TO_FP16(d); - } - - int32x4_t accv = vdupq_n_s32(0); - - for (int j = 0; j < 8; j++) { - const float32x4_t v = vmulq_n_f32(srcv[j], id); - const int32x4_t vi = vcvtnq_s32_f32(v); - - if (i < nb4) { - y4[i4].qs[QK8_1*ir + 4*j + 0] = vgetq_lane_s32(vi, 0); - y4[i4].qs[QK8_1*ir + 4*j + 1] = vgetq_lane_s32(vi, 1); - y4[i4].qs[QK8_1*ir + 4*j + 2] = vgetq_lane_s32(vi, 2); - y4[i4].qs[QK8_1*ir + 4*j + 3] = vgetq_lane_s32(vi, 3); - } else { - y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); - y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); - y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); - y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); - } - - accv = vaddq_s32(accv, vi); - } - - if (i < nb4) { - y4[i4].d[ir+4] = GGML_FP32_TO_FP16(d * vaddvq_s32(accv)); - } else { - y[i].s = GGML_FP32_TO_FP16(d * vaddvq_s32(accv)); - } - } -#else - for (int i = 0; i < nb; i++) { - int i4 = i/4, ir = i%4; - // Load elements into 4 AVX vectors - __m256 v0 = _mm256_loadu_ps( x ); - __m256 v1 = _mm256_loadu_ps( x + 8 ); - __m256 v2 = _mm256_loadu_ps( x + 16 ); - __m256 v3 = _mm256_loadu_ps( x + 24 ); - x += 32; - - // Compute max(abs(e)) for the block - const __m256 signBit = _mm256_set1_ps( -0.0f ); - __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); - - __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); - max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); - max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); - const float max_scalar = _mm_cvtss_f32( max4 ); - - // Quantize these floats - const float d = max_scalar / 127.f; - if (i < nb4) { - y4[i4].d[ir] = GGML_FP32_TO_FP16(d); - } else { - y[i].d = GGML_FP32_TO_FP16(d); - } - const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f; - const __m256 mul = _mm256_set1_ps( id ); - - // Apply the multiplier - v0 = _mm256_mul_ps( v0, mul ); - v1 = _mm256_mul_ps( v1, mul ); - v2 = _mm256_mul_ps( v2, mul ); - v3 = _mm256_mul_ps( v3, mul ); - - // Round to nearest integer - v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); - v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); - v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); - v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); - - // Convert floats to integers - __m256i i0 = _mm256_cvtps_epi32( v0 ); - __m256i i1 = _mm256_cvtps_epi32( v1 ); - __m256i i2 = _mm256_cvtps_epi32( v2 ); - __m256i i3 = _mm256_cvtps_epi32( v3 ); - - // Compute the sum of the quants and set y[i].s - if (i < nb4) { - y4[i4].d[ir+4] = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)))); - } else { - y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)))); - } - - // Convert int32 to int16 - i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 - i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 - // Convert int16 to int8 - i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 - - // We got our precious signed bytes, but the order is now wrong - // These AVX2 pack instructions process 16-byte pieces independently - // The following instruction is fixing the order - const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); - i0 = _mm256_permutevar8x32_epi32( i0, perm ); - - if (i < nb4) { - _mm256_storeu_si256((__m256i *)y4[i4].qs + ir, i0); - } else { - _mm256_storeu_si256((__m256i *)y[i].qs, i0); - } - } -#endif -} - template <int D, int step> struct HelperQ80 final : public BaseHelper<step> { using Base = BaseHelper<step>; +#ifdef HAVE_FANCY_SIMD + //using block_q8 = block_q8_1; + using block_q8 = block_q8_1; +#else using block_q8 = block_q8_0; +#endif HelperQ80(const char * data, int stride) : Base(data, stride) {} // Needed for v * softmax(k * q) inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const { int j = F16::block_size*i; - auto dl = (const block_q8_0_x4 *)Base::lblock(l1) + j/(4*QK8_0); - int ii = (j/QK8_0)%4; + auto dl = (const block_q8_0 *)Base::lblock(l1) + j/QK8_0; #ifdef __aarch64__ - const float16_t * d = (const float16_t *)dl->d; - auto vd = F16::set1(d[ii]); - auto qs = vld1_s8_x2(dl->qs + 32*ii + j%32); + auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d)); + int ii = j%QK8_0; + auto qs = vld1_s8_x2(dl->qs + ii); v1 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[0]))); v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[1]))); #else - auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d[ii])); + auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d)); #ifdef HAVE_FANCY_SIMD - v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl->qs+2*ii+0)))); - v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl->qs+2*ii+1)))); + v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl->qs+0)))); + v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl->qs+1)))); #else - v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(dl->qs+32*ii+j%32))))); - v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(dl->qs+32*ii+j%32+8))))); + int ii = j%QK8_0; + v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(dl->qs+ii)+0)))); + v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(dl->qs+ii)+1)))); #endif #endif } @@ -12689,7 +12487,7 @@ struct HelperQ80 final : public BaseHelper<step> { static inline void convert(int nq, int stride_q, const float * q, block_q8_0 * y) { GGML_ASSERT(nq <= step); for (int i = 0; i < nq; ++i) { - quantize_row_q8_0(q, y, D); + quantize_row_q8_0_x4(q, y, D); q += stride_q; y += D/QK8_0; } @@ -12698,7 +12496,7 @@ struct HelperQ80 final : public BaseHelper<step> { static inline void convert(int nq, int stride_q, const float * q, block_q8_1 * y) { GGML_ASSERT(nq <= step); for (int i = 0; i < nq; ++i) { - quantize_row_q8_1(q, y, D); + quantize_row_q8_1_x4(q, y, D); q += stride_q; y += D/QK8_1; } @@ -13144,13 +12942,15 @@ struct FlashQKV { } } } - F16::Data v1, v2; - for (int l1 = 0; l1 < k_step; ++l1) { - vh.load(l1, i, v1, v2); + F16::Data v1, v2, v3, v4; + for (int l1 = 0; l1 < k_step; l1 += 2) { + vh.load(l1+0, i, v1, v2); + vh.load(l1+1, i, v3, v4); for (int j = 0; j < q_step; ++j) { - auto vs = F16::set1(fms.cache[k_step*j + l1]); - vk[2*j+0] = F16::fmadd(vk[2*j+0], v1, vs); - vk[2*j+1] = F16::fmadd(vk[2*j+1], v2, vs); + auto vs1 = F16::set1(fms.cache[k_step*j + l1+0]); + auto vs2 = F16::set1(fms.cache[k_step*j + l1+1]); + vk[2*j+0] = F16::fmadd(F16::fmadd(vk[2*j+0], v1, vs1), v3, vs2); + vk[2*j+1] = F16::fmadd(F16::fmadd(vk[2*j+1], v2, vs1), v4, vs2); } } for (int j = 0; j < q_step; ++j) { @@ -13178,13 +12978,15 @@ struct FlashQKV { } } } - F16::Data v1, v2; - for (int l1 = 0; l1 < k_step; ++l1) { - vh.load(l1, i, v1, v2); + F16::Data v1, v2, v3, v4; + for (int l1 = 0; l1 < k_step; l1 += 2) { + vh.load(l1+0, i, v1, v2); + vh.load(l1+1, i, v3, v4); for (int j = 0; j < nq1; ++j) { - auto vs = F16::set1(fms.cache[k_step*j + l1]); - vk[2*j+0] = F16::fmadd(vk[2*j+0], v1, vs); - vk[2*j+1] = F16::fmadd(vk[2*j+1], v2, vs); + auto vs1 = F16::set1(fms.cache[k_step*j + l1+0]); + auto vs2 = F16::set1(fms.cache[k_step*j + l1+1]); + vk[2*j+0] = F16::fmadd(F16::fmadd(vk[2*j+0], v1, vs1), v3, vs2); + vk[2*j+1] = F16::fmadd(F16::fmadd(vk[2*j+1], v2, vs1), v4, vs2); } } for (int j = 0; j < nq1; ++j) { @@ -13388,57 +13190,83 @@ struct FlashQKfp32 { } #endif - template <typename KHelper, typename block_q8> - static inline void mul_mask_kq(const KHelper& kh, int stride_m, - const block_q8 * q, const char * mask, FlashMS<q_step, k_step>& fms) { - static_assert(q_step <= 8); + template <typename KHelper> + static inline std::pair<mul_mat_t, int> mul_mat_kernel(int nq) { + constexpr int kMaxQ = 8; +#define MAKE_FUNCS(mul_mat, n) \ + if (n >= kMaxQ) return std::make_pair(mul_mat, kMaxQ>, kMaxQ);\ + else {\ + switch (n) {\ + case 1: return std::make_pair(mul_mat, 1>, 1);\ + case 2: return std::make_pair(mul_mat, 2>, 2);\ + case 3: return std::make_pair(mul_mat, 3>, 3);\ + case 4: return std::make_pair(mul_mat, 4>, 4);\ + case 5: return std::make_pair(mul_mat, 5>, 5);\ + case 6: return std::make_pair(mul_mat, 6>, 6);\ + case 7: return std::make_pair(mul_mat, 7>, 7);\ + }\ + } if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>>) { - DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr}; #ifdef __aarch64__ - mul_mat_qX_0_q8_0<DequantizerQ40, q_step>(D, kh.block, kh.stride, info, k_step); + MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ40, nq); #else - mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step); + MAKE_FUNCS(mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, nq); #endif } else if constexpr (std::is_same_v<KHelper, HelperQ80<D, k_step>>) { - DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr}; #ifdef __aarch64__ - mul_mat_qX_0_q8_0<DequantizerQ80_x4, q_step>(D, kh.block, kh.stride, info, k_step); + MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ80, nq); #else if constexpr (D >= 128) { - mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step); +#ifdef HAVE_FANCY_SIMD + MAKE_FUNCS(mul_mat_qX_1_q8_1_T<Q8_0_1_Unpacker, nq); +#else + MAKE_FUNCS(mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, nq); +#endif } else { - mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step); + // This does not actually work until we fix K-cache to be quantized to Q8_0_x4 only if D%128 == 0 + MAKE_FUNCS(mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, nq); } #endif } else if constexpr (std::is_same_v<KHelper, HelperQ41<D, k_step>>) { - DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr}; #ifdef __aarch64__ - mul_mat_qX_1_q8_1<DequantizerQ41, q_step>(D, kh.block, kh.stride, info, k_step); + MAKE_FUNCS(mul_mat_qX_1_q8_1<DequantizerQ41, nq); #else - mul_mat_qX_1_q8_1_T<Q4_1_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step); + MAKE_FUNCS(mul_mat_qX_1_q8_1_T<Q4_1_Unpacker, nq); #endif } else if constexpr (std::is_same_v<KHelper, HelperIQ4nl<D, k_step>>) { - DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr}; #ifdef __aarch64__ - mul_mat_qX_0_q8_0<DequantizerIQ4NL, q_step>(D, kh.block, kh.stride, info, k_step); + MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerIQ4NL, nq); #else - mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step); + MAKE_FUNCS(mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, nq); #endif } else if constexpr (std::is_same_v<KHelper, HelperQ60<D, k_step>>) { - DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr}; #ifdef __aarch64__ - mul_mat_qX_0_q8_0<DequantizerQ60, q_step>(D, kh.block, kh.stride, info, k_step); + MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ60, nq); #else - mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step); + MAKE_FUNCS(mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, nq); #endif } else { GGML_ASSERT(false); } + return std::make_pair<mul_mat_t, int>(nullptr, 0); + } + + template <typename KHelper, typename block_q8> + static inline void mul_mask_kq(const KHelper& kh, int stride_m, + const block_q8 * q, const char * mask, FlashMS<q_step, k_step>& fms) { + constexpr int kMaxQ = 8; + static_assert(q_step < kMaxQ || q_step%kMaxQ == 0); + auto [mul_mat, nrc_q] = mul_mat_kernel<KHelper>(q_step); + DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr}; + for (int iq = 0; iq < q_step/nrc_q; ++iq) { + mul_mat(D, kh.block, kh.stride, info, k_step); + info.cur_y += nrc_q; + } #ifdef __aarch64__ float32x4_t vk[k_step/4]; for (int j = 0; j < q_step; ++j) { @@ -13451,136 +13279,21 @@ struct FlashQKfp32 { } #endif } + template <typename KHelper, typename block_q8> static inline void mul_mask_kq(int nq, const KHelper& kh, int stride_m, const block_q8 * q, const char * mask, FlashMS<q_step, k_step>& fms) { - GGML_ASSERT(nq < 8); - if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>>) { - DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr}; - switch (nq) { -#ifdef __aarch64__ - case 1: mul_mat_qX_0_q8_0<DequantizerQ40, 1>(D, kh.block, kh.stride, info, k_step); break; - case 2: mul_mat_qX_0_q8_0<DequantizerQ40, 2>(D, kh.block, kh.stride, info, k_step); break; - case 3: mul_mat_qX_0_q8_0<DequantizerQ40, 3>(D, kh.block, kh.stride, info, k_step); break; - case 4: mul_mat_qX_0_q8_0<DequantizerQ40, 4>(D, kh.block, kh.stride, info, k_step); break; - case 5: mul_mat_qX_0_q8_0<DequantizerQ40, 5>(D, kh.block, kh.stride, info, k_step); break; - case 6: mul_mat_qX_0_q8_0<DequantizerQ40, 6>(D, kh.block, kh.stride, info, k_step); break; - case 7: mul_mat_qX_0_q8_0<DequantizerQ40, 7>(D, kh.block, kh.stride, info, k_step); break; -#else - case 1: mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, 1>(D, kh.block, kh.stride, info, k_step); break; - case 2: mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, 2>(D, kh.block, kh.stride, info, k_step); break; - case 3: mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, 3>(D, kh.block, kh.stride, info, k_step); break; - case 4: mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, 4>(D, kh.block, kh.stride, info, k_step); break; - case 5: mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, 5>(D, kh.block, kh.stride, info, k_step); break; - case 6: mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, 6>(D, kh.block, kh.stride, info, k_step); break; - case 7: mul_mat_qX_0_q8_0_T<Q4_0_Unpacker, 7>(D, kh.block, kh.stride, info, k_step); break; -#endif - } - } - else if constexpr (std::is_same_v<KHelper, HelperQ80<D, k_step>>) { - DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr}; -#ifdef __aarch64__ - switch (nq) { - case 1: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 1>(D, kh.block, kh.stride, info, k_step); break; - case 2: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 2>(D, kh.block, kh.stride, info, k_step); break; - case 3: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 3>(D, kh.block, kh.stride, info, k_step); break; - case 4: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 4>(D, kh.block, kh.stride, info, k_step); break; - case 5: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 5>(D, kh.block, kh.stride, info, k_step); break; - case 6: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 6>(D, kh.block, kh.stride, info, k_step); break; - case 7: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 7>(D, kh.block, kh.stride, info, k_step); break; - } -#else - if constexpr (D >= 128) { - switch (nq) { - case 1: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 1>(D, kh.block, kh.stride, info, k_step); break; - case 2: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 2>(D, kh.block, kh.stride, info, k_step); break; - case 3: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 3>(D, kh.block, kh.stride, info, k_step); break; - case 4: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 4>(D, kh.block, kh.stride, info, k_step); break; - case 5: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 5>(D, kh.block, kh.stride, info, k_step); break; - case 6: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 6>(D, kh.block, kh.stride, info, k_step); break; - case 7: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 7>(D, kh.block, kh.stride, info, k_step); break; - } - } else { - switch (nq) { - case 1: mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, 1>(D, kh.block, kh.stride, info, k_step); break; - case 2: mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, 2>(D, kh.block, kh.stride, info, k_step); break; - case 3: mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, 3>(D, kh.block, kh.stride, info, k_step); break; - case 4: mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, 4>(D, kh.block, kh.stride, info, k_step); break; - case 5: mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, 5>(D, kh.block, kh.stride, info, k_step); break; - case 6: mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, 6>(D, kh.block, kh.stride, info, k_step); break; - case 7: mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, 7>(D, kh.block, kh.stride, info, k_step); break; - } - } -#endif - } - else if constexpr (std::is_same_v<KHelper, HelperQ41<D, k_step>>) { - DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr}; - switch (nq) { -#ifdef __aarch64__ - case 1: mul_mat_qX_1_q8_1<DequantizerQ41, 1>(D, kh.block, kh.stride, info, k_step); break; - case 2: mul_mat_qX_1_q8_1<DequantizerQ41, 2>(D, kh.block, kh.stride, info, k_step); break; - case 3: mul_mat_qX_1_q8_1<DequantizerQ41, 3>(D, kh.block, kh.stride, info, k_step); break; - case 4: mul_mat_qX_1_q8_1<DequantizerQ41, 4>(D, kh.block, kh.stride, info, k_step); break; - case 5: mul_mat_qX_1_q8_1<DequantizerQ41, 5>(D, kh.block, kh.stride, info, k_step); break; - case 6: mul_mat_qX_1_q8_1<DequantizerQ41, 6>(D, kh.block, kh.stride, info, k_step); break; - case 7: mul_mat_qX_1_q8_1<DequantizerQ41, 7>(D, kh.block, kh.stride, info, k_step); break; -#else - case 1: mul_mat_qX_1_q8_1_T<Q4_1_Unpacker, 1>(D, kh.block, kh.stride, info, k_step); break; - case 2: mul_mat_qX_1_q8_1_T<Q4_1_Unpacker, 2>(D, kh.block, kh.stride, info, k_step); break; - case 3: mul_mat_qX_1_q8_1_T<Q4_1_Unpacker, 3>(D, kh.block, kh.stride, info, k_step); break; - case 4: mul_mat_qX_1_q8_1_T<Q4_1_Unpacker, 4>(D, kh.block, kh.stride, info, k_step); break; - case 5: mul_mat_qX_1_q8_1_T<Q4_1_Unpacker, 5>(D, kh.block, kh.stride, info, k_step); break; - case 6: mul_mat_qX_1_q8_1_T<Q4_1_Unpacker, 6>(D, kh.block, kh.stride, info, k_step); break; - case 7: mul_mat_qX_1_q8_1_T<Q4_1_Unpacker, 7>(D, kh.block, kh.stride, info, k_step); break; -#endif - } - } - else if constexpr (std::is_same_v<KHelper, HelperIQ4nl<D, k_step>>) { - DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr}; - switch (nq) { -#ifdef __aarch64__ - case 1: mul_mat_qX_0_q8_0<DequantizerIQ4NL, 1>(D, kh.block, kh.stride, info, k_step); break; - case 2: mul_mat_qX_0_q8_0<DequantizerIQ4NL, 2>(D, kh.block, kh.stride, info, k_step); break; - case 3: mul_mat_qX_0_q8_0<DequantizerIQ4NL, 3>(D, kh.block, kh.stride, info, k_step); break; - case 4: mul_mat_qX_0_q8_0<DequantizerIQ4NL, 4>(D, kh.block, kh.stride, info, k_step); break; - case 5: mul_mat_qX_0_q8_0<DequantizerIQ4NL, 5>(D, kh.block, kh.stride, info, k_step); break; - case 6: mul_mat_qX_0_q8_0<DequantizerIQ4NL, 6>(D, kh.block, kh.stride, info, k_step); break; - case 7: mul_mat_qX_0_q8_0<DequantizerIQ4NL, 7>(D, kh.block, kh.stride, info, k_step); break; -#else - case 1: mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, 1>(D, kh.block, kh.stride, info, k_step); break; - case 2: mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, 2>(D, kh.block, kh.stride, info, k_step); break; - case 3: mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, 3>(D, kh.block, kh.stride, info, k_step); break; - case 4: mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, 4>(D, kh.block, kh.stride, info, k_step); break; - case 5: mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, 5>(D, kh.block, kh.stride, info, k_step); break; - case 6: mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, 6>(D, kh.block, kh.stride, info, k_step); break; - case 7: mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, 7>(D, kh.block, kh.stride, info, k_step); break; -#endif - } - } - else if constexpr (std::is_same_v<KHelper, HelperQ60<D, k_step>>) { - DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr}; - switch (nq) { -#ifdef __aarch64__ - case 1: mul_mat_qX_0_q8_0<DequantizerQ60, 1>(D, kh.block, kh.stride, info, k_step); break; - case 2: mul_mat_qX_0_q8_0<DequantizerQ60, 2>(D, kh.block, kh.stride, info, k_step); break; - case 3: mul_mat_qX_0_q8_0<DequantizerQ60, 3>(D, kh.block, kh.stride, info, k_step); break; - case 4: mul_mat_qX_0_q8_0<DequantizerQ60, 4>(D, kh.block, kh.stride, info, k_step); break; - case 5: mul_mat_qX_0_q8_0<DequantizerQ60, 5>(D, kh.block, kh.stride, info, k_step); break; - case 6: mul_mat_qX_0_q8_0<DequantizerQ60, 6>(D, kh.block, kh.stride, info, k_step); break; - case 7: mul_mat_qX_0_q8_0<DequantizerQ60, 7>(D, kh.block, kh.stride, info, k_step); break; -#else - case 1: mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, 1>(D, kh.block, kh.stride, info, k_step); break; - case 2: mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, 2>(D, kh.block, kh.stride, info, k_step); break; - case 3: mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, 3>(D, kh.block, kh.stride, info, k_step); break; - case 4: mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, 4>(D, kh.block, kh.stride, info, k_step); break; - case 5: mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, 5>(D, kh.block, kh.stride, info, k_step); break; - case 6: mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, 6>(D, kh.block, kh.stride, info, k_step); break; - case 7: mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, 7>(D, kh.block, kh.stride, info, k_step); break; -#endif - } - } - else { - GGML_ASSERT(false); + auto [mul_mat, nrc_q] = mul_mat_kernel<KHelper>(nq); + DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr}; + for (int iq = 0; iq < nq/nrc_q; ++iq) { + mul_mat(D, kh.block, kh.stride, info, k_step); + info.cur_y += nrc_q; + } + int iq = nrc_q*(nq/nrc_q); + if (iq < nq) { + auto [mul_mat1, nrc_q1] = mul_mat_kernel<KHelper>(nq - iq); + GGML_ASSERT(nrc_q1 == nq - iq); + mul_mat1(D, kh.block, kh.stride, info, k_step); } #ifdef __aarch64__ float32x4_t vk[k_step/4]; @@ -13864,6 +13577,11 @@ struct FlashQKbf16 { } } + static inline __m128 hsum_float_4x4(__m128 * a) { + for (int i = 0; i < 2; ++i) a[i] = _mm_add_ps(_mm_unpacklo_ps(a[i], a[i+2]), _mm_unpackhi_ps(a[i], a[i+2])); + return _mm_add_ps(_mm_unpacklo_ps(a[0], a[1]), _mm_unpackhi_ps(a[0], a[1])); + } + template <typename KHelper> static inline void multiply_mask_kq(const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask, FlashMS<q_step, k_step>& fms) { @@ -13893,6 +13611,34 @@ struct FlashQKbf16 { } } + static inline void mult_mask_kq_4(int l1, int m1, const ggml_bf16_t * q, + __m512bh * qv, const __m512bh * vkh, FlashMS<q_step, k_step>& fms) { + auto qr = q + m1*D; + for (int i = 0; i < D/32; ++i) qv[i] = __m512bh(_mm512_loadu_si512((const __m512i *)qr + i)); + __m128 sum[4]; + for (int k = 0; k < 4; ++k) { + auto vsum = _mm512_setzero_ps(); + for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i+k*(D/32)], qv[i]); + auto aux = _mm256_add_ps(_mm512_castps512_ps256(vsum), _mm512_extractf32x8_ps(vsum, 1)); + sum[k] = _mm_add_ps(_mm256_castps256_ps128(aux), _mm256_extractf128_ps(aux, 1)); + } + //auto sum4 = _mm_mask_blend_ps(m8, hsum_float_4x4(sum), _mm_set1_ps(-INFINITY)); + //_mm_storeu_ps(fms.cache + k_step*m1 + l1, sum4); + _mm_storeu_ps(fms.cache + k_step*m1 + l1, hsum_float_4x4(sum)); + } + + static inline void mult_mask_kq_one(int l1, int m1, const ggml_bf16_t * q, + __m512bh * qv, const __m512bh * vkh, FlashMS<q_step, k_step>& fms) { + auto qr = q + m1*D; + for (int i = 0; i < D/32; ++i) qv[i] = __m512bh(_mm512_loadu_si512((const __m512i*)qr + i)); + auto vsum = _mm512_setzero_ps(); + for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i], qv[i]); + fms.cache[k_step*m1 + l1 + 0] = _mm512_reduce_add_ps(vsum); + vsum = _mm512_setzero_ps(); + for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i+D/32], qv[i]); + fms.cache[k_step*m1 + l1 + 1] = _mm512_reduce_add_ps(vsum); + } + template <typename KHelper> static inline void multiply_mask_kq(const KHelper& kh, int stride_m, const ggml_bf16_t * q, const char * mask, FlashMS<q_step, k_step>& fms) { @@ -13902,23 +13648,44 @@ struct FlashQKbf16 { __m512bh vkh[D/8]; for (int l1 = 0; l1 < k_step; l1 += 4) { kh.load_4(l1, vkh); - for (int j = 0; j < q_step; ++j) { - mult_mask_kq_4(l1, j, stride_m, q, mask, qv, vkh, fms); - } + for (int j = 0; j < q_step; ++j) mult_mask_kq_4(l1, j, q, qv, vkh, fms); } } else { __m512bh vkh[D/16]; for (int l1 = 0; l1 < k_step; l1 += 2) { kh.load_2(l1, vkh); - for (int j = 0; j < q_step; ++j) { - mult_mask_kq_one(l1, j, stride_m, q, mask, qv, vkh, fms); - } + for (int j = 0; j < q_step; ++j) mult_mask_kq_one(l1, j, q, qv, vkh, fms); } } } - __m512 vk[k_step/16]; + F16::Data vk[k_step/16]; for (int j = 0; j < q_step; ++j) { - fms.update_M_S(j, vk); + fms.update_M_S(j, vk, mask + stride_m*j); + } + } + + template <typename KHelper> + static inline void multiply_mask_kq(int nq, const KHelper& kh, int stride_m, const ggml_bf16_t * q, + const char * mask, FlashMS<q_step, k_step>& fms) { + { + __m512bh qv[D/32]; + if constexpr (D <= 128) { + __m512bh vkh[D/8]; + for (int l1 = 0; l1 < k_step; l1 += 4) { + kh.load_4(l1, vkh); + for (int j = 0; j < nq; ++j) mult_mask_kq_4(l1, j, q, qv, vkh, fms); + } + } else { + __m512bh vkh[D/16]; + for (int l1 = 0; l1 < k_step; l1 += 2) { + kh.load_2(l1, vkh); + for (int j = 0; j < nq; ++j) mult_mask_kq_one(l1, j, q, qv, vkh, fms); + } + } + } + F16::Data vk[k_step/16]; + for (int j = 0; j < nq; ++j) { + fms.update_M_S(j, vk, mask + stride_m*j); } } @@ -13953,6 +13720,19 @@ struct FlashQKbf16 { bf16 += D; } } + + static inline void convert(int nq, int stride_q, const float * q, ggml_bf16_t * bf16) { + auto qr = q; + for (int j = 0; j < nq; ++j) { + for (int i = 0; i < D/32; ++i) { + auto val1 = _mm512_loadu_ps(qr + 32*i); + auto val2 = _mm512_loadu_ps(qr + 32*i + 16); + _mm512_storeu_si512((__m512i *)bf16 + i, (__m512i)_mm512_cvtne2ps_pbh(val2, val1)); + } + qr += stride_q; + bf16 += D; + } + } }; template <int D, int q_step, int k_step> @@ -13991,9 +13771,10 @@ struct FlashAttnBF16 { fms.init_qstep(); kh.reset_block(); vh.reset_block(); + FlashQKbf16<D, q_step, k_step>::convert(n_left, stride_q, q, q_bf16); auto mr = mask; for (int k1 = 0; k1 < nk1/k_step; ++k1) { - FlashQKbf16<D, q_step, k_step>::multiply_mask_kq(n_left, kh, stride_q, stride_m, q, mr, fms); + FlashQKbf16<D, q_step, k_step>::multiply_mask_kq(n_left, kh, stride_m, q_bf16, mr, fms); fqkv.accumulate_qkv(n_left, vh, fms); kh.next_block(); vh.next_block(); @@ -14008,28 +13789,59 @@ struct FlashAttnBF16 { }; #endif -template <int D, int q_step, int k_step, typename KHelper, typename VHelper> +template <int D, int k_step, typename KHelper, typename VHelper> inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, const float * q, const char * mask, float scale, float softcap, float * qkv) { - if (nq1 >= q_step) { - FlashAttn<D, q_step, k_step> fa(scale, softcap); +#if defined __AVX2__ + constexpr bool kUseLargeStepsQ = !std::is_same_v<KHelper, HelperF16<D, k_step>>; +#else + constexpr bool kUseLargeStepsQ = true; +#endif + if constexpr (kUseLargeStepsQ) { + if (nk1 >= 4096) { + if (nq1 >= 32) { + FlashAttn<D, 32, k_step> fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + return; + } + else if (nq1 >= 8) { + FlashAttn<D, 8, k_step> fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + return; + } + } + } + if (nq1 >= 8) { + FlashAttn<D, 8, k_step> fa(scale, softcap); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); - } else { + } + else { FlashAttn<D, 1, k_step> fa(scale, softcap); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); } } #ifdef __AVX512BF16__ -template <int D, int q_step, int k_step> +template <int D, int k_step> inline void iqk_flash_helper_T(int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv, const float * q, const char * k, const char * v, const char * mask, float scale, float softcap, float * qkv) { HelperBF16<D, k_step> kh(k, stride_k); HelperBF16<D, k_step> vh(v, stride_v); - if (nq1 >= q_step) { - FlashAttnBF16<D, q_step, k_step> fa(scale, softcap); + if (nk1 >= 4096) { + if (nq1 >= 64) { + FlashAttnBF16<D, 64, k_step> fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + } + else if (nq1 >= 16) { + FlashAttnBF16<D, 16, k_step> fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + } + return; + } + if (nq1 >= 8) { + FlashAttnBF16<D, 8, k_step> fa(scale, softcap); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); } else { FlashAttnBF16<D, 1, k_step> fa(scale, softcap); @@ -14038,7 +13850,7 @@ inline void iqk_flash_helper_T(int nq1, int nk1, int stride_q, int stride_k, int } #endif -template <int D, int q_step, int k_step, typename KHelper> +template <int D, int k_step, typename KHelper> inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v, int nq1, int nk1, int stride_q, int stride_v, int stride_m, int stride_qkv, const float * q, const char * v, const char * mask, @@ -14047,33 +13859,39 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v, switch (type_v) { case GGML_TYPE_F16: { HelperF16<D, k_step> vh(v, stride_v); - iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; +#ifdef HAVE_FANCY_SIMD + case GGML_TYPE_BF16: { + HelperBF16<D, k_step> vh(v, stride_v); + iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + } break; +#endif case GGML_TYPE_Q8_0: { HelperQ80<D, k_step> vh(v, stride_v); - iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; case GGML_TYPE_Q4_0: { HelperQ40<D, k_step> vh(v, stride_v); - iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; case GGML_TYPE_Q4_1: { HelperQ41<D, k_step> vh(v, stride_v); - iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; case GGML_TYPE_IQ4_NL: { HelperIQ4nl<D, k_step> vh(v, stride_v); - iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; case GGML_TYPE_Q6_0: { HelperQ60<D, k_step> vh(v, stride_v); - iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + iqk_flash_helper<D, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; default: break; } } -template <int D, int q_step, int k_step> +template <int D, int k_step> inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv, const float * q, const char * k, const char * v, const char * mask, @@ -14082,27 +13900,27 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, switch (type_k) { case GGML_TYPE_F16: { HelperF16<D, k_step> kh(k, stride_k); - iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; case GGML_TYPE_Q8_0: { HelperQ80<D, k_step> kh(k, stride_k); - iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; case GGML_TYPE_Q4_0: { HelperQ40<D, k_step> kh(k, stride_k); - iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; case GGML_TYPE_Q4_1: { HelperQ41<D, k_step> kh(k, stride_k); - iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; case GGML_TYPE_IQ4_NL: { HelperIQ4nl<D, k_step> kh(k, stride_k); - iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; case GGML_TYPE_Q6_0: { HelperQ60<D, k_step> kh(k, stride_k); - iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + iqk_flash_helper_T<D, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; default: break; } @@ -14149,17 +13967,17 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k stride_q /= sizeof(float); // q stride as float #ifdef __AVX512BF16__ - if (type_k == GGML_TYPE_BF16 || type_v == GGML_TYPE_BF16) { - if (type_k != GGML_TYPE_BF16 || type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 with other types + if (type_k == GGML_TYPE_BF16) { + if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types switch (D) { case 64: - iqk_flash_helper_T< 64, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 64, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 96: - iqk_flash_helper_T< 96, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 96, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 128: - iqk_flash_helper_T<128, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<128, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 256: - iqk_flash_helper_T<256, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<256, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; default: return false; } @@ -14168,21 +13986,42 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k } #endif + if (nk1%64 == 0) { + switch (D) { + case 64: + iqk_flash_helper_T< 64, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + // Disable until we fix accumulate_qkv for odd D/16 + //case 80: + // iqk_flash_helper_T< 80, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + case 96: + iqk_flash_helper_T< 96, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + // Disable until we fix accumulate_qkv for odd D/16 + //case 112: + // iqk_flash_helper_T<112, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + case 128: + iqk_flash_helper_T<128, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + case 256: + iqk_flash_helper_T<256, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + default: + return false; + } + return true; + } switch (D) { case 64: - iqk_flash_helper_T< 64, F16::q_step, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 64, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; // Disable until we fix accumulate_qkv for odd D/16 //case 80: - // iqk_flash_helper_T< 80, 4, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + // iqk_flash_helper_T< 80, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 96: - iqk_flash_helper_T< 96, F16::q_step, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 96, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; // Disable until we fix accumulate_qkv for odd D/16 //case 112: - // iqk_flash_helper_T<112, 4, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + // iqk_flash_helper_T<112, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 128: - iqk_flash_helper_T<128, F16::q_step, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<128, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 256: - iqk_flash_helper_T<256, F16::q_step, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<256, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; default: return false; } diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 2404246d..a2ade6a7 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -654,6 +654,257 @@ void quantize_row_q8_K16(const float * x, void * vy, int64_t nk) { #endif } +void quantize_row_q8_0_x4(const float * x, void * vy, int64_t k) { + const int nb = k / QK8_0; + const int nb4 = 4*(nb/4); + + block_q8_0 * y = (block_q8_0 *)vy; + block_q8_0_x4 * y4 = (block_q8_0_x4 *)vy; +#if defined(__aarch64__) + for (int i = 0; i < nb; i++) { + int i4 = i/4, ir = i%4; + float32x4_t srcv [8]; + float32x4_t asrcv[8]; + float32x4_t amaxv[8]; + + for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j); + for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]); + + for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]); + for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]); + for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]); + + const float amax = vmaxvq_f32(amaxv[0]); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + if (i < nb4) { + y4[i4].d[ir] = GGML_FP32_TO_FP16(d); + } else { + y[i].d = GGML_FP32_TO_FP16(d); + } + + for (int j = 0; j < 8; j++) { + const float32x4_t v = vmulq_n_f32(srcv[j], id); + const int32x4_t vi = vcvtnq_s32_f32(v); + + if (i < nb4) { + y4[i4].qs[32*ir + 4*j + 0] = vgetq_lane_s32(vi, 0); + y4[i4].qs[32*ir + 4*j + 1] = vgetq_lane_s32(vi, 1); + y4[i4].qs[32*ir + 4*j + 2] = vgetq_lane_s32(vi, 2); + y4[i4].qs[32*ir + 4*j + 3] = vgetq_lane_s32(vi, 3); + } else { + y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); + y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); + y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); + y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); + } + } + } +#else + for (int i = 0; i < nb; i++) { + int i4 = i/4, ir = i%4; + // Load elements into 4 AVX vectors + __m256 v0 = _mm256_loadu_ps( x ); + __m256 v1 = _mm256_loadu_ps( x + 8 ); + __m256 v2 = _mm256_loadu_ps( x + 16 ); + __m256 v3 = _mm256_loadu_ps( x + 24 ); + x += 32; + + const __m256 signBit = _mm256_set1_ps( -0.0f ); + __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); + + __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); + max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); + max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); + const float maxScalar = _mm_cvtss_f32( max4 ); + + const float d = maxScalar / 127.f; + if (i < nb4) { + y4[i4].d[ir] = GGML_FP32_TO_FP16(d); + } else { + y[i].d = GGML_FP32_TO_FP16(d); + } + const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; + const __m256 mul = _mm256_set1_ps( id ); + + v0 = _mm256_mul_ps( v0, mul ); + v1 = _mm256_mul_ps( v1, mul ); + v2 = _mm256_mul_ps( v2, mul ); + v3 = _mm256_mul_ps( v3, mul ); + + v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); + v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); + v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); + v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); + + __m256i i0 = _mm256_cvtps_epi32( v0 ); + __m256i i1 = _mm256_cvtps_epi32( v1 ); + __m256i i2 = _mm256_cvtps_epi32( v2 ); + __m256i i3 = _mm256_cvtps_epi32( v3 ); + + // Convert int32 to int16 + i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 + i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 + // Convert int16 to int8 + i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 + + // We got our precious signed bytes, but the order is now wrong + // These AVX2 pack instructions process 16-byte pieces independently + // The following instruction is fixing the order + const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); + i0 = _mm256_permutevar8x32_epi32( i0, perm ); + + if (i < nb4) { + _mm256_storeu_si256((__m256i *)y4[i4].qs + ir, i0); + } else { + _mm256_storeu_si256((__m256i *)y[i].qs, i0); + } + } +#endif +} + +void quantize_row_q8_1_x4(const float * x, void * vy, int64_t k) { + assert(k % QK8_1 == 0); + const int nb = k / QK8_1; + + const int nb4 = 4*(nb/4); + block_q8_1 * y = (block_q8_1 *)vy; + block_q8_1_x4 * y4 = (block_q8_1_x4 *)vy; +#if defined(__aarch64__) + for (int i = 0; i < nb; i++) { + int i4 = i/4, ir = i%4; + float32x4_t srcv [8]; + float32x4_t asrcv[8]; + float32x4_t amaxv[8]; + + for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j); + for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]); + + for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]); + for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]); + for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]); + + const float amax = vmaxvq_f32(amaxv[0]); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + if (i < nb4) { + y4[i4].d[ir] = GGML_FP32_TO_FP16(d); + } else { + y[i].d = GGML_FP32_TO_FP16(d); + } + + int32x4_t accv = vdupq_n_s32(0); + + for (int j = 0; j < 8; j++) { + const float32x4_t v = vmulq_n_f32(srcv[j], id); + const int32x4_t vi = vcvtnq_s32_f32(v); + + if (i < nb4) { + y4[i4].qs[QK8_1*ir + 4*j + 0] = vgetq_lane_s32(vi, 0); + y4[i4].qs[QK8_1*ir + 4*j + 1] = vgetq_lane_s32(vi, 1); + y4[i4].qs[QK8_1*ir + 4*j + 2] = vgetq_lane_s32(vi, 2); + y4[i4].qs[QK8_1*ir + 4*j + 3] = vgetq_lane_s32(vi, 3); + } else { + y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); + y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); + y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); + y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); + } + + accv = vaddq_s32(accv, vi); + } + + if (i < nb4) { + y4[i4].d[ir+4] = GGML_FP32_TO_FP16(d * vaddvq_s32(accv)); + } else { + y[i].s = GGML_FP32_TO_FP16(d * vaddvq_s32(accv)); + } + } +#else + for (int i = 0; i < nb; i++) { + int i4 = i/4, ir = i%4; + // Load elements into 4 AVX vectors + __m256 v0 = _mm256_loadu_ps( x ); + __m256 v1 = _mm256_loadu_ps( x + 8 ); + __m256 v2 = _mm256_loadu_ps( x + 16 ); + __m256 v3 = _mm256_loadu_ps( x + 24 ); + x += 32; + + // Compute max(abs(e)) for the block + const __m256 signBit = _mm256_set1_ps( -0.0f ); + __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); + + __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); + max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); + max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); + const float max_scalar = _mm_cvtss_f32( max4 ); + + // Quantize these floats + const float d = max_scalar / 127.f; + if (i < nb4) { + y4[i4].d[ir] = GGML_FP32_TO_FP16(d); + } else { + y[i].d = GGML_FP32_TO_FP16(d); + } + const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f; + const __m256 mul = _mm256_set1_ps( id ); + + // Apply the multiplier + v0 = _mm256_mul_ps( v0, mul ); + v1 = _mm256_mul_ps( v1, mul ); + v2 = _mm256_mul_ps( v2, mul ); + v3 = _mm256_mul_ps( v3, mul ); + + // Round to nearest integer + v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); + v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); + v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); + v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); + + // Convert floats to integers + __m256i i0 = _mm256_cvtps_epi32( v0 ); + __m256i i1 = _mm256_cvtps_epi32( v1 ); + __m256i i2 = _mm256_cvtps_epi32( v2 ); + __m256i i3 = _mm256_cvtps_epi32( v3 ); + + // Compute the sum of the quants and set y[i].s + if (i < nb4) { + y4[i4].d[ir+4] = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)))); + } else { + y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)))); + } + + // Convert int32 to int16 + i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 + i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 + // Convert int16 to int8 + i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 + + // We got our precious signed bytes, but the order is now wrong + // These AVX2 pack instructions process 16-byte pieces independently + // The following instruction is fixing the order + const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); + i0 = _mm256_permutevar8x32_epi32( i0, perm ); + + if (i < nb4) { + _mm256_storeu_si256((__m256i *)y4[i4].qs + ir, i0); + } else { + _mm256_storeu_si256((__m256i *)y[i].qs, i0); + } + } +#endif +} + // // ============================================== iq2_K // diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h index 68a38811..729b0ec0 100644 --- a/ggml/src/iqk/iqk_quantize.h +++ b/ggml/src/iqk/iqk_quantize.h @@ -211,6 +211,8 @@ void quantize_row_q8_K64(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, void quantize_row_q8_K16(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q8_K32(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q8_KR8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_0_x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_1_x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void repack_f32_bf16_r16 (const void * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row); void repack_bf16_bf16_r16(const void * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row); |