diff options
author | Pierrick Hymbert <pierrick.hymbert@gmail.com> | 2024-03-17 19:12:37 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-17 19:12:37 +0100 |
commit | d01b3c4c32357567f3531d4e6ceffc5d23e87583 (patch) | |
tree | 80e0a075a8b120d6b5b095a73cc36cb2a4535aed /examples/server/tests/features/steps/steps.py | |
parent | cd776c37c945bf58efc8fe44b370456680cb1b59 (diff) |
common: llama_load_model_from_url using --model-url (#6098)
* common: llama_load_model_from_url with libcurl dependency
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'examples/server/tests/features/steps/steps.py')
-rw-r--r-- | examples/server/tests/features/steps/steps.py | 37 |
1 files changed, 35 insertions, 2 deletions
diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index a59a52d2..9e348d5f 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -5,6 +5,8 @@ import os import re import socket import subprocess +import sys +import threading import time from contextlib import closing from re import RegexFlag @@ -32,6 +34,8 @@ def step_server_config(context, server_fqdn, server_port): context.base_url = f'http://{context.server_fqdn}:{context.server_port}' context.model_alias = None + context.model_file = None + context.model_url = None context.n_batch = None context.n_ubatch = None context.n_ctx = None @@ -65,6 +69,16 @@ def step_download_hf_model(context, hf_file, hf_repo): print(f"model file: {context.model_file}\n") +@step('a model file {model_file}') +def step_model_file(context, model_file): + context.model_file = model_file + + +@step('a model url {model_url}') +def step_model_url(context, model_url): + context.model_url = model_url + + @step('a model alias {model_alias}') def step_model_alias(context, model_alias): context.model_alias = model_alias @@ -141,7 +155,8 @@ def step_start_server(context): async def step_wait_for_the_server_to_be_started(context, expecting_status): match expecting_status: case 'healthy': - await wait_for_health_status(context, context.base_url, 200, 'ok') + await wait_for_health_status(context, context.base_url, 200, 'ok', + timeout=30) case 'ready' | 'idle': await wait_for_health_status(context, context.base_url, 200, 'ok', @@ -1038,8 +1053,11 @@ def start_server_background(context): server_args = [ '--host', server_listen_addr, '--port', context.server_port, - '--model', context.model_file ] + if context.model_file: + server_args.extend(['--model', context.model_file]) + if context.model_url: + server_args.extend(['--model-url', context.model_url]) if context.n_batch: server_args.extend(['--batch-size', context.n_batch]) if context.n_ubatch: @@ -1079,8 +1097,23 @@ def start_server_background(context): pkwargs = { 'creationflags': flags, + 'stdout': subprocess.PIPE, + 'stderr': subprocess.PIPE } context.server_process = subprocess.Popen( [str(arg) for arg in [context.server_path, *server_args]], **pkwargs) + + def log_stdout(process): + for line in iter(process.stdout.readline, b''): + print(line.decode('utf-8'), end='') + thread_stdout = threading.Thread(target=log_stdout, args=(context.server_process,)) + thread_stdout.start() + + def log_stderr(process): + for line in iter(process.stderr.readline, b''): + print(line.decode('utf-8'), end='', file=sys.stderr) + thread_stderr = threading.Thread(target=log_stderr, args=(context.server_process,)) + thread_stderr.start() + print(f"server pid={context.server_process.pid}, behave pid={os.getpid()}") |