diff options
author | Jan Boon <jan.boon@kaetemi.be> | 2024-04-08 20:43:30 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-08 15:43:30 +0300 |
commit | beea6e1b16e783a0886e78dec01002a8c00db24d (patch) | |
tree | a7365b1e93145b78a8b4be72df959239aa8c0f0d /examples/server/tests/features | |
parent | 87fb5b4234d4b9c56ac94cf7aa229c8fd7defdb0 (diff) |
llama : save and restore kv cache for single seq id (#6341)
* llama : save and restore kv cache for single seq id
* remove trailing whitespace
* respond error in case there's no space in the kv cache
* add kv seq save restore to test case
* add --slot-save-path arg to enable save restore and restrict save location
* Returning 0 for some cases, instead of asserting.
* cleanup error cases
* rename sequence state functions
* rename state get set functions
* add previous function names back in with DEPRECATED notice
* update doc
* adjust endpoints to preferred style
* fix restoring zero cell count
* handle seq rm return value
* unused param
* keep in the size check
* fix return types
* add server test case for slot save restore
* cleanup
* add cake
* cleanup style
* add special
* removing a whole sequence never fails
* move sequence state file functionality from server to llama to match session api and add version tags
* catch exceptions on save as well
* error log messages
* check types for stricter restore
* update server doc
* readme : update API changes date
* strict filename validation
* move include, reject bom as well
* also reject empty filename
* reject whitespace and trailing dot
---------
Co-authored-by: Martin Evans <martindevans@gmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'examples/server/tests/features')
-rw-r--r-- | examples/server/tests/features/slotsave.feature | 58 | ||||
-rw-r--r-- | examples/server/tests/features/steps/steps.py | 60 |
2 files changed, 118 insertions, 0 deletions
diff --git a/examples/server/tests/features/slotsave.feature b/examples/server/tests/features/slotsave.feature new file mode 100644 index 00000000..1c281c07 --- /dev/null +++ b/examples/server/tests/features/slotsave.feature @@ -0,0 +1,58 @@ +@llama.cpp +@slotsave +Feature: llama.cpp server slot management + + Background: Server startup + Given a server listening on localhost:8080 + And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models + And prompt caching is enabled + And 2 slots + And . as slot save path + And 2048 KV cache size + And 42 as server seed + And 24 max tokens to predict + Then the server is starting + Then the server is healthy + + Scenario: Save and Restore Slot + # First prompt in slot 1 should be fully processed + Given a user prompt "What is the capital of France?" + And using slot id 1 + And a completion request with no api error + Then 24 tokens are predicted matching (Lily|cake) + And 22 prompt tokens are processed + When the slot 1 is saved with filename "slot1.bin" + Then the server responds with status code 200 + # Since we have cache, this should only process the last tokens + Given a user prompt "What is the capital of Germany?" + And a completion request with no api error + Then 24 tokens are predicted matching (Thank|special) + And 7 prompt tokens are processed + # Loading the original cache into slot 0, + # we should only be processing 1 prompt token and get the same output + When the slot 0 is restored with filename "slot1.bin" + Then the server responds with status code 200 + Given a user prompt "What is the capital of France?" + And using slot id 0 + And a completion request with no api error + Then 24 tokens are predicted matching (Lily|cake) + And 1 prompt tokens are processed + # For verification that slot 1 was not corrupted during slot 0 load, same thing + Given a user prompt "What is the capital of Germany?" + And using slot id 1 + And a completion request with no api error + Then 24 tokens are predicted matching (Thank|special) + And 1 prompt tokens are processed + + Scenario: Erase Slot + Given a user prompt "What is the capital of France?" + And using slot id 1 + And a completion request with no api error + Then 24 tokens are predicted matching (Lily|cake) + And 22 prompt tokens are processed + When the slot 1 is erased + Then the server responds with status code 200 + Given a user prompt "What is the capital of France?" + And a completion request with no api error + Then 24 tokens are predicted matching (Lily|cake) + And 22 prompt tokens are processed diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index 9a6cf7d6..ca400efa 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -49,6 +49,9 @@ def step_server_config(context, server_fqdn, server_port): context.n_predict = None context.n_prompts = 0 context.n_server_predict = None + context.slot_save_path = None + context.id_slot = None + context.cache_prompt = None context.n_slots = None context.prompt_prefix = None context.prompt_suffix = None @@ -119,6 +122,21 @@ def step_server_n_predict(context, n_predict): context.n_server_predict = n_predict +@step('{slot_save_path} as slot save path') +def step_slot_save_path(context, slot_save_path): + context.slot_save_path = slot_save_path + + +@step('using slot id {id_slot:d}') +def step_id_slot(context, id_slot): + context.id_slot = id_slot + + +@step('prompt caching is enabled') +def step_enable_prompt_cache(context): + context.cache_prompt = True + + @step('continuous batching') def step_server_continuous_batching(context): context.server_continuous_batching = True @@ -212,6 +230,8 @@ async def step_request_completion(context, api_error): context.base_url, debug=context.debug, n_predict=context.n_predict, + cache_prompt=context.cache_prompt, + id_slot=context.id_slot, seed=await completions_seed(context), expect_api_error=expect_api_error, user_api_key=context.user_api_key) @@ -711,12 +731,48 @@ async def concurrent_requests(context, f_completion, *args, **kwargs): await asyncio.sleep(0.1) +@step('the slot {slot_id:d} is saved with filename "{filename}"') +@async_run_until_complete +async def step_save_slot(context, slot_id, filename): + async with aiohttp.ClientSession() as session: + async with session.post(f'{context.base_url}/slots/{slot_id}?action=save', + json={"filename": filename}, + headers={"Content-Type": "application/json"}) as response: + context.response = response + + +@step('the slot {slot_id:d} is restored with filename "{filename}"') +@async_run_until_complete +async def step_restore_slot(context, slot_id, filename): + async with aiohttp.ClientSession() as session: + async with session.post(f'{context.base_url}/slots/{slot_id}?action=restore', + json={"filename": filename}, + headers={"Content-Type": "application/json"}) as response: + context.response = response + + +@step('the slot {slot_id:d} is erased') +@async_run_until_complete +async def step_erase_slot(context, slot_id): + async with aiohttp.ClientSession() as session: + async with session.post(f'{context.base_url}/slots/{slot_id}?action=erase', + headers={"Content-Type": "application/json"}) as response: + context.response = response + + +@step('the server responds with status code {status_code:d}') +def step_server_responds_with_status_code(context, status_code): + assert context.response.status == status_code + + async def request_completion(prompt, base_url, debug=False, prompt_prefix=None, prompt_suffix=None, n_predict=None, + cache_prompt=False, + id_slot=None, seed=None, expect_api_error=None, user_api_key=None): @@ -738,6 +794,8 @@ async def request_completion(prompt, "prompt": prompt, "input_suffix": prompt_suffix, "n_predict": n_predict if n_predict is not None else -1, + "cache_prompt": cache_prompt, + "id_slot": id_slot, "seed": seed if seed is not None else 42 }, headers=headers, @@ -1104,6 +1162,8 @@ def start_server_background(context): server_args.extend(['--parallel', context.n_slots]) if context.n_server_predict: server_args.extend(['--n-predict', context.n_server_predict]) + if context.slot_save_path: + server_args.extend(['--slot-save-path', context.slot_save_path]) if context.server_api_key: server_args.extend(['--api-key', context.server_api_key]) if context.n_ga: |