summaryrefslogtreecommitdiff
path: root/common/sampling.cpp
diff options
context:
space:
mode:
authorJohannes Gäßler <johannesg@5d6.de>2024-04-24 11:08:36 +0200
committerGitHub <noreply@github.com>2024-04-24 11:08:36 +0200
commit28103f4832e301a9c84d44ff0df9d75d46ab6c76 (patch)
tree8ba391e3a7e0ce9a20d4b41782ef133bd7e32738 /common/sampling.cpp
parentc0d1b3e03e27634ac2871761f5033cf9324d472d (diff)
Server: fix seed for multiple slots (#6835)
* Server: add tests for consistent results * sampling: separate rng per sampling context
Diffstat (limited to 'common/sampling.cpp')
-rw-r--r--common/sampling.cpp13
1 files changed, 12 insertions, 1 deletions
diff --git a/common/sampling.cpp b/common/sampling.cpp
index 45d68b26..f2466550 100644
--- a/common/sampling.cpp
+++ b/common/sampling.cpp
@@ -1,4 +1,6 @@
+#define LLAMA_API_INTERNAL
#include "sampling.h"
+#include <random>
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
struct llama_sampling_context * result = new llama_sampling_context();
@@ -33,6 +35,8 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_
result->prev.resize(params.n_prev);
+ llama_sampling_set_rng_seed(result, params.seed);
+
return result;
}
@@ -62,6 +66,13 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
ctx->cur.clear();
}
+void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) {
+ if (seed == LLAMA_DEFAULT_SEED) {
+ seed = time(NULL);
+ }
+ ctx->rng.seed(seed);
+}
+
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) {
if (dst->grammar) {
llama_grammar_free(dst->grammar);
@@ -203,7 +214,7 @@ static llama_token llama_sampling_sample_impl(
sampler_queue(ctx_main, params, cur_p, min_keep);
- id = llama_sample_token(ctx_main, &cur_p);
+ id = llama_sample_token_with_rng(ctx_main, &cur_p, ctx_sampling->rng);
//{
// const int n_top = 10;