summaryrefslogtreecommitdiff
path: root/examples/server/oai.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'examples/server/oai.hpp')
-rw-r--r--examples/server/oai.hpp6
1 files changed, 2 insertions, 4 deletions
diff --git a/examples/server/oai.hpp b/examples/server/oai.hpp
index 2eca8a9f..ff4ad699 100644
--- a/examples/server/oai.hpp
+++ b/examples/server/oai.hpp
@@ -15,13 +15,11 @@
using json = nlohmann::json;
inline static json oaicompat_completion_params_parse(
+ const struct llama_model * model,
const json &body, /* openai api json semantics */
const std::string &chat_template)
{
json llama_params;
- std::string formatted_prompt = chat_template == "chatml"
- ? format_chatml(body["messages"]) // OpenAI 'messages' to chatml (with <|im_start|>,...)
- : format_llama2(body["messages"]); // OpenAI 'messages' to llama2 (with [INST],...)
llama_params["__oaicompat"] = true;
@@ -34,7 +32,7 @@ inline static json oaicompat_completion_params_parse(
// https://platform.openai.com/docs/api-reference/chat/create
llama_sampling_params default_sparams;
llama_params["model"] = json_value(body, "model", std::string("unknown"));
- llama_params["prompt"] = formatted_prompt;
+ llama_params["prompt"] = format_chat(model, chat_template, body["messages"]);
llama_params["cache_prompt"] = json_value(body, "cache_prompt", false);
llama_params["temperature"] = json_value(body, "temperature", 0.0);
llama_params["top_k"] = json_value(body, "top_k", default_sparams.top_k);