From c12f73ba6153d162f36434cb48e36dd3649b7701 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sun, 9 Feb 2025 19:48:44 +0200 Subject: Add optional MLA (#188) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Deepseek MLA Optimizations Co-authored-by: Stanisław Szymczyk * Make MLA optional * Remove some unnecessary copies in the MLA attention * Deepseek MLA Optimizations V2 (#195) * Avoid allocating MHA KV cache when MLA is turned on * Added missing gguf-py file * Added final optimizations Co-authored-by: Stanisław Szymczyk * Make sure we do have wk_b and wv_b before enabling MLA --------- Co-authored-by: Stanisław Szymczyk Co-authored-by: Iwan Kawrakow * Use type_k and type_v to set the types of the MLA caches They were hard-coded at f16. On my Ryzen-7950X with native bf16 support I get a fairly significant PP performance boost with bf16 KV-cache: PP-4096 = 320 t/s up from 292 t/s with fp16 KV-cache. * Better gemm strategy when nth > nhead It gives a ~10% PP performance boost for DeepSeek-Lite with 32 threads (with or without MLA). Before this commit, when nth > nhead heads were processed sequentially with all nth threads participating in each matrix multiplication. Now we ind the gcd of nhead and nth and split threads into nth/gcd groups, each group processing nhead/gcd heads. --------- Co-authored-by: Saood Karim Co-authored-by: Stanisław Szymczyk Co-authored-by: Iwan Kawrakow --- ggml/src/ggml.c | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) (limited to 'ggml/src/ggml.c') diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index e07dd547..3867cf00 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -14064,31 +14064,22 @@ static void ggml_compute_forward_mul_mat( #endif #if GGML_USE_IQK_MULMAT - if (dst->type == GGML_TYPE_F32 && (ne12*ne13)%nth == 0) { + if (dst->type == GGML_TYPE_F32) { + int gcd = simple_gcd(ne12*ne13, nth); int counter = 0; for (int64_t i13 = 0; i13 < ne13; i13++) { for (int64_t i12 = 0; i12 < ne12; i12++) { - if (counter++ % nth == ith) { + if ((counter++ % gcd) == (ith%gcd)) { if (!iqk_mul_mat(ne01, ne11, ne00, src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01, ///ggml_type_size(src0->type), src1->type, (const char *)src1->data + i12*nb12 + i13*nb13, nb11, ///ggml_type_size(src1->type), (float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type), - 0, 1)) goto IQK_MulMat_Not_Available1; + ith/gcd, nth/gcd)) goto IQK_MulMat_Not_Available1; } } } return; } - if (dst->type == GGML_TYPE_F32) { - for (int64_t i13 = 0; i13 < ne13; i13++) - for (int64_t i12 = 0; i12 < ne12; i12++) - if (!iqk_mul_mat(ne01, ne11, ne00, - src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01, ///ggml_type_size(src0->type), - src1->type, (const char *)src1->data + i12*nb12 + i13*nb13, nb11, ///ggml_type_size(src1->type), - (float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type), - ith, nth)) goto IQK_MulMat_Not_Available1; - return; - } IQK_MulMat_Not_Available1:; #endif -- cgit v1.2.3