diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2025-06-03 17:35:09 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-06-03 17:35:09 +0300 |
commit | f6d5fbdc5780b6dca770c896b8463de3239c7f8b (patch) | |
tree | 5174cb76b596d23383a4434ab2179d8b5213512f /common | |
parent | ccb265c01676aad9ae5860ba50e74e61dfcd1cf8 (diff) |
Adding top-n-sigma sampler (#489)
* Adding top-n-sigma sampler
* Fix typos in XTC PR
* Update README.md for main and server
* More README
* More README
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'common')
-rw-r--r-- | common/common.cpp | 9 | ||||
-rw-r--r-- | common/sampling.cpp | 22 | ||||
-rw-r--r-- | common/sampling.h | 4 |
3 files changed, 25 insertions, 10 deletions
diff --git a/common/common.cpp b/common/common.cpp index cefbf63f..232101e4 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -659,6 +659,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa sparams.xtc_threshold = std::stof(argv[i]); return true; } + if (arg == "--top-n-sigma") { + CHECK_ARG + sparams.top_n_sigma = std::stof(argv[i]); + return true; + } if (arg == "--cfg-negative-prompt") { CHECK_ARG sparams.cfg_negative_prompt = argv[i]; @@ -1646,7 +1651,8 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", " --mirostat-lr N", "Mirostat learning rate, parameter eta (default: %.1f)", (double)sparams.mirostat_eta }); options.push_back({ "*", " --mirostat-ent N", "Mirostat target entropy, parameter tau (default: %.1f)", (double)sparams.mirostat_tau }); options.push_back({ "*", " --xtc-probability p", "xtc probability (default: %.1f, 0.0 = disabled)", (double)sparams.xtc_probability }); - options.push_back({ "*", " --xtc-threshold t", "xtc threshold (default: %.1f, 0.0 = disabled)", (double)sparams.xtc_threshold}); + options.push_back({ "*", " --xtc-threshold t", "xtc threshold (default: %.1f, >0.5 = disabled)", (double)sparams.xtc_threshold}); + options.push_back({ "*", " --top-n-sigma t", "top-n-sigma parmeter (default: %.1f, 0.0 = disabled)", (double)sparams.top_n_sigma}); options.push_back({ "*", " -l TOKEN_ID(+/-)BIAS", "modifies the likelihood of token appearing in the completion,\n" "i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n" "or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'" }); @@ -3410,6 +3416,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "mirostat_lr: %f # default: 0.1\n", sparams.mirostat_eta); fprintf(stream, "xtc_probability: %f # default: 0.0\n", sparams.xtc_probability); fprintf(stream, "xtc_threshold: %f # default: 0.0\n", sparams.xtc_threshold); + fprintf(stream, "top_n_sigma: %f # default: 0.0\n", sparams.top_n_sigma); fprintf(stream, "mlock: %s # default: false\n", params.use_mlock ? "true" : "false"); fprintf(stream, "model: %s # default: %s\n", params.model.c_str(), DEFAULT_MODEL_PATH); fprintf(stream, "model_draft: %s # default:\n", params.model_draft.c_str()); diff --git a/common/sampling.cpp b/common/sampling.cpp index 84691d93..4db12ee1 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -122,11 +122,11 @@ std::string llama_sampling_print(const llama_sampling_params & params) { "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n" "\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n" "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f\n" - "\txtc_probability = %.3f, xtc_threshold = %.3f", + "\txtc_probability = %.3f, xtc_threshold = %.3f, top_n_sigma = %.3f", params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present, params.top_k, params.tfs_z, params.top_p, params.min_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau, - params.xtc_probability, params.xtc_threshold); + params.xtc_probability, params.xtc_threshold, params.top_n_sigma); return std::string(result); } @@ -156,6 +156,7 @@ std::string llama_sampling_type_to_str(llama_sampler_type sampler_type) { case llama_sampler_type::MIN_P: return "min_p"; case llama_sampler_type::TEMPERATURE: return "temperature"; case llama_sampler_type::XTC : return "xtc"; + case llama_sampler_type::TOP_N_SIGMA: return "top_n_sigma"; default : return ""; } } @@ -168,6 +169,7 @@ std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vecto {"min_p", llama_sampler_type::MIN_P}, {"tfs_z", llama_sampler_type::TFS_Z}, {"xtc", llama_sampler_type::XTC}, + {"top_n_sigma", llama_sampler_type::TOP_N_SIGMA}, {"temperature", llama_sampler_type::TEMPERATURE} }; @@ -183,6 +185,7 @@ std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vecto {"tfs-z", llama_sampler_type::TFS_Z}, {"tfs", llama_sampler_type::TFS_Z}, {"xtc", llama_sampler_type::XTC}, + {"top-n-sigma", llama_sampler_type::TOP_N_SIGMA}, {"temp", llama_sampler_type::TEMPERATURE} }; @@ -218,6 +221,7 @@ std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::strin {'m', llama_sampler_type::MIN_P}, {'f', llama_sampler_type::TFS_Z}, {'x', llama_sampler_type::XTC}, + {'n', llama_sampler_type::TOP_N_SIGMA}, {'t', llama_sampler_type::TEMPERATURE} }; @@ -248,16 +252,18 @@ static void sampler_queue( const float typical_p = params.typical_p; const float xtc_probability = params.xtc_probability; const float xtc_threshold = params.xtc_threshold; + const float top_n_sigma = params.top_n_sigma; const std::vector<llama_sampler_type> & samplers_sequence = params.samplers_sequence; for (auto sampler_type : samplers_sequence) { switch (sampler_type) { - case llama_sampler_type::TOP_K : llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); break; - case llama_sampler_type::TFS_Z : llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); break; - case llama_sampler_type::TYPICAL_P: llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break; - case llama_sampler_type::TOP_P : llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break; - case llama_sampler_type::MIN_P : llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break; - case llama_sampler_type::XTC : llama_sample_xtc (ctx_main, &cur_p, xtc_probability, xtc_threshold, min_keep); break; + case llama_sampler_type::TOP_K : llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); break; + case llama_sampler_type::TFS_Z : llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); break; + case llama_sampler_type::TYPICAL_P : llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break; + case llama_sampler_type::TOP_P : llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break; + case llama_sampler_type::MIN_P : llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break; + case llama_sampler_type::XTC : llama_sample_xtc (ctx_main, &cur_p, xtc_probability, xtc_threshold, min_keep); break; + case llama_sampler_type::TOP_N_SIGMA: llama_sample_top_n_sigma(ctx_main, &cur_p, top_n_sigma); break; case llama_sampler_type::TEMPERATURE: if (dynatemp_range > 0) { float dynatemp_min = std::max(0.0f, temp - dynatemp_range); diff --git a/common/sampling.h b/common/sampling.h index 163cdfca..99fb07ac 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -16,6 +16,7 @@ enum class llama_sampler_type : char { MIN_P = 'm', TFS_Z = 'f', XTC = 'x', + TOP_N_SIGMA = 'n', TYPICAL_P = 'y', TEMPERATURE = 't' }; @@ -41,7 +42,8 @@ typedef struct llama_sampling_params { float mirostat_tau = 5.00f; // target entropy float mirostat_eta = 0.10f; // learning rate float xtc_probability = 0.0f; // xtc probability - float xtc_threshold = 1.0f; // xtc threashold, disabled if > 0.5 + float xtc_threshold = 1.0f; // xtc threshold, disabled if > 0.5 + float top_n_sigma = 0.0f; // top-n-sigma bool penalize_nl = false; // consider newlines as a repeatable token uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context |