summaryrefslogtreecommitdiff
path: root/common/common.cpp
diff options
context:
space:
mode:
authorkalomaze <66376113+kalomaze@users.noreply.github.com>2023-10-31 14:44:49 -0500
committerGitHub <noreply@github.com>2023-10-31 20:44:49 +0100
commit238657db2364cfb728c694470a4a81702afea760 (patch)
tree8b870a0600d1a2de4d9efe7981c24164357f5552 /common/common.cpp
parent07178c98e1b61a5e2af39d347add12e7eb9e08e1 (diff)
samplers : Min-P sampler implementation [alternative to Top P/Top K] (#3841)
* Introduce the new Min-P sampler by @kalomaze The Min-P sampling method was designed as an alternative to Top-P, and aims to ensure a balance of quality and variety. The parameter *p* represents the minimum probability for a token to be considered, relative to the probability of the most likely token. * Min-P enabled and set to 0.05 default --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: cebtenzzre <cebtenzzre@gmail.com>
Diffstat (limited to 'common/common.cpp')
-rw-r--r--common/common.cpp8
1 files changed, 8 insertions, 0 deletions
diff --git a/common/common.cpp b/common/common.cpp
index c187128d..dc4865e8 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -218,6 +218,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break;
}
sparams.top_p = std::stof(argv[i]);
+ } else if (arg == "--min-p") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ sparams.min_p = std::stof(argv[i]);
} else if (arg == "--temp") {
if (++i >= argc) {
invalid_param = true;
@@ -679,6 +685,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k);
printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p);
+ printf(" --min-p N min-p sampling (default: %.1f, 0.0 = disabled)\n", (double)sparams.min_p);
printf(" --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)sparams.tfs_z);
printf(" --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)sparams.typical_p);
printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.penalty_last_n);
@@ -1275,6 +1282,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, "threads: %d # default: %d\n", params.n_threads, std::thread::hardware_concurrency());
fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k);
fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p);
+ fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p);
fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p);
fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
}