summaryrefslogtreecommitdiff
path: root/examples/parallel/parallel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'examples/parallel/parallel.cpp')
-rw-r--r--examples/parallel/parallel.cpp20
1 files changed, 13 insertions, 7 deletions
diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp
index 7d11fcd5..a2ef0fb0 100644
--- a/examples/parallel/parallel.cpp
+++ b/examples/parallel/parallel.cpp
@@ -107,6 +107,9 @@ int main(int argc, char ** argv) {
// number of simultaneous "clients" to simulate
const int32_t n_clients = params.n_parallel;
+ // dedicate one sequence to the system prompt
+ params.n_parallel += 1;
+
// requests to simulate
const int32_t n_seq = params.n_sequences;
@@ -196,8 +199,8 @@ int main(int argc, char ** argv) {
}
// assign the system KV cache to all parallel sequences
- for (int32_t i = 1; i < n_clients; ++i) {
- llama_kv_cache_seq_cp(ctx, 0, i, 0, n_tokens_system);
+ for (int32_t i = 1; i <= n_clients; ++i) {
+ llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
}
LOG_TEE("\n");
@@ -221,15 +224,17 @@ int main(int argc, char ** argv) {
client.i_batch = batch.n_tokens;
- llama_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id }, true);
+ llama_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id + 1 }, true);
client.n_decoded += 1;
}
if (batch.n_tokens == 0) {
// all sequences have ended - clear the entire KV cache
- for (int i = 0; i < n_clients; ++i) {
- llama_kv_cache_seq_rm(ctx, i, n_tokens_system, -1);
+ for (int i = 1; i <= n_clients; ++i) {
+ llama_kv_cache_seq_rm(ctx, i, -1, -1);
+ // but keep the system prompt
+ llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
}
LOG_TEE("%s: clearing the KV cache\n", __func__);
@@ -255,7 +260,7 @@ int main(int argc, char ** argv) {
tokens_prompt = ::llama_tokenize(ctx, client.prompt, false);
for (size_t i = 0; i < tokens_prompt.size(); ++i) {
- llama_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id }, false);
+ llama_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id + 1 }, false);
}
// extract the logits only for the last token
@@ -366,7 +371,8 @@ int main(int argc, char ** argv) {
}
// delete only the generated part of the sequence, i.e. keep the system prompt in the cache
- llama_kv_cache_seq_rm(ctx, client.id, n_tokens_system, -1);
+ llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1);
+ llama_kv_cache_seq_cp(ctx, 0, client.id + 1, -1, -1);
const auto t_main_end = ggml_time_us();