summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/llama.cpp338
1 files changed, 279 insertions, 59 deletions
diff --git a/src/llama.cpp b/src/llama.cpp
index 29926a94..00e6c934 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -539,6 +539,8 @@ enum llm_tensor {
LLM_TENSOR_ATTN_Q_B,
LLM_TENSOR_ATTN_KV_A_MQA,
LLM_TENSOR_ATTN_KV_B,
+ LLM_TENSOR_ATTN_K_B,
+ LLM_TENSOR_ATTN_V_B,
LLM_TENSOR_ATTN_Q_A_NORM,
LLM_TENSOR_ATTN_KV_A_NORM,
LLM_TENSOR_ATTN_SUB_NORM,
@@ -1203,6 +1205,8 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
{ LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" },
{ LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
{ LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
+ { LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" },
+ { LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
@@ -2503,6 +2507,7 @@ struct llama_cparams {
bool causal_attn;
bool offload_kqv;
bool flash_attn;
+ bool mla_attn;
enum llama_pooling_type pooling_type;
@@ -2541,6 +2546,8 @@ struct llama_layer {
struct ggml_tensor * wq_b;
struct ggml_tensor * wkv_a_mqa;
struct ggml_tensor * wkv_b;
+ struct ggml_tensor * wk_b;
+ struct ggml_tensor * wv_b;
struct ggml_tensor * wq_cross;
struct ggml_tensor * wk_cross;
struct ggml_tensor * wv_cross;
@@ -2669,11 +2676,19 @@ struct llama_kv_cache {
ggml_type type_k = GGML_TYPE_F16;
ggml_type type_v = GGML_TYPE_F16;
+ ggml_type type_kr = GGML_TYPE_F16;
+ ggml_type type_kv = GGML_TYPE_F16;
+
std::vector<llama_kv_cell> cells;
std::vector<struct ggml_tensor *> k_l; // per layer
std::vector<struct ggml_tensor *> v_l;
+ // DeepSeek MLA
+ std::vector<struct ggml_tensor *> kr_l; // per layer
+ std::vector<struct ggml_tensor *> kv_l;
+ std::vector<struct ggml_tensor *> kvt_l;
+
std::vector<struct ggml_context *> ctxs;
std::vector<ggml_backend_buffer_t> bufs;
@@ -3104,8 +3119,10 @@ static bool llama_kv_cache_init(
cache.size = kv_size;
cache.used = 0;
- cache.type_k = type_k;
- cache.type_v = type_v;
+ cache.type_k = type_k;
+ cache.type_v = type_v;
+ cache.type_kr = type_k;
+ cache.type_kv = type_v;
cache.cells.clear();
cache.cells.resize(kv_size);
@@ -3132,7 +3149,7 @@ static bool llama_kv_cache_init(
for (auto & it : buft_layer_count) {
int n_layers = it.second;
struct ggml_init_params params = {
- /*.mem_size =*/ 2u*n_layers*ggml_tensor_overhead(),
+ /*.mem_size =*/ 5u*n_layers*ggml_tensor_overhead(),
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
};
@@ -3148,17 +3165,46 @@ static bool llama_kv_cache_init(
cache.k_l.reserve(n_layer);
cache.v_l.reserve(n_layer);
+ // DeepSeek MLA
+ cache.kr_l.reserve(n_layer);
+ cache.kv_l.reserve(n_layer);
+ cache.kvt_l.reserve(n_layer);
+
for (int i = 0; i < (int) n_layer; i++) {
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
- ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
- ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
+ ggml_tensor * k;
+ ggml_tensor * v;
+ if (cparams.mla_attn && model.layers[i].wk_b && model.layers[i].wv_b) {
+ k = ggml_new_tensor_1d(ctx, type_k, 1);
+ v = ggml_new_tensor_1d(ctx, type_v, 1);
+ }
+ else {
+ k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
+ v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
+ }
+
ggml_format_name(k, "cache_k_l%d", i);
ggml_format_name(v, "cache_v_l%d", i);
cache.k_l.push_back(k);
cache.v_l.push_back(v);
+
+
+ // DeepSeek MLA
+ const uint32_t n_embd_head_qk_rope = hparams.n_rot;
+ const uint32_t kv_lora_rank = hparams.n_lora_kv;
+ LLAMA_LOG_INFO("%s: layer %d: n_embd_head_qk_rope = %d, kv_lora_rank = %d\n", __func__, i, n_embd_head_qk_rope, kv_lora_rank);
+ ggml_tensor * kr = ggml_new_tensor_1d(ctx, cache.type_kr, n_embd_head_qk_rope*kv_size);
+ ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_kv, kv_lora_rank*kv_size);
+ ggml_tensor * kvt = ggml_new_tensor_1d(ctx, cache.type_kv, kv_lora_rank*kv_size);
+ ggml_format_name(kr, "cache_kr_l%d", i);
+ ggml_format_name(kv, "cache_kv_l%d", i);
+ ggml_format_name(kvt, "cache_kvt_l%d", i);
+ cache.kr_l.push_back(kr);
+ cache.kv_l.push_back(kv);
+ cache.kvt_l.push_back(kvt);
}
// allocate tensors and initialize the buffers to avoid NaNs in the padding
@@ -7644,6 +7690,8 @@ static bool llm_load_tensors(
layer.wkv_a_mqa = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)});
layer.wkv_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)});
+ layer.wk_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, n_head * kv_lora_rank}, 1);
+ layer.wv_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_head * n_embd_head_v}, 1);
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd});
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
@@ -8815,6 +8863,7 @@ struct llm_build_context {
const int32_t n_ctx_orig;
const bool flash_attn;
+ const bool mla_attn;
const enum llama_pooling_type pooling_type;
const enum llama_rope_type rope_type;
@@ -8864,6 +8913,7 @@ struct llm_build_context {
kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head),
n_ctx_orig (cparams.n_ctx_orig_yarn),
flash_attn (cparams.flash_attn),
+ mla_attn (cparams.mla_attn),
pooling_type (cparams.pooling_type),
rope_type (hparams.rope_type),
cb (cb),
@@ -13329,6 +13379,10 @@ struct llm_build_context {
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
+ // whether to use n_tokens as the matrix dimension during multiplication or n_head
+ // n_tokens is higher during prompt processing, this allows to optimize for this case
+ bool pp_opt = n_tokens > n_head;
+
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * inpSA = inpL;
@@ -13383,71 +13437,216 @@ struct llm_build_context {
0);
cb(kv_compressed, "kv_compressed", il);
- // and {n_embd_head_qk_rope, n_tokens}
- struct ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens,
- kv_pe_compresseed->nb[1],
- kv_pe_compresseed->nb[1],
- ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
- cb(k_pe, "k_pe", il);
+ if (lctx.cparams.mla_attn && model.layers[il].wk_b && model.layers[il].wv_b) {
- //kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm
- kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams,
- model.layers[il].attn_kv_a_norm, NULL,
- LLM_NORM_RMS, cb, il);
- cb(kv_compressed, "kv_compressed", il);
+ // and {n_embd_head_qk_rope, n_tokens}
+ struct ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens,
+ kv_pe_compresseed->nb[1],
+ kv_pe_compresseed->nb[1],
+ ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
+ cb(k_pe, "k_pe", il);
- // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens}
- struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed);
- cb(kv, "kv", il);
+ //kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm
+ kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams,
+ model.layers[il].attn_kv_a_norm, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(kv_compressed, "kv_compressed", il);
- // split into {n_head * n_embd_head_qk_nope, n_tokens}
- struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
- ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
- ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
- 0);
- cb(k_nope, "k_nope", il);
+ struct ggml_tensor * kv_cache_view = ggml_view_1d(ctx0, kv_self.kv_l[il], n_tokens*kv_lora_rank, ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank)*kv_head);
+ cb(kv_cache_view, "kv_cache_view", il);
- // and {n_head * n_embd_head_v, n_tokens}
- struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens,
- ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)),
- ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head),
- ggml_row_size(kv->type, (n_embd_head_qk_nope)));
- cb(v_states, "v_states", il);
+ // note: storing c^KV in the KV cache
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, kv_compressed, kv_cache_view));
- v_states = ggml_cont(ctx0, v_states);
- cb(v_states, "v_states", il);
+ struct ggml_tensor * kv_cache_trans_view = ggml_view_2d(ctx0, kv_self.kvt_l[il], n_tokens, kv_lora_rank, ggml_row_size(kv_self.kv_l[il]->type, kv_self.size), ggml_row_size(kv_self.kv_l[il]->type, kv_head));
+ cb(kv_cache_trans_view, "kv_cache_trans_view", il);
- v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens,
- ggml_row_size(kv->type, hparams.n_embd_head_v * n_head),
- 0);
- cb(v_states, "v_states", il);
+ // note: storing transposed c^KV in the transposed KV cache
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, ggml_transpose(ctx0, kv_compressed), kv_cache_trans_view));
- //q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
- q_pe = ggml_rope_ext(
- ctx0, q_pe, inp_pos, nullptr,
- n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
- ext_factor, attn_factor_scaled, beta_fast, beta_slow
- );
- cb(q_pe, "q_pe", il);
+ struct ggml_tensor * kv_cache =
+ ggml_view_2d(ctx0, kv_self.kv_l[il],
+ kv_lora_rank, n_kv,
+ ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank),
+ 0);
+ cb(kv_cache, "kv_cache", il);
- // shared RoPE key
- //k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
- k_pe = ggml_rope_ext(
- ctx0, k_pe, inp_pos, nullptr,
- n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
- ext_factor, attn_factor_scaled, beta_fast, beta_slow
- );
- cb(k_pe, "k_pe", il);
+ struct ggml_tensor * kv_cache_trans =
+ ggml_view_2d(ctx0, kv_self.kvt_l[il],
+ n_kv, kv_lora_rank,
+ ggml_row_size(kv_self.kv_l[il]->type, kv_self.size),
+ 0);
+ cb(kv_cache_trans, "kv_cache_trans", il);
- struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0);
- cb(q_states, "q_states", il);
+ //q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
+ q_pe = ggml_rope_ext(
+ ctx0, q_pe, inp_pos, nullptr,
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor_scaled, beta_fast, beta_slow
+ );
+ cb(q_pe, "q_pe", il);
+
+ // shared RoPE key
+ //k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
+ k_pe = ggml_rope_ext(
+ ctx0, k_pe, inp_pos, nullptr,
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor_scaled, beta_fast, beta_slow
+ );
+ cb(k_pe, "k_pe", il);
- struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
- cb(k_states, "k_states", il);
+ struct ggml_tensor * kr_cache_view = ggml_view_1d(ctx0, kv_self.kr_l[il], n_tokens*n_embd_head_qk_rope, ggml_row_size(kv_self.kr_l[il]->type, n_embd_head_qk_rope)*kv_head);
+ cb(kr_cache_view, "kr_cache_view", il);
- cur = llm_build_kv(ctx0, lctx, kv_self, gf,
- model.layers[il].wo, NULL,
- k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
+ // note: storing RoPE-ed version of K^R in the KV cache
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_pe, kr_cache_view));
+
+ struct ggml_tensor * kr_cache =
+ ggml_view_2d(ctx0, kv_self.kr_l[il],
+ n_embd_head_qk_rope, n_kv,
+ ggml_row_size(kv_self.kr_l[il]->type, n_embd_head_qk_rope),
+ 0);
+ cb(kr_cache, "kr_cache", il);
+
+ struct ggml_tensor * wk_b = ggml_view_3d(ctx0, model.layers[il].wk_b, n_embd_head_qk_nope, kv_lora_rank, n_head, ggml_row_size(model.layers[il].wk_b->type, n_embd_head_qk_nope), ggml_row_size(model.layers[il].wk_b->type, kv_lora_rank * n_embd_head_qk_nope), 0);
+ cb(wk_b, "wk_b", il);
+
+ q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3);
+ cb(q_nope, "q_nope_perm", il);
+
+ struct ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b, q_nope);
+ cb(q_nope2, "q_nope2", il);
+
+ if (!pp_opt) {
+ q_nope2 = ggml_permute(ctx0, q_nope2, 0, 2, 1, 3);
+ cb(q_nope2, "q_nope2_perm", il);
+ }
+ struct ggml_tensor * kq_nope = ggml_mul_mat(ctx0, kv_cache, q_nope2);
+ cb(kq_nope, "kq_nope", il);
+
+ if (!pp_opt) {
+ kq_nope = ggml_permute(ctx0, kq_nope, 0, 2, 1, 3);
+ cb(kq_nope, "kq_nope_perm", il);
+ }
+
+ if (pp_opt) {
+ q_pe = ggml_permute(ctx0, q_pe, 0, 2, 1, 3);
+ cb(q_pe, "q_pe_perm", il);
+ }
+ struct ggml_tensor * kq_pe = ggml_mul_mat(ctx0, kr_cache, q_pe);
+ cb(kq_pe, "kq_pe", il);
+
+ if (!pp_opt) {
+ kq_pe = ggml_permute(ctx0, kq_pe, 0, 2, 1, 3);
+ cb(kq_pe, "kq_pe_perm", il);
+ }
+
+ struct ggml_tensor * kq = ggml_add(ctx0, kq_nope, kq_pe);
+ cb(kq, "kq", il);
+
+ kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
+ cb(kq, "kq_soft_max_ext", il);
+
+ if (!pp_opt) {
+ kq = ggml_permute(ctx0, kq, 0, 2, 1, 3);
+ cb(kq, "kq_soft_max_ext_perm", il);
+ }
+
+ struct ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq);
+ cb(kqv_compressed, "kqv_compressed", il);
+
+ if (!pp_opt) {
+ kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3);
+ cb(kqv_compressed, "kqv_compressed_perm", il);
+ }
+
+ struct ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, kv_lora_rank, n_embd_head_v, n_head, ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank), ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank * n_embd_head_v), 0);
+ cb(wv_b, "wv_b", il);
+
+ struct ggml_tensor * kqv = ggml_mul_mat(ctx0, wv_b, kqv_compressed);
+ cb(kqv, "kqv", il);
+
+ kqv = ggml_cont(ctx0, ggml_permute(ctx0, kqv, 0, 2, 1, 3));
+ cb(kqv, "kqv_perm", il);
+
+ cur = ggml_view_2d(ctx0, kqv, n_embd_head_v*n_head, n_tokens, ggml_row_size(kqv->type, n_embd_head_v*n_head), 0);
+ cb(cur, "kqv_2d", il);
+
+ ggml_build_forward_expand(gf, cur);
+
+ cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
+ cb(cur, "kqv_out", il);
+
+ }
+ else {
+
+ // and {n_embd_head_qk_rope, n_tokens}
+ struct ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens,
+ kv_pe_compresseed->nb[1],
+ kv_pe_compresseed->nb[1],
+ ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
+ cb(k_pe, "k_pe", il);
+
+ //kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm
+ kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams,
+ model.layers[il].attn_kv_a_norm, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(kv_compressed, "kv_compressed", il);
+
+ // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens}
+ struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed);
+ cb(kv, "kv", il);
+
+ // split into {n_head * n_embd_head_qk_nope, n_tokens}
+ struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
+ ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
+ ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
+ 0);
+ cb(k_nope, "k_nope", il);
+
+ // and {n_head * n_embd_head_v, n_tokens}
+ struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens,
+ ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)),
+ ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head),
+ ggml_row_size(kv->type, (n_embd_head_qk_nope)));
+ cb(v_states, "v_states", il);
+
+ v_states = ggml_cont(ctx0, v_states);
+ cb(v_states, "v_states", il);
+
+ v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens,
+ ggml_row_size(kv->type, hparams.n_embd_head_v * n_head),
+ 0);
+ cb(v_states, "v_states", il);
+
+ //q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
+ q_pe = ggml_rope_ext(
+ ctx0, q_pe, inp_pos, nullptr,
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor_scaled, beta_fast, beta_slow
+ );
+ cb(q_pe, "q_pe", il);
+
+ // shared RoPE key
+ //k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
+ k_pe = ggml_rope_ext(
+ ctx0, k_pe, inp_pos, nullptr,
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor_scaled, beta_fast, beta_slow
+ );
+ cb(k_pe, "k_pe", il);
+
+ struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0);
+ cb(q_states, "q_states", il);
+
+ struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
+ cb(k_states, "k_states", il);
+
+ cur = llm_build_kv(ctx0, lctx, kv_self, gf,
+ model.layers[il].wo, NULL,
+ k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
+
+ }
}
if (il == n_layer - 1) {
@@ -17391,6 +17590,7 @@ struct llama_context_params llama_context_default_params() {
/*.embeddings =*/ false,
/*.offload_kqv =*/ true,
/*.flash_attn =*/ false,
+ /*.mla_attn =*/ false,
/*.abort_callback =*/ nullptr,
/*.abort_callback_data =*/ nullptr,
};
@@ -17589,6 +17789,7 @@ struct llama_context * llama_new_context_with_model(
cparams.embeddings = params.embeddings;
cparams.offload_kqv = params.offload_kqv;
cparams.flash_attn = params.flash_attn;
+ cparams.mla_attn = params.mla_attn;
cparams.pooling_type = params.pooling_type;
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
@@ -17655,6 +17856,7 @@ struct llama_context * llama_new_context_with_model(
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
+ LLAMA_LOG_INFO("%s: mla_attn = %d\n", __func__, cparams.mla_attn);
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
@@ -17853,6 +18055,24 @@ struct llama_context * llama_new_context_with_model(
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
}
+ {
+ size_t memory_size_kr = 0;
+ size_t memory_size_kv = 0;
+
+ for (auto & kr : ctx->kv_self.kr_l) {
+ memory_size_kr += ggml_nbytes(kr);
+ }
+
+ for (auto & kv : ctx->kv_self.kv_l) {
+ memory_size_kv += ggml_nbytes(kv);
+ }
+
+ LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K^R (%s): %7.2f MiB, c^KV (%s): %7.2f MiB\n", __func__,
+ (float)(memory_size_kr + memory_size_kv) / (1024.0f * 1024.0f),
+ ggml_type_name(type_k), (float)memory_size_kr / (1024.0f * 1024.0f),
+ ggml_type_name(type_k), (float)memory_size_kv / (1024.0f * 1024.0f));
+ }
+
// graph outputs buffer
{
// resized during inference when a batch uses more outputs