summaryrefslogtreecommitdiff
path: root/examples/server/tests/features/steps/steps.py
diff options
context:
space:
mode:
authorPierrick Hymbert <pierrick.hymbert@gmail.com>2024-03-17 19:12:37 +0100
committerGitHub <noreply@github.com>2024-03-17 19:12:37 +0100
commitd01b3c4c32357567f3531d4e6ceffc5d23e87583 (patch)
tree80e0a075a8b120d6b5b095a73cc36cb2a4535aed /examples/server/tests/features/steps/steps.py
parentcd776c37c945bf58efc8fe44b370456680cb1b59 (diff)
common: llama_load_model_from_url using --model-url (#6098)
* common: llama_load_model_from_url with libcurl dependency Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'examples/server/tests/features/steps/steps.py')
-rw-r--r--examples/server/tests/features/steps/steps.py37
1 files changed, 35 insertions, 2 deletions
diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py
index a59a52d2..9e348d5f 100644
--- a/examples/server/tests/features/steps/steps.py
+++ b/examples/server/tests/features/steps/steps.py
@@ -5,6 +5,8 @@ import os
import re
import socket
import subprocess
+import sys
+import threading
import time
from contextlib import closing
from re import RegexFlag
@@ -32,6 +34,8 @@ def step_server_config(context, server_fqdn, server_port):
context.base_url = f'http://{context.server_fqdn}:{context.server_port}'
context.model_alias = None
+ context.model_file = None
+ context.model_url = None
context.n_batch = None
context.n_ubatch = None
context.n_ctx = None
@@ -65,6 +69,16 @@ def step_download_hf_model(context, hf_file, hf_repo):
print(f"model file: {context.model_file}\n")
+@step('a model file {model_file}')
+def step_model_file(context, model_file):
+ context.model_file = model_file
+
+
+@step('a model url {model_url}')
+def step_model_url(context, model_url):
+ context.model_url = model_url
+
+
@step('a model alias {model_alias}')
def step_model_alias(context, model_alias):
context.model_alias = model_alias
@@ -141,7 +155,8 @@ def step_start_server(context):
async def step_wait_for_the_server_to_be_started(context, expecting_status):
match expecting_status:
case 'healthy':
- await wait_for_health_status(context, context.base_url, 200, 'ok')
+ await wait_for_health_status(context, context.base_url, 200, 'ok',
+ timeout=30)
case 'ready' | 'idle':
await wait_for_health_status(context, context.base_url, 200, 'ok',
@@ -1038,8 +1053,11 @@ def start_server_background(context):
server_args = [
'--host', server_listen_addr,
'--port', context.server_port,
- '--model', context.model_file
]
+ if context.model_file:
+ server_args.extend(['--model', context.model_file])
+ if context.model_url:
+ server_args.extend(['--model-url', context.model_url])
if context.n_batch:
server_args.extend(['--batch-size', context.n_batch])
if context.n_ubatch:
@@ -1079,8 +1097,23 @@ def start_server_background(context):
pkwargs = {
'creationflags': flags,
+ 'stdout': subprocess.PIPE,
+ 'stderr': subprocess.PIPE
}
context.server_process = subprocess.Popen(
[str(arg) for arg in [context.server_path, *server_args]],
**pkwargs)
+
+ def log_stdout(process):
+ for line in iter(process.stdout.readline, b''):
+ print(line.decode('utf-8'), end='')
+ thread_stdout = threading.Thread(target=log_stdout, args=(context.server_process,))
+ thread_stdout.start()
+
+ def log_stderr(process):
+ for line in iter(process.stderr.readline, b''):
+ print(line.decode('utf-8'), end='', file=sys.stderr)
+ thread_stderr = threading.Thread(target=log_stderr, args=(context.server_process,))
+ thread_stderr.start()
+
print(f"server pid={context.server_process.pid}, behave pid={os.getpid()}")