diff options
-rw-r--r-- | common/common.cpp | 7 | ||||
-rw-r--r-- | common/common.h | 1 | ||||
-rwxr-xr-x | convert_hf_to_gguf.py | 42 | ||||
-rw-r--r-- | examples/llama-bench/llama-bench.cpp | 35 | ||||
-rw-r--r-- | ggml/src/ggml.c | 17 | ||||
-rw-r--r-- | gguf-py/gguf/constants.py | 6 | ||||
-rw-r--r-- | gguf-py/gguf/tensor_mapping.py | 8 | ||||
-rw-r--r-- | include/llama.h | 1 | ||||
-rw-r--r-- | src/llama.cpp | 338 |
9 files changed, 380 insertions, 75 deletions
diff --git a/common/common.cpp b/common/common.cpp index 95e91bc1..6219f0ce 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -813,6 +813,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.flash_attn = true; return true; } + if (arg == "-mla" || arg == "--mla-use") { + params.mla_attn = true; + return true; + } if (arg == "-co" || arg == "--color") { params.use_color = true; return true; @@ -1452,6 +1456,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", " --keep N", "number of tokens to keep from the initial prompt (default: %d, -1 = all)", params.n_keep }); options.push_back({ "*", " --chunks N", "max number of chunks to process (default: %d, -1 = all)", params.n_chunks }); options.push_back({ "*", "-fa, --flash-attn", "enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled" }); + options.push_back({ "*", "-mla, --mla-use", "enable MLA (default: %s)", params.mla_attn ? "enabled" : "disabled" }); options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n" "in conversation mode, this will be used as system prompt\n" "(default: '%s')", params.prompt.c_str() }); @@ -2283,6 +2288,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.cb_eval_user_data = params.cb_eval_user_data; cparams.offload_kqv = !params.no_kv_offload; cparams.flash_attn = params.flash_attn; + cparams.mla_attn = params.mla_attn; cparams.type_k = kv_cache_type_from_str(params.cache_type_k); cparams.type_v = kv_cache_type_from_str(params.cache_type_v); @@ -3280,6 +3286,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false"); fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false"); fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false"); + fprintf(stream, "mla_attn: %s # default: false\n", params.mla_attn ? "true" : "false"); fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp); const std::vector<float> tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices()); diff --git a/common/common.h b/common/common.h index 73d7d650..fc1ae619 100644 --- a/common/common.h +++ b/common/common.h @@ -174,6 +174,7 @@ struct gpt_params { bool simple_io = false; // improves compatibility with subprocesses and limited consoles bool cont_batching = true; // insert new sequences for decoding on-the-fly bool flash_attn = false; // flash attention + bool mla_attn = false; // MLA bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix bool ignore_eos = false; // ignore generated EOS tokens diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 3910aa1d..1ee82724 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -3123,6 +3123,7 @@ class ArcticModel(Model): @Model.register("DeepseekV2ForCausalLM") +@Model.register("DeepseekV3ForCausalLM") class DeepseekV2Model(Model): model_arch = gguf.MODEL_ARCH.DEEPSEEK2 @@ -3144,6 +3145,15 @@ class DeepseekV2Model(Model): self.gguf_writer.add_expert_count(hparams["n_routed_experts"]) self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"]) self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"]) + self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"]) + + if hparams["scoring_func"] == "sigmoid": + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID) + elif hparams["scoring_func"] == "softmax": + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX) + else: + raise ValueError(f"Unsupported scoring_func value: {hparams['scoring_func']}") + self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: @@ -3156,6 +3166,17 @@ class DeepseekV2Model(Model): _experts: list[dict[str, Tensor]] | None = None def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # rename e_score_correction_bias tensors + if name.endswith("e_score_correction_bias"): + name = name.replace("e_score_correction_bias", "e_score_correction.bias") + + # skip Multi-Token Prediction (MTP) layers + block_count = self.hparams["num_hidden_layers"] + match = re.match(r"model.layers.(\d+)", name) + if match and int(match.group(1)) >= block_count: + return [] + + # process the experts separately if name.find("mlp.experts") != -1: n_experts = self.hparams["n_routed_experts"] @@ -3188,6 +3209,27 @@ class DeepseekV2Model(Model): return tensors else: return [] + if name.endswith("kv_b_proj.weight"): + name_kb = name.replace("kv_b_proj", "k_b_proj") + name_vb = name.replace("kv_b_proj", "v_b_proj") + + n_head_kv = self.hparams["num_key_value_heads"] + v_head_dim = self.hparams["v_head_dim"] + qk_nope_head_dim = self.hparams["qk_nope_head_dim"] + + assert data_torch.shape[0] == n_head_kv * (v_head_dim + qk_nope_head_dim) + + kv_b = data_torch.view(n_head_kv, v_head_dim + qk_nope_head_dim, data_torch.shape[-1]) + k_b, v_b = torch.split(kv_b, [qk_nope_head_dim, v_head_dim], dim=1) + k_b = k_b.transpose(1, 2) + k_b = k_b.reshape(n_head_kv * data_torch.shape[-1], qk_nope_head_dim) + v_b = v_b.reshape(n_head_kv * v_head_dim, data_torch.shape[-1]) + + return [ + (self.map_tensor_name(name), data_torch), + (self.map_tensor_name(name_kb), k_b), + (self.map_tensor_name(name_vb), v_b) + ] return [(self.map_tensor_name(name), data_torch)] diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 42320da8..41b93df5 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -232,6 +232,7 @@ struct cmd_params { std::vector<int> main_gpu; std::vector<bool> no_kv_offload; std::vector<bool> flash_attn; + std::vector<bool> mla_attn; std::vector<std::vector<float>> tensor_split; std::vector<bool> use_mmap; std::vector<bool> embeddings; @@ -261,6 +262,7 @@ static const cmd_params cmd_params_defaults = { /* main_gpu */ {0}, /* no_kv_offload */ {false}, /* flash_attn */ {false}, + /* mla_attn */ {false}, /* tensor_split */ {std::vector<float>(llama_max_devices(), 0.0f)}, /* use_mmap */ {true}, /* embeddings */ {false}, @@ -294,6 +296,7 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -mg, --main-gpu <i> (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str()); printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str()); printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str()); + printf(" -mla, --mla-attn <0|1> (default: %s)\n", join(cmd_params_defaults.mla_attn, ",").c_str()); printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str()); printf(" --numa <distribute|isolate|numactl> (default: disabled)\n"); printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str()); @@ -526,6 +529,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { } auto p = string_split<bool>(argv[i], split_delim); params.flash_attn.insert(params.flash_attn.end(), p.begin(), p.end()); + } else if (arg == "-mla" || arg == "--mla-attn") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto p = string_split<bool>(argv[i], split_delim); + params.mla_attn.insert(params.mla_attn.end(), p.begin(), p.end()); } else if (arg == "-mmp" || arg == "--mmap") { if (++i >= argc) { invalid_param = true; @@ -621,6 +631,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { if (params.main_gpu.empty()) { params.main_gpu = cmd_params_defaults.main_gpu; } if (params.no_kv_offload.empty()){ params.no_kv_offload = cmd_params_defaults.no_kv_offload; } if (params.flash_attn.empty()) { params.flash_attn = cmd_params_defaults.flash_attn; } + if (params.mla_attn.empty()) { params.mla_attn = cmd_params_defaults.mla_attn; } if (params.tensor_split.empty()) { params.tensor_split = cmd_params_defaults.tensor_split; } if (params.use_mmap.empty()) { params.use_mmap = cmd_params_defaults.use_mmap; } if (params.embeddings.empty()) { params.embeddings = cmd_params_defaults.embeddings; } @@ -656,6 +667,7 @@ struct cmd_params_instance { int main_gpu; bool no_kv_offload; bool flash_attn; + bool mla_attn; std::vector<float> tensor_split; bool use_mmap; bool embeddings; @@ -698,6 +710,7 @@ struct cmd_params_instance { cparams.type_v = type_v; cparams.offload_kqv = !no_kv_offload; cparams.flash_attn = flash_attn; + cparams.mla_attn = mla_attn; cparams.embeddings = embeddings; return cparams; @@ -722,6 +735,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param for (const auto & tv : params.type_v) for (const auto & nkvo : params.no_kv_offload) for (const auto & fa : params.flash_attn) + for (const auto & mla : params.mla_attn) for (const auto & nt : params.n_threads) { for (const auto & n_prompt : params.n_prompt) { if (n_prompt == 0) { @@ -743,6 +757,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param /* .main_gpu = */ mg, /* .no_kv_offload= */ nkvo, /* .flash_attn = */ fa, + /* .mla_attn = */ mla, /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, @@ -771,6 +786,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param /* .main_gpu = */ mg, /* .no_kv_offload= */ nkvo, /* .flash_attn = */ fa, + /* .mla_attn = */ mla, /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, @@ -799,6 +815,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param /* .main_gpu = */ mg, /* .no_kv_offload= */ nkvo, /* .flash_attn = */ fa, + /* .mla_attn = */ mla, /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, @@ -827,6 +844,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param /* .main_gpu = */ mg, /* .no_kv_offload= */ nkvo, /* .flash_attn = */ fa, + /* .mla_attn = */ mla, /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, @@ -866,6 +884,7 @@ struct test { int main_gpu; bool no_kv_offload; bool flash_attn; + bool mla_attn; std::vector<float> tensor_split; bool use_mmap; bool embeddings; @@ -895,6 +914,7 @@ struct test { main_gpu = inst.main_gpu; no_kv_offload = inst.no_kv_offload; flash_attn = inst.flash_attn; + mla_attn = inst.mla_attn; tensor_split = inst.tensor_split; use_mmap = inst.use_mmap; embeddings = inst.embeddings; @@ -988,7 +1008,7 @@ struct test { "n_batch", "n_ubatch", "n_threads", "type_k", "type_v", "n_gpu_layers", "split_mode", - "main_gpu", "no_kv_offload", "flash_attn", + "main_gpu", "no_kv_offload", "flash_attn", "mla_attn", "tensor_split", "use_mmap", "embeddings", "repack", "n_prompt", "n_gen", "test_time", "avg_ns", "stddev_ns", @@ -1010,7 +1030,7 @@ struct test { } if (field == "cuda" || field == "vulkan" || field == "kompute" || field == "metal" || field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" || - field == "flash_attn" || field == "use_mmap" || field == "embeddings" || field == "repack") { + field == "flash_attn" || field == "mla_attn" || field == "use_mmap" || field == "embeddings" || field == "repack") { return BOOL; } if (field == "avg_ts" || field == "stddev_ts") { @@ -1044,7 +1064,7 @@ struct test { std::to_string(n_batch), std::to_string(n_ubatch), std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v), std::to_string(n_gpu_layers), split_mode_str(split_mode), - std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), + std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), std::to_string(mla_attn), tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), std::to_string(repack), std::to_string(n_prompt), std::to_string(n_gen), test_time, std::to_string(avg_ns()), std::to_string(stdev_ns()), @@ -1208,6 +1228,9 @@ struct markdown_printer : public printer { if (field == "flash_attn") { return 2; } + if (field == "mla_attn") { + return 3; + } if (field == "use_mmap") { return 4; } @@ -1242,6 +1265,9 @@ struct markdown_printer : public printer { if (field == "flash_attn") { return "fa"; } + if (field == "mla_attn") { + return "mla"; + } if (field == "use_mmap") { return "mmap"; } @@ -1294,6 +1320,9 @@ struct markdown_printer : public printer { if (params.flash_attn.size() > 1 || params.flash_attn != cmd_params_defaults.flash_attn) { fields.emplace_back("flash_attn"); } + if (params.mla_attn.size() > 1 || params.mla_attn != cmd_params_defaults.mla_attn) { + fields.emplace_back("mla_attn"); + } if (params.tensor_split.size() > 1 || params.tensor_split != cmd_params_defaults.tensor_split) { fields.emplace_back("tensor_split"); } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index e07dd547..3867cf00 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -14064,31 +14064,22 @@ static void ggml_compute_forward_mul_mat( #endif #if GGML_USE_IQK_MULMAT - if (dst->type == GGML_TYPE_F32 && (ne12*ne13)%nth == 0) { + if (dst->type == GGML_TYPE_F32) { + int gcd = simple_gcd(ne12*ne13, nth); int counter = 0; for (int64_t i13 = 0; i13 < ne13; i13++) { for (int64_t i12 = 0; i12 < ne12; i12++) { - if (counter++ % nth == ith) { + if ((counter++ % gcd) == (ith%gcd)) { if (!iqk_mul_mat(ne01, ne11, ne00, src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01, ///ggml_type_size(src0->type), src1->type, (const char *)src1->data + i12*nb12 + i13*nb13, nb11, ///ggml_type_size(src1->type), (float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type), - 0, 1)) goto IQK_MulMat_Not_Available1; + ith/gcd, nth/gcd)) goto IQK_MulMat_Not_Available1; } } } return; } - if (dst->type == GGML_TYPE_F32) { - for (int64_t i13 = 0; i13 < ne13; i13++) - for (int64_t i12 = 0; i12 < ne12; i12++) - if (!iqk_mul_mat(ne01, ne11, ne00, - src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01, ///ggml_type_size(src0->type), - src1->type, (const char *)src1->data + i12*nb12 + i13*nb13, nb11, ///ggml_type_size(src1->type), - (float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type), - ith, nth)) goto IQK_MulMat_Not_Available1; - return; - } IQK_MulMat_Not_Available1:; #endif diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 90d5efec..037837da 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -274,6 +274,8 @@ class MODEL_TENSOR(IntEnum): ATTN_Q_B = auto() ATTN_KV_A_MQA = auto() ATTN_KV_B = auto() + ATTN_K_B = auto() + ATTN_V_B = auto() ATTN_Q_A_NORM = auto() ATTN_KV_A_NORM = auto() FFN_SUB_NORM = auto() @@ -403,6 +405,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b", MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa", MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b", + MODEL_TENSOR.ATTN_K_B: "blk.{bid}.attn_k_b", + MODEL_TENSOR.ATTN_V_B: "blk.{bid}.attn_v_b", MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm", MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm", MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm", @@ -967,6 +971,8 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.ATTN_Q_B, MODEL_TENSOR.ATTN_KV_A_MQA, MODEL_TENSOR.ATTN_KV_B, + MODEL_TENSOR.ATTN_K_B, + MODEL_TENSOR.ATTN_V_B, MODEL_TENSOR.ATTN_Q_A_NORM, MODEL_TENSOR.ATTN_KV_A_NORM, MODEL_TENSOR.ATTN_OUT, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index a70b69c5..e8725426 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -446,6 +446,14 @@ class TensorNameMap: "model.layers.{bid}.self_attn.kv_b_proj", # deepseek2 ), + MODEL_TENSOR.ATTN_K_B: ( + "model.layers.{bid}.self_attn.k_b_proj", # deepseek2 + ), + + MODEL_TENSOR.ATTN_V_B: ( + "model.layers.{bid}.self_attn.v_b_proj", # deepseek2 + ), + MODEL_TENSOR.ATTN_Q_A_NORM: ( "model.layers.{bid}.self_attn.q_a_layernorm", # deepseek2 ), diff --git a/include/llama.h b/include/llama.h index 730c087a..39251d35 100644 --- a/include/llama.h +++ b/include/llama.h @@ -374,6 +374,7 @@ extern "C" { bool embeddings; // if true, extract embeddings (together with logits) bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU bool flash_attn; // whether to use flash attention [EXPERIMENTAL] + bool mla_attn; // whether to use MLA attention [EXPERIMENTAL] // Abort callback // if it returns true, execution of llama_decode() will be aborted diff --git a/src/llama.cpp b/src/llama.cpp index 29926a94..00e6c934 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -539,6 +539,8 @@ enum llm_tensor { LLM_TENSOR_ATTN_Q_B, LLM_TENSOR_ATTN_KV_A_MQA, LLM_TENSOR_ATTN_KV_B, + LLM_TENSOR_ATTN_K_B, + LLM_TENSOR_ATTN_V_B, LLM_TENSOR_ATTN_Q_A_NORM, LLM_TENSOR_ATTN_KV_A_NORM, LLM_TENSOR_ATTN_SUB_NORM, @@ -1203,6 +1205,8 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA { LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" }, { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" }, { LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" }, + { LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" }, + { LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" }, { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, @@ -2503,6 +2507,7 @@ struct llama_cparams { bool causal_attn; bool offload_kqv; bool flash_attn; + bool mla_attn; enum llama_pooling_type pooling_type; @@ -2541,6 +2546,8 @@ struct llama_layer { struct ggml_tensor * wq_b; struct ggml_tensor * wkv_a_mqa; struct ggml_tensor * wkv_b; + struct ggml_tensor * wk_b; + struct ggml_tensor * wv_b; struct ggml_tensor * wq_cross; struct ggml_tensor * wk_cross; struct ggml_tensor * wv_cross; @@ -2669,11 +2676,19 @@ struct llama_kv_cache { ggml_type type_k = GGML_TYPE_F16; ggml_type type_v = GGML_TYPE_F16; + ggml_type type_kr = GGML_TYPE_F16; + ggml_type type_kv = GGML_TYPE_F16; + std::vector<llama_kv_cell> cells; std::vector<struct ggml_tensor *> k_l; // per layer std::vector<struct ggml_tensor *> v_l; + // DeepSeek MLA + std::vector<struct ggml_tensor *> kr_l; // per layer + std::vector<struct ggml_tensor *> kv_l; + std::vector<struct ggml_tensor *> kvt_l; + std::vector<struct ggml_context *> ctxs; std::vector<ggml_backend_buffer_t> bufs; @@ -3104,8 +3119,10 @@ static bool llama_kv_cache_init( cache.size = kv_size; cache.used = 0; - cache.type_k = type_k; - cache.type_v = type_v; + cache.type_k = type_k; + cache.type_v = type_v; + cache.type_kr = type_k; + cache.type_kv = type_v; cache.cells.clear(); cache.cells.resize(kv_size); @@ -3132,7 +3149,7 @@ static bool llama_kv_cache_init( for (auto & it : buft_layer_count) { int n_layers = it.second; struct ggml_init_params params = { - /*.mem_size =*/ 2u*n_layers*ggml_tensor_overhead(), + /*.mem_size =*/ 5u*n_layers*ggml_tensor_overhead(), /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; @@ -3148,17 +3165,46 @@ static bool llama_kv_cache_init( cache.k_l.reserve(n_layer); cache.v_l.reserve(n_layer); + // DeepSeek MLA + cache.kr_l.reserve(n_layer); + cache.kv_l.reserve(n_layer); + cache.kvt_l.reserve(n_layer); + for (int i = 0; i < (int) n_layer; i++) { const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front(); - ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); - ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); + ggml_tensor * k; + ggml_tensor * v; + if (cparams.mla_attn && model.layers[i].wk_b && model.layers[i].wv_b) { + k = ggml_new_tensor_1d(ctx, type_k, 1); + v = ggml_new_tensor_1d(ctx, type_v, 1); + } + else { + k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); + v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); + } + ggml_format_name(k, "cache_k_l%d", i); ggml_format_name(v, "cache_v_l%d", i); cache.k_l.push_back(k); cache.v_l.push_back(v); + + + // DeepSeek MLA + const uint32_t n_embd_head_qk_rope = hparams.n_rot; + const uint32_t kv_lora_rank = hparams.n_lora_kv; + LLAMA_LOG_INFO("%s: layer %d: n_embd_head_qk_rope = %d, kv_lora_rank = %d\n", __func__, i, n_embd_head_qk_rope, kv_lora_rank); + ggml_tensor * kr = ggml_new_tensor_1d(ctx, cache.type_kr, n_embd_head_qk_rope*kv_size); + ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_kv, kv_lora_rank*kv_size); + ggml_tensor * kvt = ggml_new_tensor_1d(ctx, cache.type_kv, kv_lora_rank*kv_size); + ggml_format_name(kr, "cache_kr_l%d", i); + ggml_format_name(kv, "cache_kv_l%d", i); + ggml_format_name(kvt, "cache_kvt_l%d", i); + cache.kr_l.push_back(kr); + cache.kv_l.push_back(kv); + cache.kvt_l.push_back(kvt); } // allocate tensors and initialize the buffers to avoid NaNs in the padding @@ -7644,6 +7690,8 @@ static bool llm_load_tensors( layer.wkv_a_mqa = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}); layer.wkv_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}); + layer.wk_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, n_head * kv_lora_rank}, 1); + layer.wv_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_head * n_embd_head_v}, 1); layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}); layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); @@ -8815,6 +8863,7 @@ struct llm_build_context { const int32_t n_ctx_orig; const bool flash_attn; + const bool mla_attn; const enum llama_pooling_type pooling_type; const enum llama_rope_type rope_type; @@ -8864,6 +8913,7 @@ struct llm_build_context { kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head), n_ctx_orig (cparams.n_ctx_orig_yarn), flash_attn (cparams.flash_attn), + mla_attn (cparams.mla_attn), pooling_type (cparams.pooling_type), rope_type (hparams.rope_type), cb (cb), @@ -13329,6 +13379,10 @@ struct llm_build_context { // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + // whether to use n_tokens as the matrix dimension during multiplication or n_head + // n_tokens is higher during prompt processing, this allows to optimize for this case + bool pp_opt = n_tokens > n_head; + for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -13383,71 +13437,216 @@ struct llm_build_context { 0); cb(kv_compressed, "kv_compressed", il); - // and {n_embd_head_qk_rope, n_tokens} - struct ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens, - kv_pe_compresseed->nb[1], - kv_pe_compresseed->nb[1], - ggml_row_size(kv_pe_compresseed->type, kv_lora_rank)); - cb(k_pe, "k_pe", il); + if (lctx.cparams.mla_attn && model.layers[il].wk_b && model.layers[il].wv_b) { - //kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm - kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams, - model.layers[il].attn_kv_a_norm, NULL, - LLM_NORM_RMS, cb, il); - cb(kv_compressed, "kv_compressed", il); + // and {n_embd_head_qk_rope, n_tokens} + struct ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens, + kv_pe_compresseed->nb[1], + kv_pe_compresseed->nb[1], + ggml_row_size(kv_pe_compresseed->type, kv_lora_rank)); + cb(k_pe, "k_pe", il); - // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens} - struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed); - cb(kv, "kv", il); + //kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm + kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams, + model.layers[il].attn_kv_a_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(kv_compressed, "kv_compressed", il); - // split into {n_head * n_embd_head_qk_nope, n_tokens} - struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens, - ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v), - ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), - 0); - cb(k_nope, "k_nope", il); + struct ggml_tensor * kv_cache_view = ggml_view_1d(ctx0, kv_self.kv_l[il], n_tokens*kv_lora_rank, ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank)*kv_head); + cb(kv_cache_view, "kv_cache_view", il); - // and {n_head * n_embd_head_v, n_tokens} - struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens, - ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)), - ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head), - ggml_row_size(kv->type, (n_embd_head_qk_nope))); - cb(v_states, "v_states", il); + // note: storing c^KV in the KV cache + ggml_build_forward_expand(gf, ggml_cpy(ctx0, kv_compressed, kv_cache_view)); - v_states = ggml_cont(ctx0, v_states); - cb(v_states, "v_states", il); + struct ggml_tensor * kv_cache_trans_view = ggml_view_2d(ctx0, kv_self.kvt_l[il], n_tokens, kv_lora_rank, ggml_row_size(kv_self.kv_l[il]->type, kv_self.size), ggml_row_size(kv_self.kv_l[il]->type, kv_head)); + cb(kv_cache_trans_view, "kv_cache_trans_view", il); - v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens, - ggml_row_size(kv->type, hparams.n_embd_head_v * n_head), - 0); - cb(v_states, "v_states", il); + // note: storing transposed c^KV in the transposed KV cache + ggml_build_forward_expand(gf, ggml_cpy(ctx0, ggml_transpose(ctx0, kv_compressed), kv_cache_trans_view)); - //q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE - q_pe = ggml_rope_ext( - ctx0, q_pe, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor_scaled, beta_fast, beta_slow - ); - cb(q_pe, "q_pe", il); + struct ggml_tensor * kv_cache = + ggml_view_2d(ctx0, kv_self.kv_l[il], + kv_lora_rank, n_kv, + ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank), + 0); + cb(kv_cache, "kv_cache", il); - // shared RoPE key - //k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE - k_pe = ggml_rope_ext( - ctx0, k_pe, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor_scaled, beta_fast, beta_slow - ); - cb(k_pe, "k_pe", il); + struct ggml_tensor * kv_cache_trans = + ggml_view_2d(ctx0, kv_self.kvt_l[il], + n_kv, kv_lora_rank, + ggml_row_size(kv_self.kv_l[il]->type, kv_self.size), + 0); + cb(kv_cache_trans, "kv_cache_trans", il); - struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0); - cb(q_states, "q_states", il); + //q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE + q_pe = ggml_rope_ext( + ctx0, q_pe, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor_scaled, beta_fast, beta_slow + ); + cb(q_pe, "q_pe", il); + + // shared RoPE key + //k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE + k_pe = ggml_rope_ext( + ctx0, k_pe, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor_scaled, beta_fast, beta_slow + ); + cb(k_pe, "k_pe", il); - struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0); - cb(k_states, "k_states", il); + struct ggml_tensor * kr_cache_view = ggml_view_1d(ctx0, kv_self.kr_l[il], n_tokens*n_embd_head_qk_rope, ggml_row_size(kv_self.kr_l[il]->type, n_embd_head_qk_rope)*kv_head); + cb(kr_cache_view, "kr_cache_view", il); - cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, NULL, - k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il); + // note: storing RoPE-ed version of K^R in the KV cache + ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_pe, kr_cache_view)); + + struct ggml_tensor * kr_cache = + ggml_view_2d(ctx0, kv_self.kr_l[il], + n_embd_head_qk_rope, n_kv, + ggml_row_size(kv_self.kr_l[il]->type, n_embd_head_qk_rope), + 0); + cb(kr_cache, "kr_cache", il); + + struct ggml_tensor * wk_b = ggml_view_3d(ctx0, model.layers[il].wk_b, n_embd_head_qk_nope, kv_lora_rank, n_head, ggml_row_size(model.layers[il].wk_b->type, n_embd_head_qk_nope), ggml_row_size(model.layers[il].wk_b->type, kv_lora_rank * n_embd_head_qk_nope), 0); + cb(wk_b, "wk_b", il); + + q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3); + cb(q_nope, "q_nope_perm", il); + + struct ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b, q_nope); + cb(q_nope2, "q_nope2", il); + + if (!pp_opt) { + q_nope2 = ggml_permute(ctx0, q_nope2, 0, 2, 1, 3); + cb(q_nope2, "q_nope2_perm", il); + } + struct ggml_tensor * kq_nope = ggml_mul_mat(ctx0, kv_cache, q_nope2); + cb(kq_nope, "kq_nope", il); + + if (!pp_opt) { + kq_nope = ggml_permute(ctx0, kq_nope, 0, 2, 1, 3); + cb(kq_nope, "kq_nope_perm", il); + } + + if (pp_opt) { + q_pe = ggml_permute(ctx0, q_pe, 0, 2, 1, 3); + cb(q_pe, "q_pe_perm", il); + } + struct ggml_tensor * kq_pe = ggml_mul_mat(ctx0, kr_cache, q_pe); + cb(kq_pe, "kq_pe", il); + + if (!pp_opt) { + kq_pe = ggml_permute(ctx0, kq_pe, 0, 2, 1, 3); + cb(kq_pe, "kq_pe_perm", il); + } + + struct ggml_tensor * kq = ggml_add(ctx0, kq_nope, kq_pe); + cb(kq, "kq", il); + + kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias); + cb(kq, "kq_soft_max_ext", il); + + if (!pp_opt) { + kq = ggml_permute(ctx0, kq, 0, 2, 1, 3); + cb(kq, "kq_soft_max_ext_perm", il); + } + + struct ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq); + cb(kqv_compressed, "kqv_compressed", il); + + if (!pp_opt) { + kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3); + cb(kqv_compressed, "kqv_compressed_perm", il); + } + + struct ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, kv_lora_rank, n_embd_head_v, n_head, ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank), ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank * n_embd_head_v), 0); + cb(wv_b, "wv_b", il); + + struct ggml_tensor * kqv = ggml_mul_mat(ctx0, wv_b, kqv_compressed); + cb(kqv, "kqv", il); + + kqv = ggml_cont(ctx0, ggml_permute(ctx0, kqv, 0, 2, 1, 3)); + cb(kqv, "kqv_perm", il); + + cur = ggml_view_2d(ctx0, kqv, n_embd_head_v*n_head, n_tokens, ggml_row_size(kqv->type, n_embd_head_v*n_head), 0); + cb(cur, "kqv_2d", il); + + ggml_build_forward_expand(gf, cur); + + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur); + cb(cur, "kqv_out", il); + + } + else { + + // and {n_embd_head_qk_rope, n_tokens} + struct ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens, + kv_pe_compresseed->nb[1], + kv_pe_compresseed->nb[1], + ggml_row_size(kv_pe_compresseed->type, kv_lora_rank)); + cb(k_pe, "k_pe", il); + + //kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm + kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams, + model.layers[il].attn_kv_a_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(kv_compressed, "kv_compressed", il); + + // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens} + struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed); + cb(kv, "kv", il); + + // split into {n_head * n_embd_head_qk_nope, n_tokens} + struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens, + ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v), + ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), + 0); + cb(k_nope, "k_nope", il); + + // and {n_head * n_embd_head_v, n_tokens} + struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens, + ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)), + ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head), + ggml_row_size(kv->type, (n_embd_head_qk_nope))); + cb(v_states, "v_states", il); + + v_states = ggml_cont(ctx0, v_states); + cb(v_states, "v_states", il); + + v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens, + ggml_row_size(kv->type, hparams.n_embd_head_v * n_head), + 0); + cb(v_states, "v_states", il); + + //q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE + q_pe = ggml_rope_ext( + ctx0, q_pe, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor_scaled, beta_fast, beta_slow + ); + cb(q_pe, "q_pe", il); + + // shared RoPE key + //k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE + k_pe = ggml_rope_ext( + ctx0, k_pe, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor_scaled, beta_fast, beta_slow + ); + cb(k_pe, "k_pe", il); + + struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0); + cb(q_states, "q_states", il); + + struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0); + cb(k_states, "k_states", il); + + cur = llm_build_kv(ctx0, lctx, kv_self, gf, + model.layers[il].wo, NULL, + k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il); + + } } if (il == n_layer - 1) { @@ -17391,6 +17590,7 @@ struct llama_context_params llama_context_default_params() { /*.embeddings =*/ false, /*.offload_kqv =*/ true, /*.flash_attn =*/ false, + /*.mla_attn =*/ false, /*.abort_callback =*/ nullptr, /*.abort_callback_data =*/ nullptr, }; @@ -17589,6 +17789,7 @@ struct llama_context * llama_new_context_with_model( cparams.embeddings = params.embeddings; cparams.offload_kqv = params.offload_kqv; cparams.flash_attn = params.flash_attn; + cparams.mla_attn = params.mla_attn; cparams.pooling_type = params.pooling_type; cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; @@ -17655,6 +17856,7 @@ struct llama_context * llama_new_context_with_model( LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch); LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn); + LLAMA_LOG_INFO("%s: mla_attn = %d\n", __func__, cparams.mla_attn); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); @@ -17853,6 +18055,24 @@ struct llama_context * llama_new_context_with_model( ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); } + { + size_t memory_size_kr = 0; + size_t memory_size_kv = 0; + + for (auto & kr : ctx->kv_self.kr_l) { + memory_size_kr += ggml_nbytes(kr); + } + + for (auto & kv : ctx->kv_self.kv_l) { + memory_size_kv += ggml_nbytes(kv); + } + + LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K^R (%s): %7.2f MiB, c^KV (%s): %7.2f MiB\n", __func__, + (float)(memory_size_kr + memory_size_kv) / (1024.0f * 1024.0f), + ggml_type_name(type_k), (float)memory_size_kr / (1024.0f * 1024.0f), + ggml_type_name(type_k), (float)memory_size_kv / (1024.0f * 1024.0f)); + } + // graph outputs buffer { // resized during inference when a batch uses more outputs |