summaryrefslogtreecommitdiff
path: root/common
diff options
context:
space:
mode:
authorJohannes Gäßler <johannesg@5d6.de>2024-04-24 11:08:36 +0200
committerGitHub <noreply@github.com>2024-04-24 11:08:36 +0200
commit28103f4832e301a9c84d44ff0df9d75d46ab6c76 (patch)
tree8ba391e3a7e0ce9a20d4b41782ef133bd7e32738 /common
parentc0d1b3e03e27634ac2871761f5033cf9324d472d (diff)
Server: fix seed for multiple slots (#6835)
* Server: add tests for consistent results * sampling: separate rng per sampling context
Diffstat (limited to 'common')
-rw-r--r--common/common.cpp2
-rw-r--r--common/sampling.cpp13
-rw-r--r--common/sampling.h47
3 files changed, 41 insertions, 21 deletions
diff --git a/common/common.cpp b/common/common.cpp
index 06f252ea..a0d1f8d5 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -242,7 +242,9 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
invalid_param = true;
return true;
}
+ // This is temporary, in the future the samplign state will be moved fully to llama_sampling_context.
params.seed = std::stoul(argv[i]);
+ sparams.seed = std::stoul(argv[i]);
return true;
}
if (arg == "-t" || arg == "--threads") {
diff --git a/common/sampling.cpp b/common/sampling.cpp
index 45d68b26..f2466550 100644
--- a/common/sampling.cpp
+++ b/common/sampling.cpp
@@ -1,4 +1,6 @@
+#define LLAMA_API_INTERNAL
#include "sampling.h"
+#include <random>
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
struct llama_sampling_context * result = new llama_sampling_context();
@@ -33,6 +35,8 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_
result->prev.resize(params.n_prev);
+ llama_sampling_set_rng_seed(result, params.seed);
+
return result;
}
@@ -62,6 +66,13 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
ctx->cur.clear();
}
+void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) {
+ if (seed == LLAMA_DEFAULT_SEED) {
+ seed = time(NULL);
+ }
+ ctx->rng.seed(seed);
+}
+
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) {
if (dst->grammar) {
llama_grammar_free(dst->grammar);
@@ -203,7 +214,7 @@ static llama_token llama_sampling_sample_impl(
sampler_queue(ctx_main, params, cur_p, min_keep);
- id = llama_sample_token(ctx_main, &cur_p);
+ id = llama_sample_token_with_rng(ctx_main, &cur_p, ctx_sampling->rng);
//{
// const int n_top = 10;
diff --git a/common/sampling.h b/common/sampling.h
index 639b819a..cf7081e3 100644
--- a/common/sampling.h
+++ b/common/sampling.h
@@ -4,9 +4,10 @@
#include "grammar-parser.h"
+#include <random>
#include <string>
-#include <vector>
#include <unordered_map>
+#include <vector>
// sampler types
enum class llama_sampler_type : char {
@@ -20,25 +21,26 @@ enum class llama_sampler_type : char {
// sampling parameters
typedef struct llama_sampling_params {
- int32_t n_prev = 64; // number of previous tokens to remember
- int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
- int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
- int32_t top_k = 40; // <= 0 to use vocab size
- float top_p = 0.95f; // 1.0 = disabled
- float min_p = 0.05f; // 0.0 = disabled
- float tfs_z = 1.00f; // 1.0 = disabled
- float typical_p = 1.00f; // 1.0 = disabled
- float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
- float dynatemp_range = 0.00f; // 0.0 = disabled
- float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
- int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
- float penalty_repeat = 1.00f; // 1.0 = disabled
- float penalty_freq = 0.00f; // 0.0 = disabled
- float penalty_present = 0.00f; // 0.0 = disabled
- int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
- float mirostat_tau = 5.00f; // target entropy
- float mirostat_eta = 0.10f; // learning rate
- bool penalize_nl = false; // consider newlines as a repeatable token
+ int32_t n_prev = 64; // number of previous tokens to remember
+ int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
+ int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
+ int32_t top_k = 40; // <= 0 to use vocab size
+ float top_p = 0.95f; // 1.0 = disabled
+ float min_p = 0.05f; // 0.0 = disabled
+ float tfs_z = 1.00f; // 1.0 = disabled
+ float typical_p = 1.00f; // 1.0 = disabled
+ float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
+ float dynatemp_range = 0.00f; // 0.0 = disabled
+ float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
+ int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
+ float penalty_repeat = 1.00f; // 1.0 = disabled
+ float penalty_freq = 0.00f; // 0.0 = disabled
+ float penalty_present = 0.00f; // 0.0 = disabled
+ int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
+ float mirostat_tau = 5.00f; // target entropy
+ float mirostat_eta = 0.10f; // learning rate
+ bool penalize_nl = false; // consider newlines as a repeatable token
+ uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
std::vector<llama_sampler_type> samplers_sequence = {
llama_sampler_type::TOP_K,
@@ -79,6 +81,8 @@ struct llama_sampling_context {
// TODO: replace with ring-buffer
std::vector<llama_token> prev;
std::vector<llama_token_data> cur;
+
+ std::mt19937 rng;
};
#include "common.h"
@@ -93,6 +97,9 @@ void llama_sampling_free(struct llama_sampling_context * ctx);
// - reset grammar
void llama_sampling_reset(llama_sampling_context * ctx);
+// Set the sampler seed
+void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed);
+
// Copy the sampler context
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst);