summaryrefslogtreecommitdiff
path: root/examples/server/oai.hpp
diff options
context:
space:
mode:
authorXuan Son Nguyen <thichthat@gmail.com>2024-02-20 15:58:27 +0100
committerGitHub <noreply@github.com>2024-02-20 15:58:27 +0100
commit9c405c9f9a7cfd23511fd6b2de05dc72481119b4 (patch)
tree694b5a169d63eb4640df2d6f536d384cc481b300 /examples/server/oai.hpp
parent5207b3fbc500f89dfe528693e96540956dbaed96 (diff)
Server: use llama_chat_apply_template (#5593)
* server: use llama_chat_apply_template * server: remove trailing space * server: fix format_chat * server: fix help message Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * server: fix formatted_chat --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'examples/server/oai.hpp')
-rw-r--r--examples/server/oai.hpp6
1 files changed, 2 insertions, 4 deletions
diff --git a/examples/server/oai.hpp b/examples/server/oai.hpp
index 2eca8a9f..ff4ad699 100644
--- a/examples/server/oai.hpp
+++ b/examples/server/oai.hpp
@@ -15,13 +15,11 @@
using json = nlohmann::json;
inline static json oaicompat_completion_params_parse(
+ const struct llama_model * model,
const json &body, /* openai api json semantics */
const std::string &chat_template)
{
json llama_params;
- std::string formatted_prompt = chat_template == "chatml"
- ? format_chatml(body["messages"]) // OpenAI 'messages' to chatml (with <|im_start|>,...)
- : format_llama2(body["messages"]); // OpenAI 'messages' to llama2 (with [INST],...)
llama_params["__oaicompat"] = true;
@@ -34,7 +32,7 @@ inline static json oaicompat_completion_params_parse(
// https://platform.openai.com/docs/api-reference/chat/create
llama_sampling_params default_sparams;
llama_params["model"] = json_value(body, "model", std::string("unknown"));
- llama_params["prompt"] = formatted_prompt;
+ llama_params["prompt"] = format_chat(model, chat_template, body["messages"]);
llama_params["cache_prompt"] = json_value(body, "cache_prompt", false);
llama_params["temperature"] = json_value(body, "temperature", 0.0);
llama_params["top_k"] = json_value(body, "top_k", default_sparams.top_k);