summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-05-12 07:49:51 +0300
committerGitHub <noreply@github.com>2025-05-12 07:49:51 +0300
commitf27cd405422307e02dffa8949ac30bc56b4d2900 (patch)
tree722b742827684815ca2cc0fb6379edd4edd2f3fd
parent465569dff8b49a195450a0eb1974fd72a32fcebc (diff)
Enable faster prompt processing with mainline llama.cpp GGUFs (#409)
* Enable MLA-3 in crippled GGUFs: WIP * Enable MLA-3 in crippled GGUFs: seems to work * Add newly created tensors to model.tensors_by_name Else they don't get run-time repacked. --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--common/common.cpp1
-rw-r--r--include/llama.h1
-rw-r--r--src/llama.cpp432
3 files changed, 294 insertions, 140 deletions
diff --git a/common/common.cpp b/common/common.cpp
index ab936ee7..0dbde58f 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -2334,6 +2334,7 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
if (params.n_gpu_layers != -1) {
mparams.n_gpu_layers = params.n_gpu_layers;
}
+ mparams.mla = params.mla_attn;
mparams.rpc_servers = params.rpc_servers.c_str();
mparams.main_gpu = params.main_gpu;
mparams.split_mode = params.split_mode;
diff --git a/include/llama.h b/include/llama.h
index f1511548..0f3ae862 100644
--- a/include/llama.h
+++ b/include/llama.h
@@ -325,6 +325,7 @@ extern "C" {
struct llama_model_params {
int32_t n_gpu_layers; // number of layers to store in VRAM
+ int32_t mla; // MLA implementation to use (only applicable to DeepSeek models at this point)
enum llama_split_mode split_mode; // how to split the model across multiple GPUs
// main_gpu interpretation depends on split_mode:
diff --git a/src/llama.cpp b/src/llama.cpp
index b4d42c84..9369d10e 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -2942,6 +2942,7 @@ struct llama_layer {
std::unique_ptr<ggml_tensor> computed_wk_b;
std::unique_ptr<ggml_tensor> computed_wv_b;
+ std::unique_ptr<ggml_tensor> computed_wkv_b;
};
struct llama_kv_cell {
@@ -6756,11 +6757,299 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
}
+static void llm_prepare_mla(llama_model & model, int mla) {
+ if (model.arch != LLM_ARCH_DEEPSEEK2) return;
+ const auto& hparams = model.hparams;
+ const int n_layer = model.layers.size();
+ int n_to_compute = 0;
+ for (auto& l : model.layers) {
+ if (!l.wk_b) ++n_to_compute;
+ }
+ if (mla > 0 && n_to_compute > 0) {
+ // Prepare wk_b tensors to enable MLA usage also for model files that do not include
+ // the wk_b tensors (because, e.g., they were converted using mainline llama.cpp)
+ // We do it here because otherwise wkv_b may get run-time-repacked, which will make
+ // preparation of wk_b impossible. It also has the benefit that wk_b will get automatically
+ // run-time repacked if the rtr option is set. The downside is that we will prepare wk_b
+ // even if it is not needed (because MLA is not being used). If we wanted to avoid
+ // computing wk_b from wkv_b if not needed, we would need to propagate the context parameters
+ // to the model loading function. On the other hand, in some hypothetical bright future,
+ // where we are able to use the optimum settings for the computation, which for DeepSeekV3/R1/Lite
+ // is no MLA + FA for prompt processing, and MLA + FA for token generation, it would be useful
+ // to change the MLA setting on the fly, depending on context. In that case, having prepared
+ // the MLA tensors here is the right ting to do^TM.
+ const uint32_t n_embd_head_qk_rope = hparams.n_rot;
+ const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
+ const uint32_t kv_lora_rank = hparams.n_lora_kv;
+ const int32_t n_embd_head_v = hparams.n_embd_head_v;
+ const int32_t n_head = hparams.n_head(0);
+ std::vector<uint8_t> work_data;
+ LLAMA_LOG_INFO("============ %s: need to compute %d wk_b/wv_b tensors\n", __func__, n_to_compute);
+ for (int il = 1; il < n_layer; ++il) {
+ // Somehow the number of heads is being defined as being per layer. Not sure why this is the
+ // case, but for now we do not support strange models that have different numbers of heads
+ // in different model layers.
+ if (hparams.n_head(il) != n_head) throw std::runtime_error("Unsupported configuration");
+ }
+ auto total_size_wkb = 0;
+ size_t max_wkv_size = 0;
+ size_t max_wk_size = 0;
+ for (auto& l : model.layers) {
+ if (!l.wk_b) {
+ auto new_type = ggml_is_quantized(l.wkv_b->type) ? GGML_TYPE_Q8_0 : l.wkv_b->type;
+ auto size = ggml_row_size(new_type, n_embd_head_qk_nope)*kv_lora_rank*n_head;
+ max_wk_size = std::max(max_wk_size, size);
+ if (!ggml_backend_buffer_is_host(l.wkv_b->buffer)) {
+ max_wkv_size = std::max(max_wkv_size, ggml_nbytes(l.wkv_b));
+ }
+ }
+ }
+ auto context_size = max_wk_size + 2*n_embd_head_qk_nope*kv_lora_rank*n_head*sizeof(float);
+ context_size *= 2; // just in case;
+ std::vector<uint8_t> wkv_buffer;
+ if (max_wkv_size > 0) wkv_buffer.resize(max_wkv_size);
+ // So, transposing tensors and then making them contiguous as needed for wk_b may or may not
+ // be supported on all backends. Hence, to be sure that the preparation of wk_b will
+ // work correctly, we do it on the CPU backend. We then copy the resulting tensor data to
+ // the bacikend where wkv_b is stored.
+ ggml_init_params params{context_size, nullptr, true};
+ auto ctx = ggml_init(params);
+ auto graph = ggml_new_graph_custom(ctx, 8, false);
+ std::vector<uint8_t> tensor_data(2*n_embd_head_qk_nope*kv_lora_rank*n_head*sizeof(float) + max_wk_size);
+ for (int il = 0; il < n_layer; ++il) {
+ auto& l = model.layers[il];
+ if (l.wk_b) continue;
+ auto wkv_b = *l.wkv_b;
+ if (!ggml_backend_buffer_is_host(l.wkv_b->buffer)) {
+ ggml_backend_tensor_get(l.wkv_b, wkv_buffer.data(), 0, ggml_nbytes(l.wkv_b));
+ wkv_b.data = wkv_buffer.data();
+ }
+ auto wk_b_view = ggml_view_3d(ctx, &wkv_b, kv_lora_rank, n_embd_head_qk_nope, n_head,
+ l.wkv_b->nb[1], l.wkv_b->nb[1]*(n_embd_head_qk_nope + n_embd_head_v), 0);
+ auto wk_b_f32 = ggml_cast(ctx, wk_b_view, GGML_TYPE_F32);
+ wk_b_f32->data = tensor_data.data();
+ auto wk_b_f32_tview = ggml_transpose(ctx, wk_b_f32);
+ auto wk_b_f32_t = ggml_cont(ctx, wk_b_f32_tview);
+ wk_b_f32_t->data = (char *)wk_b_f32->data + ggml_nbytes(wk_b_f32);
+
+ auto new_type = ggml_is_quantized(wkv_b.type) ?
+ wkv_b.type >= GGML_TYPE_Q4_0_R8 && wkv_b.type <= GGML_TYPE_Q8_K_R8 ? GGML_TYPE_Q8_0_R8 : GGML_TYPE_Q8_0 : wkv_b.type;
+ auto wk_b = ggml_cast(ctx, wk_b_f32_t, new_type);
+ wk_b->data = (char *)wk_b_f32_t->data + ggml_nbytes(wk_b_f32_t);
+
+ ggml_build_forward_expand(graph, wk_b);
+
+ auto plan = ggml_graph_plan(graph, std::thread::hardware_concurrency()/2);
+ if (plan.work_size > work_data.size()) work_data.resize(plan.work_size);
+ plan.work_data = work_data.data();
+
+ auto status = ggml_graph_compute(graph, &plan);
+ if (status != GGML_STATUS_SUCCESS) throw std::runtime_error("Failed to compute wk_b");
+
+ auto name = std::string{"blk."} + std::to_string(il) + ".attn_k_b.weight";
+
+ l.computed_wk_b = std::make_unique<ggml_tensor>(*wk_b);
+ l.computed_wk_b->buffer = ggml_backend_buft_alloc_buffer(ggml_backend_buffer_get_type(l.wkv_b->buffer), ggml_nbytes(wk_b));
+ l.computed_wk_b->data = ggml_backend_buffer_get_base(l.computed_wk_b->buffer);
+ l.computed_wk_b->op = GGML_OP_NONE; // we absolutely need to do this, else the backend will attempt to find the parents
+ // of wk_b, which no longer exist, and will therefore crash.
+ for (int j = 0; j < GGML_MAX_SRC; ++j) l.computed_wk_b->src[j] = nullptr;
+ ggml_set_name(l.computed_wk_b.get(), name.c_str());
+ ggml_backend_buffer_set_usage(l.computed_wk_b->buffer, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
+ ggml_backend_tensor_set(l.computed_wk_b.get(), wk_b->data, 0, ggml_nbytes(wk_b));
+ if (ggml_backend_buffer_is_host(l.computed_wk_b->buffer)) {
+ iqk_modify_tensor(l.computed_wk_b.get());
+ }
+
+ l.wk_b = l.computed_wk_b.get();
+ model.tensors_by_name.push_back(std::make_pair(name, l.wk_b));
+
+ ggml_graph_clear(graph);
+ auto wv_b = ggml_cont(ctx, ggml_view_3d(ctx, &wkv_b, kv_lora_rank, n_embd_head_v, n_head,
+ l.wkv_b->nb[1], l.wkv_b->nb[1]*(n_embd_head_qk_nope + n_embd_head_v), l.wkv_b->nb[1]*n_embd_head_qk_nope));
+ wv_b->data = tensor_data.data();
+ ggml_build_forward_expand(graph, wv_b);
+ plan = ggml_graph_plan(graph, std::thread::hardware_concurrency()/2);
+ if (plan.work_size > work_data.size()) work_data.resize(plan.work_size);
+ plan.work_data = work_data.data();
+ status = ggml_graph_compute(graph, &plan);
+ if (status != GGML_STATUS_SUCCESS) throw std::runtime_error("Failed to compute wv_b");
+
+ name = std::string{"blk."} + std::to_string(il) + ".attn_v_b.weight";
+
+ l.computed_wv_b = std::make_unique<ggml_tensor>(*wv_b);
+ l.computed_wv_b->buffer = ggml_backend_buft_alloc_buffer(ggml_backend_buffer_get_type(l.wkv_b->buffer), ggml_nbytes(wv_b));
+ l.computed_wv_b->data = ggml_backend_buffer_get_base(l.computed_wv_b->buffer);
+ l.computed_wv_b->op = GGML_OP_NONE; // we absolutely need to do this, else the backend will attempt to find the parents
+ // of wk_b, which no longer exist, and will therefore crash.
+ for (int j = 0; j < GGML_MAX_SRC; ++j) l.computed_wv_b->src[j] = nullptr;
+ ggml_set_name(l.computed_wv_b.get(), name.c_str());
+ ggml_backend_buffer_set_usage(l.computed_wv_b->buffer, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
+ ggml_backend_tensor_set(l.computed_wv_b.get(), wv_b->data, 0, ggml_nbytes(wv_b));
+ if (ggml_backend_buffer_is_host(l.computed_wv_b->buffer)) {
+ iqk_modify_tensor(l.computed_wv_b.get());
+ }
+
+ l.wv_b = l.computed_wv_b.get();
+ model.tensors_by_name.push_back(std::make_pair(name, l.wv_b));
+
+ printf("Computed %s as %ld x %ld x %ld and stored in buffer %s\n", name.c_str(), wk_b->ne[0], wk_b->ne[1], wk_b->ne[2],
+ ggml_backend_buffer_name(l.computed_wk_b->buffer));
+
+ ggml_graph_clear(graph);
+ }
+ ggml_free(ctx);
+ }
+ if (mla == 1) return;
+
+ n_to_compute = 0;
+ for (auto& l : model.layers) {
+ if (l.wk_b && l.wv_b && !l.wkv_b) ++n_to_compute;
+ }
+ if (n_to_compute == 0) return;
+
+ //
+ // Prepare wkv_b tensors to enable MLA=2,3 usage also for model files that have been
+ // crippled to the mainline llama.cpp MLA implementation (MLA=1 here).
+ // We do it here because otherwise wk_b and wv_b may get run-time-repacked, which will make
+ // preparation of wkv_b impossible. It also has the benefit that wkv_b will get automatically
+ // run-time repacked if the rtr option is set.
+ //
+ const uint32_t n_embd_head_qk_rope = hparams.n_rot;
+ const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
+ const uint32_t kv_lora_rank = hparams.n_lora_kv;
+ const int32_t n_embd_head_v = hparams.n_embd_head_v;
+ const int32_t n_head = hparams.n_head(0);
+ std::vector<uint8_t> work_data;
+ LLAMA_LOG_INFO("============ %s: need to compute %d wkv_b tensors\n", __func__, n_to_compute);
+ for (int il = 1; il < n_layer; ++il) {
+ // Somehow the number of heads is being defined as being per layer. Not sure why this is the
+ // case, but for now we do not support strange models that have different numbers of heads
+ // in different model layers.
+ if (hparams.n_head(il) != n_head) throw std::runtime_error("Unsupported configuration");
+ }
+
+ size_t context_size = ggml_tensor_overhead()*16*n_layer;
+
+ ggml_init_params params{context_size, nullptr, true};
+ auto ctx = ggml_init(params);
+ auto graph = ggml_new_graph_custom(ctx, 8, false);
+
+ //layer.wk_b = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, kv_lora_rank, n_head}, 0);
+ //layer.wv_b = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v, n_head}, 0);
+
+ std::vector<char> wk_buffer, wv_buffer;
+ std::vector<char> tmp_buffer;
+ //std::vector<uint8_t> tensor_data(2*n_embd_head_qk_nope*kv_lora_rank*n_head*sizeof(float) + max_wk_size);
+ for (int il = 0; il < n_layer; ++il) {
+ auto& l = model.layers[il];
+ if (l.wkv_b || !l.wk_b || !l.wv_b) continue;
+ auto wk_b = *l.wk_b;
+ auto wv_b = *l.wv_b;
+ if (!ggml_backend_buffer_is_host(l.wk_b->buffer)) {
+ auto nbytes = ggml_nbytes(l.wk_b);
+ if (wk_buffer.size() < nbytes) wk_buffer.resize(nbytes);
+ ggml_backend_tensor_get(l.wk_b, wk_buffer.data(), 0, nbytes);
+ wk_b.data = wk_buffer.data();
+ }
+ if (!ggml_backend_buffer_is_host(l.wv_b->buffer)) {
+ auto nbytes = ggml_nbytes(l.wv_b);
+ if (wv_buffer.size() < nbytes) wv_buffer.resize(nbytes);
+ ggml_backend_tensor_get(l.wv_b, wv_buffer.data(), 0, nbytes);
+ wv_b.data = wv_buffer.data();
+ }
+
+ auto n_wk = ggml_nelements(&wk_b);
+ auto n_wv = ggml_nelements(&wv_b);
+
+ size_t tot_size = 0;
+ if (wk_b.type != GGML_TYPE_F32) {
+ tot_size += n_wk*sizeof(float);
+ }
+ tot_size += n_wk*sizeof(float); // ggml_cont(ctx, ggml_transpose(ctx, wk_b_used));
+ if (wv_b.type != GGML_TYPE_F32) {
+ tot_size += n_wv*sizeof(float);
+ }
+ tot_size += (n_wk + n_wv)*sizeof(float); // ggml_concat(ctx, wk_b_transposed, wv_b_used, 0);
+ tot_size += (n_wk + n_wv)*sizeof(float); // ggml_cast(ctx, wkv_b_f32, new_type);
+
+ if (tmp_buffer.size() < tot_size) tmp_buffer.resize(tot_size);
+
+ auto ptr = tmp_buffer.data();
+
+ auto wk_b_used = &wk_b;
+ if (wk_b.type != GGML_TYPE_F32) {
+ wk_b_used = ggml_cast(ctx, &wk_b, GGML_TYPE_F32);
+ wk_b_used->data = ptr;
+ ptr += ggml_nbytes(wk_b_used);
+ }
+ auto wk_b_transposed = ggml_cont(ctx, ggml_transpose(ctx, wk_b_used));
+ wk_b_transposed->data = ptr;
+ ptr += ggml_nbytes(wk_b_transposed);
+
+ auto wv_b_used = &wv_b;
+ if (wv_b.type != GGML_TYPE_F32) {
+ wv_b_used = ggml_cast(ctx, &wv_b, GGML_TYPE_F32);
+ wv_b_used->data = ptr;
+ ptr += ggml_nbytes(wv_b_used);
+ }
+
+ auto wkv_b_f32_3d = ggml_concat(ctx, wk_b_transposed, wv_b_used, 1);
+ wkv_b_f32_3d->data = ptr;
+ ptr += ggml_nbytes(wkv_b_f32_3d);
+
+ auto wkv_b_f32 = ggml_view_2d(ctx, wkv_b_f32_3d, wkv_b_f32_3d->ne[0], wkv_b_f32_3d->ne[1]*wkv_b_f32_3d->ne[2],
+ wkv_b_f32_3d->nb[1], 0);
+
+ auto new_type = wk_b.type == GGML_TYPE_BF16 && wv_b.type == GGML_TYPE_BF16 ? GGML_TYPE_BF16
+ : wk_b.type == GGML_TYPE_F16 && wv_b.type == GGML_TYPE_F16 ? GGML_TYPE_F16
+ : GGML_TYPE_Q8_0;
+
+ auto wkv_b = ggml_cast(ctx, wkv_b_f32, new_type);
+ wkv_b->data = ptr;
+ ptr += ggml_nbytes(wkv_b);
+
+ ggml_build_forward_expand(graph, wkv_b);
+
+ auto plan = ggml_graph_plan(graph, std::thread::hardware_concurrency()/2);
+ if (plan.work_size > work_data.size()) work_data.resize(plan.work_size);
+ plan.work_data = work_data.data();
+
+ auto status = ggml_graph_compute(graph, &plan);
+ if (status != GGML_STATUS_SUCCESS) throw std::runtime_error("Failed to compute wkv_b");
+
+ auto name = std::string{"blk."} + std::to_string(il) + ".attn_kv_b.weight";
+
+ l.computed_wkv_b = std::make_unique<ggml_tensor>(*wkv_b);
+ l.computed_wkv_b->buffer = ggml_backend_buft_alloc_buffer(ggml_backend_buffer_get_type(l.wk_b->buffer), ggml_nbytes(wkv_b));
+ l.computed_wkv_b->data = ggml_backend_buffer_get_base(l.computed_wkv_b->buffer);
+ l.computed_wkv_b->op = GGML_OP_NONE; // we absolutely need to do this, else the backend will attempt to find the parents
+ // of wkv_b, which no longer exist, and will therefore crash.
+ for (int j = 0; j < GGML_MAX_SRC; ++j) l.computed_wkv_b->src[j] = nullptr;
+ ggml_set_name(l.computed_wkv_b.get(), name.c_str());
+ ggml_backend_buffer_set_usage(l.computed_wkv_b->buffer, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
+ ggml_backend_tensor_set(l.computed_wkv_b.get(), wkv_b->data, 0, ggml_nbytes(wkv_b));
+ if (ggml_backend_buffer_is_host(l.computed_wkv_b->buffer)) {
+ iqk_modify_tensor(l.computed_wkv_b.get());
+ }
+
+ l.wkv_b = l.computed_wkv_b.get();
+ model.tensors_by_name.push_back(std::make_pair(name, l.wkv_b));
+
+ printf("Computed %s as %ld x %ld and stored in buffer %s\n", name.c_str(), wkv_b->ne[0], wkv_b->ne[1],
+ ggml_backend_buffer_name(l.computed_wkv_b->buffer));
+
+ ggml_graph_clear(graph);
+ }
+ ggml_free(ctx);
+}
+
// Returns false if cancelled by progress_callback
static bool llm_load_tensors(
llama_model_loader & ml,
llama_model & model,
int n_gpu_layers,
+ int mla_attn,
enum llama_split_mode split_mode,
int main_gpu,
const float * tensor_split,
@@ -8997,145 +9286,7 @@ static bool llm_load_tensors(
}
}
- if (model.arch == LLM_ARCH_DEEPSEEK2) {
- int n_to_compute = 0;
- for (auto& l : model.layers) {
- if (!l.wk_b) ++n_to_compute;
- }
- if (n_to_compute > 0) {
- // Prepare wk_b tensors to enable MLA usage also for model files that do not include
- // the wk_b tensors (because, e.g., they were converted using mainline llama.cpp)
- // We do it here because otherwise wkv_b may get run-time-repacked, which will make
- // preparation of wk_b impossible. It also has the benefit that wk_b will get automatically
- // run-time repacked if the rtr option is set. The downside is that we will prepare wk_b
- // even if it is not needed (because MLA is not being used). If we wanted to avoid
- // computing wk_b from wkv_b if not needed, we would need to propagate the context parameters
- // to the model loading function. On the other hand, in some hypothetical bright future,
- // where we are able to use the optimum settings for the computation, which for DeepSeekV3/R1/Lite
- // is no MLA + FA for prompt processing, and MLA + FA for token generation, it would be useful
- // to change the MLA setting on the fly, depending on context. In that case, having prepared
- // the MLA tensors here is the right ting to do^TM.
- const uint32_t n_embd_head_qk_rope = hparams.n_rot;
- const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
- const uint32_t kv_lora_rank = hparams.n_lora_kv;
- const int32_t n_embd_head_v = hparams.n_embd_head_v;
- const int32_t n_head = hparams.n_head(0);
- std::vector<uint8_t> work_data;
- LLAMA_LOG_INFO("============ %s: need to compute %d wk_b tensors\n", __func__, n_to_compute);
- for (int il = 1; il < n_layer; ++il) {
- // Somehow the number of heads is being defined as being per layer. Not sure why this is the
- // case, but for now we do not support strange models that have different numbers of heads
- // in different model layers.
- if (hparams.n_head(il) != n_head) throw std::runtime_error("Unsupported configuration");
- }
- auto total_size_wkb = 0;
- size_t max_wkv_size = 0;
- size_t max_wk_size = 0;
- for (auto& l : model.layers) {
- if (!l.wk_b) {
- auto new_type = ggml_is_quantized(l.wkv_b->type) ? GGML_TYPE_Q8_0 : l.wkv_b->type;
- auto size = ggml_row_size(new_type, n_embd_head_qk_nope)*kv_lora_rank*n_head;
- max_wk_size = std::max(max_wk_size, size);
- if (!ggml_backend_buffer_is_host(l.wkv_b->buffer)) {
- max_wkv_size = std::max(max_wkv_size, ggml_nbytes(l.wkv_b));
- }
- }
- }
- auto context_size = max_wk_size + 2*n_embd_head_qk_nope*kv_lora_rank*n_head*sizeof(float);
- context_size *= 2; // just in case;
- std::vector<uint8_t> wkv_buffer;
- if (max_wkv_size > 0) wkv_buffer.resize(max_wkv_size);
- // So, transposing tensors and then making them contiguous as needed for wk_b may or may not
- // be supported on all backends. Hence, to be sure that the preparation of wk_b will
- // work correctly, we do it on the CPU backend. We then copy the resulting tensor data to
- // the bacikend where wkv_b is stored.
- ggml_init_params params{context_size, nullptr, true};
- auto ctx = ggml_init(params);
- auto graph = ggml_new_graph_custom(ctx, 8, false);
- std::vector<uint8_t> tensor_data(2*n_embd_head_qk_nope*kv_lora_rank*n_head*sizeof(float) + max_wk_size);
- for (int il = 0; il < n_layer; ++il) {
- auto& l = model.layers[il];
- if (l.wk_b) continue;
- auto wkv_b = *l.wkv_b;
- if (!ggml_backend_buffer_is_host(l.wkv_b->buffer)) {
- ggml_backend_tensor_get(l.wkv_b, wkv_buffer.data(), 0, ggml_nbytes(l.wkv_b));
- wkv_b.data = wkv_buffer.data();
- }
- auto wk_b_view = ggml_view_3d(ctx, &wkv_b, kv_lora_rank, n_embd_head_qk_nope, n_head,
- l.wkv_b->nb[1], l.wkv_b->nb[1]*(n_embd_head_qk_nope + n_embd_head_v), 0);
- auto wk_b_f32 = ggml_cast(ctx, wk_b_view, GGML_TYPE_F32);
- wk_b_f32->data = tensor_data.data();
- auto wk_b_f32_tview = ggml_transpose(ctx, wk_b_f32);
- auto wk_b_f32_t = ggml_cont(ctx, wk_b_f32_tview);
- wk_b_f32_t->data = (char *)wk_b_f32->data + ggml_nbytes(wk_b_f32);
-
- auto new_type = ggml_is_quantized(wkv_b.type) ?
- wkv_b.type >= GGML_TYPE_Q4_0_R8 && wkv_b.type <= GGML_TYPE_Q8_K_R8 ? GGML_TYPE_Q8_0_R8 : GGML_TYPE_Q8_0 : wkv_b.type;
- auto wk_b = ggml_cast(ctx, wk_b_f32_t, new_type);
- wk_b->data = (char *)wk_b_f32_t->data + ggml_nbytes(wk_b_f32_t);
-
- ggml_build_forward_expand(graph, wk_b);
-
- auto plan = ggml_graph_plan(graph, std::thread::hardware_concurrency()/2);
- if (plan.work_size > work_data.size()) work_data.resize(plan.work_size);
- plan.work_data = work_data.data();
-
- auto status = ggml_graph_compute(graph, &plan);
- if (status != GGML_STATUS_SUCCESS) throw std::runtime_error("Failed to compute wk_b");
-
- auto name = std::string{"blk."} + std::to_string(il) + ".attn_k_b.weight";
-
- l.computed_wk_b = std::make_unique<ggml_tensor>(*wk_b);
- l.computed_wk_b->buffer = ggml_backend_buft_alloc_buffer(ggml_backend_buffer_get_type(l.wkv_b->buffer), ggml_nbytes(wk_b));
- l.computed_wk_b->data = ggml_backend_buffer_get_base(l.computed_wk_b->buffer);
- l.computed_wk_b->op = GGML_OP_NONE; // we absolutely need to do this, else the backend will attempt to find the parents
- // of wk_b, which no longer exist, and will therefore crash.
- for (int j = 0; j < GGML_MAX_SRC; ++j) l.computed_wk_b->src[j] = nullptr;
- ggml_set_name(l.computed_wk_b.get(), name.c_str());
- ggml_backend_buffer_set_usage(l.computed_wk_b->buffer, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
- ggml_backend_tensor_set(l.computed_wk_b.get(), wk_b->data, 0, ggml_nbytes(wk_b));
- if (ggml_backend_buffer_is_host(l.computed_wk_b->buffer)) {
- iqk_modify_tensor(l.computed_wk_b.get());
- }
-
- l.wk_b = l.computed_wk_b.get();
-
- ggml_graph_clear(graph);
- auto wv_b = ggml_cont(ctx, ggml_view_3d(ctx, &wkv_b, kv_lora_rank, n_embd_head_v, n_head,
- l.wkv_b->nb[1], l.wkv_b->nb[1]*(n_embd_head_qk_nope + n_embd_head_v), l.wkv_b->nb[1]*n_embd_head_qk_nope));
- wv_b->data = tensor_data.data();
- ggml_build_forward_expand(graph, wv_b);
- plan = ggml_graph_plan(graph, std::thread::hardware_concurrency()/2);
- if (plan.work_size > work_data.size()) work_data.resize(plan.work_size);
- plan.work_data = work_data.data();
- status = ggml_graph_compute(graph, &plan);
- if (status != GGML_STATUS_SUCCESS) throw std::runtime_error("Failed to compute wv_b");
-
- name = std::string{"blk."} + std::to_string(il) + ".attn_v_b.weight";
-
- l.computed_wv_b = std::make_unique<ggml_tensor>(*wv_b);
- l.computed_wv_b->buffer = ggml_backend_buft_alloc_buffer(ggml_backend_buffer_get_type(l.wkv_b->buffer), ggml_nbytes(wv_b));
- l.computed_wv_b->data = ggml_backend_buffer_get_base(l.computed_wv_b->buffer);
- l.computed_wv_b->op = GGML_OP_NONE; // we absolutely need to do this, else the backend will attempt to find the parents
- // of wk_b, which no longer exist, and will therefore crash.
- for (int j = 0; j < GGML_MAX_SRC; ++j) l.computed_wv_b->src[j] = nullptr;
- ggml_set_name(l.computed_wv_b.get(), name.c_str());
- ggml_backend_buffer_set_usage(l.computed_wv_b->buffer, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
- ggml_backend_tensor_set(l.computed_wv_b.get(), wv_b->data, 0, ggml_nbytes(wv_b));
- if (ggml_backend_buffer_is_host(l.computed_wv_b->buffer)) {
- iqk_modify_tensor(l.computed_wv_b.get());
- }
-
- l.wv_b = l.computed_wv_b.get();
-
- printf("Computed %s as %ld x %ld x %ld and stored in buffer %s\n", name.c_str(), wk_b->ne[0], wk_b->ne[1], wk_b->ne[2],
- ggml_backend_buffer_name(l.computed_wk_b->buffer));
-
- ggml_graph_clear(graph);
- }
- ggml_free(ctx);
- }
- }
+ llm_prepare_mla(model, mla_attn);
if (use_mmap_buffer) {
for (auto & mapping : ml.mappings) {
@@ -9252,7 +9403,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
#endif
if (!llm_load_tensors(
- ml, model, params.n_gpu_layers, params.split_mode, params.main_gpu, params.tensor_split, params.use_mlock,
+ ml, model, params.n_gpu_layers, params.mla, params.split_mode, params.main_gpu, params.tensor_split, params.use_mlock,
params.progress_callback, params.progress_callback_user_data
)) {
return -2;
@@ -19928,6 +20079,7 @@ void llama_lora_adapter_free(struct llama_lora_adapter * adapter) {
struct llama_model_params llama_model_default_params() {
struct llama_model_params result = {
/*.n_gpu_layers =*/ 0,
+ /*.mla =*/ 0,
/*.split_mode =*/ LLAMA_SPLIT_MODE_LAYER,
/*.main_gpu =*/ 0,
/*.tensor_split =*/ nullptr,