summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsaood06 <saood05@gmail.com>2025-06-06 03:33:47 -0500
committerGitHub <noreply@github.com>2025-06-06 11:33:47 +0300
commitffd87f282e76ff9d34f47efd6d3f6af2071d416a (patch)
treef593753dd699b03ba50fc42a71214ca614f1f517
parenteded4e20d4decdc6e8c18e645fd1db0833ad251d (diff)
Make prompt cache saving and restoring MLA aware (#497)
* Remove kv_l, kvt_l and just use k_l and v_l * Hopefully take care of missing V cache (MLA) * Fix save and restore when there is no V cache * Fix double print * Update write_kv_cache_data and read_kv_cache_data to be MLA aware --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--src/llama.cpp9
1 files changed, 7 insertions, 2 deletions
diff --git a/src/llama.cpp b/src/llama.cpp
index be404500..dfd53337 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -21448,13 +21448,15 @@ struct llama_data_write {
// Get whole range at a time
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
+ const uint32_t n_embd_head_qk_rope = hparams.n_rot;
+ const uint32_t kv_lora_rank = hparams.n_lora_kv;
// Write key type
const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
write(&k_type_i, sizeof(k_type_i));
// Write row size of key
- const uint64_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
+ const uint64_t k_size_row = (ctx->cparams.mla_attn == 0) ? ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa) : ggml_row_size(kv_self.k_l[il]->type, kv_lora_rank + n_embd_head_qk_rope);
write(&k_size_row, sizeof(k_size_row));
// Read each range of cells of k_size length each into tmp_buf and write out
@@ -21758,6 +21760,9 @@ struct llama_data_read {
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
+ const uint32_t n_embd_head_qk_rope = hparams.n_rot;
+ const uint32_t kv_lora_rank = hparams.n_lora_kv;
+
// Read type of key
int32_t k_type_i_ref;
@@ -21771,7 +21776,7 @@ struct llama_data_read {
// Read row size of key
uint64_t k_size_row_ref;
read_to(&k_size_row_ref, sizeof(k_size_row_ref));
- const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
+ const uint64_t k_size_row = (ctx->cparams.mla_attn == 0) ? ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa) : ggml_row_size(kv_self.k_l[il]->type, kv_lora_rank + n_embd_head_qk_rope);
if (k_size_row != k_size_row_ref) {
LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
return false;