summaryrefslogtreecommitdiff
path: root/common
diff options
context:
space:
mode:
Diffstat (limited to 'common')
-rw-r--r--common/common.cpp119
-rw-r--r--common/common.h7
-rw-r--r--common/sampling.cpp33
-rw-r--r--common/sampling.h20
4 files changed, 122 insertions, 57 deletions
diff --git a/common/common.cpp b/common/common.cpp
index 9a489a55..f64da2cb 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -340,13 +340,14 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
invalid_param = true;
break;
}
- sparams.samplers_sequence = parse_samplers_input(argv[i]);
+ const auto sampler_names = string_split(argv[i], ';');
+ sparams.samplers_sequence = sampler_types_from_names(sampler_names);
} else if (arg == "--sampling-seq") {
if (++i >= argc) {
invalid_param = true;
break;
}
- sparams.samplers_sequence = argv[i];
+ sparams.samplers_sequence = sampler_types_from_chars(argv[i]);
} else if (arg == "--top-p") {
if (++i >= argc) {
invalid_param = true;
@@ -906,6 +907,14 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
const llama_sampling_params & sparams = params.sparams;
+ std::string sampler_type_chars;
+ std::string sampler_type_names;
+ for (const auto sampler_type : sparams.samplers_sequence) {
+ sampler_type_chars += static_cast<char>(sampler_type);
+ sampler_type_names += sampler_type_to_name_string(sampler_type) + ";";
+ }
+ sampler_type_names.pop_back();
+
printf("\n");
printf("usage: %s [options]\n", argv[0]);
printf("\n");
@@ -947,8 +956,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict);
printf(" -c N, --ctx-size N size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx);
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
- printf(" --samplers samplers that will be used for generation in the order, separated by \';\', for example: \"top_k;tfs;typical;top_p;min_p;temp\"\n");
- printf(" --sampling-seq simplified sequence for samplers that will be used (default: %s)\n", sparams.samplers_sequence.c_str());
+ printf(" --samplers samplers that will be used for generation in the order, separated by \';\' (default: %s)\n", sampler_type_names.c_str());
+ printf(" --sampling-seq simplified sequence for samplers that will be used (default: %s)\n", sampler_type_chars.c_str());
printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k);
printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p);
printf(" --min-p N min-p sampling (default: %.1f, 0.0 = disabled)\n", (double)sparams.min_p);
@@ -1097,45 +1106,85 @@ std::string gpt_random_prompt(std::mt19937 & rng) {
}
//
-// String parsing
+// String utils
//
-std::string parse_samplers_input(std::string input) {
- std::string output = "";
+std::vector<std::string> string_split(std::string input, char separator) {
+ std::vector<std::string> parts;
+ size_t separator_pos = input.find(separator);
+ while (separator_pos != std::string::npos) {
+ std::string part = input.substr(0, separator_pos);
+ parts.emplace_back(part);
+ input = input.substr(separator_pos + 1);
+ separator_pos = input.find(separator);
+ }
+ parts.emplace_back(input);
+ return parts;
+}
+
+std::vector<llama_sampler_type> sampler_types_from_names(const std::vector<std::string> & names) {
// since samplers names are written multiple ways
// make it ready for both system names and input names
- std::unordered_map<std::string, char> samplers_symbols {
- {"top_k", 'k'},
- {"top-k", 'k'},
- {"top_p", 'p'},
- {"top-p", 'p'},
- {"nucleus", 'p'},
- {"typical_p", 'y'},
- {"typical-p", 'y'},
- {"typical", 'y'},
- {"min_p", 'm'},
- {"min-p", 'm'},
- {"tfs_z", 'f'},
- {"tfs-z", 'f'},
- {"tfs", 'f'},
- {"temp", 't'},
- {"temperature",'t'}
+ std::unordered_map<std::string, llama_sampler_type> sampler_name_map {
+ {"top_k", llama_sampler_type::TOP_K},
+ {"top-k", llama_sampler_type::TOP_K},
+ {"top_p", llama_sampler_type::TOP_P},
+ {"top-p", llama_sampler_type::TOP_P},
+ {"nucleus", llama_sampler_type::TOP_P},
+ {"typical_p", llama_sampler_type::TYPICAL_P},
+ {"typical-p", llama_sampler_type::TYPICAL_P},
+ {"typical", llama_sampler_type::TYPICAL_P},
+ {"min_p", llama_sampler_type::MIN_P},
+ {"min-p", llama_sampler_type::MIN_P},
+ {"tfs_z", llama_sampler_type::TFS_Z},
+ {"tfs-z", llama_sampler_type::TFS_Z},
+ {"tfs", llama_sampler_type::TFS_Z},
+ {"temp", llama_sampler_type::TEMP},
+ {"temperature", llama_sampler_type::TEMP}
+ };
+
+ std::vector<llama_sampler_type> sampler_types;
+ sampler_types.reserve(names.size());
+ for (const auto& name : names) {
+ const auto sampler_item = sampler_name_map.find(name);
+ if (sampler_item != sampler_name_map.end()) {
+ sampler_types.push_back(sampler_item->second);
+ }
+ }
+ return sampler_types;
+}
+
+std::vector<llama_sampler_type> sampler_types_from_chars(const std::string & names_string) {
+ std::unordered_map<char, llama_sampler_type> sampler_name_map {
+ {'k', llama_sampler_type::TOP_K},
+ {'p', llama_sampler_type::TOP_P},
+ {'y', llama_sampler_type::TYPICAL_P},
+ {'m', llama_sampler_type::MIN_P},
+ {'f', llama_sampler_type::TFS_Z},
+ {'t', llama_sampler_type::TEMP}
};
- // expected format example: "temp;top_k;tfs_z;typical_p;top_p;min_p"
- size_t separator = input.find(';');
- while (separator != input.npos) {
- std::string name = input.substr(0,separator);
- input = input.substr(separator+1);
- separator = input.find(';');
-
- if (samplers_symbols.find(name) != samplers_symbols.end()) {
- output += samplers_symbols[name];
+
+ std::vector<llama_sampler_type> sampler_types;
+ sampler_types.reserve(names_string.size());
+ for (const auto & c : names_string) {
+ const auto sampler_item = sampler_name_map.find(c);
+ if (sampler_item != sampler_name_map.end()) {
+ sampler_types.push_back(sampler_item->second);
}
}
- if (samplers_symbols.find(input) != samplers_symbols.end()) {
- output += samplers_symbols[input];
+ return sampler_types;
+}
+
+std::string sampler_type_to_name_string(llama_sampler_type sampler_type) {
+ switch (sampler_type) {
+ case llama_sampler_type::TOP_K: return "top_k";
+ case llama_sampler_type::TFS_Z: return "tfs_z";
+ case llama_sampler_type::TYPICAL_P: return "typical_p";
+ case llama_sampler_type::TOP_P: return "top_p";
+ case llama_sampler_type::MIN_P: return "min_p";
+ case llama_sampler_type::TEMP: return "temp";
+ default : return "";
}
- return output;
}
//
diff --git a/common/common.h b/common/common.h
index 62de25d6..9bdd45cf 100644
--- a/common/common.h
+++ b/common/common.h
@@ -162,10 +162,13 @@ std::string gpt_random_prompt(std::mt19937 & rng);
void process_escapes(std::string& input);
//
-// String parsing
+// String utils
//
-std::string parse_samplers_input(std::string input);
+std::vector<llama_sampler_type> sampler_types_from_names(const std::vector<std::string> & names);
+std::vector<llama_sampler_type> sampler_types_from_chars(const std::string & names_string);
+std::vector<std::string> string_split(std::string input, char separator);
+std::string sampler_type_to_name_string(llama_sampler_type sampler_type);
//
// Model utils
diff --git a/common/sampling.cpp b/common/sampling.cpp
index 82cbdece..a001750d 100644
--- a/common/sampling.cpp
+++ b/common/sampling.cpp
@@ -103,15 +103,10 @@ std::string llama_sampling_print(const llama_sampling_params & params) {
std::string llama_sampling_order_print(const llama_sampling_params & params) {
std::string result = "CFG -> Penalties ";
if (params.mirostat == 0) {
- for (auto s : params.samplers_sequence) {
- switch (s) {
- case 'k': result += "-> top_k "; break;
- case 'f': result += "-> tfs_z "; break;
- case 'y': result += "-> typical_p "; break;
- case 'p': result += "-> top_p "; break;
- case 'm': result += "-> min_p "; break;
- case 't': result += "-> temp "; break;
- default : break;
+ for (auto sampler_type : params.samplers_sequence) {
+ const auto sampler_type_name = sampler_type_to_name_string(sampler_type);
+ if (!sampler_type_name.empty()) {
+ result += "-> " + sampler_type_name + " ";
}
}
} else {
@@ -135,16 +130,16 @@ static void sampler_queue(
const float min_p = params.min_p;
const float tfs_z = params.tfs_z;
const float typical_p = params.typical_p;
- const std::string & samplers_sequence = params.samplers_sequence;
-
- for (auto s : samplers_sequence) {
- switch (s){
- case 'k': llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); break;
- case 'f': llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); break;
- case 'y': llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break;
- case 'p': llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break;
- case 'm': llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break;
- case 't':
+ const std::vector<llama_sampler_type> & samplers_sequence = params.samplers_sequence;
+
+ for (auto sampler_type : samplers_sequence) {
+ switch (sampler_type) {
+ case llama_sampler_type::TOP_K : llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); break;
+ case llama_sampler_type::TFS_Z : llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); break;
+ case llama_sampler_type::TYPICAL_P: llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break;
+ case llama_sampler_type::TOP_P : llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break;
+ case llama_sampler_type::MIN_P : llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break;
+ case llama_sampler_type::TEMP:
if (dynatemp_range > 0) {
float dynatemp_min = std::max(0.0f, temp - dynatemp_range);
float dynatemp_max = std::max(0.0f, temp + dynatemp_range);
diff --git a/common/sampling.h b/common/sampling.h
index 88899c09..2bd6a75d 100644
--- a/common/sampling.h
+++ b/common/sampling.h
@@ -8,6 +8,16 @@
#include <vector>
#include <unordered_map>
+// sampler types
+enum class llama_sampler_type : char {
+ TOP_K = 'k',
+ TOP_P = 'p',
+ MIN_P = 'm',
+ TFS_Z = 'f',
+ TYPICAL_P = 'y',
+ TEMP = 't'
+};
+
// sampling parameters
typedef struct llama_sampling_params {
int32_t n_prev = 64; // number of previous tokens to remember
@@ -28,7 +38,15 @@ typedef struct llama_sampling_params {
float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate
bool penalize_nl = true; // consider newlines as a repeatable token
- std::string samplers_sequence = "kfypmt"; // top_k, tail_free, typical_p, top_p, min_p, temp
+
+ std::vector<llama_sampler_type> samplers_sequence = {
+ llama_sampler_type::TOP_K,
+ llama_sampler_type::TFS_Z,
+ llama_sampler_type::TYPICAL_P,
+ llama_sampler_type::TOP_P,
+ llama_sampler_type::MIN_P,
+ llama_sampler_type::TEMP
+ };
std::string grammar; // optional BNF-like grammar to constrain sampling