summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
Diffstat (limited to 'examples')
-rw-r--r--examples/main/main.cpp6
-rw-r--r--examples/save-load-state/save-load-state.cpp101
-rw-r--r--examples/server/README.md52
-rw-r--r--examples/server/server.cpp228
-rw-r--r--examples/server/tests/features/slotsave.feature58
-rw-r--r--examples/server/tests/features/steps/steps.py60
6 files changed, 495 insertions, 10 deletions
diff --git a/examples/main/main.cpp b/examples/main/main.cpp
index e2d07a63..711f162d 100644
--- a/examples/main/main.cpp
+++ b/examples/main/main.cpp
@@ -235,7 +235,7 @@ int main(int argc, char ** argv) {
// The file exists and is not empty
session_tokens.resize(n_ctx);
size_t n_token_count_out = 0;
- if (!llama_load_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) {
+ if (!llama_state_load_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) {
LOG_TEE("%s: error: failed to load session file '%s'\n", __func__, path_session.c_str());
return 1;
}
@@ -693,7 +693,7 @@ int main(int argc, char ** argv) {
// optionally save the session on first sample (for faster prompt loading next time)
if (!path_session.empty() && need_to_save_session && !params.prompt_cache_ro) {
need_to_save_session = false;
- llama_save_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
+ llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
LOG("saved session to %s\n", path_session.c_str());
}
@@ -935,7 +935,7 @@ int main(int argc, char ** argv) {
if (!path_session.empty() && params.prompt_cache_all && !params.prompt_cache_ro) {
LOG_TEE("\n%s: saving final output to session file '%s'\n", __func__, path_session.c_str());
- llama_save_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
+ llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
}
llama_print_timings(ctx);
diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp
index ef952e2b..c3b76688 100644
--- a/examples/save-load-state/save-load-state.cpp
+++ b/examples/save-load-state/save-load-state.cpp
@@ -24,6 +24,7 @@ int main(int argc, char ** argv) {
std::string result0;
std::string result1;
+ std::string result2;
// init
llama_model * model;
@@ -44,8 +45,8 @@ int main(int argc, char ** argv) {
// save state (rng, logits, embedding and kv_cache) to file
{
- std::vector<uint8_t> state_mem(llama_get_state_size(ctx));
- const size_t written = llama_copy_state_data(ctx, state_mem.data());
+ std::vector<uint8_t> state_mem(llama_state_get_size(ctx));
+ const size_t written = llama_state_get_data(ctx, state_mem.data());
FILE *fp_write = fopen("dump_state.bin", "wb");
fwrite(state_mem.data(), 1, written, fp_write);
@@ -97,13 +98,13 @@ int main(int argc, char ** argv) {
// load state (rng, logits, embedding and kv_cache) from file
{
- std::vector<uint8_t> state_mem(llama_get_state_size(ctx2));
+ std::vector<uint8_t> state_mem(llama_state_get_size(ctx2));
FILE * fp_read = fopen("dump_state.bin", "rb");
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
fclose(fp_read);
- if (read != llama_set_state_data(ctx2, state_mem.data())) {
+ if (read != llama_state_set_data(ctx2, state_mem.data())) {
fprintf(stderr, "\n%s : failed to read state\n", __func__);
llama_free(ctx2);
llama_free_model(model);
@@ -141,16 +142,104 @@ int main(int argc, char ** argv) {
n_past += 1;
}
- printf("\n");
+ printf("\n\n");
llama_free(ctx2);
- llama_free_model(model);
if (result0 != result1) {
fprintf(stderr, "\n%s : error : the 2 generations are different\n", __func__);
return 1;
}
+ // make new context
+ auto* ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
+
+ printf("\nsingle seq run: %s", params.prompt.c_str());
+
+ // load state (rng, logits, embedding and kv_cache) from file
+ {
+ std::vector<uint8_t> state_mem(llama_state_get_size(ctx3));
+
+ FILE * fp_read = fopen("dump_state.bin", "rb");
+ const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
+ fclose(fp_read);
+
+ if (read != llama_state_set_data(ctx3, state_mem.data())) {
+ fprintf(stderr, "\n%s : failed to read state\n", __func__);
+ llama_free(ctx3);
+ llama_free_model(model);
+ return 1;
+ }
+
+ fprintf(stderr, "%s : deserialized state from %zd out of a maximum of %zd bytes\n", __func__, read, state_mem.size());
+ }
+
+ // restore state (last tokens)
+ n_past = n_past_saved;
+
+ // save seq 0 and load into seq 1
+ {
+ // save kv of seq 0
+ std::vector<uint8_t> seq_store(llama_state_seq_get_size(ctx3, 0));
+ const size_t ncopy = llama_state_seq_get_data(ctx3, seq_store.data(), 0);
+ if (ncopy != seq_store.size()) {
+ fprintf(stderr, "\n%s : seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size());
+ llama_free(ctx3);
+ llama_free_model(model);
+ return 1;
+ }
+ fprintf(stderr, "%s : seq 0 copied, %zd bytes\n", __func__, ncopy);
+
+ // erase whole kv
+ llama_kv_cache_clear(ctx3);
+ fprintf(stderr, "%s : kv cache cleared\n", __func__);
+
+ // restore kv into seq 1
+ const size_t nset = llama_state_seq_set_data(ctx3, seq_store.data(), 1);
+ if (nset != seq_store.size()) {
+ fprintf(stderr, "\n%s : seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size());
+ llama_free(ctx3);
+ llama_free_model(model);
+ return 1;
+ }
+ fprintf(stderr, "%s : seq 1 restored, %zd bytes\n", __func__, nset);
+ }
+
+ // third run with seq 1 instead of 0
+ for (auto i = 0; i < params.n_predict; i++) {
+ auto * logits = llama_get_logits(ctx3);
+ auto n_vocab = llama_n_vocab(model);
+ std::vector<llama_token_data> candidates;
+ candidates.reserve(n_vocab);
+ for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
+ candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
+ }
+ llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
+ auto next_token = llama_sample_token(ctx3, &candidates_p);
+ auto next_token_str = llama_token_to_piece(ctx3, next_token);
+
+ printf("%s", next_token_str.c_str());
+ result2 += next_token_str;
+
+ if (llama_decode(ctx3, llama_batch_get_one(&next_token, 1, n_past, 1))) {
+ fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
+ llama_free(ctx3);
+ llama_free_model(model);
+ return 1;
+ }
+ n_past += 1;
+ }
+
+ printf("\n");
+
+ llama_free(ctx3);
+ llama_free_model(model);
+
+ if (result0 != result2) {
+ fprintf(stderr, "\n%s : error : the seq restore generation is different\n", __func__);
+ return 1;
+ }
+
fprintf(stderr, "\n%s : success\n", __func__);
return 0;
diff --git a/examples/server/README.md b/examples/server/README.md
index 0d8564a1..a6fc92ea 100644
--- a/examples/server/README.md
+++ b/examples/server/README.md
@@ -57,6 +57,7 @@ page cache before using this. See https://github.com/ggerganov/llama.cpp/issues/
- `-n N, --n-predict N`: Set the maximum tokens to predict. Default: `-1`
- `--slots-endpoint-disable`: To disable slots state monitoring endpoint. Slots state may contain user data, prompts included.
- `--metrics`: enable prometheus `/metrics` compatible endpoint. Default: disabled
+- `--slot-save-path PATH`: Specifies the path where the state of slots (the prompt cache) can be stored. If not provided, the slot management endpoints will be disabled.
- `--chat-template JINJA_TEMPLATE`: Set custom jinja chat template. This parameter accepts a string, not a file name. Default: template taken from model's metadata. We only support [some pre-defined templates](https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template)
- `--log-disable`: Output logs to stdout only, not to `llama.log`. Default: enabled
- `--log-format FORMAT`: Define the log output to FORMAT: json or text Default: `json`
@@ -517,6 +518,57 @@ Available metrics:
- `llamacpp:requests_processing`: Number of requests processing.
- `llamacpp:requests_deferred`: Number of requests deferred.
+- **POST** `/slots/{id_slot}?action=save`: Save the prompt cache of the specified slot to a file.
+
+ *Options:*
+
+ `filename`: Name of the file to save the slot's prompt cache. The file will be saved in the directory specified by the `--slot-save-path` server parameter.
+
+### Result JSON
+
+```json
+{
+ "id_slot": 0,
+ "filename": "slot_save_file.bin",
+ "n_saved": 1745,
+ "n_written": 14309796,
+ "timings": {
+ "save_ms": 49.865
+ }
+}
+```
+
+- **POST** `/slots/{id_slot}?action=restore`: Restore the prompt cache of the specified slot from a file.
+
+ *Options:*
+
+ `filename`: Name of the file to restore the slot's prompt cache from. The file should be located in the directory specified by the `--slot-save-path` server parameter.
+
+### Result JSON
+
+```json
+{
+ "id_slot": 0,
+ "filename": "slot_save_file.bin",
+ "n_restored": 1745,
+ "n_read": 14309796,
+ "timings": {
+ "restore_ms": 42.937
+ }
+}
+```
+
+- **POST** `/slots/{id_slot}?action=erase`: Erase the prompt cache of the specified slot.
+
+### Result JSON
+
+```json
+{
+ "id_slot": 0,
+ "n_erased": 1745
+}
+```
+
## More examples
### Change system prompt on runtime
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index 183f0d3c..6c64fe3e 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -61,7 +61,10 @@ enum server_task_type {
SERVER_TASK_TYPE_COMPLETION,
SERVER_TASK_TYPE_CANCEL,
SERVER_TASK_TYPE_NEXT_RESPONSE,
- SERVER_TASK_TYPE_METRICS
+ SERVER_TASK_TYPE_METRICS,
+ SERVER_TASK_TYPE_SLOT_SAVE,
+ SERVER_TASK_TYPE_SLOT_RESTORE,
+ SERVER_TASK_TYPE_SLOT_ERASE,
};
struct server_task {
@@ -128,6 +131,7 @@ struct server_params {
bool slots_endpoint = true;
bool metrics_endpoint = false;
+ std::string slot_save_path;
};
struct server_slot {
@@ -1612,6 +1616,107 @@ struct server_context {
}
queue_results.send(res);
} break;
+ case SERVER_TASK_TYPE_SLOT_SAVE:
+ {
+ int id_slot = task.data["id_slot"];
+ server_slot * slot = get_slot(id_slot);
+ if (slot == nullptr) {
+ send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
+ break;
+ }
+
+ const size_t token_count = slot->cache_tokens.size();
+ const int64_t t_start = ggml_time_us();
+
+ std::string filename = task.data["filename"];
+ std::string filepath = task.data["filepath"];
+
+ const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), token_count);
+
+ const int64_t t_end = ggml_time_us();
+ const double t_save_ms = (t_end - t_start) / 1000.0;
+
+ server_task_result result;
+ result.id = task.id;
+ result.stop = true;
+ result.error = false;
+ result.data = json {
+ { "id_slot", id_slot },
+ { "filename", filename },
+ { "n_saved", token_count }, // tokens saved
+ { "n_written", nwrite }, // bytes written
+ { "timings", {
+ { "save_ms", t_save_ms }
+ } }
+ };
+ queue_results.send(result);
+ } break;
+ case SERVER_TASK_TYPE_SLOT_RESTORE:
+ {
+ int id_slot = task.data["id_slot"];
+ server_slot * slot = get_slot(id_slot);
+ if (slot == nullptr) {
+ send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
+ break;
+ }
+
+ const int64_t t_start = ggml_time_us();
+
+ std::string filename = task.data["filename"];
+ std::string filepath = task.data["filepath"];
+
+ slot->cache_tokens.resize(slot->n_ctx);
+ size_t token_count = 0;
+ size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count);
+ if (nread == 0) {
+ slot->cache_tokens.resize(0);
+ send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST);
+ break;
+ }
+ slot->cache_tokens.resize(token_count);
+
+ const int64_t t_end = ggml_time_us();
+ const double t_restore_ms = (t_end - t_start) / 1000.0;
+
+ server_task_result result;
+ result.id = task.id;
+ result.stop = true;
+ result.error = false;
+ result.data = json {
+ { "id_slot", id_slot },
+ { "filename", filename },
+ { "n_restored", token_count }, // tokens restored
+ { "n_read", nread }, // bytes read
+ { "timings", {
+ { "restore_ms", t_restore_ms }
+ } }
+ };
+ queue_results.send(result);
+ } break;
+ case SERVER_TASK_TYPE_SLOT_ERASE:
+ {
+ int id_slot = task.data["id_slot"];
+ server_slot * slot = get_slot(id_slot);
+ if (slot == nullptr) {
+ send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
+ break;
+ }
+
+ // Erase token cache
+ const size_t n_erased = slot->cache_tokens.size();
+ llama_kv_cache_seq_rm(ctx, slot->id + 1, -1, -1);
+ slot->cache_tokens.clear();
+
+ server_task_result result;
+ result.id = task.id;
+ result.stop = true;
+ result.error = false;
+ result.data = json {
+ { "id_slot", id_slot },
+ { "n_erased", n_erased }
+ };
+ queue_results.send(result);
+ } break;
}
}
@@ -2249,6 +2354,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co
printf(" --log-disable disables logging to a file.\n");
printf(" --slots-endpoint-disable disables slots monitoring endpoint.\n");
printf(" --metrics enable prometheus compatible metrics endpoint (default: %s).\n", sparams.metrics_endpoint ? "enabled" : "disabled");
+ printf(" --slot-save-path PATH path to save slot kv cache (default: disabled)\n");
printf("\n");
printf(" -n, --n-predict maximum tokens to predict (default: %d)\n", params.n_predict);
printf(" --override-kv KEY=TYPE:VALUE\n");
@@ -2657,6 +2763,16 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
sparams.slots_endpoint = false;
} else if (arg == "--metrics") {
sparams.metrics_endpoint = true;
+ } else if (arg == "--slot-save-path") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ sparams.slot_save_path = argv[i];
+ // if doesn't end with DIRECTORY_SEPARATOR, add it
+ if (!sparams.slot_save_path.empty() && sparams.slot_save_path[sparams.slot_save_path.size() - 1] != DIRECTORY_SEPARATOR) {
+ sparams.slot_save_path += DIRECTORY_SEPARATOR;
+ }
} else if (arg == "--chat-template") {
if (++i >= argc) {
invalid_param = true;
@@ -3159,6 +3275,112 @@ int main(int argc, char ** argv) {
res.status = 200; // HTTP OK
};
+ const auto handle_slots_save = [&ctx_server, &res_error, &sparams](const httplib::Request & req, httplib::Response & res, int id_slot) {
+ json request_data = json::parse(req.body);
+ std::string filename = request_data["filename"];
+ if (!validate_file_name(filename)) {
+ res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
+ return;
+ }
+ std::string filepath = sparams.slot_save_path + filename;
+
+ server_task task;
+ task.type = SERVER_TASK_TYPE_SLOT_SAVE;
+ task.data = {
+ { "id_slot", id_slot },
+ { "filename", filename },
+ { "filepath", filepath }
+ };
+
+ const int id_task = ctx_server.queue_tasks.post(task);
+ ctx_server.queue_results.add_waiting_task_id(id_task);
+
+ server_task_result result = ctx_server.queue_results.recv(id_task);
+ ctx_server.queue_results.remove_waiting_task_id(id_task);
+
+ if (result.error) {
+ res_error(res, result.data);
+ } else {
+ res.set_content(result.data.dump(), "application/json");
+ }
+ };
+
+ const auto handle_slots_restore = [&ctx_server, &res_error, &sparams](const httplib::Request & req, httplib::Response & res, int id_slot) {
+ json request_data = json::parse(req.body);
+ std::string filename = request_data["filename"];
+ if (!validate_file_name(filename)) {
+ res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
+ return;
+ }
+ std::string filepath = sparams.slot_save_path + filename;
+
+ server_task task;
+ task.type = SERVER_TASK_TYPE_SLOT_RESTORE;
+ task.data = {
+ { "id_slot", id_slot },
+ { "filename", filename },
+ { "filepath", filepath }
+ };
+
+ const int id_task = ctx_server.queue_tasks.post(task);
+ ctx_server.queue_results.add_waiting_task_id(id_task);
+
+ server_task_result result = ctx_server.queue_results.recv(id_task);
+ ctx_server.queue_results.remove_waiting_task_id(id_task);
+
+ if (result.error) {
+ res_error(res, result.data);
+ } else {
+ res.set_content(result.data.dump(), "application/json");
+ }
+ };
+
+ const auto handle_slots_erase = [&ctx_server, &res_error](const httplib::Request & /* req */, httplib::Response & res, int id_slot) {
+ server_task task;
+ task.type = SERVER_TASK_TYPE_SLOT_ERASE;
+ task.data = {
+ { "id_slot", id_slot },
+ };
+
+ const int id_task = ctx_server.queue_tasks.post(task);
+ ctx_server.queue_results.add_waiting_task_id(id_task);
+
+ server_task_result result = ctx_server.queue_results.recv(id_task);
+ ctx_server.queue_results.remove_waiting_task_id(id_task);
+
+ if (result.error) {
+ res_error(res, result.data);
+ } else {
+ res.set_content(result.data.dump(), "application/json");
+ }
+ };
+
+ const auto handle_slots_action = [&res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) {
+ res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
+
+ std::string id_slot_str = req.path_params.at("id_slot");
+ int id_slot;
+
+ try {
+ id_slot = std::stoi(id_slot_str);
+ } catch (const std::exception &) {
+ res_error(res, format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST));
+ return;
+ }
+
+ std::string action = req.get_param_value("action");
+
+ if (action == "save") {
+ handle_slots_save(req, res, id_slot);
+ } else if (action == "restore") {
+ handle_slots_restore(req, res, id_slot);
+ } else if (action == "erase") {
+ handle_slots_erase(req, res, id_slot);
+ } else {
+ res_error(res, format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST));
+ }
+ };
+
const auto handle_props = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json data = {
@@ -3521,6 +3743,10 @@ int main(int argc, char ** argv) {
svr->Post("/v1/embeddings", handle_embeddings);
svr->Post("/tokenize", handle_tokenize);
svr->Post("/detokenize", handle_detokenize);
+ if (!sparams.slot_save_path.empty()) {
+ // only enable slot endpoints if slot_save_path is set
+ svr->Post("/slots/:id_slot", handle_slots_action);
+ }
//
// Start the server
diff --git a/examples/server/tests/features/slotsave.feature b/examples/server/tests/features/slotsave.feature
new file mode 100644
index 00000000..1c281c07
--- /dev/null
+++ b/examples/server/tests/features/slotsave.feature
@@ -0,0 +1,58 @@
+@llama.cpp
+@slotsave
+Feature: llama.cpp server slot management
+
+ Background: Server startup
+ Given a server listening on localhost:8080
+ And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
+ And prompt caching is enabled
+ And 2 slots
+ And . as slot save path
+ And 2048 KV cache size
+ And 42 as server seed
+ And 24 max tokens to predict
+ Then the server is starting
+ Then the server is healthy
+
+ Scenario: Save and Restore Slot
+ # First prompt in slot 1 should be fully processed
+ Given a user prompt "What is the capital of France?"
+ And using slot id 1
+ And a completion request with no api error
+ Then 24 tokens are predicted matching (Lily|cake)
+ And 22 prompt tokens are processed
+ When the slot 1 is saved with filename "slot1.bin"
+ Then the server responds with status code 200
+ # Since we have cache, this should only process the last tokens
+ Given a user prompt "What is the capital of Germany?"
+ And a completion request with no api error
+ Then 24 tokens are predicted matching (Thank|special)
+ And 7 prompt tokens are processed
+ # Loading the original cache into slot 0,
+ # we should only be processing 1 prompt token and get the same output
+ When the slot 0 is restored with filename "slot1.bin"
+ Then the server responds with status code 200
+ Given a user prompt "What is the capital of France?"
+ And using slot id 0
+ And a completion request with no api error
+ Then 24 tokens are predicted matching (Lily|cake)
+ And 1 prompt tokens are processed
+ # For verification that slot 1 was not corrupted during slot 0 load, same thing
+ Given a user prompt "What is the capital of Germany?"
+ And using slot id 1
+ And a completion request with no api error
+ Then 24 tokens are predicted matching (Thank|special)
+ And 1 prompt tokens are processed
+
+ Scenario: Erase Slot
+ Given a user prompt "What is the capital of France?"
+ And using slot id 1
+ And a completion request with no api error
+ Then 24 tokens are predicted matching (Lily|cake)
+ And 22 prompt tokens are processed
+ When the slot 1 is erased
+ Then the server responds with status code 200
+ Given a user prompt "What is the capital of France?"
+ And a completion request with no api error
+ Then 24 tokens are predicted matching (Lily|cake)
+ And 22 prompt tokens are processed
diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py
index 9a6cf7d6..ca400efa 100644
--- a/examples/server/tests/features/steps/steps.py
+++ b/examples/server/tests/features/steps/steps.py
@@ -49,6 +49,9 @@ def step_server_config(context, server_fqdn, server_port):
context.n_predict = None
context.n_prompts = 0
context.n_server_predict = None
+ context.slot_save_path = None
+ context.id_slot = None
+ context.cache_prompt = None
context.n_slots = None
context.prompt_prefix = None
context.prompt_suffix = None
@@ -119,6 +122,21 @@ def step_server_n_predict(context, n_predict):
context.n_server_predict = n_predict
+@step('{slot_save_path} as slot save path')
+def step_slot_save_path(context, slot_save_path):
+ context.slot_save_path = slot_save_path
+
+
+@step('using slot id {id_slot:d}')
+def step_id_slot(context, id_slot):
+ context.id_slot = id_slot
+
+
+@step('prompt caching is enabled')
+def step_enable_prompt_cache(context):
+ context.cache_prompt = True
+
+
@step('continuous batching')
def step_server_continuous_batching(context):
context.server_continuous_batching = True
@@ -212,6 +230,8 @@ async def step_request_completion(context, api_error):
context.base_url,
debug=context.debug,
n_predict=context.n_predict,
+ cache_prompt=context.cache_prompt,
+ id_slot=context.id_slot,
seed=await completions_seed(context),
expect_api_error=expect_api_error,
user_api_key=context.user_api_key)
@@ -711,12 +731,48 @@ async def concurrent_requests(context, f_completion, *args, **kwargs):
await asyncio.sleep(0.1)
+@step('the slot {slot_id:d} is saved with filename "{filename}"')
+@async_run_until_complete
+async def step_save_slot(context, slot_id, filename):
+ async with aiohttp.ClientSession() as session:
+ async with session.post(f'{context.base_url}/slots/{slot_id}?action=save',
+ json={"filename": filename},
+ headers={"Content-Type": "application/json"}) as response:
+ context.response = response
+
+
+@step('the slot {slot_id:d} is restored with filename "{filename}"')
+@async_run_until_complete
+async def step_restore_slot(context, slot_id, filename):
+ async with aiohttp.ClientSession() as session:
+ async with session.post(f'{context.base_url}/slots/{slot_id}?action=restore',
+ json={"filename": filename},
+ headers={"Content-Type": "application/json"}) as response:
+ context.response = response
+
+
+@step('the slot {slot_id:d} is erased')
+@async_run_until_complete
+async def step_erase_slot(context, slot_id):
+ async with aiohttp.ClientSession() as session:
+ async with session.post(f'{context.base_url}/slots/{slot_id}?action=erase',
+ headers={"Content-Type": "application/json"}) as response:
+ context.response = response
+
+
+@step('the server responds with status code {status_code:d}')
+def step_server_responds_with_status_code(context, status_code):
+ assert context.response.status == status_code
+
+
async def request_completion(prompt,
base_url,
debug=False,
prompt_prefix=None,
prompt_suffix=None,
n_predict=None,
+ cache_prompt=False,
+ id_slot=None,
seed=None,
expect_api_error=None,
user_api_key=None):
@@ -738,6 +794,8 @@ async def request_completion(prompt,
"prompt": prompt,
"input_suffix": prompt_suffix,
"n_predict": n_predict if n_predict is not None else -1,
+ "cache_prompt": cache_prompt,
+ "id_slot": id_slot,
"seed": seed if seed is not None else 42
},
headers=headers,
@@ -1104,6 +1162,8 @@ def start_server_background(context):
server_args.extend(['--parallel', context.n_slots])
if context.n_server_predict:
server_args.extend(['--n-predict', context.n_server_predict])
+ if context.slot_save_path:
+ server_args.extend(['--slot-save-path', context.slot_save_path])
if context.server_api_key:
server_args.extend(['--api-key', context.server_api_key])
if context.n_ga: