diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2025-05-12 07:49:51 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-05-12 07:49:51 +0300 |
commit | f27cd405422307e02dffa8949ac30bc56b4d2900 (patch) | |
tree | 722b742827684815ca2cc0fb6379edd4edd2f3fd | |
parent | 465569dff8b49a195450a0eb1974fd72a32fcebc (diff) |
Enable faster prompt processing with mainline llama.cpp GGUFs (#409)
* Enable MLA-3 in crippled GGUFs: WIP
* Enable MLA-3 in crippled GGUFs: seems to work
* Add newly created tensors to model.tensors_by_name
Else they don't get run-time repacked.
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | common/common.cpp | 1 | ||||
-rw-r--r-- | include/llama.h | 1 | ||||
-rw-r--r-- | src/llama.cpp | 432 |
3 files changed, 294 insertions, 140 deletions
diff --git a/common/common.cpp b/common/common.cpp index ab936ee7..0dbde58f 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -2334,6 +2334,7 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & if (params.n_gpu_layers != -1) { mparams.n_gpu_layers = params.n_gpu_layers; } + mparams.mla = params.mla_attn; mparams.rpc_servers = params.rpc_servers.c_str(); mparams.main_gpu = params.main_gpu; mparams.split_mode = params.split_mode; diff --git a/include/llama.h b/include/llama.h index f1511548..0f3ae862 100644 --- a/include/llama.h +++ b/include/llama.h @@ -325,6 +325,7 @@ extern "C" { struct llama_model_params { int32_t n_gpu_layers; // number of layers to store in VRAM + int32_t mla; // MLA implementation to use (only applicable to DeepSeek models at this point) enum llama_split_mode split_mode; // how to split the model across multiple GPUs // main_gpu interpretation depends on split_mode: diff --git a/src/llama.cpp b/src/llama.cpp index b4d42c84..9369d10e 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2942,6 +2942,7 @@ struct llama_layer { std::unique_ptr<ggml_tensor> computed_wk_b; std::unique_ptr<ggml_tensor> computed_wv_b; + std::unique_ptr<ggml_tensor> computed_wkv_b; }; struct llama_kv_cell { @@ -6756,11 +6757,299 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { } +static void llm_prepare_mla(llama_model & model, int mla) { + if (model.arch != LLM_ARCH_DEEPSEEK2) return; + const auto& hparams = model.hparams; + const int n_layer = model.layers.size(); + int n_to_compute = 0; + for (auto& l : model.layers) { + if (!l.wk_b) ++n_to_compute; + } + if (mla > 0 && n_to_compute > 0) { + // Prepare wk_b tensors to enable MLA usage also for model files that do not include + // the wk_b tensors (because, e.g., they were converted using mainline llama.cpp) + // We do it here because otherwise wkv_b may get run-time-repacked, which will make + // preparation of wk_b impossible. It also has the benefit that wk_b will get automatically + // run-time repacked if the rtr option is set. The downside is that we will prepare wk_b + // even if it is not needed (because MLA is not being used). If we wanted to avoid + // computing wk_b from wkv_b if not needed, we would need to propagate the context parameters + // to the model loading function. On the other hand, in some hypothetical bright future, + // where we are able to use the optimum settings for the computation, which for DeepSeekV3/R1/Lite + // is no MLA + FA for prompt processing, and MLA + FA for token generation, it would be useful + // to change the MLA setting on the fly, depending on context. In that case, having prepared + // the MLA tensors here is the right ting to do^TM. + const uint32_t n_embd_head_qk_rope = hparams.n_rot; + const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; + const uint32_t kv_lora_rank = hparams.n_lora_kv; + const int32_t n_embd_head_v = hparams.n_embd_head_v; + const int32_t n_head = hparams.n_head(0); + std::vector<uint8_t> work_data; + LLAMA_LOG_INFO("============ %s: need to compute %d wk_b/wv_b tensors\n", __func__, n_to_compute); + for (int il = 1; il < n_layer; ++il) { + // Somehow the number of heads is being defined as being per layer. Not sure why this is the + // case, but for now we do not support strange models that have different numbers of heads + // in different model layers. + if (hparams.n_head(il) != n_head) throw std::runtime_error("Unsupported configuration"); + } + auto total_size_wkb = 0; + size_t max_wkv_size = 0; + size_t max_wk_size = 0; + for (auto& l : model.layers) { + if (!l.wk_b) { + auto new_type = ggml_is_quantized(l.wkv_b->type) ? GGML_TYPE_Q8_0 : l.wkv_b->type; + auto size = ggml_row_size(new_type, n_embd_head_qk_nope)*kv_lora_rank*n_head; + max_wk_size = std::max(max_wk_size, size); + if (!ggml_backend_buffer_is_host(l.wkv_b->buffer)) { + max_wkv_size = std::max(max_wkv_size, ggml_nbytes(l.wkv_b)); + } + } + } + auto context_size = max_wk_size + 2*n_embd_head_qk_nope*kv_lora_rank*n_head*sizeof(float); + context_size *= 2; // just in case; + std::vector<uint8_t> wkv_buffer; + if (max_wkv_size > 0) wkv_buffer.resize(max_wkv_size); + // So, transposing tensors and then making them contiguous as needed for wk_b may or may not + // be supported on all backends. Hence, to be sure that the preparation of wk_b will + // work correctly, we do it on the CPU backend. We then copy the resulting tensor data to + // the bacikend where wkv_b is stored. + ggml_init_params params{context_size, nullptr, true}; + auto ctx = ggml_init(params); + auto graph = ggml_new_graph_custom(ctx, 8, false); + std::vector<uint8_t> tensor_data(2*n_embd_head_qk_nope*kv_lora_rank*n_head*sizeof(float) + max_wk_size); + for (int il = 0; il < n_layer; ++il) { + auto& l = model.layers[il]; + if (l.wk_b) continue; + auto wkv_b = *l.wkv_b; + if (!ggml_backend_buffer_is_host(l.wkv_b->buffer)) { + ggml_backend_tensor_get(l.wkv_b, wkv_buffer.data(), 0, ggml_nbytes(l.wkv_b)); + wkv_b.data = wkv_buffer.data(); + } + auto wk_b_view = ggml_view_3d(ctx, &wkv_b, kv_lora_rank, n_embd_head_qk_nope, n_head, + l.wkv_b->nb[1], l.wkv_b->nb[1]*(n_embd_head_qk_nope + n_embd_head_v), 0); + auto wk_b_f32 = ggml_cast(ctx, wk_b_view, GGML_TYPE_F32); + wk_b_f32->data = tensor_data.data(); + auto wk_b_f32_tview = ggml_transpose(ctx, wk_b_f32); + auto wk_b_f32_t = ggml_cont(ctx, wk_b_f32_tview); + wk_b_f32_t->data = (char *)wk_b_f32->data + ggml_nbytes(wk_b_f32); + + auto new_type = ggml_is_quantized(wkv_b.type) ? + wkv_b.type >= GGML_TYPE_Q4_0_R8 && wkv_b.type <= GGML_TYPE_Q8_K_R8 ? GGML_TYPE_Q8_0_R8 : GGML_TYPE_Q8_0 : wkv_b.type; + auto wk_b = ggml_cast(ctx, wk_b_f32_t, new_type); + wk_b->data = (char *)wk_b_f32_t->data + ggml_nbytes(wk_b_f32_t); + + ggml_build_forward_expand(graph, wk_b); + + auto plan = ggml_graph_plan(graph, std::thread::hardware_concurrency()/2); + if (plan.work_size > work_data.size()) work_data.resize(plan.work_size); + plan.work_data = work_data.data(); + + auto status = ggml_graph_compute(graph, &plan); + if (status != GGML_STATUS_SUCCESS) throw std::runtime_error("Failed to compute wk_b"); + + auto name = std::string{"blk."} + std::to_string(il) + ".attn_k_b.weight"; + + l.computed_wk_b = std::make_unique<ggml_tensor>(*wk_b); + l.computed_wk_b->buffer = ggml_backend_buft_alloc_buffer(ggml_backend_buffer_get_type(l.wkv_b->buffer), ggml_nbytes(wk_b)); + l.computed_wk_b->data = ggml_backend_buffer_get_base(l.computed_wk_b->buffer); + l.computed_wk_b->op = GGML_OP_NONE; // we absolutely need to do this, else the backend will attempt to find the parents + // of wk_b, which no longer exist, and will therefore crash. + for (int j = 0; j < GGML_MAX_SRC; ++j) l.computed_wk_b->src[j] = nullptr; + ggml_set_name(l.computed_wk_b.get(), name.c_str()); + ggml_backend_buffer_set_usage(l.computed_wk_b->buffer, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + ggml_backend_tensor_set(l.computed_wk_b.get(), wk_b->data, 0, ggml_nbytes(wk_b)); + if (ggml_backend_buffer_is_host(l.computed_wk_b->buffer)) { + iqk_modify_tensor(l.computed_wk_b.get()); + } + + l.wk_b = l.computed_wk_b.get(); + model.tensors_by_name.push_back(std::make_pair(name, l.wk_b)); + + ggml_graph_clear(graph); + auto wv_b = ggml_cont(ctx, ggml_view_3d(ctx, &wkv_b, kv_lora_rank, n_embd_head_v, n_head, + l.wkv_b->nb[1], l.wkv_b->nb[1]*(n_embd_head_qk_nope + n_embd_head_v), l.wkv_b->nb[1]*n_embd_head_qk_nope)); + wv_b->data = tensor_data.data(); + ggml_build_forward_expand(graph, wv_b); + plan = ggml_graph_plan(graph, std::thread::hardware_concurrency()/2); + if (plan.work_size > work_data.size()) work_data.resize(plan.work_size); + plan.work_data = work_data.data(); + status = ggml_graph_compute(graph, &plan); + if (status != GGML_STATUS_SUCCESS) throw std::runtime_error("Failed to compute wv_b"); + + name = std::string{"blk."} + std::to_string(il) + ".attn_v_b.weight"; + + l.computed_wv_b = std::make_unique<ggml_tensor>(*wv_b); + l.computed_wv_b->buffer = ggml_backend_buft_alloc_buffer(ggml_backend_buffer_get_type(l.wkv_b->buffer), ggml_nbytes(wv_b)); + l.computed_wv_b->data = ggml_backend_buffer_get_base(l.computed_wv_b->buffer); + l.computed_wv_b->op = GGML_OP_NONE; // we absolutely need to do this, else the backend will attempt to find the parents + // of wk_b, which no longer exist, and will therefore crash. + for (int j = 0; j < GGML_MAX_SRC; ++j) l.computed_wv_b->src[j] = nullptr; + ggml_set_name(l.computed_wv_b.get(), name.c_str()); + ggml_backend_buffer_set_usage(l.computed_wv_b->buffer, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + ggml_backend_tensor_set(l.computed_wv_b.get(), wv_b->data, 0, ggml_nbytes(wv_b)); + if (ggml_backend_buffer_is_host(l.computed_wv_b->buffer)) { + iqk_modify_tensor(l.computed_wv_b.get()); + } + + l.wv_b = l.computed_wv_b.get(); + model.tensors_by_name.push_back(std::make_pair(name, l.wv_b)); + + printf("Computed %s as %ld x %ld x %ld and stored in buffer %s\n", name.c_str(), wk_b->ne[0], wk_b->ne[1], wk_b->ne[2], + ggml_backend_buffer_name(l.computed_wk_b->buffer)); + + ggml_graph_clear(graph); + } + ggml_free(ctx); + } + if (mla == 1) return; + + n_to_compute = 0; + for (auto& l : model.layers) { + if (l.wk_b && l.wv_b && !l.wkv_b) ++n_to_compute; + } + if (n_to_compute == 0) return; + + // + // Prepare wkv_b tensors to enable MLA=2,3 usage also for model files that have been + // crippled to the mainline llama.cpp MLA implementation (MLA=1 here). + // We do it here because otherwise wk_b and wv_b may get run-time-repacked, which will make + // preparation of wkv_b impossible. It also has the benefit that wkv_b will get automatically + // run-time repacked if the rtr option is set. + // + const uint32_t n_embd_head_qk_rope = hparams.n_rot; + const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; + const uint32_t kv_lora_rank = hparams.n_lora_kv; + const int32_t n_embd_head_v = hparams.n_embd_head_v; + const int32_t n_head = hparams.n_head(0); + std::vector<uint8_t> work_data; + LLAMA_LOG_INFO("============ %s: need to compute %d wkv_b tensors\n", __func__, n_to_compute); + for (int il = 1; il < n_layer; ++il) { + // Somehow the number of heads is being defined as being per layer. Not sure why this is the + // case, but for now we do not support strange models that have different numbers of heads + // in different model layers. + if (hparams.n_head(il) != n_head) throw std::runtime_error("Unsupported configuration"); + } + + size_t context_size = ggml_tensor_overhead()*16*n_layer; + + ggml_init_params params{context_size, nullptr, true}; + auto ctx = ggml_init(params); + auto graph = ggml_new_graph_custom(ctx, 8, false); + + //layer.wk_b = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, kv_lora_rank, n_head}, 0); + //layer.wv_b = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v, n_head}, 0); + + std::vector<char> wk_buffer, wv_buffer; + std::vector<char> tmp_buffer; + //std::vector<uint8_t> tensor_data(2*n_embd_head_qk_nope*kv_lora_rank*n_head*sizeof(float) + max_wk_size); + for (int il = 0; il < n_layer; ++il) { + auto& l = model.layers[il]; + if (l.wkv_b || !l.wk_b || !l.wv_b) continue; + auto wk_b = *l.wk_b; + auto wv_b = *l.wv_b; + if (!ggml_backend_buffer_is_host(l.wk_b->buffer)) { + auto nbytes = ggml_nbytes(l.wk_b); + if (wk_buffer.size() < nbytes) wk_buffer.resize(nbytes); + ggml_backend_tensor_get(l.wk_b, wk_buffer.data(), 0, nbytes); + wk_b.data = wk_buffer.data(); + } + if (!ggml_backend_buffer_is_host(l.wv_b->buffer)) { + auto nbytes = ggml_nbytes(l.wv_b); + if (wv_buffer.size() < nbytes) wv_buffer.resize(nbytes); + ggml_backend_tensor_get(l.wv_b, wv_buffer.data(), 0, nbytes); + wv_b.data = wv_buffer.data(); + } + + auto n_wk = ggml_nelements(&wk_b); + auto n_wv = ggml_nelements(&wv_b); + + size_t tot_size = 0; + if (wk_b.type != GGML_TYPE_F32) { + tot_size += n_wk*sizeof(float); + } + tot_size += n_wk*sizeof(float); // ggml_cont(ctx, ggml_transpose(ctx, wk_b_used)); + if (wv_b.type != GGML_TYPE_F32) { + tot_size += n_wv*sizeof(float); + } + tot_size += (n_wk + n_wv)*sizeof(float); // ggml_concat(ctx, wk_b_transposed, wv_b_used, 0); + tot_size += (n_wk + n_wv)*sizeof(float); // ggml_cast(ctx, wkv_b_f32, new_type); + + if (tmp_buffer.size() < tot_size) tmp_buffer.resize(tot_size); + + auto ptr = tmp_buffer.data(); + + auto wk_b_used = &wk_b; + if (wk_b.type != GGML_TYPE_F32) { + wk_b_used = ggml_cast(ctx, &wk_b, GGML_TYPE_F32); + wk_b_used->data = ptr; + ptr += ggml_nbytes(wk_b_used); + } + auto wk_b_transposed = ggml_cont(ctx, ggml_transpose(ctx, wk_b_used)); + wk_b_transposed->data = ptr; + ptr += ggml_nbytes(wk_b_transposed); + + auto wv_b_used = &wv_b; + if (wv_b.type != GGML_TYPE_F32) { + wv_b_used = ggml_cast(ctx, &wv_b, GGML_TYPE_F32); + wv_b_used->data = ptr; + ptr += ggml_nbytes(wv_b_used); + } + + auto wkv_b_f32_3d = ggml_concat(ctx, wk_b_transposed, wv_b_used, 1); + wkv_b_f32_3d->data = ptr; + ptr += ggml_nbytes(wkv_b_f32_3d); + + auto wkv_b_f32 = ggml_view_2d(ctx, wkv_b_f32_3d, wkv_b_f32_3d->ne[0], wkv_b_f32_3d->ne[1]*wkv_b_f32_3d->ne[2], + wkv_b_f32_3d->nb[1], 0); + + auto new_type = wk_b.type == GGML_TYPE_BF16 && wv_b.type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 + : wk_b.type == GGML_TYPE_F16 && wv_b.type == GGML_TYPE_F16 ? GGML_TYPE_F16 + : GGML_TYPE_Q8_0; + + auto wkv_b = ggml_cast(ctx, wkv_b_f32, new_type); + wkv_b->data = ptr; + ptr += ggml_nbytes(wkv_b); + + ggml_build_forward_expand(graph, wkv_b); + + auto plan = ggml_graph_plan(graph, std::thread::hardware_concurrency()/2); + if (plan.work_size > work_data.size()) work_data.resize(plan.work_size); + plan.work_data = work_data.data(); + + auto status = ggml_graph_compute(graph, &plan); + if (status != GGML_STATUS_SUCCESS) throw std::runtime_error("Failed to compute wkv_b"); + + auto name = std::string{"blk."} + std::to_string(il) + ".attn_kv_b.weight"; + + l.computed_wkv_b = std::make_unique<ggml_tensor>(*wkv_b); + l.computed_wkv_b->buffer = ggml_backend_buft_alloc_buffer(ggml_backend_buffer_get_type(l.wk_b->buffer), ggml_nbytes(wkv_b)); + l.computed_wkv_b->data = ggml_backend_buffer_get_base(l.computed_wkv_b->buffer); + l.computed_wkv_b->op = GGML_OP_NONE; // we absolutely need to do this, else the backend will attempt to find the parents + // of wkv_b, which no longer exist, and will therefore crash. + for (int j = 0; j < GGML_MAX_SRC; ++j) l.computed_wkv_b->src[j] = nullptr; + ggml_set_name(l.computed_wkv_b.get(), name.c_str()); + ggml_backend_buffer_set_usage(l.computed_wkv_b->buffer, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + ggml_backend_tensor_set(l.computed_wkv_b.get(), wkv_b->data, 0, ggml_nbytes(wkv_b)); + if (ggml_backend_buffer_is_host(l.computed_wkv_b->buffer)) { + iqk_modify_tensor(l.computed_wkv_b.get()); + } + + l.wkv_b = l.computed_wkv_b.get(); + model.tensors_by_name.push_back(std::make_pair(name, l.wkv_b)); + + printf("Computed %s as %ld x %ld and stored in buffer %s\n", name.c_str(), wkv_b->ne[0], wkv_b->ne[1], + ggml_backend_buffer_name(l.computed_wkv_b->buffer)); + + ggml_graph_clear(graph); + } + ggml_free(ctx); +} + // Returns false if cancelled by progress_callback static bool llm_load_tensors( llama_model_loader & ml, llama_model & model, int n_gpu_layers, + int mla_attn, enum llama_split_mode split_mode, int main_gpu, const float * tensor_split, @@ -8997,145 +9286,7 @@ static bool llm_load_tensors( } } - if (model.arch == LLM_ARCH_DEEPSEEK2) { - int n_to_compute = 0; - for (auto& l : model.layers) { - if (!l.wk_b) ++n_to_compute; - } - if (n_to_compute > 0) { - // Prepare wk_b tensors to enable MLA usage also for model files that do not include - // the wk_b tensors (because, e.g., they were converted using mainline llama.cpp) - // We do it here because otherwise wkv_b may get run-time-repacked, which will make - // preparation of wk_b impossible. It also has the benefit that wk_b will get automatically - // run-time repacked if the rtr option is set. The downside is that we will prepare wk_b - // even if it is not needed (because MLA is not being used). If we wanted to avoid - // computing wk_b from wkv_b if not needed, we would need to propagate the context parameters - // to the model loading function. On the other hand, in some hypothetical bright future, - // where we are able to use the optimum settings for the computation, which for DeepSeekV3/R1/Lite - // is no MLA + FA for prompt processing, and MLA + FA for token generation, it would be useful - // to change the MLA setting on the fly, depending on context. In that case, having prepared - // the MLA tensors here is the right ting to do^TM. - const uint32_t n_embd_head_qk_rope = hparams.n_rot; - const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; - const uint32_t kv_lora_rank = hparams.n_lora_kv; - const int32_t n_embd_head_v = hparams.n_embd_head_v; - const int32_t n_head = hparams.n_head(0); - std::vector<uint8_t> work_data; - LLAMA_LOG_INFO("============ %s: need to compute %d wk_b tensors\n", __func__, n_to_compute); - for (int il = 1; il < n_layer; ++il) { - // Somehow the number of heads is being defined as being per layer. Not sure why this is the - // case, but for now we do not support strange models that have different numbers of heads - // in different model layers. - if (hparams.n_head(il) != n_head) throw std::runtime_error("Unsupported configuration"); - } - auto total_size_wkb = 0; - size_t max_wkv_size = 0; - size_t max_wk_size = 0; - for (auto& l : model.layers) { - if (!l.wk_b) { - auto new_type = ggml_is_quantized(l.wkv_b->type) ? GGML_TYPE_Q8_0 : l.wkv_b->type; - auto size = ggml_row_size(new_type, n_embd_head_qk_nope)*kv_lora_rank*n_head; - max_wk_size = std::max(max_wk_size, size); - if (!ggml_backend_buffer_is_host(l.wkv_b->buffer)) { - max_wkv_size = std::max(max_wkv_size, ggml_nbytes(l.wkv_b)); - } - } - } - auto context_size = max_wk_size + 2*n_embd_head_qk_nope*kv_lora_rank*n_head*sizeof(float); - context_size *= 2; // just in case; - std::vector<uint8_t> wkv_buffer; - if (max_wkv_size > 0) wkv_buffer.resize(max_wkv_size); - // So, transposing tensors and then making them contiguous as needed for wk_b may or may not - // be supported on all backends. Hence, to be sure that the preparation of wk_b will - // work correctly, we do it on the CPU backend. We then copy the resulting tensor data to - // the bacikend where wkv_b is stored. - ggml_init_params params{context_size, nullptr, true}; - auto ctx = ggml_init(params); - auto graph = ggml_new_graph_custom(ctx, 8, false); - std::vector<uint8_t> tensor_data(2*n_embd_head_qk_nope*kv_lora_rank*n_head*sizeof(float) + max_wk_size); - for (int il = 0; il < n_layer; ++il) { - auto& l = model.layers[il]; - if (l.wk_b) continue; - auto wkv_b = *l.wkv_b; - if (!ggml_backend_buffer_is_host(l.wkv_b->buffer)) { - ggml_backend_tensor_get(l.wkv_b, wkv_buffer.data(), 0, ggml_nbytes(l.wkv_b)); - wkv_b.data = wkv_buffer.data(); - } - auto wk_b_view = ggml_view_3d(ctx, &wkv_b, kv_lora_rank, n_embd_head_qk_nope, n_head, - l.wkv_b->nb[1], l.wkv_b->nb[1]*(n_embd_head_qk_nope + n_embd_head_v), 0); - auto wk_b_f32 = ggml_cast(ctx, wk_b_view, GGML_TYPE_F32); - wk_b_f32->data = tensor_data.data(); - auto wk_b_f32_tview = ggml_transpose(ctx, wk_b_f32); - auto wk_b_f32_t = ggml_cont(ctx, wk_b_f32_tview); - wk_b_f32_t->data = (char *)wk_b_f32->data + ggml_nbytes(wk_b_f32); - - auto new_type = ggml_is_quantized(wkv_b.type) ? - wkv_b.type >= GGML_TYPE_Q4_0_R8 && wkv_b.type <= GGML_TYPE_Q8_K_R8 ? GGML_TYPE_Q8_0_R8 : GGML_TYPE_Q8_0 : wkv_b.type; - auto wk_b = ggml_cast(ctx, wk_b_f32_t, new_type); - wk_b->data = (char *)wk_b_f32_t->data + ggml_nbytes(wk_b_f32_t); - - ggml_build_forward_expand(graph, wk_b); - - auto plan = ggml_graph_plan(graph, std::thread::hardware_concurrency()/2); - if (plan.work_size > work_data.size()) work_data.resize(plan.work_size); - plan.work_data = work_data.data(); - - auto status = ggml_graph_compute(graph, &plan); - if (status != GGML_STATUS_SUCCESS) throw std::runtime_error("Failed to compute wk_b"); - - auto name = std::string{"blk."} + std::to_string(il) + ".attn_k_b.weight"; - - l.computed_wk_b = std::make_unique<ggml_tensor>(*wk_b); - l.computed_wk_b->buffer = ggml_backend_buft_alloc_buffer(ggml_backend_buffer_get_type(l.wkv_b->buffer), ggml_nbytes(wk_b)); - l.computed_wk_b->data = ggml_backend_buffer_get_base(l.computed_wk_b->buffer); - l.computed_wk_b->op = GGML_OP_NONE; // we absolutely need to do this, else the backend will attempt to find the parents - // of wk_b, which no longer exist, and will therefore crash. - for (int j = 0; j < GGML_MAX_SRC; ++j) l.computed_wk_b->src[j] = nullptr; - ggml_set_name(l.computed_wk_b.get(), name.c_str()); - ggml_backend_buffer_set_usage(l.computed_wk_b->buffer, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); - ggml_backend_tensor_set(l.computed_wk_b.get(), wk_b->data, 0, ggml_nbytes(wk_b)); - if (ggml_backend_buffer_is_host(l.computed_wk_b->buffer)) { - iqk_modify_tensor(l.computed_wk_b.get()); - } - - l.wk_b = l.computed_wk_b.get(); - - ggml_graph_clear(graph); - auto wv_b = ggml_cont(ctx, ggml_view_3d(ctx, &wkv_b, kv_lora_rank, n_embd_head_v, n_head, - l.wkv_b->nb[1], l.wkv_b->nb[1]*(n_embd_head_qk_nope + n_embd_head_v), l.wkv_b->nb[1]*n_embd_head_qk_nope)); - wv_b->data = tensor_data.data(); - ggml_build_forward_expand(graph, wv_b); - plan = ggml_graph_plan(graph, std::thread::hardware_concurrency()/2); - if (plan.work_size > work_data.size()) work_data.resize(plan.work_size); - plan.work_data = work_data.data(); - status = ggml_graph_compute(graph, &plan); - if (status != GGML_STATUS_SUCCESS) throw std::runtime_error("Failed to compute wv_b"); - - name = std::string{"blk."} + std::to_string(il) + ".attn_v_b.weight"; - - l.computed_wv_b = std::make_unique<ggml_tensor>(*wv_b); - l.computed_wv_b->buffer = ggml_backend_buft_alloc_buffer(ggml_backend_buffer_get_type(l.wkv_b->buffer), ggml_nbytes(wv_b)); - l.computed_wv_b->data = ggml_backend_buffer_get_base(l.computed_wv_b->buffer); - l.computed_wv_b->op = GGML_OP_NONE; // we absolutely need to do this, else the backend will attempt to find the parents - // of wk_b, which no longer exist, and will therefore crash. - for (int j = 0; j < GGML_MAX_SRC; ++j) l.computed_wv_b->src[j] = nullptr; - ggml_set_name(l.computed_wv_b.get(), name.c_str()); - ggml_backend_buffer_set_usage(l.computed_wv_b->buffer, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); - ggml_backend_tensor_set(l.computed_wv_b.get(), wv_b->data, 0, ggml_nbytes(wv_b)); - if (ggml_backend_buffer_is_host(l.computed_wv_b->buffer)) { - iqk_modify_tensor(l.computed_wv_b.get()); - } - - l.wv_b = l.computed_wv_b.get(); - - printf("Computed %s as %ld x %ld x %ld and stored in buffer %s\n", name.c_str(), wk_b->ne[0], wk_b->ne[1], wk_b->ne[2], - ggml_backend_buffer_name(l.computed_wk_b->buffer)); - - ggml_graph_clear(graph); - } - ggml_free(ctx); - } - } + llm_prepare_mla(model, mla_attn); if (use_mmap_buffer) { for (auto & mapping : ml.mappings) { @@ -9252,7 +9403,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam #endif if (!llm_load_tensors( - ml, model, params.n_gpu_layers, params.split_mode, params.main_gpu, params.tensor_split, params.use_mlock, + ml, model, params.n_gpu_layers, params.mla, params.split_mode, params.main_gpu, params.tensor_split, params.use_mlock, params.progress_callback, params.progress_callback_user_data )) { return -2; @@ -19928,6 +20079,7 @@ void llama_lora_adapter_free(struct llama_lora_adapter * adapter) { struct llama_model_params llama_model_default_params() { struct llama_model_params result = { /*.n_gpu_layers =*/ 0, + /*.mla =*/ 0, /*.split_mode =*/ LLAMA_SPLIT_MODE_LAYER, /*.main_gpu =*/ 0, /*.tensor_split =*/ nullptr, |