summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-04-25 09:21:03 +0200
committerGitHub <noreply@github.com>2025-04-25 09:21:03 +0200
commitf176122a3d50c781414458b498b9426086a91647 (patch)
tree1f1fb9520b81fa7fdb35bc8a4396c7b7b11f5fff
parentc9eec1729fe95a5fcfd4ce47df440c2445abb17e (diff)
Fix LLaMA-4 attention (#342)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--src/llama.cpp11
1 files changed, 9 insertions, 2 deletions
diff --git a/src/llama.cpp b/src/llama.cpp
index d0fd9c48..c870b09e 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -9974,7 +9974,12 @@ struct llm_build_context {
}
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
- struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
+ //bool is_swa = hparams.n_swa > 0 && h_params.n_swa_pattern > 0 ?
+ ggml_tensor * KQ_mask = build_inp_KQ_mask();
+ ggml_tensor * KQ_mask_swa = nullptr;
+ if (hparams.n_swa > 0 && hparams.n_swa_pattern > 0) {
+ KQ_mask_swa = build_inp_KQ_mask_swa();
+ }
//const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : 1.f;
@@ -9982,6 +9987,8 @@ struct llm_build_context {
struct ggml_tensor * inpSA = inpL;
bool use_rope = model.arch == LLM_ARCH_LLAMA4 ? (il + 1) % hparams.n_no_rope_layer_step != 0 : true;
+ auto this_KQ_mask = hparams.n_swa > 0 && hparams.n_swa_pattern > 0 && il % hparams.n_swa_pattern < (hparams.n_swa_pattern - 1) ?
+ KQ_mask_swa : KQ_mask;
// norm
cur = llm_build_norm(ctx0, inpL, hparams,
@@ -10046,7 +10053,7 @@ struct llm_build_context {
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
- Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
+ Kcur, Vcur, Qcur, this_KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
}
if (il == n_layer - 1) {