diff options
author | Xuan Son Nguyen <thichthat@gmail.com> | 2024-02-11 11:16:22 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-11 12:16:22 +0200 |
commit | 907e08c1109f498b01036367804cff3082c44524 (patch) | |
tree | 333330e10039769bb10fa2f387c39a5ec2c6d98e /examples/server/server.cpp | |
parent | f026f8120f97090d34a52b3dc023c82e0ede3f7d (diff) |
server : add llama2 chat template (#5425)
* server: add mistral chat template
* server: fix typo
* server: rename template mistral to llama2
* server: format_llama2: remove BOS
* server: validate "--chat-template" argument
* server: clean up using_chatml variable
Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com>
---------
Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com>
Diffstat (limited to 'examples/server/server.cpp')
-rw-r--r-- | examples/server/server.cpp | 22 |
1 files changed, 20 insertions, 2 deletions
diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 8d668f79..4d212f1f 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -36,6 +36,7 @@ struct server_params std::string hostname = "127.0.0.1"; std::vector<std::string> api_keys; std::string public_path = "examples/server/public"; + std::string chat_template = "chatml"; int32_t port = 8080; int32_t read_timeout = 600; int32_t write_timeout = 600; @@ -1859,6 +1860,8 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n"); printf(" -gan N, --grp-attn-n N set the group attention factor to extend context size through self-extend(default: 1=disabled), used together with group attention width `--grp-attn-w`"); printf(" -gaw N, --grp-attn-w N set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n`"); + printf(" --chat-template FORMAT_NAME"); + printf(" set chat template, possible valus is: llama2, chatml (default %s)", sparams.chat_template.c_str()); printf("\n"); } @@ -2290,6 +2293,21 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, log_set_target(stdout); LOG_INFO("logging to file is disabled.", {}); } + else if (arg == "--chat-template") + { + if (++i >= argc) + { + invalid_param = true; + break; + } + std::string value(argv[i]); + if (value != "chatml" && value != "llama2") { + fprintf(stderr, "error: chat template can be \"llama2\" or \"chatml\", but got: %s\n", value.c_str()); + invalid_param = true; + break; + } + sparams.chat_template = value; + } else if (arg == "--override-kv") { if (++i >= argc) { @@ -2743,13 +2761,13 @@ int main(int argc, char **argv) // TODO: add mount point without "/v1" prefix -- how? - svr.Post("/v1/chat/completions", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res) + svr.Post("/v1/chat/completions", [&llama, &validate_api_key, &sparams](const httplib::Request &req, httplib::Response &res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); if (!validate_api_key(req, res)) { return; } - json data = oaicompat_completion_params_parse(json::parse(req.body)); + json data = oaicompat_completion_params_parse(json::parse(req.body), sparams.chat_template); const int task_id = llama.queue_tasks.get_new_id(); llama.queue_results.add_waiting_task_id(task_id); |