diff options
Diffstat (limited to 'examples/server/tests/features/steps/steps.py')
-rw-r--r-- | examples/server/tests/features/steps/steps.py | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index 42cdd898..03f55f65 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -59,6 +59,7 @@ def step_server_config(context, server_fqdn, server_port): context.seed = None context.server_seed = None context.user_api_key = None + context.response_format = None context.tasks_result = [] context.concurrent_tasks = [] @@ -269,6 +270,11 @@ def step_max_tokens(context, max_tokens): context.n_predict = max_tokens +@step('a response format {response_format}') +def step_response_format(context, response_format): + context.response_format = json.loads(response_format) + + @step('streaming is {enable_streaming}') def step_streaming(context, enable_streaming): context.enable_streaming = enable_streaming == 'enabled' @@ -384,6 +390,9 @@ async def step_oai_chat_completions(context, api_error): enable_streaming=context.enable_streaming if hasattr(context, 'enable_streaming') else None, + response_format=context.response_format + if hasattr(context, 'response_format') else None, + seed=await completions_seed(context), user_api_key=context.user_api_key @@ -443,6 +452,8 @@ async def step_oai_chat_completions(context): if hasattr(context, 'n_predict') else None, enable_streaming=context.enable_streaming if hasattr(context, 'enable_streaming') else None, + response_format=context.response_format + if hasattr(context, 'response_format') else None, seed=await completions_seed(context), user_api_key=context.user_api_key if hasattr(context, 'user_api_key') else None) @@ -463,6 +474,8 @@ async def step_oai_chat_completions(context): if hasattr(context, 'n_predict') else None, enable_streaming=context.enable_streaming if hasattr(context, 'enable_streaming') else None, + response_format=context.response_format + if hasattr(context, 'response_format') else None, seed=context.seed if hasattr(context, 'seed') else context.server_seed @@ -745,6 +758,7 @@ async def oai_chat_completions(user_prompt, model=None, n_predict=None, enable_streaming=None, + response_format=None, seed=None, user_api_key=None, expect_api_error=None): @@ -770,6 +784,8 @@ async def oai_chat_completions(user_prompt, "stream": enable_streaming, "seed": seed } + if response_format is not None: + payload['response_format'] = response_format completion_response = { 'content': '', 'timings': { @@ -830,6 +846,7 @@ async def oai_chat_completions(user_prompt, model=model, max_tokens=n_predict, stream=enable_streaming, + response_format=payload.get('response_format'), seed=seed ) except openai.error.AuthenticationError as e: |