summaryrefslogtreecommitdiff
path: root/examples/server/server.cpp
diff options
context:
space:
mode:
authorAlexey Parfenov <zxed@alkatrazstudio.net>2023-11-11 05:48:21 +0000
committerGitHub <noreply@github.com>2023-11-10 23:48:21 -0600
commitd96ca7ded77df764db797b68b4a29e34c5b56285 (patch)
tree1a3c0ace565643e47933b3e241773e73b94d273b /examples/server/server.cpp
parent34b0a082074b073eb14c2bd93c0c070e20ddcd16 (diff)
server : fix crash when prompt exceeds context size (#3996)
Diffstat (limited to 'examples/server/server.cpp')
-rw-r--r--examples/server/server.cpp58
1 files changed, 29 insertions, 29 deletions
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index cbf36ad6..46862a84 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -1557,6 +1557,35 @@ struct llama_server_context
slot.num_prompt_tokens = prompt_tokens.size();
+ if (slot.params.n_keep < 0)
+ {
+ slot.params.n_keep = slot.num_prompt_tokens;
+ }
+ slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
+
+ // if input prompt is too big, truncate it
+ if (slot.num_prompt_tokens >= slot.n_ctx)
+ {
+ const int n_left = slot.n_ctx - slot.params.n_keep;
+ const int n_block_size = n_left / 2;
+ const int erased_blocks = (slot.num_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
+
+ std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + slot.params.n_keep);
+ new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, prompt_tokens.end());
+
+ LOG_VERBOSE("input truncated", {
+ {"n_ctx", slot.n_ctx},
+ {"n_keep", slot.params.n_keep},
+ {"n_left", n_left},
+ {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
+ });
+ slot.truncated = true;
+ prompt_tokens = new_tokens;
+
+ slot.num_prompt_tokens = prompt_tokens.size();
+ GGML_ASSERT(slot.num_prompt_tokens < slot.n_ctx);
+ }
+
if (!slot.params.cache_prompt)
{
llama_sampling_reset(slot.ctx_sampling);
@@ -1566,35 +1595,6 @@ struct llama_server_context
}
else
{
- if (slot.params.n_keep < 0)
- {
- slot.params.n_keep = slot.num_prompt_tokens;
- }
- slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
-
- // if input prompt is too big, truncate it
- if (slot.num_prompt_tokens >= slot.n_ctx)
- {
- const int n_left = slot.n_ctx - slot.params.n_keep;
- const int n_block_size = n_left / 2;
- const int erased_blocks = (slot.num_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
-
- std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + slot.params.n_keep);
- new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, prompt_tokens.end());
-
- LOG_VERBOSE("input truncated", {
- {"n_ctx", slot.n_ctx},
- {"n_keep", slot.params.n_keep},
- {"n_left", n_left},
- {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
- });
- slot.truncated = true;
- prompt_tokens = new_tokens;
-
- slot.num_prompt_tokens = prompt_tokens.size();
- GGML_ASSERT(slot.num_prompt_tokens < slot.n_ctx);
- }
-
// push the prompt into the sampling context (do not apply grammar)
for (auto &token : prompt_tokens)
{