summaryrefslogtreecommitdiff
path: root/examples/server/server.cpp
diff options
context:
space:
mode:
authorJorge A <161275481+jorgealias@users.noreply.github.com>2024-02-28 01:39:15 -0700
committerGitHub <noreply@github.com>2024-02-28 10:39:15 +0200
commitefc72253f7987ed7bdc8bde9d9fa5c7cac2f6292 (patch)
tree9f208051c3b76fa9817b748e9d2b805b439d75a5 /examples/server/server.cpp
parent7c4263d4261d6ee6f0539d53eb9e1b4d120ba8af (diff)
server : add "/chat/completions" alias for "/v1/...` (#5722)
* Add "/chat/completions" as alias for "/v1/chat/completions" * merge to upstream master * minor : fix trailing whitespace --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'examples/server/server.cpp')
-rw-r--r--examples/server/server.cpp131
1 files changed, 66 insertions, 65 deletions
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index 846ef7e5..6b3ee531 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -3211,87 +3211,88 @@ int main(int argc, char **argv)
res.set_content(models.dump(), "application/json; charset=utf-8");
});
+ const auto chat_completions = [&llama, &validate_api_key, &sparams](const httplib::Request &req, httplib::Response &res)
+ {
+ res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
+ if (!validate_api_key(req, res)) {
+ return;
+ }
+ json data = oaicompat_completion_params_parse(llama.model, json::parse(req.body), sparams.chat_template);
- // TODO: add mount point without "/v1" prefix -- how?
- svr.Post("/v1/chat/completions", [&llama, &validate_api_key, &sparams](const httplib::Request &req, httplib::Response &res)
- {
- res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
- if (!validate_api_key(req, res)) {
- return;
- }
- json data = oaicompat_completion_params_parse(llama.model, json::parse(req.body), sparams.chat_template);
-
- const int task_id = llama.queue_tasks.get_new_id();
- llama.queue_results.add_waiting_task_id(task_id);
- llama.request_completion(task_id, data, false, false, -1);
+ const int task_id = llama.queue_tasks.get_new_id();
+ llama.queue_results.add_waiting_task_id(task_id);
+ llama.request_completion(task_id, data, false, false, -1);
- if (!json_value(data, "stream", false)) {
- std::string completion_text;
- task_result result = llama.queue_results.recv(task_id);
+ if (!json_value(data, "stream", false)) {
+ std::string completion_text;
+ task_result result = llama.queue_results.recv(task_id);
- if (!result.error && result.stop) {
- json oaicompat_result = format_final_response_oaicompat(data, result);
+ if (!result.error && result.stop) {
+ json oaicompat_result = format_final_response_oaicompat(data, result);
- res.set_content(oaicompat_result.dump(-1, ' ', false,
- json::error_handler_t::replace),
- "application/json; charset=utf-8");
- } else {
- res.status = 500;
- res.set_content(result.result_json["content"], "text/plain; charset=utf-8");
- }
- llama.queue_results.remove_waiting_task_id(task_id);
- } else {
- const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink &sink) {
- while (true) {
- task_result llama_result = llama.queue_results.recv(task_id);
- if (!llama_result.error) {
- std::vector<json> result_array = format_partial_response_oaicompat( llama_result);
+ res.set_content(oaicompat_result.dump(-1, ' ', false,
+ json::error_handler_t::replace),
+ "application/json; charset=utf-8");
+ } else {
+ res.status = 500;
+ res.set_content(result.result_json["content"], "text/plain; charset=utf-8");
+ }
+ llama.queue_results.remove_waiting_task_id(task_id);
+ } else {
+ const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink &sink) {
+ while (true) {
+ task_result llama_result = llama.queue_results.recv(task_id);
+ if (!llama_result.error) {
+ std::vector<json> result_array = format_partial_response_oaicompat( llama_result);
- for (auto it = result_array.begin(); it != result_array.end(); ++it)
- {
- if (!it->empty()) {
- const std::string str =
- "data: " +
- it->dump(-1, ' ', false, json::error_handler_t::replace) +
- "\n\n";
- LOG_VERBOSE("data stream", {{"to_send", str}});
- if (!sink.write(str.c_str(), str.size())) {
- llama.queue_results.remove_waiting_task_id(task_id);
- return false;
- }
- }
- }
- if (llama_result.stop) {
- break;
- }
- } else {
+ for (auto it = result_array.begin(); it != result_array.end(); ++it)
+ {
+ if (!it->empty()) {
const std::string str =
- "error: " +
- llama_result.result_json.dump(-1, ' ', false,
- json::error_handler_t::replace) +
+ "data: " +
+ it->dump(-1, ' ', false, json::error_handler_t::replace) +
"\n\n";
LOG_VERBOSE("data stream", {{"to_send", str}});
if (!sink.write(str.c_str(), str.size())) {
llama.queue_results.remove_waiting_task_id(task_id);
return false;
}
- break;
}
}
- sink.done();
- llama.queue_results.remove_waiting_task_id(task_id);
- return true;
- };
+ if (llama_result.stop) {
+ break;
+ }
+ } else {
+ const std::string str =
+ "error: " +
+ llama_result.result_json.dump(-1, ' ', false,
+ json::error_handler_t::replace) +
+ "\n\n";
+ LOG_VERBOSE("data stream", {{"to_send", str}});
+ if (!sink.write(str.c_str(), str.size())) {
+ llama.queue_results.remove_waiting_task_id(task_id);
+ return false;
+ }
+ break;
+ }
+ }
+ sink.done();
+ llama.queue_results.remove_waiting_task_id(task_id);
+ return true;
+ };
- auto on_complete = [task_id, &llama](bool) {
- // cancel request
- llama.request_cancel(task_id);
- llama.queue_results.remove_waiting_task_id(task_id);
- };
+ auto on_complete = [task_id, &llama](bool) {
+ // cancel request
+ llama.request_cancel(task_id);
+ llama.queue_results.remove_waiting_task_id(task_id);
+ };
- res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
- }
- });
+ res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
+ }
+ };
+
+ svr.Post("/chat/completions", chat_completions);
+ svr.Post("/v1/chat/completions", chat_completions);
svr.Post("/infill", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res)
{