summaryrefslogtreecommitdiff
path: root/src/llama.cpp
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-03-21 07:24:22 +0100
committerGitHub <noreply@github.com>2025-03-21 07:24:22 +0100
commitddc8eee10ee9216de57429167e6f74e618577d93 (patch)
tree1dc5d251aa32e1bc6b1e3c4e68ea8619a89a5697 /src/llama.cpp
parentb8d1fac97b756968b86b470d44bb1026ded7157a (diff)
FlashMLA-3: the best of both worlds (CPU only) (#273)
* Repack a model with the quantize tool * WIP * Fixed various issues As we don't have a way to tell if a repacked quant has been modified, I had to remove the modification at the expense of a slight decrease in performance. This affects q8_0_r8, q8_KV_r8, q8_k_r8 on Zen4, and q4_0_r8 on ARM. * Create wk_b and wv_b as Q8_0_R8 if the wkv_b type is interleaved * Fix GCC 13.3 compilation error * Another one * Add missing include * FlashMLA-3: the best of both worlds - CPU only --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'src/llama.cpp')
-rw-r--r--src/llama.cpp4
1 files changed, 2 insertions, 2 deletions
diff --git a/src/llama.cpp b/src/llama.cpp
index a459cb00..33b69389 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -13760,7 +13760,7 @@ struct llm_build_context {
ggml_tensor * kqv;
- if (lctx.cparams.mla_attn > 1 && lctx.cparams.flash_attn && (pp_opt || lctx.cparams.mla_attn > 2)) {
+ if (lctx.cparams.mla_attn > 1 && lctx.cparams.flash_attn && pp_opt) { // PP for mla=2,3
auto kv_cache_nope = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_lora_rank, n_kv, kv_self.kv_l[il]->nb[1], 0);
@@ -13869,7 +13869,7 @@ struct llm_build_context {
ggml_tensor * q = ggml_concat(ctx0, q_nope2, ggml_permute(ctx0, q_rope, 0, 2, 1, 3), 0);
cb(q, "q", il);
- if (lctx.cparams.flash_attn && lctx.cparams.mla_attn == 1) {
+ if (lctx.cparams.flash_attn && lctx.cparams.mla_attn == 1 || lctx.cparams.mla_attn == 3) {
ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.kv_l[il],
kv_lora_rank, n_kv,
ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), 0);