summaryrefslogtreecommitdiff
path: root/common/common.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'common/common.cpp')
-rw-r--r--common/common.cpp56
1 files changed, 56 insertions, 0 deletions
diff --git a/common/common.cpp b/common/common.cpp
index 1dcc235e..8e6d74d0 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -280,6 +280,18 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
params.yarn_beta_slow = std::stof(argv[i]);
} else if (arg == "--memory-f32") {
params.memory_f16 = false;
+ } else if (arg == "--samplers") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ sparams.samplers_sequence = parse_samplers_input(argv[i]);
+ } else if (arg == "--sampling-seq") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ sparams.samplers_sequence = argv[i];
} else if (arg == "--top-p") {
if (++i >= argc) {
invalid_param = true;
@@ -761,6 +773,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict);
printf(" -c N, --ctx-size N size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx);
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
+ printf(" --samplers samplers that will be used for generation in the order, separated by \';\', for example: \"top_k;tfs;typical;top_p;min_p;temp\"\n");
+ printf(" --sampling-seq simplified sequence for samplers that will be used (default: %s)\n", sparams.samplers_sequence.c_str());
printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k);
printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p);
printf(" --min-p N min-p sampling (default: %.1f, 0.0 = disabled)\n", (double)sparams.min_p);
@@ -887,6 +901,48 @@ std::string gpt_random_prompt(std::mt19937 & rng) {
}
//
+// String parsing
+//
+
+std::string parse_samplers_input(std::string input) {
+ std::string output = "";
+ // since samplers names are written multiple ways
+ // make it ready for both system names and input names
+ std::unordered_map<std::string, char> samplers_symbols {
+ {"top_k", 'k'},
+ {"top-k", 'k'},
+ {"top_p", 'p'},
+ {"top-p", 'p'},
+ {"nucleus", 'p'},
+ {"typical_p", 'y'},
+ {"typical-p", 'y'},
+ {"typical", 'y'},
+ {"min_p", 'm'},
+ {"min-p", 'm'},
+ {"tfs_z", 'f'},
+ {"tfs-z", 'f'},
+ {"tfs", 'f'},
+ {"temp", 't'},
+ {"temperature",'t'}
+ };
+ // expected format example: "temp;top_k;tfs_z;typical_p;top_p;min_p"
+ size_t separator = input.find(';');
+ while (separator != input.npos) {
+ std::string name = input.substr(0,separator);
+ input = input.substr(separator+1);
+ separator = input.find(';');
+
+ if (samplers_symbols.find(name) != samplers_symbols.end()) {
+ output += samplers_symbols[name];
+ }
+ }
+ if (samplers_symbols.find(input) != samplers_symbols.end()) {
+ output += samplers_symbols[input];
+ }
+ return output;
+}
+
+//
// Model utils
//