summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorJohannes Gäßler <johannesg@5d6.de>2024-04-24 11:08:36 +0200
committerGitHub <noreply@github.com>2024-04-24 11:08:36 +0200
commit28103f4832e301a9c84d44ff0df9d75d46ab6c76 (patch)
tree8ba391e3a7e0ce9a20d4b41782ef133bd7e32738 /examples
parentc0d1b3e03e27634ac2871761f5033cf9324d472d (diff)
Server: fix seed for multiple slots (#6835)
* Server: add tests for consistent results * sampling: separate rng per sampling context
Diffstat (limited to 'examples')
-rw-r--r--examples/lookup/lookup-stats.cpp1
-rw-r--r--examples/lookup/lookup.cpp1
-rw-r--r--examples/main/main.cpp1
-rw-r--r--examples/server/server.cpp3
-rw-r--r--examples/server/tests/features/results.feature57
-rw-r--r--examples/server/tests/features/steps/steps.py34
6 files changed, 92 insertions, 5 deletions
diff --git a/examples/lookup/lookup-stats.cpp b/examples/lookup/lookup-stats.cpp
index 41b62c2f..87ecc0a4 100644
--- a/examples/lookup/lookup-stats.cpp
+++ b/examples/lookup/lookup-stats.cpp
@@ -30,7 +30,6 @@ int main(int argc, char ** argv){
// load the model
std::tie(model, ctx) = llama_init_from_gpt_params(params);
- llama_set_rng_seed(ctx, params.seed);
GGML_ASSERT(llama_n_vocab(model) < (1 << 16));
// tokenize the prompt
diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp
index 9526e898..eebbd00a 100644
--- a/examples/lookup/lookup.cpp
+++ b/examples/lookup/lookup.cpp
@@ -38,7 +38,6 @@ int main(int argc, char ** argv){
// load the model
std::tie(model, ctx) = llama_init_from_gpt_params(params);
- llama_set_rng_seed(ctx, params.seed);
GGML_ASSERT(llama_n_vocab(model) < (1 << 16));
// tokenize the prompt
diff --git a/examples/main/main.cpp b/examples/main/main.cpp
index 1180734b..a74d4d9c 100644
--- a/examples/main/main.cpp
+++ b/examples/main/main.cpp
@@ -240,7 +240,6 @@ int main(int argc, char ** argv) {
return 1;
}
session_tokens.resize(n_token_count_out);
- llama_set_rng_seed(ctx, params.seed);
LOG_TEE("%s: loaded a session with prompt size of %d tokens\n", __func__, (int)session_tokens.size());
}
}
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index 25bc2963..68c63f9f 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -854,7 +854,7 @@ struct server_context {
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
- slot.params.seed = json_value(data, "seed", default_params.seed);
+ slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
@@ -1028,7 +1028,6 @@ struct server_context {
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
return false;
}
- llama_set_rng_seed(ctx, slot.params.seed);
}
slot.command = SLOT_COMMAND_LOAD_PROMPT;
diff --git a/examples/server/tests/features/results.feature b/examples/server/tests/features/results.feature
new file mode 100644
index 00000000..f17120f7
--- /dev/null
+++ b/examples/server/tests/features/results.feature
@@ -0,0 +1,57 @@
+@llama.cpp
+@results
+Feature: Results
+
+ Background: Server startup
+ Given a server listening on localhost:8080
+ And a model file tinyllamas/split/stories15M-00001-of-00003.gguf from HF repo ggml-org/models
+ And a model file test-model-00001-of-00003.gguf
+ And 128 as batch size
+ And 256 KV cache size
+ And 128 max tokens to predict
+
+ Scenario Outline: Multi users completion
+ Given <n_slots> slots
+ And continuous batching
+ Then the server is starting
+ Then the server is healthy
+
+ Given 42 as seed
+ And a prompt:
+ """
+ Write a very long story about AI.
+ """
+
+ Given 42 as seed
+ And a prompt:
+ """
+ Write a very long story about AI.
+ """
+
+ Given 42 as seed
+ And a prompt:
+ """
+ Write a very long story about AI.
+ """
+
+ Given 42 as seed
+ And a prompt:
+ """
+ Write a very long story about AI.
+ """
+
+ Given 42 as seed
+ And a prompt:
+ """
+ Write a very long story about AI.
+ """
+
+ Given concurrent completion requests
+ Then the server is busy
+ Then the server is idle
+ And all slots are idle
+ Then all predictions are equal
+ Examples:
+ | n_slots |
+ | 1 |
+ | 2 |
diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py
index ca400efa..f71e0d70 100644
--- a/examples/server/tests/features/steps/steps.py
+++ b/examples/server/tests/features/steps/steps.py
@@ -61,6 +61,7 @@ def step_server_config(context, server_fqdn, server_port):
context.server_metrics = False
context.server_process = None
context.seed = None
+ context.draft = None
context.server_seed = None
context.user_api_key = None
context.response_format = None
@@ -107,6 +108,11 @@ def step_n_gpu_layer(context, ngl):
context.n_gpu_layer = ngl
+@step('{draft:d} as draft')
+def step_draft(context, draft):
+ context.draft = draft
+
+
@step('{n_ctx:d} KV cache size')
def step_n_ctx(context, n_ctx):
context.n_ctx = n_ctx
@@ -254,6 +260,15 @@ def step_n_tokens_predicted(context, predicted_n):
assert_n_tokens_predicted(context.completion, predicted_n)
+@step('all predictions are equal')
+@async_run_until_complete
+async def step_predictions_equal(context):
+ n_completions = await gather_tasks_results(context)
+ assert n_completions >= 2, "need at least 2 completions"
+ assert_all_predictions_equal(context.tasks_result)
+ context.tasks_result = []
+
+
@step('the completion is truncated')
def step_assert_completion_truncated(context):
step_assert_completion_truncated(context, '')
@@ -1020,6 +1035,23 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re
assert n_predicted == expected_predicted_n, (f'invalid number of tokens predicted:'
f' {n_predicted} <> {expected_predicted_n}')
+def assert_all_predictions_equal(completion_responses):
+ content_0 = completion_responses[0]['content']
+
+ if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON':
+ print(f"content 0: {content_0}")
+
+ i = 1
+ for response in completion_responses[1:]:
+ content = response['content']
+
+ if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON':
+ print(f"content {i}: {content}")
+
+ assert content == content_0, "contents not equal"
+
+ i += 1
+
async def gather_tasks_results(context):
n_tasks = len(context.concurrent_tasks)
@@ -1148,6 +1180,8 @@ def start_server_background(context):
server_args.extend(['--ubatch-size', context.n_ubatch])
if context.n_gpu_layer:
server_args.extend(['--n-gpu-layers', context.n_gpu_layer])
+ if context.draft is not None:
+ server_args.extend(['--draft', context.draft])
if context.server_continuous_batching:
server_args.append('--cont-batching')
if context.server_embeddings: