summaryrefslogtreecommitdiff
path: root/llama.cpp
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2023-08-17 10:47:09 +0300
committerGitHub <noreply@github.com>2023-08-17 10:47:09 +0300
commita73ccf1aa34de49f61bfeb7f8a679c3bfdb3abe3 (patch)
tree32b0834658476238b1c7209474f8606a7947b350 /llama.cpp
parent7cf54e1f746941279d81d485796777c01f88049c (diff)
llama : replace (permute + reshape + view_1d) with (view_3d) (#2538)
ggml-ci
Diffstat (limited to 'llama.cpp')
-rw-r--r--llama.cpp16
1 files changed, 8 insertions, 8 deletions
diff --git a/llama.cpp b/llama.cpp
index 34524399..b8cc2294 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -1609,11 +1609,11 @@ static struct ggml_cgraph * llama_build_graph(
ggml_set_name(Q, "Q");
struct ggml_tensor * K =
- ggml_permute(ctx0,
- ggml_reshape_3d(ctx0,
- ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd_gqa, il*n_ctx*ggml_element_size(kv_self.k)*n_embd_gqa),
- n_embd_head, n_head_kv, n_past + N),
- 0, 2, 1, 3);
+ ggml_view_3d(ctx0, kv_self.k,
+ n_embd_head, n_past + N, n_head_kv,
+ ggml_element_size(kv_self.k)*n_embd_gqa,
+ ggml_element_size(kv_self.k)*n_embd_head,
+ ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il);
offload_func_kq(K);
ggml_set_name(K, "K");
@@ -1642,9 +1642,9 @@ static struct ggml_cgraph * llama_build_graph(
struct ggml_tensor * V =
ggml_view_3d(ctx0, kv_self.v,
n_past + N, n_embd_head, n_head_kv,
- n_ctx*ggml_element_size(kv_self.v),
- n_ctx*ggml_element_size(kv_self.v)*n_embd_head,
- n_ctx*ggml_element_size(kv_self.v)*n_embd_gqa*il);
+ ggml_element_size(kv_self.v)*n_ctx,
+ ggml_element_size(kv_self.v)*n_ctx*n_embd_head,
+ ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il);
offload_func_v(V);
ggml_set_name(V, "V");