summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorXuan Son Nguyen <thichthat@gmail.com>2024-03-25 09:42:17 +0100
committerGitHub <noreply@github.com>2024-03-25 09:42:17 +0100
commitad3a0505e3b6cd777259ee35e61d428357ffc565 (patch)
treeae3976c33914df984df4f0b0ae5445422a0dd30d
parent95ad616cddda50273e955bfe192328acd9aa4896 (diff)
Server: clean up OAI params parsing function (#6284)
* server: clean up oai parsing function * fix response_format * fix empty response_format * minor fixes * add TODO for logprobs * update docs
-rw-r--r--examples/server/README.md2
-rw-r--r--examples/server/server.cpp13
-rw-r--r--examples/server/utils.hpp86
3 files changed, 63 insertions, 38 deletions
diff --git a/examples/server/README.md b/examples/server/README.md
index dfea2b90..49121a46 100644
--- a/examples/server/README.md
+++ b/examples/server/README.md
@@ -360,7 +360,7 @@ Notice that each `probs` is an array of length `n_probs`.
- `default_generation_settings` - the default generation settings for the `/completion` endpoint, has the same fields as the `generation_settings` response object from the `/completion` endpoint.
- `total_slots` - the total number of slots for process requests (defined by `--parallel` option)
-- **POST** `/v1/chat/completions`: OpenAI-compatible Chat Completions API. Given a ChatML-formatted json description in `messages`, it returns the predicted completion. Both synchronous and streaming mode are supported, so scripted and interactive applications work fine. While no strong claims of compatibility with OpenAI API spec is being made, in our experience it suffices to support many apps. Only ChatML-tuned models, such as Dolphin, OpenOrca, OpenHermes, OpenChat-3.5, etc can be used with this endpoint.
+- **POST** `/v1/chat/completions`: OpenAI-compatible Chat Completions API. Given a ChatML-formatted json description in `messages`, it returns the predicted completion. Both synchronous and streaming mode are supported, so scripted and interactive applications work fine. While no strong claims of compatibility with OpenAI API spec is being made, in our experience it suffices to support many apps. Only model with [supported chat template](https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template) can be used optimally with this endpoint. By default, ChatML template will be used.
*Options:*
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index b02c2546..338e60f2 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -847,9 +847,16 @@ struct server_context {
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
slot.params.seed = json_value(data, "seed", default_params.seed);
- if (data.contains("json_schema") && !data.contains("grammar")) {
+ slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
+ slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
+
+ // process "json_schema" and "grammar"
+ if (data.contains("json_schema") && data.contains("grammar")) {
+ send_error(task, "Either \"json_schema\" or \"grammar\" can be specified, but not both", ERROR_TYPE_INVALID_REQUEST);
+ return false;
+ } else if (data.contains("json_schema") && !data.contains("grammar")) {
try {
- auto schema = json_value(data, "json_schema", json::object());
+ auto schema = json_value(data, "json_schema", json::object());
slot.sparams.grammar = json_schema_to_grammar(schema);
} catch (const std::exception & e) {
send_error(task, std::string("\"json_schema\": ") + e.what(), ERROR_TYPE_INVALID_REQUEST);
@@ -858,8 +865,6 @@ struct server_context {
} else {
slot.sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
}
- slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
- slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
if (slot.params.cache_prompt && slot.ga_n != 1) {
LOG_WARNING("cache_prompt is not supported with group-attention", {});
diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp
index 8f20ff61..7d9ab622 100644
--- a/examples/server/utils.hpp
+++ b/examples/server/utils.hpp
@@ -352,51 +352,71 @@ static json oaicompat_completion_params_parse(
// https://platform.openai.com/docs/api-reference/chat/create
llama_sampling_params default_sparams;
llama_params["model"] = json_value(body, "model", std::string("unknown"));
- llama_params["prompt"] = format_chat(model, chat_template, body["messages"]);
- llama_params["cache_prompt"] = json_value(body, "cache_prompt", false);
- llama_params["temperature"] = json_value(body, "temperature", 0.0);
- llama_params["top_k"] = json_value(body, "top_k", default_sparams.top_k);
- llama_params["top_p"] = json_value(body, "top_p", 1.0);
- llama_params["n_predict"] = json_value(body, "max_tokens", -1);
- llama_params["logit_bias"] = json_value(body, "logit_bias", json::object());
llama_params["frequency_penalty"] = json_value(body, "frequency_penalty", 0.0);
+ llama_params["logit_bias"] = json_value(body, "logit_bias", json::object());
+ llama_params["n_predict"] = json_value(body, "max_tokens", -1);
llama_params["presence_penalty"] = json_value(body, "presence_penalty", 0.0);
llama_params["seed"] = json_value(body, "seed", LLAMA_DEFAULT_SEED);
llama_params["stream"] = json_value(body, "stream", false);
- llama_params["mirostat"] = json_value(body, "mirostat", default_sparams.mirostat);
- llama_params["mirostat_tau"] = json_value(body, "mirostat_tau", default_sparams.mirostat_tau);
- llama_params["mirostat_eta"] = json_value(body, "mirostat_eta", default_sparams.mirostat_eta);
- llama_params["penalize_nl"] = json_value(body, "penalize_nl", default_sparams.penalize_nl);
- llama_params["typical_p"] = json_value(body, "typical_p", default_sparams.typical_p);
- llama_params["repeat_last_n"] = json_value(body, "repeat_last_n", default_sparams.penalty_last_n);
- llama_params["ignore_eos"] = json_value(body, "ignore_eos", false);
- llama_params["tfs_z"] = json_value(body, "tfs_z", default_sparams.tfs_z);
- llama_params["n_keep"] = json_value(body, "n_keep", 0);
-
- if (body.contains("grammar")) {
- llama_params["grammar"] = json_value(body, "grammar", json::object());
- }
+ llama_params["temperature"] = json_value(body, "temperature", 0.0);
+ llama_params["top_p"] = json_value(body, "top_p", 1.0);
- if (body.contains("response_format")) {
- auto response_format = json_value(body, "response_format", json::object());
- if (response_format.contains("type")) {
- if (response_format["type"] == "json_object") {
- llama_params["json_schema"] = json_value(response_format, "schema", json::object());
- } else {
- throw std::runtime_error("response_format type not supported: " + response_format["type"].dump());
- }
- }
- }
+ // Apply chat template to the list of messages
+ llama_params["prompt"] = format_chat(model, chat_template, body["messages"]);
- // Handle 'stop' field
+ // Handle "stop" field
if (body.contains("stop") && body["stop"].is_string()) {
llama_params["stop"] = json::array({body["stop"].get<std::string>()});
} else {
llama_params["stop"] = json_value(body, "stop", json::array());
}
+ // Some chat templates don't use EOS token to stop generation
+ // We must add their end sequences to list of stop words
+ llama_params["stop"].push_back("<|im_end|>"); // chatml
+ llama_params["stop"].push_back("<end_of_turn>"); // gemma
- // Ensure there is ChatML-specific end sequence among stop words
- llama_params["stop"].push_back("<|im_end|>");
+ // Handle "response_format" field
+ if (body.contains("response_format")) {
+ json response_format = json_value(body, "response_format", json::object());
+ std::string response_type = json_value(response_format, "type", std::string());
+ if (response_type == "json_object") {
+ llama_params["json_schema"] = json_value(response_format, "schema", json::object());
+ } else if (!response_type.empty() && response_type != "text") {
+ throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type);
+ }
+ }
+
+ // Handle "n" field
+ int n_choices = json_value(body, "n", 1);
+ if (n_choices != 1) {
+ throw std::runtime_error("Only one completion choice is allowed");
+ }
+
+ // Handle "logprobs" field
+ // TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future
+ if (body.contains("logprobs")) {
+ llama_params["n_probs"] = json_value(body, "top_logprobs", 20);
+ } else if (body.contains("top_logprobs")) {
+ throw std::runtime_error("top_logprobs requires logprobs to be set to true");
+ }
+
+ // Params supported by OAI but unsupported by llama.cpp
+ static const std::vector<std::string> unsupported_params { "tools", "tool_choice" };
+ for (auto & param : unsupported_params) {
+ if (body.contains(param)) {
+ throw std::runtime_error("Unsupported param: " + param);
+ }
+ }
+
+ // Copy remaining properties to llama_params
+ // This allows user to use llama.cpp-specific params like "mirostat", "tfs_z",... via OAI endpoint.
+ // See "launch_slot_with_task()" for a complete list of params supported by llama.cpp
+ for (const auto & item : body.items()) {
+ // Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens"
+ if (!llama_params.contains(item.key()) || item.key() == "n_predict") {
+ llama_params[item.key()] = item.value();
+ }
+ }
return llama_params;
}