diff options
author | saood06 <saood05@gmail.com> | 2025-06-06 03:33:47 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-06-06 11:33:47 +0300 |
commit | ffd87f282e76ff9d34f47efd6d3f6af2071d416a (patch) | |
tree | f593753dd699b03ba50fc42a71214ca614f1f517 | |
parent | eded4e20d4decdc6e8c18e645fd1db0833ad251d (diff) |
Make prompt cache saving and restoring MLA aware (#497)
* Remove kv_l, kvt_l and just use k_l and v_l
* Hopefully take care of missing V cache (MLA)
* Fix save and restore when there is no V cache
* Fix double print
* Update write_kv_cache_data and read_kv_cache_data to be MLA aware
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | src/llama.cpp | 9 |
1 files changed, 7 insertions, 2 deletions
diff --git a/src/llama.cpp b/src/llama.cpp index be404500..dfd53337 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -21448,13 +21448,15 @@ struct llama_data_write { // Get whole range at a time for (uint32_t il = 0; il < n_layer; ++il) { const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + const uint32_t n_embd_head_qk_rope = hparams.n_rot; + const uint32_t kv_lora_rank = hparams.n_lora_kv; // Write key type const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type; write(&k_type_i, sizeof(k_type_i)); // Write row size of key - const uint64_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa); + const uint64_t k_size_row = (ctx->cparams.mla_attn == 0) ? ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa) : ggml_row_size(kv_self.k_l[il]->type, kv_lora_rank + n_embd_head_qk_rope); write(&k_size_row, sizeof(k_size_row)); // Read each range of cells of k_size length each into tmp_buf and write out @@ -21758,6 +21760,9 @@ struct llama_data_read { // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block for (uint32_t il = 0; il < n_layer; ++il) { const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + const uint32_t n_embd_head_qk_rope = hparams.n_rot; + const uint32_t kv_lora_rank = hparams.n_lora_kv; + // Read type of key int32_t k_type_i_ref; @@ -21771,7 +21776,7 @@ struct llama_data_read { // Read row size of key uint64_t k_size_row_ref; read_to(&k_size_row_ref, sizeof(k_size_row_ref)); - const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa); + const uint64_t k_size_row = (ctx->cparams.mla_attn == 0) ? ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa) : ggml_row_size(kv_self.k_l[il]->type, kv_lora_rank + n_embd_head_qk_rope); if (k_size_row != k_size_row_ref) { LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il); return false; |