diff options
Diffstat (limited to 'examples/server/tests/features/steps/steps.py')
-rw-r--r-- | examples/server/tests/features/steps/steps.py | 33 |
1 files changed, 17 insertions, 16 deletions
diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index 9e348d5f..40c97001 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -66,7 +66,7 @@ def step_server_config(context, server_fqdn, server_port): def step_download_hf_model(context, hf_file, hf_repo): context.model_file = hf_hub_download(repo_id=hf_repo, filename=hf_file) if context.debug: - print(f"model file: {context.model_file}\n") + print(f"model file: {context.model_file}") @step('a model file {model_file}') @@ -137,9 +137,12 @@ def step_start_server(context): if 'GITHUB_ACTIONS' in os.environ: max_attempts *= 2 + addrs = socket.getaddrinfo(context.server_fqdn, context.server_port, type=socket.SOCK_STREAM) + family, typ, proto, _, sockaddr = addrs[0] + while True: - with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock: - result = sock.connect_ex((context.server_fqdn, context.server_port)) + with closing(socket.socket(family, typ, proto)) as sock: + result = sock.connect_ex(sockaddr) if result == 0: print("\x1b[33;46mserver started!\x1b[0m") return @@ -209,7 +212,7 @@ async def step_request_completion(context, api_error): user_api_key=context.user_api_key) context.tasks_result.append(completion) if context.debug: - print(f"Completion response: {completion}\n") + print(f"Completion response: {completion}") if expect_api_error: assert completion == 401, f"completion must be an 401 status code: {completion}" @@ -354,7 +357,7 @@ def step_prompt_passkey(context, passkey, i_pos): prompt += context.prompt_junk_suffix if context.debug: passkey_highlight = "\x1b[33m" + passkey + "\x1b[0m" - print(f"Passkey challenge:\n```{prompt.replace(passkey, passkey_highlight)}```\n") + print(f"Passkey challenge:\n```{prompt.replace(passkey, passkey_highlight)}```") context.prompts.append(context.prompt_prefix + prompt + context.prompt_suffix) context.n_prompts = len(context.prompts) @@ -363,7 +366,7 @@ def step_prompt_passkey(context, passkey, i_pos): @async_run_until_complete async def step_oai_chat_completions(context, api_error): if context.debug: - print(f"Submitting OAI compatible completions request...\n") + print(f"Submitting OAI compatible completions request...") expect_api_error = api_error == 'raised' completion = await oai_chat_completions(context.prompts.pop(), context.system_prompt, @@ -508,12 +511,12 @@ async def step_all_embeddings_are_the_same(context): embedding1 = np.array(embeddings[i]) embedding2 = np.array(embeddings[j]) if context.debug: - print(f"embedding1: {embedding1[-8:]}\n") - print(f"embedding2: {embedding2[-8:]}\n") + print(f"embedding1: {embedding1[-8:]}") + print(f"embedding2: {embedding2[-8:]}") similarity = np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2)) msg = f"Similarity between {i} and {j}: {similarity:.10f}" if context.debug: - print(f"{msg}\n") + print(f"{msg}") assert np.isclose(similarity, 1.0, rtol=1e-05, atol=1e-08, equal_nan=False), msg @@ -630,7 +633,7 @@ async def step_prometheus_metrics_exported(context): metrics_raw = await metrics_response.text() metric_exported = False if context.debug: - print(f"/metrics answer:\n{metrics_raw}\n") + print(f"/metrics answer:\n{metrics_raw}") context.metrics = {} for metric in parser.text_string_to_metric_families(metrics_raw): match metric.name: @@ -932,7 +935,7 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re last_match = end highlighted += content[last_match:] if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON': - print(f"Checking completion response: {highlighted}\n") + print(f"Checking completion response: {highlighted}") assert last_match > 0, f'/{re_content}/ must match ```{highlighted}```' if expected_predicted_n and expected_predicted_n > 0: assert n_predicted == expected_predicted_n, (f'invalid number of tokens predicted:' @@ -942,7 +945,7 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re async def gather_tasks_results(context): n_tasks = len(context.concurrent_tasks) if context.debug: - print(f"Waiting for all {n_tasks} tasks results...\n") + print(f"Waiting for all {n_tasks} tasks results...") for task_no in range(n_tasks): context.tasks_result.append(await context.concurrent_tasks.pop()) n_completions = len(context.tasks_result) @@ -959,7 +962,7 @@ async def wait_for_health_status(context, slots_processing=None, expected_slots=None): if context.debug: - print(f"Starting checking for health for expected_health_status={expected_health_status}\n") + print(f"Starting checking for health for expected_health_status={expected_health_status}") interval = 0.5 counter = 0 if 'GITHUB_ACTIONS' in os.environ: @@ -1048,8 +1051,6 @@ def start_server_background(context): if 'LLAMA_SERVER_BIN_PATH' in os.environ: context.server_path = os.environ['LLAMA_SERVER_BIN_PATH'] server_listen_addr = context.server_fqdn - if os.name == 'nt': - server_listen_addr = '0.0.0.0' server_args = [ '--host', server_listen_addr, '--port', context.server_port, @@ -1088,7 +1089,7 @@ def start_server_background(context): server_args.append('--verbose') if 'SERVER_LOG_FORMAT_JSON' not in os.environ: server_args.extend(['--log-format', "text"]) - print(f"starting server with: {context.server_path} {server_args}\n") + print(f"starting server with: {context.server_path} {server_args}") flags = 0 if 'nt' == os.name: flags |= subprocess.DETACHED_PROCESS |