From beea6e1b16e783a0886e78dec01002a8c00db24d Mon Sep 17 00:00:00 2001 From: Jan Boon Date: Mon, 8 Apr 2024 20:43:30 +0800 Subject: llama : save and restore kv cache for single seq id (#6341) * llama : save and restore kv cache for single seq id * remove trailing whitespace * respond error in case there's no space in the kv cache * add kv seq save restore to test case * add --slot-save-path arg to enable save restore and restrict save location * Returning 0 for some cases, instead of asserting. * cleanup error cases * rename sequence state functions * rename state get set functions * add previous function names back in with DEPRECATED notice * update doc * adjust endpoints to preferred style * fix restoring zero cell count * handle seq rm return value * unused param * keep in the size check * fix return types * add server test case for slot save restore * cleanup * add cake * cleanup style * add special * removing a whole sequence never fails * move sequence state file functionality from server to llama to match session api and add version tags * catch exceptions on save as well * error log messages * check types for stricter restore * update server doc * readme : update API changes date * strict filename validation * move include, reject bom as well * also reject empty filename * reject whitespace and trailing dot --------- Co-authored-by: Martin Evans Co-authored-by: Georgi Gerganov --- examples/server/tests/features/steps/steps.py | 60 +++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) (limited to 'examples/server/tests/features/steps/steps.py') diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index 9a6cf7d6..ca400efa 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -49,6 +49,9 @@ def step_server_config(context, server_fqdn, server_port): context.n_predict = None context.n_prompts = 0 context.n_server_predict = None + context.slot_save_path = None + context.id_slot = None + context.cache_prompt = None context.n_slots = None context.prompt_prefix = None context.prompt_suffix = None @@ -119,6 +122,21 @@ def step_server_n_predict(context, n_predict): context.n_server_predict = n_predict +@step('{slot_save_path} as slot save path') +def step_slot_save_path(context, slot_save_path): + context.slot_save_path = slot_save_path + + +@step('using slot id {id_slot:d}') +def step_id_slot(context, id_slot): + context.id_slot = id_slot + + +@step('prompt caching is enabled') +def step_enable_prompt_cache(context): + context.cache_prompt = True + + @step('continuous batching') def step_server_continuous_batching(context): context.server_continuous_batching = True @@ -212,6 +230,8 @@ async def step_request_completion(context, api_error): context.base_url, debug=context.debug, n_predict=context.n_predict, + cache_prompt=context.cache_prompt, + id_slot=context.id_slot, seed=await completions_seed(context), expect_api_error=expect_api_error, user_api_key=context.user_api_key) @@ -711,12 +731,48 @@ async def concurrent_requests(context, f_completion, *args, **kwargs): await asyncio.sleep(0.1) +@step('the slot {slot_id:d} is saved with filename "{filename}"') +@async_run_until_complete +async def step_save_slot(context, slot_id, filename): + async with aiohttp.ClientSession() as session: + async with session.post(f'{context.base_url}/slots/{slot_id}?action=save', + json={"filename": filename}, + headers={"Content-Type": "application/json"}) as response: + context.response = response + + +@step('the slot {slot_id:d} is restored with filename "{filename}"') +@async_run_until_complete +async def step_restore_slot(context, slot_id, filename): + async with aiohttp.ClientSession() as session: + async with session.post(f'{context.base_url}/slots/{slot_id}?action=restore', + json={"filename": filename}, + headers={"Content-Type": "application/json"}) as response: + context.response = response + + +@step('the slot {slot_id:d} is erased') +@async_run_until_complete +async def step_erase_slot(context, slot_id): + async with aiohttp.ClientSession() as session: + async with session.post(f'{context.base_url}/slots/{slot_id}?action=erase', + headers={"Content-Type": "application/json"}) as response: + context.response = response + + +@step('the server responds with status code {status_code:d}') +def step_server_responds_with_status_code(context, status_code): + assert context.response.status == status_code + + async def request_completion(prompt, base_url, debug=False, prompt_prefix=None, prompt_suffix=None, n_predict=None, + cache_prompt=False, + id_slot=None, seed=None, expect_api_error=None, user_api_key=None): @@ -738,6 +794,8 @@ async def request_completion(prompt, "prompt": prompt, "input_suffix": prompt_suffix, "n_predict": n_predict if n_predict is not None else -1, + "cache_prompt": cache_prompt, + "id_slot": id_slot, "seed": seed if seed is not None else 42 }, headers=headers, @@ -1104,6 +1162,8 @@ def start_server_background(context): server_args.extend(['--parallel', context.n_slots]) if context.n_server_predict: server_args.extend(['--n-predict', context.n_server_predict]) + if context.slot_save_path: + server_args.extend(['--slot-save-path', context.slot_save_path]) if context.server_api_key: server_args.extend(['--api-key', context.server_api_key]) if context.n_ga: -- cgit v1.2.3