diff options
Diffstat (limited to 'examples/server/tests/features/steps/steps.py')
-rw-r--r-- | examples/server/tests/features/steps/steps.py | 60 |
1 files changed, 60 insertions, 0 deletions
diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index 9a6cf7d6..ca400efa 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -49,6 +49,9 @@ def step_server_config(context, server_fqdn, server_port): context.n_predict = None context.n_prompts = 0 context.n_server_predict = None + context.slot_save_path = None + context.id_slot = None + context.cache_prompt = None context.n_slots = None context.prompt_prefix = None context.prompt_suffix = None @@ -119,6 +122,21 @@ def step_server_n_predict(context, n_predict): context.n_server_predict = n_predict +@step('{slot_save_path} as slot save path') +def step_slot_save_path(context, slot_save_path): + context.slot_save_path = slot_save_path + + +@step('using slot id {id_slot:d}') +def step_id_slot(context, id_slot): + context.id_slot = id_slot + + +@step('prompt caching is enabled') +def step_enable_prompt_cache(context): + context.cache_prompt = True + + @step('continuous batching') def step_server_continuous_batching(context): context.server_continuous_batching = True @@ -212,6 +230,8 @@ async def step_request_completion(context, api_error): context.base_url, debug=context.debug, n_predict=context.n_predict, + cache_prompt=context.cache_prompt, + id_slot=context.id_slot, seed=await completions_seed(context), expect_api_error=expect_api_error, user_api_key=context.user_api_key) @@ -711,12 +731,48 @@ async def concurrent_requests(context, f_completion, *args, **kwargs): await asyncio.sleep(0.1) +@step('the slot {slot_id:d} is saved with filename "{filename}"') +@async_run_until_complete +async def step_save_slot(context, slot_id, filename): + async with aiohttp.ClientSession() as session: + async with session.post(f'{context.base_url}/slots/{slot_id}?action=save', + json={"filename": filename}, + headers={"Content-Type": "application/json"}) as response: + context.response = response + + +@step('the slot {slot_id:d} is restored with filename "{filename}"') +@async_run_until_complete +async def step_restore_slot(context, slot_id, filename): + async with aiohttp.ClientSession() as session: + async with session.post(f'{context.base_url}/slots/{slot_id}?action=restore', + json={"filename": filename}, + headers={"Content-Type": "application/json"}) as response: + context.response = response + + +@step('the slot {slot_id:d} is erased') +@async_run_until_complete +async def step_erase_slot(context, slot_id): + async with aiohttp.ClientSession() as session: + async with session.post(f'{context.base_url}/slots/{slot_id}?action=erase', + headers={"Content-Type": "application/json"}) as response: + context.response = response + + +@step('the server responds with status code {status_code:d}') +def step_server_responds_with_status_code(context, status_code): + assert context.response.status == status_code + + async def request_completion(prompt, base_url, debug=False, prompt_prefix=None, prompt_suffix=None, n_predict=None, + cache_prompt=False, + id_slot=None, seed=None, expect_api_error=None, user_api_key=None): @@ -738,6 +794,8 @@ async def request_completion(prompt, "prompt": prompt, "input_suffix": prompt_suffix, "n_predict": n_predict if n_predict is not None else -1, + "cache_prompt": cache_prompt, + "id_slot": id_slot, "seed": seed if seed is not None else 42 }, headers=headers, @@ -1104,6 +1162,8 @@ def start_server_background(context): server_args.extend(['--parallel', context.n_slots]) if context.n_server_predict: server_args.extend(['--n-predict', context.n_server_predict]) + if context.slot_save_path: + server_args.extend(['--slot-save-path', context.slot_save_path]) if context.server_api_key: server_args.extend(['--api-key', context.server_api_key]) if context.n_ga: |