diff options
Diffstat (limited to 'examples/server/tests/features/steps/steps.py')
-rw-r--r-- | examples/server/tests/features/steps/steps.py | 86 |
1 files changed, 67 insertions, 19 deletions
diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index 31952780..a0b2ffdf 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -10,6 +10,7 @@ from contextlib import closing from re import RegexFlag import aiohttp +import numpy as np import openai from behave import step from behave.api.async_step import async_run_until_complete @@ -24,6 +25,9 @@ def step_server_config(context, server_fqdn, server_port): if 'PORT' in os.environ: context.server_port = int(os.environ['PORT']) print(f"$PORT set, overriding server port with to {context.server_port}") + if 'FQDN' in os.environ: + context.server_fqdn = os.environ['FQDN'] + print(f"$FQDN set, overriding server fqdn with to {context.server_fqdn}") context.base_url = f'http://{context.server_fqdn}:{context.server_port}' @@ -34,6 +38,7 @@ def step_server_config(context, server_fqdn, server_port): context.n_ga_w = None context.n_gpu_layer = None context.n_predict = None + context.n_prompts = 0 context.n_server_predict = None context.n_slots = None context.prompt_prefix = None @@ -202,6 +207,7 @@ def step_n_tokens_predicted(context, predicted_n): @step(u'a user prompt {user_prompt}') def step_user_prompt(context, user_prompt): context.prompts.append(user_prompt) + context.n_prompts = len(context.prompts) @step(u'a system prompt {system_prompt}') @@ -290,6 +296,12 @@ def step_prompt_passkey(context): context.prompt_passkey = context.text +@step(u'{n_prompts:d} fixed prompts') +def step_fixed_prompts(context, n_prompts): + context.prompts.extend([str(0)*(context.n_batch if context.n_batch is not None else 512) for i in range(n_prompts)]) + context.n_prompts = n_prompts + + @step(u'a "{passkey}" passkey challenge prompt with the passkey inserted every {i_pos:d} junk') def step_prompt_passkey(context, passkey, i_pos): prompt = "" @@ -301,6 +313,7 @@ def step_prompt_passkey(context, passkey, i_pos): passkey_highlight = "\x1b[33m" + passkey + "\x1b[0m" print(f"Passkey challenge:\n```{prompt.replace(passkey, passkey_highlight)}```\n") context.prompts.append(context.prompt_prefix + prompt + context.prompt_suffix) + context.n_prompts = len(context.prompts) @step(u'an OAI compatible chat completions request with {api_error} api error') @@ -341,11 +354,13 @@ async def step_oai_chat_completions(context, api_error): @step(u'a prompt') def step_a_prompt(context): context.prompts.append(context.text) + context.n_prompts = len(context.prompts) @step(u'a prompt {prompt}') def step_a_prompt_prompt(context, prompt): context.prompts.append(prompt) + context.n_prompts = len(context.prompts) @step(u'concurrent completion requests') @@ -430,25 +445,47 @@ async def all_prompts_are_predicted(context, expected_predicted_n=None): @step(u'embeddings are computed for') @async_run_until_complete async def step_compute_embedding(context): + context.n_prompts = 1 context.embeddings = await request_embedding(context.text, base_url=context.base_url) +@step(u'all embeddings are the same') +@async_run_until_complete +async def step_all_embeddings_are_the_same(context): + n_embedding_requests = await gather_tasks_results(context) + assert n_embedding_requests > 0 + embeddings = [] + for i in range(n_embedding_requests): + embedding = context.tasks_result.pop().pop() + embeddings.append(embedding) + assert_embeddings(embedding) + n = len(embeddings) + for i in range(n-1): + for j in range(i+1, n): + embedding1 = np.array(embeddings[i]) + embedding2 = np.array(embeddings[j]) + if context.debug: + print(f"embedding1: {embedding1[-8:]}\n") + print(f"embedding2: {embedding2[-8:]}\n") + similarity = np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2)) + msg = f"Similarity between {i} and {j}: {similarity:.10f}" + if context.debug: + print(f"{msg}\n") + assert np.isclose(similarity, 1.0, rtol=1e-05, atol=1e-08, equal_nan=False), msg + @step(u'embeddings are generated') def step_assert_embeddings(context): - if len(context.prompts) == 0: - assert_embeddings(context.embeddings) - else: - assert len(context.embeddings) == len(context.prompts), (f"unexpected response:\n" - f"context.prompts={context.prompts}\n" - f"context.embeddings={context.embeddings}") - for embedding in context.embeddings: - context.prompts.pop() - assert_embeddings(embedding) + assert context.n_prompts == len(context.embeddings), (f"unexpected response:\n" + f"context.n_prompts={context.n_prompts}\n" + f"context.embeddings={context.embeddings}") + for embedding in context.embeddings: + assert_embeddings(embedding) @step(u'an OAI compatible embeddings computation request for') @async_run_until_complete async def step_oai_compute_embeddings(context): + context.n_prompts = 1 context.embeddings = await request_oai_embeddings(context.text, base_url=context.base_url, user_api_key=context.user_api_key, @@ -462,6 +499,7 @@ async def step_oai_compute_embeddings_multiple_inputs(context): base_url=context.base_url, user_api_key=context.user_api_key, model=context.model) + context.prompts.clear() @step(u'concurrent embedding requests') @@ -488,9 +526,9 @@ async def step_concurrent_oai_embedding_requests(context): @async_run_until_complete() async def all_embeddings_are_generated(context): n_embedding_requests = await gather_tasks_results(context) - assert n_embedding_requests > 0 + assert n_embedding_requests == context.n_prompts for i in range(n_embedding_requests): - assert_embeddings(context.tasks_result.pop()) + assert_embeddings(context.tasks_result.pop().pop()) @step(u'tokenizing') @@ -588,11 +626,11 @@ def step_supported_models(context, i_model, param, preposition, param_value): async def concurrent_requests(context, f_completion, *args, **kwargs): - n_prompts = len(context.prompts) + context.n_prompts = len(context.prompts) if context.debug: - print(f"starting {n_prompts} concurrent completion requests...") - assert n_prompts > 0 - for prompt_no in range(n_prompts): + print(f"starting {context.n_prompts} concurrent completion requests...") + assert context.n_prompts > 0 + for prompt_no in range(context.n_prompts): shifted_args = [context.prompts.pop(), *args] context.concurrent_tasks.append(asyncio.create_task(f_completion(*shifted_args, **kwargs))) await asyncio.sleep(0.1) @@ -765,7 +803,7 @@ async def request_embedding(content, base_url=None): }) as response: assert response.status == 200 response_json = await response.json() - return response_json['embedding'] + return [response_json['embedding']] async def request_oai_embeddings(input, @@ -775,6 +813,7 @@ async def request_oai_embeddings(input, user_api_key = user_api_key if user_api_key is not None else 'nope' if async_client: origin = 'llama.cpp' + headers=[] if user_api_key is not None: headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin} async with aiohttp.ClientSession() as session: @@ -783,14 +822,21 @@ async def request_oai_embeddings(input, "input": input, "model": model, }, - headers=headers) as response: + headers=headers, + timeout=3600) as response: assert response.status == 200, f"received status code not expected: {response.status}" assert response.headers['Access-Control-Allow-Origin'] == origin assert response.headers['Content-Type'] == "application/json; charset=utf-8" response_json = await response.json() assert response_json['model'] == model, f"invalid model received: {response_json['model']}" assert response_json['object'] == 'list' - return response_json['data'] + if isinstance(input, collections.abc.Sequence): + embeddings = [] + for an_oai_embeddings in response_json['data']: + embeddings.append(an_oai_embeddings['embedding']) + else: + embeddings = [response_json['data']['embedding']] + return embeddings else: openai.api_key = user_api_key openai.api_base = f'{base_url}/v1' @@ -804,7 +850,7 @@ async def request_oai_embeddings(input, for an_oai_embeddings in oai_embeddings.data: embeddings.append(an_oai_embeddings.embedding) else: - embeddings = oai_embeddings.data.embedding + embeddings = [oai_embeddings.data.embedding] return embeddings @@ -899,6 +945,8 @@ def assert_embeddings(embeddings): assert len(embeddings) > 0 embeddings_computed = False for emb in embeddings: + if not isinstance(emb, float): + assert False, f"Bad embeddings: {embeddings}" if emb != 0: embeddings_computed = True assert embeddings_computed, f"Embeddings: {embeddings}" |