diff options
author | Johannes Gäßler <johannesg@5d6.de> | 2024-04-24 11:08:36 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-24 11:08:36 +0200 |
commit | 28103f4832e301a9c84d44ff0df9d75d46ab6c76 (patch) | |
tree | 8ba391e3a7e0ce9a20d4b41782ef133bd7e32738 /examples/server/tests/features/steps | |
parent | c0d1b3e03e27634ac2871761f5033cf9324d472d (diff) |
Server: fix seed for multiple slots (#6835)
* Server: add tests for consistent results
* sampling: separate rng per sampling context
Diffstat (limited to 'examples/server/tests/features/steps')
-rw-r--r-- | examples/server/tests/features/steps/steps.py | 34 |
1 files changed, 34 insertions, 0 deletions
diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index ca400efa..f71e0d70 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -61,6 +61,7 @@ def step_server_config(context, server_fqdn, server_port): context.server_metrics = False context.server_process = None context.seed = None + context.draft = None context.server_seed = None context.user_api_key = None context.response_format = None @@ -107,6 +108,11 @@ def step_n_gpu_layer(context, ngl): context.n_gpu_layer = ngl +@step('{draft:d} as draft') +def step_draft(context, draft): + context.draft = draft + + @step('{n_ctx:d} KV cache size') def step_n_ctx(context, n_ctx): context.n_ctx = n_ctx @@ -254,6 +260,15 @@ def step_n_tokens_predicted(context, predicted_n): assert_n_tokens_predicted(context.completion, predicted_n) +@step('all predictions are equal') +@async_run_until_complete +async def step_predictions_equal(context): + n_completions = await gather_tasks_results(context) + assert n_completions >= 2, "need at least 2 completions" + assert_all_predictions_equal(context.tasks_result) + context.tasks_result = [] + + @step('the completion is truncated') def step_assert_completion_truncated(context): step_assert_completion_truncated(context, '') @@ -1020,6 +1035,23 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re assert n_predicted == expected_predicted_n, (f'invalid number of tokens predicted:' f' {n_predicted} <> {expected_predicted_n}') +def assert_all_predictions_equal(completion_responses): + content_0 = completion_responses[0]['content'] + + if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON': + print(f"content 0: {content_0}") + + i = 1 + for response in completion_responses[1:]: + content = response['content'] + + if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON': + print(f"content {i}: {content}") + + assert content == content_0, "contents not equal" + + i += 1 + async def gather_tasks_results(context): n_tasks = len(context.concurrent_tasks) @@ -1148,6 +1180,8 @@ def start_server_background(context): server_args.extend(['--ubatch-size', context.n_ubatch]) if context.n_gpu_layer: server_args.extend(['--n-gpu-layers', context.n_gpu_layer]) + if context.draft is not None: + server_args.extend(['--draft', context.draft]) if context.server_continuous_batching: server_args.append('--cont-batching') if context.server_embeddings: |