diff options
author | Johannes Gäßler <johannesg@5d6.de> | 2024-04-24 11:08:36 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-24 11:08:36 +0200 |
commit | 28103f4832e301a9c84d44ff0df9d75d46ab6c76 (patch) | |
tree | 8ba391e3a7e0ce9a20d4b41782ef133bd7e32738 /common/sampling.cpp | |
parent | c0d1b3e03e27634ac2871761f5033cf9324d472d (diff) |
Server: fix seed for multiple slots (#6835)
* Server: add tests for consistent results
* sampling: separate rng per sampling context
Diffstat (limited to 'common/sampling.cpp')
-rw-r--r-- | common/sampling.cpp | 13 |
1 files changed, 12 insertions, 1 deletions
diff --git a/common/sampling.cpp b/common/sampling.cpp index 45d68b26..f2466550 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -1,4 +1,6 @@ +#define LLAMA_API_INTERNAL #include "sampling.h" +#include <random> struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) { struct llama_sampling_context * result = new llama_sampling_context(); @@ -33,6 +35,8 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_ result->prev.resize(params.n_prev); + llama_sampling_set_rng_seed(result, params.seed); + return result; } @@ -62,6 +66,13 @@ void llama_sampling_reset(llama_sampling_context * ctx) { ctx->cur.clear(); } +void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) { + if (seed == LLAMA_DEFAULT_SEED) { + seed = time(NULL); + } + ctx->rng.seed(seed); +} + void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) { if (dst->grammar) { llama_grammar_free(dst->grammar); @@ -203,7 +214,7 @@ static llama_token llama_sampling_sample_impl( sampler_queue(ctx_main, params, cur_p, min_keep); - id = llama_sample_token(ctx_main, &cur_p); + id = llama_sample_token_with_rng(ctx_main, &cur_p, ctx_sampling->rng); //{ // const int n_top = 10; |