diff options
Diffstat (limited to 'common/sampling.h')
-rw-r--r-- | common/sampling.h | 17 |
1 files changed, 14 insertions, 3 deletions
diff --git a/common/sampling.h b/common/sampling.h index 4fc86595..1d5bf0b9 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -35,11 +35,16 @@ typedef struct llama_sampling_params { float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities float dynatemp_range = 0.00f; // 0.0 = disabled float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler - int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) + int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) float penalty_repeat = 1.00f; // 1.0 = disabled float penalty_freq = 0.00f; // 0.0 = disabled float penalty_present = 0.00f; // 0.0 = disabled - int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 + float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition: + float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length) + int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty + int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size) + int32_t total_context_size = 16840; + int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 float mirostat_tau = 5.00f; // target entropy float mirostat_eta = 0.10f; // learning rate float xtc_probability = 0.0f; // xtc probability @@ -48,12 +53,16 @@ typedef struct llama_sampling_params { bool penalize_nl = false; // consider newlines as a repeatable token uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context + std::vector<std::string> dry_sequence_breakers = { "\n", ":", "\"", "*" }; // default sequence breakers for DRY + std::vector<llama_sampler_type> samplers_sequence = { + llama_sampler_type::DRY, llama_sampler_type::TOP_K, llama_sampler_type::TFS_Z, llama_sampler_type::TYPICAL_P, llama_sampler_type::TOP_P, llama_sampler_type::MIN_P, + llama_sampler_type::XTC, llama_sampler_type::TOP_N_SIGMA, llama_sampler_type::TEMPERATURE }; @@ -88,6 +97,8 @@ struct llama_sampling_context { // TODO: replace with ring-buffer std::vector<llama_token> prev; std::vector<llama_token_data> cur; + llama_sampler_dry* smpl; + size_t n_valid; // Number of correct top tokens with correct probabilities. std::mt19937 rng; @@ -96,7 +107,7 @@ struct llama_sampling_context { #include "common.h" // Create a new sampling context instance. -struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params); +struct llama_sampling_context * llama_sampling_init(const struct llama_vocab* vocab, const struct llama_sampling_params & params); void llama_sampling_free(struct llama_sampling_context * ctx); |