summaryrefslogtreecommitdiff
path: root/examples/server/tests/features/steps/steps.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/server/tests/features/steps/steps.py')
-rw-r--r--examples/server/tests/features/steps/steps.py100
1 files changed, 51 insertions, 49 deletions
diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py
index 7b5dabb0..df0814cc 100644
--- a/examples/server/tests/features/steps/steps.py
+++ b/examples/server/tests/features/steps/steps.py
@@ -1,5 +1,4 @@
import asyncio
-import collections
import json
import os
import re
@@ -8,19 +7,23 @@ import subprocess
import sys
import threading
import time
+from collections.abc import Sequence
from contextlib import closing
from re import RegexFlag
+from typing import Any, Literal, cast
import aiohttp
import numpy as np
import openai
-from behave import step
+from openai.types.chat import ChatCompletionChunk
+from behave import step # pyright: ignore[reportAttributeAccessIssue]
from behave.api.async_step import async_run_until_complete
from prometheus_client import parser
+# pyright: reportRedeclaration=false
@step("a server listening on {server_fqdn}:{server_port}")
-def step_server_config(context, server_fqdn, server_port):
+def step_server_config(context, server_fqdn: str, server_port: str):
context.server_fqdn = server_fqdn
context.server_port = int(server_port)
context.n_threads = None
@@ -74,34 +77,34 @@ def step_server_config(context, server_fqdn, server_port):
@step('a model file {hf_file} from HF repo {hf_repo}')
-def step_download_hf_model(context, hf_file, hf_repo):
+def step_download_hf_model(context, hf_file: str, hf_repo: str):
context.model_hf_repo = hf_repo
context.model_hf_file = hf_file
context.model_file = os.path.basename(hf_file)
@step('a model file {model_file}')
-def step_model_file(context, model_file):
+def step_model_file(context, model_file: str):
context.model_file = model_file
@step('a model url {model_url}')
-def step_model_url(context, model_url):
+def step_model_url(context, model_url: str):
context.model_url = model_url
@step('a model alias {model_alias}')
-def step_model_alias(context, model_alias):
+def step_model_alias(context, model_alias: str):
context.model_alias = model_alias
@step('{seed:d} as server seed')
-def step_seed(context, seed):
+def step_seed(context, seed: int):
context.server_seed = seed
@step('{ngl:d} GPU offloaded layers')
-def step_n_gpu_layer(context, ngl):
+def step_n_gpu_layer(context, ngl: int):
if 'N_GPU_LAYERS' in os.environ:
new_ngl = int(os.environ['N_GPU_LAYERS'])
if context.debug:
@@ -111,37 +114,37 @@ def step_n_gpu_layer(context, ngl):
@step('{n_threads:d} threads')
-def step_n_threads(context, n_threads):
+def step_n_threads(context, n_threads: int):
context.n_thread = n_threads
@step('{draft:d} as draft')
-def step_draft(context, draft):
+def step_draft(context, draft: int):
context.draft = draft
@step('{n_ctx:d} KV cache size')
-def step_n_ctx(context, n_ctx):
+def step_n_ctx(context, n_ctx: int):
context.n_ctx = n_ctx
@step('{n_slots:d} slots')
-def step_n_slots(context, n_slots):
+def step_n_slots(context, n_slots: int):
context.n_slots = n_slots
@step('{n_predict:d} server max tokens to predict')
-def step_server_n_predict(context, n_predict):
+def step_server_n_predict(context, n_predict: int):
context.n_server_predict = n_predict
@step('{slot_save_path} as slot save path')
-def step_slot_save_path(context, slot_save_path):
+def step_slot_save_path(context, slot_save_path: str):
context.slot_save_path = slot_save_path
@step('using slot id {id_slot:d}')
-def step_id_slot(context, id_slot):
+def step_id_slot(context, id_slot: int):
context.id_slot = id_slot
@@ -191,7 +194,7 @@ def step_start_server(context):
@step("the server is {expecting_status}")
@async_run_until_complete
-async def step_wait_for_the_server_to_be_started(context, expecting_status):
+async def step_wait_for_the_server_to_be_started(context, expecting_status: Literal['healthy', 'ready', 'idle', 'busy'] | str):
match expecting_status:
case 'healthy':
await wait_for_health_status(context, context.base_url, 200, 'ok',
@@ -221,7 +224,7 @@ async def step_wait_for_the_server_to_be_started(context, expecting_status):
@step('all slots are {expected_slot_status_string}')
@async_run_until_complete
-async def step_all_slots_status(context, expected_slot_status_string):
+async def step_all_slots_status(context, expected_slot_status_string: Literal['idle', 'busy'] | str):
match expected_slot_status_string:
case 'idle':
expected_slot_status = 0
@@ -237,7 +240,7 @@ async def step_all_slots_status(context, expected_slot_status_string):
@step('a completion request with {api_error} api error')
@async_run_until_complete
-async def step_request_completion(context, api_error):
+async def step_request_completion(context, api_error: Literal['raised'] | str):
expect_api_error = api_error == 'raised'
seeds = await completions_seed(context, num_seeds=1)
completion = await request_completion(context.prompts.pop(),
@@ -777,8 +780,8 @@ def step_assert_metric_value(context, metric_name, metric_value):
def step_available_models(context):
# openai client always expects an api_key
openai.api_key = context.user_api_key if context.user_api_key is not None else 'nope'
- openai.api_base = f'{context.base_url}/v1'
- context.models = openai.Model.list().data
+ openai.base_url = f'{context.base_url}/v1/'
+ context.models = openai.models.list().data
@step('{n_model:d} models are supported')
@@ -789,7 +792,7 @@ def step_supported_models(context, n_model):
@step('model {i_model:d} is {param} {preposition} {param_value}')
-def step_supported_models(context, i_model, param, preposition, param_value):
+def step_supported_models(context, i_model: int, param: Literal['identified', 'trained'] | str, preposition: str, param_value: str):
assert i_model < len(context.models)
model = context.models[i_model]
@@ -798,7 +801,7 @@ def step_supported_models(context, i_model, param, preposition, param_value):
case 'identified':
value = model.id
case 'trained':
- value = str(model.meta.n_ctx_train)
+ value = str(model.meta["n_ctx_train"])
case _:
assert False, "param {param} not supported"
assert param_value == value, f"model param {param} {value} != {param_value}"
@@ -810,6 +813,7 @@ async def concurrent_requests(context, f_completion, *args, **kwargs):
print(f"starting {context.n_prompts} concurrent completion requests...")
assert context.n_prompts > 0
seeds = await completions_seed(context)
+ assert seeds is not None
for prompt_no in range(context.n_prompts):
shifted_args = [context.prompts.pop(), seeds[prompt_no], *args]
context.concurrent_tasks.append(asyncio.create_task(f_completion(*shifted_args, **kwargs)))
@@ -861,7 +865,7 @@ async def request_completion(prompt,
id_slot=None,
expect_api_error=None,
user_api_key=None,
- temperature=None):
+ temperature=None) -> int | dict[str, Any]:
if debug:
print(f"Sending completion request: {prompt}")
origin = "my.super.domain"
@@ -899,8 +903,8 @@ async def request_completion(prompt,
async def oai_chat_completions(user_prompt,
seed,
system_prompt,
- base_url,
- base_path,
+ base_url: str,
+ base_path: str,
async_client,
debug=False,
temperature=None,
@@ -909,7 +913,7 @@ async def oai_chat_completions(user_prompt,
enable_streaming=None,
response_format=None,
user_api_key=None,
- expect_api_error=None):
+ expect_api_error=None) -> int | dict[str, Any]:
if debug:
print(f"Sending OAI Chat completions request: {user_prompt}")
# openai client always expects an api key
@@ -989,32 +993,35 @@ async def oai_chat_completions(user_prompt,
else:
try:
openai.api_key = user_api_key
- openai.api_base = f'{base_url}{base_path}'
- chat_completion = openai.Completion.create(
+ openai.base_url = f'{base_url}{base_path.removesuffix("chat")}'
+ assert model is not None
+ chat_completion = openai.chat.completions.create(
messages=payload['messages'],
model=model,
max_tokens=n_predict,
stream=enable_streaming,
- response_format=payload.get('response_format'),
+ response_format=payload.get('response_format') or openai.NOT_GIVEN,
seed=seed,
temperature=payload['temperature']
)
- except openai.error.AuthenticationError as e:
+ except openai.AuthenticationError as e:
if expect_api_error is not None and expect_api_error:
return 401
else:
assert False, f'error raised: {e}'
if enable_streaming:
+ chat_completion = cast(openai.Stream[ChatCompletionChunk], chat_completion)
for chunk in chat_completion:
assert len(chunk.choices) == 1
delta = chunk.choices[0].delta
- if 'content' in delta:
- completion_response['content'] += delta['content']
+ if delta.content is not None:
+ completion_response['content'] += delta.content
completion_response['timings']['predicted_n'] += 1
completion_response['truncated'] = chunk.choices[0].finish_reason != 'stop'
else:
assert len(chat_completion.choices) == 1
+ assert chat_completion.usage is not None
completion_response = {
'content': chat_completion.choices[0].message.content,
'timings': {
@@ -1028,7 +1035,7 @@ async def oai_chat_completions(user_prompt,
return completion_response
-async def request_embedding(content, seed, base_url=None):
+async def request_embedding(content, seed, base_url=None) -> list[list[float]]:
async with aiohttp.ClientSession() as session:
async with session.post(f'{base_url}/embedding',
json={
@@ -1041,7 +1048,7 @@ async def request_embedding(content, seed, base_url=None):
async def request_oai_embeddings(input, seed,
base_url=None, user_api_key=None,
- model=None, async_client=False):
+ model=None, async_client=False) -> list[list[float]]:
# openai client always expects an api_key
user_api_key = user_api_key if user_api_key is not None else 'nope'
if async_client:
@@ -1063,7 +1070,7 @@ async def request_oai_embeddings(input, seed,
response_json = await response.json()
assert response_json['model'] == model, f"invalid model received: {response_json['model']}"
assert response_json['object'] == 'list'
- if isinstance(input, collections.abc.Sequence):
+ if isinstance(input, Sequence):
embeddings = []
for an_oai_embeddings in response_json['data']:
embeddings.append(an_oai_embeddings['embedding'])
@@ -1072,19 +1079,14 @@ async def request_oai_embeddings(input, seed,
return embeddings
else:
openai.api_key = user_api_key
- openai.api_base = f'{base_url}/v1'
- oai_embeddings = openai.Embedding.create(
+ openai.base_url = f'{base_url}/v1/'
+ assert model is not None
+ oai_embeddings = openai.embeddings.create(
model=model,
input=input,
)
- if isinstance(input, collections.abc.Sequence):
- embeddings = []
- for an_oai_embeddings in oai_embeddings.data:
- embeddings.append(an_oai_embeddings.embedding)
- else:
- embeddings = [oai_embeddings.data.embedding]
- return embeddings
+ return [e.embedding for e in oai_embeddings.data]
def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re_content=None):
@@ -1122,7 +1124,7 @@ def assert_all_predictions_equal(completion_responses):
if i == j:
continue
content_j = response_j['content']
- assert content_i == content_j, "contents not equal"
+ assert content_i == content_j, "contents not equal"
def assert_all_predictions_different(completion_responses):
@@ -1136,7 +1138,7 @@ def assert_all_predictions_different(completion_responses):
if i == j:
continue
content_j = response_j['content']
- assert content_i != content_j, "contents not different"
+ assert content_i != content_j, "contents not different"
def assert_all_token_probabilities_equal(completion_responses):
@@ -1153,7 +1155,7 @@ def assert_all_token_probabilities_equal(completion_responses):
if i == j:
continue
probs_j = response_j['completion_probabilities'][pos]['probs']
- assert probs_i == probs_j, "contents not equal"
+ assert probs_i == probs_j, "contents not equal"
async def gather_tasks_results(context):
@@ -1343,7 +1345,7 @@ def start_server_background(context):
}
context.server_process = subprocess.Popen(
[str(arg) for arg in [context.server_path, *server_args]],
- **pkwargs)
+ **pkwargs) # pyright: ignore[reportArgumentType, reportCallIssue]
def server_log(in_stream, out_stream):
for line in iter(in_stream.readline, b''):