From bf08e00643fd529f748f0a858fd79f3061e3fa18 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 25 Feb 2024 22:12:24 +0200 Subject: llama : refactor k-shift implementation + KV defragmentation (#5691) * llama : refactor k-shift implementation ggml-ci * llama : rename llama_kv_cache_seq_shift to llama_kv_cache_seq_add * llama : cont k-shift refactoring + normalize type names ggml-ci * minor : fix MPI builds * llama : reuse n_rot from the build context ggml-ci * llama : revert enum name changes from this PR ggml-ci * llama : update llama_rope_type * llama : add comment about rope values * llama : fix build * passkey : apply kv cache updates explicitly ggml-ci * llama : change name to llama_kv_cache_update() * llama : add llama_kv_cache_seq_pos_max() * passkey : fix llama_kv_cache_seq_pos_max() usage * llama : some llama_kv_cell simplifications * llama : add llama_kv_cache_compress (EXPERIMENTAL) * llama : add alternative KV cache merging (EXPERIMENTAL) * llama : add llama_kv_cache_defrag * llama : comments * llama : remove llama_kv_cache_compress will add in a separate PR ggml-ci * llama : defragment via non-overlapping moves * llama : ggml_graph based defrag implementation ggml-ci * llama : switch the loop order in build_defrag * llama : add comments --- examples/passkey/passkey.cpp | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) (limited to 'examples/passkey/passkey.cpp') diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index e12a1cdf..47de67a9 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -126,7 +126,7 @@ int main(int argc, char ** argv) { const int n_batch = ctx_params.n_batch; const int n_batch_grp = ctx_params.n_batch/n_grp; - LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_kv_req = %d, n_grp = %d, n_batch = %d\n", __func__, n_len, n_ctx, n_kv_req, n_grp, n_batch); + LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_kv_req = %d, n_grp = %d, n_batch = %d, n_junk = %d, i_pos = %d\n", __func__, n_len, n_ctx, n_kv_req, n_grp, n_batch, n_junk, i_pos); // print the prompt token-by-token @@ -146,10 +146,11 @@ int main(int argc, char ** argv) { const int ib = i/n_batch - 1; const int bd = n_batch_grp*(n_grp - 1); - llama_kv_cache_seq_shift(ctx, 0, n_past - n_batch, n_past, ib*bd); - llama_kv_cache_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp); + llama_kv_cache_seq_add (ctx, 0, n_past - n_batch, n_past, ib*bd); + llama_kv_cache_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp); + llama_kv_cache_update (ctx); - n_past -= bd; + n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1; } llama_batch_clear(batch); @@ -179,10 +180,12 @@ int main(int argc, char ** argv) { LOG_TEE("%s: shifting KV cache with %d\n", __func__, n_discard); - llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard); - llama_kv_cache_seq_shift(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); + llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard); + llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); + llama_kv_cache_defrag (ctx); + llama_kv_cache_update (ctx); - n_past -= n_discard; + n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1; llama_batch_clear(batch); @@ -208,10 +211,12 @@ int main(int argc, char ** argv) { if (n_discard > 0) { LOG_TEE("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard); - llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard); - llama_kv_cache_seq_shift(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); + llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard); + llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); + llama_kv_cache_defrag (ctx); + llama_kv_cache_update (ctx); - n_past -= n_discard; + n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1; } } -- cgit v1.2.3