diff options
author | Jan Boon <jan.boon@kaetemi.be> | 2024-04-08 20:43:30 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-08 15:43:30 +0300 |
commit | beea6e1b16e783a0886e78dec01002a8c00db24d (patch) | |
tree | a7365b1e93145b78a8b4be72df959239aa8c0f0d /examples | |
parent | 87fb5b4234d4b9c56ac94cf7aa229c8fd7defdb0 (diff) |
llama : save and restore kv cache for single seq id (#6341)
* llama : save and restore kv cache for single seq id
* remove trailing whitespace
* respond error in case there's no space in the kv cache
* add kv seq save restore to test case
* add --slot-save-path arg to enable save restore and restrict save location
* Returning 0 for some cases, instead of asserting.
* cleanup error cases
* rename sequence state functions
* rename state get set functions
* add previous function names back in with DEPRECATED notice
* update doc
* adjust endpoints to preferred style
* fix restoring zero cell count
* handle seq rm return value
* unused param
* keep in the size check
* fix return types
* add server test case for slot save restore
* cleanup
* add cake
* cleanup style
* add special
* removing a whole sequence never fails
* move sequence state file functionality from server to llama to match session api and add version tags
* catch exceptions on save as well
* error log messages
* check types for stricter restore
* update server doc
* readme : update API changes date
* strict filename validation
* move include, reject bom as well
* also reject empty filename
* reject whitespace and trailing dot
---------
Co-authored-by: Martin Evans <martindevans@gmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'examples')
-rw-r--r-- | examples/main/main.cpp | 6 | ||||
-rw-r--r-- | examples/save-load-state/save-load-state.cpp | 101 | ||||
-rw-r--r-- | examples/server/README.md | 52 | ||||
-rw-r--r-- | examples/server/server.cpp | 228 | ||||
-rw-r--r-- | examples/server/tests/features/slotsave.feature | 58 | ||||
-rw-r--r-- | examples/server/tests/features/steps/steps.py | 60 |
6 files changed, 495 insertions, 10 deletions
diff --git a/examples/main/main.cpp b/examples/main/main.cpp index e2d07a63..711f162d 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -235,7 +235,7 @@ int main(int argc, char ** argv) { // The file exists and is not empty session_tokens.resize(n_ctx); size_t n_token_count_out = 0; - if (!llama_load_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) { + if (!llama_state_load_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) { LOG_TEE("%s: error: failed to load session file '%s'\n", __func__, path_session.c_str()); return 1; } @@ -693,7 +693,7 @@ int main(int argc, char ** argv) { // optionally save the session on first sample (for faster prompt loading next time) if (!path_session.empty() && need_to_save_session && !params.prompt_cache_ro) { need_to_save_session = false; - llama_save_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size()); + llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size()); LOG("saved session to %s\n", path_session.c_str()); } @@ -935,7 +935,7 @@ int main(int argc, char ** argv) { if (!path_session.empty() && params.prompt_cache_all && !params.prompt_cache_ro) { LOG_TEE("\n%s: saving final output to session file '%s'\n", __func__, path_session.c_str()); - llama_save_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size()); + llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size()); } llama_print_timings(ctx); diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index ef952e2b..c3b76688 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -24,6 +24,7 @@ int main(int argc, char ** argv) { std::string result0; std::string result1; + std::string result2; // init llama_model * model; @@ -44,8 +45,8 @@ int main(int argc, char ** argv) { // save state (rng, logits, embedding and kv_cache) to file { - std::vector<uint8_t> state_mem(llama_get_state_size(ctx)); - const size_t written = llama_copy_state_data(ctx, state_mem.data()); + std::vector<uint8_t> state_mem(llama_state_get_size(ctx)); + const size_t written = llama_state_get_data(ctx, state_mem.data()); FILE *fp_write = fopen("dump_state.bin", "wb"); fwrite(state_mem.data(), 1, written, fp_write); @@ -97,13 +98,13 @@ int main(int argc, char ** argv) { // load state (rng, logits, embedding and kv_cache) from file { - std::vector<uint8_t> state_mem(llama_get_state_size(ctx2)); + std::vector<uint8_t> state_mem(llama_state_get_size(ctx2)); FILE * fp_read = fopen("dump_state.bin", "rb"); const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read); fclose(fp_read); - if (read != llama_set_state_data(ctx2, state_mem.data())) { + if (read != llama_state_set_data(ctx2, state_mem.data())) { fprintf(stderr, "\n%s : failed to read state\n", __func__); llama_free(ctx2); llama_free_model(model); @@ -141,16 +142,104 @@ int main(int argc, char ** argv) { n_past += 1; } - printf("\n"); + printf("\n\n"); llama_free(ctx2); - llama_free_model(model); if (result0 != result1) { fprintf(stderr, "\n%s : error : the 2 generations are different\n", __func__); return 1; } + // make new context + auto* ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params)); + + printf("\nsingle seq run: %s", params.prompt.c_str()); + + // load state (rng, logits, embedding and kv_cache) from file + { + std::vector<uint8_t> state_mem(llama_state_get_size(ctx3)); + + FILE * fp_read = fopen("dump_state.bin", "rb"); + const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read); + fclose(fp_read); + + if (read != llama_state_set_data(ctx3, state_mem.data())) { + fprintf(stderr, "\n%s : failed to read state\n", __func__); + llama_free(ctx3); + llama_free_model(model); + return 1; + } + + fprintf(stderr, "%s : deserialized state from %zd out of a maximum of %zd bytes\n", __func__, read, state_mem.size()); + } + + // restore state (last tokens) + n_past = n_past_saved; + + // save seq 0 and load into seq 1 + { + // save kv of seq 0 + std::vector<uint8_t> seq_store(llama_state_seq_get_size(ctx3, 0)); + const size_t ncopy = llama_state_seq_get_data(ctx3, seq_store.data(), 0); + if (ncopy != seq_store.size()) { + fprintf(stderr, "\n%s : seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size()); + llama_free(ctx3); + llama_free_model(model); + return 1; + } + fprintf(stderr, "%s : seq 0 copied, %zd bytes\n", __func__, ncopy); + + // erase whole kv + llama_kv_cache_clear(ctx3); + fprintf(stderr, "%s : kv cache cleared\n", __func__); + + // restore kv into seq 1 + const size_t nset = llama_state_seq_set_data(ctx3, seq_store.data(), 1); + if (nset != seq_store.size()) { + fprintf(stderr, "\n%s : seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size()); + llama_free(ctx3); + llama_free_model(model); + return 1; + } + fprintf(stderr, "%s : seq 1 restored, %zd bytes\n", __func__, nset); + } + + // third run with seq 1 instead of 0 + for (auto i = 0; i < params.n_predict; i++) { + auto * logits = llama_get_logits(ctx3); + auto n_vocab = llama_n_vocab(model); + std::vector<llama_token_data> candidates; + candidates.reserve(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); + } + llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + auto next_token = llama_sample_token(ctx3, &candidates_p); + auto next_token_str = llama_token_to_piece(ctx3, next_token); + + printf("%s", next_token_str.c_str()); + result2 += next_token_str; + + if (llama_decode(ctx3, llama_batch_get_one(&next_token, 1, n_past, 1))) { + fprintf(stderr, "\n%s : failed to evaluate\n", __func__); + llama_free(ctx3); + llama_free_model(model); + return 1; + } + n_past += 1; + } + + printf("\n"); + + llama_free(ctx3); + llama_free_model(model); + + if (result0 != result2) { + fprintf(stderr, "\n%s : error : the seq restore generation is different\n", __func__); + return 1; + } + fprintf(stderr, "\n%s : success\n", __func__); return 0; diff --git a/examples/server/README.md b/examples/server/README.md index 0d8564a1..a6fc92ea 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -57,6 +57,7 @@ page cache before using this. See https://github.com/ggerganov/llama.cpp/issues/ - `-n N, --n-predict N`: Set the maximum tokens to predict. Default: `-1` - `--slots-endpoint-disable`: To disable slots state monitoring endpoint. Slots state may contain user data, prompts included. - `--metrics`: enable prometheus `/metrics` compatible endpoint. Default: disabled +- `--slot-save-path PATH`: Specifies the path where the state of slots (the prompt cache) can be stored. If not provided, the slot management endpoints will be disabled. - `--chat-template JINJA_TEMPLATE`: Set custom jinja chat template. This parameter accepts a string, not a file name. Default: template taken from model's metadata. We only support [some pre-defined templates](https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template) - `--log-disable`: Output logs to stdout only, not to `llama.log`. Default: enabled - `--log-format FORMAT`: Define the log output to FORMAT: json or text Default: `json` @@ -517,6 +518,57 @@ Available metrics: - `llamacpp:requests_processing`: Number of requests processing. - `llamacpp:requests_deferred`: Number of requests deferred. +- **POST** `/slots/{id_slot}?action=save`: Save the prompt cache of the specified slot to a file. + + *Options:* + + `filename`: Name of the file to save the slot's prompt cache. The file will be saved in the directory specified by the `--slot-save-path` server parameter. + +### Result JSON + +```json +{ + "id_slot": 0, + "filename": "slot_save_file.bin", + "n_saved": 1745, + "n_written": 14309796, + "timings": { + "save_ms": 49.865 + } +} +``` + +- **POST** `/slots/{id_slot}?action=restore`: Restore the prompt cache of the specified slot from a file. + + *Options:* + + `filename`: Name of the file to restore the slot's prompt cache from. The file should be located in the directory specified by the `--slot-save-path` server parameter. + +### Result JSON + +```json +{ + "id_slot": 0, + "filename": "slot_save_file.bin", + "n_restored": 1745, + "n_read": 14309796, + "timings": { + "restore_ms": 42.937 + } +} +``` + +- **POST** `/slots/{id_slot}?action=erase`: Erase the prompt cache of the specified slot. + +### Result JSON + +```json +{ + "id_slot": 0, + "n_erased": 1745 +} +``` + ## More examples ### Change system prompt on runtime diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 183f0d3c..6c64fe3e 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -61,7 +61,10 @@ enum server_task_type { SERVER_TASK_TYPE_COMPLETION, SERVER_TASK_TYPE_CANCEL, SERVER_TASK_TYPE_NEXT_RESPONSE, - SERVER_TASK_TYPE_METRICS + SERVER_TASK_TYPE_METRICS, + SERVER_TASK_TYPE_SLOT_SAVE, + SERVER_TASK_TYPE_SLOT_RESTORE, + SERVER_TASK_TYPE_SLOT_ERASE, }; struct server_task { @@ -128,6 +131,7 @@ struct server_params { bool slots_endpoint = true; bool metrics_endpoint = false; + std::string slot_save_path; }; struct server_slot { @@ -1612,6 +1616,107 @@ struct server_context { } queue_results.send(res); } break; + case SERVER_TASK_TYPE_SLOT_SAVE: + { + int id_slot = task.data["id_slot"]; + server_slot * slot = get_slot(id_slot); + if (slot == nullptr) { + send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); + break; + } + + const size_t token_count = slot->cache_tokens.size(); + const int64_t t_start = ggml_time_us(); + + std::string filename = task.data["filename"]; + std::string filepath = task.data["filepath"]; + + const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), token_count); + + const int64_t t_end = ggml_time_us(); + const double t_save_ms = (t_end - t_start) / 1000.0; + + server_task_result result; + result.id = task.id; + result.stop = true; + result.error = false; + result.data = json { + { "id_slot", id_slot }, + { "filename", filename }, + { "n_saved", token_count }, // tokens saved + { "n_written", nwrite }, // bytes written + { "timings", { + { "save_ms", t_save_ms } + } } + }; + queue_results.send(result); + } break; + case SERVER_TASK_TYPE_SLOT_RESTORE: + { + int id_slot = task.data["id_slot"]; + server_slot * slot = get_slot(id_slot); + if (slot == nullptr) { + send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); + break; + } + + const int64_t t_start = ggml_time_us(); + + std::string filename = task.data["filename"]; + std::string filepath = task.data["filepath"]; + + slot->cache_tokens.resize(slot->n_ctx); + size_t token_count = 0; + size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count); + if (nread == 0) { + slot->cache_tokens.resize(0); + send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST); + break; + } + slot->cache_tokens.resize(token_count); + + const int64_t t_end = ggml_time_us(); + const double t_restore_ms = (t_end - t_start) / 1000.0; + + server_task_result result; + result.id = task.id; + result.stop = true; + result.error = false; + result.data = json { + { "id_slot", id_slot }, + { "filename", filename }, + { "n_restored", token_count }, // tokens restored + { "n_read", nread }, // bytes read + { "timings", { + { "restore_ms", t_restore_ms } + } } + }; + queue_results.send(result); + } break; + case SERVER_TASK_TYPE_SLOT_ERASE: + { + int id_slot = task.data["id_slot"]; + server_slot * slot = get_slot(id_slot); + if (slot == nullptr) { + send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); + break; + } + + // Erase token cache + const size_t n_erased = slot->cache_tokens.size(); + llama_kv_cache_seq_rm(ctx, slot->id + 1, -1, -1); + slot->cache_tokens.clear(); + + server_task_result result; + result.id = task.id; + result.stop = true; + result.error = false; + result.data = json { + { "id_slot", id_slot }, + { "n_erased", n_erased } + }; + queue_results.send(result); + } break; } } @@ -2249,6 +2354,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co printf(" --log-disable disables logging to a file.\n"); printf(" --slots-endpoint-disable disables slots monitoring endpoint.\n"); printf(" --metrics enable prometheus compatible metrics endpoint (default: %s).\n", sparams.metrics_endpoint ? "enabled" : "disabled"); + printf(" --slot-save-path PATH path to save slot kv cache (default: disabled)\n"); printf("\n"); printf(" -n, --n-predict maximum tokens to predict (default: %d)\n", params.n_predict); printf(" --override-kv KEY=TYPE:VALUE\n"); @@ -2657,6 +2763,16 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams, sparams.slots_endpoint = false; } else if (arg == "--metrics") { sparams.metrics_endpoint = true; + } else if (arg == "--slot-save-path") { + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.slot_save_path = argv[i]; + // if doesn't end with DIRECTORY_SEPARATOR, add it + if (!sparams.slot_save_path.empty() && sparams.slot_save_path[sparams.slot_save_path.size() - 1] != DIRECTORY_SEPARATOR) { + sparams.slot_save_path += DIRECTORY_SEPARATOR; + } } else if (arg == "--chat-template") { if (++i >= argc) { invalid_param = true; @@ -3159,6 +3275,112 @@ int main(int argc, char ** argv) { res.status = 200; // HTTP OK }; + const auto handle_slots_save = [&ctx_server, &res_error, &sparams](const httplib::Request & req, httplib::Response & res, int id_slot) { + json request_data = json::parse(req.body); + std::string filename = request_data["filename"]; + if (!validate_file_name(filename)) { + res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); + return; + } + std::string filepath = sparams.slot_save_path + filename; + + server_task task; + task.type = SERVER_TASK_TYPE_SLOT_SAVE; + task.data = { + { "id_slot", id_slot }, + { "filename", filename }, + { "filepath", filepath } + }; + + const int id_task = ctx_server.queue_tasks.post(task); + ctx_server.queue_results.add_waiting_task_id(id_task); + + server_task_result result = ctx_server.queue_results.recv(id_task); + ctx_server.queue_results.remove_waiting_task_id(id_task); + + if (result.error) { + res_error(res, result.data); + } else { + res.set_content(result.data.dump(), "application/json"); + } + }; + + const auto handle_slots_restore = [&ctx_server, &res_error, &sparams](const httplib::Request & req, httplib::Response & res, int id_slot) { + json request_data = json::parse(req.body); + std::string filename = request_data["filename"]; + if (!validate_file_name(filename)) { + res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); + return; + } + std::string filepath = sparams.slot_save_path + filename; + + server_task task; + task.type = SERVER_TASK_TYPE_SLOT_RESTORE; + task.data = { + { "id_slot", id_slot }, + { "filename", filename }, + { "filepath", filepath } + }; + + const int id_task = ctx_server.queue_tasks.post(task); + ctx_server.queue_results.add_waiting_task_id(id_task); + + server_task_result result = ctx_server.queue_results.recv(id_task); + ctx_server.queue_results.remove_waiting_task_id(id_task); + + if (result.error) { + res_error(res, result.data); + } else { + res.set_content(result.data.dump(), "application/json"); + } + }; + + const auto handle_slots_erase = [&ctx_server, &res_error](const httplib::Request & /* req */, httplib::Response & res, int id_slot) { + server_task task; + task.type = SERVER_TASK_TYPE_SLOT_ERASE; + task.data = { + { "id_slot", id_slot }, + }; + + const int id_task = ctx_server.queue_tasks.post(task); + ctx_server.queue_results.add_waiting_task_id(id_task); + + server_task_result result = ctx_server.queue_results.recv(id_task); + ctx_server.queue_results.remove_waiting_task_id(id_task); + + if (result.error) { + res_error(res, result.data); + } else { + res.set_content(result.data.dump(), "application/json"); + } + }; + + const auto handle_slots_action = [&res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + + std::string id_slot_str = req.path_params.at("id_slot"); + int id_slot; + + try { + id_slot = std::stoi(id_slot_str); + } catch (const std::exception &) { + res_error(res, format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST)); + return; + } + + std::string action = req.get_param_value("action"); + + if (action == "save") { + handle_slots_save(req, res, id_slot); + } else if (action == "restore") { + handle_slots_restore(req, res, id_slot); + } else if (action == "erase") { + handle_slots_erase(req, res, id_slot); + } else { + res_error(res, format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST)); + } + }; + const auto handle_props = [&ctx_server](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json data = { @@ -3521,6 +3743,10 @@ int main(int argc, char ** argv) { svr->Post("/v1/embeddings", handle_embeddings); svr->Post("/tokenize", handle_tokenize); svr->Post("/detokenize", handle_detokenize); + if (!sparams.slot_save_path.empty()) { + // only enable slot endpoints if slot_save_path is set + svr->Post("/slots/:id_slot", handle_slots_action); + } // // Start the server diff --git a/examples/server/tests/features/slotsave.feature b/examples/server/tests/features/slotsave.feature new file mode 100644 index 00000000..1c281c07 --- /dev/null +++ b/examples/server/tests/features/slotsave.feature @@ -0,0 +1,58 @@ +@llama.cpp +@slotsave +Feature: llama.cpp server slot management + + Background: Server startup + Given a server listening on localhost:8080 + And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models + And prompt caching is enabled + And 2 slots + And . as slot save path + And 2048 KV cache size + And 42 as server seed + And 24 max tokens to predict + Then the server is starting + Then the server is healthy + + Scenario: Save and Restore Slot + # First prompt in slot 1 should be fully processed + Given a user prompt "What is the capital of France?" + And using slot id 1 + And a completion request with no api error + Then 24 tokens are predicted matching (Lily|cake) + And 22 prompt tokens are processed + When the slot 1 is saved with filename "slot1.bin" + Then the server responds with status code 200 + # Since we have cache, this should only process the last tokens + Given a user prompt "What is the capital of Germany?" + And a completion request with no api error + Then 24 tokens are predicted matching (Thank|special) + And 7 prompt tokens are processed + # Loading the original cache into slot 0, + # we should only be processing 1 prompt token and get the same output + When the slot 0 is restored with filename "slot1.bin" + Then the server responds with status code 200 + Given a user prompt "What is the capital of France?" + And using slot id 0 + And a completion request with no api error + Then 24 tokens are predicted matching (Lily|cake) + And 1 prompt tokens are processed + # For verification that slot 1 was not corrupted during slot 0 load, same thing + Given a user prompt "What is the capital of Germany?" + And using slot id 1 + And a completion request with no api error + Then 24 tokens are predicted matching (Thank|special) + And 1 prompt tokens are processed + + Scenario: Erase Slot + Given a user prompt "What is the capital of France?" + And using slot id 1 + And a completion request with no api error + Then 24 tokens are predicted matching (Lily|cake) + And 22 prompt tokens are processed + When the slot 1 is erased + Then the server responds with status code 200 + Given a user prompt "What is the capital of France?" + And a completion request with no api error + Then 24 tokens are predicted matching (Lily|cake) + And 22 prompt tokens are processed diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index 9a6cf7d6..ca400efa 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -49,6 +49,9 @@ def step_server_config(context, server_fqdn, server_port): context.n_predict = None context.n_prompts = 0 context.n_server_predict = None + context.slot_save_path = None + context.id_slot = None + context.cache_prompt = None context.n_slots = None context.prompt_prefix = None context.prompt_suffix = None @@ -119,6 +122,21 @@ def step_server_n_predict(context, n_predict): context.n_server_predict = n_predict +@step('{slot_save_path} as slot save path') +def step_slot_save_path(context, slot_save_path): + context.slot_save_path = slot_save_path + + +@step('using slot id {id_slot:d}') +def step_id_slot(context, id_slot): + context.id_slot = id_slot + + +@step('prompt caching is enabled') +def step_enable_prompt_cache(context): + context.cache_prompt = True + + @step('continuous batching') def step_server_continuous_batching(context): context.server_continuous_batching = True @@ -212,6 +230,8 @@ async def step_request_completion(context, api_error): context.base_url, debug=context.debug, n_predict=context.n_predict, + cache_prompt=context.cache_prompt, + id_slot=context.id_slot, seed=await completions_seed(context), expect_api_error=expect_api_error, user_api_key=context.user_api_key) @@ -711,12 +731,48 @@ async def concurrent_requests(context, f_completion, *args, **kwargs): await asyncio.sleep(0.1) +@step('the slot {slot_id:d} is saved with filename "{filename}"') +@async_run_until_complete +async def step_save_slot(context, slot_id, filename): + async with aiohttp.ClientSession() as session: + async with session.post(f'{context.base_url}/slots/{slot_id}?action=save', + json={"filename": filename}, + headers={"Content-Type": "application/json"}) as response: + context.response = response + + +@step('the slot {slot_id:d} is restored with filename "{filename}"') +@async_run_until_complete +async def step_restore_slot(context, slot_id, filename): + async with aiohttp.ClientSession() as session: + async with session.post(f'{context.base_url}/slots/{slot_id}?action=restore', + json={"filename": filename}, + headers={"Content-Type": "application/json"}) as response: + context.response = response + + +@step('the slot {slot_id:d} is erased') +@async_run_until_complete +async def step_erase_slot(context, slot_id): + async with aiohttp.ClientSession() as session: + async with session.post(f'{context.base_url}/slots/{slot_id}?action=erase', + headers={"Content-Type": "application/json"}) as response: + context.response = response + + +@step('the server responds with status code {status_code:d}') +def step_server_responds_with_status_code(context, status_code): + assert context.response.status == status_code + + async def request_completion(prompt, base_url, debug=False, prompt_prefix=None, prompt_suffix=None, n_predict=None, + cache_prompt=False, + id_slot=None, seed=None, expect_api_error=None, user_api_key=None): @@ -738,6 +794,8 @@ async def request_completion(prompt, "prompt": prompt, "input_suffix": prompt_suffix, "n_predict": n_predict if n_predict is not None else -1, + "cache_prompt": cache_prompt, + "id_slot": id_slot, "seed": seed if seed is not None else 42 }, headers=headers, @@ -1104,6 +1162,8 @@ def start_server_background(context): server_args.extend(['--parallel', context.n_slots]) if context.n_server_predict: server_args.extend(['--n-predict', context.n_server_predict]) + if context.slot_save_path: + server_args.extend(['--slot-save-path', context.slot_save_path]) if context.server_api_key: server_args.extend(['--api-key', context.server_api_key]) if context.n_ga: |