diff options
author | Georgi Gerganov <ggerganov@gmail.com> | 2023-11-23 19:07:56 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-11-23 19:07:56 +0200 |
commit | 6b0a7420d03b9d13cb0e9439a01ce8476d8bf093 (patch) | |
tree | f184d281cb47e357e4ead4a93a0d1fe504c74bbe /examples/parallel/parallel.cpp | |
parent | d103d935c0e75769a6a597f7a64cab72c6cc3e79 (diff) |
llama : KV cache view API + better KV cache management (#4170)
* llama : keep track of used KV cells + better KV cache management
* llama : zero KV cache used upon clear
ggml-ci
* llama : allow exporting a view of the KV cache (#4180)
* Allow exporting a view of the KV cache
* Allow dumping the sequences per cell in common
* Track max contiguous cells value and position as well
* Fix max contiguous empty cells index calculation
Make dump functions deal with lengths or sequences counts > 10 better
* Fix off by one error in dump_kv_cache_view
* Add doc comments for KV cache view functions
Eliminate cell sequence struct; use llama_seq_id directly
Minor cleanups
* common : add -dkvc arg for enabling kv cache dumps
---------
Co-authored-by: Kerfuffle <44031344+KerfuffleV2@users.noreply.github.com>
Diffstat (limited to 'examples/parallel/parallel.cpp')
-rw-r--r-- | examples/parallel/parallel.cpp | 9 |
1 files changed, 9 insertions, 0 deletions
diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 9b89bdfe..d2e074d9 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -113,6 +113,8 @@ int main(int argc, char ** argv) { // insert new requests as soon as the previous one is done const bool cont_batching = params.cont_batching; + const bool dump_kv_cache = params.dump_kv_cache; + #ifndef LOG_DISABLE_LOGS log_set_target(log_filename_generator("parallel", "log")); LOG_TEE("Log start\n"); @@ -172,6 +174,8 @@ int main(int argc, char ** argv) { int32_t n_total_gen = 0; int32_t n_cache_miss = 0; + struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, n_clients); + const auto t_main_start = ggml_time_us(); LOG_TEE("%s: Simulating parallel requests from clients:\n", __func__); @@ -201,6 +205,11 @@ int main(int argc, char ** argv) { LOG_TEE("Processing requests ...\n\n"); while (true) { + if (dump_kv_cache) { + llama_kv_cache_view_update(ctx, &kvc_view); + dump_kv_cache_view_seqs(kvc_view, 40); + } + llama_batch_clear(batch); // decode any currently ongoing sequences |