diff options
author | Georgi Gerganov <ggerganov@gmail.com> | 2024-02-25 22:12:24 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-25 22:12:24 +0200 |
commit | bf08e00643fd529f748f0a858fd79f3061e3fa18 (patch) | |
tree | 0043ee582e83a19c8f1ca6d75d1519038f866e1c /examples/passkey/passkey.cpp | |
parent | f7625019c51ca437a5840576d92362cfa710e4a2 (diff) |
llama : refactor k-shift implementation + KV defragmentation (#5691)
* llama : refactor k-shift implementation
ggml-ci
* llama : rename llama_kv_cache_seq_shift to llama_kv_cache_seq_add
* llama : cont k-shift refactoring + normalize type names
ggml-ci
* minor : fix MPI builds
* llama : reuse n_rot from the build context
ggml-ci
* llama : revert enum name changes from this PR
ggml-ci
* llama : update llama_rope_type
* llama : add comment about rope values
* llama : fix build
* passkey : apply kv cache updates explicitly
ggml-ci
* llama : change name to llama_kv_cache_update()
* llama : add llama_kv_cache_seq_pos_max()
* passkey : fix llama_kv_cache_seq_pos_max() usage
* llama : some llama_kv_cell simplifications
* llama : add llama_kv_cache_compress (EXPERIMENTAL)
* llama : add alternative KV cache merging (EXPERIMENTAL)
* llama : add llama_kv_cache_defrag
* llama : comments
* llama : remove llama_kv_cache_compress
will add in a separate PR
ggml-ci
* llama : defragment via non-overlapping moves
* llama : ggml_graph based defrag implementation
ggml-ci
* llama : switch the loop order in build_defrag
* llama : add comments
Diffstat (limited to 'examples/passkey/passkey.cpp')
-rw-r--r-- | examples/passkey/passkey.cpp | 25 |
1 files changed, 15 insertions, 10 deletions
diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index e12a1cdf..47de67a9 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -126,7 +126,7 @@ int main(int argc, char ** argv) { const int n_batch = ctx_params.n_batch; const int n_batch_grp = ctx_params.n_batch/n_grp; - LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_kv_req = %d, n_grp = %d, n_batch = %d\n", __func__, n_len, n_ctx, n_kv_req, n_grp, n_batch); + LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_kv_req = %d, n_grp = %d, n_batch = %d, n_junk = %d, i_pos = %d\n", __func__, n_len, n_ctx, n_kv_req, n_grp, n_batch, n_junk, i_pos); // print the prompt token-by-token @@ -146,10 +146,11 @@ int main(int argc, char ** argv) { const int ib = i/n_batch - 1; const int bd = n_batch_grp*(n_grp - 1); - llama_kv_cache_seq_shift(ctx, 0, n_past - n_batch, n_past, ib*bd); - llama_kv_cache_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp); + llama_kv_cache_seq_add (ctx, 0, n_past - n_batch, n_past, ib*bd); + llama_kv_cache_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp); + llama_kv_cache_update (ctx); - n_past -= bd; + n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1; } llama_batch_clear(batch); @@ -179,10 +180,12 @@ int main(int argc, char ** argv) { LOG_TEE("%s: shifting KV cache with %d\n", __func__, n_discard); - llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard); - llama_kv_cache_seq_shift(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); + llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard); + llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); + llama_kv_cache_defrag (ctx); + llama_kv_cache_update (ctx); - n_past -= n_discard; + n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1; llama_batch_clear(batch); @@ -208,10 +211,12 @@ int main(int argc, char ** argv) { if (n_discard > 0) { LOG_TEE("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard); - llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard); - llama_kv_cache_seq_shift(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); + llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard); + llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); + llama_kv_cache_defrag (ctx); + llama_kv_cache_update (ctx); - n_past -= n_discard; + n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1; } } |