From ffd87f282e76ff9d34f47efd6d3f6af2071d416a Mon Sep 17 00:00:00 2001 From: saood06 Date: Fri, 6 Jun 2025 03:33:47 -0500 Subject: Make prompt cache saving and restoring MLA aware (#497) * Remove kv_l, kvt_l and just use k_l and v_l * Hopefully take care of missing V cache (MLA) * Fix save and restore when there is no V cache * Fix double print * Update write_kv_cache_data and read_kv_cache_data to be MLA aware --------- Co-authored-by: Iwan Kawrakow --- src/llama.cpp | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) (limited to 'src') diff --git a/src/llama.cpp b/src/llama.cpp index be404500..dfd53337 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -21448,13 +21448,15 @@ struct llama_data_write { // Get whole range at a time for (uint32_t il = 0; il < n_layer; ++il) { const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + const uint32_t n_embd_head_qk_rope = hparams.n_rot; + const uint32_t kv_lora_rank = hparams.n_lora_kv; // Write key type const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type; write(&k_type_i, sizeof(k_type_i)); // Write row size of key - const uint64_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa); + const uint64_t k_size_row = (ctx->cparams.mla_attn == 0) ? ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa) : ggml_row_size(kv_self.k_l[il]->type, kv_lora_rank + n_embd_head_qk_rope); write(&k_size_row, sizeof(k_size_row)); // Read each range of cells of k_size length each into tmp_buf and write out @@ -21758,6 +21760,9 @@ struct llama_data_read { // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block for (uint32_t il = 0; il < n_layer; ++il) { const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + const uint32_t n_embd_head_qk_rope = hparams.n_rot; + const uint32_t kv_lora_rank = hparams.n_lora_kv; + // Read type of key int32_t k_type_i_ref; @@ -21771,7 +21776,7 @@ struct llama_data_read { // Read row size of key uint64_t k_size_row_ref; read_to(&k_size_row_ref, sizeof(k_size_row_ref)); - const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa); + const uint64_t k_size_row = (ctx->cparams.mla_attn == 0) ? ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa) : ggml_row_size(kv_self.k_l[il]->type, kv_lora_rank + n_embd_head_qk_rope); if (k_size_row != k_size_row_ref) { LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il); return false; -- cgit v1.2.3