summaryrefslogtreecommitdiff
path: root/examples/server/tests/features/steps/steps.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/server/tests/features/steps/steps.py')
-rw-r--r--examples/server/tests/features/steps/steps.py35
1 files changed, 29 insertions, 6 deletions
diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py
index d7f00583..0076f805 100644
--- a/examples/server/tests/features/steps/steps.py
+++ b/examples/server/tests/features/steps/steps.py
@@ -196,12 +196,30 @@ async def step_request_completion(context, api_error):
@step(u'{predicted_n:d} tokens are predicted matching {re_content}')
def step_n_tokens_predicted_with_content(context, predicted_n, re_content):
- assert_n_tokens_predicted(context.tasks_result.pop(), predicted_n, re_content)
+ context.completion = context.tasks_result.pop()
+ assert_n_tokens_predicted(context.completion, predicted_n, re_content)
@step(u'{predicted_n:d} tokens are predicted')
def step_n_tokens_predicted(context, predicted_n):
- assert_n_tokens_predicted(context.tasks_result.pop(), predicted_n)
+ context.completion = context.tasks_result.pop()
+ assert_n_tokens_predicted(context.completion, predicted_n)
+
+
+@step(u'the completion is truncated')
+def step_assert_completion_truncated(context):
+ step_assert_completion_truncated(context, '')
+
+
+@step(u'the completion is {truncated} truncated')
+def step_assert_completion_truncated(context, truncated):
+ truncated = truncated != "not"
+ assert context.completion['truncated'] == truncated, f'{context.completion}'
+
+
+@step(u'{n_prompt:d} prompt tokens are processed')
+def step_impl(context, n_prompt):
+ assert n_prompt < 0 or n_prompt == context.completion['timings']['prompt_n'], f"n_prompt={context.completion['timings']['prompt_n']}"
@step(u'a user prompt {user_prompt}')
@@ -722,7 +740,8 @@ async def oai_chat_completions(user_prompt,
completion_response = {
'content': '',
'timings': {
- 'predicted_n': 0
+ 'predicted_n': 0,
+ 'prompt_n': 0
}
}
if async_client:
@@ -763,7 +782,8 @@ async def oai_chat_completions(user_prompt,
completion_response = {
'content': chat_completion_raw['choices'][0]['message'],
'timings': {
- 'predicted_n': chat_completion_raw['usage']['completion_tokens']
+ 'predicted_n': chat_completion_raw['usage']['completion_tokens'],
+ 'prompt_n': chat_completion_raw['usage']['prompt_tokens']
}
}
else:
@@ -792,13 +812,16 @@ async def oai_chat_completions(user_prompt,
if 'content' in delta:
completion_response['content'] += delta['content']
completion_response['timings']['predicted_n'] += 1
+ completion_response['truncated'] = chunk.choices[0].finish_reason != 'stop'
else:
assert len(chat_completion.choices) == 1
completion_response = {
'content': chat_completion.choices[0].message.content,
'timings': {
- 'predicted_n': chat_completion.usage.completion_tokens
- }
+ 'predicted_n': chat_completion.usage.completion_tokens,
+ 'prompt_n': chat_completion.usage.prompt_tokens
+ },
+ 'truncated': chat_completion.choices[0].finish_reason != 'stop'
}
if debug:
print("OAI response formatted to llama.cpp:", completion_response)