summaryrefslogtreecommitdiff
path: root/examples/server/tests/features/steps
diff options
context:
space:
mode:
authorJohannes Gäßler <johannesg@5d6.de>2024-04-24 11:08:36 +0200
committerGitHub <noreply@github.com>2024-04-24 11:08:36 +0200
commit28103f4832e301a9c84d44ff0df9d75d46ab6c76 (patch)
tree8ba391e3a7e0ce9a20d4b41782ef133bd7e32738 /examples/server/tests/features/steps
parentc0d1b3e03e27634ac2871761f5033cf9324d472d (diff)
Server: fix seed for multiple slots (#6835)
* Server: add tests for consistent results * sampling: separate rng per sampling context
Diffstat (limited to 'examples/server/tests/features/steps')
-rw-r--r--examples/server/tests/features/steps/steps.py34
1 files changed, 34 insertions, 0 deletions
diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py
index ca400efa..f71e0d70 100644
--- a/examples/server/tests/features/steps/steps.py
+++ b/examples/server/tests/features/steps/steps.py
@@ -61,6 +61,7 @@ def step_server_config(context, server_fqdn, server_port):
context.server_metrics = False
context.server_process = None
context.seed = None
+ context.draft = None
context.server_seed = None
context.user_api_key = None
context.response_format = None
@@ -107,6 +108,11 @@ def step_n_gpu_layer(context, ngl):
context.n_gpu_layer = ngl
+@step('{draft:d} as draft')
+def step_draft(context, draft):
+ context.draft = draft
+
+
@step('{n_ctx:d} KV cache size')
def step_n_ctx(context, n_ctx):
context.n_ctx = n_ctx
@@ -254,6 +260,15 @@ def step_n_tokens_predicted(context, predicted_n):
assert_n_tokens_predicted(context.completion, predicted_n)
+@step('all predictions are equal')
+@async_run_until_complete
+async def step_predictions_equal(context):
+ n_completions = await gather_tasks_results(context)
+ assert n_completions >= 2, "need at least 2 completions"
+ assert_all_predictions_equal(context.tasks_result)
+ context.tasks_result = []
+
+
@step('the completion is truncated')
def step_assert_completion_truncated(context):
step_assert_completion_truncated(context, '')
@@ -1020,6 +1035,23 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re
assert n_predicted == expected_predicted_n, (f'invalid number of tokens predicted:'
f' {n_predicted} <> {expected_predicted_n}')
+def assert_all_predictions_equal(completion_responses):
+ content_0 = completion_responses[0]['content']
+
+ if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON':
+ print(f"content 0: {content_0}")
+
+ i = 1
+ for response in completion_responses[1:]:
+ content = response['content']
+
+ if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON':
+ print(f"content {i}: {content}")
+
+ assert content == content_0, "contents not equal"
+
+ i += 1
+
async def gather_tasks_results(context):
n_tasks = len(context.concurrent_tasks)
@@ -1148,6 +1180,8 @@ def start_server_background(context):
server_args.extend(['--ubatch-size', context.n_ubatch])
if context.n_gpu_layer:
server_args.extend(['--n-gpu-layers', context.n_gpu_layer])
+ if context.draft is not None:
+ server_args.extend(['--draft', context.draft])
if context.server_continuous_batching:
server_args.append('--cont-batching')
if context.server_embeddings: