summaryrefslogtreecommitdiff
path: root/src/llama.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/llama.cpp')
-rw-r--r--src/llama.cpp22
1 files changed, 17 insertions, 5 deletions
diff --git a/src/llama.cpp b/src/llama.cpp
index 605e5d36..76039f8e 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -13771,9 +13771,21 @@ struct llm_build_context {
auto kv_cache_rope = ggml_view_3d(ctx0, kv_self.kv_l[il], n_embd_head_qk_rope, n_kv, 1,
kv_self.kv_l[il]->nb[1], kv_self.kv_l[il]->nb[2], ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank));
+ // There is still an issue with one or more of the ops GGML_OP_REPEAT, GGML_OP_CONCAT, GGML_OP_CPY on CUDA when
+ // the KV cache is quantized. Hence, in that case we will simply use fp16 for now.
+ // The downside of the following line is that fp16 will be used even if attention is computed on the CPU
+ // if the build is with CUDA enabled.
+ auto kv_type = lctx.backends.size() == 1 && lctx.backends.front() == lctx.backend_cpu ? kv_self.kv_l[il]->type : GGML_TYPE_F16;
+
ggml_tensor repeater;
repeater.ne[0] = n_embd_head_qk_rope; repeater.ne[1] = n_kv; repeater.ne[2] = n_max_head; repeater.ne[3] = 1;
- auto k_rope = ggml_repeat(ctx0, kv_cache_rope, &repeater);
+ ggml_tensor * k_rope;
+ if (kv_cache_rope->type == kv_type) {
+ k_rope = ggml_repeat(ctx0, kv_cache_rope, &repeater);
+ } else {
+ auto kv_cache_rope_f16 = ggml_cast(ctx0, kv_cache_rope, GGML_TYPE_F16);
+ k_rope = ggml_repeat(ctx0, kv_cache_rope_f16, &repeater);
+ }
cb(k_rope, "k_rope", il);
auto q = ggml_concat(ctx0, q_nope, q_rope, 0);
@@ -13796,15 +13808,15 @@ struct llm_build_context {
ggml_row_size(kv_f32->type, n_embd_head_qk_nope));
cb(v_f32, "v_f32", il);
- auto v = ggml_cast(ctx0, v_f32, kv_self.kv_l[il]->type);
- cb(v, "v", il);
-
auto k_nope_f32 = ggml_view_3d(ctx0, kv_f32, n_embd_head_qk_nope, n_kv, n_max_head,
ggml_row_size(kv_f32->type, n_max_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v), 0);
cb(k_nope_f32, "k_nope_f32", il);
- auto k_nope = ggml_cast(ctx0, k_nope_f32, kv_self.kv_l[il]->type);
+ auto v = ggml_cast(ctx0, v_f32, kv_type);
+ cb(v, "v", il);
+
+ auto k_nope = ggml_cast(ctx0, k_nope_f32, kv_type);
cb(k_nope, "k_nope", il);
ggml_build_forward_expand(gf, k_nope);