summaryrefslogtreecommitdiff
path: root/examples/server
diff options
context:
space:
mode:
authorJohannes Gäßler <johannesg@5d6.de>2024-05-01 17:52:55 +0200
committerGitHub <noreply@github.com>2024-05-01 17:52:55 +0200
commit3ea0d36000e2314baba7e9fc6a97f08670a6f7e4 (patch)
treec6fe0d93491ff5aa430bb5c349d08e772de63ce1 /examples/server
parent1613ef8d8eb2479ba55c4d598e08c8f3f18a0fed (diff)
Server: add tests for batch size, different seeds (#6950)
Diffstat (limited to 'examples/server')
-rw-r--r--examples/server/tests/features/results.feature88
-rw-r--r--examples/server/tests/features/steps/steps.py148
2 files changed, 156 insertions, 80 deletions
diff --git a/examples/server/tests/features/results.feature b/examples/server/tests/features/results.feature
index f17120f7..aa0b8d0c 100644
--- a/examples/server/tests/features/results.feature
+++ b/examples/server/tests/features/results.feature
@@ -7,44 +7,16 @@ Feature: Results
And a model file tinyllamas/split/stories15M-00001-of-00003.gguf from HF repo ggml-org/models
And a model file test-model-00001-of-00003.gguf
And 128 as batch size
- And 256 KV cache size
+ And 1024 KV cache size
And 128 max tokens to predict
+ And continuous batching
- Scenario Outline: Multi users completion
+ Scenario Outline: consistent results with same seed
Given <n_slots> slots
- And continuous batching
Then the server is starting
Then the server is healthy
- Given 42 as seed
- And a prompt:
- """
- Write a very long story about AI.
- """
-
- Given 42 as seed
- And a prompt:
- """
- Write a very long story about AI.
- """
-
- Given 42 as seed
- And a prompt:
- """
- Write a very long story about AI.
- """
-
- Given 42 as seed
- And a prompt:
- """
- Write a very long story about AI.
- """
-
- Given 42 as seed
- And a prompt:
- """
- Write a very long story about AI.
- """
+ Given 4 prompts "Title: Little Red Riding Hood But In Space\n\nSummary:" with seed 42
Given concurrent completion requests
Then the server is busy
@@ -55,3 +27,55 @@ Feature: Results
| n_slots |
| 1 |
| 2 |
+
+ Scenario Outline: different results with different seed
+ Given <n_slots> slots
+ Then the server is starting
+ Then the server is healthy
+
+ Given 1 prompts "Title: Little Red Riding Hood But In Space\n\nSummary:" with seed 42
+ Given 1 prompts "Title: Little Red Riding Hood But In Space\n\nSummary:" with seed 43
+ Given 1 prompts "Title: Little Red Riding Hood But In Space\n\nSummary:" with seed 44
+ Given 1 prompts "Title: Little Red Riding Hood But In Space\n\nSummary:" with seed 45
+
+ Given concurrent completion requests
+ Then the server is busy
+ Then the server is idle
+ And all slots are idle
+ Then all predictions are different
+ Examples:
+ | n_slots |
+ | 1 |
+ | 2 |
+
+ Scenario Outline: consistent results with same seed and varying batch size
+ Given 4 slots
+ And <temp> temperature
+ # And 0 as draft
+ Then the server is starting
+ Then the server is healthy
+
+ Given 1 prompts "Write a very long story about AI." with seed 42
+ And concurrent completion requests
+ # Then the server is busy # Not all slots will be utilized.
+ Then the server is idle
+ And all slots are idle
+
+ Given <n_parallel> prompts "Write a very long story about AI." with seed 42
+ And concurrent completion requests
+ # Then the server is busy # Not all slots will be utilized.
+ Then the server is idle
+ And all slots are idle
+
+ Then all predictions are equal
+ Examples:
+ | n_parallel | temp |
+ | 1 | 0.0 |
+ | 2 | 0.0 |
+ | 4 | 0.0 |
+ | 1 | 1.0 |
+ # FIXME: These tests fail on master. The problem seems to be the unified KV cache.
+ # See https://github.com/ggerganov/whisper.cpp/issues/1941#issuecomment-1986923227
+ # and https://github.com/ggerganov/llama.cpp/pull/6122#discussion_r1531405574 .
+ # | 2 | 1.0 |
+ # | 4 | 1.0 |
diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py
index f71e0d70..b8dbef21 100644
--- a/examples/server/tests/features/steps/steps.py
+++ b/examples/server/tests/features/steps/steps.py
@@ -65,6 +65,7 @@ def step_server_config(context, server_fqdn, server_port):
context.server_seed = None
context.user_api_key = None
context.response_format = None
+ context.temperature = None
context.tasks_result = []
context.concurrent_tasks = []
@@ -232,15 +233,17 @@ async def step_all_slots_status(context, expected_slot_status_string):
@async_run_until_complete
async def step_request_completion(context, api_error):
expect_api_error = api_error == 'raised'
+ seeds = await completions_seed(context, num_seeds=1)
completion = await request_completion(context.prompts.pop(),
+ seeds[0] if seeds is not None else seeds,
context.base_url,
debug=context.debug,
n_predict=context.n_predict,
cache_prompt=context.cache_prompt,
id_slot=context.id_slot,
- seed=await completions_seed(context),
expect_api_error=expect_api_error,
- user_api_key=context.user_api_key)
+ user_api_key=context.user_api_key,
+ temperature=context.temperature)
context.tasks_result.append(completion)
if context.debug:
print(f"Completion response: {completion}")
@@ -269,6 +272,15 @@ async def step_predictions_equal(context):
context.tasks_result = []
+@step('all predictions are different')
+@async_run_until_complete
+async def step_predictions_equal(context):
+ n_completions = await gather_tasks_results(context)
+ assert n_completions >= 2, "need at least 2 completions"
+ assert_all_predictions_different(context.tasks_result)
+ context.tasks_result = []
+
+
@step('the completion is truncated')
def step_assert_completion_truncated(context):
step_assert_completion_truncated(context, '')
@@ -311,6 +323,11 @@ def step_response_format(context, response_format):
context.response_format = json.loads(response_format)
+@step('{temperature:f} temperature')
+def step_temperature(context, temperature):
+ context.temperature = temperature
+
+
@step('streaming is {enable_streaming}')
def step_streaming(context, enable_streaming):
context.enable_streaming = enable_streaming == 'enabled'
@@ -353,7 +370,10 @@ def step_n_ubatch(context, n_ubatch):
@step('{seed:d} as seed')
def step_seed(context, seed):
- context.seed = seed
+ if context.seed is None:
+ context.seed = [seed]
+ else:
+ context.seed.append(seed)
@step('a prefix prompt')
@@ -413,7 +433,9 @@ async def step_oai_chat_completions(context, api_error):
if context.debug:
print(f"Submitting OAI compatible completions request...")
expect_api_error = api_error == 'raised'
+ seeds = await completions_seed(context, num_seeds=1),
completion = await oai_chat_completions(context.prompts.pop(),
+ seeds[0] if seeds is not None else seeds,
context.system_prompt,
context.base_url,
'/v1/chat',
@@ -429,8 +451,6 @@ async def step_oai_chat_completions(context, api_error):
response_format=context.response_format
if hasattr(context, 'response_format') else None,
- seed=await completions_seed(context),
-
user_api_key=context.user_api_key
if hasattr(context, 'user_api_key') else None,
@@ -457,20 +477,31 @@ def step_a_prompt_prompt(context, prompt):
context.n_prompts = len(context.prompts)
+@step('{num_prompts:d} prompts {prompt} with seed {seed:d}')
+def step_many_prompts(context, num_prompts, prompt, seed):
+ if context.seed is None:
+ context.seed = []
+ for _ in range(num_prompts):
+ context.seed.append(seed)
+ context.prompts.append(prompt)
+ context.n_prompts = len(context.prompts)
+
+
@step('concurrent completion requests')
@async_run_until_complete()
async def step_concurrent_completion_requests(context):
- await concurrent_requests(context,
- request_completion,
- # prompt is inserted automatically
- context.base_url,
- debug=context.debug,
- prompt_prefix=context.prompt_prefix,
- prompt_suffix=context.prompt_suffix,
- n_predict=context.n_predict if hasattr(context, 'n_predict') else None,
- seed=await completions_seed(context),
- user_api_key=context.user_api_key if hasattr(context,
- 'user_api_key') else None)
+ await concurrent_requests(
+ context,
+ request_completion,
+ # prompt is inserted automatically
+ context.base_url,
+ debug=context.debug,
+ prompt_prefix=context.prompt_prefix,
+ prompt_suffix=context.prompt_suffix,
+ n_predict=context.n_predict if hasattr(context, 'n_predict') else None,
+ user_api_key=context.user_api_key if hasattr(context, 'user_api_key') else None,
+ temperature=context.temperature,
+ )
@step('concurrent OAI completions requests')
@@ -490,7 +521,6 @@ async def step_oai_chat_completions(context):
if hasattr(context, 'enable_streaming') else None,
response_format=context.response_format
if hasattr(context, 'response_format') else None,
- seed=await completions_seed(context),
user_api_key=context.user_api_key
if hasattr(context, 'user_api_key') else None)
@@ -512,10 +542,6 @@ async def step_oai_chat_completions(context):
if hasattr(context, 'enable_streaming') else None,
response_format=context.response_format
if hasattr(context, 'response_format') else None,
- seed=context.seed
- if hasattr(context, 'seed') else
- context.server_seed
- if hasattr(context, 'server_seed') else None,
user_api_key=context.user_api_key
if hasattr(context, 'user_api_key') else None)
@@ -544,7 +570,7 @@ async def all_prompts_are_predicted(context, expected_predicted_n=None):
@async_run_until_complete
async def step_compute_embedding(context):
context.n_prompts = 1
- context.embeddings = await request_embedding(context_text(context), base_url=context.base_url)
+ context.embeddings = await request_embedding(context_text(context), None, base_url=context.base_url)
@step('all embeddings are the same')
@@ -585,7 +611,7 @@ def step_assert_embeddings(context):
@async_run_until_complete
async def step_oai_compute_embeddings(context):
context.n_prompts = 1
- context.embeddings = await request_oai_embeddings(context_text(context),
+ context.embeddings = await request_oai_embeddings(context_text(context), None,
base_url=context.base_url,
user_api_key=context.user_api_key,
model=context.model)
@@ -594,7 +620,7 @@ async def step_oai_compute_embeddings(context):
@step('an OAI compatible embeddings computation request for multiple inputs')
@async_run_until_complete
async def step_oai_compute_embeddings_multiple_inputs(context):
- context.embeddings = await request_oai_embeddings(context.prompts,
+ context.embeddings = await request_oai_embeddings(context.prompts, None,
base_url=context.base_url,
user_api_key=context.user_api_key,
model=context.model)
@@ -740,8 +766,9 @@ async def concurrent_requests(context, f_completion, *args, **kwargs):
if context.debug:
print(f"starting {context.n_prompts} concurrent completion requests...")
assert context.n_prompts > 0
+ seeds = await completions_seed(context)
for prompt_no in range(context.n_prompts):
- shifted_args = [context.prompts.pop(), *args]
+ shifted_args = [context.prompts.pop(), seeds[prompt_no], *args]
context.concurrent_tasks.append(asyncio.create_task(f_completion(*shifted_args, **kwargs)))
await asyncio.sleep(0.1)
@@ -781,6 +808,7 @@ def step_server_responds_with_status_code(context, status_code):
async def request_completion(prompt,
+ seed,
base_url,
debug=False,
prompt_prefix=None,
@@ -788,9 +816,9 @@ async def request_completion(prompt,
n_predict=None,
cache_prompt=False,
id_slot=None,
- seed=None,
expect_api_error=None,
- user_api_key=None):
+ user_api_key=None,
+ temperature=None):
if debug:
print(f"Sending completion request: {prompt}")
origin = "my.super.domain"
@@ -811,7 +839,8 @@ async def request_completion(prompt,
"n_predict": n_predict if n_predict is not None else -1,
"cache_prompt": cache_prompt,
"id_slot": id_slot,
- "seed": seed if seed is not None else 42
+ "seed": seed if seed is not None else 42,
+ "temperature": temperature if temperature is not None else "0.8f",
},
headers=headers,
timeout=3600) as response:
@@ -824,6 +853,7 @@ async def request_completion(prompt,
async def oai_chat_completions(user_prompt,
+ seed,
system_prompt,
base_url,
base_path,
@@ -833,7 +863,6 @@ async def oai_chat_completions(user_prompt,
n_predict=None,
enable_streaming=None,
response_format=None,
- seed=None,
user_api_key=None,
expect_api_error=None):
if debug:
@@ -952,7 +981,7 @@ async def oai_chat_completions(user_prompt,
return completion_response
-async def request_embedding(content, base_url=None):
+async def request_embedding(content, seed, base_url=None):
async with aiohttp.ClientSession() as session:
async with session.post(f'{base_url}/embedding',
json={
@@ -963,7 +992,7 @@ async def request_embedding(content, base_url=None):
return [response_json['embedding']]
-async def request_oai_embeddings(input,
+async def request_oai_embeddings(input, seed,
base_url=None, user_api_key=None,
model=None, async_client=False):
# openai client always expects an api_key
@@ -1036,21 +1065,31 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re
f' {n_predicted} <> {expected_predicted_n}')
def assert_all_predictions_equal(completion_responses):
- content_0 = completion_responses[0]['content']
-
if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON':
- print(f"content 0: {content_0}")
-
- i = 1
- for response in completion_responses[1:]:
- content = response['content']
-
- if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON':
- print(f"content {i}: {content}")
-
- assert content == content_0, "contents not equal"
-
- i += 1
+ for i, response_i in enumerate(completion_responses):
+ content_i = response_i['content']
+ print(f"content {i}: {content_i}")
+ for i, response_i in enumerate(completion_responses):
+ content_i = response_i['content']
+ for j, response_j in enumerate(completion_responses):
+ if i == j:
+ continue
+ content_j = response_j['content']
+ assert content_i == content_j, "contents not equal"
+
+
+def assert_all_predictions_different(completion_responses):
+ if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON':
+ for i, response_i in enumerate(completion_responses):
+ content_i = response_i['content']
+ print(f"content {i}: {content_i}")
+ for i, response_i in enumerate(completion_responses):
+ content_i = response_i['content']
+ for j, response_j in enumerate(completion_responses):
+ if i == j:
+ continue
+ content_j = response_j['content']
+ assert content_i != content_j, "contents not different"
async def gather_tasks_results(context):
@@ -1145,9 +1184,22 @@ def assert_slots_status(slots, expected_slots):
f" = {expected[key]} != {slot[key]}")
-async def completions_seed(context):
- return context.seed if hasattr(context, 'seed') and context.seed is not None \
- else context.server_seed if hasattr(context, 'server_seed') else None
+async def completions_seed(context, num_seeds=None):
+ if hasattr(context, "seed") and context.seed is not None:
+ assert len(context.seed) == context.n_prompts
+ if num_seeds is None:
+ num_seeds = context.n_prompts
+ assert num_seeds <= context.n_prompts
+ seeds = context.seed[:num_seeds]
+ context.seed = context.seed[num_seeds:] if num_seeds < context.n_prompts else None
+ return seeds
+
+ if hasattr(context, "server_seed") and context.server_seed is not None:
+ if num_seeds is None:
+ return [context.server_seed] * context.n_prompts
+ else:
+ return [context.server_seed] * num_seeds
+ return None
def context_text(context):