summaryrefslogtreecommitdiff
path: root/llama.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp')
-rw-r--r--llama.cpp9
1 files changed, 5 insertions, 4 deletions
diff --git a/llama.cpp b/llama.cpp
index c5a1fa0f..e11f0ac4 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -6622,6 +6622,7 @@ static struct ggml_tensor * llm_build_kqv(
const int64_t n_embd_head_k = hparams.n_embd_head_k;
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
const int64_t n_embd_head_v = hparams.n_embd_head_v;
+ const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
struct ggml_tensor * q = ggml_permute(ctx, q_cur, 0, 2, 1, 3);
cb(q, "q", il);
@@ -6644,8 +6645,8 @@ static struct ggml_tensor * llm_build_kqv(
struct ggml_tensor * v =
ggml_view_3d(ctx, kv.v_l[il],
n_embd_head_v, n_kv, n_head_kv,
- ggml_row_size(kv.v_l[il]->type, n_embd_k_gqa),
- ggml_row_size(kv.v_l[il]->type, n_embd_head_k),
+ ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa),
+ ggml_row_size(kv.v_l[il]->type, n_embd_head_v),
0);
cb(v, "v", il);
@@ -6655,7 +6656,7 @@ static struct ggml_tensor * llm_build_kqv(
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
}
- cur = ggml_reshape_2d(ctx, cur, n_embd_head_k*n_head, n_tokens);
+ cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens);
} else {
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
cb(kq, "kq", il);
@@ -6700,7 +6701,7 @@ static struct ggml_tensor * llm_build_kqv(
struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);
cb(kqv_merged, "kqv_merged", il);
- cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens);
+ cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_v*n_head, n_tokens);
cb(cur, "kqv_merged_cont", il);
}