summaryrefslogtreecommitdiff
path: root/examples/server/tests/features/steps/steps.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/server/tests/features/steps/steps.py')
-rw-r--r--examples/server/tests/features/steps/steps.py86
1 files changed, 67 insertions, 19 deletions
diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py
index 31952780..a0b2ffdf 100644
--- a/examples/server/tests/features/steps/steps.py
+++ b/examples/server/tests/features/steps/steps.py
@@ -10,6 +10,7 @@ from contextlib import closing
from re import RegexFlag
import aiohttp
+import numpy as np
import openai
from behave import step
from behave.api.async_step import async_run_until_complete
@@ -24,6 +25,9 @@ def step_server_config(context, server_fqdn, server_port):
if 'PORT' in os.environ:
context.server_port = int(os.environ['PORT'])
print(f"$PORT set, overriding server port with to {context.server_port}")
+ if 'FQDN' in os.environ:
+ context.server_fqdn = os.environ['FQDN']
+ print(f"$FQDN set, overriding server fqdn with to {context.server_fqdn}")
context.base_url = f'http://{context.server_fqdn}:{context.server_port}'
@@ -34,6 +38,7 @@ def step_server_config(context, server_fqdn, server_port):
context.n_ga_w = None
context.n_gpu_layer = None
context.n_predict = None
+ context.n_prompts = 0
context.n_server_predict = None
context.n_slots = None
context.prompt_prefix = None
@@ -202,6 +207,7 @@ def step_n_tokens_predicted(context, predicted_n):
@step(u'a user prompt {user_prompt}')
def step_user_prompt(context, user_prompt):
context.prompts.append(user_prompt)
+ context.n_prompts = len(context.prompts)
@step(u'a system prompt {system_prompt}')
@@ -290,6 +296,12 @@ def step_prompt_passkey(context):
context.prompt_passkey = context.text
+@step(u'{n_prompts:d} fixed prompts')
+def step_fixed_prompts(context, n_prompts):
+ context.prompts.extend([str(0)*(context.n_batch if context.n_batch is not None else 512) for i in range(n_prompts)])
+ context.n_prompts = n_prompts
+
+
@step(u'a "{passkey}" passkey challenge prompt with the passkey inserted every {i_pos:d} junk')
def step_prompt_passkey(context, passkey, i_pos):
prompt = ""
@@ -301,6 +313,7 @@ def step_prompt_passkey(context, passkey, i_pos):
passkey_highlight = "\x1b[33m" + passkey + "\x1b[0m"
print(f"Passkey challenge:\n```{prompt.replace(passkey, passkey_highlight)}```\n")
context.prompts.append(context.prompt_prefix + prompt + context.prompt_suffix)
+ context.n_prompts = len(context.prompts)
@step(u'an OAI compatible chat completions request with {api_error} api error')
@@ -341,11 +354,13 @@ async def step_oai_chat_completions(context, api_error):
@step(u'a prompt')
def step_a_prompt(context):
context.prompts.append(context.text)
+ context.n_prompts = len(context.prompts)
@step(u'a prompt {prompt}')
def step_a_prompt_prompt(context, prompt):
context.prompts.append(prompt)
+ context.n_prompts = len(context.prompts)
@step(u'concurrent completion requests')
@@ -430,25 +445,47 @@ async def all_prompts_are_predicted(context, expected_predicted_n=None):
@step(u'embeddings are computed for')
@async_run_until_complete
async def step_compute_embedding(context):
+ context.n_prompts = 1
context.embeddings = await request_embedding(context.text, base_url=context.base_url)
+@step(u'all embeddings are the same')
+@async_run_until_complete
+async def step_all_embeddings_are_the_same(context):
+ n_embedding_requests = await gather_tasks_results(context)
+ assert n_embedding_requests > 0
+ embeddings = []
+ for i in range(n_embedding_requests):
+ embedding = context.tasks_result.pop().pop()
+ embeddings.append(embedding)
+ assert_embeddings(embedding)
+ n = len(embeddings)
+ for i in range(n-1):
+ for j in range(i+1, n):
+ embedding1 = np.array(embeddings[i])
+ embedding2 = np.array(embeddings[j])
+ if context.debug:
+ print(f"embedding1: {embedding1[-8:]}\n")
+ print(f"embedding2: {embedding2[-8:]}\n")
+ similarity = np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2))
+ msg = f"Similarity between {i} and {j}: {similarity:.10f}"
+ if context.debug:
+ print(f"{msg}\n")
+ assert np.isclose(similarity, 1.0, rtol=1e-05, atol=1e-08, equal_nan=False), msg
+
@step(u'embeddings are generated')
def step_assert_embeddings(context):
- if len(context.prompts) == 0:
- assert_embeddings(context.embeddings)
- else:
- assert len(context.embeddings) == len(context.prompts), (f"unexpected response:\n"
- f"context.prompts={context.prompts}\n"
- f"context.embeddings={context.embeddings}")
- for embedding in context.embeddings:
- context.prompts.pop()
- assert_embeddings(embedding)
+ assert context.n_prompts == len(context.embeddings), (f"unexpected response:\n"
+ f"context.n_prompts={context.n_prompts}\n"
+ f"context.embeddings={context.embeddings}")
+ for embedding in context.embeddings:
+ assert_embeddings(embedding)
@step(u'an OAI compatible embeddings computation request for')
@async_run_until_complete
async def step_oai_compute_embeddings(context):
+ context.n_prompts = 1
context.embeddings = await request_oai_embeddings(context.text,
base_url=context.base_url,
user_api_key=context.user_api_key,
@@ -462,6 +499,7 @@ async def step_oai_compute_embeddings_multiple_inputs(context):
base_url=context.base_url,
user_api_key=context.user_api_key,
model=context.model)
+ context.prompts.clear()
@step(u'concurrent embedding requests')
@@ -488,9 +526,9 @@ async def step_concurrent_oai_embedding_requests(context):
@async_run_until_complete()
async def all_embeddings_are_generated(context):
n_embedding_requests = await gather_tasks_results(context)
- assert n_embedding_requests > 0
+ assert n_embedding_requests == context.n_prompts
for i in range(n_embedding_requests):
- assert_embeddings(context.tasks_result.pop())
+ assert_embeddings(context.tasks_result.pop().pop())
@step(u'tokenizing')
@@ -588,11 +626,11 @@ def step_supported_models(context, i_model, param, preposition, param_value):
async def concurrent_requests(context, f_completion, *args, **kwargs):
- n_prompts = len(context.prompts)
+ context.n_prompts = len(context.prompts)
if context.debug:
- print(f"starting {n_prompts} concurrent completion requests...")
- assert n_prompts > 0
- for prompt_no in range(n_prompts):
+ print(f"starting {context.n_prompts} concurrent completion requests...")
+ assert context.n_prompts > 0
+ for prompt_no in range(context.n_prompts):
shifted_args = [context.prompts.pop(), *args]
context.concurrent_tasks.append(asyncio.create_task(f_completion(*shifted_args, **kwargs)))
await asyncio.sleep(0.1)
@@ -765,7 +803,7 @@ async def request_embedding(content, base_url=None):
}) as response:
assert response.status == 200
response_json = await response.json()
- return response_json['embedding']
+ return [response_json['embedding']]
async def request_oai_embeddings(input,
@@ -775,6 +813,7 @@ async def request_oai_embeddings(input,
user_api_key = user_api_key if user_api_key is not None else 'nope'
if async_client:
origin = 'llama.cpp'
+ headers=[]
if user_api_key is not None:
headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin}
async with aiohttp.ClientSession() as session:
@@ -783,14 +822,21 @@ async def request_oai_embeddings(input,
"input": input,
"model": model,
},
- headers=headers) as response:
+ headers=headers,
+ timeout=3600) as response:
assert response.status == 200, f"received status code not expected: {response.status}"
assert response.headers['Access-Control-Allow-Origin'] == origin
assert response.headers['Content-Type'] == "application/json; charset=utf-8"
response_json = await response.json()
assert response_json['model'] == model, f"invalid model received: {response_json['model']}"
assert response_json['object'] == 'list'
- return response_json['data']
+ if isinstance(input, collections.abc.Sequence):
+ embeddings = []
+ for an_oai_embeddings in response_json['data']:
+ embeddings.append(an_oai_embeddings['embedding'])
+ else:
+ embeddings = [response_json['data']['embedding']]
+ return embeddings
else:
openai.api_key = user_api_key
openai.api_base = f'{base_url}/v1'
@@ -804,7 +850,7 @@ async def request_oai_embeddings(input,
for an_oai_embeddings in oai_embeddings.data:
embeddings.append(an_oai_embeddings.embedding)
else:
- embeddings = oai_embeddings.data.embedding
+ embeddings = [oai_embeddings.data.embedding]
return embeddings
@@ -899,6 +945,8 @@ def assert_embeddings(embeddings):
assert len(embeddings) > 0
embeddings_computed = False
for emb in embeddings:
+ if not isinstance(emb, float):
+ assert False, f"Bad embeddings: {embeddings}"
if emb != 0:
embeddings_computed = True
assert embeddings_computed, f"Embeddings: {embeddings}"