summaryrefslogtreecommitdiff
path: root/examples/server/tests/features/steps/steps.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/server/tests/features/steps/steps.py')
-rw-r--r--examples/server/tests/features/steps/steps.py13
1 files changed, 9 insertions, 4 deletions
diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py
index 03f55f65..86c3339d 100644
--- a/examples/server/tests/features/steps/steps.py
+++ b/examples/server/tests/features/steps/steps.py
@@ -16,7 +16,6 @@ import numpy as np
import openai
from behave import step
from behave.api.async_step import async_run_until_complete
-from huggingface_hub import hf_hub_download
from prometheus_client import parser
@@ -39,6 +38,8 @@ def step_server_config(context, server_fqdn, server_port):
context.model_alias = None
context.model_file = None
+ context.model_hf_repo = None
+ context.model_hf_file = None
context.model_url = None
context.n_batch = None
context.n_ubatch = None
@@ -68,9 +69,9 @@ def step_server_config(context, server_fqdn, server_port):
@step('a model file {hf_file} from HF repo {hf_repo}')
def step_download_hf_model(context, hf_file, hf_repo):
- context.model_file = hf_hub_download(repo_id=hf_repo, filename=hf_file)
- if context.debug:
- print(f"model file: {context.model_file}")
+ context.model_hf_repo = hf_repo
+ context.model_hf_file = hf_file
+ context.model_file = os.path.basename(hf_file)
@step('a model file {model_file}')
@@ -1079,6 +1080,10 @@ def start_server_background(context):
server_args.extend(['--model', context.model_file])
if context.model_url:
server_args.extend(['--model-url', context.model_url])
+ if context.model_hf_repo:
+ server_args.extend(['--hf-repo', context.model_hf_repo])
+ if context.model_hf_file:
+ server_args.extend(['--hf-file', context.model_hf_file])
if context.n_batch:
server_args.extend(['--batch-size', context.n_batch])
if context.n_ubatch: