summaryrefslogtreecommitdiff
path: root/common/sampling.h
diff options
context:
space:
mode:
authorfirecoperana <xuqiaowei1124@gmail.com>2025-06-19 02:24:53 -0500
committerGitHub <noreply@github.com>2025-06-19 10:24:53 +0300
commit3f111ad7bbb2d4f721332f9b2b344e48b3bbf9aa (patch)
treea3a17ee74e0436253e17f0d322320ed554d34b0a /common/sampling.h
parentc5368148cf3af7a3694e0eb03d24a08326c01d12 (diff)
add dry sampler (#513)
* add dry sampler * use vocab instead of model in dry_init function * fix compile error for build test --------- Co-authored-by: firecoperana <firecoperana>
Diffstat (limited to 'common/sampling.h')
-rw-r--r--common/sampling.h17
1 files changed, 14 insertions, 3 deletions
diff --git a/common/sampling.h b/common/sampling.h
index 4fc86595..1d5bf0b9 100644
--- a/common/sampling.h
+++ b/common/sampling.h
@@ -35,11 +35,16 @@ typedef struct llama_sampling_params {
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
float dynatemp_range = 0.00f; // 0.0 = disabled
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
- int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
+ int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
float penalty_repeat = 1.00f; // 1.0 = disabled
float penalty_freq = 0.00f; // 0.0 = disabled
float penalty_present = 0.00f; // 0.0 = disabled
- int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
+ float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition:
+ float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length)
+ int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
+ int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
+ int32_t total_context_size = 16840;
+ int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate
float xtc_probability = 0.0f; // xtc probability
@@ -48,12 +53,16 @@ typedef struct llama_sampling_params {
bool penalize_nl = false; // consider newlines as a repeatable token
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
+ std::vector<std::string> dry_sequence_breakers = { "\n", ":", "\"", "*" }; // default sequence breakers for DRY
+
std::vector<llama_sampler_type> samplers_sequence = {
+ llama_sampler_type::DRY,
llama_sampler_type::TOP_K,
llama_sampler_type::TFS_Z,
llama_sampler_type::TYPICAL_P,
llama_sampler_type::TOP_P,
llama_sampler_type::MIN_P,
+ llama_sampler_type::XTC,
llama_sampler_type::TOP_N_SIGMA,
llama_sampler_type::TEMPERATURE
};
@@ -88,6 +97,8 @@ struct llama_sampling_context {
// TODO: replace with ring-buffer
std::vector<llama_token> prev;
std::vector<llama_token_data> cur;
+ llama_sampler_dry* smpl;
+
size_t n_valid; // Number of correct top tokens with correct probabilities.
std::mt19937 rng;
@@ -96,7 +107,7 @@ struct llama_sampling_context {
#include "common.h"
// Create a new sampling context instance.
-struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params);
+struct llama_sampling_context * llama_sampling_init(const struct llama_vocab* vocab, const struct llama_sampling_params & params);
void llama_sampling_free(struct llama_sampling_context * ctx);