summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--common/common.cpp7
-rw-r--r--common/common.h1
-rwxr-xr-xconvert_hf_to_gguf.py42
-rw-r--r--examples/llama-bench/llama-bench.cpp35
-rw-r--r--ggml/src/ggml.c17
-rw-r--r--gguf-py/gguf/constants.py6
-rw-r--r--gguf-py/gguf/tensor_mapping.py8
-rw-r--r--include/llama.h1
-rw-r--r--src/llama.cpp338
9 files changed, 380 insertions, 75 deletions
diff --git a/common/common.cpp b/common/common.cpp
index 95e91bc1..6219f0ce 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -813,6 +813,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.flash_attn = true;
return true;
}
+ if (arg == "-mla" || arg == "--mla-use") {
+ params.mla_attn = true;
+ return true;
+ }
if (arg == "-co" || arg == "--color") {
params.use_color = true;
return true;
@@ -1452,6 +1456,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "*", " --keep N", "number of tokens to keep from the initial prompt (default: %d, -1 = all)", params.n_keep });
options.push_back({ "*", " --chunks N", "max number of chunks to process (default: %d, -1 = all)", params.n_chunks });
options.push_back({ "*", "-fa, --flash-attn", "enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled" });
+ options.push_back({ "*", "-mla, --mla-use", "enable MLA (default: %s)", params.mla_attn ? "enabled" : "disabled" });
options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n"
"in conversation mode, this will be used as system prompt\n"
"(default: '%s')", params.prompt.c_str() });
@@ -2283,6 +2288,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.cb_eval_user_data = params.cb_eval_user_data;
cparams.offload_kqv = !params.no_kv_offload;
cparams.flash_attn = params.flash_attn;
+ cparams.mla_attn = params.mla_attn;
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
@@ -3280,6 +3286,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false");
fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false");
fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false");
+ fprintf(stream, "mla_attn: %s # default: false\n", params.mla_attn ? "true" : "false");
fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp);
const std::vector<float> tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices());
diff --git a/common/common.h b/common/common.h
index 73d7d650..fc1ae619 100644
--- a/common/common.h
+++ b/common/common.h
@@ -174,6 +174,7 @@ struct gpt_params {
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
bool cont_batching = true; // insert new sequences for decoding on-the-fly
bool flash_attn = false; // flash attention
+ bool mla_attn = false; // MLA
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
bool ignore_eos = false; // ignore generated EOS tokens
diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py
index 3910aa1d..1ee82724 100755
--- a/convert_hf_to_gguf.py
+++ b/convert_hf_to_gguf.py
@@ -3123,6 +3123,7 @@ class ArcticModel(Model):
@Model.register("DeepseekV2ForCausalLM")
+@Model.register("DeepseekV3ForCausalLM")
class DeepseekV2Model(Model):
model_arch = gguf.MODEL_ARCH.DEEPSEEK2
@@ -3144,6 +3145,15 @@ class DeepseekV2Model(Model):
self.gguf_writer.add_expert_count(hparams["n_routed_experts"])
self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"])
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
+ self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
+
+ if hparams["scoring_func"] == "sigmoid":
+ self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
+ elif hparams["scoring_func"] == "softmax":
+ self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
+ else:
+ raise ValueError(f"Unsupported scoring_func value: {hparams['scoring_func']}")
+
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
@@ -3156,6 +3166,17 @@ class DeepseekV2Model(Model):
_experts: list[dict[str, Tensor]] | None = None
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+ # rename e_score_correction_bias tensors
+ if name.endswith("e_score_correction_bias"):
+ name = name.replace("e_score_correction_bias", "e_score_correction.bias")
+
+ # skip Multi-Token Prediction (MTP) layers
+ block_count = self.hparams["num_hidden_layers"]
+ match = re.match(r"model.layers.(\d+)", name)
+ if match and int(match.group(1)) >= block_count:
+ return []
+
+
# process the experts separately
if name.find("mlp.experts") != -1:
n_experts = self.hparams["n_routed_experts"]
@@ -3188,6 +3209,27 @@ class DeepseekV2Model(Model):
return tensors
else:
return []
+ if name.endswith("kv_b_proj.weight"):
+ name_kb = name.replace("kv_b_proj", "k_b_proj")
+ name_vb = name.replace("kv_b_proj", "v_b_proj")
+
+ n_head_kv = self.hparams["num_key_value_heads"]
+ v_head_dim = self.hparams["v_head_dim"]
+ qk_nope_head_dim = self.hparams["qk_nope_head_dim"]
+
+ assert data_torch.shape[0] == n_head_kv * (v_head_dim + qk_nope_head_dim)
+
+ kv_b = data_torch.view(n_head_kv, v_head_dim + qk_nope_head_dim, data_torch.shape[-1])
+ k_b, v_b = torch.split(kv_b, [qk_nope_head_dim, v_head_dim], dim=1)
+ k_b = k_b.transpose(1, 2)
+ k_b = k_b.reshape(n_head_kv * data_torch.shape[-1], qk_nope_head_dim)
+ v_b = v_b.reshape(n_head_kv * v_head_dim, data_torch.shape[-1])
+
+ return [
+ (self.map_tensor_name(name), data_torch),
+ (self.map_tensor_name(name_kb), k_b),
+ (self.map_tensor_name(name_vb), v_b)
+ ]
return [(self.map_tensor_name(name), data_torch)]
diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp
index 42320da8..41b93df5 100644
--- a/examples/llama-bench/llama-bench.cpp
+++ b/examples/llama-bench/llama-bench.cpp
@@ -232,6 +232,7 @@ struct cmd_params {
std::vector<int> main_gpu;
std::vector<bool> no_kv_offload;
std::vector<bool> flash_attn;
+ std::vector<bool> mla_attn;
std::vector<std::vector<float>> tensor_split;
std::vector<bool> use_mmap;
std::vector<bool> embeddings;
@@ -261,6 +262,7 @@ static const cmd_params cmd_params_defaults = {
/* main_gpu */ {0},
/* no_kv_offload */ {false},
/* flash_attn */ {false},
+ /* mla_attn */ {false},
/* tensor_split */ {std::vector<float>(llama_max_devices(), 0.0f)},
/* use_mmap */ {true},
/* embeddings */ {false},
@@ -294,6 +296,7 @@ static void print_usage(int /* argc */, char ** argv) {
printf(" -mg, --main-gpu <i> (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str());
printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str());
printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str());
+ printf(" -mla, --mla-attn <0|1> (default: %s)\n", join(cmd_params_defaults.mla_attn, ",").c_str());
printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str());
printf(" --numa <distribute|isolate|numactl> (default: disabled)\n");
printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str());
@@ -526,6 +529,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
}
auto p = string_split<bool>(argv[i], split_delim);
params.flash_attn.insert(params.flash_attn.end(), p.begin(), p.end());
+ } else if (arg == "-mla" || arg == "--mla-attn") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ auto p = string_split<bool>(argv[i], split_delim);
+ params.mla_attn.insert(params.mla_attn.end(), p.begin(), p.end());
} else if (arg == "-mmp" || arg == "--mmap") {
if (++i >= argc) {
invalid_param = true;
@@ -621,6 +631,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
if (params.main_gpu.empty()) { params.main_gpu = cmd_params_defaults.main_gpu; }
if (params.no_kv_offload.empty()){ params.no_kv_offload = cmd_params_defaults.no_kv_offload; }
if (params.flash_attn.empty()) { params.flash_attn = cmd_params_defaults.flash_attn; }
+ if (params.mla_attn.empty()) { params.mla_attn = cmd_params_defaults.mla_attn; }
if (params.tensor_split.empty()) { params.tensor_split = cmd_params_defaults.tensor_split; }
if (params.use_mmap.empty()) { params.use_mmap = cmd_params_defaults.use_mmap; }
if (params.embeddings.empty()) { params.embeddings = cmd_params_defaults.embeddings; }
@@ -656,6 +667,7 @@ struct cmd_params_instance {
int main_gpu;
bool no_kv_offload;
bool flash_attn;
+ bool mla_attn;
std::vector<float> tensor_split;
bool use_mmap;
bool embeddings;
@@ -698,6 +710,7 @@ struct cmd_params_instance {
cparams.type_v = type_v;
cparams.offload_kqv = !no_kv_offload;
cparams.flash_attn = flash_attn;
+ cparams.mla_attn = mla_attn;
cparams.embeddings = embeddings;
return cparams;
@@ -722,6 +735,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
for (const auto & tv : params.type_v)
for (const auto & nkvo : params.no_kv_offload)
for (const auto & fa : params.flash_attn)
+ for (const auto & mla : params.mla_attn)
for (const auto & nt : params.n_threads) {
for (const auto & n_prompt : params.n_prompt) {
if (n_prompt == 0) {
@@ -743,6 +757,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .main_gpu = */ mg,
/* .no_kv_offload= */ nkvo,
/* .flash_attn = */ fa,
+ /* .mla_attn = */ mla,
/* .tensor_split = */ ts,
/* .use_mmap = */ mmp,
/* .embeddings = */ embd,
@@ -771,6 +786,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .main_gpu = */ mg,
/* .no_kv_offload= */ nkvo,
/* .flash_attn = */ fa,
+ /* .mla_attn = */ mla,
/* .tensor_split = */ ts,
/* .use_mmap = */ mmp,
/* .embeddings = */ embd,
@@ -799,6 +815,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .main_gpu = */ mg,
/* .no_kv_offload= */ nkvo,
/* .flash_attn = */ fa,
+ /* .mla_attn = */ mla,
/* .tensor_split = */ ts,
/* .use_mmap = */ mmp,
/* .embeddings = */ embd,
@@ -827,6 +844,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .main_gpu = */ mg,
/* .no_kv_offload= */ nkvo,
/* .flash_attn = */ fa,
+ /* .mla_attn = */ mla,
/* .tensor_split = */ ts,
/* .use_mmap = */ mmp,
/* .embeddings = */ embd,
@@ -866,6 +884,7 @@ struct test {
int main_gpu;
bool no_kv_offload;
bool flash_attn;
+ bool mla_attn;
std::vector<float> tensor_split;
bool use_mmap;
bool embeddings;
@@ -895,6 +914,7 @@ struct test {
main_gpu = inst.main_gpu;
no_kv_offload = inst.no_kv_offload;
flash_attn = inst.flash_attn;
+ mla_attn = inst.mla_attn;
tensor_split = inst.tensor_split;
use_mmap = inst.use_mmap;
embeddings = inst.embeddings;
@@ -988,7 +1008,7 @@ struct test {
"n_batch", "n_ubatch",
"n_threads", "type_k", "type_v",
"n_gpu_layers", "split_mode",
- "main_gpu", "no_kv_offload", "flash_attn",
+ "main_gpu", "no_kv_offload", "flash_attn", "mla_attn",
"tensor_split", "use_mmap", "embeddings", "repack",
"n_prompt", "n_gen", "test_time",
"avg_ns", "stddev_ns",
@@ -1010,7 +1030,7 @@ struct test {
}
if (field == "cuda" || field == "vulkan" || field == "kompute" || field == "metal" ||
field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" ||
- field == "flash_attn" || field == "use_mmap" || field == "embeddings" || field == "repack") {
+ field == "flash_attn" || field == "mla_attn" || field == "use_mmap" || field == "embeddings" || field == "repack") {
return BOOL;
}
if (field == "avg_ts" || field == "stddev_ts") {
@@ -1044,7 +1064,7 @@ struct test {
std::to_string(n_batch), std::to_string(n_ubatch),
std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v),
std::to_string(n_gpu_layers), split_mode_str(split_mode),
- std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn),
+ std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), std::to_string(mla_attn),
tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), std::to_string(repack),
std::to_string(n_prompt), std::to_string(n_gen), test_time,
std::to_string(avg_ns()), std::to_string(stdev_ns()),
@@ -1208,6 +1228,9 @@ struct markdown_printer : public printer {
if (field == "flash_attn") {
return 2;
}
+ if (field == "mla_attn") {
+ return 3;
+ }
if (field == "use_mmap") {
return 4;
}
@@ -1242,6 +1265,9 @@ struct markdown_printer : public printer {
if (field == "flash_attn") {
return "fa";
}
+ if (field == "mla_attn") {
+ return "mla";
+ }
if (field == "use_mmap") {
return "mmap";
}
@@ -1294,6 +1320,9 @@ struct markdown_printer : public printer {
if (params.flash_attn.size() > 1 || params.flash_attn != cmd_params_defaults.flash_attn) {
fields.emplace_back("flash_attn");
}
+ if (params.mla_attn.size() > 1 || params.mla_attn != cmd_params_defaults.mla_attn) {
+ fields.emplace_back("mla_attn");
+ }
if (params.tensor_split.size() > 1 || params.tensor_split != cmd_params_defaults.tensor_split) {
fields.emplace_back("tensor_split");
}
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index e07dd547..3867cf00 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -14064,31 +14064,22 @@ static void ggml_compute_forward_mul_mat(
#endif
#if GGML_USE_IQK_MULMAT
- if (dst->type == GGML_TYPE_F32 && (ne12*ne13)%nth == 0) {
+ if (dst->type == GGML_TYPE_F32) {
+ int gcd = simple_gcd(ne12*ne13, nth);
int counter = 0;
for (int64_t i13 = 0; i13 < ne13; i13++) {
for (int64_t i12 = 0; i12 < ne12; i12++) {
- if (counter++ % nth == ith) {
+ if ((counter++ % gcd) == (ith%gcd)) {
if (!iqk_mul_mat(ne01, ne11, ne00,
src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01, ///ggml_type_size(src0->type),
src1->type, (const char *)src1->data + i12*nb12 + i13*nb13, nb11, ///ggml_type_size(src1->type),
(float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type),
- 0, 1)) goto IQK_MulMat_Not_Available1;
+ ith/gcd, nth/gcd)) goto IQK_MulMat_Not_Available1;
}
}
}
return;
}
- if (dst->type == GGML_TYPE_F32) {
- for (int64_t i13 = 0; i13 < ne13; i13++)
- for (int64_t i12 = 0; i12 < ne12; i12++)
- if (!iqk_mul_mat(ne01, ne11, ne00,
- src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01, ///ggml_type_size(src0->type),
- src1->type, (const char *)src1->data + i12*nb12 + i13*nb13, nb11, ///ggml_type_size(src1->type),
- (float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type),
- ith, nth)) goto IQK_MulMat_Not_Available1;
- return;
- }
IQK_MulMat_Not_Available1:;
#endif
diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py
index 90d5efec..037837da 100644
--- a/gguf-py/gguf/constants.py
+++ b/gguf-py/gguf/constants.py
@@ -274,6 +274,8 @@ class MODEL_TENSOR(IntEnum):
ATTN_Q_B = auto()
ATTN_KV_A_MQA = auto()
ATTN_KV_B = auto()
+ ATTN_K_B = auto()
+ ATTN_V_B = auto()
ATTN_Q_A_NORM = auto()
ATTN_KV_A_NORM = auto()
FFN_SUB_NORM = auto()
@@ -403,6 +405,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b",
MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa",
MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b",
+ MODEL_TENSOR.ATTN_K_B: "blk.{bid}.attn_k_b",
+ MODEL_TENSOR.ATTN_V_B: "blk.{bid}.attn_v_b",
MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm",
MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm",
MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm",
@@ -967,6 +971,8 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.ATTN_Q_B,
MODEL_TENSOR.ATTN_KV_A_MQA,
MODEL_TENSOR.ATTN_KV_B,
+ MODEL_TENSOR.ATTN_K_B,
+ MODEL_TENSOR.ATTN_V_B,
MODEL_TENSOR.ATTN_Q_A_NORM,
MODEL_TENSOR.ATTN_KV_A_NORM,
MODEL_TENSOR.ATTN_OUT,
diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py
index a70b69c5..e8725426 100644
--- a/gguf-py/gguf/tensor_mapping.py
+++ b/gguf-py/gguf/tensor_mapping.py
@@ -446,6 +446,14 @@ class TensorNameMap:
"model.layers.{bid}.self_attn.kv_b_proj", # deepseek2
),
+ MODEL_TENSOR.ATTN_K_B: (
+ "model.layers.{bid}.self_attn.k_b_proj", # deepseek2
+ ),
+
+ MODEL_TENSOR.ATTN_V_B: (
+ "model.layers.{bid}.self_attn.v_b_proj", # deepseek2
+ ),
+
MODEL_TENSOR.ATTN_Q_A_NORM: (
"model.layers.{bid}.self_attn.q_a_layernorm", # deepseek2
),
diff --git a/include/llama.h b/include/llama.h
index 730c087a..39251d35 100644
--- a/include/llama.h
+++ b/include/llama.h
@@ -374,6 +374,7 @@ extern "C" {
bool embeddings; // if true, extract embeddings (together with logits)
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
+ bool mla_attn; // whether to use MLA attention [EXPERIMENTAL]
// Abort callback
// if it returns true, execution of llama_decode() will be aborted
diff --git a/src/llama.cpp b/src/llama.cpp
index 29926a94..00e6c934 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -539,6 +539,8 @@ enum llm_tensor {
LLM_TENSOR_ATTN_Q_B,
LLM_TENSOR_ATTN_KV_A_MQA,
LLM_TENSOR_ATTN_KV_B,
+ LLM_TENSOR_ATTN_K_B,
+ LLM_TENSOR_ATTN_V_B,
LLM_TENSOR_ATTN_Q_A_NORM,
LLM_TENSOR_ATTN_KV_A_NORM,
LLM_TENSOR_ATTN_SUB_NORM,
@@ -1203,6 +1205,8 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
{ LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" },
{ LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
{ LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
+ { LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" },
+ { LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
@@ -2503,6 +2507,7 @@ struct llama_cparams {
bool causal_attn;
bool offload_kqv;
bool flash_attn;
+ bool mla_attn;
enum llama_pooling_type pooling_type;
@@ -2541,6 +2546,8 @@ struct llama_layer {
struct ggml_tensor * wq_b;
struct ggml_tensor * wkv_a_mqa;
struct ggml_tensor * wkv_b;
+ struct ggml_tensor * wk_b;
+ struct ggml_tensor * wv_b;
struct ggml_tensor * wq_cross;
struct ggml_tensor * wk_cross;
struct ggml_tensor * wv_cross;
@@ -2669,11 +2676,19 @@ struct llama_kv_cache {
ggml_type type_k = GGML_TYPE_F16;
ggml_type type_v = GGML_TYPE_F16;
+ ggml_type type_kr = GGML_TYPE_F16;
+ ggml_type type_kv = GGML_TYPE_F16;
+
std::vector<llama_kv_cell> cells;
std::vector<struct ggml_tensor *> k_l; // per layer
std::vector<struct ggml_tensor *> v_l;
+ // DeepSeek MLA
+ std::vector<struct ggml_tensor *> kr_l; // per layer
+ std::vector<struct ggml_tensor *> kv_l;
+ std::vector<struct ggml_tensor *> kvt_l;
+
std::vector<struct ggml_context *> ctxs;
std::vector<ggml_backend_buffer_t> bufs;
@@ -3104,8 +3119,10 @@ static bool llama_kv_cache_init(
cache.size = kv_size;
cache.used = 0;
- cache.type_k = type_k;
- cache.type_v = type_v;
+ cache.type_k = type_k;
+ cache.type_v = type_v;
+ cache.type_kr = type_k;
+ cache.type_kv = type_v;
cache.cells.clear();
cache.cells.resize(kv_size);
@@ -3132,7 +3149,7 @@ static bool llama_kv_cache_init(
for (auto & it : buft_layer_count) {
int n_layers = it.second;
struct ggml_init_params params = {
- /*.mem_size =*/ 2u*n_layers*ggml_tensor_overhead(),
+ /*.mem_size =*/ 5u*n_layers*ggml_tensor_overhead(),
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
};
@@ -3148,17 +3165,46 @@ static bool llama_kv_cache_init(
cache.k_l.reserve(n_layer);
cache.v_l.reserve(n_layer);
+ // DeepSeek MLA
+ cache.kr_l.reserve(n_layer);
+ cache.kv_l.reserve(n_layer);
+ cache.kvt_l.reserve(n_layer);
+
for (int i = 0; i < (int) n_layer; i++) {
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
- ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
- ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
+ ggml_tensor * k;
+ ggml_tensor * v;
+ if (cparams.mla_attn && model.layers[i].wk_b && model.layers[i].wv_b) {
+ k = ggml_new_tensor_1d(ctx, type_k, 1);
+ v = ggml_new_tensor_1d(ctx, type_v, 1);
+ }
+ else {
+ k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
+ v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
+ }
+
ggml_format_name(k, "cache_k_l%d", i);
ggml_format_name(v, "cache_v_l%d", i);
cache.k_l.push_back(k);
cache.v_l.push_back(v);
+
+
+ // DeepSeek MLA
+ const uint32_t n_embd_head_qk_rope = hparams.n_rot;
+ const uint32_t kv_lora_rank = hparams.n_lora_kv;
+ LLAMA_LOG_INFO("%s: layer %d: n_embd_head_qk_rope = %d, kv_lora_rank = %d\n", __func__, i, n_embd_head_qk_rope, kv_lora_rank);
+ ggml_tensor * kr = ggml_new_tensor_1d(ctx, cache.type_kr, n_embd_head_qk_rope*kv_size);
+ ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_kv, kv_lora_rank*kv_size);
+ ggml_tensor * kvt = ggml_new_tensor_1d(ctx, cache.type_kv, kv_lora_rank*kv_size);
+ ggml_format_name(kr, "cache_kr_l%d", i);
+ ggml_format_name(kv, "cache_kv_l%d", i);
+ ggml_format_name(kvt, "cache_kvt_l%d", i);
+ cache.kr_l.push_back(kr);
+ cache.kv_l.push_back(kv);
+ cache.kvt_l.push_back(kvt);
}
// allocate tensors and initialize the buffers to avoid NaNs in the padding
@@ -7644,6 +7690,8 @@ static bool llm_load_tensors(
layer.wkv_a_mqa = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)});
layer.wkv_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)});
+ layer.wk_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, n_head * kv_lora_rank}, 1);
+ layer.wv_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_head * n_embd_head_v}, 1);
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd});
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
@@ -8815,6 +8863,7 @@ struct llm_build_context {
const int32_t n_ctx_orig;
const bool flash_attn;
+ const bool mla_attn;
const enum llama_pooling_type pooling_type;
const enum llama_rope_type rope_type;
@@ -8864,6 +8913,7 @@ struct llm_build_context {
kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head),
n_ctx_orig (cparams.n_ctx_orig_yarn),
flash_attn (cparams.flash_attn),
+ mla_attn (cparams.mla_attn),
pooling_type (cparams.pooling_type),
rope_type (hparams.rope_type),
cb (cb),
@@ -13329,6 +13379,10 @@ struct llm_build_context {
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
+ // whether to use n_tokens as the matrix dimension during multiplication or n_head
+ // n_tokens is higher during prompt processing, this allows to optimize for this case
+ bool pp_opt = n_tokens > n_head;
+
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * inpSA = inpL;
@@ -13383,71 +13437,216 @@ struct llm_build_context {
0);
cb(kv_compressed, "kv_compressed", il);
- // and {n_embd_head_qk_rope, n_tokens}
- struct ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens,
- kv_pe_compresseed->nb[1],
- kv_pe_compresseed->nb[1],
- ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
- cb(k_pe, "k_pe", il);
+ if (lctx.cparams.mla_attn && model.layers[il].wk_b && model.layers[il].wv_b) {
- //kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm
- kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams,
- model.layers[il].attn_kv_a_norm, NULL,
- LLM_NORM_RMS, cb, il);
- cb(kv_compressed, "kv_compressed", il);
+ // and {n_embd_head_qk_rope, n_tokens}
+ struct ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens,
+ kv_pe_compresseed->nb[1],
+ kv_pe_compresseed->nb[1],
+ ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
+ cb(k_pe, "k_pe", il);
- // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens}
- struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed);
- cb(kv, "kv", il);
+ //kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm
+ kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams,
+ model.layers[il].attn_kv_a_norm, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(kv_compressed, "kv_compressed", il);
- // split into {n_head * n_embd_head_qk_nope, n_tokens}
- struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
- ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
- ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
- 0);
- cb(k_nope, "k_nope", il);
+ struct ggml_tensor * kv_cache_view = ggml_view_1d(ctx0, kv_self.kv_l[il], n_tokens*kv_lora_rank, ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank)*kv_head);
+ cb(kv_cache_view, "kv_cache_view", il);
- // and {n_head * n_embd_head_v, n_tokens}
- struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens,
- ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)),
- ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head),
- ggml_row_size(kv->type, (n_embd_head_qk_nope)));
- cb(v_states, "v_states", il);
+ // note: storing c^KV in the KV cache
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, kv_compressed, kv_cache_view));
- v_states = ggml_cont(ctx0, v_states);
- cb(v_states, "v_states", il);
+ struct ggml_tensor * kv_cache_trans_view = ggml_view_2d(ctx0, kv_self.kvt_l[il], n_tokens, kv_lora_rank, ggml_row_size(kv_self.kv_l[il]->type, kv_self.size), ggml_row_size(kv_self.kv_l[il]->type, kv_head));
+ cb(kv_cache_trans_view, "kv_cache_trans_view", il);
- v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens,
- ggml_row_size(kv->type, hparams.n_embd_head_v * n_head),
- 0);
- cb(v_states, "v_states", il);
+ // note: storing transposed c^KV in the transposed KV cache
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, ggml_transpose(ctx0, kv_compressed), kv_cache_trans_view));
- //q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
- q_pe = ggml_rope_ext(
- ctx0, q_pe, inp_pos, nullptr,
- n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
- ext_factor, attn_factor_scaled, beta_fast, beta_slow
- );
- cb(q_pe, "q_pe", il);
+ struct ggml_tensor * kv_cache =
+ ggml_view_2d(ctx0, kv_self.kv_l[il],
+ kv_lora_rank, n_kv,
+ ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank),
+ 0);
+ cb(kv_cache, "kv_cache", il);
- // shared RoPE key
- //k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
- k_pe = ggml_rope_ext(
- ctx0, k_pe, inp_pos, nullptr,
- n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
- ext_factor, attn_factor_scaled, beta_fast, beta_slow
- );
- cb(k_pe, "k_pe", il);
+ struct ggml_tensor * kv_cache_trans =
+ ggml_view_2d(ctx0, kv_self.kvt_l[il],
+ n_kv, kv_lora_rank,
+ ggml_row_size(kv_self.kv_l[il]->type, kv_self.size),
+ 0);
+ cb(kv_cache_trans, "kv_cache_trans", il);
- struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0);
- cb(q_states, "q_states", il);
+ //q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
+ q_pe = ggml_rope_ext(
+ ctx0, q_pe, inp_pos, nullptr,
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor_scaled, beta_fast, beta_slow
+ );
+ cb(q_pe, "q_pe", il);
+
+ // shared RoPE key
+ //k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
+ k_pe = ggml_rope_ext(
+ ctx0, k_pe, inp_pos, nullptr,
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor_scaled, beta_fast, beta_slow
+ );
+ cb(k_pe, "k_pe", il);
- struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
- cb(k_states, "k_states", il);
+ struct ggml_tensor * kr_cache_view = ggml_view_1d(ctx0, kv_self.kr_l[il], n_tokens*n_embd_head_qk_rope, ggml_row_size(kv_self.kr_l[il]->type, n_embd_head_qk_rope)*kv_head);
+ cb(kr_cache_view, "kr_cache_view", il);
- cur = llm_build_kv(ctx0, lctx, kv_self, gf,
- model.layers[il].wo, NULL,
- k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
+ // note: storing RoPE-ed version of K^R in the KV cache
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_pe, kr_cache_view));
+
+ struct ggml_tensor * kr_cache =
+ ggml_view_2d(ctx0, kv_self.kr_l[il],
+ n_embd_head_qk_rope, n_kv,
+ ggml_row_size(kv_self.kr_l[il]->type, n_embd_head_qk_rope),
+ 0);
+ cb(kr_cache, "kr_cache", il);
+
+ struct ggml_tensor * wk_b = ggml_view_3d(ctx0, model.layers[il].wk_b, n_embd_head_qk_nope, kv_lora_rank, n_head, ggml_row_size(model.layers[il].wk_b->type, n_embd_head_qk_nope), ggml_row_size(model.layers[il].wk_b->type, kv_lora_rank * n_embd_head_qk_nope), 0);
+ cb(wk_b, "wk_b", il);
+
+ q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3);
+ cb(q_nope, "q_nope_perm", il);
+
+ struct ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b, q_nope);
+ cb(q_nope2, "q_nope2", il);
+
+ if (!pp_opt) {
+ q_nope2 = ggml_permute(ctx0, q_nope2, 0, 2, 1, 3);
+ cb(q_nope2, "q_nope2_perm", il);
+ }
+ struct ggml_tensor * kq_nope = ggml_mul_mat(ctx0, kv_cache, q_nope2);
+ cb(kq_nope, "kq_nope", il);
+
+ if (!pp_opt) {
+ kq_nope = ggml_permute(ctx0, kq_nope, 0, 2, 1, 3);
+ cb(kq_nope, "kq_nope_perm", il);
+ }
+
+ if (pp_opt) {
+ q_pe = ggml_permute(ctx0, q_pe, 0, 2, 1, 3);
+ cb(q_pe, "q_pe_perm", il);
+ }
+ struct ggml_tensor * kq_pe = ggml_mul_mat(ctx0, kr_cache, q_pe);
+ cb(kq_pe, "kq_pe", il);
+
+ if (!pp_opt) {
+ kq_pe = ggml_permute(ctx0, kq_pe, 0, 2, 1, 3);
+ cb(kq_pe, "kq_pe_perm", il);
+ }
+
+ struct ggml_tensor * kq = ggml_add(ctx0, kq_nope, kq_pe);
+ cb(kq, "kq", il);
+
+ kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
+ cb(kq, "kq_soft_max_ext", il);
+
+ if (!pp_opt) {
+ kq = ggml_permute(ctx0, kq, 0, 2, 1, 3);
+ cb(kq, "kq_soft_max_ext_perm", il);
+ }
+
+ struct ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq);
+ cb(kqv_compressed, "kqv_compressed", il);
+
+ if (!pp_opt) {
+ kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3);
+ cb(kqv_compressed, "kqv_compressed_perm", il);
+ }
+
+ struct ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, kv_lora_rank, n_embd_head_v, n_head, ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank), ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank * n_embd_head_v), 0);
+ cb(wv_b, "wv_b", il);
+
+ struct ggml_tensor * kqv = ggml_mul_mat(ctx0, wv_b, kqv_compressed);
+ cb(kqv, "kqv", il);
+
+ kqv = ggml_cont(ctx0, ggml_permute(ctx0, kqv, 0, 2, 1, 3));
+ cb(kqv, "kqv_perm", il);
+
+ cur = ggml_view_2d(ctx0, kqv, n_embd_head_v*n_head, n_tokens, ggml_row_size(kqv->type, n_embd_head_v*n_head), 0);
+ cb(cur, "kqv_2d", il);
+
+ ggml_build_forward_expand(gf, cur);
+
+ cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
+ cb(cur, "kqv_out", il);
+
+ }
+ else {
+
+ // and {n_embd_head_qk_rope, n_tokens}
+ struct ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens,
+ kv_pe_compresseed->nb[1],
+ kv_pe_compresseed->nb[1],
+ ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
+ cb(k_pe, "k_pe", il);
+
+ //kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm
+ kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams,
+ model.layers[il].attn_kv_a_norm, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(kv_compressed, "kv_compressed", il);
+
+ // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens}
+ struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed);
+ cb(kv, "kv", il);
+
+ // split into {n_head * n_embd_head_qk_nope, n_tokens}
+ struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
+ ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
+ ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
+ 0);
+ cb(k_nope, "k_nope", il);
+
+ // and {n_head * n_embd_head_v, n_tokens}
+ struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens,
+ ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)),
+ ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head),
+ ggml_row_size(kv->type, (n_embd_head_qk_nope)));
+ cb(v_states, "v_states", il);
+
+ v_states = ggml_cont(ctx0, v_states);
+ cb(v_states, "v_states", il);
+
+ v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens,
+ ggml_row_size(kv->type, hparams.n_embd_head_v * n_head),
+ 0);
+ cb(v_states, "v_states", il);
+
+ //q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
+ q_pe = ggml_rope_ext(
+ ctx0, q_pe, inp_pos, nullptr,
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor_scaled, beta_fast, beta_slow
+ );
+ cb(q_pe, "q_pe", il);
+
+ // shared RoPE key
+ //k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
+ k_pe = ggml_rope_ext(
+ ctx0, k_pe, inp_pos, nullptr,
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor_scaled, beta_fast, beta_slow
+ );
+ cb(k_pe, "k_pe", il);
+
+ struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0);
+ cb(q_states, "q_states", il);
+
+ struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
+ cb(k_states, "k_states", il);
+
+ cur = llm_build_kv(ctx0, lctx, kv_self, gf,
+ model.layers[il].wo, NULL,
+ k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
+
+ }
}
if (il == n_layer - 1) {
@@ -17391,6 +17590,7 @@ struct llama_context_params llama_context_default_params() {
/*.embeddings =*/ false,
/*.offload_kqv =*/ true,
/*.flash_attn =*/ false,
+ /*.mla_attn =*/ false,
/*.abort_callback =*/ nullptr,
/*.abort_callback_data =*/ nullptr,
};
@@ -17589,6 +17789,7 @@ struct llama_context * llama_new_context_with_model(
cparams.embeddings = params.embeddings;
cparams.offload_kqv = params.offload_kqv;
cparams.flash_attn = params.flash_attn;
+ cparams.mla_attn = params.mla_attn;
cparams.pooling_type = params.pooling_type;
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
@@ -17655,6 +17856,7 @@ struct llama_context * llama_new_context_with_model(
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
+ LLAMA_LOG_INFO("%s: mla_attn = %d\n", __func__, cparams.mla_attn);
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
@@ -17853,6 +18055,24 @@ struct llama_context * llama_new_context_with_model(
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
}
+ {
+ size_t memory_size_kr = 0;
+ size_t memory_size_kv = 0;
+
+ for (auto & kr : ctx->kv_self.kr_l) {
+ memory_size_kr += ggml_nbytes(kr);
+ }
+
+ for (auto & kv : ctx->kv_self.kv_l) {
+ memory_size_kv += ggml_nbytes(kv);
+ }
+
+ LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K^R (%s): %7.2f MiB, c^KV (%s): %7.2f MiB\n", __func__,
+ (float)(memory_size_kr + memory_size_kv) / (1024.0f * 1024.0f),
+ ggml_type_name(type_k), (float)memory_size_kr / (1024.0f * 1024.0f),
+ ggml_type_name(type_k), (float)memory_size_kv / (1024.0f * 1024.0f));
+ }
+
// graph outputs buffer
{
// resized during inference when a batch uses more outputs