diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2025-04-08 08:47:24 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-04-08 08:47:24 +0200 |
commit | 5f44f4b3d006a24267ea02fe65490bb760a01447 (patch) | |
tree | b2737800378216ccb7bcee673a12bcdb7e97a785 /src | |
parent | 22d7440ba28000874b571b4a44a8b2c39c9a5ba8 (diff) |
Guard against attempts to use MLA for non-MLA models (#320)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'src')
-rw-r--r-- | src/llama.cpp | 25 |
1 files changed, 17 insertions, 8 deletions
diff --git a/src/llama.cpp b/src/llama.cpp index d41bd6d8..5f4642fb 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3245,13 +3245,15 @@ static bool llama_kv_cache_init( cache.ctxs.push_back(ctx); } - cache.k_l.reserve(n_layer); - cache.v_l.reserve(n_layer); - - // DeepSeek MLA - cache.kv_l.reserve(n_layer); - if (cparams.mla_attn == 1 && !cparams.flash_attn) { - cache.kvt_l.reserve(n_layer); + if (model.arch == LLM_ARCH_DEEPSEEK2 && cparams.mla_attn) { + // DeepSeek MLA + cache.kv_l.reserve(n_layer); + if (cparams.mla_attn == 1 && !cparams.flash_attn) { + cache.kvt_l.reserve(n_layer); + } + } else { + cache.k_l.reserve(n_layer); + cache.v_l.reserve(n_layer); } bool warn = true; @@ -3299,7 +3301,7 @@ static bool llama_kv_cache_init( cache.v_l.push_back(v); } } - if (cparams.mla_attn && n_mla < n_layer && n_mla > 0) { + if (model.arch == LLM_ARCH_DEEPSEEK2 && cparams.mla_attn && n_mla < n_layer && n_mla > 0) { LLAMA_LOG_ERROR("%s: unexpected situation with %d out of %d layers having MLA enabled\n", __func__, n_mla, int(n_layer)); LLAMA_LOG_ERROR("%s: bailing out\n", __func__); GGML_ABORT("fatal error"); @@ -18568,6 +18570,13 @@ struct llama_context * llama_new_context_with_model( params.seed = time(NULL); } + if (model->arch != LLM_ARCH_DEEPSEEK2 && cparams.mla_attn > 0) { + LLAMA_LOG_WARN("=====================================================================\n"); + LLAMA_LOG_WARN(" MLA is only available for LLM_ARCH_DEEPSEEK2 -> turning off MLA\n"); + LLAMA_LOG_WARN("=====================================================================\n"); + cparams.mla_attn = 0; + } + LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx); LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch); LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); |