summaryrefslogtreecommitdiff
path: root/examples/server/server.cpp
diff options
context:
space:
mode:
authorJan Boon <jan.boon@kaetemi.be>2024-04-08 20:43:30 +0800
committerGitHub <noreply@github.com>2024-04-08 15:43:30 +0300
commitbeea6e1b16e783a0886e78dec01002a8c00db24d (patch)
treea7365b1e93145b78a8b4be72df959239aa8c0f0d /examples/server/server.cpp
parent87fb5b4234d4b9c56ac94cf7aa229c8fd7defdb0 (diff)
llama : save and restore kv cache for single seq id (#6341)
* llama : save and restore kv cache for single seq id * remove trailing whitespace * respond error in case there's no space in the kv cache * add kv seq save restore to test case * add --slot-save-path arg to enable save restore and restrict save location * Returning 0 for some cases, instead of asserting. * cleanup error cases * rename sequence state functions * rename state get set functions * add previous function names back in with DEPRECATED notice * update doc * adjust endpoints to preferred style * fix restoring zero cell count * handle seq rm return value * unused param * keep in the size check * fix return types * add server test case for slot save restore * cleanup * add cake * cleanup style * add special * removing a whole sequence never fails * move sequence state file functionality from server to llama to match session api and add version tags * catch exceptions on save as well * error log messages * check types for stricter restore * update server doc * readme : update API changes date * strict filename validation * move include, reject bom as well * also reject empty filename * reject whitespace and trailing dot --------- Co-authored-by: Martin Evans <martindevans@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'examples/server/server.cpp')
-rw-r--r--examples/server/server.cpp228
1 files changed, 227 insertions, 1 deletions
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index 183f0d3c..6c64fe3e 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -61,7 +61,10 @@ enum server_task_type {
SERVER_TASK_TYPE_COMPLETION,
SERVER_TASK_TYPE_CANCEL,
SERVER_TASK_TYPE_NEXT_RESPONSE,
- SERVER_TASK_TYPE_METRICS
+ SERVER_TASK_TYPE_METRICS,
+ SERVER_TASK_TYPE_SLOT_SAVE,
+ SERVER_TASK_TYPE_SLOT_RESTORE,
+ SERVER_TASK_TYPE_SLOT_ERASE,
};
struct server_task {
@@ -128,6 +131,7 @@ struct server_params {
bool slots_endpoint = true;
bool metrics_endpoint = false;
+ std::string slot_save_path;
};
struct server_slot {
@@ -1612,6 +1616,107 @@ struct server_context {
}
queue_results.send(res);
} break;
+ case SERVER_TASK_TYPE_SLOT_SAVE:
+ {
+ int id_slot = task.data["id_slot"];
+ server_slot * slot = get_slot(id_slot);
+ if (slot == nullptr) {
+ send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
+ break;
+ }
+
+ const size_t token_count = slot->cache_tokens.size();
+ const int64_t t_start = ggml_time_us();
+
+ std::string filename = task.data["filename"];
+ std::string filepath = task.data["filepath"];
+
+ const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), token_count);
+
+ const int64_t t_end = ggml_time_us();
+ const double t_save_ms = (t_end - t_start) / 1000.0;
+
+ server_task_result result;
+ result.id = task.id;
+ result.stop = true;
+ result.error = false;
+ result.data = json {
+ { "id_slot", id_slot },
+ { "filename", filename },
+ { "n_saved", token_count }, // tokens saved
+ { "n_written", nwrite }, // bytes written
+ { "timings", {
+ { "save_ms", t_save_ms }
+ } }
+ };
+ queue_results.send(result);
+ } break;
+ case SERVER_TASK_TYPE_SLOT_RESTORE:
+ {
+ int id_slot = task.data["id_slot"];
+ server_slot * slot = get_slot(id_slot);
+ if (slot == nullptr) {
+ send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
+ break;
+ }
+
+ const int64_t t_start = ggml_time_us();
+
+ std::string filename = task.data["filename"];
+ std::string filepath = task.data["filepath"];
+
+ slot->cache_tokens.resize(slot->n_ctx);
+ size_t token_count = 0;
+ size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count);
+ if (nread == 0) {
+ slot->cache_tokens.resize(0);
+ send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST);
+ break;
+ }
+ slot->cache_tokens.resize(token_count);
+
+ const int64_t t_end = ggml_time_us();
+ const double t_restore_ms = (t_end - t_start) / 1000.0;
+
+ server_task_result result;
+ result.id = task.id;
+ result.stop = true;
+ result.error = false;
+ result.data = json {
+ { "id_slot", id_slot },
+ { "filename", filename },
+ { "n_restored", token_count }, // tokens restored
+ { "n_read", nread }, // bytes read
+ { "timings", {
+ { "restore_ms", t_restore_ms }
+ } }
+ };
+ queue_results.send(result);
+ } break;
+ case SERVER_TASK_TYPE_SLOT_ERASE:
+ {
+ int id_slot = task.data["id_slot"];
+ server_slot * slot = get_slot(id_slot);
+ if (slot == nullptr) {
+ send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
+ break;
+ }
+
+ // Erase token cache
+ const size_t n_erased = slot->cache_tokens.size();
+ llama_kv_cache_seq_rm(ctx, slot->id + 1, -1, -1);
+ slot->cache_tokens.clear();
+
+ server_task_result result;
+ result.id = task.id;
+ result.stop = true;
+ result.error = false;
+ result.data = json {
+ { "id_slot", id_slot },
+ { "n_erased", n_erased }
+ };
+ queue_results.send(result);
+ } break;
}
}
@@ -2249,6 +2354,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co
printf(" --log-disable disables logging to a file.\n");
printf(" --slots-endpoint-disable disables slots monitoring endpoint.\n");
printf(" --metrics enable prometheus compatible metrics endpoint (default: %s).\n", sparams.metrics_endpoint ? "enabled" : "disabled");
+ printf(" --slot-save-path PATH path to save slot kv cache (default: disabled)\n");
printf("\n");
printf(" -n, --n-predict maximum tokens to predict (default: %d)\n", params.n_predict);
printf(" --override-kv KEY=TYPE:VALUE\n");
@@ -2657,6 +2763,16 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
sparams.slots_endpoint = false;
} else if (arg == "--metrics") {
sparams.metrics_endpoint = true;
+ } else if (arg == "--slot-save-path") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ sparams.slot_save_path = argv[i];
+ // if doesn't end with DIRECTORY_SEPARATOR, add it
+ if (!sparams.slot_save_path.empty() && sparams.slot_save_path[sparams.slot_save_path.size() - 1] != DIRECTORY_SEPARATOR) {
+ sparams.slot_save_path += DIRECTORY_SEPARATOR;
+ }
} else if (arg == "--chat-template") {
if (++i >= argc) {
invalid_param = true;
@@ -3159,6 +3275,112 @@ int main(int argc, char ** argv) {
res.status = 200; // HTTP OK
};
+ const auto handle_slots_save = [&ctx_server, &res_error, &sparams](const httplib::Request & req, httplib::Response & res, int id_slot) {
+ json request_data = json::parse(req.body);
+ std::string filename = request_data["filename"];
+ if (!validate_file_name(filename)) {
+ res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
+ return;
+ }
+ std::string filepath = sparams.slot_save_path + filename;
+
+ server_task task;
+ task.type = SERVER_TASK_TYPE_SLOT_SAVE;
+ task.data = {
+ { "id_slot", id_slot },
+ { "filename", filename },
+ { "filepath", filepath }
+ };
+
+ const int id_task = ctx_server.queue_tasks.post(task);
+ ctx_server.queue_results.add_waiting_task_id(id_task);
+
+ server_task_result result = ctx_server.queue_results.recv(id_task);
+ ctx_server.queue_results.remove_waiting_task_id(id_task);
+
+ if (result.error) {
+ res_error(res, result.data);
+ } else {
+ res.set_content(result.data.dump(), "application/json");
+ }
+ };
+
+ const auto handle_slots_restore = [&ctx_server, &res_error, &sparams](const httplib::Request & req, httplib::Response & res, int id_slot) {
+ json request_data = json::parse(req.body);
+ std::string filename = request_data["filename"];
+ if (!validate_file_name(filename)) {
+ res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
+ return;
+ }
+ std::string filepath = sparams.slot_save_path + filename;
+
+ server_task task;
+ task.type = SERVER_TASK_TYPE_SLOT_RESTORE;
+ task.data = {
+ { "id_slot", id_slot },
+ { "filename", filename },
+ { "filepath", filepath }
+ };
+
+ const int id_task = ctx_server.queue_tasks.post(task);
+ ctx_server.queue_results.add_waiting_task_id(id_task);
+
+ server_task_result result = ctx_server.queue_results.recv(id_task);
+ ctx_server.queue_results.remove_waiting_task_id(id_task);
+
+ if (result.error) {
+ res_error(res, result.data);
+ } else {
+ res.set_content(result.data.dump(), "application/json");
+ }
+ };
+
+ const auto handle_slots_erase = [&ctx_server, &res_error](const httplib::Request & /* req */, httplib::Response & res, int id_slot) {
+ server_task task;
+ task.type = SERVER_TASK_TYPE_SLOT_ERASE;
+ task.data = {
+ { "id_slot", id_slot },
+ };
+
+ const int id_task = ctx_server.queue_tasks.post(task);
+ ctx_server.queue_results.add_waiting_task_id(id_task);
+
+ server_task_result result = ctx_server.queue_results.recv(id_task);
+ ctx_server.queue_results.remove_waiting_task_id(id_task);
+
+ if (result.error) {
+ res_error(res, result.data);
+ } else {
+ res.set_content(result.data.dump(), "application/json");
+ }
+ };
+
+ const auto handle_slots_action = [&res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) {
+ res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
+
+ std::string id_slot_str = req.path_params.at("id_slot");
+ int id_slot;
+
+ try {
+ id_slot = std::stoi(id_slot_str);
+ } catch (const std::exception &) {
+ res_error(res, format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST));
+ return;
+ }
+
+ std::string action = req.get_param_value("action");
+
+ if (action == "save") {
+ handle_slots_save(req, res, id_slot);
+ } else if (action == "restore") {
+ handle_slots_restore(req, res, id_slot);
+ } else if (action == "erase") {
+ handle_slots_erase(req, res, id_slot);
+ } else {
+ res_error(res, format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST));
+ }
+ };
+
const auto handle_props = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json data = {
@@ -3521,6 +3743,10 @@ int main(int argc, char ** argv) {
svr->Post("/v1/embeddings", handle_embeddings);
svr->Post("/tokenize", handle_tokenize);
svr->Post("/detokenize", handle_detokenize);
+ if (!sparams.slot_save_path.empty()) {
+ // only enable slot endpoints if slot_save_path is set
+ svr->Post("/slots/:id_slot", handle_slots_action);
+ }
//
// Start the server