summaryrefslogtreecommitdiff
path: root/examples/server/tests/features
diff options
context:
space:
mode:
authorJan Boon <jan.boon@kaetemi.be>2024-04-08 20:43:30 +0800
committerGitHub <noreply@github.com>2024-04-08 15:43:30 +0300
commitbeea6e1b16e783a0886e78dec01002a8c00db24d (patch)
treea7365b1e93145b78a8b4be72df959239aa8c0f0d /examples/server/tests/features
parent87fb5b4234d4b9c56ac94cf7aa229c8fd7defdb0 (diff)
llama : save and restore kv cache for single seq id (#6341)
* llama : save and restore kv cache for single seq id * remove trailing whitespace * respond error in case there's no space in the kv cache * add kv seq save restore to test case * add --slot-save-path arg to enable save restore and restrict save location * Returning 0 for some cases, instead of asserting. * cleanup error cases * rename sequence state functions * rename state get set functions * add previous function names back in with DEPRECATED notice * update doc * adjust endpoints to preferred style * fix restoring zero cell count * handle seq rm return value * unused param * keep in the size check * fix return types * add server test case for slot save restore * cleanup * add cake * cleanup style * add special * removing a whole sequence never fails * move sequence state file functionality from server to llama to match session api and add version tags * catch exceptions on save as well * error log messages * check types for stricter restore * update server doc * readme : update API changes date * strict filename validation * move include, reject bom as well * also reject empty filename * reject whitespace and trailing dot --------- Co-authored-by: Martin Evans <martindevans@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'examples/server/tests/features')
-rw-r--r--examples/server/tests/features/slotsave.feature58
-rw-r--r--examples/server/tests/features/steps/steps.py60
2 files changed, 118 insertions, 0 deletions
diff --git a/examples/server/tests/features/slotsave.feature b/examples/server/tests/features/slotsave.feature
new file mode 100644
index 00000000..1c281c07
--- /dev/null
+++ b/examples/server/tests/features/slotsave.feature
@@ -0,0 +1,58 @@
+@llama.cpp
+@slotsave
+Feature: llama.cpp server slot management
+
+ Background: Server startup
+ Given a server listening on localhost:8080
+ And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
+ And prompt caching is enabled
+ And 2 slots
+ And . as slot save path
+ And 2048 KV cache size
+ And 42 as server seed
+ And 24 max tokens to predict
+ Then the server is starting
+ Then the server is healthy
+
+ Scenario: Save and Restore Slot
+ # First prompt in slot 1 should be fully processed
+ Given a user prompt "What is the capital of France?"
+ And using slot id 1
+ And a completion request with no api error
+ Then 24 tokens are predicted matching (Lily|cake)
+ And 22 prompt tokens are processed
+ When the slot 1 is saved with filename "slot1.bin"
+ Then the server responds with status code 200
+ # Since we have cache, this should only process the last tokens
+ Given a user prompt "What is the capital of Germany?"
+ And a completion request with no api error
+ Then 24 tokens are predicted matching (Thank|special)
+ And 7 prompt tokens are processed
+ # Loading the original cache into slot 0,
+ # we should only be processing 1 prompt token and get the same output
+ When the slot 0 is restored with filename "slot1.bin"
+ Then the server responds with status code 200
+ Given a user prompt "What is the capital of France?"
+ And using slot id 0
+ And a completion request with no api error
+ Then 24 tokens are predicted matching (Lily|cake)
+ And 1 prompt tokens are processed
+ # For verification that slot 1 was not corrupted during slot 0 load, same thing
+ Given a user prompt "What is the capital of Germany?"
+ And using slot id 1
+ And a completion request with no api error
+ Then 24 tokens are predicted matching (Thank|special)
+ And 1 prompt tokens are processed
+
+ Scenario: Erase Slot
+ Given a user prompt "What is the capital of France?"
+ And using slot id 1
+ And a completion request with no api error
+ Then 24 tokens are predicted matching (Lily|cake)
+ And 22 prompt tokens are processed
+ When the slot 1 is erased
+ Then the server responds with status code 200
+ Given a user prompt "What is the capital of France?"
+ And a completion request with no api error
+ Then 24 tokens are predicted matching (Lily|cake)
+ And 22 prompt tokens are processed
diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py
index 9a6cf7d6..ca400efa 100644
--- a/examples/server/tests/features/steps/steps.py
+++ b/examples/server/tests/features/steps/steps.py
@@ -49,6 +49,9 @@ def step_server_config(context, server_fqdn, server_port):
context.n_predict = None
context.n_prompts = 0
context.n_server_predict = None
+ context.slot_save_path = None
+ context.id_slot = None
+ context.cache_prompt = None
context.n_slots = None
context.prompt_prefix = None
context.prompt_suffix = None
@@ -119,6 +122,21 @@ def step_server_n_predict(context, n_predict):
context.n_server_predict = n_predict
+@step('{slot_save_path} as slot save path')
+def step_slot_save_path(context, slot_save_path):
+ context.slot_save_path = slot_save_path
+
+
+@step('using slot id {id_slot:d}')
+def step_id_slot(context, id_slot):
+ context.id_slot = id_slot
+
+
+@step('prompt caching is enabled')
+def step_enable_prompt_cache(context):
+ context.cache_prompt = True
+
+
@step('continuous batching')
def step_server_continuous_batching(context):
context.server_continuous_batching = True
@@ -212,6 +230,8 @@ async def step_request_completion(context, api_error):
context.base_url,
debug=context.debug,
n_predict=context.n_predict,
+ cache_prompt=context.cache_prompt,
+ id_slot=context.id_slot,
seed=await completions_seed(context),
expect_api_error=expect_api_error,
user_api_key=context.user_api_key)
@@ -711,12 +731,48 @@ async def concurrent_requests(context, f_completion, *args, **kwargs):
await asyncio.sleep(0.1)
+@step('the slot {slot_id:d} is saved with filename "{filename}"')
+@async_run_until_complete
+async def step_save_slot(context, slot_id, filename):
+ async with aiohttp.ClientSession() as session:
+ async with session.post(f'{context.base_url}/slots/{slot_id}?action=save',
+ json={"filename": filename},
+ headers={"Content-Type": "application/json"}) as response:
+ context.response = response
+
+
+@step('the slot {slot_id:d} is restored with filename "{filename}"')
+@async_run_until_complete
+async def step_restore_slot(context, slot_id, filename):
+ async with aiohttp.ClientSession() as session:
+ async with session.post(f'{context.base_url}/slots/{slot_id}?action=restore',
+ json={"filename": filename},
+ headers={"Content-Type": "application/json"}) as response:
+ context.response = response
+
+
+@step('the slot {slot_id:d} is erased')
+@async_run_until_complete
+async def step_erase_slot(context, slot_id):
+ async with aiohttp.ClientSession() as session:
+ async with session.post(f'{context.base_url}/slots/{slot_id}?action=erase',
+ headers={"Content-Type": "application/json"}) as response:
+ context.response = response
+
+
+@step('the server responds with status code {status_code:d}')
+def step_server_responds_with_status_code(context, status_code):
+ assert context.response.status == status_code
+
+
async def request_completion(prompt,
base_url,
debug=False,
prompt_prefix=None,
prompt_suffix=None,
n_predict=None,
+ cache_prompt=False,
+ id_slot=None,
seed=None,
expect_api_error=None,
user_api_key=None):
@@ -738,6 +794,8 @@ async def request_completion(prompt,
"prompt": prompt,
"input_suffix": prompt_suffix,
"n_predict": n_predict if n_predict is not None else -1,
+ "cache_prompt": cache_prompt,
+ "id_slot": id_slot,
"seed": seed if seed is not None else 42
},
headers=headers,
@@ -1104,6 +1162,8 @@ def start_server_background(context):
server_args.extend(['--parallel', context.n_slots])
if context.n_server_predict:
server_args.extend(['--n-predict', context.n_server_predict])
+ if context.slot_save_path:
+ server_args.extend(['--slot-save-path', context.slot_save_path])
if context.server_api_key:
server_args.extend(['--api-key', context.server_api_key])
if context.n_ga: