From 28103f4832e301a9c84d44ff0df9d75d46ab6c76 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Wed, 24 Apr 2024 11:08:36 +0200 Subject: Server: fix seed for multiple slots (#6835) * Server: add tests for consistent results * sampling: separate rng per sampling context --- common/sampling.cpp | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) (limited to 'common/sampling.cpp') diff --git a/common/sampling.cpp b/common/sampling.cpp index 45d68b26..f2466550 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -1,4 +1,6 @@ +#define LLAMA_API_INTERNAL #include "sampling.h" +#include struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) { struct llama_sampling_context * result = new llama_sampling_context(); @@ -33,6 +35,8 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_ result->prev.resize(params.n_prev); + llama_sampling_set_rng_seed(result, params.seed); + return result; } @@ -62,6 +66,13 @@ void llama_sampling_reset(llama_sampling_context * ctx) { ctx->cur.clear(); } +void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) { + if (seed == LLAMA_DEFAULT_SEED) { + seed = time(NULL); + } + ctx->rng.seed(seed); +} + void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) { if (dst->grammar) { llama_grammar_free(dst->grammar); @@ -203,7 +214,7 @@ static llama_token llama_sampling_sample_impl( sampler_queue(ctx_main, params, cur_p, min_keep); - id = llama_sample_token(ctx_main, &cur_p); + id = llama_sample_token_with_rng(ctx_main, &cur_p, ctx_sampling->rng); //{ // const int n_top = 10; -- cgit v1.2.3