From c12f73ba6153d162f36434cb48e36dd3649b7701 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sun, 9 Feb 2025 19:48:44 +0200 Subject: Add optional MLA (#188) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Deepseek MLA Optimizations Co-authored-by: Stanisław Szymczyk * Make MLA optional * Remove some unnecessary copies in the MLA attention * Deepseek MLA Optimizations V2 (#195) * Avoid allocating MHA KV cache when MLA is turned on * Added missing gguf-py file * Added final optimizations Co-authored-by: Stanisław Szymczyk * Make sure we do have wk_b and wv_b before enabling MLA --------- Co-authored-by: Stanisław Szymczyk Co-authored-by: Iwan Kawrakow * Use type_k and type_v to set the types of the MLA caches They were hard-coded at f16. On my Ryzen-7950X with native bf16 support I get a fairly significant PP performance boost with bf16 KV-cache: PP-4096 = 320 t/s up from 292 t/s with fp16 KV-cache. * Better gemm strategy when nth > nhead It gives a ~10% PP performance boost for DeepSeek-Lite with 32 threads (with or without MLA). Before this commit, when nth > nhead heads were processed sequentially with all nth threads participating in each matrix multiplication. Now we ind the gcd of nhead and nth and split threads into nth/gcd groups, each group processing nhead/gcd heads. --------- Co-authored-by: Saood Karim Co-authored-by: Stanisław Szymczyk Co-authored-by: Iwan Kawrakow --- include/llama.h | 1 + 1 file changed, 1 insertion(+) (limited to 'include/llama.h') diff --git a/include/llama.h b/include/llama.h index 730c087a..39251d35 100644 --- a/include/llama.h +++ b/include/llama.h @@ -374,6 +374,7 @@ extern "C" { bool embeddings; // if true, extract embeddings (together with logits) bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU bool flash_attn; // whether to use flash attention [EXPERIMENTAL] + bool mla_attn; // whether to use MLA attention [EXPERIMENTAL] // Abort callback // if it returns true, execution of llama_decode() will be aborted -- cgit v1.2.3