summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRőczey Barnabás <31726601+An0nie@users.noreply.github.com>2024-02-16 11:00:56 +0100
committerGitHub <noreply@github.com>2024-02-16 12:00:56 +0200
commit5f5808ca7b7f23a1fa7a77241842bb84a0e55108 (patch)
tree1699ff6045b4dadc86bf72e66abfd14e0831ff0b
parentf486f6e1e5e9d01603d9325ab3e05f1edb362a95 (diff)
server : fix system prompt cli (#5516)
-rw-r--r--examples/server/server.cpp45
1 files changed, 21 insertions, 24 deletions
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index 912c750c..0cb802ce 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -436,10 +436,6 @@ struct llama_server_context
default_generation_settings_for_props["seed"] = -1;
batch = llama_batch_init(n_ctx, 0, params.n_parallel);
-
- // empty system prompt
- system_prompt = "";
- system_tokens.clear();
}
std::vector<llama_token> tokenize(const json & json_prompt, bool add_bos) const
@@ -765,27 +761,30 @@ struct llama_server_context
}
void update_system_prompt() {
- system_tokens = ::llama_tokenize(ctx, system_prompt, add_bos_token);
+ kv_cache_clear();
+ system_tokens.clear();
- llama_batch_clear(batch);
+ if (!system_prompt.empty()) {
+ system_tokens = ::llama_tokenize(ctx, system_prompt, add_bos_token);
- kv_cache_clear();
+ llama_batch_clear(batch);
- for (int i = 0; i < (int) system_tokens.size(); ++i)
- {
- llama_batch_add(batch, system_tokens[i], i, { 0 }, false);
- }
+ for (int i = 0; i < (int)system_tokens.size(); ++i)
+ {
+ llama_batch_add(batch, system_tokens[i], i, { 0 }, false);
+ }
- if (llama_decode(ctx, batch) != 0)
- {
- LOG_TEE("%s: llama_decode() failed\n", __func__);
- return;
- }
+ if (llama_decode(ctx, batch) != 0)
+ {
+ LOG_TEE("%s: llama_decode() failed\n", __func__);
+ return;
+ }
- // assign the system KV cache to all parallel sequences
- for (int32_t i = 1; i < params.n_parallel; ++i)
- {
- llama_kv_cache_seq_cp(ctx, 0, i, 0, system_tokens.size());
+ // assign the system KV cache to all parallel sequences
+ for (int32_t i = 1; i < params.n_parallel; ++i)
+ {
+ llama_kv_cache_seq_cp(ctx, 0, i, 0, system_tokens.size());
+ }
}
LOG_TEE("system prompt updated\n");
@@ -807,10 +806,8 @@ struct llama_server_context
name_user = sys_props.value("anti_prompt", "");
name_assistant = sys_props.value("assistant_name", "");
- if (slots.size() > 0)
- {
- notify_system_prompt_changed();
- }
+
+ notify_system_prompt_changed();
}
static size_t find_stopping_strings(const std::string &text, const size_t last_token_size,