diff options
author | Xuan Son Nguyen <thichthat@gmail.com> | 2024-03-09 11:27:53 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-09 11:27:53 +0100 |
commit | 950ba1ab84db199f0bbdecdb2bb911f35261b321 (patch) | |
tree | e1633dcc8ed207b71d0b6c60af099ae344a0cb66 /examples/server/server.cpp | |
parent | e1fa9569ba8ce276bc7801a3cebdcf8b1aa116ea (diff) |
Server: reorganize some http logic (#5939)
* refactor static file handler
* use set_pre_routing_handler for validate_api_key
* merge embedding handlers
* correct http verb for endpoints
* fix embedding response
* fix test case CORS Options
* fix code style
Diffstat (limited to 'examples/server/server.cpp')
-rw-r--r-- | examples/server/server.cpp | 637 |
1 files changed, 328 insertions, 309 deletions
diff --git a/examples/server/server.cpp b/examples/server/server.cpp index c3b87c84..6e0f8328 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -113,7 +113,7 @@ struct server_params { int32_t n_threads_http = -1; std::string hostname = "127.0.0.1"; - std::string public_path = "examples/server/public"; + std::string public_path = ""; std::string chat_template = ""; std::string system_prompt = ""; @@ -2145,7 +2145,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co printf(" --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n"); printf(" --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str()); printf(" --port PORT port to listen (default (default: %d)\n", sparams.port); - printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str()); + printf(" --path PUBLIC_PATH path from which to serve static files (default: disabled)\n"); printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n"); printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n"); #ifdef CPPHTTPLIB_OPENSSL_SUPPORT @@ -2211,7 +2211,7 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams, invalid_param = true; break; } - sparams.api_keys.emplace_back(argv[i]); + sparams.api_keys.push_back(argv[i]); } else if (arg == "--api-key-file") { if (++i >= argc) { invalid_param = true; @@ -2712,180 +2712,6 @@ int main(int argc, char ** argv) { res.set_header("Access-Control-Allow-Headers", "*"); }); - svr->Get("/health", [&](const httplib::Request & req, httplib::Response & res) { - server_state current_state = state.load(); - switch (current_state) { - case SERVER_STATE_READY: - { - // request slots data using task queue - server_task task; - task.id = ctx_server.queue_tasks.get_new_id(); - task.type = SERVER_TASK_TYPE_METRICS; - task.id_target = -1; - - ctx_server.queue_results.add_waiting_task_id(task.id); - ctx_server.queue_tasks.post(task); - - // get the result - server_task_result result = ctx_server.queue_results.recv(task.id); - ctx_server.queue_results.remove_waiting_task_id(task.id); - - const int n_idle_slots = result.data["idle"]; - const int n_processing_slots = result.data["processing"]; - - json health = { - {"status", "ok"}, - {"slots_idle", n_idle_slots}, - {"slots_processing", n_processing_slots} - }; - - res.status = 200; // HTTP OK - if (sparams.slots_endpoint && req.has_param("include_slots")) { - health["slots"] = result.data["slots"]; - } - - if (n_idle_slots == 0) { - health["status"] = "no slot available"; - if (req.has_param("fail_on_no_slot")) { - res.status = 503; // HTTP Service Unavailable - } - } - - res.set_content(health.dump(), "application/json"); - break; - } - case SERVER_STATE_LOADING_MODEL: - { - res.set_content(R"({"status": "loading model"})", "application/json"); - res.status = 503; // HTTP Service Unavailable - } break; - case SERVER_STATE_ERROR: - { - res.set_content(R"({"status": "error", "error": "Model failed to load"})", "application/json"); - res.status = 500; // HTTP Internal Server Error - } break; - } - }); - - if (sparams.slots_endpoint) { - svr->Get("/slots", [&](const httplib::Request &, httplib::Response & res) { - // request slots data using task queue - server_task task; - task.id = ctx_server.queue_tasks.get_new_id(); - task.id_multi = -1; - task.id_target = -1; - task.type = SERVER_TASK_TYPE_METRICS; - - ctx_server.queue_results.add_waiting_task_id(task.id); - ctx_server.queue_tasks.post(task); - - // get the result - server_task_result result = ctx_server.queue_results.recv(task.id); - ctx_server.queue_results.remove_waiting_task_id(task.id); - - res.set_content(result.data["slots"].dump(), "application/json"); - res.status = 200; // HTTP OK - }); - } - - if (sparams.metrics_endpoint) { - svr->Get("/metrics", [&](const httplib::Request &, httplib::Response & res) { - // request slots data using task queue - server_task task; - task.id = ctx_server.queue_tasks.get_new_id(); - task.id_multi = -1; - task.id_target = -1; - task.type = SERVER_TASK_TYPE_METRICS; - task.data.push_back({{"reset_bucket", true}}); - - ctx_server.queue_results.add_waiting_task_id(task.id); - ctx_server.queue_tasks.post(task); - - // get the result - server_task_result result = ctx_server.queue_results.recv(task.id); - ctx_server.queue_results.remove_waiting_task_id(task.id); - - json data = result.data; - - const uint64_t n_prompt_tokens_processed = data["n_prompt_tokens_processed"]; - const uint64_t t_prompt_processing = data["t_prompt_processing"]; - - const uint64_t n_tokens_predicted = data["n_tokens_predicted"]; - const uint64_t t_tokens_generation = data["t_tokens_generation"]; - - const int32_t kv_cache_used_cells = data["kv_cache_used_cells"]; - - // metrics definition: https://prometheus.io/docs/practices/naming/#metric-names - json all_metrics_def = json { - {"counter", {{ - {"name", "prompt_tokens_total"}, - {"help", "Number of prompt tokens processed."}, - {"value", (uint64_t) data["n_prompt_tokens_processed_total"]} - }, { - {"name", "prompt_seconds_total"}, - {"help", "Prompt process time"}, - {"value", (uint64_t) data["t_prompt_processing_total"] / 1.e3} - }, { - {"name", "tokens_predicted_total"}, - {"help", "Number of generation tokens processed."}, - {"value", (uint64_t) data["n_tokens_predicted_total"]} - }, { - {"name", "tokens_predicted_seconds_total"}, - {"help", "Predict process time"}, - {"value", (uint64_t) data["t_tokens_generation_total"] / 1.e3} - }}}, - {"gauge", {{ - {"name", "prompt_tokens_seconds"}, - {"help", "Average prompt throughput in tokens/s."}, - {"value", n_prompt_tokens_processed ? 1.e3 / t_prompt_processing * n_prompt_tokens_processed : 0.} - },{ - {"name", "predicted_tokens_seconds"}, - {"help", "Average generation throughput in tokens/s."}, - {"value", n_tokens_predicted ? 1.e3 / t_tokens_generation * n_tokens_predicted : 0.} - },{ - {"name", "kv_cache_usage_ratio"}, - {"help", "KV-cache usage. 1 means 100 percent usage."}, - {"value", 1. * kv_cache_used_cells / params.n_ctx} - },{ - {"name", "kv_cache_tokens"}, - {"help", "KV-cache tokens."}, - {"value", (uint64_t) data["kv_cache_tokens_count"]} - },{ - {"name", "requests_processing"}, - {"help", "Number of request processing."}, - {"value", (uint64_t) data["processing"]} - },{ - {"name", "requests_deferred"}, - {"help", "Number of request deferred."}, - {"value", (uint64_t) data["deferred"]} - }}} - }; - - std::stringstream prometheus; - - for (const auto & el : all_metrics_def.items()) { - const auto & type = el.key(); - const auto & metrics_def = el.value(); - - for (const auto & metric_def : metrics_def) { - const std::string name = metric_def["name"]; - const std::string help = metric_def["help"]; - - auto value = json_value(metric_def, "value", 0.); - prometheus << "# HELP llamacpp:" << name << " " << help << "\n" - << "# TYPE llamacpp:" << name << " " << type << "\n" - << "llamacpp:" << name << " " << value << "\n"; - } - } - - const int64_t t_start = data["t_start"]; - res.set_header("Process-Start-Time-Unix", std::to_string(t_start)); - - res.set_content(prometheus.str(), "text/plain; version=0.0.4"); - res.status = 200; // HTTP OK - }); - } - svr->set_logger(log_server_request); svr->set_exception_handler([](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) { @@ -2925,16 +2751,14 @@ int main(int argc, char ** argv) { return 1; } - // Set the base directory for serving static files - svr->set_base_dir(sparams.public_path); - std::unordered_map<std::string, std::string> log_data; log_data["hostname"] = sparams.hostname; log_data["port"] = std::to_string(sparams.port); if (sparams.api_keys.size() == 1) { - log_data["api_key"] = "api_key: ****" + sparams.api_keys[0].substr(sparams.api_keys[0].length() - 4); + auto key = sparams.api_keys[0]; + log_data["api_key"] = "api_key: ****" + key.substr(std::max((int)(key.length() - 4), 0)); } else if (sparams.api_keys.size() > 1) { log_data["api_key"] = "api_key: " + std::to_string(sparams.api_keys.size()) + " keys loaded"; } @@ -2959,13 +2783,37 @@ int main(int argc, char ** argv) { } } - // Middleware for API key validation - auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool { + // + // Middlewares + // + + auto middleware_validate_api_key = [&sparams](const httplib::Request & req, httplib::Response & res) { + // TODO: should we apply API key to all endpoints, including "/health" and "/models"? + static const std::set<std::string> protected_endpoints = { + "/props", + "/completion", + "/completions", + "/v1/completions", + "/chat/completions", + "/v1/chat/completions", + "/infill", + "/tokenize", + "/detokenize", + "/embedding", + "/embeddings", + "/v1/embeddings", + }; + // If API key is not set, skip validation if (sparams.api_keys.empty()) { return true; } + // If path is not in protected_endpoints list, skip validation + if (protected_endpoints.find(req.path) == protected_endpoints.end()) { + return true; + } + // Check for API key in the header auto auth_header = req.get_header_value("Authorization"); @@ -2978,6 +2826,8 @@ int main(int argc, char ** argv) { } // API key is invalid or not provided + // TODO: make another middleware for CORS related logic + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_content("Unauthorized: Invalid API Key", "text/plain; charset=utf-8"); res.status = 401; // Unauthorized @@ -2986,31 +2836,201 @@ int main(int argc, char ** argv) { return false; }; - // this is only called if no index.html is found in the public --path - svr->Get("/", [](const httplib::Request &, httplib::Response & res) { - res.set_content(reinterpret_cast<const char*>(&index_html), index_html_len, "text/html; charset=utf-8"); - return false; + // register server middlewares + svr->set_pre_routing_handler([&middleware_validate_api_key](const httplib::Request & req, httplib::Response & res) { + if (!middleware_validate_api_key(req, res)) { + return httplib::Server::HandlerResponse::Handled; + } + return httplib::Server::HandlerResponse::Unhandled; }); - // this is only called if no index.js is found in the public --path - svr->Get("/index.js", [](const httplib::Request &, httplib::Response & res) { - res.set_content(reinterpret_cast<const char *>(&index_js), index_js_len, "text/javascript; charset=utf-8"); - return false; - }); + // + // Route handlers (or controllers) + // - // this is only called if no index.html is found in the public --path - svr->Get("/completion.js", [](const httplib::Request &, httplib::Response & res) { - res.set_content(reinterpret_cast<const char*>(&completion_js), completion_js_len, "application/javascript; charset=utf-8"); - return false; - }); + const auto handle_health = [&](const httplib::Request & req, httplib::Response & res) { + server_state current_state = state.load(); + switch (current_state) { + case SERVER_STATE_READY: + { + // request slots data using task queue + server_task task; + task.id = ctx_server.queue_tasks.get_new_id(); + task.type = SERVER_TASK_TYPE_METRICS; + task.id_target = -1; - // this is only called if no index.html is found in the public --path - svr->Get("/json-schema-to-grammar.mjs", [](const httplib::Request &, httplib::Response & res) { - res.set_content(reinterpret_cast<const char*>(&json_schema_to_grammar_mjs), json_schema_to_grammar_mjs_len, "application/javascript; charset=utf-8"); - return false; - }); + ctx_server.queue_results.add_waiting_task_id(task.id); + ctx_server.queue_tasks.post(task); + + // get the result + server_task_result result = ctx_server.queue_results.recv(task.id); + ctx_server.queue_results.remove_waiting_task_id(task.id); + + const int n_idle_slots = result.data["idle"]; + const int n_processing_slots = result.data["processing"]; + + json health = { + {"status", "ok"}, + {"slots_idle", n_idle_slots}, + {"slots_processing", n_processing_slots} + }; + + res.status = 200; // HTTP OK + if (sparams.slots_endpoint && req.has_param("include_slots")) { + health["slots"] = result.data["slots"]; + } - svr->Get("/props", [&ctx_server](const httplib::Request & req, httplib::Response & res) { + if (n_idle_slots == 0) { + health["status"] = "no slot available"; + if (req.has_param("fail_on_no_slot")) { + res.status = 503; // HTTP Service Unavailable + } + } + + res.set_content(health.dump(), "application/json"); + break; + } + case SERVER_STATE_LOADING_MODEL: + { + res.set_content(R"({"status": "loading model"})", "application/json"); + res.status = 503; // HTTP Service Unavailable + } break; + case SERVER_STATE_ERROR: + { + res.set_content(R"({"status": "error", "error": "Model failed to load"})", "application/json"); + res.status = 500; // HTTP Internal Server Error + } break; + } + }; + + const auto handle_slots = [&](const httplib::Request &, httplib::Response & res) { + if (!sparams.slots_endpoint) { + res.status = 501; + res.set_content("This server does not support slots endpoint.", "text/plain; charset=utf-8"); + return; + } + + // request slots data using task queue + server_task task; + task.id = ctx_server.queue_tasks.get_new_id(); + task.id_multi = -1; + task.id_target = -1; + task.type = SERVER_TASK_TYPE_METRICS; + + ctx_server.queue_results.add_waiting_task_id(task.id); + ctx_server.queue_tasks.post(task); + + // get the result + server_task_result result = ctx_server.queue_results.recv(task.id); + ctx_server.queue_results.remove_waiting_task_id(task.id); + + res.set_content(result.data["slots"].dump(), "application/json"); + res.status = 200; // HTTP OK + }; + + const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) { + if (!sparams.metrics_endpoint) { + res.status = 501; + res.set_content("This server does not support metrics endpoint.", "text/plain; charset=utf-8"); + return; + } + + // request slots data using task queue + server_task task; + task.id = ctx_server.queue_tasks.get_new_id(); + task.id_multi = -1; + task.id_target = -1; + task.type = SERVER_TASK_TYPE_METRICS; + task.data.push_back({{"reset_bucket", true}}); + + ctx_server.queue_results.add_waiting_task_id(task.id); + ctx_server.queue_tasks.post(task); + + // get the result + server_task_result result = ctx_server.queue_results.recv(task.id); + ctx_server.queue_results.remove_waiting_task_id(task.id); + + json data = result.data; + + const uint64_t n_prompt_tokens_processed = data["n_prompt_tokens_processed"]; + const uint64_t t_prompt_processing = data["t_prompt_processing"]; + + const uint64_t n_tokens_predicted = data["n_tokens_predicted"]; + const uint64_t t_tokens_generation = data["t_tokens_generation"]; + + const int32_t kv_cache_used_cells = data["kv_cache_used_cells"]; + + // metrics definition: https://prometheus.io/docs/practices/naming/#metric-names + json all_metrics_def = json { + {"counter", {{ + {"name", "prompt_tokens_total"}, + {"help", "Number of prompt tokens processed."}, + {"value", (uint64_t) data["n_prompt_tokens_processed_total"]} + }, { + {"name", "prompt_seconds_total"}, + {"help", "Prompt process time"}, + {"value", (uint64_t) data["t_prompt_processing_total"] / 1.e3} + }, { + {"name", "tokens_predicted_total"}, + {"help", "Number of generation tokens processed."}, + {"value", (uint64_t) data["n_tokens_predicted_total"]} + }, { + {"name", "tokens_predicted_seconds_total"}, + {"help", "Predict process time"}, + {"value", (uint64_t) data["t_tokens_generation_total"] / 1.e3} + }}}, + {"gauge", {{ + {"name", "prompt_tokens_seconds"}, + {"help", "Average prompt throughput in tokens/s."}, + {"value", n_prompt_tokens_processed ? 1.e3 / t_prompt_processing * n_prompt_tokens_processed : 0.} + },{ + {"name", "predicted_tokens_seconds"}, + {"help", "Average generation throughput in tokens/s."}, + {"value", n_tokens_predicted ? 1.e3 / t_tokens_generation * n_tokens_predicted : 0.} + },{ + {"name", "kv_cache_usage_ratio"}, + {"help", "KV-cache usage. 1 means 100 percent usage."}, + {"value", 1. * kv_cache_used_cells / params.n_ctx} + },{ + {"name", "kv_cache_tokens"}, + {"help", "KV-cache tokens."}, + {"value", (uint64_t) data["kv_cache_tokens_count"]} + },{ + {"name", "requests_processing"}, + {"help", "Number of request processing."}, + {"value", (uint64_t) data["processing"]} + },{ + {"name", "requests_deferred"}, + {"help", "Number of request deferred."}, + {"value", (uint64_t) data["deferred"]} + }}} + }; + + std::stringstream prometheus; + + for (const auto & el : all_metrics_def.items()) { + const auto & type = el.key(); + const auto & metrics_def = el.value(); + + for (const auto & metric_def : metrics_def) { + const std::string name = metric_def["name"]; + const std::string help = metric_def["help"]; + + auto value = json_value(metric_def, "value", 0.); + prometheus << "# HELP llamacpp:" << name << " " << help << "\n" + << "# TYPE llamacpp:" << name << " " << type << "\n" + << "llamacpp:" << name << " " << value << "\n"; + } + } + + const int64_t t_start = data["t_start"]; + res.set_header("Process-Start-Time-Unix", std::to_string(t_start)); + + res.set_content(prometheus.str(), "text/plain; version=0.0.4"); + res.status = 200; // HTTP OK + }; + + const auto handle_props = [&ctx_server](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json data = { { "user_name", ctx_server.name_user.c_str() }, @@ -3020,13 +3040,10 @@ int main(int argc, char ** argv) { }; res.set_content(data.dump(), "application/json; charset=utf-8"); - }); + }; - const auto completions = [&ctx_server, &validate_api_key](const httplib::Request & req, httplib::Response & res) { + const auto handle_completions = [&ctx_server](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - if (!validate_api_key(req, res)) { - return; - } json data = json::parse(req.body); @@ -3102,11 +3119,7 @@ int main(int argc, char ** argv) { } }; - svr->Post("/completion", completions); // legacy - svr->Post("/completions", completions); - svr->Post("/v1/completions", completions); - - svr->Get("/v1/models", [¶ms, &model_meta](const httplib::Request & req, httplib::Response & res) { + const auto handle_models = [¶ms, &model_meta](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json models = { @@ -3123,14 +3136,10 @@ int main(int argc, char ** argv) { }; res.set_content(models.dump(), "application/json; charset=utf-8"); - }); + }; - const auto chat_completions = [&ctx_server, &validate_api_key, &sparams](const httplib::Request & req, httplib::Response & res) { + const auto handle_chat_completions = [&ctx_server, &sparams](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - if (!validate_api_key(req, res)) { - return; - } - json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), sparams.chat_template); const int id_task = ctx_server.queue_tasks.get_new_id(); @@ -3201,14 +3210,8 @@ int main(int argc, char ** argv) { } }; - svr->Post("/chat/completions", chat_completions); - svr->Post("/v1/chat/completions", chat_completions); - - svr->Post("/infill", [&ctx_server, &validate_api_key](const httplib::Request & req, httplib::Response & res) { + const auto handle_infill = [&ctx_server](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - if (!validate_api_key(req, res)) { - return; - } json data = json::parse(req.body); @@ -3266,13 +3269,9 @@ int main(int argc, char ** argv) { res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); } - }); - - svr->Options(R"(/.*)", [](const httplib::Request &, httplib::Response & res) { - return res.set_content("", "application/json; charset=utf-8"); - }); + }; - svr->Post("/tokenize", [&ctx_server](const httplib::Request & req, httplib::Response & res) { + const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); const json body = json::parse(req.body); @@ -3282,9 +3281,9 @@ int main(int argc, char ** argv) { } const json data = format_tokenizer_response(tokens); return res.set_content(data.dump(), "application/json; charset=utf-8"); - }); + }; - svr->Post("/detokenize", [&ctx_server](const httplib::Request & req, httplib::Response & res) { + const auto handle_detokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); const json body = json::parse(req.body); @@ -3296,9 +3295,9 @@ int main(int argc, char ** argv) { const json data = format_detokenized_response(content); return res.set_content(data.dump(), "application/json; charset=utf-8"); - }); + }; - svr->Post("/embedding", [¶ms, &ctx_server](const httplib::Request & req, httplib::Response & res) { + const auto handle_embeddings = [¶ms, &ctx_server](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); if (!params.embedding) { res.status = 501; @@ -3307,94 +3306,114 @@ int main(int argc, char ** argv) { } const json body = json::parse(req.body); + bool is_openai = false; - json prompt; - if (body.count("content") != 0) { - prompt = body["content"]; + // an input prompt can string or a list of tokens (integer) + std::vector<json> prompts; + if (body.count("input") != 0) { + is_openai = true; + if (body["input"].is_array()) { + // support multiple prompts + for (const json & elem : body["input"]) { + prompts.push_back(elem); + } + } else { + // single input prompt + prompts.push_back(body["input"]); + } + } else if (body.count("content") != 0) { + // only support single prompt here + std::string content = body["content"]; + prompts.push_back(content); } else { - prompt = ""; - } - - // create and queue the task - const int id_task = ctx_server.queue_tasks.get_new_id(); - - ctx_server.queue_results.add_waiting_task_id(id_task); - ctx_server.request_completion(id_task, -1, { {"prompt", prompt}, { "n_predict", 0} }, false, true); - - // get the result - server_task_result result = ctx_server.queue_results.recv(id_task); - ctx_server.queue_results.remove_waiting_task_id(id_task); - - // send the result - return res.set_content(result.data.dump(), "application/json; charset=utf-8"); - }); - - svr->Post("/v1/embeddings", [¶ms, &ctx_server](const httplib::Request & req, httplib::Response & res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - if (!params.embedding) { - res.status = 501; - res.set_content("This server does not support embeddings. Start it with `--embeddings`", "text/plain; charset=utf-8"); - return; + // TODO @ngxson : should return an error here + prompts.push_back(""); } - const json body = json::parse(req.body); - - json prompt; - if (body.count("input") != 0) { - prompt = body["input"]; - if (prompt.is_array()) { - json data = json::array(); - - int i = 0; - for (const json & elem : prompt) { - const int id_task = ctx_server.queue_tasks.get_new_id(); - - ctx_server.queue_results.add_waiting_task_id(id_task); - ctx_server.request_completion(id_task, -1, { {"prompt", elem}, { "n_predict", 0} }, false, true); - - // get the result - server_task_result result = ctx_server.queue_results.recv(id_task); - ctx_server.queue_results.remove_waiting_task_id(id_task); - - json embedding = json{ - {"embedding", json_value(result.data, "embedding", json::array())}, - {"index", i++}, - {"object", "embedding"} - }; + // process all prompts + json responses = json::array(); + for (auto & prompt : prompts) { + // TODO @ngxson : maybe support multitask for this endpoint? + // create and queue the task + const int id_task = ctx_server.queue_tasks.get_new_id(); - data.push_back(embedding); - } - - json result = format_embeddings_response_oaicompat(body, data); + ctx_server.queue_results.add_waiting_task_id(id_task); + ctx_server.request_completion(id_task, -1, { {"prompt", prompt}, { "n_predict", 0}}, false, true); - return res.set_content(result.dump(), "application/json; charset=utf-8"); + // get the result + server_task_result result = ctx_server.queue_results.recv(id_task); + ctx_server.queue_results.remove_waiting_task_id(id_task); + responses.push_back(result.data); + } + + // write JSON response + json root; + if (is_openai) { + json res_oai = json::array(); + int i = 0; + for (auto & elem : responses) { + res_oai.push_back(json{ + {"embedding", json_value(elem, "embedding", json::array())}, + {"index", i++}, + {"object", "embedding"} + }); } + root = format_embeddings_response_oaicompat(body, res_oai); } else { - prompt = ""; + root = responses[0]; } + return res.set_content(root.dump(), "application/json; charset=utf-8"); + }; - // create and queue the task - const int id_task = ctx_server.queue_tasks.get_new_id(); - - ctx_server.queue_results.add_waiting_task_id(id_task); - ctx_server.request_completion(id_task, -1, { {"prompt", prompt}, { "n_predict", 0}}, false, true); + // + // Router + // - // get the result - server_task_result result = ctx_server.queue_results.recv(id_task); - ctx_server.queue_results.remove_waiting_task_id(id_task); - - json data = json::array({json{ - {"embedding", json_value(result.data, "embedding", json::array())}, - {"index", 0}, - {"object", "embedding"} - }} - ); + // register static assets routes + if (!sparams.public_path.empty()) { + // Set the base directory for serving static files + svr->set_base_dir(sparams.public_path); + } - json root = format_embeddings_response_oaicompat(body, data); + // using embedded static files + auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) { + return [content, len, mime_type](const httplib::Request &, httplib::Response & res) { + res.set_content(reinterpret_cast<const char*>(content), len, mime_type); + return false; + }; + }; - return res.set_content(root.dump(), "application/json; charset=utf-8"); + svr->Options(R"(/.*)", [](const httplib::Request &, httplib::Response & res) { + // TODO @ngxson : I have no idea what it is... maybe this is redundant? + return res.set_content("", "application/json; charset=utf-8"); }); + svr->Get("/", handle_static_file(index_html, index_html_len, "text/html; charset=utf-8")); + svr->Get("/index.js", handle_static_file(index_js, index_js_len, "text/javascript; charset=utf-8")); + svr->Get("/completion.js", handle_static_file(completion_js, completion_js_len, "text/javascript; charset=utf-8")); + svr->Get("/json-schema-to-grammar.mjs", handle_static_file( + json_schema_to_grammar_mjs, json_schema_to_grammar_mjs_len, "text/javascript; charset=utf-8")); + + // register API routes + svr->Get ("/health", handle_health); + svr->Get ("/slots", handle_slots); + svr->Get ("/metrics", handle_metrics); + svr->Get ("/props", handle_props); + svr->Get ("/v1/models", handle_models); + svr->Post("/completion", handle_completions); // legacy + svr->Post("/completions", handle_completions); + svr->Post("/v1/completions", handle_completions); + svr->Post("/chat/completions", handle_chat_completions); + svr->Post("/v1/chat/completions", handle_chat_completions); + svr->Post("/infill", handle_infill); + svr->Post("/embedding", handle_embeddings); // legacy + svr->Post("/embeddings", handle_embeddings); + svr->Post("/v1/embeddings", handle_embeddings); + svr->Post("/tokenize", handle_tokenize); + svr->Post("/detokenize", handle_detokenize); + // + // Start the server + // if (sparams.n_threads_http < 1) { // +2 threads for monitoring endpoints sparams.n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1); |