diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2025-03-08 19:33:41 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-03-08 19:33:41 +0200 |
commit | 81748fb55e474ef1ddb3c64c14f7c378f0f6cd8b (patch) | |
tree | e36c0b1490b9e751086f1aa798596fc2710ff0b1 /examples/eval-callback | |
parent | 3d85a1d66302989401f92a5ae347577b03cbdaa7 (diff) |
Faster FlashMLA prompt processing (#246)
* FlashMLA-2: faster prompt processing
The current MLA implementation computes
wv_b * (k_cache * softmax(k_cache * (wk_b*q)))
This leads to 3.4X more multiply-adds (madds)
compared to standard attention. Due to the resulting
tensor shapes, TG is still faster than standard attention
because the k_cache*(wk_b*q) and k_cache*(softmax(k_cache * (wk_b*q)))
multiplications become GEMMs, so the additional madds are
more than compensated for due to the much higher performance
of GEMMs compared to GEMVs. But for PP, where we are dealing
with GEMMs in both cases, the additional madds needed for MLA
lead to lower performance, with the performance gap increasing
with context length.
So, then, when we are dealing with PP, we can rearrange the
above to (wv_b * k_cache) * softmax( (wk_b^T*k_cache) * q),
thus transforming it into the standard attention mechanism.
We do need two additional matrix multiplications (which in practice
is done as a single wkv_b * k_cache GEMM) with the *entire*
K cache. But this is still cheaper than MLA, as we end up with
1.8X the madds required by standard attention. Oh, these figures
are for the DeepSeek-V3/R1/Lite attention architecture.
This leads to a significant PP performance increase compared
to standard MLA with FA.
There are many upsides to this:
* If we only apply the above trick when we are processing more than
X tokens (with suitable chosen X), TG performance stays the same
as MLA with FA
* We still need to store just the K-cache, so 576 entries per layer
for DeepSeek-V3/R1/Lite
* We get significantly better PP performance
* We can use MLA+FA on CUDA. It works already with this commit
for PP, something is not yet quite right for TG.
The downside is that it only works with fp16 cache (for now).
This is so because we need to convert the cache to fp32,
else we cannot do the wkv_b * k_cache matrix multiplication
(which in ggml requires the second operand to be fp32).
But converting (copying) to fp32 only works for f16, bf16 and
f32 tensors, so no luck with quantized cache. Another reason
that we need to convert to fp32 is that the cache contains the
RoPE'd portion, which we need to concatenate to the result of
the wkv_b * k_cache matrix multiplication. Also this op
works only when the tensors being concatenated are both fp32.
So much about ggml being a general purpose ML library.
* FlashMLA-2: on the CPU it now works for quantized cache
except for q8_KV (q8_KV has row meta data, and there is still
some confusion with row sizes because of that).
* FlashMLA-2: on the CPU it now works also with q8_KV
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'examples/eval-callback')
0 files changed, 0 insertions, 0 deletions