From 81748fb55e474ef1ddb3c64c14f7c378f0f6cd8b Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sat, 8 Mar 2025 19:33:41 +0200 Subject: Faster FlashMLA prompt processing (#246) * FlashMLA-2: faster prompt processing The current MLA implementation computes wv_b * (k_cache * softmax(k_cache * (wk_b*q))) This leads to 3.4X more multiply-adds (madds) compared to standard attention. Due to the resulting tensor shapes, TG is still faster than standard attention because the k_cache*(wk_b*q) and k_cache*(softmax(k_cache * (wk_b*q))) multiplications become GEMMs, so the additional madds are more than compensated for due to the much higher performance of GEMMs compared to GEMVs. But for PP, where we are dealing with GEMMs in both cases, the additional madds needed for MLA lead to lower performance, with the performance gap increasing with context length. So, then, when we are dealing with PP, we can rearrange the above to (wv_b * k_cache) * softmax( (wk_b^T*k_cache) * q), thus transforming it into the standard attention mechanism. We do need two additional matrix multiplications (which in practice is done as a single wkv_b * k_cache GEMM) with the *entire* K cache. But this is still cheaper than MLA, as we end up with 1.8X the madds required by standard attention. Oh, these figures are for the DeepSeek-V3/R1/Lite attention architecture. This leads to a significant PP performance increase compared to standard MLA with FA. There are many upsides to this: * If we only apply the above trick when we are processing more than X tokens (with suitable chosen X), TG performance stays the same as MLA with FA * We still need to store just the K-cache, so 576 entries per layer for DeepSeek-V3/R1/Lite * We get significantly better PP performance * We can use MLA+FA on CUDA. It works already with this commit for PP, something is not yet quite right for TG. The downside is that it only works with fp16 cache (for now). This is so because we need to convert the cache to fp32, else we cannot do the wkv_b * k_cache matrix multiplication (which in ggml requires the second operand to be fp32). But converting (copying) to fp32 only works for f16, bf16 and f32 tensors, so no luck with quantized cache. Another reason that we need to convert to fp32 is that the cache contains the RoPE'd portion, which we need to concatenate to the result of the wkv_b * k_cache matrix multiplication. Also this op works only when the tensors being concatenated are both fp32. So much about ggml being a general purpose ML library. * FlashMLA-2: on the CPU it now works for quantized cache except for q8_KV (q8_KV has row meta data, and there is still some confusion with row sizes because of that). * FlashMLA-2: on the CPU it now works also with q8_KV --------- Co-authored-by: Iwan Kawrakow --- src/llama.cpp | 138 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 135 insertions(+), 3 deletions(-) (limited to 'src') diff --git a/src/llama.cpp b/src/llama.cpp index 9c9739e9..b45d9e70 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2691,6 +2691,9 @@ struct llama_kv_cache { // DeepSeek MLA std::vector kv_l; std::vector kvt_l; + ggml_tensor * kv_aux_f32 = nullptr; + ggml_tensor * k_aux = nullptr; + ggml_tensor * v_aux = nullptr; std::vector ctxs; std::vector bufs; @@ -3199,12 +3202,35 @@ static bool llama_kv_cache_init( if (cparams.mla_attn && model.layers[i].wk_b && model.layers[i].wv_b) { // DeepSeek MLA const uint32_t n_embd_head_qk_rope = hparams.n_rot; + const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; const uint32_t kv_lora_rank = hparams.n_lora_kv; LLAMA_LOG_INFO("%s: layer %d: n_embd_head_qk_rope = %d, kv_lora_rank = %d\n", __func__, i, n_embd_head_qk_rope, kv_lora_rank); if (cparams.flash_attn) { ggml_tensor * kv = ggml_new_tensor_2d(ctx, cache.type_k, kv_lora_rank + n_embd_head_qk_rope, kv_size); ggml_format_name(kv, "cache_kv_l%d", i); cache.kv_l.push_back(kv); + if (cparams.mla_attn > 1 && cache.kv_aux_f32 == nullptr) { + + cache.kv_aux_f32 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, + kv_lora_rank + n_embd_head_qk_rope, kv_size); + //(n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head, kv_size); + ggml_format_name(cache.kv_aux_f32, "kv_aux_f32%d", 0); + + cache.k_aux = ggml_new_tensor_3d(ctx, cache.type_k, hparams.n_embd_head_k, n_head, kv_size); + ggml_format_name(cache.k_aux, "k_aux%d", 0); + + cache.v_aux = ggml_new_tensor_3d(ctx, cache.type_k, hparams.n_embd_head_v, n_head, kv_size); + ggml_format_name(cache.v_aux, "v_aux%d", 0); + + //cache.kv_aux_2 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, + // (hparams.n_embd_head_k + hparams.n_embd_head_v)*n_head, kv_size); + //ggml_format_name(cache.kv_aux, "kv_aux%d", 0); + //ggml_format_name(cache.kv_aux_2, "kv_aux%d", 2); + LLAMA_LOG_INFO("%s: allocated kv auxilary tensors as %ld x %ld, %ld x %ld x %ld, %ld x %ld x %ld\n", __func__, + cache.kv_aux_f32->ne[0], cache.kv_aux_f32->ne[1], + cache.k_aux->ne[0], cache.k_aux->ne[1], cache.k_aux->ne[2], + cache.v_aux->ne[0], cache.v_aux->ne[1], cache.v_aux->ne[2]); + } } else { auto kv_type = cparams.mla_attn == 1 ? cache.type_k : cache.type_v; ggml_tensor * kv = ggml_new_tensor_2d(ctx, kv_type, kv_lora_rank + n_embd_head_qk_rope, kv_size); @@ -13625,6 +13651,111 @@ struct llm_build_context { ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), 0); cb(kv_cache, "kv_cache", il); + ggml_tensor * kqv; + + if (lctx.cparams.mla_attn > 1 && lctx.cparams.flash_attn && (pp_opt || lctx.cparams.mla_attn > 2)) { + + // Hahaha, we need to convert the KV cache for this layer to f32 because the general purpose ML library ggml does not + // provide ops on (almost) anything other than f32. In this case, the cache will be the second operand to a matrix + // multiplication, which *must* be f32. + auto kv_cache_view = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_self.kv_l[il]->ne[0], n_kv, kv_self.kv_l[il]->nb[1], 0); + auto kv_cache_view_f32 = ggml_view_2d(ctx0, kv_self.kv_aux_f32, kv_self.kv_aux_f32->ne[0], n_kv, kv_self.kv_aux_f32->nb[1], 0); + kv_cache_view_f32 = ggml_cpy(ctx0, kv_cache_view, kv_cache_view_f32); + + // The no- and rorational position encoding portions of the KV cache + auto kv_cache_nope = ggml_view_2d(ctx0, kv_cache_view_f32, kv_lora_rank, n_kv, kv_cache_view_f32->nb[1], 0); + auto kv_cache_rope = ggml_view_3d(ctx0, kv_cache_view_f32, n_embd_head_qk_rope, 1, n_kv, + kv_cache_view_f32->nb[1], kv_cache_view_f32->nb[1], ggml_row_size(kv_cache_view_f32->type, kv_lora_rank)); + + auto kv_f32 = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cache_nope); + + //// split into {n_head * n_embd_head_qk_nope, n_tokens} + //struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens, + // ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v), + // ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), + // 0); + //cb(k_nope, "k_nope", il); + + //// and {n_head * n_embd_head_v, n_tokens} + //struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens, + // ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)), + // ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head), + // ggml_row_size(kv->type, (n_embd_head_qk_nope))); + //cb(v_states, "v_states", il); + + auto k_nope_f32 = ggml_view_3d(ctx0, kv_f32, n_embd_head_qk_nope, n_kv, n_head, + ggml_row_size(kv_f32->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), + ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v), 0); + + ggml_tensor repeater; + repeater.ne[0] = n_embd_head_qk_rope; repeater.ne[1] = n_head; repeater.ne[2] = n_kv; repeater.ne[3] = 1; + auto k_rope_f32 = ggml_permute(ctx0, ggml_repeat(ctx0, kv_cache_rope, &repeater), 0, 2, 1, 3); + + auto k_f32 = ggml_concat(ctx0, k_nope_f32, k_rope_f32, 0); + + auto k_row_size = ggml_row_size(kv_self.k_aux->type, k_f32->ne[0]); + auto k = ggml_view_3d(ctx0, kv_self.k_aux, k_f32->ne[0], k_f32->ne[1], k_f32->ne[2], + k_row_size, k_row_size*k_f32->ne[1], 0); + //kv_self.k_aux->nb[1], k_row_size, 0); + //k_row_size, kv_self.k_aux->nb[1], 0); + k = ggml_cpy(ctx0, k_f32, k); + + auto v_f32 = ggml_view_3d(ctx0, kv_f32, hparams.n_embd_head_v, n_kv, n_head, + ggml_row_size(kv_f32->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), + ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v), + ggml_row_size(kv_f32->type, n_embd_head_qk_nope)); + + auto v_row_size = ggml_row_size(kv_self.v_aux->type, v_f32->ne[0]); + auto v = ggml_view_3d(ctx0, kv_self.v_aux, v_f32->ne[0], v_f32->ne[1], v_f32->ne[2], + v_row_size, v_row_size*v_f32->ne[1], 0); + //kv_self.v_aux->nb[1], v_row_size, 0); + //v_row_size, kv_self.v_aux->nb[1], 0); + v = ggml_cpy(ctx0, v_f32, v); + + //auto kv_cache_nope = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_lora_rank, n_kv, kv_cache->nb[1], 0); + //auto kv_cache_rope = ggml_view_2d(ctx0, kv_self.kv_l[il], n_embd_head_qk_rope, n_kv, kv_cache->nb[1], + // ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank)); + ////kv_cache_rope = ggml_permute(ctx0, kv_cache_rope, 0, 2, 1, 3); + + //auto kv_cache_nope_f32 = ggml_view_2d(ctx0, kv_self.kv_aux_f32, kv_lora_rank, n_kv, kv_self.kv_aux_f32->nb[1], 0); + //kv_cache_nope_f32 = ggml_cpy(ctx0, kv_cache_nope, kv_cache_nope_f32); + //auto kv_f32 = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cache_nope_f32); + + ////auto kv = ggml_new_tensor_2d(ctx0, kv_self.kv_l[il]->type, model.layers[il].wkv_b->ne[1], n_kv); + ////auto kv = ggml_new_tensor_2d(ctx0, kv_self.kv_l[il]->type, kv_f32->ne[0], kv_f32->ne[1]); + //auto kv = ggml_view_2d(ctx0, kv_self.kv_aux, kv_self.kv_aux->ne[0], n_kv, kv_self.kv_aux->nb[1], 0); + //kv = ggml_cpy(ctx0, kv_f32, kv); + + //auto k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_kv, n_head, + // ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), + // ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v), 0); + + //ggml_tensor repeater; + //repeater.ne[0] = n_embd_head_qk_rope; repeater.ne[1] = n_kv; repeater.ne[2] = n_head; repeater.ne[3] = 1; + //auto k = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, kv_cache_rope, &repeater), 0); + + //auto v = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_kv, n_head, + // ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), + // ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v), + // ggml_row_size(kv->type, n_embd_head_qk_nope)); + + auto q = ggml_concat(ctx0, q_nope, q_rope, 0); + q = ggml_permute(ctx0, q, 0, 2, 1, 3); + + kqv = ggml_flash_attn_ext(ctx0, q, k, v, KQ_mask, kq_scale, hparams.f_max_alibi_bias, 0.f); + cb(kqv, "kqv", il); + + cur = ggml_reshape_2d(ctx0, kqv, n_embd_head_v*n_head, n_tokens); + + //kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3); + //cb(kqv_compressed, "kqv_compressed_perm", il); + } + else { + + ggml_tensor * kqv_compressed; + + //printf("wkv_b: %ld x %ld x %ld kv_cache: %ld x %ld x %ld\n", model.layers[il].wkv_b->ne[0], model.layers[il].wkv_b->ne[1], model.layers[il].wkv_b->ne[2], kv_cache->ne[0], kv_cache->ne[1], kv_cache->ne[2]); + struct ggml_tensor * wk_b = ggml_view_3d(ctx0, model.layers[il].wk_b, n_embd_head_qk_nope, kv_lora_rank, n_head, ggml_row_size(model.layers[il].wk_b->type, n_embd_head_qk_nope), ggml_row_size(model.layers[il].wk_b->type, kv_lora_rank)*n_embd_head_qk_nope, 0); @@ -13636,12 +13767,12 @@ struct llm_build_context { struct ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b, q_nope); cb(q_nope2, "q_nope2", il); + //printf("q_nope2 (%ld x %ld x %ld) = wk_b (%ld x %ld x %ld) * q_nope (%ld x %ld x %ld)\n", q_nope2->ne[0], q_nope2->ne[1], q_nope2->ne[2], + // wk_b->ne[0], wk_b->ne[1], wk_b->ne[2], q_nope->ne[0], q_nope->ne[1], q_nope->ne[2]); ggml_tensor * q = ggml_concat(ctx0, q_nope2, ggml_permute(ctx0, q_rope, 0, 2, 1, 3), 0); cb(q, "q", il); - ggml_tensor * kqv_compressed; - if (lctx.cparams.flash_attn) { ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_lora_rank, n_kv, @@ -13729,7 +13860,7 @@ struct llm_build_context { ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank)*n_embd_head_v, 0); cb(wv_b, "wv_b", il); - struct ggml_tensor * kqv = ggml_mul_mat(ctx0, wv_b, kqv_compressed); + kqv = ggml_mul_mat(ctx0, wv_b, kqv_compressed); cb(kqv, "kqv", il); kqv = ggml_cont(ctx0, ggml_permute(ctx0, kqv, 0, 2, 1, 3)); @@ -13737,6 +13868,7 @@ struct llm_build_context { cur = ggml_view_2d(ctx0, kqv, n_embd_head_v*n_head, n_tokens, ggml_row_size(kqv->type, n_embd_head_v*n_head), 0); cb(cur, "kqv_2d", il); + } ggml_build_forward_expand(gf, cur); -- cgit v1.2.3