summaryrefslogtreecommitdiff
path: root/common/sampling.cpp
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-06-03 17:35:09 +0300
committerGitHub <noreply@github.com>2025-06-03 17:35:09 +0300
commitf6d5fbdc5780b6dca770c896b8463de3239c7f8b (patch)
tree5174cb76b596d23383a4434ab2179d8b5213512f /common/sampling.cpp
parentccb265c01676aad9ae5860ba50e74e61dfcd1cf8 (diff)
Adding top-n-sigma sampler (#489)
* Adding top-n-sigma sampler * Fix typos in XTC PR * Update README.md for main and server * More README * More README --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'common/sampling.cpp')
-rw-r--r--common/sampling.cpp22
1 files changed, 14 insertions, 8 deletions
diff --git a/common/sampling.cpp b/common/sampling.cpp
index 84691d93..4db12ee1 100644
--- a/common/sampling.cpp
+++ b/common/sampling.cpp
@@ -122,11 +122,11 @@ std::string llama_sampling_print(const llama_sampling_params & params) {
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n"
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f\n"
- "\txtc_probability = %.3f, xtc_threshold = %.3f",
+ "\txtc_probability = %.3f, xtc_threshold = %.3f, top_n_sigma = %.3f",
params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present,
params.top_k, params.tfs_z, params.top_p, params.min_p, params.typical_p, params.temp,
params.mirostat, params.mirostat_eta, params.mirostat_tau,
- params.xtc_probability, params.xtc_threshold);
+ params.xtc_probability, params.xtc_threshold, params.top_n_sigma);
return std::string(result);
}
@@ -156,6 +156,7 @@ std::string llama_sampling_type_to_str(llama_sampler_type sampler_type) {
case llama_sampler_type::MIN_P: return "min_p";
case llama_sampler_type::TEMPERATURE: return "temperature";
case llama_sampler_type::XTC : return "xtc";
+ case llama_sampler_type::TOP_N_SIGMA: return "top_n_sigma";
default : return "";
}
}
@@ -168,6 +169,7 @@ std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vecto
{"min_p", llama_sampler_type::MIN_P},
{"tfs_z", llama_sampler_type::TFS_Z},
{"xtc", llama_sampler_type::XTC},
+ {"top_n_sigma", llama_sampler_type::TOP_N_SIGMA},
{"temperature", llama_sampler_type::TEMPERATURE}
};
@@ -183,6 +185,7 @@ std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vecto
{"tfs-z", llama_sampler_type::TFS_Z},
{"tfs", llama_sampler_type::TFS_Z},
{"xtc", llama_sampler_type::XTC},
+ {"top-n-sigma", llama_sampler_type::TOP_N_SIGMA},
{"temp", llama_sampler_type::TEMPERATURE}
};
@@ -218,6 +221,7 @@ std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::strin
{'m', llama_sampler_type::MIN_P},
{'f', llama_sampler_type::TFS_Z},
{'x', llama_sampler_type::XTC},
+ {'n', llama_sampler_type::TOP_N_SIGMA},
{'t', llama_sampler_type::TEMPERATURE}
};
@@ -248,16 +252,18 @@ static void sampler_queue(
const float typical_p = params.typical_p;
const float xtc_probability = params.xtc_probability;
const float xtc_threshold = params.xtc_threshold;
+ const float top_n_sigma = params.top_n_sigma;
const std::vector<llama_sampler_type> & samplers_sequence = params.samplers_sequence;
for (auto sampler_type : samplers_sequence) {
switch (sampler_type) {
- case llama_sampler_type::TOP_K : llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); break;
- case llama_sampler_type::TFS_Z : llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); break;
- case llama_sampler_type::TYPICAL_P: llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break;
- case llama_sampler_type::TOP_P : llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break;
- case llama_sampler_type::MIN_P : llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break;
- case llama_sampler_type::XTC : llama_sample_xtc (ctx_main, &cur_p, xtc_probability, xtc_threshold, min_keep); break;
+ case llama_sampler_type::TOP_K : llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); break;
+ case llama_sampler_type::TFS_Z : llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); break;
+ case llama_sampler_type::TYPICAL_P : llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break;
+ case llama_sampler_type::TOP_P : llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break;
+ case llama_sampler_type::MIN_P : llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break;
+ case llama_sampler_type::XTC : llama_sample_xtc (ctx_main, &cur_p, xtc_probability, xtc_threshold, min_keep); break;
+ case llama_sampler_type::TOP_N_SIGMA: llama_sample_top_n_sigma(ctx_main, &cur_p, top_n_sigma); break;
case llama_sampler_type::TEMPERATURE:
if (dynatemp_range > 0) {
float dynatemp_min = std::max(0.0f, temp - dynatemp_range);