summaryrefslogtreecommitdiff
path: root/common
diff options
context:
space:
mode:
Diffstat (limited to 'common')
-rw-r--r--common/common.cpp59
-rw-r--r--common/common.h2
-rw-r--r--common/sampling.cpp2
-rw-r--r--common/sampling.h14
4 files changed, 47 insertions, 30 deletions
diff --git a/common/common.cpp b/common/common.cpp
index c5e83cc2..3a92d379 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -341,7 +341,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
break;
}
const auto sampler_names = string_split(argv[i], ';');
- sparams.samplers_sequence = sampler_types_from_names(sampler_names);
+ sparams.samplers_sequence = sampler_types_from_names(sampler_names, true);
} else if (arg == "--sampling-seq") {
if (++i >= argc) {
invalid_param = true;
@@ -964,7 +964,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict);
printf(" -c N, --ctx-size N size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx);
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
- printf(" --samplers samplers that will be used for generation in the order, separated by \';\' (default: %s)\n", sampler_type_names.c_str());
+ printf(" --samplers samplers that will be used for generation in the order, separated by \';\'\n");
+ printf(" (default: %s)\n", sampler_type_names.c_str());
printf(" --sampling-seq simplified sequence for samplers that will be used (default: %s)\n", sampler_type_chars.c_str());
printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k);
printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p);
@@ -1133,34 +1134,50 @@ std::vector<std::string> string_split(std::string input, char separator) {
return parts;
}
-std::vector<llama_sampler_type> sampler_types_from_names(const std::vector<std::string> & names) {
+std::vector<llama_sampler_type> sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
+ std::unordered_map<std::string, llama_sampler_type> sampler_canonical_name_map {
+ {"top_k", llama_sampler_type::TOP_K},
+ {"top_p", llama_sampler_type::TOP_P},
+ {"typical_p", llama_sampler_type::TYPICAL_P},
+ {"min_p", llama_sampler_type::MIN_P},
+ {"tfs_z", llama_sampler_type::TFS_Z},
+ {"temperature", llama_sampler_type::TEMPERATURE}
+ };
+
// since samplers names are written multiple ways
// make it ready for both system names and input names
- std::unordered_map<std::string, llama_sampler_type> sampler_name_map {
- {"top_k", llama_sampler_type::TOP_K},
+ std::unordered_map<std::string, llama_sampler_type> sampler_alt_name_map {
{"top-k", llama_sampler_type::TOP_K},
- {"top_p", llama_sampler_type::TOP_P},
{"top-p", llama_sampler_type::TOP_P},
{"nucleus", llama_sampler_type::TOP_P},
- {"typical_p", llama_sampler_type::TYPICAL_P},
{"typical-p", llama_sampler_type::TYPICAL_P},
{"typical", llama_sampler_type::TYPICAL_P},
- {"min_p", llama_sampler_type::MIN_P},
{"min-p", llama_sampler_type::MIN_P},
- {"tfs_z", llama_sampler_type::TFS_Z},
{"tfs-z", llama_sampler_type::TFS_Z},
{"tfs", llama_sampler_type::TFS_Z},
- {"temp", llama_sampler_type::TEMP},
- {"temperature", llama_sampler_type::TEMP}
+ {"temp", llama_sampler_type::TEMPERATURE}
};
std::vector<llama_sampler_type> sampler_types;
sampler_types.reserve(names.size());
- for (const auto& name : names) {
- const auto sampler_item = sampler_name_map.find(name);
- if (sampler_item != sampler_name_map.end()) {
+ for (const auto & name : names)
+ {
+ auto sampler_item = sampler_canonical_name_map.find(name);
+ if (sampler_item != sampler_canonical_name_map.end())
+ {
sampler_types.push_back(sampler_item->second);
}
+ else
+ {
+ if (allow_alt_names)
+ {
+ sampler_item = sampler_alt_name_map.find(name);
+ if (sampler_item != sampler_alt_name_map.end())
+ {
+ sampler_types.push_back(sampler_item->second);
+ }
+ }
+ }
}
return sampler_types;
}
@@ -1172,7 +1189,7 @@ std::vector<llama_sampler_type> sampler_types_from_chars(const std::string & nam
{'y', llama_sampler_type::TYPICAL_P},
{'m', llama_sampler_type::MIN_P},
{'f', llama_sampler_type::TFS_Z},
- {'t', llama_sampler_type::TEMP}
+ {'t', llama_sampler_type::TEMPERATURE}
};
std::vector<llama_sampler_type> sampler_types;
@@ -1188,12 +1205,12 @@ std::vector<llama_sampler_type> sampler_types_from_chars(const std::string & nam
std::string sampler_type_to_name_string(llama_sampler_type sampler_type) {
switch (sampler_type) {
- case llama_sampler_type::TOP_K: return "top_k";
- case llama_sampler_type::TFS_Z: return "tfs_z";
- case llama_sampler_type::TYPICAL_P: return "typical_p";
- case llama_sampler_type::TOP_P: return "top_p";
- case llama_sampler_type::MIN_P: return "min_p";
- case llama_sampler_type::TEMP: return "temp";
+ case llama_sampler_type::TOP_K: return "top_k";
+ case llama_sampler_type::TFS_Z: return "tfs_z";
+ case llama_sampler_type::TYPICAL_P: return "typical_p";
+ case llama_sampler_type::TOP_P: return "top_p";
+ case llama_sampler_type::MIN_P: return "min_p";
+ case llama_sampler_type::TEMPERATURE: return "temperature";
default : return "";
}
}
diff --git a/common/common.h b/common/common.h
index 74c13699..935771d4 100644
--- a/common/common.h
+++ b/common/common.h
@@ -165,7 +165,7 @@ void process_escapes(std::string& input);
// String utils
//
-std::vector<llama_sampler_type> sampler_types_from_names(const std::vector<std::string> & names);
+std::vector<llama_sampler_type> sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
std::vector<llama_sampler_type> sampler_types_from_chars(const std::string & names_string);
std::vector<std::string> string_split(std::string input, char separator);
std::string sampler_type_to_name_string(llama_sampler_type sampler_type);
diff --git a/common/sampling.cpp b/common/sampling.cpp
index a001750d..53013138 100644
--- a/common/sampling.cpp
+++ b/common/sampling.cpp
@@ -139,7 +139,7 @@ static void sampler_queue(
case llama_sampler_type::TYPICAL_P: llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break;
case llama_sampler_type::TOP_P : llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break;
case llama_sampler_type::MIN_P : llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break;
- case llama_sampler_type::TEMP:
+ case llama_sampler_type::TEMPERATURE:
if (dynatemp_range > 0) {
float dynatemp_min = std::max(0.0f, temp - dynatemp_range);
float dynatemp_max = std::max(0.0f, temp + dynatemp_range);
diff --git a/common/sampling.h b/common/sampling.h
index 2bd6a75d..e1279a89 100644
--- a/common/sampling.h
+++ b/common/sampling.h
@@ -10,12 +10,12 @@
// sampler types
enum class llama_sampler_type : char {
- TOP_K = 'k',
- TOP_P = 'p',
- MIN_P = 'm',
- TFS_Z = 'f',
- TYPICAL_P = 'y',
- TEMP = 't'
+ TOP_K = 'k',
+ TOP_P = 'p',
+ MIN_P = 'm',
+ TFS_Z = 'f',
+ TYPICAL_P = 'y',
+ TEMPERATURE = 't'
};
// sampling parameters
@@ -45,7 +45,7 @@ typedef struct llama_sampling_params {
llama_sampler_type::TYPICAL_P,
llama_sampler_type::TOP_P,
llama_sampler_type::MIN_P,
- llama_sampler_type::TEMP
+ llama_sampler_type::TEMPERATURE
};
std::string grammar; // optional BNF-like grammar to constrain sampling