diff options
author | Rőczey Barnabás <31726601+An0nie@users.noreply.github.com> | 2024-02-16 11:00:56 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-16 12:00:56 +0200 |
commit | 5f5808ca7b7f23a1fa7a77241842bb84a0e55108 (patch) | |
tree | 1699ff6045b4dadc86bf72e66abfd14e0831ff0b | |
parent | f486f6e1e5e9d01603d9325ab3e05f1edb362a95 (diff) |
server : fix system prompt cli (#5516)
-rw-r--r-- | examples/server/server.cpp | 45 |
1 files changed, 21 insertions, 24 deletions
diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 912c750c..0cb802ce 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -436,10 +436,6 @@ struct llama_server_context default_generation_settings_for_props["seed"] = -1; batch = llama_batch_init(n_ctx, 0, params.n_parallel); - - // empty system prompt - system_prompt = ""; - system_tokens.clear(); } std::vector<llama_token> tokenize(const json & json_prompt, bool add_bos) const @@ -765,27 +761,30 @@ struct llama_server_context } void update_system_prompt() { - system_tokens = ::llama_tokenize(ctx, system_prompt, add_bos_token); + kv_cache_clear(); + system_tokens.clear(); - llama_batch_clear(batch); + if (!system_prompt.empty()) { + system_tokens = ::llama_tokenize(ctx, system_prompt, add_bos_token); - kv_cache_clear(); + llama_batch_clear(batch); - for (int i = 0; i < (int) system_tokens.size(); ++i) - { - llama_batch_add(batch, system_tokens[i], i, { 0 }, false); - } + for (int i = 0; i < (int)system_tokens.size(); ++i) + { + llama_batch_add(batch, system_tokens[i], i, { 0 }, false); + } - if (llama_decode(ctx, batch) != 0) - { - LOG_TEE("%s: llama_decode() failed\n", __func__); - return; - } + if (llama_decode(ctx, batch) != 0) + { + LOG_TEE("%s: llama_decode() failed\n", __func__); + return; + } - // assign the system KV cache to all parallel sequences - for (int32_t i = 1; i < params.n_parallel; ++i) - { - llama_kv_cache_seq_cp(ctx, 0, i, 0, system_tokens.size()); + // assign the system KV cache to all parallel sequences + for (int32_t i = 1; i < params.n_parallel; ++i) + { + llama_kv_cache_seq_cp(ctx, 0, i, 0, system_tokens.size()); + } } LOG_TEE("system prompt updated\n"); @@ -807,10 +806,8 @@ struct llama_server_context name_user = sys_props.value("anti_prompt", ""); name_assistant = sys_props.value("assistant_name", ""); - if (slots.size() > 0) - { - notify_system_prompt_changed(); - } + + notify_system_prompt_changed(); } static size_t find_stopping_strings(const std::string &text, const size_t last_token_size, |