summaryrefslogtreecommitdiff
path: root/examples/server/oai.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'examples/server/oai.hpp')
-rw-r--r--examples/server/oai.hpp8
1 files changed, 6 insertions, 2 deletions
diff --git a/examples/server/oai.hpp b/examples/server/oai.hpp
index 43410f80..2eca8a9f 100644
--- a/examples/server/oai.hpp
+++ b/examples/server/oai.hpp
@@ -15,9 +15,13 @@
using json = nlohmann::json;
inline static json oaicompat_completion_params_parse(
- const json &body /* openai api json semantics */)
+ const json &body, /* openai api json semantics */
+ const std::string &chat_template)
{
json llama_params;
+ std::string formatted_prompt = chat_template == "chatml"
+ ? format_chatml(body["messages"]) // OpenAI 'messages' to chatml (with <|im_start|>,...)
+ : format_llama2(body["messages"]); // OpenAI 'messages' to llama2 (with [INST],...)
llama_params["__oaicompat"] = true;
@@ -30,7 +34,7 @@ inline static json oaicompat_completion_params_parse(
// https://platform.openai.com/docs/api-reference/chat/create
llama_sampling_params default_sparams;
llama_params["model"] = json_value(body, "model", std::string("unknown"));
- llama_params["prompt"] = format_chatml(body["messages"]); // OpenAI 'messages' to llama.cpp 'prompt'
+ llama_params["prompt"] = formatted_prompt;
llama_params["cache_prompt"] = json_value(body, "cache_prompt", false);
llama_params["temperature"] = json_value(body, "temperature", 0.0);
llama_params["top_k"] = json_value(body, "top_k", default_sparams.top_k);