From 2cf12eb12dcd82cdfe4785c1bcd8dc6255621790 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Fri, 30 May 2025 11:08:17 +0300 Subject: Replace MLA-specific KV cache with the standard KV cache (#469) * Remove kv_l, kvt_l and just use k_l and v_l * Hopefully take care of missing V cache (MLA) * Replace MLA-specific KV cache with the standard KV cache V2 (#473) * Fix save and restore when there is no V cache * Fix double print --------- Co-authored-by: Iwan Kawrakow Co-authored-by: saood06 --- src/llama.cpp | 197 +++++++++++++++++++++++++++------------------------------- 1 file changed, 92 insertions(+), 105 deletions(-) (limited to 'src/llama.cpp') diff --git a/src/llama.cpp b/src/llama.cpp index 0a620164..b8555677 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2991,10 +2991,6 @@ struct llama_kv_cache { std::vector k_l; // per layer std::vector v_l; - // DeepSeek MLA - std::vector kv_l; - std::vector kvt_l; - std::vector ctxs; std::vector bufs; @@ -3493,16 +3489,12 @@ static bool llama_kv_cache_init( } } + cache.k_l.reserve(n_layer); + bool needs_v_cache = true; if (model.arch == LLM_ARCH_DEEPSEEK2 && cparams.mla_attn) { - // DeepSeek MLA - cache.kv_l.reserve(n_layer); - if (cparams.mla_attn == 1 && !cparams.flash_attn) { - cache.kvt_l.reserve(n_layer); - } - } else { - cache.k_l.reserve(n_layer); - cache.v_l.reserve(n_layer); + needs_v_cache = cparams.mla_attn == 1 && !cparams.flash_attn; } + if (needs_v_cache) cache.v_l.reserve(n_layer); bool warn = true; int n_mla = 0; @@ -3525,17 +3517,17 @@ static bool llama_kv_cache_init( //LLAMA_LOG_INFO("%s: layer %d: n_embd_head_qk_rope = %d, kv_lora_rank = %d\n", __func__, i, n_embd_head_qk_rope, kv_lora_rank); if (cparams.flash_attn) { ggml_tensor * kv = ggml_new_tensor_2d(ctx, cache.type_k, kv_lora_rank + n_embd_head_qk_rope, kv_size); - ggml_format_name(kv, "cache_kv_l%d", i); - cache.kv_l.push_back(kv); + ggml_format_name(kv, "cache_k_l%d", i); + cache.k_l.push_back(kv); } else { auto kv_type = cparams.mla_attn == 1 ? cache.type_k : cache.type_v; ggml_tensor * kv = ggml_new_tensor_2d(ctx, kv_type, kv_lora_rank + n_embd_head_qk_rope, kv_size); - ggml_format_name(kv, "cache_kv_l%d", i); - cache.kv_l.push_back(kv); + ggml_format_name(kv, "cache_k_l%d", i); + cache.k_l.push_back(kv); if (cparams.mla_attn == 1) { ggml_tensor * kvt = ggml_new_tensor_1d(ctx, cache.type_v, kv_lora_rank*kv_size); - ggml_format_name(kvt, "cache_kvt_l%d", i); - cache.kvt_l.push_back(kvt); + ggml_format_name(kvt, "cache_v_l%d", i); + cache.v_l.push_back(kvt); } } n_mla++; @@ -10355,34 +10347,39 @@ struct llm_build_context { ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa), ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*id)); - ggml_tensor * view_v_src; - ggml_tensor * view_v_dst; - - if (flash_attn) { - // NOTE: the V cache is not transposed when using flash attention - view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il], - n_embd_v_gqa, nm, - ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa), - ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*i)); - - view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il], - n_embd_v_gqa, nm, - ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa), - ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*id)); - } else { - view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il], - nm, n_embd_v_gqa, - ggml_row_size(kv_self.v_l[il]->type, kv_self.size), - ggml_row_size(kv_self.v_l[il]->type, i)); - - view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il], - nm, n_embd_v_gqa, - ggml_row_size(kv_self.v_l[il]->type, kv_self.size), - ggml_row_size(kv_self.v_l[il]->type, id)); + ggml_tensor * view_v_src = nullptr; + ggml_tensor * view_v_dst = nullptr; + + if (kv_self.v_l.size() > il) { + // Note: with MLA the V cache may not be present. + if (flash_attn) { + // NOTE: the V cache is not transposed when using flash attention + view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il], + n_embd_v_gqa, nm, + ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa), + ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*i)); + + view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il], + n_embd_v_gqa, nm, + ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa), + ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*id)); + } else { + view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il], + nm, n_embd_v_gqa, + ggml_row_size(kv_self.v_l[il]->type, kv_self.size), + ggml_row_size(kv_self.v_l[il]->type, i)); + + view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il], + nm, n_embd_v_gqa, + ggml_row_size(kv_self.v_l[il]->type, kv_self.size), + ggml_row_size(kv_self.v_l[il]->type, id)); + } } ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst)); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst)); + if (view_v_src && view_v_dst) { + ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst)); + } } i += nm - 1; @@ -15371,16 +15368,16 @@ struct llm_build_context { ggml_tensor * kv_cache_trans; if (lctx.cparams.mla_attn == 1 && !lctx.cparams.flash_attn) { - ggml_tensor * kv_cache_trans_view = ggml_view_2d(ctx0, kv_self.kvt_l[il], n_tokens, kv_lora_rank, - ggml_row_size(kv_self.kvt_l[il]->type, kv_self.size), ggml_row_size(kv_self.kvt_l[il]->type, kv_head)); + ggml_tensor * kv_cache_trans_view = ggml_view_2d(ctx0, kv_self.v_l[il], n_tokens, kv_lora_rank, + ggml_row_size(kv_self.v_l[il]->type, kv_self.size), ggml_row_size(kv_self.v_l[il]->type, kv_head)); cb(kv_cache_trans_view, "kv_cache_trans_view", il); // note: storing transposed c^KV in the transposed KV cache ggml_build_forward_expand(gf, ggml_cpy(ctx0, ggml_transpose(ctx0, kv_compressed), kv_cache_trans_view)); - kv_cache_trans = ggml_view_2d(ctx0, kv_self.kvt_l[il], + kv_cache_trans = ggml_view_2d(ctx0, kv_self.v_l[il], n_kv, kv_lora_rank, - ggml_row_size(kv_self.kvt_l[il]->type, kv_self.size), + ggml_row_size(kv_self.v_l[il]->type, kv_self.size), 0); cb(kv_cache_trans, "kv_cache_trans", il); } @@ -15389,21 +15386,21 @@ struct llm_build_context { ggml_tensor * kvr = ggml_concat(ctx0, ggml_permute(ctx0, k_rope, 0, 2, 1, 3), kv_compressed, 0); cb(kvr, "kvr", il); - auto row_size = ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope); - ggml_tensor * kv_cache_view = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_self.kv_l[il]->ne[0], n_tokens, + auto row_size = ggml_row_size(kv_self.k_l[il]->type, kv_lora_rank + n_embd_head_qk_rope); + ggml_tensor * kv_cache_view = ggml_view_2d(ctx0, kv_self.k_l[il], kv_self.k_l[il]->ne[0], n_tokens, row_size, row_size*kv_head); ggml_build_forward_expand(gf, ggml_cpy(ctx0, kvr, kv_cache_view)); - ggml_tensor * kv_cache = ggml_view_2d(ctx0, kv_self.kv_l[il], + ggml_tensor * kv_cache = ggml_view_2d(ctx0, kv_self.k_l[il], kv_lora_rank + n_embd_head_qk_rope, n_kv, - ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), 0); + ggml_row_size(kv_self.k_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), 0); cb(kv_cache, "kv_cache", il); ggml_tensor * kqv; if (lctx.cparams.mla_attn > 1 && lctx.cparams.flash_attn && pp_opt) { // PP for mla=2,3 - auto kv_cache_nope = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_lora_rank, n_kv, kv_self.kv_l[il]->nb[1], - ggml_row_size(kv_self.kv_l[il]->type, n_embd_head_qk_rope)); + auto kv_cache_nope = ggml_view_2d(ctx0, kv_self.k_l[il], kv_lora_rank, n_kv, kv_self.k_l[il]->nb[1], + ggml_row_size(kv_self.k_l[il]->type, n_embd_head_qk_rope)); auto kv_f32_size = model.layers[il].wkv_b->ne[1] * kv_cache_nope->ne[1] * sizeof(float) / (1024*1024); int n_max_head = n_head; @@ -15416,14 +15413,14 @@ struct llm_build_context { auto n_per_head = model.layers[il].wkv_b->ne[1] / n_head; - auto kv_cache_rope = ggml_view_3d(ctx0, kv_self.kv_l[il], n_embd_head_qk_rope, n_kv, 1, - kv_self.kv_l[il]->nb[1], kv_self.kv_l[il]->nb[2], 0); //ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank)); + auto kv_cache_rope = ggml_view_3d(ctx0, kv_self.k_l[il], n_embd_head_qk_rope, n_kv, 1, + kv_self.k_l[il]->nb[1], kv_self.k_l[il]->nb[2], 0); //ggml_row_size(kv_self.k_l[il]->type, kv_lora_rank)); // There is still an issue with one or more of the ops GGML_OP_REPEAT, GGML_OP_CONCAT, GGML_OP_CPY on CUDA when // the KV cache is quantized. Hence, in that case we will simply use fp16 for now. // The downside of the following line is that fp16 will be used even if attention is computed on the CPU // if the build is with CUDA enabled. - auto kv_type = lctx.backends.size() == 1 && lctx.backends.front() == lctx.backend_cpu ? kv_self.kv_l[il]->type : GGML_TYPE_F16; + auto kv_type = lctx.backends.size() == 1 && lctx.backends.front() == lctx.backend_cpu ? kv_self.k_l[il]->type : GGML_TYPE_F16; ggml_tensor repeater; repeater.ne[0] = n_embd_head_qk_rope; repeater.ne[1] = n_kv; repeater.ne[2] = n_max_head; repeater.ne[3] = 1; @@ -15514,10 +15511,10 @@ struct llm_build_context { cb(q, "q", il); if (lctx.cparams.flash_attn && (lctx.cparams.mla_attn == 1 || lctx.cparams.mla_attn == 3)) { - ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.kv_l[il], + ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.k_l[il], kv_lora_rank, n_kv, - ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), - ggml_row_size(kv_self.kv_l[il]->type, n_embd_head_qk_rope)); + ggml_row_size(kv_self.k_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), + ggml_row_size(kv_self.k_l[il]->type, n_embd_head_qk_rope)); cb(kv_cache_lora, "kv_cache_lora", il); kqv_compressed = ggml_flash_attn_ext(ctx0, q, kv_cache, kv_cache_lora, KQ_mask, kq_scale, hparams.f_max_alibi_bias, 0.f); @@ -15528,10 +15525,10 @@ struct llm_build_context { } else { if (lctx.cparams.mla_attn > 1) { - ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.kv_l[il], + ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.k_l[il], kv_lora_rank, n_kv, - ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), - ggml_row_size(kv_self.kv_l[il]->type, n_embd_head_qk_rope)); + ggml_row_size(kv_self.k_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), + ggml_row_size(kv_self.k_l[il]->type, n_embd_head_qk_rope)); cb(kv_cache, "kv_cache_lora", il); kv_cache_trans = ggml_cont(ctx0, ggml_transpose(ctx0, kv_cache_lora)); @@ -20702,42 +20699,21 @@ struct llama_context * llama_new_context_with_model( } if (memory_size_k + memory_size_v > 0) { - LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, - (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), - ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), - ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); - } - } - - { - size_t memory_size_kv = 0; - size_t memory_size_kvt = 0; - - ggml_type kv_type = GGML_TYPE_COUNT; - ggml_type kvt_type = GGML_TYPE_COUNT; - - for (auto & kv : ctx->kv_self.kv_l) { - memory_size_kv += ggml_nbytes(kv); - kv_type = kv->type; - } - - for (auto & kvt : ctx->kv_self.kvt_l) { - memory_size_kvt += ggml_nbytes(kvt); - kvt_type = kvt->type; - } - - if (memory_size_kv + memory_size_kvt > 0) { - if (cparams.mla_attn == 1 && !cparams.flash_attn) { + if (cparams.mla_attn != 0 && !cparams.flash_attn) { LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, c^KV (%s): %7.2f MiB, kv^T (%s): %7.2f MiB\n", __func__, - (float)(memory_size_kv + memory_size_kvt) / (1024.0f * 1024.0f), - ggml_type_name(kv_type), (float)memory_size_kv / (1024.0f * 1024.0f), - ggml_type_name(kvt_type), (float)memory_size_kvt / (1024.0f * 1024.0f)); - } else { - GGML_ASSERT(memory_size_kvt == 0); + (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), + ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), + ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); + } else if (cparams.mla_attn != 0 && cparams.flash_attn) { LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, c^KV (%s): %7.2f MiB, kv^T: not used\n", __func__, - (float)(memory_size_kv + memory_size_kvt) / (1024.0f * 1024.0f), - ggml_type_name(kv_type), (float)memory_size_kv / (1024.0f * 1024.0f)); - } + (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), + ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f)); + } else { + LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, + (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), + ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), + ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); + } } } @@ -21454,10 +21430,13 @@ struct llama_data_write { const struct llama_kv_cache & kv_self = ctx->kv_self; const struct llama_hparams & hparams = ctx->model.hparams; - const uint32_t v_trans = kv_self.v_trans ? 1 : 0; + // v_state: 0 -> not transposed V cache + // 1 -> transposed V cache + // 2 -> no V cache (as it may be the case with MLA) + const uint32_t v_state = kv_self.v_l.empty() ? 2 : kv_self.v_trans ? 1 : 0; const uint32_t n_layer = hparams.n_layer; - write(&v_trans, sizeof(v_trans)); + write(&v_state, sizeof(v_state)); write(&n_layer, sizeof(n_layer)); std::vector tmp_buf; @@ -21483,7 +21462,7 @@ struct llama_data_write { } } - if (!kv_self.v_trans) { + if (v_state == 0) { for (uint32_t il = 0; il < n_layer; ++il) { const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); @@ -21502,7 +21481,8 @@ struct llama_data_write { write_tensor_data(kv_self.v_l[il], range.first * v_size_row, buf_size); } } - } else { + } + else if (v_state == 1) { // When v is transposed, we also need the element size and get the element ranges from each row const uint32_t kv_size = kv_self.size; for (uint32_t il = 0; il < n_layer; ++il) { @@ -21748,9 +21728,13 @@ struct llama_data_read { bool read_kv_cache_data(struct llama_context * ctx, uint32_t cell_count) { const struct llama_hparams & hparams = ctx->model.hparams; struct llama_kv_cache & kv_self = ctx->kv_self; - uint32_t v_trans; + + // v_state: 0 -> not transposed V cache + // 1 -> transposed V cache + // 2 -> no V cache (as it may be the case with MLA) + uint32_t v_state; uint32_t n_layer; - read_to(&v_trans, sizeof(v_trans)); + read_to(&v_state, sizeof(v_state)); read_to(&n_layer, sizeof(n_layer)); if (n_layer != hparams.n_layer) { @@ -21761,7 +21745,9 @@ struct llama_data_read { LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, kv_self.size); return false; } - if (kv_self.v_trans != (bool) v_trans) { + + // Currently the only way there is no V cache (and thus v_state is 2) requires flash_attn, and flash_attn sets kv_self.v_trans to false + if (kv_self.v_trans != (v_state == 1)) { LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__); return false; } @@ -21794,7 +21780,7 @@ struct llama_data_read { } } - if (!kv_self.v_trans) { + if (v_state == 0) { for (uint32_t il = 0; il < n_layer; ++il) { const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); @@ -21821,7 +21807,8 @@ struct llama_data_read { ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_row), kv_self.head * v_size_row, cell_count * v_size_row); } } - } else { + } + else if (v_state == 1) { // For each layer, read the values for each cell (transposed) for (uint32_t il = 0; il < n_layer; ++il) { const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); -- cgit v1.2.3