diff options
author | Jan Boon <jan.boon@kaetemi.be> | 2024-04-08 20:43:30 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-08 15:43:30 +0300 |
commit | beea6e1b16e783a0886e78dec01002a8c00db24d (patch) | |
tree | a7365b1e93145b78a8b4be72df959239aa8c0f0d /examples/server/server.cpp | |
parent | 87fb5b4234d4b9c56ac94cf7aa229c8fd7defdb0 (diff) |
llama : save and restore kv cache for single seq id (#6341)
* llama : save and restore kv cache for single seq id
* remove trailing whitespace
* respond error in case there's no space in the kv cache
* add kv seq save restore to test case
* add --slot-save-path arg to enable save restore and restrict save location
* Returning 0 for some cases, instead of asserting.
* cleanup error cases
* rename sequence state functions
* rename state get set functions
* add previous function names back in with DEPRECATED notice
* update doc
* adjust endpoints to preferred style
* fix restoring zero cell count
* handle seq rm return value
* unused param
* keep in the size check
* fix return types
* add server test case for slot save restore
* cleanup
* add cake
* cleanup style
* add special
* removing a whole sequence never fails
* move sequence state file functionality from server to llama to match session api and add version tags
* catch exceptions on save as well
* error log messages
* check types for stricter restore
* update server doc
* readme : update API changes date
* strict filename validation
* move include, reject bom as well
* also reject empty filename
* reject whitespace and trailing dot
---------
Co-authored-by: Martin Evans <martindevans@gmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'examples/server/server.cpp')
-rw-r--r-- | examples/server/server.cpp | 228 |
1 files changed, 227 insertions, 1 deletions
diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 183f0d3c..6c64fe3e 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -61,7 +61,10 @@ enum server_task_type { SERVER_TASK_TYPE_COMPLETION, SERVER_TASK_TYPE_CANCEL, SERVER_TASK_TYPE_NEXT_RESPONSE, - SERVER_TASK_TYPE_METRICS + SERVER_TASK_TYPE_METRICS, + SERVER_TASK_TYPE_SLOT_SAVE, + SERVER_TASK_TYPE_SLOT_RESTORE, + SERVER_TASK_TYPE_SLOT_ERASE, }; struct server_task { @@ -128,6 +131,7 @@ struct server_params { bool slots_endpoint = true; bool metrics_endpoint = false; + std::string slot_save_path; }; struct server_slot { @@ -1612,6 +1616,107 @@ struct server_context { } queue_results.send(res); } break; + case SERVER_TASK_TYPE_SLOT_SAVE: + { + int id_slot = task.data["id_slot"]; + server_slot * slot = get_slot(id_slot); + if (slot == nullptr) { + send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); + break; + } + + const size_t token_count = slot->cache_tokens.size(); + const int64_t t_start = ggml_time_us(); + + std::string filename = task.data["filename"]; + std::string filepath = task.data["filepath"]; + + const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), token_count); + + const int64_t t_end = ggml_time_us(); + const double t_save_ms = (t_end - t_start) / 1000.0; + + server_task_result result; + result.id = task.id; + result.stop = true; + result.error = false; + result.data = json { + { "id_slot", id_slot }, + { "filename", filename }, + { "n_saved", token_count }, // tokens saved + { "n_written", nwrite }, // bytes written + { "timings", { + { "save_ms", t_save_ms } + } } + }; + queue_results.send(result); + } break; + case SERVER_TASK_TYPE_SLOT_RESTORE: + { + int id_slot = task.data["id_slot"]; + server_slot * slot = get_slot(id_slot); + if (slot == nullptr) { + send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); + break; + } + + const int64_t t_start = ggml_time_us(); + + std::string filename = task.data["filename"]; + std::string filepath = task.data["filepath"]; + + slot->cache_tokens.resize(slot->n_ctx); + size_t token_count = 0; + size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count); + if (nread == 0) { + slot->cache_tokens.resize(0); + send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST); + break; + } + slot->cache_tokens.resize(token_count); + + const int64_t t_end = ggml_time_us(); + const double t_restore_ms = (t_end - t_start) / 1000.0; + + server_task_result result; + result.id = task.id; + result.stop = true; + result.error = false; + result.data = json { + { "id_slot", id_slot }, + { "filename", filename }, + { "n_restored", token_count }, // tokens restored + { "n_read", nread }, // bytes read + { "timings", { + { "restore_ms", t_restore_ms } + } } + }; + queue_results.send(result); + } break; + case SERVER_TASK_TYPE_SLOT_ERASE: + { + int id_slot = task.data["id_slot"]; + server_slot * slot = get_slot(id_slot); + if (slot == nullptr) { + send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); + break; + } + + // Erase token cache + const size_t n_erased = slot->cache_tokens.size(); + llama_kv_cache_seq_rm(ctx, slot->id + 1, -1, -1); + slot->cache_tokens.clear(); + + server_task_result result; + result.id = task.id; + result.stop = true; + result.error = false; + result.data = json { + { "id_slot", id_slot }, + { "n_erased", n_erased } + }; + queue_results.send(result); + } break; } } @@ -2249,6 +2354,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co printf(" --log-disable disables logging to a file.\n"); printf(" --slots-endpoint-disable disables slots monitoring endpoint.\n"); printf(" --metrics enable prometheus compatible metrics endpoint (default: %s).\n", sparams.metrics_endpoint ? "enabled" : "disabled"); + printf(" --slot-save-path PATH path to save slot kv cache (default: disabled)\n"); printf("\n"); printf(" -n, --n-predict maximum tokens to predict (default: %d)\n", params.n_predict); printf(" --override-kv KEY=TYPE:VALUE\n"); @@ -2657,6 +2763,16 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams, sparams.slots_endpoint = false; } else if (arg == "--metrics") { sparams.metrics_endpoint = true; + } else if (arg == "--slot-save-path") { + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.slot_save_path = argv[i]; + // if doesn't end with DIRECTORY_SEPARATOR, add it + if (!sparams.slot_save_path.empty() && sparams.slot_save_path[sparams.slot_save_path.size() - 1] != DIRECTORY_SEPARATOR) { + sparams.slot_save_path += DIRECTORY_SEPARATOR; + } } else if (arg == "--chat-template") { if (++i >= argc) { invalid_param = true; @@ -3159,6 +3275,112 @@ int main(int argc, char ** argv) { res.status = 200; // HTTP OK }; + const auto handle_slots_save = [&ctx_server, &res_error, &sparams](const httplib::Request & req, httplib::Response & res, int id_slot) { + json request_data = json::parse(req.body); + std::string filename = request_data["filename"]; + if (!validate_file_name(filename)) { + res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); + return; + } + std::string filepath = sparams.slot_save_path + filename; + + server_task task; + task.type = SERVER_TASK_TYPE_SLOT_SAVE; + task.data = { + { "id_slot", id_slot }, + { "filename", filename }, + { "filepath", filepath } + }; + + const int id_task = ctx_server.queue_tasks.post(task); + ctx_server.queue_results.add_waiting_task_id(id_task); + + server_task_result result = ctx_server.queue_results.recv(id_task); + ctx_server.queue_results.remove_waiting_task_id(id_task); + + if (result.error) { + res_error(res, result.data); + } else { + res.set_content(result.data.dump(), "application/json"); + } + }; + + const auto handle_slots_restore = [&ctx_server, &res_error, &sparams](const httplib::Request & req, httplib::Response & res, int id_slot) { + json request_data = json::parse(req.body); + std::string filename = request_data["filename"]; + if (!validate_file_name(filename)) { + res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); + return; + } + std::string filepath = sparams.slot_save_path + filename; + + server_task task; + task.type = SERVER_TASK_TYPE_SLOT_RESTORE; + task.data = { + { "id_slot", id_slot }, + { "filename", filename }, + { "filepath", filepath } + }; + + const int id_task = ctx_server.queue_tasks.post(task); + ctx_server.queue_results.add_waiting_task_id(id_task); + + server_task_result result = ctx_server.queue_results.recv(id_task); + ctx_server.queue_results.remove_waiting_task_id(id_task); + + if (result.error) { + res_error(res, result.data); + } else { + res.set_content(result.data.dump(), "application/json"); + } + }; + + const auto handle_slots_erase = [&ctx_server, &res_error](const httplib::Request & /* req */, httplib::Response & res, int id_slot) { + server_task task; + task.type = SERVER_TASK_TYPE_SLOT_ERASE; + task.data = { + { "id_slot", id_slot }, + }; + + const int id_task = ctx_server.queue_tasks.post(task); + ctx_server.queue_results.add_waiting_task_id(id_task); + + server_task_result result = ctx_server.queue_results.recv(id_task); + ctx_server.queue_results.remove_waiting_task_id(id_task); + + if (result.error) { + res_error(res, result.data); + } else { + res.set_content(result.data.dump(), "application/json"); + } + }; + + const auto handle_slots_action = [&res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + + std::string id_slot_str = req.path_params.at("id_slot"); + int id_slot; + + try { + id_slot = std::stoi(id_slot_str); + } catch (const std::exception &) { + res_error(res, format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST)); + return; + } + + std::string action = req.get_param_value("action"); + + if (action == "save") { + handle_slots_save(req, res, id_slot); + } else if (action == "restore") { + handle_slots_restore(req, res, id_slot); + } else if (action == "erase") { + handle_slots_erase(req, res, id_slot); + } else { + res_error(res, format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST)); + } + }; + const auto handle_props = [&ctx_server](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json data = { @@ -3521,6 +3743,10 @@ int main(int argc, char ** argv) { svr->Post("/v1/embeddings", handle_embeddings); svr->Post("/tokenize", handle_tokenize); svr->Post("/detokenize", handle_detokenize); + if (!sparams.slot_save_path.empty()) { + // only enable slot endpoints if slot_save_path is set + svr->Post("/slots/:id_slot", handle_slots_action); + } // // Start the server |