diff options
Diffstat (limited to 'src/llama.cpp')
-rw-r--r-- | src/llama.cpp | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/src/llama.cpp b/src/llama.cpp index a459cb00..33b69389 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -13760,7 +13760,7 @@ struct llm_build_context { ggml_tensor * kqv; - if (lctx.cparams.mla_attn > 1 && lctx.cparams.flash_attn && (pp_opt || lctx.cparams.mla_attn > 2)) { + if (lctx.cparams.mla_attn > 1 && lctx.cparams.flash_attn && pp_opt) { // PP for mla=2,3 auto kv_cache_nope = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_lora_rank, n_kv, kv_self.kv_l[il]->nb[1], 0); @@ -13869,7 +13869,7 @@ struct llm_build_context { ggml_tensor * q = ggml_concat(ctx0, q_nope2, ggml_permute(ctx0, q_rope, 0, 2, 1, 3), 0); cb(q, "q", il); - if (lctx.cparams.flash_attn && lctx.cparams.mla_attn == 1) { + if (lctx.cparams.flash_attn && lctx.cparams.mla_attn == 1 || lctx.cparams.mla_attn == 3) { ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_lora_rank, n_kv, ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), 0); |