summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorfairydreaming <166155368+fairydreaming@users.noreply.github.com>2024-05-17 13:24:38 +0200
committerGitHub <noreply@github.com>2024-05-17 14:24:38 +0300
commit27b040691cbe45314147c2745e891a38e9c048d4 (patch)
tree85f2ade442a6e316a6700a89bb121924e051c593
parent29c60d8cddcfd14fa8a6bf023a6c4eb8692c76ba (diff)
llama : use n_embd_head_v when reshaping kqv (#7327)
* llama : use n_embd_head_v instead of n_embd_head_k when reshaping kqv * llama : use n_embd_v_gqa and n_embd_head_v instead of n_embd_k_gqa and n_embd_head_k when making a view of cached value vectors. --------- Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
-rw-r--r--llama.cpp9
1 files changed, 5 insertions, 4 deletions
diff --git a/llama.cpp b/llama.cpp
index c5a1fa0f..e11f0ac4 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -6622,6 +6622,7 @@ static struct ggml_tensor * llm_build_kqv(
const int64_t n_embd_head_k = hparams.n_embd_head_k;
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
const int64_t n_embd_head_v = hparams.n_embd_head_v;
+ const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
struct ggml_tensor * q = ggml_permute(ctx, q_cur, 0, 2, 1, 3);
cb(q, "q", il);
@@ -6644,8 +6645,8 @@ static struct ggml_tensor * llm_build_kqv(
struct ggml_tensor * v =
ggml_view_3d(ctx, kv.v_l[il],
n_embd_head_v, n_kv, n_head_kv,
- ggml_row_size(kv.v_l[il]->type, n_embd_k_gqa),
- ggml_row_size(kv.v_l[il]->type, n_embd_head_k),
+ ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa),
+ ggml_row_size(kv.v_l[il]->type, n_embd_head_v),
0);
cb(v, "v", il);
@@ -6655,7 +6656,7 @@ static struct ggml_tensor * llm_build_kqv(
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
}
- cur = ggml_reshape_2d(ctx, cur, n_embd_head_k*n_head, n_tokens);
+ cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens);
} else {
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
cb(kq, "kq", il);
@@ -6700,7 +6701,7 @@ static struct ggml_tensor * llm_build_kqv(
struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);
cb(kqv_merged, "kqv_merged", il);
- cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens);
+ cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_v*n_head, n_tokens);
cb(cur, "kqv_merged_cont", il);
}