summaryrefslogtreecommitdiff
path: root/examples/server/server.cpp
diff options
context:
space:
mode:
authorcompilade <113953597+compilade@users.noreply.github.com>2024-02-25 13:43:50 -0500
committerGitHub <noreply@github.com>2024-02-25 20:43:50 +0200
commitf7625019c51ca437a5840576d92362cfa710e4a2 (patch)
tree6bfc6ccfd3f00857759192a1458a31f1d0b755d9 /examples/server/server.cpp
parentabbabc5e51d0d4656b438aec10b7fae9479ef37d (diff)
server : fix crash when system prompt is bigger than batch size (#5714)
The system prompt is now decoded in batches. * server : fix off-by-one n_past when start of prompt matches whole cache The tokens right after the matching part would otherwise skip a pos value.
Diffstat (limited to 'examples/server/server.cpp')
-rw-r--r--examples/server/server.cpp28
1 files changed, 25 insertions, 3 deletions
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index d970202d..c1eb6167 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -902,10 +902,24 @@ struct llama_server_context
llama_batch_add(batch, system_tokens[i], i, { 0 }, false);
}
- if (llama_decode(ctx, batch) != 0)
+ for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += params.n_batch)
{
- LOG_TEE("%s: llama_decode() failed\n", __func__);
- return;
+ const int32_t n_tokens = std::min(params.n_batch, (int32_t) (batch.n_tokens - i));
+ llama_batch batch_view = {
+ n_tokens,
+ batch.token + i,
+ nullptr,
+ batch.pos + i,
+ batch.n_seq_id + i,
+ batch.seq_id + i,
+ batch.logits + i,
+ 0, 0, 0, // unused
+ };
+ if (llama_decode(ctx, batch_view) != 0)
+ {
+ LOG_TEE("%s: llama_decode() failed\n", __func__);
+ return;
+ }
}
// assign the system KV cache to all parallel sequences
@@ -1785,6 +1799,14 @@ struct llama_server_context
}
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
+
+ // the last token of the cache is not in the KV cache until the next call to llama_decode
+ // (it was sampled, pushed into the "cache_tokens", but not yet put in the context)
+ if (slot.n_past > 0 && slot.n_past == (int32_t) slot.cache_tokens.size())
+ {
+ slot.n_past -= 1;
+ }
+
slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past;
if (slot.ga_n != 1)