From efc72253f7987ed7bdc8bde9d9fa5c7cac2f6292 Mon Sep 17 00:00:00 2001 From: Jorge A <161275481+jorgealias@users.noreply.github.com> Date: Wed, 28 Feb 2024 01:39:15 -0700 Subject: server : add "/chat/completions" alias for "/v1/...` (#5722) * Add "/chat/completions" as alias for "/v1/chat/completions" * merge to upstream master * minor : fix trailing whitespace --------- Co-authored-by: Georgi Gerganov --- examples/server/server.cpp | 131 +++++++++++++++++++++++---------------------- 1 file changed, 66 insertions(+), 65 deletions(-) (limited to 'examples/server/server.cpp') diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 846ef7e5..6b3ee531 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3211,87 +3211,88 @@ int main(int argc, char **argv) res.set_content(models.dump(), "application/json; charset=utf-8"); }); + const auto chat_completions = [&llama, &validate_api_key, &sparams](const httplib::Request &req, httplib::Response &res) + { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + if (!validate_api_key(req, res)) { + return; + } + json data = oaicompat_completion_params_parse(llama.model, json::parse(req.body), sparams.chat_template); - // TODO: add mount point without "/v1" prefix -- how? - svr.Post("/v1/chat/completions", [&llama, &validate_api_key, &sparams](const httplib::Request &req, httplib::Response &res) - { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - if (!validate_api_key(req, res)) { - return; - } - json data = oaicompat_completion_params_parse(llama.model, json::parse(req.body), sparams.chat_template); - - const int task_id = llama.queue_tasks.get_new_id(); - llama.queue_results.add_waiting_task_id(task_id); - llama.request_completion(task_id, data, false, false, -1); + const int task_id = llama.queue_tasks.get_new_id(); + llama.queue_results.add_waiting_task_id(task_id); + llama.request_completion(task_id, data, false, false, -1); - if (!json_value(data, "stream", false)) { - std::string completion_text; - task_result result = llama.queue_results.recv(task_id); + if (!json_value(data, "stream", false)) { + std::string completion_text; + task_result result = llama.queue_results.recv(task_id); - if (!result.error && result.stop) { - json oaicompat_result = format_final_response_oaicompat(data, result); + if (!result.error && result.stop) { + json oaicompat_result = format_final_response_oaicompat(data, result); - res.set_content(oaicompat_result.dump(-1, ' ', false, - json::error_handler_t::replace), - "application/json; charset=utf-8"); - } else { - res.status = 500; - res.set_content(result.result_json["content"], "text/plain; charset=utf-8"); - } - llama.queue_results.remove_waiting_task_id(task_id); - } else { - const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink &sink) { - while (true) { - task_result llama_result = llama.queue_results.recv(task_id); - if (!llama_result.error) { - std::vector result_array = format_partial_response_oaicompat( llama_result); + res.set_content(oaicompat_result.dump(-1, ' ', false, + json::error_handler_t::replace), + "application/json; charset=utf-8"); + } else { + res.status = 500; + res.set_content(result.result_json["content"], "text/plain; charset=utf-8"); + } + llama.queue_results.remove_waiting_task_id(task_id); + } else { + const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink &sink) { + while (true) { + task_result llama_result = llama.queue_results.recv(task_id); + if (!llama_result.error) { + std::vector result_array = format_partial_response_oaicompat( llama_result); - for (auto it = result_array.begin(); it != result_array.end(); ++it) - { - if (!it->empty()) { - const std::string str = - "data: " + - it->dump(-1, ' ', false, json::error_handler_t::replace) + - "\n\n"; - LOG_VERBOSE("data stream", {{"to_send", str}}); - if (!sink.write(str.c_str(), str.size())) { - llama.queue_results.remove_waiting_task_id(task_id); - return false; - } - } - } - if (llama_result.stop) { - break; - } - } else { + for (auto it = result_array.begin(); it != result_array.end(); ++it) + { + if (!it->empty()) { const std::string str = - "error: " + - llama_result.result_json.dump(-1, ' ', false, - json::error_handler_t::replace) + + "data: " + + it->dump(-1, ' ', false, json::error_handler_t::replace) + "\n\n"; LOG_VERBOSE("data stream", {{"to_send", str}}); if (!sink.write(str.c_str(), str.size())) { llama.queue_results.remove_waiting_task_id(task_id); return false; } - break; } } - sink.done(); - llama.queue_results.remove_waiting_task_id(task_id); - return true; - }; + if (llama_result.stop) { + break; + } + } else { + const std::string str = + "error: " + + llama_result.result_json.dump(-1, ' ', false, + json::error_handler_t::replace) + + "\n\n"; + LOG_VERBOSE("data stream", {{"to_send", str}}); + if (!sink.write(str.c_str(), str.size())) { + llama.queue_results.remove_waiting_task_id(task_id); + return false; + } + break; + } + } + sink.done(); + llama.queue_results.remove_waiting_task_id(task_id); + return true; + }; - auto on_complete = [task_id, &llama](bool) { - // cancel request - llama.request_cancel(task_id); - llama.queue_results.remove_waiting_task_id(task_id); - }; + auto on_complete = [task_id, &llama](bool) { + // cancel request + llama.request_cancel(task_id); + llama.queue_results.remove_waiting_task_id(task_id); + }; - res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); - } - }); + res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); + } + }; + + svr.Post("/chat/completions", chat_completions); + svr.Post("/v1/chat/completions", chat_completions); svr.Post("/infill", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res) { -- cgit v1.2.3