From 28103f4832e301a9c84d44ff0df9d75d46ab6c76 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Wed, 24 Apr 2024 11:08:36 +0200 Subject: Server: fix seed for multiple slots (#6835) * Server: add tests for consistent results * sampling: separate rng per sampling context --- llama.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'llama.cpp') diff --git a/llama.cpp b/llama.cpp index e4ca34bd..3a4a03d8 100644 --- a/llama.cpp +++ b/llama.cpp @@ -13667,7 +13667,7 @@ llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_da return result; } -llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) { +llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng) { GGML_ASSERT(ctx); const int64_t t_start_sample_us = ggml_time_us(); @@ -13680,7 +13680,6 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra } std::discrete_distribution<> dist(probs.begin(), probs.end()); - auto & rng = ctx->rng; int idx = dist(rng); llama_token result = candidates->data[idx].id; @@ -13690,6 +13689,10 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra return result; } +llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) { + return llama_sample_token_with_rng(ctx, candidates, ctx->rng); +} + void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) { const int64_t t_start_sample_us = ggml_time_us(); -- cgit v1.2.3