diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2025-06-03 17:35:09 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-06-03 17:35:09 +0300 |
commit | f6d5fbdc5780b6dca770c896b8463de3239c7f8b (patch) | |
tree | 5174cb76b596d23383a4434ab2179d8b5213512f /common/sampling.cpp | |
parent | ccb265c01676aad9ae5860ba50e74e61dfcd1cf8 (diff) |
Adding top-n-sigma sampler (#489)
* Adding top-n-sigma sampler
* Fix typos in XTC PR
* Update README.md for main and server
* More README
* More README
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'common/sampling.cpp')
-rw-r--r-- | common/sampling.cpp | 22 |
1 files changed, 14 insertions, 8 deletions
diff --git a/common/sampling.cpp b/common/sampling.cpp index 84691d93..4db12ee1 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -122,11 +122,11 @@ std::string llama_sampling_print(const llama_sampling_params & params) { "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n" "\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n" "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f\n" - "\txtc_probability = %.3f, xtc_threshold = %.3f", + "\txtc_probability = %.3f, xtc_threshold = %.3f, top_n_sigma = %.3f", params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present, params.top_k, params.tfs_z, params.top_p, params.min_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau, - params.xtc_probability, params.xtc_threshold); + params.xtc_probability, params.xtc_threshold, params.top_n_sigma); return std::string(result); } @@ -156,6 +156,7 @@ std::string llama_sampling_type_to_str(llama_sampler_type sampler_type) { case llama_sampler_type::MIN_P: return "min_p"; case llama_sampler_type::TEMPERATURE: return "temperature"; case llama_sampler_type::XTC : return "xtc"; + case llama_sampler_type::TOP_N_SIGMA: return "top_n_sigma"; default : return ""; } } @@ -168,6 +169,7 @@ std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vecto {"min_p", llama_sampler_type::MIN_P}, {"tfs_z", llama_sampler_type::TFS_Z}, {"xtc", llama_sampler_type::XTC}, + {"top_n_sigma", llama_sampler_type::TOP_N_SIGMA}, {"temperature", llama_sampler_type::TEMPERATURE} }; @@ -183,6 +185,7 @@ std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vecto {"tfs-z", llama_sampler_type::TFS_Z}, {"tfs", llama_sampler_type::TFS_Z}, {"xtc", llama_sampler_type::XTC}, + {"top-n-sigma", llama_sampler_type::TOP_N_SIGMA}, {"temp", llama_sampler_type::TEMPERATURE} }; @@ -218,6 +221,7 @@ std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::strin {'m', llama_sampler_type::MIN_P}, {'f', llama_sampler_type::TFS_Z}, {'x', llama_sampler_type::XTC}, + {'n', llama_sampler_type::TOP_N_SIGMA}, {'t', llama_sampler_type::TEMPERATURE} }; @@ -248,16 +252,18 @@ static void sampler_queue( const float typical_p = params.typical_p; const float xtc_probability = params.xtc_probability; const float xtc_threshold = params.xtc_threshold; + const float top_n_sigma = params.top_n_sigma; const std::vector<llama_sampler_type> & samplers_sequence = params.samplers_sequence; for (auto sampler_type : samplers_sequence) { switch (sampler_type) { - case llama_sampler_type::TOP_K : llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); break; - case llama_sampler_type::TFS_Z : llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); break; - case llama_sampler_type::TYPICAL_P: llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break; - case llama_sampler_type::TOP_P : llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break; - case llama_sampler_type::MIN_P : llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break; - case llama_sampler_type::XTC : llama_sample_xtc (ctx_main, &cur_p, xtc_probability, xtc_threshold, min_keep); break; + case llama_sampler_type::TOP_K : llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); break; + case llama_sampler_type::TFS_Z : llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); break; + case llama_sampler_type::TYPICAL_P : llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break; + case llama_sampler_type::TOP_P : llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break; + case llama_sampler_type::MIN_P : llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break; + case llama_sampler_type::XTC : llama_sample_xtc (ctx_main, &cur_p, xtc_probability, xtc_threshold, min_keep); break; + case llama_sampler_type::TOP_N_SIGMA: llama_sample_top_n_sigma(ctx_main, &cur_p, top_n_sigma); break; case llama_sampler_type::TEMPERATURE: if (dynatemp_range > 0) { float dynatemp_min = std::max(0.0f, temp - dynatemp_range); |