diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2025-03-08 19:33:41 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-03-08 19:33:41 +0200 |
commit | 81748fb55e474ef1ddb3c64c14f7c378f0f6cd8b (patch) | |
tree | e36c0b1490b9e751086f1aa798596fc2710ff0b1 | |
parent | 3d85a1d66302989401f92a5ae347577b03cbdaa7 (diff) |
Faster FlashMLA prompt processing (#246)
* FlashMLA-2: faster prompt processing
The current MLA implementation computes
wv_b * (k_cache * softmax(k_cache * (wk_b*q)))
This leads to 3.4X more multiply-adds (madds)
compared to standard attention. Due to the resulting
tensor shapes, TG is still faster than standard attention
because the k_cache*(wk_b*q) and k_cache*(softmax(k_cache * (wk_b*q)))
multiplications become GEMMs, so the additional madds are
more than compensated for due to the much higher performance
of GEMMs compared to GEMVs. But for PP, where we are dealing
with GEMMs in both cases, the additional madds needed for MLA
lead to lower performance, with the performance gap increasing
with context length.
So, then, when we are dealing with PP, we can rearrange the
above to (wv_b * k_cache) * softmax( (wk_b^T*k_cache) * q),
thus transforming it into the standard attention mechanism.
We do need two additional matrix multiplications (which in practice
is done as a single wkv_b * k_cache GEMM) with the *entire*
K cache. But this is still cheaper than MLA, as we end up with
1.8X the madds required by standard attention. Oh, these figures
are for the DeepSeek-V3/R1/Lite attention architecture.
This leads to a significant PP performance increase compared
to standard MLA with FA.
There are many upsides to this:
* If we only apply the above trick when we are processing more than
X tokens (with suitable chosen X), TG performance stays the same
as MLA with FA
* We still need to store just the K-cache, so 576 entries per layer
for DeepSeek-V3/R1/Lite
* We get significantly better PP performance
* We can use MLA+FA on CUDA. It works already with this commit
for PP, something is not yet quite right for TG.
The downside is that it only works with fp16 cache (for now).
This is so because we need to convert the cache to fp32,
else we cannot do the wkv_b * k_cache matrix multiplication
(which in ggml requires the second operand to be fp32).
But converting (copying) to fp32 only works for f16, bf16 and
f32 tensors, so no luck with quantized cache. Another reason
that we need to convert to fp32 is that the cache contains the
RoPE'd portion, which we need to concatenate to the result of
the wkv_b * k_cache matrix multiplication. Also this op
works only when the tensors being concatenated are both fp32.
So much about ggml being a general purpose ML library.
* FlashMLA-2: on the CPU it now works for quantized cache
except for q8_KV (q8_KV has row meta data, and there is still
some confusion with row sizes because of that).
* FlashMLA-2: on the CPU it now works also with q8_KV
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | ggml/src/ggml.c | 41 | ||||
-rw-r--r-- | src/llama.cpp | 138 |
2 files changed, 176 insertions, 3 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index e5ad15f2..4089e9b7 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -10565,6 +10565,42 @@ static void ggml_compute_forward_dup_bytes( } } +static void ggml_compute_forward_dup_q( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + GGML_ASSERT(dst->type == GGML_TYPE_F32); + struct ggml_tensor * src0 = dst->src[0]; + GGML_ASSERT(src0->ne[0] == dst->ne[0] && src0->nb[0] == ggml_type_size(src0->type)); + + ggml_to_float_t to_float = type_traits[src0->type].to_float; + GGML_ASSERT(to_float != NULL); + + int64_t nrows = ggml_nrows(dst); + int ith = params->ith; + int nth = params->nth; + + int64_t n_per_thread = (nrows + nth - 1)/nth; + int64_t first_row = ith*n_per_thread; + if (first_row >= nrows) return; + int64_t last_row = MIN(first_row + n_per_thread, nrows); + + for (int64_t ir = first_row; ir < last_row; ++ir) { + int64_t i03 = ir/(src0->ne[1]*src0->ne[2]); + int64_t i02 = (ir - i03*src0->ne[1]*src0->ne[2])/src0->ne[1]; + int64_t i01 = ir - i03*src0->ne[1]*src0->ne[2] - i02*src0->ne[1]; + int64_t i3 = ir/(dst->ne[1]*dst->ne[2]); + int64_t i2 = (ir - i3*dst->ne[1]*dst->ne[2])/dst->ne[1]; + int64_t i1 = ir - i3*dst->ne[1]*dst->ne[2] - i2*dst->ne[1]; + + const char * q = (const char *)src0->data + i03*src0->nb[3] + i02*src0->nb[2] + i01*src0->nb[1]; + char * f = ( char *)dst->data + i3* dst->nb[3] + i2* dst->nb[2] + i1* dst->nb[1]; + + to_float((const void *)q, (float *)f, src0->ne[0]); + } + +} + static void ggml_compute_forward_dup( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -10576,6 +10612,11 @@ static void ggml_compute_forward_dup( return; } + if (ggml_is_quantized(src0->type)) { + ggml_compute_forward_dup_q(params, dst); + return; + } + switch (src0->type) { case GGML_TYPE_F16: { diff --git a/src/llama.cpp b/src/llama.cpp index 9c9739e9..b45d9e70 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2691,6 +2691,9 @@ struct llama_kv_cache { // DeepSeek MLA std::vector<struct ggml_tensor *> kv_l; std::vector<struct ggml_tensor *> kvt_l; + ggml_tensor * kv_aux_f32 = nullptr; + ggml_tensor * k_aux = nullptr; + ggml_tensor * v_aux = nullptr; std::vector<struct ggml_context *> ctxs; std::vector<ggml_backend_buffer_t> bufs; @@ -3199,12 +3202,35 @@ static bool llama_kv_cache_init( if (cparams.mla_attn && model.layers[i].wk_b && model.layers[i].wv_b) { // DeepSeek MLA const uint32_t n_embd_head_qk_rope = hparams.n_rot; + const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; const uint32_t kv_lora_rank = hparams.n_lora_kv; LLAMA_LOG_INFO("%s: layer %d: n_embd_head_qk_rope = %d, kv_lora_rank = %d\n", __func__, i, n_embd_head_qk_rope, kv_lora_rank); if (cparams.flash_attn) { ggml_tensor * kv = ggml_new_tensor_2d(ctx, cache.type_k, kv_lora_rank + n_embd_head_qk_rope, kv_size); ggml_format_name(kv, "cache_kv_l%d", i); cache.kv_l.push_back(kv); + if (cparams.mla_attn > 1 && cache.kv_aux_f32 == nullptr) { + + cache.kv_aux_f32 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, + kv_lora_rank + n_embd_head_qk_rope, kv_size); + //(n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head, kv_size); + ggml_format_name(cache.kv_aux_f32, "kv_aux_f32%d", 0); + + cache.k_aux = ggml_new_tensor_3d(ctx, cache.type_k, hparams.n_embd_head_k, n_head, kv_size); + ggml_format_name(cache.k_aux, "k_aux%d", 0); + + cache.v_aux = ggml_new_tensor_3d(ctx, cache.type_k, hparams.n_embd_head_v, n_head, kv_size); + ggml_format_name(cache.v_aux, "v_aux%d", 0); + + //cache.kv_aux_2 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, + // (hparams.n_embd_head_k + hparams.n_embd_head_v)*n_head, kv_size); + //ggml_format_name(cache.kv_aux, "kv_aux%d", 0); + //ggml_format_name(cache.kv_aux_2, "kv_aux%d", 2); + LLAMA_LOG_INFO("%s: allocated kv auxilary tensors as %ld x %ld, %ld x %ld x %ld, %ld x %ld x %ld\n", __func__, + cache.kv_aux_f32->ne[0], cache.kv_aux_f32->ne[1], + cache.k_aux->ne[0], cache.k_aux->ne[1], cache.k_aux->ne[2], + cache.v_aux->ne[0], cache.v_aux->ne[1], cache.v_aux->ne[2]); + } } else { auto kv_type = cparams.mla_attn == 1 ? cache.type_k : cache.type_v; ggml_tensor * kv = ggml_new_tensor_2d(ctx, kv_type, kv_lora_rank + n_embd_head_qk_rope, kv_size); @@ -13625,6 +13651,111 @@ struct llm_build_context { ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), 0); cb(kv_cache, "kv_cache", il); + ggml_tensor * kqv; + + if (lctx.cparams.mla_attn > 1 && lctx.cparams.flash_attn && (pp_opt || lctx.cparams.mla_attn > 2)) { + + // Hahaha, we need to convert the KV cache for this layer to f32 because the general purpose ML library ggml does not + // provide ops on (almost) anything other than f32. In this case, the cache will be the second operand to a matrix + // multiplication, which *must* be f32. + auto kv_cache_view = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_self.kv_l[il]->ne[0], n_kv, kv_self.kv_l[il]->nb[1], 0); + auto kv_cache_view_f32 = ggml_view_2d(ctx0, kv_self.kv_aux_f32, kv_self.kv_aux_f32->ne[0], n_kv, kv_self.kv_aux_f32->nb[1], 0); + kv_cache_view_f32 = ggml_cpy(ctx0, kv_cache_view, kv_cache_view_f32); + + // The no- and rorational position encoding portions of the KV cache + auto kv_cache_nope = ggml_view_2d(ctx0, kv_cache_view_f32, kv_lora_rank, n_kv, kv_cache_view_f32->nb[1], 0); + auto kv_cache_rope = ggml_view_3d(ctx0, kv_cache_view_f32, n_embd_head_qk_rope, 1, n_kv, + kv_cache_view_f32->nb[1], kv_cache_view_f32->nb[1], ggml_row_size(kv_cache_view_f32->type, kv_lora_rank)); + + auto kv_f32 = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cache_nope); + + //// split into {n_head * n_embd_head_qk_nope, n_tokens} + //struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens, + // ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v), + // ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), + // 0); + //cb(k_nope, "k_nope", il); + + //// and {n_head * n_embd_head_v, n_tokens} + //struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens, + // ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)), + // ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head), + // ggml_row_size(kv->type, (n_embd_head_qk_nope))); + //cb(v_states, "v_states", il); + + auto k_nope_f32 = ggml_view_3d(ctx0, kv_f32, n_embd_head_qk_nope, n_kv, n_head, + ggml_row_size(kv_f32->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), + ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v), 0); + + ggml_tensor repeater; + repeater.ne[0] = n_embd_head_qk_rope; repeater.ne[1] = n_head; repeater.ne[2] = n_kv; repeater.ne[3] = 1; + auto k_rope_f32 = ggml_permute(ctx0, ggml_repeat(ctx0, kv_cache_rope, &repeater), 0, 2, 1, 3); + + auto k_f32 = ggml_concat(ctx0, k_nope_f32, k_rope_f32, 0); + + auto k_row_size = ggml_row_size(kv_self.k_aux->type, k_f32->ne[0]); + auto k = ggml_view_3d(ctx0, kv_self.k_aux, k_f32->ne[0], k_f32->ne[1], k_f32->ne[2], + k_row_size, k_row_size*k_f32->ne[1], 0); + //kv_self.k_aux->nb[1], k_row_size, 0); + //k_row_size, kv_self.k_aux->nb[1], 0); + k = ggml_cpy(ctx0, k_f32, k); + + auto v_f32 = ggml_view_3d(ctx0, kv_f32, hparams.n_embd_head_v, n_kv, n_head, + ggml_row_size(kv_f32->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), + ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v), + ggml_row_size(kv_f32->type, n_embd_head_qk_nope)); + + auto v_row_size = ggml_row_size(kv_self.v_aux->type, v_f32->ne[0]); + auto v = ggml_view_3d(ctx0, kv_self.v_aux, v_f32->ne[0], v_f32->ne[1], v_f32->ne[2], + v_row_size, v_row_size*v_f32->ne[1], 0); + //kv_self.v_aux->nb[1], v_row_size, 0); + //v_row_size, kv_self.v_aux->nb[1], 0); + v = ggml_cpy(ctx0, v_f32, v); + + //auto kv_cache_nope = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_lora_rank, n_kv, kv_cache->nb[1], 0); + //auto kv_cache_rope = ggml_view_2d(ctx0, kv_self.kv_l[il], n_embd_head_qk_rope, n_kv, kv_cache->nb[1], + // ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank)); + ////kv_cache_rope = ggml_permute(ctx0, kv_cache_rope, 0, 2, 1, 3); + + //auto kv_cache_nope_f32 = ggml_view_2d(ctx0, kv_self.kv_aux_f32, kv_lora_rank, n_kv, kv_self.kv_aux_f32->nb[1], 0); + //kv_cache_nope_f32 = ggml_cpy(ctx0, kv_cache_nope, kv_cache_nope_f32); + //auto kv_f32 = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cache_nope_f32); + + ////auto kv = ggml_new_tensor_2d(ctx0, kv_self.kv_l[il]->type, model.layers[il].wkv_b->ne[1], n_kv); + ////auto kv = ggml_new_tensor_2d(ctx0, kv_self.kv_l[il]->type, kv_f32->ne[0], kv_f32->ne[1]); + //auto kv = ggml_view_2d(ctx0, kv_self.kv_aux, kv_self.kv_aux->ne[0], n_kv, kv_self.kv_aux->nb[1], 0); + //kv = ggml_cpy(ctx0, kv_f32, kv); + + //auto k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_kv, n_head, + // ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), + // ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v), 0); + + //ggml_tensor repeater; + //repeater.ne[0] = n_embd_head_qk_rope; repeater.ne[1] = n_kv; repeater.ne[2] = n_head; repeater.ne[3] = 1; + //auto k = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, kv_cache_rope, &repeater), 0); + + //auto v = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_kv, n_head, + // ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), + // ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v), + // ggml_row_size(kv->type, n_embd_head_qk_nope)); + + auto q = ggml_concat(ctx0, q_nope, q_rope, 0); + q = ggml_permute(ctx0, q, 0, 2, 1, 3); + + kqv = ggml_flash_attn_ext(ctx0, q, k, v, KQ_mask, kq_scale, hparams.f_max_alibi_bias, 0.f); + cb(kqv, "kqv", il); + + cur = ggml_reshape_2d(ctx0, kqv, n_embd_head_v*n_head, n_tokens); + + //kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3); + //cb(kqv_compressed, "kqv_compressed_perm", il); + } + else { + + ggml_tensor * kqv_compressed; + + //printf("wkv_b: %ld x %ld x %ld kv_cache: %ld x %ld x %ld\n", model.layers[il].wkv_b->ne[0], model.layers[il].wkv_b->ne[1], model.layers[il].wkv_b->ne[2], kv_cache->ne[0], kv_cache->ne[1], kv_cache->ne[2]); + struct ggml_tensor * wk_b = ggml_view_3d(ctx0, model.layers[il].wk_b, n_embd_head_qk_nope, kv_lora_rank, n_head, ggml_row_size(model.layers[il].wk_b->type, n_embd_head_qk_nope), ggml_row_size(model.layers[il].wk_b->type, kv_lora_rank)*n_embd_head_qk_nope, 0); @@ -13636,12 +13767,12 @@ struct llm_build_context { struct ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b, q_nope); cb(q_nope2, "q_nope2", il); + //printf("q_nope2 (%ld x %ld x %ld) = wk_b (%ld x %ld x %ld) * q_nope (%ld x %ld x %ld)\n", q_nope2->ne[0], q_nope2->ne[1], q_nope2->ne[2], + // wk_b->ne[0], wk_b->ne[1], wk_b->ne[2], q_nope->ne[0], q_nope->ne[1], q_nope->ne[2]); ggml_tensor * q = ggml_concat(ctx0, q_nope2, ggml_permute(ctx0, q_rope, 0, 2, 1, 3), 0); cb(q, "q", il); - ggml_tensor * kqv_compressed; - if (lctx.cparams.flash_attn) { ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_lora_rank, n_kv, @@ -13729,7 +13860,7 @@ struct llm_build_context { ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank)*n_embd_head_v, 0); cb(wv_b, "wv_b", il); - struct ggml_tensor * kqv = ggml_mul_mat(ctx0, wv_b, kqv_compressed); + kqv = ggml_mul_mat(ctx0, wv_b, kqv_compressed); cb(kqv, "kqv", il); kqv = ggml_cont(ctx0, ggml_permute(ctx0, kqv, 0, 2, 1, 3)); @@ -13737,6 +13868,7 @@ struct llm_build_context { cur = ggml_view_2d(ctx0, kqv, n_embd_head_v*n_head, n_tokens, ggml_row_size(kqv->type, n_embd_head_v*n_head), 0); cb(cur, "kqv_2d", il); + } ggml_build_forward_expand(gf, cur); |