summaryrefslogtreecommitdiff
path: root/examples/server/tests
diff options
context:
space:
mode:
authorPierrick Hymbert <pierrick.hymbert@gmail.com>2024-03-02 22:00:14 +0100
committerGitHub <noreply@github.com>2024-03-02 22:00:14 +0100
commit9731134296af3a6839cd682e51d9c2109a871de5 (patch)
tree882db21742d552ee948d1b5db013f02bf35ff8fa /examples/server/tests
parent4a6e2d6142ab815c964924896891e9ab3e050632 (diff)
server: tests: passkey challenge / self-extend with context shift demo (#5832)
* server: tests: add models endpoint scenario * server: /v1/models add some metadata * server: tests: add debug field in context before scenario * server: tests: download model from HF, add batch size * server: tests: add passkey test * server: tests: add group attention params * server: do not truncate prompt tokens if self-extend through group attention is enabled * server: logs: do not truncate log values * server: tests - passkey - first good working value of nga * server: tests: fix server timeout * server: tests: fix passkey, add doc, fix regex content matching, fix timeout * server: tests: fix regex content matching * server: tests: schedule slow tests on master * server: metrics: fix when no prompt processed * server: tests: self-extend add llama-2-7B and Mixtral-8x7B-v0.1 * server: tests: increase timeout for completion * server: tests: keep only the PHI-2 test * server: tests: passkey add a negative test
Diffstat (limited to 'examples/server/tests')
-rw-r--r--examples/server/tests/README.md50
-rw-r--r--examples/server/tests/features/environment.py5
-rw-r--r--examples/server/tests/features/issues.feature1
-rw-r--r--examples/server/tests/features/parallel.feature5
-rw-r--r--examples/server/tests/features/passkey.feature55
-rw-r--r--examples/server/tests/features/security.feature3
-rw-r--r--examples/server/tests/features/server.feature23
-rw-r--r--examples/server/tests/features/steps/steps.py259
-rw-r--r--examples/server/tests/features/wrong_usages.feature5
-rw-r--r--examples/server/tests/requirements.txt1
-rwxr-xr-xexamples/server/tests/tests.sh2
11 files changed, 321 insertions, 88 deletions
diff --git a/examples/server/tests/README.md b/examples/server/tests/README.md
index 0b9fdc4e..95a0353b 100644
--- a/examples/server/tests/README.md
+++ b/examples/server/tests/README.md
@@ -1,22 +1,30 @@
# Server tests
-Python based server tests scenario using [BDD](https://en.wikipedia.org/wiki/Behavior-driven_development) and [behave](https://behave.readthedocs.io/en/latest/):
- * [issues.feature](./features/issues.feature) Pending issues scenario
- * [parallel.feature](./features/parallel.feature) Scenario involving multi slots and concurrent requests
- * [security.feature](./features/security.feature) Security, CORS and API Key
- * [server.feature](./features/server.feature) Server base scenario: completion, embedding, tokenization, etc...
+Python based server tests scenario using [BDD](https://en.wikipedia.org/wiki/Behavior-driven_development)
+and [behave](https://behave.readthedocs.io/en/latest/):
+
+* [issues.feature](./features/issues.feature) Pending issues scenario
+* [parallel.feature](./features/parallel.feature) Scenario involving multi slots and concurrent requests
+* [security.feature](./features/security.feature) Security, CORS and API Key
+* [server.feature](./features/server.feature) Server base scenario: completion, embedding, tokenization, etc...
Tests target GitHub workflows job runners with 4 vCPU.
-Requests are using [aiohttp](https://docs.aiohttp.org/en/stable/client_reference.html), [asyncio](https://docs.python.org/fr/3/library/asyncio.html) based http client.
+Requests are
+using [aiohttp](https://docs.aiohttp.org/en/stable/client_reference.html), [asyncio](https://docs.python.org/fr/3/library/asyncio.html)
+based http client.
-Note: If the host architecture inference speed is faster than GitHub runners one, parallel scenario may randomly fail. To mitigate it, you can increase values in `n_predict`, `kv_size`.
+Note: If the host architecture inference speed is faster than GitHub runners one, parallel scenario may randomly fail.
+To mitigate it, you can increase values in `n_predict`, `kv_size`.
### Install dependencies
+
`pip install -r requirements.txt`
### Run tests
+
1. Build the server
+
```shell
cd ../../..
mkdir build
@@ -24,24 +32,36 @@ cd build
cmake ../
cmake --build . --target server
```
-2. download required models:
- 1. `../../../scripts/hf.sh --repo ggml-org/models --file tinyllamas/stories260K.gguf`
-3. Start the test: `./tests.sh`
+
+2. Start the test: `./tests.sh`
It's possible to override some scenario steps values with environment variables:
- - `PORT` -> `context.server_port` to set the listening port of the server during scenario, default: `8080`
- - `LLAMA_SERVER_BIN_PATH` -> to change the server binary path, default: `../../../build/bin/server`
- - `DEBUG` -> "ON" to enable steps and server verbose mode `--verbose`
- - `SERVER_LOG_FORMAT_JSON` -> if set switch server logs to json format
+
+| variable | description |
+|--------------------------|------------------------------------------------------------------------------------------------|
+| `PORT` | `context.server_port` to set the listening port of the server during scenario, default: `8080` |
+| `LLAMA_SERVER_BIN_PATH` | to change the server binary path, default: `../../../build/bin/server` |
+| `DEBUG` | "ON" to enable steps and server verbose mode `--verbose` |
+| `SERVER_LOG_FORMAT_JSON` | if set switch server logs to json format |
+| `N_GPU_LAYERS` | number of model layers to offload to VRAM `-ngl --n-gpu-layers` |
### Run @bug, @wip or @wrong_usage annotated scenario
Feature or Scenario must be annotated with `@llama.cpp` to be included in the default scope.
+
- `@bug` annotation aims to link a scenario with a GitHub issue.
- `@wrong_usage` are meant to show user issue that are actually an expected behavior
- `@wip` to focus on a scenario working in progress
+- `@slow` heavy test, disabled by default
To run a scenario annotated with `@bug`, start:
-`DEBUG=ON ./tests.sh --no-skipped --tags bug`
+
+```shell
+DEBUG=ON ./tests.sh --no-skipped --tags bug
+```
After changing logic in `steps.py`, ensure that `@bug` and `@wrong_usage` scenario are updated.
+
+```shell
+./tests.sh --no-skipped --tags bug,wrong_usage || echo "should failed but compile"
+```
diff --git a/examples/server/tests/features/environment.py b/examples/server/tests/features/environment.py
index 09e82674..9fd330db 100644
--- a/examples/server/tests/features/environment.py
+++ b/examples/server/tests/features/environment.py
@@ -7,7 +7,10 @@ from signal import SIGKILL
def before_scenario(context, scenario):
- print(f"\x1b[33;42mStarting new scenario: {scenario.name}!\x1b[0m")
+ context.debug = 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON'
+ if context.debug:
+ print("DEBUG=ON\n")
+ print(f"\x1b[33;42mStarting new scenario: {scenario.name}!\x1b[0m\n")
port = 8080
if 'PORT' in os.environ:
port = int(os.environ['PORT'])
diff --git a/examples/server/tests/features/issues.feature b/examples/server/tests/features/issues.feature
index bf5a175a..7b13e44c 100644
--- a/examples/server/tests/features/issues.feature
+++ b/examples/server/tests/features/issues.feature
@@ -1,4 +1,5 @@
# List of ongoing issues
+# run with: DEBUG=ON ./tests.sh --no-skipped --tags bug
@bug
Feature: Issues
# No confirmed issue at the moment
diff --git a/examples/server/tests/features/parallel.feature b/examples/server/tests/features/parallel.feature
index 5f895cf9..86cdf728 100644
--- a/examples/server/tests/features/parallel.feature
+++ b/examples/server/tests/features/parallel.feature
@@ -1,11 +1,12 @@
@llama.cpp
+@parallel
Feature: Parallel
Background: Server startup
Given a server listening on localhost:8080
- And a model file stories260K.gguf
- And a model alias tinyllama-2
+ And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
And 42 as server seed
+ And 512 as batch size
And 64 KV cache size
And 2 slots
And embeddings extraction
diff --git a/examples/server/tests/features/passkey.feature b/examples/server/tests/features/passkey.feature
new file mode 100644
index 00000000..1bde7aab
--- /dev/null
+++ b/examples/server/tests/features/passkey.feature
@@ -0,0 +1,55 @@
+# run with: ./tests.sh --no-skipped --tags passkey
+@passkey
+@slow
+Feature: Passkey / Self-extend with context shift
+
+ Background: Server startup
+ Given a server listening on localhost:8080
+
+ # Generates a long text of junk and inserts a secret passkey number inside it.
+ # Then we query the LLM for the secret passkey.
+ # see #3856 and #4810
+ Scenario Outline: Passkey
+ Given a model file <hf_file> from HF repo <hf_repo>
+ And <n_batch> as batch size
+ And <n_junk> as number of junk
+ And <n_predicted> server max tokens to predict
+ And 42 as seed
+ And <n_ctx> KV cache size
+ And 1 slots
+ And <n_ga> group attention factor to extend context size through self-extend
+ And <n_ga_w> group attention width to extend context size through self-extend
+ # Can be override with N_GPU_LAYERS
+ And <ngl> GPU offloaded layers
+ Then the server is starting
+ Then the server is healthy
+ Given available models
+ Then model 0 is trained on <n_ctx_train> tokens context
+ Given a prefix prompt:
+ """
+ here is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.
+ """
+ And a passkey prompt template:
+ """
+ The pass key is <passkey> Remember it. <passkey> is the pass key.
+ """
+ And a junk suffix prompt:
+ """
+ The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.
+ """
+ And a suffix prompt:
+ """
+ What is the pass key? The pass key is
+ """
+ Given a "<passkey>" passkey challenge prompt with the passkey inserted every <i_pos> junk
+ And a completion request with no api error
+ Then <n_predicted> tokens are predicted matching <re_content>
+
+ Examples:
+ | hf_repo | hf_file | n_ctx_train | ngl | n_ctx | n_batch | n_ga | n_ga_w | n_junk | i_pos | passkey | n_predicted | re_content |
+ | TheBloke/phi-2-GGUF | phi-2.Q4_K_M.gguf | 2048 | 5 | 8192 | 512 | 4 | 512 | 250 | 50 | 42 | 1 | 42 |
+ | TheBloke/phi-2-GGUF | phi-2.Q4_K_M.gguf | 2048 | 5 | 8192 | 512 | 2 | 512 | 250 | 50 | 42 | 1 | \b((?!42)\w)+\b |
+ #| TheBloke/Llama-2-7B-GGUF | llama-2-7b.Q2_K.gguf | 4096 | 3 | 16384 | 512 | 4 | 512 | 500 | 300 | 1234 | 5 | 1234 |
+ #| TheBloke/Mixtral-8x7B-v0.1-GGUF | mixtral-8x7b-v0.1.Q2_K.gguf | 32768 | 2 | 16384 | 512 | 4 | 512 | 500 | 100 | 0987 | 5 | 0
+ # 987 |
+
diff --git a/examples/server/tests/features/security.feature b/examples/server/tests/features/security.feature
index db06d397..42a6709a 100644
--- a/examples/server/tests/features/security.feature
+++ b/examples/server/tests/features/security.feature
@@ -1,9 +1,10 @@
@llama.cpp
+@security
Feature: Security
Background: Server startup with an api key defined
Given a server listening on localhost:8080
- And a model file stories260K.gguf
+ And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
And a server api key llama.cpp
Then the server is starting
Then the server is healthy
diff --git a/examples/server/tests/features/server.feature b/examples/server/tests/features/server.feature
index b571582a..7c977bcc 100644
--- a/examples/server/tests/features/server.feature
+++ b/examples/server/tests/features/server.feature
@@ -1,15 +1,17 @@
@llama.cpp
+@server
Feature: llama.cpp server
Background: Server startup
Given a server listening on localhost:8080
- And a model file stories260K.gguf
+ And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
And a model alias tinyllama-2
And 42 as server seed
# KV Cache corresponds to the total amount of tokens
# that can be stored across all independent sequences: #4130
# see --ctx-size and #5568
And 32 KV cache size
+ And 512 as batch size
And 1 slots
And embeddings extraction
And 32 server max tokens to predict
@@ -29,9 +31,9 @@ Feature: llama.cpp server
And prometheus metrics are exposed
Examples: Prompts
- | prompt | n_predict | re_content | n_predicted |
- | I believe the meaning of life is | 8 | (read<or>going)+ | 8 |
- | Write a joke about AI | 64 | (park<or>friends<or>scared<or>always)+ | 32 |
+ | prompt | n_predict | re_content | n_predicted |
+ | I believe the meaning of life is | 8 | (read\|going)+ | 8 |
+ | Write a joke about AI | 64 | (park\|friends\|scared\|always)+ | 32 |
Scenario Outline: OAI Compatibility
Given a model <model>
@@ -43,9 +45,9 @@ Feature: llama.cpp server
Then <n_predicted> tokens are predicted matching <re_content>
Examples: Prompts
- | model | system_prompt | user_prompt | max_tokens | re_content | n_predicted | enable_streaming |
- | llama-2 | Book | What is the best book | 8 | (Mom<or>what)+ | 8 | disabled |
- | codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 64 | (thanks<or>happy<or>bird)+ | 32 | enabled |
+ | model | system_prompt | user_prompt | max_tokens | re_content | n_predicted | enable_streaming |
+ | llama-2 | Book | What is the best book | 8 | (Mom\|what)+ | 8 | disabled |
+ | codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 64 | (thanks\|happy\|bird)+ | 32 | enabled |
Scenario: Embedding
When embeddings are computed for:
@@ -75,10 +77,15 @@ Feature: llama.cpp server
When an OAI compatible embeddings computation request for multiple inputs
Then embeddings are generated
-
Scenario: Tokenize / Detokenize
When tokenizing:
"""
What is the capital of France ?
"""
Then tokens can be detokenize
+
+ Scenario: Models available
+ Given available models
+ Then 1 models are supported
+ Then model 0 is identified by tinyllama-2
+ Then model 0 is trained on 128 tokens context
diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py
index 381da105..31952780 100644
--- a/examples/server/tests/features/steps/steps.py
+++ b/examples/server/tests/features/steps/steps.py
@@ -13,6 +13,7 @@ import aiohttp
import openai
from behave import step
from behave.api.async_step import async_run_until_complete
+from huggingface_hub import hf_hub_download
from prometheus_client import parser
@@ -26,17 +27,23 @@ def step_server_config(context, server_fqdn, server_port):
context.base_url = f'http://{context.server_fqdn}:{context.server_port}'
- context.debug = 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON'
context.model_alias = None
+ context.n_batch = None
context.n_ctx = None
+ context.n_ga = None
+ context.n_ga_w = None
+ context.n_gpu_layer = None
context.n_predict = None
context.n_server_predict = None
context.n_slots = None
+ context.prompt_prefix = None
+ context.prompt_suffix = None
context.server_api_key = None
context.server_continuous_batching = False
context.server_embeddings = False
context.server_metrics = False
context.server_process = None
+ context.seed = None
context.server_seed = None
context.user_api_key = None
@@ -45,9 +52,11 @@ def step_server_config(context, server_fqdn, server_port):
context.prompts = []
-@step(u'a model file {model_file}')
-def step_model_file(context, model_file):
- context.model_file = model_file
+@step(u'a model file {hf_file} from HF repo {hf_repo}')
+def step_download_hf_model(context, hf_file, hf_repo):
+ context.model_file = hf_hub_download(repo_id=hf_repo, filename=hf_file)
+ if context.debug:
+ print(f"model file: {context.model_file}\n")
@step(u'a model alias {model_alias}')
@@ -55,24 +64,34 @@ def step_model_alias(context, model_alias):
context.model_alias = model_alias
-@step(u'{seed} as server seed')
+@step(u'{seed:d} as server seed')
def step_seed(context, seed):
- context.server_seed = int(seed)
+ context.server_seed = seed
+
+
+@step(u'{ngl:d} GPU offloaded layers')
+def step_n_gpu_layer(context, ngl):
+ if 'N_GPU_LAYERS' in os.environ:
+ new_ngl = int(os.environ['N_GPU_LAYERS'])
+ if context.debug:
+ print(f"-ngl upgraded from {ngl} to {new_ngl}")
+ ngl = new_ngl
+ context.n_gpu_layer = ngl
-@step(u'{n_ctx} KV cache size')
+@step(u'{n_ctx:d} KV cache size')
def step_n_ctx(context, n_ctx):
- context.n_ctx = int(n_ctx)
+ context.n_ctx = n_ctx
-@step(u'{n_slots} slots')
+@step(u'{n_slots:d} slots')
def step_n_slots(context, n_slots):
- context.n_slots = int(n_slots)
+ context.n_slots = n_slots
-@step(u'{n_predict} server max tokens to predict')
+@step(u'{n_predict:d} server max tokens to predict')
def step_server_n_predict(context, n_predict):
- context.n_server_predict = int(n_predict)
+ context.n_server_predict = n_predict
@step(u'continuous batching')
@@ -116,11 +135,13 @@ async def step_wait_for_the_server_to_be_started(context, expecting_status):
case 'ready' | 'idle':
await wait_for_health_status(context, context.base_url, 200, 'ok',
+ timeout=10,
params={'fail_on_no_slot': 0, 'include_slots': 0},
slots_idle=context.n_slots,
slots_processing=0,
expected_slots=[{'id': slot_id, 'state': 0}
- for slot_id in range(context.n_slots)])
+ for slot_id in
+ range(context.n_slots if context.n_slots else 1)])
case 'busy':
await wait_for_health_status(context, context.base_url, 503,
'no slot available',
@@ -128,7 +149,8 @@ async def step_wait_for_the_server_to_be_started(context, expecting_status):
slots_idle=0,
slots_processing=context.n_slots,
expected_slots=[{'id': slot_id, 'state': 1}
- for slot_id in range(context.n_slots)])
+ for slot_id in
+ range(context.n_slots if context.n_slots else 1)])
case _:
assert False, "unknown status"
@@ -157,24 +179,24 @@ async def step_request_completion(context, api_error):
context.base_url,
debug=context.debug,
n_predict=context.n_predict,
- server_seed=context.server_seed,
+ seed=await completions_seed(context),
expect_api_error=expect_api_error,
user_api_key=context.user_api_key)
context.tasks_result.append(completion)
if context.debug:
- print(f"Completion response: {completion}")
+ print(f"Completion response: {completion}\n")
if expect_api_error:
assert completion == 401, f"completion must be an 401 status code: {completion}"
-@step(u'{predicted_n} tokens are predicted matching {re_content}')
+@step(u'{predicted_n:d} tokens are predicted matching {re_content}')
def step_n_tokens_predicted_with_content(context, predicted_n, re_content):
- assert_n_tokens_predicted(context.tasks_result.pop(), int(predicted_n), re_content)
+ assert_n_tokens_predicted(context.tasks_result.pop(), predicted_n, re_content)
-@step(u'{predicted_n} tokens are predicted')
+@step(u'{predicted_n:d} tokens are predicted')
def step_n_tokens_predicted(context, predicted_n):
- assert_n_tokens_predicted(context.tasks_result.pop(), int(predicted_n))
+ assert_n_tokens_predicted(context.tasks_result.pop(), predicted_n)
@step(u'a user prompt {user_prompt}')
@@ -192,9 +214,9 @@ def step_model(context, model):
context.model = model
-@step(u'{max_tokens} max tokens to predict')
+@step(u'{max_tokens:d} max tokens to predict')
def step_max_tokens(context, max_tokens):
- context.n_predict = int(max_tokens)
+ context.n_predict = max_tokens
@step(u'streaming is {enable_streaming}')
@@ -222,11 +244,70 @@ def step_server_api_key(context, server_api_key):
context.server_api_key = server_api_key
+@step(u'{n_junk:d} as number of junk')
+def step_n_junk(context, n_junk):
+ context.n_junk = n_junk
+
+
+@step(u'{n_batch:d} as batch size')
+def step_n_batch(context, n_batch):
+ context.n_batch = n_batch
+
+
+@step(u'{seed:d} as seed')
+def step_seed(context, seed):
+ context.seed = seed
+
+
+@step(u'a prefix prompt')
+def step_prompt_prefix(context):
+ context.prompt_prefix = context.text
+
+
+@step(u'a junk suffix prompt')
+def step_prompt_junk_suffix(context):
+ context.prompt_junk_suffix = context.text
+
+
+@step(u'a suffix prompt')
+def step_prompt_suffix(context):
+ context.prompt_suffix = context.text
+
+
+@step(u'{n_ga:d} group attention factor'
+ u' to extend context size through self-extend')
+def step_impl(context, n_ga):
+ context.n_ga = n_ga
+
+
+@step(u'{n_ga_w:d} group attention width to extend context size through self-extend')
+def step_impl(context, n_ga_w):
+ context.n_ga_w = n_ga_w
+
+
+@step(u'a passkey prompt template')
+def step_prompt_passkey(context):
+ context.prompt_passkey = context.text
+
+
+@step(u'a "{passkey}" passkey challenge prompt with the passkey inserted every {i_pos:d} junk')
+def step_prompt_passkey(context, passkey, i_pos):
+ prompt = ""
+ for i in range(context.n_junk):
+ if i % context.n_junk == i_pos:
+ prompt += context.prompt_passkey # the passkey is already substituted
+ prompt += context.prompt_junk_suffix
+ if context.debug:
+ passkey_highlight = "\x1b[33m" + passkey + "\x1b[0m"
+ print(f"Passkey challenge:\n```{prompt.replace(passkey, passkey_highlight)}```\n")
+ context.prompts.append(context.prompt_prefix + prompt + context.prompt_suffix)
+
+
@step(u'an OAI compatible chat completions request with {api_error} api error')
@async_run_until_complete
async def step_oai_chat_completions(context, api_error):
if context.debug:
- print(f"Submitting OAI compatible completions request...")
+ print(f"Submitting OAI compatible completions request...\n")
expect_api_error = api_error == 'raised'
completion = await oai_chat_completions(context.prompts.pop(),
context.system_prompt,
@@ -241,8 +322,7 @@ async def step_oai_chat_completions(context, api_error):
enable_streaming=context.enable_streaming
if hasattr(context, 'enable_streaming') else None,
- server_seed=context.server_seed
- if hasattr(context, 'server_seed') else None,
+ seed=await completions_seed(context),
user_api_key=context.user_api_key
if hasattr(context, 'user_api_key') else None,
@@ -276,8 +356,10 @@ async def step_concurrent_completion_requests(context):
# prompt is inserted automatically
context.base_url,
debug=context.debug,
+ prompt_prefix=context.prompt_prefix,
+ prompt_suffix=context.prompt_suffix,
n_predict=context.n_predict if hasattr(context, 'n_predict') else None,
- server_seed=context.server_seed if hasattr(context, 'server_seed') else None,
+ seed=await completions_seed(context),
user_api_key=context.user_api_key if hasattr(context,
'user_api_key') else None)
@@ -297,8 +379,7 @@ async def step_oai_chat_completions(context):
if hasattr(context, 'n_predict') else None,
enable_streaming=context.enable_streaming
if hasattr(context, 'enable_streaming') else None,
- server_seed=context.server_seed
- if hasattr(context, 'server_seed') else None,
+ seed=await completions_seed(context),
user_api_key=context.user_api_key
if hasattr(context, 'user_api_key') else None)
@@ -318,7 +399,9 @@ async def step_oai_chat_completions(context):
if hasattr(context, 'n_predict') else None,
enable_streaming=context.enable_streaming
if hasattr(context, 'enable_streaming') else None,
- server_seed=context.server_seed
+ seed=context.seed
+ if hasattr(context, 'seed') else
+ context.server_seed
if hasattr(context, 'server_seed') else None,
user_api_key=context.user_api_key
if hasattr(context, 'user_api_key') else None)
@@ -330,11 +413,10 @@ async def step_all_prompts_are_predicted(context):
await all_prompts_are_predicted(context)
-@step(u'all prompts are predicted with {n_predict} tokens')
+@step(u'all prompts are predicted with {n_expected_predicted:d} tokens')
@async_run_until_complete
-async def step_all_prompts_are_predicted_with_n_tokens(context, n_predict):
- expected_predicted_n = int(n_predict)
- await all_prompts_are_predicted(context, expected_predicted_n)
+async def step_all_prompts_are_predicted_with_n_tokens(context, n_expected_predicted):
+ await all_prompts_are_predicted(context, n_expected_predicted)
async def all_prompts_are_predicted(context, expected_predicted_n=None):
@@ -464,6 +546,8 @@ async def step_prometheus_metrics_exported(context):
assert metrics_response.headers['Content-Type'] == "text/plain; version=0.0.4"
metrics_raw = await metrics_response.text()
metric_exported = False
+ if context.debug:
+ print(f"/metrics answer:\n{metrics_raw}\n")
for metric in parser.text_string_to_metric_families(metrics_raw):
match metric.name:
case "llamacpp:kv_cache_usage_ratio":
@@ -472,6 +556,37 @@ async def step_prometheus_metrics_exported(context):
assert metric_exported, "No metrics exported"
+@step(u'available models')
+def step_available_models(context):
+ # openai client always expects an api_key
+ openai.api_key = context.user_api_key if context.user_api_key is not None else 'nope'
+ openai.api_base = f'{context.base_url}/v1'
+ context.models = openai.Model.list().data
+
+
+@step(u'{n_model:d} models are supported')
+def step_supported_models(context, n_model):
+ if context.debug:
+ print("server models available:", context.models)
+ assert len(context.models) == n_model
+
+
+@step(u'model {i_model:d} is {param} {preposition} {param_value}')
+def step_supported_models(context, i_model, param, preposition, param_value):
+ assert i_model < len(context.models)
+ model = context.models[i_model]
+
+ param_value = param_value.split(' ', 1)[0]
+ match param:
+ case 'identified':
+ value = model.id
+ case 'trained':
+ value = str(model.meta.n_ctx_train)
+ case _:
+ assert False, "param {param} not supported"
+ assert param_value == value, f"model param {param} {value} != {param_value}"
+
+
async def concurrent_requests(context, f_completion, *args, **kwargs):
n_prompts = len(context.prompts)
if context.debug:
@@ -486,8 +601,10 @@ async def concurrent_requests(context, f_completion, *args, **kwargs):
async def request_completion(prompt,
base_url,
debug=False,
+ prompt_prefix=None,
+ prompt_suffix=None,
n_predict=None,
- server_seed=None,
+ seed=None,
expect_api_error=None,
user_api_key=None):
if debug:
@@ -504,11 +621,14 @@ async def request_completion(prompt,
async with aiohttp.ClientSession() as session:
async with session.post(f'{base_url}/completion',
json={
+ "input_prefix": prompt_prefix,
"prompt": prompt,
- "n_predict": int(n_predict) if n_predict is not None else -1,
- "seed": server_seed if server_seed is not None else 42
+ "input_suffix": prompt_suffix,
+ "n_predict": n_predict if n_predict is not None else -1,
+ "seed": seed if seed is not None else 42
},
- headers=headers) as response:
+ headers=headers,
+ timeout=3600) as response:
if expect_api_error is None or not expect_api_error:
assert response.status == 200
assert response.headers['Access-Control-Allow-Origin'] == origin
@@ -526,14 +646,14 @@ async def oai_chat_completions(user_prompt,
model=None,
n_predict=None,
enable_streaming=None,
- server_seed=None,
+ seed=None,
user_api_key=None,
expect_api_error=None):
if debug:
print(f"Sending OAI Chat completions request: {user_prompt}")
# openai client always expects an api key
user_api_key = user_api_key if user_api_key is not None else 'nope'
- seed = server_seed if server_seed is not None else 42
+ seed = seed if seed is not None else 42
enable_streaming = enable_streaming if enable_streaming is not None else False
payload = {
"messages": [
@@ -692,20 +812,32 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re
content = completion_response['content']
n_predicted = completion_response['timings']['predicted_n']
assert len(content) > 0, "no token predicted"
- if expected_predicted_n is not None:
+ if re_content is not None:
+ p = re.compile(re_content, flags=RegexFlag.IGNORECASE | RegexFlag.MULTILINE | RegexFlag.DOTALL)
+ matches = p.finditer(content)
+ last_match = 0
+ highlighted = ''
+ for match in matches:
+ start, end = match.span()
+ highlighted += content[last_match: start]
+ highlighted += '\x1b[33m'
+ highlighted += content[start: end]
+ highlighted += '\x1b[0m'
+ last_match = end
+ highlighted += content[last_match:]
+ if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON':
+ print(f"Checking completion response: {highlighted}\n")
+ assert last_match > 0, f'/{re_content}/ must match ```{highlighted}```'
+ if expected_predicted_n and expected_predicted_n > 0:
assert n_predicted == expected_predicted_n, (f'invalid number of tokens predicted:'
f' {n_predicted} <> {expected_predicted_n}')
- if re_content is not None:
- re_content = '^.*' + re_content.replace('<or>', '|') + '.*$'
- assert re.match(re_content, content, flags=RegexFlag.IGNORECASE | RegexFlag.MULTILINE | RegexFlag.DOTALL), (
- f'invalid tokens predicted:'
- f' ```\n{content}\n``` do not match /{re_content}/')
+
async def gather_tasks_results(context):
n_tasks = len(context.concurrent_tasks)
if context.debug:
- print(f"Waiting for all {n_tasks} tasks results...")
+ print(f"Waiting for all {n_tasks} tasks results...\n")
for task_no in range(n_tasks):
context.tasks_result.append(await context.concurrent_tasks.pop())
n_completions = len(context.tasks_result)
@@ -716,15 +848,13 @@ async def wait_for_health_status(context,
base_url,
expected_http_status_code,
expected_health_status,
+ timeout=3,
params=None,
slots_idle=None,
slots_processing=None,
expected_slots=None):
if context.debug:
- print(f"Starting checking for health for expected_health_status={expected_health_status}")
- timeout = 3 # seconds
- if expected_health_status == 'ok':
- timeout = 10 # CI slow inference
+ print(f"Starting checking for health for expected_health_status={expected_health_status}\n")
interval = 0.5
counter = 0
async with aiohttp.ClientSession() as session:
@@ -734,7 +864,7 @@ async def wait_for_health_status(context,
health = await health_response.json()
if context.debug:
print(f"HEALTH - response for expected health status='{expected_health_status}' on "
- f"'{base_url}/health'?{params} is {health}")
+ f"'{base_url}/health'?{params} is {health}\n")
if (status_code == expected_http_status_code
and health['status'] == expected_health_status
and (slots_idle is None or health['slots_idle'] == slots_idle)
@@ -757,7 +887,7 @@ async def wait_for_health_status(context,
if expected_http_status_code == 503:
if len(context.tasks_result) == 0:
print("\x1b[5;37;43mWARNING: forcing concurrent tasks,"
- " busy health check missed, probably too fast inference\x1b[0m")
+ " busy health check missed, probably too fast inference\x1b[0m\n")
n_completions = await gather_tasks_results(context)
if n_completions > 0:
return
@@ -791,6 +921,11 @@ def assert_slots_status(slots, expected_slots):
f" = {expected[key]} != {slot[key]}")
+async def completions_seed(context):
+ return context.seed if hasattr(context, 'seed') and context.seed is not None \
+ else context.server_seed if hasattr(context, 'server_seed') else None
+
+
def start_server_background(context):
context.server_path = '../../../build/bin/server'
if 'LLAMA_SERVER_BIN_PATH' in os.environ:
@@ -800,27 +935,35 @@ def start_server_background(context):
'--port', context.server_port,
'--model', context.model_file
]
+ if context.n_batch:
+ server_args.extend(['--batch-size', context.n_batch])
+ if context.n_gpu_layer:
+ server_args.extend(['--n-gpu-layers', context.n_gpu_layer])
if context.server_continuous_batching:
server_args.append('--cont-batching')
if context.server_embeddings:
server_args.append('--embedding')
if context.server_metrics:
server_args.append('--metrics')
- if context.model_alias is not None:
+ if context.model_alias:
server_args.extend(['--alias', context.model_alias])
- if context.n_ctx is not None:
+ if context.n_ctx:
server_args.extend(['--ctx-size', context.n_ctx])
- if context.n_slots is not None:
+ if context.n_slots:
server_args.extend(['--parallel', context.n_slots])
- if context.n_server_predict is not None:
+ if context.n_server_predict:
server_args.extend(['--n-predict', context.n_server_predict])
- if context.server_api_key is not None:
+ if context.server_api_key:
server_args.extend(['--api-key', context.server_api_key])
+ if context.n_ga:
+ server_args.extend(['--grp-attn-n', context.n_ga])
+ if context.n_ga_w:
+ server_args.extend(['--grp-attn-w', context.n_ga_w])
if context.debug:
server_args.append('--verbose')
if 'SERVER_LOG_FORMAT_JSON' not in os.environ:
server_args.extend(['--log-format', "text"])
- print(f"starting server with: {context.server_path}", *server_args)
+ print(f"starting server with: {context.server_path} {server_args}\n")
context.server_process = subprocess.Popen(
[str(arg) for arg in [context.server_path, *server_args]],
close_fds=True)
diff --git a/examples/server/tests/features/wrong_usages.feature b/examples/server/tests/features/wrong_usages.feature
index e228b237..cf14b3b4 100644
--- a/examples/server/tests/features/wrong_usages.feature
+++ b/examples/server/tests/features/wrong_usages.feature
@@ -1,4 +1,4 @@
-# run with ./test.sh --tags wrong_usage
+# run with: ./tests.sh --no-skipped --tags wrong_usage
@wrong_usage
Feature: Wrong usage of llama.cpp server
@@ -7,7 +7,7 @@ Feature: Wrong usage of llama.cpp server
# or pass n_predict/max_tokens in the request.
Scenario: Infinite loop
Given a server listening on localhost:8080
- And a model file stories260K.gguf
+ And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
# Uncomment below to fix the issue
#And 64 server max tokens to predict
Then the server is starting
@@ -18,4 +18,5 @@ Feature: Wrong usage of llama.cpp server
# Uncomment below to fix the issue
#And 128 max tokens to predict
Given concurrent completion requests
+ Then the server is idle
Then all prompts are predicted
diff --git a/examples/server/tests/requirements.txt b/examples/server/tests/requirements.txt
index 334fa4a7..5d421016 100644
--- a/examples/server/tests/requirements.txt
+++ b/examples/server/tests/requirements.txt
@@ -1,4 +1,5 @@
aiohttp~=3.9.3
behave~=1.2.6
+huggingface_hub~=0.20.3
openai~=0.25.0
prometheus-client~=0.20.0
diff --git a/examples/server/tests/tests.sh b/examples/server/tests/tests.sh
index 17a4e6fc..1c6c5695 100755
--- a/examples/server/tests/tests.sh
+++ b/examples/server/tests/tests.sh
@@ -5,7 +5,7 @@ set -eu
if [ $# -lt 1 ]
then
# Start @llama.cpp scenario
- behave --summary --stop --no-capture --exclude 'issues|wrong_usages' --tags llama.cpp
+ behave --summary --stop --no-capture --exclude 'issues|wrong_usages|passkey' --tags llama.cpp
else
behave "$@"
fi