diff options
author | Johannes Gäßler <johannesg@5d6.de> | 2024-04-24 11:08:36 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-24 11:08:36 +0200 |
commit | 28103f4832e301a9c84d44ff0df9d75d46ab6c76 (patch) | |
tree | 8ba391e3a7e0ce9a20d4b41782ef133bd7e32738 /common | |
parent | c0d1b3e03e27634ac2871761f5033cf9324d472d (diff) |
Server: fix seed for multiple slots (#6835)
* Server: add tests for consistent results
* sampling: separate rng per sampling context
Diffstat (limited to 'common')
-rw-r--r-- | common/common.cpp | 2 | ||||
-rw-r--r-- | common/sampling.cpp | 13 | ||||
-rw-r--r-- | common/sampling.h | 47 |
3 files changed, 41 insertions, 21 deletions
diff --git a/common/common.cpp b/common/common.cpp index 06f252ea..a0d1f8d5 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -242,7 +242,9 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa invalid_param = true; return true; } + // This is temporary, in the future the samplign state will be moved fully to llama_sampling_context. params.seed = std::stoul(argv[i]); + sparams.seed = std::stoul(argv[i]); return true; } if (arg == "-t" || arg == "--threads") { diff --git a/common/sampling.cpp b/common/sampling.cpp index 45d68b26..f2466550 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -1,4 +1,6 @@ +#define LLAMA_API_INTERNAL #include "sampling.h" +#include <random> struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) { struct llama_sampling_context * result = new llama_sampling_context(); @@ -33,6 +35,8 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_ result->prev.resize(params.n_prev); + llama_sampling_set_rng_seed(result, params.seed); + return result; } @@ -62,6 +66,13 @@ void llama_sampling_reset(llama_sampling_context * ctx) { ctx->cur.clear(); } +void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) { + if (seed == LLAMA_DEFAULT_SEED) { + seed = time(NULL); + } + ctx->rng.seed(seed); +} + void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) { if (dst->grammar) { llama_grammar_free(dst->grammar); @@ -203,7 +214,7 @@ static llama_token llama_sampling_sample_impl( sampler_queue(ctx_main, params, cur_p, min_keep); - id = llama_sample_token(ctx_main, &cur_p); + id = llama_sample_token_with_rng(ctx_main, &cur_p, ctx_sampling->rng); //{ // const int n_top = 10; diff --git a/common/sampling.h b/common/sampling.h index 639b819a..cf7081e3 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -4,9 +4,10 @@ #include "grammar-parser.h" +#include <random> #include <string> -#include <vector> #include <unordered_map> +#include <vector> // sampler types enum class llama_sampler_type : char { @@ -20,25 +21,26 @@ enum class llama_sampler_type : char { // sampling parameters typedef struct llama_sampling_params { - int32_t n_prev = 64; // number of previous tokens to remember - int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. - int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens - int32_t top_k = 40; // <= 0 to use vocab size - float top_p = 0.95f; // 1.0 = disabled - float min_p = 0.05f; // 0.0 = disabled - float tfs_z = 1.00f; // 1.0 = disabled - float typical_p = 1.00f; // 1.0 = disabled - float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities - float dynatemp_range = 0.00f; // 0.0 = disabled - float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler - int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) - float penalty_repeat = 1.00f; // 1.0 = disabled - float penalty_freq = 0.00f; // 0.0 = disabled - float penalty_present = 0.00f; // 0.0 = disabled - int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 - float mirostat_tau = 5.00f; // target entropy - float mirostat_eta = 0.10f; // learning rate - bool penalize_nl = false; // consider newlines as a repeatable token + int32_t n_prev = 64; // number of previous tokens to remember + int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. + int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens + int32_t top_k = 40; // <= 0 to use vocab size + float top_p = 0.95f; // 1.0 = disabled + float min_p = 0.05f; // 0.0 = disabled + float tfs_z = 1.00f; // 1.0 = disabled + float typical_p = 1.00f; // 1.0 = disabled + float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities + float dynatemp_range = 0.00f; // 0.0 = disabled + float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler + int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) + float penalty_repeat = 1.00f; // 1.0 = disabled + float penalty_freq = 0.00f; // 0.0 = disabled + float penalty_present = 0.00f; // 0.0 = disabled + int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 + float mirostat_tau = 5.00f; // target entropy + float mirostat_eta = 0.10f; // learning rate + bool penalize_nl = false; // consider newlines as a repeatable token + uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context std::vector<llama_sampler_type> samplers_sequence = { llama_sampler_type::TOP_K, @@ -79,6 +81,8 @@ struct llama_sampling_context { // TODO: replace with ring-buffer std::vector<llama_token> prev; std::vector<llama_token_data> cur; + + std::mt19937 rng; }; #include "common.h" @@ -93,6 +97,9 @@ void llama_sampling_free(struct llama_sampling_context * ctx); // - reset grammar void llama_sampling_reset(llama_sampling_context * ctx); +// Set the sampler seed +void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed); + // Copy the sampler context void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst); |