diff options
Diffstat (limited to 'examples/parallel')
-rw-r--r-- | examples/parallel/parallel.cpp | 20 |
1 files changed, 13 insertions, 7 deletions
diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 7d11fcd5..a2ef0fb0 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -107,6 +107,9 @@ int main(int argc, char ** argv) { // number of simultaneous "clients" to simulate const int32_t n_clients = params.n_parallel; + // dedicate one sequence to the system prompt + params.n_parallel += 1; + // requests to simulate const int32_t n_seq = params.n_sequences; @@ -196,8 +199,8 @@ int main(int argc, char ** argv) { } // assign the system KV cache to all parallel sequences - for (int32_t i = 1; i < n_clients; ++i) { - llama_kv_cache_seq_cp(ctx, 0, i, 0, n_tokens_system); + for (int32_t i = 1; i <= n_clients; ++i) { + llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); } LOG_TEE("\n"); @@ -221,15 +224,17 @@ int main(int argc, char ** argv) { client.i_batch = batch.n_tokens; - llama_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id }, true); + llama_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id + 1 }, true); client.n_decoded += 1; } if (batch.n_tokens == 0) { // all sequences have ended - clear the entire KV cache - for (int i = 0; i < n_clients; ++i) { - llama_kv_cache_seq_rm(ctx, i, n_tokens_system, -1); + for (int i = 1; i <= n_clients; ++i) { + llama_kv_cache_seq_rm(ctx, i, -1, -1); + // but keep the system prompt + llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); } LOG_TEE("%s: clearing the KV cache\n", __func__); @@ -255,7 +260,7 @@ int main(int argc, char ** argv) { tokens_prompt = ::llama_tokenize(ctx, client.prompt, false); for (size_t i = 0; i < tokens_prompt.size(); ++i) { - llama_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id }, false); + llama_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id + 1 }, false); } // extract the logits only for the last token @@ -366,7 +371,8 @@ int main(int argc, char ** argv) { } // delete only the generated part of the sequence, i.e. keep the system prompt in the cache - llama_kv_cache_seq_rm(ctx, client.id, n_tokens_system, -1); + llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1); + llama_kv_cache_seq_cp(ctx, 0, client.id + 1, -1, -1); const auto t_main_end = ggml_time_us(); |