summaryrefslogtreecommitdiff
path: root/ggml/src/ggml.c
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-01-15 18:19:22 +0200
committerGitHub <noreply@github.com>2025-01-15 18:19:22 +0200
commit0b74397d596bbcdfba27299393406d2b6330b133 (patch)
tree2101d059f79b6b268086c71878aa2da1c328c73d /ggml/src/ggml.c
parent49b27069fd267d3dac8de5d13141b4274e4be16b (diff)
CPU Flash Attention improvements (#172)
* Slightly faster FA for bf16 KV cache ~2-3% sort of thing. Sadly, when we go beyond 8k tokens, the advantage kind of goes away. * Slightly faster FA for Q8_0 KV cache * FA: allow bf16 for V-cache with any supported K-cache E.g., -ctk q8_0 -ctv bf16 is slightly faster than -ctk q8_0 -ctv q8_0 on Zen4 for not too long context lengths (say, <= 4096). * FA: much better bf16 kv-cache speed for large contexts We now hit 122 t/s for LLaMA-3.1-8B (quantized as iq4_xs and run-time-repacked) with a context of 32768. IIRC, the previous best for such large context was ~90 t/s. Non-negligible improvement at 16384 and 8192 as well: 173.4 and 214 t/s. * FA: slightly better quantized kv-cache speed for large contexts E.g., for q8_0 and context of 32768, we are now at 113 t/s for LLaMA-3.1-8B. Also simplified the quantized K*Q multiplication. * Fix q8_0 KV cache when not using FA - WIP (AVX2) 1. We add new types GGML_TYPE_Q8_0_X4 and GGML_TYPE_Q8_1_X4, and use those to quantize activations for quants that use Q8_0 or Q8_1 as their vec_dot type. 2. We revert the changes to quantize_row_q8_0 and quantize_row_q8_1 3. We use GGML_TYPE_Q8_0_X4 and GGML_TYPE_Q8_1_X4 as the vec_dot type 4. We change the FA implementation to use GGML_TYPE_Q8_0 rather than GGML_TYPE_Q8_0_X4 as the K and V types 5. We change the expected type to GGML_TYPE_Q8_0_X4/GGML_TYPE_Q8_1_X4 in iqk_mul_mat Also added an optimization in ggml_compute_forward_mul_mat when ne12*ne13 > 1 (K*Q and V*softmax(K*Q)) to process n12*ne13/GCD(n12*ne13, nthread) threads simultaneously using nthread/GCD(n12*ne13, nthread) threads per head. This results in a non-negligible performance gain for large contexts. Question: why is it not allowed to use quantized V-cache when not using FA? * Fix q8_0 KV cache when not using FA - NEON * Fix AVX2 Again the issue with _mm256_maddubs_epi16 overflowing that I keep forgetting. * FA: don't use large Q steps on AVX2 for fp16 K-cache * On Zen4 it is also better to not use large Q steps for fp16 K-cache --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml/src/ggml.c')
-rw-r--r--ggml/src/ggml.c167
1 files changed, 139 insertions, 28 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index b9f9b3d8..bcb8bf41 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -714,8 +714,12 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_q4_0,
.from_float_ref = (ggml_from_float_t) quantize_row_q4_0_ref,
.vec_dot = ggml_vec_dot_q4_0_q8_0,
-#if GGML_USE_IQK_MULMAT && defined __AVX2__
- .vec_dot_type = GGML_TYPE_Q8_1,
+#if GGML_USE_IQK_MULMAT
+#if defined __AVX2__
+ .vec_dot_type = GGML_TYPE_Q8_1_X4,
+#else
+ .vec_dot_type = GGML_TYPE_Q8_0_X4,
+#endif
#else
.vec_dot_type = GGML_TYPE_Q8_0,
#endif
@@ -735,7 +739,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_q4_1,
.from_float_ref = (ggml_from_float_t) quantize_row_q4_1_ref,
.vec_dot = ggml_vec_dot_q4_1_q8_1,
+#if GGML_USE_IQK_MULMAT
+ .vec_dot_type = GGML_TYPE_Q8_1_X4,
+#else
.vec_dot_type = GGML_TYPE_Q8_1,
+#endif
#if defined (__ARM_FEATURE_MATMUL_INT8)
.nrows = 2,
#else
@@ -778,8 +786,12 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_q5_0,
.from_float_ref = (ggml_from_float_t) quantize_row_q5_0_ref,
.vec_dot = ggml_vec_dot_q5_0_q8_0,
-#if GGML_USE_IQK_MULMAT && defined __AVX2__
- .vec_dot_type = GGML_TYPE_Q8_1,
+#if GGML_USE_IQK_MULMAT
+#if defined __AVX2__
+ .vec_dot_type = GGML_TYPE_Q8_1_X4,
+#else
+ .vec_dot_type = GGML_TYPE_Q8_0_X4,
+#endif
#else
.vec_dot_type = GGML_TYPE_Q8_0,
#endif
@@ -795,7 +807,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_q5_1,
.from_float_ref = (ggml_from_float_t) quantize_row_q5_1_ref,
.vec_dot = ggml_vec_dot_q5_1_q8_1,
+#if GGML_USE_IQK_MULMAT
+ .vec_dot_type = GGML_TYPE_Q8_1_X4,
+#else
.vec_dot_type = GGML_TYPE_Q8_1,
+#endif
.nrows = 1,
.row_meta_size = 0,
},
@@ -808,8 +824,12 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_q6_0,
.from_float_ref = (ggml_from_float_t) quantize_row_q6_0_ref,
.vec_dot = ggml_vec_dot_q6_0_q8_0,
-#if GGML_USE_IQK_MULMAT && defined __AVX2__
- .vec_dot_type = GGML_TYPE_Q8_1,
+#if GGML_USE_IQK_MULMAT
+#if defined __AVX2__
+ .vec_dot_type = GGML_TYPE_Q8_1_X4,
+#else
+ .vec_dot_type = GGML_TYPE_Q8_0_X4,
+#endif
#else
.vec_dot_type = GGML_TYPE_Q8_0,
#endif
@@ -826,8 +846,16 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float_ref = (ggml_from_float_t) quantize_row_q8_0_ref,
.from_float_to_mat = quantize_mat_q8_0,
.vec_dot = ggml_vec_dot_q8_0_q8_0,
-#if GGML_USE_IQK_MULMAT && defined __AVX2__
- .vec_dot_type = GGML_TYPE_Q8_1,
+#if GGML_USE_IQK_MULMAT
+#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__)
+ // Remember: we cannot add 128 to the Q8 quants and use iblock sum in Q8_1 to subtract as we do on Zen4 for pure AVX2
+ // because there the result of the _mm256_maddubs_epi16() instruction may overflow the int16_t range
+ // (and it gets satured if it does), leading to wrong results.
+ // TODO: expose HAVE_FANCY_SIMD from iqk_mul_mat.cpp and use #ifdef HAVE_FANCY_SIMD instead of the above.
+ .vec_dot_type = GGML_TYPE_Q8_1_X4,
+#else
+ .vec_dot_type = GGML_TYPE_Q8_0_X4,
+#endif
#else
.vec_dot_type = GGML_TYPE_Q8_0,
#endif
@@ -849,6 +877,26 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.nrows = 1,
.row_meta_size = 0,
},
+ [GGML_TYPE_Q8_0_X4] = {
+ .type_name = "q8_0_x4",
+ .blck_size = QK8_0,
+ .type_size = sizeof(block_q8_0),
+ .is_quantized = true,
+ .from_float = quantize_row_q8_0_x4,
+ .from_float_ref = quantize_row_q8_0_x4,
+ .nrows = 1,
+ .row_meta_size = 0,
+ },
+ [GGML_TYPE_Q8_1_X4] = {
+ .type_name = "q8_1_x4",
+ .blck_size = QK8_1,
+ .type_size = sizeof(block_q8_1),
+ .is_quantized = true,
+ .from_float = quantize_row_q8_1_x4,
+ .from_float_ref = quantize_row_q8_1_x4,
+ .nrows = 1,
+ .row_meta_size = 0,
+ },
[GGML_TYPE_Q2_K] = {
.type_name = "q2_K",
.blck_size = QK_K,
@@ -1196,8 +1244,12 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_iq4_nl,
.from_float_ref = (ggml_from_float_t)quantize_row_iq4_nl_ref,
.vec_dot = ggml_vec_dot_iq4_nl_q8_0,
-#if GGML_USE_IQK_MULMAT && defined __AVX2__
- .vec_dot_type = GGML_TYPE_Q8_1,
+#if GGML_USE_IQK_MULMAT
+#if defined __AVX2__
+ .vec_dot_type = GGML_TYPE_Q8_1_X4,
+#else
+ .vec_dot_type = GGML_TYPE_Q8_0_X4,
+#endif
#else
.vec_dot_type = GGML_TYPE_Q8_0,
#endif
@@ -1516,8 +1568,12 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_iq4_nl_r4,
.from_float_ref = (ggml_from_float_t)quantize_row_iq4_nl_r4_ref,
.vec_dot = vec_dot_iq4_nl_r4_q8_0,
-#if GGML_USE_IQK_MULMAT && defined __AVX2__
- .vec_dot_type = GGML_TYPE_Q8_1,
+#if GGML_USE_IQK_MULMAT
+#if defined __AVX2__
+ .vec_dot_type = GGML_TYPE_Q8_1_X4,
+#else
+ .vec_dot_type = GGML_TYPE_Q8_0_X4,
+#endif
#else
.vec_dot_type = GGML_TYPE_Q8_0,
#endif
@@ -1546,8 +1602,12 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_q4_0_r4,
.from_float_ref = (ggml_from_float_t)quantize_row_q4_0_r4_ref,
.vec_dot = vec_dot_q4_0_r4_q8_0,
-#if GGML_USE_IQK_MULMAT && defined __AVX2__
- .vec_dot_type = GGML_TYPE_Q8_1,
+#if GGML_USE_IQK_MULMAT
+#if defined __AVX2__
+ .vec_dot_type = GGML_TYPE_Q8_1_X4,
+#else
+ .vec_dot_type = GGML_TYPE_Q8_0_X4,
+#endif
#else
.vec_dot_type = GGML_TYPE_Q8_0,
#endif
@@ -1563,8 +1623,12 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_q8_0_r4,
.from_float_ref = (ggml_from_float_t)quantize_row_q8_0_r4_ref,
.vec_dot = vec_dot_q8_0_r4_q8_0,
-#if GGML_USE_IQK_MULMAT && defined __AVX2__
- .vec_dot_type = GGML_TYPE_Q8_1,
+#if GGML_USE_IQK_MULMAT
+#if defined __AVX2__
+ .vec_dot_type = GGML_TYPE_Q8_1_X4,
+#else
+ .vec_dot_type = GGML_TYPE_Q8_0_X4,
+#endif
#else
.vec_dot_type = GGML_TYPE_Q8_0,
#endif
@@ -1580,8 +1644,12 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_q5_0_r4,
.from_float_ref = (ggml_from_float_t)quantize_row_q5_0_r4_ref,
.vec_dot = vec_dot_q5_0_r4_q8_0,
-#if GGML_USE_IQK_MULMAT && defined __AVX2__
- .vec_dot_type = GGML_TYPE_Q8_1,
+#if GGML_USE_IQK_MULMAT
+#if defined __AVX2__
+ .vec_dot_type = GGML_TYPE_Q8_1_X4,
+#else
+ .vec_dot_type = GGML_TYPE_Q8_0_X4,
+#endif
#else
.vec_dot_type = GGML_TYPE_Q8_0,
#endif
@@ -1597,8 +1665,12 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_q6_0_r4,
.from_float_ref = (ggml_from_float_t)quantize_row_q6_0_r4_ref,
.vec_dot = vec_dot_q6_0_r4_q8_0,
-#if GGML_USE_IQK_MULMAT && defined __AVX2__
- .vec_dot_type = GGML_TYPE_Q8_1,
+#if GGML_USE_IQK_MULMAT
+#if defined __AVX2__
+ .vec_dot_type = GGML_TYPE_Q8_1_X4,
+#else
+ .vec_dot_type = GGML_TYPE_Q8_0_X4,
+#endif
#else
.vec_dot_type = GGML_TYPE_Q8_0,
#endif
@@ -11280,6 +11352,8 @@ static void ggml_compute_forward_add1(
case GGML_TYPE_Q6_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
+ case GGML_TYPE_Q8_0_X4:
+ case GGML_TYPE_Q8_1_X4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q2_K_R4:
case GGML_TYPE_Q3_K:
@@ -11443,6 +11517,8 @@ static void ggml_compute_forward_acc(
case GGML_TYPE_Q6_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
+ case GGML_TYPE_Q8_0_X4:
+ case GGML_TYPE_Q8_1_X4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q2_K_R4:
case GGML_TYPE_Q3_K:
@@ -13889,6 +13965,14 @@ static void ggml_compute_forward_mul_mat_one_chunk(
}
}
+static inline uint32_t simple_gcd(uint32_t a, uint32_t b) {
+ while (a != b) {
+ if (a > b) a -= b;
+ else b -= a;
+ }
+ return a;
+}
+
static void ggml_compute_forward_mul_mat(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
@@ -13905,10 +13989,12 @@ static void ggml_compute_forward_mul_mat(
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
ggml_from_float_t const from_float = type_traits[vec_dot_type].from_float;
- ggml_from_float_to_mat_t const from_float_to_mat = type_traits[vec_dot_type].from_float_to_mat;
int64_t const vec_dot_num_rows = type_traits[type].nrows;
int64_t const matmul_num_cols = type_traits[type].ncols;
+#if !GGML_USE_IQK_MULMAT
+ ggml_from_float_to_mat_t const from_float_to_mat = type_traits[vec_dot_type].from_float_to_mat;
int64_t const blck_size_interleave = type_traits[type].blck_size_interleave;
+#endif
ggml_gemv_t const gemv = type_traits[type].gemv;
ggml_gemm_t const gemm = type_traits[type].gemm;
@@ -14011,6 +14097,7 @@ UseGgmlGemm1:;
for (int64_t i13 = 0; i13 < ne13; ++i13) {
for (int64_t i12 = 0; i12 < ne12; ++i12) {
int64_t i11_processed = 0;
+#if !GGML_USE_IQK_MULMAT
if ((ggml_n_dims(src1) == 2) && from_float_to_mat && gemm) {
for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
from_float_to_mat((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
@@ -14019,6 +14106,7 @@ UseGgmlGemm1:;
}
i11_processed = ne11 - ne11 % 4;
}
+#endif
for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
@@ -14049,14 +14137,31 @@ AlreadyQuantized:;
#if GGML_USE_IQK_MULMAT
if (src1->type != vec_dot_type && dst->type == GGML_TYPE_F32) {
+ // When K*Q and V*softmax(K*Q) (so ne12*ne13 > 1), it is better (faster) to have fewer threads processing
+ // one matrix multiplication, but work on several heads at once.
+ // Hence, we find the GCD(n12*ne13, nth) and have nth/GCD(n12*ne13, nth) threads per head.
+ // Leaving the previous version commented out for now just in case.
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
- for (int64_t i13 = 0; i13 < ne13; i13++)
- for (int64_t i12 = 0; i12 < ne12; i12++)
- if (!iqk_mul_mat(ne01, ne11, ne00,
- src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01, ///ggml_type_size(src0->type),
- vec_dot_type, (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size, row_size, ///ggml_type_size(vec_dot_type),
- (float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type),
- ith, nth)) goto IQK_MulMat_Not_Available2;
+ int ntg = simple_gcd(ne12*ne13, nth);
+ int counter = 0;
+ for (int64_t i13 = 0; i13 < ne13; i13++) {
+ for (int64_t i12 = 0; i12 < ne12; i12++) {
+ if (counter++ % ntg == ith%ntg) {
+ if (!iqk_mul_mat(ne01, ne11, ne00,
+ src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01, ///ggml_type_size(src0->type),
+ vec_dot_type, (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size, row_size, ///ggml_type_size(vec_dot_type),
+ (float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type),
+ ith/ntg, nth/ntg)) goto IQK_MulMat_Not_Available2;
+ }
+ }
+ }
+ //for (int64_t i13 = 0; i13 < ne13; i13++)
+ // for (int64_t i12 = 0; i12 < ne12; i12++)
+ // if (!iqk_mul_mat(ne01, ne11, ne00,
+ // src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01, ///ggml_type_size(src0->type),
+ // vec_dot_type, (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size, row_size, ///ggml_type_size(vec_dot_type),
+ // (float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type),
+ // ith, nth)) goto IQK_MulMat_Not_Available2;
return;
}
IQK_MulMat_Not_Available2:;
@@ -15055,6 +15160,8 @@ static void ggml_compute_forward_set(
case GGML_TYPE_Q6_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
+ case GGML_TYPE_Q8_0_X4:
+ case GGML_TYPE_Q8_1_X4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q2_K_R4:
case GGML_TYPE_Q3_K:
@@ -15352,6 +15459,8 @@ static void ggml_compute_forward_get_rows(
case GGML_TYPE_Q6_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
+ case GGML_TYPE_Q8_0_X4:
+ case GGML_TYPE_Q8_1_X4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q2_K_R4:
case GGML_TYPE_Q3_K:
@@ -15977,6 +16086,8 @@ static void ggml_compute_forward_clamp(
case GGML_TYPE_Q6_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
+ case GGML_TYPE_Q8_0_X4:
+ case GGML_TYPE_Q8_1_X4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q2_K_R4:
case GGML_TYPE_Q3_K: