diff options
Diffstat (limited to 'examples/server')
-rw-r--r-- | examples/server/README.md | 52 | ||||
-rw-r--r-- | examples/server/server.cpp | 228 | ||||
-rw-r--r-- | examples/server/tests/features/slotsave.feature | 58 | ||||
-rw-r--r-- | examples/server/tests/features/steps/steps.py | 60 |
4 files changed, 397 insertions, 1 deletions
diff --git a/examples/server/README.md b/examples/server/README.md index 0d8564a1..a6fc92ea 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -57,6 +57,7 @@ page cache before using this. See https://github.com/ggerganov/llama.cpp/issues/ - `-n N, --n-predict N`: Set the maximum tokens to predict. Default: `-1` - `--slots-endpoint-disable`: To disable slots state monitoring endpoint. Slots state may contain user data, prompts included. - `--metrics`: enable prometheus `/metrics` compatible endpoint. Default: disabled +- `--slot-save-path PATH`: Specifies the path where the state of slots (the prompt cache) can be stored. If not provided, the slot management endpoints will be disabled. - `--chat-template JINJA_TEMPLATE`: Set custom jinja chat template. This parameter accepts a string, not a file name. Default: template taken from model's metadata. We only support [some pre-defined templates](https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template) - `--log-disable`: Output logs to stdout only, not to `llama.log`. Default: enabled - `--log-format FORMAT`: Define the log output to FORMAT: json or text Default: `json` @@ -517,6 +518,57 @@ Available metrics: - `llamacpp:requests_processing`: Number of requests processing. - `llamacpp:requests_deferred`: Number of requests deferred. +- **POST** `/slots/{id_slot}?action=save`: Save the prompt cache of the specified slot to a file. + + *Options:* + + `filename`: Name of the file to save the slot's prompt cache. The file will be saved in the directory specified by the `--slot-save-path` server parameter. + +### Result JSON + +```json +{ + "id_slot": 0, + "filename": "slot_save_file.bin", + "n_saved": 1745, + "n_written": 14309796, + "timings": { + "save_ms": 49.865 + } +} +``` + +- **POST** `/slots/{id_slot}?action=restore`: Restore the prompt cache of the specified slot from a file. + + *Options:* + + `filename`: Name of the file to restore the slot's prompt cache from. The file should be located in the directory specified by the `--slot-save-path` server parameter. + +### Result JSON + +```json +{ + "id_slot": 0, + "filename": "slot_save_file.bin", + "n_restored": 1745, + "n_read": 14309796, + "timings": { + "restore_ms": 42.937 + } +} +``` + +- **POST** `/slots/{id_slot}?action=erase`: Erase the prompt cache of the specified slot. + +### Result JSON + +```json +{ + "id_slot": 0, + "n_erased": 1745 +} +``` + ## More examples ### Change system prompt on runtime diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 183f0d3c..6c64fe3e 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -61,7 +61,10 @@ enum server_task_type { SERVER_TASK_TYPE_COMPLETION, SERVER_TASK_TYPE_CANCEL, SERVER_TASK_TYPE_NEXT_RESPONSE, - SERVER_TASK_TYPE_METRICS + SERVER_TASK_TYPE_METRICS, + SERVER_TASK_TYPE_SLOT_SAVE, + SERVER_TASK_TYPE_SLOT_RESTORE, + SERVER_TASK_TYPE_SLOT_ERASE, }; struct server_task { @@ -128,6 +131,7 @@ struct server_params { bool slots_endpoint = true; bool metrics_endpoint = false; + std::string slot_save_path; }; struct server_slot { @@ -1612,6 +1616,107 @@ struct server_context { } queue_results.send(res); } break; + case SERVER_TASK_TYPE_SLOT_SAVE: + { + int id_slot = task.data["id_slot"]; + server_slot * slot = get_slot(id_slot); + if (slot == nullptr) { + send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); + break; + } + + const size_t token_count = slot->cache_tokens.size(); + const int64_t t_start = ggml_time_us(); + + std::string filename = task.data["filename"]; + std::string filepath = task.data["filepath"]; + + const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), token_count); + + const int64_t t_end = ggml_time_us(); + const double t_save_ms = (t_end - t_start) / 1000.0; + + server_task_result result; + result.id = task.id; + result.stop = true; + result.error = false; + result.data = json { + { "id_slot", id_slot }, + { "filename", filename }, + { "n_saved", token_count }, // tokens saved + { "n_written", nwrite }, // bytes written + { "timings", { + { "save_ms", t_save_ms } + } } + }; + queue_results.send(result); + } break; + case SERVER_TASK_TYPE_SLOT_RESTORE: + { + int id_slot = task.data["id_slot"]; + server_slot * slot = get_slot(id_slot); + if (slot == nullptr) { + send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); + break; + } + + const int64_t t_start = ggml_time_us(); + + std::string filename = task.data["filename"]; + std::string filepath = task.data["filepath"]; + + slot->cache_tokens.resize(slot->n_ctx); + size_t token_count = 0; + size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count); + if (nread == 0) { + slot->cache_tokens.resize(0); + send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST); + break; + } + slot->cache_tokens.resize(token_count); + + const int64_t t_end = ggml_time_us(); + const double t_restore_ms = (t_end - t_start) / 1000.0; + + server_task_result result; + result.id = task.id; + result.stop = true; + result.error = false; + result.data = json { + { "id_slot", id_slot }, + { "filename", filename }, + { "n_restored", token_count }, // tokens restored + { "n_read", nread }, // bytes read + { "timings", { + { "restore_ms", t_restore_ms } + } } + }; + queue_results.send(result); + } break; + case SERVER_TASK_TYPE_SLOT_ERASE: + { + int id_slot = task.data["id_slot"]; + server_slot * slot = get_slot(id_slot); + if (slot == nullptr) { + send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); + break; + } + + // Erase token cache + const size_t n_erased = slot->cache_tokens.size(); + llama_kv_cache_seq_rm(ctx, slot->id + 1, -1, -1); + slot->cache_tokens.clear(); + + server_task_result result; + result.id = task.id; + result.stop = true; + result.error = false; + result.data = json { + { "id_slot", id_slot }, + { "n_erased", n_erased } + }; + queue_results.send(result); + } break; } } @@ -2249,6 +2354,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co printf(" --log-disable disables logging to a file.\n"); printf(" --slots-endpoint-disable disables slots monitoring endpoint.\n"); printf(" --metrics enable prometheus compatible metrics endpoint (default: %s).\n", sparams.metrics_endpoint ? "enabled" : "disabled"); + printf(" --slot-save-path PATH path to save slot kv cache (default: disabled)\n"); printf("\n"); printf(" -n, --n-predict maximum tokens to predict (default: %d)\n", params.n_predict); printf(" --override-kv KEY=TYPE:VALUE\n"); @@ -2657,6 +2763,16 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams, sparams.slots_endpoint = false; } else if (arg == "--metrics") { sparams.metrics_endpoint = true; + } else if (arg == "--slot-save-path") { + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.slot_save_path = argv[i]; + // if doesn't end with DIRECTORY_SEPARATOR, add it + if (!sparams.slot_save_path.empty() && sparams.slot_save_path[sparams.slot_save_path.size() - 1] != DIRECTORY_SEPARATOR) { + sparams.slot_save_path += DIRECTORY_SEPARATOR; + } } else if (arg == "--chat-template") { if (++i >= argc) { invalid_param = true; @@ -3159,6 +3275,112 @@ int main(int argc, char ** argv) { res.status = 200; // HTTP OK }; + const auto handle_slots_save = [&ctx_server, &res_error, &sparams](const httplib::Request & req, httplib::Response & res, int id_slot) { + json request_data = json::parse(req.body); + std::string filename = request_data["filename"]; + if (!validate_file_name(filename)) { + res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); + return; + } + std::string filepath = sparams.slot_save_path + filename; + + server_task task; + task.type = SERVER_TASK_TYPE_SLOT_SAVE; + task.data = { + { "id_slot", id_slot }, + { "filename", filename }, + { "filepath", filepath } + }; + + const int id_task = ctx_server.queue_tasks.post(task); + ctx_server.queue_results.add_waiting_task_id(id_task); + + server_task_result result = ctx_server.queue_results.recv(id_task); + ctx_server.queue_results.remove_waiting_task_id(id_task); + + if (result.error) { + res_error(res, result.data); + } else { + res.set_content(result.data.dump(), "application/json"); + } + }; + + const auto handle_slots_restore = [&ctx_server, &res_error, &sparams](const httplib::Request & req, httplib::Response & res, int id_slot) { + json request_data = json::parse(req.body); + std::string filename = request_data["filename"]; + if (!validate_file_name(filename)) { + res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); + return; + } + std::string filepath = sparams.slot_save_path + filename; + + server_task task; + task.type = SERVER_TASK_TYPE_SLOT_RESTORE; + task.data = { + { "id_slot", id_slot }, + { "filename", filename }, + { "filepath", filepath } + }; + + const int id_task = ctx_server.queue_tasks.post(task); + ctx_server.queue_results.add_waiting_task_id(id_task); + + server_task_result result = ctx_server.queue_results.recv(id_task); + ctx_server.queue_results.remove_waiting_task_id(id_task); + + if (result.error) { + res_error(res, result.data); + } else { + res.set_content(result.data.dump(), "application/json"); + } + }; + + const auto handle_slots_erase = [&ctx_server, &res_error](const httplib::Request & /* req */, httplib::Response & res, int id_slot) { + server_task task; + task.type = SERVER_TASK_TYPE_SLOT_ERASE; + task.data = { + { "id_slot", id_slot }, + }; + + const int id_task = ctx_server.queue_tasks.post(task); + ctx_server.queue_results.add_waiting_task_id(id_task); + + server_task_result result = ctx_server.queue_results.recv(id_task); + ctx_server.queue_results.remove_waiting_task_id(id_task); + + if (result.error) { + res_error(res, result.data); + } else { + res.set_content(result.data.dump(), "application/json"); + } + }; + + const auto handle_slots_action = [&res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + + std::string id_slot_str = req.path_params.at("id_slot"); + int id_slot; + + try { + id_slot = std::stoi(id_slot_str); + } catch (const std::exception &) { + res_error(res, format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST)); + return; + } + + std::string action = req.get_param_value("action"); + + if (action == "save") { + handle_slots_save(req, res, id_slot); + } else if (action == "restore") { + handle_slots_restore(req, res, id_slot); + } else if (action == "erase") { + handle_slots_erase(req, res, id_slot); + } else { + res_error(res, format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST)); + } + }; + const auto handle_props = [&ctx_server](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json data = { @@ -3521,6 +3743,10 @@ int main(int argc, char ** argv) { svr->Post("/v1/embeddings", handle_embeddings); svr->Post("/tokenize", handle_tokenize); svr->Post("/detokenize", handle_detokenize); + if (!sparams.slot_save_path.empty()) { + // only enable slot endpoints if slot_save_path is set + svr->Post("/slots/:id_slot", handle_slots_action); + } // // Start the server diff --git a/examples/server/tests/features/slotsave.feature b/examples/server/tests/features/slotsave.feature new file mode 100644 index 00000000..1c281c07 --- /dev/null +++ b/examples/server/tests/features/slotsave.feature @@ -0,0 +1,58 @@ +@llama.cpp +@slotsave +Feature: llama.cpp server slot management + + Background: Server startup + Given a server listening on localhost:8080 + And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models + And prompt caching is enabled + And 2 slots + And . as slot save path + And 2048 KV cache size + And 42 as server seed + And 24 max tokens to predict + Then the server is starting + Then the server is healthy + + Scenario: Save and Restore Slot + # First prompt in slot 1 should be fully processed + Given a user prompt "What is the capital of France?" + And using slot id 1 + And a completion request with no api error + Then 24 tokens are predicted matching (Lily|cake) + And 22 prompt tokens are processed + When the slot 1 is saved with filename "slot1.bin" + Then the server responds with status code 200 + # Since we have cache, this should only process the last tokens + Given a user prompt "What is the capital of Germany?" + And a completion request with no api error + Then 24 tokens are predicted matching (Thank|special) + And 7 prompt tokens are processed + # Loading the original cache into slot 0, + # we should only be processing 1 prompt token and get the same output + When the slot 0 is restored with filename "slot1.bin" + Then the server responds with status code 200 + Given a user prompt "What is the capital of France?" + And using slot id 0 + And a completion request with no api error + Then 24 tokens are predicted matching (Lily|cake) + And 1 prompt tokens are processed + # For verification that slot 1 was not corrupted during slot 0 load, same thing + Given a user prompt "What is the capital of Germany?" + And using slot id 1 + And a completion request with no api error + Then 24 tokens are predicted matching (Thank|special) + And 1 prompt tokens are processed + + Scenario: Erase Slot + Given a user prompt "What is the capital of France?" + And using slot id 1 + And a completion request with no api error + Then 24 tokens are predicted matching (Lily|cake) + And 22 prompt tokens are processed + When the slot 1 is erased + Then the server responds with status code 200 + Given a user prompt "What is the capital of France?" + And a completion request with no api error + Then 24 tokens are predicted matching (Lily|cake) + And 22 prompt tokens are processed diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index 9a6cf7d6..ca400efa 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -49,6 +49,9 @@ def step_server_config(context, server_fqdn, server_port): context.n_predict = None context.n_prompts = 0 context.n_server_predict = None + context.slot_save_path = None + context.id_slot = None + context.cache_prompt = None context.n_slots = None context.prompt_prefix = None context.prompt_suffix = None @@ -119,6 +122,21 @@ def step_server_n_predict(context, n_predict): context.n_server_predict = n_predict +@step('{slot_save_path} as slot save path') +def step_slot_save_path(context, slot_save_path): + context.slot_save_path = slot_save_path + + +@step('using slot id {id_slot:d}') +def step_id_slot(context, id_slot): + context.id_slot = id_slot + + +@step('prompt caching is enabled') +def step_enable_prompt_cache(context): + context.cache_prompt = True + + @step('continuous batching') def step_server_continuous_batching(context): context.server_continuous_batching = True @@ -212,6 +230,8 @@ async def step_request_completion(context, api_error): context.base_url, debug=context.debug, n_predict=context.n_predict, + cache_prompt=context.cache_prompt, + id_slot=context.id_slot, seed=await completions_seed(context), expect_api_error=expect_api_error, user_api_key=context.user_api_key) @@ -711,12 +731,48 @@ async def concurrent_requests(context, f_completion, *args, **kwargs): await asyncio.sleep(0.1) +@step('the slot {slot_id:d} is saved with filename "{filename}"') +@async_run_until_complete +async def step_save_slot(context, slot_id, filename): + async with aiohttp.ClientSession() as session: + async with session.post(f'{context.base_url}/slots/{slot_id}?action=save', + json={"filename": filename}, + headers={"Content-Type": "application/json"}) as response: + context.response = response + + +@step('the slot {slot_id:d} is restored with filename "{filename}"') +@async_run_until_complete +async def step_restore_slot(context, slot_id, filename): + async with aiohttp.ClientSession() as session: + async with session.post(f'{context.base_url}/slots/{slot_id}?action=restore', + json={"filename": filename}, + headers={"Content-Type": "application/json"}) as response: + context.response = response + + +@step('the slot {slot_id:d} is erased') +@async_run_until_complete +async def step_erase_slot(context, slot_id): + async with aiohttp.ClientSession() as session: + async with session.post(f'{context.base_url}/slots/{slot_id}?action=erase', + headers={"Content-Type": "application/json"}) as response: + context.response = response + + +@step('the server responds with status code {status_code:d}') +def step_server_responds_with_status_code(context, status_code): + assert context.response.status == status_code + + async def request_completion(prompt, base_url, debug=False, prompt_prefix=None, prompt_suffix=None, n_predict=None, + cache_prompt=False, + id_slot=None, seed=None, expect_api_error=None, user_api_key=None): @@ -738,6 +794,8 @@ async def request_completion(prompt, "prompt": prompt, "input_suffix": prompt_suffix, "n_predict": n_predict if n_predict is not None else -1, + "cache_prompt": cache_prompt, + "id_slot": id_slot, "seed": seed if seed is not None else 42 }, headers=headers, @@ -1104,6 +1162,8 @@ def start_server_background(context): server_args.extend(['--parallel', context.n_slots]) if context.n_server_predict: server_args.extend(['--n-predict', context.n_server_predict]) + if context.slot_save_path: + server_args.extend(['--slot-save-path', context.slot_save_path]) if context.server_api_key: server_args.extend(['--api-key', context.server_api_key]) if context.n_ga: |