summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-03-05 07:27:49 +0200
committerGitHub <noreply@github.com>2025-03-05 07:27:49 +0200
commit7bdbf99bbdbfe46b01f7783a7c98a30a1558e2c3 (patch)
treeecd49fd2b7bccc1b3d70819241d0eaf9c4dddede /src
parenta87e54db6ec2409284a55f029d4abe9e50990064 (diff)
DeepSeek CUDA Flash Attention (#241)
* WIP CUDA FA with Dk != Dv * WIP * CUDA FA WIP - It actually works! No TG yet, but for PP I can run FA with fp16 cache and it gets the same answer. * CUDA FA WIP - it now works for Q8_0 + Q8_0 for KV cache * CUDA FA WIP - TG, not working yet. * CUDA FA with Dk != Dv: it works now for DeepSeek --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'src')
-rw-r--r--src/llama.cpp6
1 files changed, 5 insertions, 1 deletions
diff --git a/src/llama.cpp b/src/llama.cpp
index 5ac44055..e246dec9 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -8777,7 +8777,11 @@ static struct ggml_tensor * llm_build_kqv(
cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
- if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) {
+ // Some models produced NaNs/gibberish when FA is computed with f16 precision on CUDA
+ // For DeepSeek-2, it is perfectly fine with fp16 for PP, but I get gibberish when uding fp16 for TG.
+ // Not sure if it is really a matter of insufficient precision, or I have made a mistake in the fattn-vec-f16 kernel.
+ if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX ||
+ (model.arch == LLM_ARCH_DEEPSEEK2 && q->ne[1] <= 8)) {
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
}
//ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);