summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-02-13 11:50:20 +0200
committerGitHub <noreply@github.com>2025-02-13 11:50:20 +0200
commit05242ff17d3685321ea0ea12021f77609219f2a6 (patch)
tree9bcbe0b2e7785b195c4678f86903a1e0c69830c0
parent1bbb543478bbc0997c3f86581c4f95338a5eb5c3 (diff)
Faster MLA prompt processing (#205)
* Do not allocate / report caches that are not used It is either the standard KV cache or MLA cache, not both. * Rename X_pe to X_rope Much easier to follow, at least for my brain, when we have X_rope : rotational position encoding X_nope : no position encoding instead of X_pe and X_nope, where I was wondering wtf is 'pe' and 'nope'. * WIP * WIP * WIP * WIP * Warn user when disabling MLA * MLA: compile time option to not use transposed KV cache Cuts KV cache size in nearly half at the expense of slower TG performance for long contexts (it becomes similar to no-MLA). --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--src/llama.cpp325
1 files changed, 156 insertions, 169 deletions
diff --git a/src/llama.cpp b/src/llama.cpp
index 0817c53c..498bb437 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -109,6 +109,14 @@
#define LLAMA_MAX_EXPERTS 256 // DeepSeekV2
//
+// === MLA cache
+// If tou are desperate to reduce KV cache size, set MLA_USE_TRANSPOSED_CACHE to 0.
+// TG perfornce will be slower (similar to no-MLA), but KV cache size will be cut to ~half.
+// PP performance will be about the same as with MLA_USE_TRANSPOSED_CACHE = 1.
+//
+#define MLA_USE_TRANSPOSED_CACHE 1
+
+//
// helpers
//
@@ -2547,7 +2555,7 @@ struct llama_layer {
struct ggml_tensor * wkv_a_mqa;
struct ggml_tensor * wkv_b;
struct ggml_tensor * wk_b;
- struct ggml_tensor * wv_b;
+ struct ggml_tensor * wv_b;
struct ggml_tensor * wq_cross;
struct ggml_tensor * wk_cross;
struct ggml_tensor * wv_cross;
@@ -2676,18 +2684,16 @@ struct llama_kv_cache {
ggml_type type_k = GGML_TYPE_F16;
ggml_type type_v = GGML_TYPE_F16;
- ggml_type type_kr = GGML_TYPE_F16;
- ggml_type type_kv = GGML_TYPE_F16;
-
std::vector<llama_kv_cell> cells;
std::vector<struct ggml_tensor *> k_l; // per layer
std::vector<struct ggml_tensor *> v_l;
// DeepSeek MLA
- std::vector<struct ggml_tensor *> kr_l; // per layer
std::vector<struct ggml_tensor *> kv_l;
+#if MLA_USE_TRANSPOSED_CACHE
std::vector<struct ggml_tensor *> kvt_l;
+#endif
std::vector<struct ggml_context *> ctxs;
std::vector<ggml_backend_buffer_t> bufs;
@@ -3121,8 +3127,6 @@ static bool llama_kv_cache_init(
cache.type_k = type_k;
cache.type_v = type_v;
- cache.type_kr = type_k;
- cache.type_kv = type_v;
cache.cells.clear();
cache.cells.resize(kv_size);
@@ -3166,10 +3170,13 @@ static bool llama_kv_cache_init(
cache.v_l.reserve(n_layer);
// DeepSeek MLA
- cache.kr_l.reserve(n_layer);
cache.kv_l.reserve(n_layer);
+#if MLA_USE_TRANSPOSED_CACHE
cache.kvt_l.reserve(n_layer);
+#endif
+ bool warn = true;
+ int n_mla = 0;
for (int i = 0; i < (int) n_layer; i++) {
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
@@ -3177,34 +3184,53 @@ static bool llama_kv_cache_init(
struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
ggml_tensor * k;
ggml_tensor * v;
+ if (cparams.mla_attn) {
+ if (!model.layers[i].wk_b || !model.layers[i].wv_b) {
+ if (warn) {
+ LLAMA_LOG_WARN("=======================================================================================\n");
+ LLAMA_LOG_WARN("%s: missing MLA tensors => disabling MLA\n", __func__);
+ LLAMA_LOG_WARN("%s: you need to reconvert your model in order to use MLA\n", __func__);
+ LLAMA_LOG_WARN("=======================================================================================\n");
+ warn = false;
+ }
+ }
+ }
if (cparams.mla_attn && model.layers[i].wk_b && model.layers[i].wv_b) {
- k = ggml_new_tensor_1d(ctx, type_k, 1);
- v = ggml_new_tensor_1d(ctx, type_v, 1);
+ // DeepSeek MLA
+ const uint32_t n_embd_head_qk_rope = hparams.n_rot;
+ const uint32_t kv_lora_rank = hparams.n_lora_kv;
+ LLAMA_LOG_INFO("%s: layer %d: n_embd_head_qk_rope = %d, kv_lora_rank = %d\n", __func__, i, n_embd_head_qk_rope, kv_lora_rank);
+#if MLA_USE_TRANSPOSED_CACHE
+ // TODO: The k-cache is contiguous and not permuted, so strictly speaking, it should be possible to quantize it.
+ // Sadly, at this point something goes wrong with quantized k-cache, so for now we set the k-cache
+ // type to type_v, which is guaranteed to be f16 or bf16 without FA.
+ //ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_k, (kv_lora_rank + n_embd_head_qk_rope)*kv_size);
+ ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_v, (kv_lora_rank + n_embd_head_qk_rope)*kv_size);
+#else
+ ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_v, (kv_lora_rank + n_embd_head_qk_rope)*kv_size);
+#endif
+ ggml_format_name(kv, "cache_kv_l%d", i);
+ cache.kv_l.push_back(kv);
+#if MLA_USE_TRANSPOSED_CACHE
+ ggml_tensor * kvt = ggml_new_tensor_1d(ctx, cache.type_v, kv_lora_rank*kv_size);
+ ggml_format_name(kvt, "cache_kvt_l%d", i);
+ cache.kvt_l.push_back(kvt);
+#endif
+ n_mla++;
}
else {
- k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
- v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
- }
-
- ggml_format_name(k, "cache_k_l%d", i);
- ggml_format_name(v, "cache_v_l%d", i);
- cache.k_l.push_back(k);
- cache.v_l.push_back(v);
-
-
- // DeepSeek MLA
- const uint32_t n_embd_head_qk_rope = hparams.n_rot;
- const uint32_t kv_lora_rank = hparams.n_lora_kv;
- LLAMA_LOG_INFO("%s: layer %d: n_embd_head_qk_rope = %d, kv_lora_rank = %d\n", __func__, i, n_embd_head_qk_rope, kv_lora_rank);
- ggml_tensor * kr = ggml_new_tensor_1d(ctx, cache.type_kr, n_embd_head_qk_rope*kv_size);
- ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_kv, kv_lora_rank*kv_size);
- ggml_tensor * kvt = ggml_new_tensor_1d(ctx, cache.type_kv, kv_lora_rank*kv_size);
- ggml_format_name(kr, "cache_kr_l%d", i);
- ggml_format_name(kv, "cache_kv_l%d", i);
- ggml_format_name(kvt, "cache_kvt_l%d", i);
- cache.kr_l.push_back(kr);
- cache.kv_l.push_back(kv);
- cache.kvt_l.push_back(kvt);
+ k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
+ v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
+ ggml_format_name(k, "cache_k_l%d", i);
+ ggml_format_name(v, "cache_v_l%d", i);
+ cache.k_l.push_back(k);
+ cache.v_l.push_back(v);
+ }
+ }
+ if (cparams.mla_attn && n_mla < n_layer && n_mla > 0) {
+ LLAMA_LOG_ERROR("%s: unexpected situation with %d out of %d layers having MLA enabled\n", __func__, n_mla, int(n_layer));
+ LLAMA_LOG_ERROR("%s: bailing out\n", __func__);
+ GGML_ABORT("fatal error");
}
// allocate tensors and initialize the buffers to avoid NaNs in the padding
@@ -13422,94 +13448,80 @@ struct llm_build_context {
cb(q_nope, "q_nope", il);
// and {n_head * n_embd_head_qk_rope, n_tokens}
- struct ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens,
+ struct ggml_tensor * q_rope = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens,
ggml_row_size(q->type, hparams.n_embd_head_k),
ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
ggml_row_size(q->type, n_embd_head_qk_nope));
- cb(q_pe, "q_pe", il);
+ cb(q_rope, "q_rope", il);
+
+ q_rope = ggml_rope_ext(
+ ctx0, q_rope, inp_pos, nullptr,
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor_scaled, beta_fast, beta_slow
+ );
+ cb(q_rope, "q_rope", il);
// {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens}
- struct ggml_tensor * kv_pe_compresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur);
- cb(kv_pe_compresseed, "kv_pe_compresseed", il);
+ struct ggml_tensor * kv_rope_compresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur);
+ cb(kv_rope_compresseed, "kv_rope_compresseed", il);
+
+ // and {n_embd_head_qk_rope, n_tokens}
+ struct ggml_tensor * k_rope = ggml_view_3d(ctx0, kv_rope_compresseed, n_embd_head_qk_rope, 1, n_tokens,
+ kv_rope_compresseed->nb[1],
+ kv_rope_compresseed->nb[1],
+ ggml_row_size(kv_rope_compresseed->type, kv_lora_rank));
+ cb(k_rope, "k_rope", il);
+
+ // shared RoPE key
+ k_rope = ggml_rope_ext(
+ ctx0, k_rope, inp_pos, nullptr,
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor_scaled, beta_fast, beta_slow
+ );
+ cb(k_rope, "k_rope", il);
// split into {kv_lora_rank, n_tokens}
- struct ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_pe_compresseed, kv_lora_rank, n_tokens,
- kv_pe_compresseed->nb[1],
+ struct ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_rope_compresseed, kv_lora_rank, n_tokens,
+ kv_rope_compresseed->nb[1],
0);
cb(kv_compressed, "kv_compressed", il);
- if (lctx.cparams.mla_attn && model.layers[il].wk_b && model.layers[il].wv_b) {
-
- // and {n_embd_head_qk_rope, n_tokens}
- struct ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens,
- kv_pe_compresseed->nb[1],
- kv_pe_compresseed->nb[1],
- ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
- cb(k_pe, "k_pe", il);
-
- //kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm
- kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams,
- model.layers[il].attn_kv_a_norm, NULL,
- LLM_NORM_RMS, cb, il);
- cb(kv_compressed, "kv_compressed", il);
-
- struct ggml_tensor * kv_cache_view = ggml_view_1d(ctx0, kv_self.kv_l[il], n_tokens*kv_lora_rank, ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank)*kv_head);
- cb(kv_cache_view, "kv_cache_view", il);
+ kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams,
+ model.layers[il].attn_kv_a_norm, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(kv_compressed, "kv_compressed", il);
- // note: storing c^KV in the KV cache
- ggml_build_forward_expand(gf, ggml_cpy(ctx0, kv_compressed, kv_cache_view));
+ if (lctx.cparams.mla_attn && model.layers[il].wk_b && model.layers[il].wv_b) {
- struct ggml_tensor * kv_cache_trans_view = ggml_view_2d(ctx0, kv_self.kvt_l[il], n_tokens, kv_lora_rank, ggml_row_size(kv_self.kv_l[il]->type, kv_self.size), ggml_row_size(kv_self.kv_l[il]->type, kv_head));
+#if MLA_USE_TRANSPOSED_CACHE
+ ggml_tensor * kv_cache_trans_view = ggml_view_2d(ctx0, kv_self.kvt_l[il], n_tokens, kv_lora_rank,
+ ggml_row_size(kv_self.kv_l[il]->type, kv_self.size), ggml_row_size(kv_self.kv_l[il]->type, kv_head));
cb(kv_cache_trans_view, "kv_cache_trans_view", il);
// note: storing transposed c^KV in the transposed KV cache
ggml_build_forward_expand(gf, ggml_cpy(ctx0, ggml_transpose(ctx0, kv_compressed), kv_cache_trans_view));
- struct ggml_tensor * kv_cache =
- ggml_view_2d(ctx0, kv_self.kv_l[il],
- kv_lora_rank, n_kv,
- ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank),
- 0);
- cb(kv_cache, "kv_cache", il);
-
- struct ggml_tensor * kv_cache_trans =
- ggml_view_2d(ctx0, kv_self.kvt_l[il],
- n_kv, kv_lora_rank,
- ggml_row_size(kv_self.kv_l[il]->type, kv_self.size),
- 0);
+ ggml_tensor * kv_cache_trans = ggml_view_2d(ctx0, kv_self.kvt_l[il],
+ n_kv, kv_lora_rank,
+ ggml_row_size(kv_self.kv_l[il]->type, kv_self.size),
+ 0);
cb(kv_cache_trans, "kv_cache_trans", il);
+#endif
- //q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
- q_pe = ggml_rope_ext(
- ctx0, q_pe, inp_pos, nullptr,
- n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
- ext_factor, attn_factor_scaled, beta_fast, beta_slow
- );
- cb(q_pe, "q_pe", il);
-
- // shared RoPE key
- //k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
- k_pe = ggml_rope_ext(
- ctx0, k_pe, inp_pos, nullptr,
- n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
- ext_factor, attn_factor_scaled, beta_fast, beta_slow
- );
- cb(k_pe, "k_pe", il);
-
- struct ggml_tensor * kr_cache_view = ggml_view_1d(ctx0, kv_self.kr_l[il], n_tokens*n_embd_head_qk_rope, ggml_row_size(kv_self.kr_l[il]->type, n_embd_head_qk_rope)*kv_head);
- cb(kr_cache_view, "kr_cache_view", il);
-
- // note: storing RoPE-ed version of K^R in the KV cache
- ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_pe, kr_cache_view));
+ ggml_tensor * kvr = ggml_concat(ctx0, kv_compressed, ggml_permute(ctx0, k_rope, 0, 2, 1, 3), 0);
+ cb(kvr, "kvr", il);
- struct ggml_tensor * kr_cache =
- ggml_view_2d(ctx0, kv_self.kr_l[il],
- n_embd_head_qk_rope, n_kv,
- ggml_row_size(kv_self.kr_l[il]->type, n_embd_head_qk_rope),
- 0);
- cb(kr_cache, "kr_cache", il);
+ ggml_tensor * kv_cache_view = ggml_view_1d(ctx0, kv_self.kv_l[il], n_tokens*(kv_lora_rank + n_embd_head_qk_rope),
+ ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope)*kv_head);
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, kvr, kv_cache_view));
+ ggml_tensor * kv_cache = ggml_view_2d(ctx0, kv_self.kv_l[il],
+ kv_lora_rank + n_embd_head_qk_rope, n_kv,
+ ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), 0);
+ cb(kv_cache, "kv_cache", il);
- struct ggml_tensor * wk_b = ggml_view_3d(ctx0, model.layers[il].wk_b, n_embd_head_qk_nope, kv_lora_rank, n_head, ggml_row_size(model.layers[il].wk_b->type, n_embd_head_qk_nope), ggml_row_size(model.layers[il].wk_b->type, kv_lora_rank * n_embd_head_qk_nope), 0);
+ struct ggml_tensor * wk_b = ggml_view_3d(ctx0, model.layers[il].wk_b, n_embd_head_qk_nope, kv_lora_rank, n_head,
+ ggml_row_size(model.layers[il].wk_b->type, n_embd_head_qk_nope),
+ ggml_row_size(model.layers[il].wk_b->type, kv_lora_rank)*n_embd_head_qk_nope, 0);
cb(wk_b, "wk_b", il);
q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3);
@@ -13518,33 +13530,20 @@ struct llm_build_context {
struct ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b, q_nope);
cb(q_nope2, "q_nope2", il);
+ ggml_tensor * q = ggml_concat(ctx0, q_nope2, ggml_permute(ctx0, q_rope, 0, 2, 1, 3), 0);
+ cb(q, "q", il);
if (!pp_opt) {
- q_nope2 = ggml_permute(ctx0, q_nope2, 0, 2, 1, 3);
- cb(q_nope2, "q_nope2_perm", il);
- }
- struct ggml_tensor * kq_nope = ggml_mul_mat(ctx0, kv_cache, q_nope2);
- cb(kq_nope, "kq_nope", il);
-
- if (!pp_opt) {
- kq_nope = ggml_permute(ctx0, kq_nope, 0, 2, 1, 3);
- cb(kq_nope, "kq_nope_perm", il);
- }
-
- if (pp_opt) {
- q_pe = ggml_permute(ctx0, q_pe, 0, 2, 1, 3);
- cb(q_pe, "q_pe_perm", il);
+ q = ggml_permute(ctx0, q, 0, 2, 1, 3);
+ cb(q, "q_perm", il);
}
- struct ggml_tensor * kq_pe = ggml_mul_mat(ctx0, kr_cache, q_pe);
- cb(kq_pe, "kq_pe", il);
+ ggml_tensor * kq = ggml_mul_mat(ctx0, kv_cache, q);
+ cb(kq, "kq", il);
- if (!pp_opt) {
- kq_pe = ggml_permute(ctx0, kq_pe, 0, 2, 1, 3);
- cb(kq_pe, "kq_pe_perm", il);
+ if (!pp_opt) {
+ kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3));
+ cb(kq, "kq_perm", il);
}
- struct ggml_tensor * kq = ggml_add(ctx0, kq_nope, kq_pe);
- cb(kq, "kq", il);
-
kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
cb(kq, "kq_soft_max_ext", il);
@@ -13553,6 +13552,16 @@ struct llm_build_context {
cb(kq, "kq_soft_max_ext_perm", il);
}
+#if !MLA_USE_TRANSPOSED_CACHE
+ ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.kv_l[il],
+ kv_lora_rank, n_kv,
+ ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), 0);
+ cb(kv_cache, "kv_cache_lora", il);
+
+ ggml_tensor * kv_cache_trans = ggml_cont(ctx0, ggml_transpose(ctx0, kv_cache_lora));
+ cb(kv_cache_trans, "kv_cache_trans", il);
+#endif
+
struct ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq);
cb(kqv_compressed, "kqv_compressed", il);
@@ -13561,7 +13570,9 @@ struct llm_build_context {
cb(kqv_compressed, "kqv_compressed_perm", il);
}
- struct ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, kv_lora_rank, n_embd_head_v, n_head, ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank), ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank * n_embd_head_v), 0);
+ struct ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, kv_lora_rank, n_embd_head_v, n_head,
+ ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank),
+ ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank)*n_embd_head_v, 0);
cb(wv_b, "wv_b", il);
struct ggml_tensor * kqv = ggml_mul_mat(ctx0, wv_b, kqv_compressed);
@@ -13581,19 +13592,6 @@ struct llm_build_context {
}
else {
- // and {n_embd_head_qk_rope, n_tokens}
- struct ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens,
- kv_pe_compresseed->nb[1],
- kv_pe_compresseed->nb[1],
- ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
- cb(k_pe, "k_pe", il);
-
- //kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm
- kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams,
- model.layers[il].attn_kv_a_norm, NULL,
- LLM_NORM_RMS, cb, il);
- cb(kv_compressed, "kv_compressed", il);
-
// {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens}
struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed);
cb(kv, "kv", il);
@@ -13620,27 +13618,10 @@ struct llm_build_context {
0);
cb(v_states, "v_states", il);
- //q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
- q_pe = ggml_rope_ext(
- ctx0, q_pe, inp_pos, nullptr,
- n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
- ext_factor, attn_factor_scaled, beta_fast, beta_slow
- );
- cb(q_pe, "q_pe", il);
-
- // shared RoPE key
- //k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
- k_pe = ggml_rope_ext(
- ctx0, k_pe, inp_pos, nullptr,
- n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
- ext_factor, attn_factor_scaled, beta_fast, beta_slow
- );
- cb(k_pe, "k_pe", il);
-
- struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0);
+ struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_rope, 0);
cb(q_states, "q_states", il);
- struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
+ struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_rope, q_rope), 0);
cb(k_states, "k_states", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
@@ -18054,28 +18035,34 @@ struct llama_context * llama_new_context_with_model(
memory_size_v += ggml_nbytes(v);
}
- LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
- (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
- ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
- ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
+ if (memory_size_k + memory_size_v > 0) {
+ LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
+ (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
+ ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
+ ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
+ }
}
- {
- size_t memory_size_kr = 0;
+ {
size_t memory_size_kv = 0;
-
- for (auto & kr : ctx->kv_self.kr_l) {
- memory_size_kr += ggml_nbytes(kr);
- }
+ size_t memory_size_kvt = 0;
for (auto & kv : ctx->kv_self.kv_l) {
memory_size_kv += ggml_nbytes(kv);
}
- LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K^R (%s): %7.2f MiB, c^KV (%s): %7.2f MiB\n", __func__,
- (float)(memory_size_kr + memory_size_kv) / (1024.0f * 1024.0f),
- ggml_type_name(type_k), (float)memory_size_kr / (1024.0f * 1024.0f),
- ggml_type_name(type_k), (float)memory_size_kv / (1024.0f * 1024.0f));
+#if MLA_USE_TRANSPOSED_CACHE
+ for (auto & kvt : ctx->kv_self.kvt_l) {
+ memory_size_kvt += ggml_nbytes(kvt);
+ }
+#endif
+
+ if (memory_size_kv + memory_size_kvt > 0) {
+ LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, c^KV (%s): %7.2f MiB, kv^T (%s): %7.2f MiB\n", __func__,
+ (float)(memory_size_kv + memory_size_kvt) / (1024.0f * 1024.0f),
+ ggml_type_name(type_v), (float)memory_size_kv / (1024.0f * 1024.0f),
+ ggml_type_name(type_v), (float)memory_size_kvt / (1024.0f * 1024.0f));
+ }
}
// graph outputs buffer