diff options
author | kalomaze <66376113+kalomaze@users.noreply.github.com> | 2023-10-31 14:44:49 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-31 20:44:49 +0100 |
commit | 238657db2364cfb728c694470a4a81702afea760 (patch) | |
tree | 8b870a0600d1a2de4d9efe7981c24164357f5552 /common/common.cpp | |
parent | 07178c98e1b61a5e2af39d347add12e7eb9e08e1 (diff) |
samplers : Min-P sampler implementation [alternative to Top P/Top K] (#3841)
* Introduce the new Min-P sampler by @kalomaze
The Min-P sampling method was designed as an alternative to Top-P, and aims to ensure a balance of quality and variety. The parameter *p* represents the minimum probability for a token to be considered, relative to the probability of the most likely token.
* Min-P enabled and set to 0.05 default
---------
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: cebtenzzre <cebtenzzre@gmail.com>
Diffstat (limited to 'common/common.cpp')
-rw-r--r-- | common/common.cpp | 8 |
1 files changed, 8 insertions, 0 deletions
diff --git a/common/common.cpp b/common/common.cpp index c187128d..dc4865e8 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -218,6 +218,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } sparams.top_p = std::stof(argv[i]); + } else if (arg == "--min-p") { + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.min_p = std::stof(argv[i]); } else if (arg == "--temp") { if (++i >= argc) { invalid_param = true; @@ -679,6 +685,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k); printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p); + printf(" --min-p N min-p sampling (default: %.1f, 0.0 = disabled)\n", (double)sparams.min_p); printf(" --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)sparams.tfs_z); printf(" --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)sparams.typical_p); printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.penalty_last_n); @@ -1275,6 +1282,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l fprintf(stream, "threads: %d # default: %d\n", params.n_threads, std::thread::hardware_concurrency()); fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k); fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p); + fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p); fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p); fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false"); } |