summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/llama.cpp120
1 files changed, 47 insertions, 73 deletions
diff --git a/src/llama.cpp b/src/llama.cpp
index 34934a15..605e5d36 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -13755,31 +13755,52 @@ struct llm_build_context {
if (lctx.cparams.mla_attn > 1 && lctx.cparams.flash_attn && (pp_opt || lctx.cparams.mla_attn > 2)) {
+ auto kv_cache_nope = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_lora_rank, n_kv, kv_self.kv_l[il]->nb[1], 0);
- ggml_tensor * k;
- ggml_tensor * v;
+ auto kv_f32_size = model.layers[il].wkv_b->ne[1] * kv_cache_nope->ne[1] * sizeof(float) / (1024*1024);
+ int n_max_head = n_head;
+ if (cparams.attn_max_batch > 0 && kv_f32_size > cparams.attn_max_batch) {
+ while (n_max_head%2 == 0 && kv_f32_size > cparams.attn_max_batch) {
+ n_max_head /= 2; kv_f32_size /= 2;
+ }
+ }
+ GGML_ASSERT(n_head % n_max_head == 0);
+
+ auto n_per_head = model.layers[il].wkv_b->ne[1] / n_head;
+
+ auto kv_cache_rope = ggml_view_3d(ctx0, kv_self.kv_l[il], n_embd_head_qk_rope, n_kv, 1,
+ kv_self.kv_l[il]->nb[1], kv_self.kv_l[il]->nb[2], ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank));
+
+ ggml_tensor repeater;
+ repeater.ne[0] = n_embd_head_qk_rope; repeater.ne[1] = n_kv; repeater.ne[2] = n_max_head; repeater.ne[3] = 1;
+ auto k_rope = ggml_repeat(ctx0, kv_cache_rope, &repeater);
+ cb(k_rope, "k_rope", il);
- // For now this only works in the CPU implementation, so we only use it if there is just the CPU backend.
- // If the code was compiled with CUDA (and/or Metal, Vulkan, whatever) support, this branch will not
- // be taken even if no layers were offloaded to the GPU.
- if (lctx.backends.size() == 1 && lctx.backends.front() == lctx.backend_cpu) {
+ auto q = ggml_concat(ctx0, q_nope, q_rope, 0);
+ q = ggml_permute(ctx0, q, 0, 2, 1, 3);
+ cb(q, "q_concat", il);
+
+ ggml_build_forward_expand(gf, q);
- auto kv_cache_nope = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_lora_rank, n_kv, kv_self.kv_l[il]->nb[1], 0);
+ for (int iter = 0; iter < n_head/n_max_head; ++iter) {
- auto kv_f32 = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cache_nope);
+ auto wkv_b = ggml_view_2d(ctx0, model.layers[il].wkv_b, model.layers[il].wkv_b->ne[0], n_per_head*n_max_head,
+ model.layers[il].wkv_b->nb[1], model.layers[il].wkv_b->nb[1]*n_per_head*n_max_head*iter);
+
+ auto kv_f32 = ggml_mul_mat(ctx0, wkv_b, kv_cache_nope);
cb(kv_f32, "kv_f32", il);
- auto v_f32 = ggml_view_3d(ctx0, kv_f32, hparams.n_embd_head_v, n_kv, n_head,
- ggml_row_size(kv_f32->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
+ auto v_f32 = ggml_view_3d(ctx0, kv_f32, hparams.n_embd_head_v, n_kv, n_max_head,
+ ggml_row_size(kv_f32->type, n_max_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
ggml_row_size(kv_f32->type, n_embd_head_qk_nope));
cb(v_f32, "v_f32", il);
- v = ggml_cast(ctx0, v_f32, kv_self.kv_l[il]->type);
+ auto v = ggml_cast(ctx0, v_f32, kv_self.kv_l[il]->type);
cb(v, "v", il);
- auto k_nope_f32 = ggml_view_3d(ctx0, kv_f32, n_embd_head_qk_nope, n_kv, n_head,
- ggml_row_size(kv_f32->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
+ auto k_nope_f32 = ggml_view_3d(ctx0, kv_f32, n_embd_head_qk_nope, n_kv, n_max_head,
+ ggml_row_size(kv_f32->type, n_max_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v), 0);
cb(k_nope_f32, "k_nope_f32", il);
@@ -13789,74 +13810,27 @@ struct llm_build_context {
ggml_build_forward_expand(gf, k_nope);
ggml_build_forward_expand(gf, v);
- auto kv_cache_rope = ggml_view_3d(ctx0, kv_self.kv_l[il], n_embd_head_qk_rope, n_kv, 1,
- kv_self.kv_l[il]->nb[1], kv_self.kv_l[il]->nb[2], ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank));
-
- ggml_tensor repeater;
- repeater.ne[0] = n_embd_head_qk_rope; repeater.ne[1] = n_kv; repeater.ne[2] = n_head; repeater.ne[3] = 1;
- auto k_rope = ggml_repeat(ctx0, kv_cache_rope, &repeater);
- cb(k_rope, "k_rope", il);
-
- k = ggml_concat(ctx0, k_nope, k_rope, 0);
+ auto k = ggml_concat(ctx0, k_nope, k_rope, 0);
cb(k, "k", il);
ggml_build_forward_expand(gf, k);
- }
- else {
- // Hahaha, we need to convert the KV cache for this layer to f32 because the general purpose ML library ggml does not
- // provide ops on (almost) anything other than f32. In this case, the cache will be the second operand to a matrix
- // multiplication, which *must* be f32.
- auto kv_cache_view = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_self.kv_l[il]->ne[0], n_kv, kv_self.kv_l[il]->nb[1], 0);
- auto kv_cache_view_f32 = ggml_cast(ctx0, kv_cache_view, GGML_TYPE_F32);
- cb(kv_cache_view_f32, "kv_cache_view_f32", il);
-
- // The no- and rotational position encoding portions of the KV cache
- auto kv_cache_nope = ggml_view_2d(ctx0, kv_cache_view_f32, kv_lora_rank, n_kv, kv_cache_view_f32->nb[1], 0);
- auto kv_cache_rope = ggml_view_3d(ctx0, kv_cache_view_f32, n_embd_head_qk_rope, 1, n_kv,
- kv_cache_view_f32->nb[1], kv_cache_view_f32->nb[1], ggml_row_size(kv_cache_view_f32->type, kv_lora_rank));
-
- auto kv_f32 = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cache_nope);
- cb(kv_f32, "kv_f32", il);
-
- auto k_nope_f32 = ggml_view_3d(ctx0, kv_f32, n_embd_head_qk_nope, n_kv, n_head,
- ggml_row_size(kv_f32->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
- ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v), 0);
- cb(k_nope_f32, "k_nope_f32", il);
- ggml_tensor repeater;
- repeater.ne[0] = n_embd_head_qk_rope; repeater.ne[1] = n_head; repeater.ne[2] = n_kv; repeater.ne[3] = 1;
- auto k_rope_f32 = ggml_permute(ctx0, ggml_repeat(ctx0, kv_cache_rope, &repeater), 0, 2, 1, 3);
- cb(k_rope_f32, "k_rope_f32", il);
+ auto q_iter = ggml_view_3d(ctx0, q, q->ne[0], q->ne[1], n_max_head,
+ q->nb[1], q->nb[2], q->nb[2]*n_max_head*iter);
- auto k_f32 = ggml_concat(ctx0, k_nope_f32, k_rope_f32, 0);
- cb(k_f32, "k_f32", il);
-
- k = ggml_cast(ctx0, k_f32, kv_self.kv_l[il]->type);
- cb(k, "k", il);
-
- auto v_f32 = ggml_view_3d(ctx0, kv_f32, hparams.n_embd_head_v, n_kv, n_head,
- ggml_row_size(kv_f32->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
- ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
- ggml_row_size(kv_f32->type, n_embd_head_qk_nope));
- cb(v_f32, "v_f32", il);
-
- v = ggml_cast(ctx0, v_f32, kv_self.kv_l[il]->type);
- cb(v, "v", il);
- }
-
- auto q = ggml_concat(ctx0, q_nope, q_rope, 0);
- q = ggml_permute(ctx0, q, 0, 2, 1, 3);
- cb(q, "q_concat", il);
+ kqv = ggml_flash_attn_ext(ctx0, q_iter, k, v, KQ_mask, kq_scale, hparams.f_max_alibi_bias, 0.f);
+ if (q->ne[1] <= 8) {
+ ggml_flash_attn_ext_set_prec(kqv, GGML_PREC_F32);
+ }
+ cb(kqv, "kqv", il);
- ggml_build_forward_expand(gf, q);
+ if (iter == 0) {
+ cur = ggml_reshape_2d(ctx0, kqv, n_embd_head_v*n_max_head, n_tokens);
+ } else {
+ cur = ggml_concat(ctx0, cur, ggml_reshape_2d(ctx0, kqv, n_embd_head_v*n_max_head, n_tokens), 0);
+ }
- kqv = ggml_flash_attn_ext(ctx0, q, k, v, KQ_mask, kq_scale, hparams.f_max_alibi_bias, 0.f);
- if (q->ne[1] <= 8) {
- ggml_flash_attn_ext_set_prec(kqv, GGML_PREC_F32);
}
- cb(kqv, "kqv", il);
-
- cur = ggml_reshape_2d(ctx0, kqv, n_embd_head_v*n_head, n_tokens);
}
else {