diff options
author | Xuan Son Nguyen <thichthat@gmail.com> | 2024-03-11 10:56:41 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-11 10:56:41 +0100 |
commit | caa106d4e05a0ab94225c220b81f9e2cd522339b (patch) | |
tree | d9acbea2801af1260358cfb1e1b964a26ea6fa1f /examples/server/server.cpp | |
parent | 3202361c5b1ba15e695b31209567ef42c22c5c32 (diff) |
Server: format error to json (#5961)
* server: format error to json
* server: do not crash on grammar error
* fix api key test case
* revert limit max n_predict
* small fix
* correct coding style
* update completion.js
* launch_slot_with_task
* update docs
* update_slots
* update webui
* update readme
Diffstat (limited to 'examples/server/server.cpp')
-rw-r--r-- | examples/server/server.cpp | 159 |
1 files changed, 89 insertions, 70 deletions
diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 3951507a..b63a6f24 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -396,7 +396,7 @@ struct server_queue { // callback functions std::function<void(server_task &)> callback_new_task; std::function<void(server_task_multi &)> callback_finish_multitask; - std::function<void(void)> callback_run_slots; + std::function<void(void)> callback_update_slots; // Add a new task to the end of the queue int post(server_task task) { @@ -435,8 +435,8 @@ struct server_queue { } // Register the function to be called when all slots data is ready to be processed - void on_run_slots(std::function<void(void)> callback) { - callback_run_slots = std::move(callback); + void on_update_slots(std::function<void(void)> callback) { + callback_update_slots = std::move(callback); } // Call when the state of one slot is changed @@ -461,7 +461,7 @@ struct server_queue { * - Wait until a new task arrives * - Process the task (i.e. maybe copy data into slot) * - Check if multitask is finished - * - Run all slots + * - Update all slots */ void start_loop() { running = true; @@ -499,9 +499,9 @@ struct server_queue { } // all tasks in the current loop is processed, slots data is now ready - LOG_VERBOSE("callback_run_slots", {}); + LOG_VERBOSE("callback_update_slots", {}); - callback_run_slots(); + callback_update_slots(); LOG_VERBOSE("wait for new task", {}); { @@ -805,9 +805,10 @@ struct server_context { return last_used; } - bool launch_slot_with_data(server_slot & slot, json data) const { + bool launch_slot_with_task(server_slot & slot, const server_task & task) { slot_params default_params; llama_sampling_params default_sparams; + auto & data = task.data; if (data.count("__oaicompat") != 0) { slot.oaicompat = true; @@ -864,10 +865,15 @@ struct server_context { { const auto & prompt = data.find("prompt"); if (prompt == data.end()) { - slot.prompt = ""; + send_error(task, "Either \"prompt\" or \"messages\" must be provided", ERROR_TYPE_INVALID_REQUEST); + return false; } else { slot.prompt = *prompt; } + if (slot.prompt.is_array() && slot.prompt.size() == 0) { + send_error(task, "\"prompt\" cannot be an empty array", ERROR_TYPE_INVALID_REQUEST); + return false; + } } // penalize user-provided tokens @@ -926,6 +932,7 @@ struct server_context { if (logit_bias != data.end() && logit_bias->is_array()) { const int n_vocab = llama_n_vocab(model); for (const auto & el : *logit_bias) { + // TODO: we may want to throw errors here, in case "el" is incorrect if (el.is_array() && el.size() == 2) { float bias; if (el[1].is_number()) { @@ -985,6 +992,11 @@ struct server_context { llama_sampling_free(slot.ctx_sampling); } slot.ctx_sampling = llama_sampling_init(slot.sparams); + if (slot.ctx_sampling == nullptr) { + // for now, the only error that may happen here is invalid grammar + send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST); + return false; + } llama_set_rng_seed(ctx, slot.params.seed); } @@ -1226,15 +1238,23 @@ struct server_context { }; } - void send_error(const server_task & task, const std::string & error) { - LOG_TEE("task %i - error: %s\n", task.id, error.c_str()); + void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { + send_error(task.id, task.id_multi, error, type); + } + + void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { + send_error(slot.id_task, slot.id_multi, error, type); + } + + void send_error(const int id_task, const int id_multi, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { + LOG_TEE("task %i - error: %s\n", id_task, error.c_str()); server_task_result res; - res.id = task.id; - res.id_multi = task.id_multi; + res.id = id_task; + res.id_multi = id_multi; res.stop = false; res.error = true; - res.data = { { "content", error } }; + res.data = format_error_response(error, type); queue_results.send(res); } @@ -1468,9 +1488,8 @@ struct server_context { slot->infill = task.infill; slot->embedding = task.embedding; - if (!launch_slot_with_data(*slot, task.data)) { - // send error result - send_error(task, "internal_error"); + if (!launch_slot_with_task(*slot, task)) { + LOG_ERROR("error while launching slot", task.data); break; } } break; @@ -1587,7 +1606,7 @@ struct server_context { queue_results.send(result); } - bool update_slots() { + void update_slots() { if (system_need_update) { system_prompt_update(); } @@ -1630,7 +1649,7 @@ struct server_context { kv_cache_clear(); } - return true; + return; } } @@ -1975,8 +1994,7 @@ struct server_context { if (batch.n_tokens == 0) { LOG_VERBOSE("no tokens to decode", {}); - - return true; + return; } LOG_VERBOSE("decoding batch", { @@ -2033,7 +2051,13 @@ struct server_context { if (n_batch == 1 || ret < 0) { // if you get here, it means the KV cache is full - try increasing it via the context size LOG_TEE("%s : failed to decode the batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret); - return false; + for (auto & slot : slots) { + slot.state = SLOT_STATE_PROCESSING; + slot.command = SLOT_COMMAND_NONE; + slot.release(); + send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size."); + } + break; // break loop of n_batch } LOG_TEE("%s : failed to find free space in the KV cache, retrying with smaller n_batch = %d\n", __func__, n_batch / 2); @@ -2042,12 +2066,12 @@ struct server_context { n_batch /= 2; i -= n_batch; - continue; + continue; // continue loop of n_batch } for (auto & slot : slots) { if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) { - continue; + continue; // continue loop of slots } // prompt evaluated for embedding @@ -2055,7 +2079,7 @@ struct server_context { send_embedding(slot, batch_view); slot.release(); slot.i_batch = -1; - continue; + continue; // continue loop of slots } completion_token_output result; @@ -2097,9 +2121,7 @@ struct server_context { } } - LOG_VERBOSE("slots updated", {}); - - return true; + LOG_VERBOSE("run slots completed", {}); } json model_meta() const { @@ -2745,32 +2767,32 @@ int main(int argc, char ** argv) { svr->set_logger(log_server_request); - svr->set_exception_handler([](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) { - const char fmt[] = "500 Internal Server Error\n%s"; + auto res_error = [](httplib::Response & res, json error_data) { + json final_response {{"error", error_data}}; + res.set_content(final_response.dump(), "application/json; charset=utf-8"); + res.status = json_value(error_data, "code", 500); + }; - char buf[BUFSIZ]; + svr->set_exception_handler([&res_error](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) { + std::string message; try { std::rethrow_exception(std::move(ep)); - } catch (std::exception &e) { - snprintf(buf, sizeof(buf), fmt, e.what()); + } catch (std::exception & e) { + message = e.what(); } catch (...) { - snprintf(buf, sizeof(buf), fmt, "Unknown Exception"); + message = "Unknown Exception"; } - res.set_content(buf, "text/plain; charset=utf-8"); - res.status = 500; + json formatted_error = format_error_response(message, ERROR_TYPE_SERVER); + LOG_VERBOSE("Got exception", formatted_error); + res_error(res, formatted_error); }); - svr->set_error_handler([](const httplib::Request &, httplib::Response & res) { - if (res.status == 401) { - res.set_content("Unauthorized", "text/plain; charset=utf-8"); - } - if (res.status == 400) { - res.set_content("Invalid request", "text/plain; charset=utf-8"); - } + svr->set_error_handler([&res_error](const httplib::Request &, httplib::Response & res) { if (res.status == 404) { - res.set_content("File Not Found", "text/plain; charset=utf-8"); + res_error(res, format_error_response("File Not Found", ERROR_TYPE_NOT_FOUND)); } + // for other error codes, we skip processing here because it's already done by res_error() }); // set timeouts and change hostname and port @@ -2835,7 +2857,7 @@ int main(int argc, char ** argv) { // Middlewares // - auto middleware_validate_api_key = [&sparams](const httplib::Request & req, httplib::Response & res) { + auto middleware_validate_api_key = [&sparams, &res_error](const httplib::Request & req, httplib::Response & res) { // TODO: should we apply API key to all endpoints, including "/health" and "/models"? static const std::set<std::string> protected_endpoints = { "/props", @@ -2876,8 +2898,7 @@ int main(int argc, char ** argv) { // API key is invalid or not provided // TODO: make another middleware for CORS related logic res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - res.set_content("Unauthorized: Invalid API Key", "text/plain; charset=utf-8"); - res.status = 401; // Unauthorized + res_error(res, format_error_response("Invalid API Key", ERROR_TYPE_AUTHENTICATION)); LOG_WARNING("Unauthorized: Invalid API Key", {}); @@ -2940,21 +2961,18 @@ int main(int argc, char ** argv) { } case SERVER_STATE_LOADING_MODEL: { - res.set_content(R"({"status": "loading model"})", "application/json"); - res.status = 503; // HTTP Service Unavailable + res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE)); } break; case SERVER_STATE_ERROR: { - res.set_content(R"({"status": "error", "error": "Model failed to load"})", "application/json"); - res.status = 500; // HTTP Internal Server Error + res_error(res, format_error_response("Model failed to load", ERROR_TYPE_SERVER)); } break; } }; const auto handle_slots = [&](const httplib::Request &, httplib::Response & res) { if (!sparams.slots_endpoint) { - res.status = 501; - res.set_content("This server does not support slots endpoint.", "text/plain; charset=utf-8"); + res_error(res, format_error_response("This server does not support slots endpoint.", ERROR_TYPE_NOT_SUPPORTED)); return; } @@ -2978,8 +2996,7 @@ int main(int argc, char ** argv) { const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) { if (!sparams.metrics_endpoint) { - res.status = 501; - res.set_content("This server does not support metrics endpoint.", "text/plain; charset=utf-8"); + res_error(res, format_error_response("This server does not support metrics endpoint.", ERROR_TYPE_NOT_SUPPORTED)); return; } @@ -3090,7 +3107,7 @@ int main(int argc, char ** argv) { res.set_content(data.dump(), "application/json; charset=utf-8"); }; - const auto handle_completions = [&ctx_server](const httplib::Request & req, httplib::Response & res) { + const auto handle_completions = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json data = json::parse(req.body); @@ -3105,8 +3122,7 @@ int main(int argc, char ** argv) { if (!result.error && result.stop) { res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8"); } else { - res.status = 500; - res.set_content(result.data["content"], "text/plain; charset=utf-8"); + res_error(res, result.data); } ctx_server.queue_results.remove_waiting_task_id(id_task); @@ -3186,7 +3202,7 @@ int main(int argc, char ** argv) { res.set_content(models.dump(), "application/json; charset=utf-8"); }; - const auto handle_chat_completions = [&ctx_server, &sparams](const httplib::Request & req, httplib::Response & res) { + const auto handle_chat_completions = [&ctx_server, &sparams, &res_error](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), sparams.chat_template); @@ -3204,8 +3220,7 @@ int main(int argc, char ** argv) { res.set_content(result_oai.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8"); } else { - res.status = 500; - res.set_content(result.data["content"], "text/plain; charset=utf-8"); + res_error(res, result.data); } ctx_server.queue_results.remove_waiting_task_id(id_task); } else { @@ -3259,7 +3274,7 @@ int main(int argc, char ** argv) { } }; - const auto handle_infill = [&ctx_server](const httplib::Request & req, httplib::Response & res) { + const auto handle_infill = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json data = json::parse(req.body); @@ -3274,8 +3289,7 @@ int main(int argc, char ** argv) { if (!result.error && result.stop) { res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8"); } else { - res.status = 404; - res.set_content(result.data["content"], "text/plain; charset=utf-8"); + res_error(res, result.data); } ctx_server.queue_results.remove_waiting_task_id(id_task); @@ -3346,7 +3360,7 @@ int main(int argc, char ** argv) { return res.set_content(data.dump(), "application/json; charset=utf-8"); }; - const auto handle_embeddings = [¶ms, &ctx_server](const httplib::Request & req, httplib::Response & res) { + const auto handle_embeddings = [¶ms, &ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); if (!params.embedding) { res.status = 501; @@ -3375,8 +3389,8 @@ int main(int argc, char ** argv) { std::string content = body["content"]; prompts.push_back(content); } else { - // TODO @ngxson : should return an error here - prompts.push_back(""); + res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST)); + return; } // process all prompts @@ -3392,9 +3406,14 @@ int main(int argc, char ** argv) { // get the result server_task_result result = ctx_server.queue_results.recv(id_task); ctx_server.queue_results.remove_waiting_task_id(id_task); - - // append to the responses - responses.push_back(result.data); + if (!result.error) { + // append to the responses + responses.push_back(result.data); + } else { + // error received, ignore everything else + res_error(res, result.data); + return; + } } // write JSON response @@ -3488,7 +3507,7 @@ int main(int argc, char ** argv) { &server_context::process_single_task, &ctx_server, std::placeholders::_1)); ctx_server.queue_tasks.on_finish_multitask(std::bind( &server_context::on_finish_multitask, &ctx_server, std::placeholders::_1)); - ctx_server.queue_tasks.on_run_slots(std::bind( + ctx_server.queue_tasks.on_update_slots(std::bind( &server_context::update_slots, &ctx_server)); ctx_server.queue_results.on_multitask_update(std::bind( &server_queue::update_multitask, |