summaryrefslogtreecommitdiff
path: root/common/sampling.cpp
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2024-05-22 20:04:20 +0300
committerGitHub <noreply@github.com>2024-05-22 20:04:20 +0300
commit6ff13987ad1a9519bee13dd98b6a21cd98979aab (patch)
treee085c9fbac76f57dbbef6233b42eb981352e9925 /common/sampling.cpp
parent38c03478a37e460ecd3a21155b338a83bfed7f90 (diff)
common : normalize naming style (#7462)
* common : normalize naming style ggml-ci * common : match declaration / definition order * zig : try to fix build
Diffstat (limited to 'common/sampling.cpp')
-rw-r--r--common/sampling.cpp83
1 files changed, 82 insertions, 1 deletions
diff --git a/common/sampling.cpp b/common/sampling.cpp
index 7fc2e215..f1f80351 100644
--- a/common/sampling.cpp
+++ b/common/sampling.cpp
@@ -125,7 +125,7 @@ std::string llama_sampling_order_print(const llama_sampling_params & params) {
std::string result = "CFG -> Penalties ";
if (params.mirostat == 0) {
for (auto sampler_type : params.samplers_sequence) {
- const auto sampler_type_name = sampler_type_to_name_string(sampler_type);
+ const auto sampler_type_name = llama_sampling_type_to_str(sampler_type);
if (!sampler_type_name.empty()) {
result += "-> " + sampler_type_name + " ";
}
@@ -137,6 +137,87 @@ std::string llama_sampling_order_print(const llama_sampling_params & params) {
return result;
}
+std::string llama_sampling_type_to_str(llama_sampler_type sampler_type) {
+ switch (sampler_type) {
+ case llama_sampler_type::TOP_K: return "top_k";
+ case llama_sampler_type::TFS_Z: return "tfs_z";
+ case llama_sampler_type::TYPICAL_P: return "typical_p";
+ case llama_sampler_type::TOP_P: return "top_p";
+ case llama_sampler_type::MIN_P: return "min_p";
+ case llama_sampler_type::TEMPERATURE: return "temperature";
+ default : return "";
+ }
+}
+
+std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
+ std::unordered_map<std::string, llama_sampler_type> sampler_canonical_name_map {
+ {"top_k", llama_sampler_type::TOP_K},
+ {"top_p", llama_sampler_type::TOP_P},
+ {"typical_p", llama_sampler_type::TYPICAL_P},
+ {"min_p", llama_sampler_type::MIN_P},
+ {"tfs_z", llama_sampler_type::TFS_Z},
+ {"temperature", llama_sampler_type::TEMPERATURE}
+ };
+
+ // since samplers names are written multiple ways
+ // make it ready for both system names and input names
+ std::unordered_map<std::string, llama_sampler_type> sampler_alt_name_map {
+ {"top-k", llama_sampler_type::TOP_K},
+ {"top-p", llama_sampler_type::TOP_P},
+ {"nucleus", llama_sampler_type::TOP_P},
+ {"typical-p", llama_sampler_type::TYPICAL_P},
+ {"typical", llama_sampler_type::TYPICAL_P},
+ {"min-p", llama_sampler_type::MIN_P},
+ {"tfs-z", llama_sampler_type::TFS_Z},
+ {"tfs", llama_sampler_type::TFS_Z},
+ {"temp", llama_sampler_type::TEMPERATURE}
+ };
+
+ std::vector<llama_sampler_type> sampler_types;
+ sampler_types.reserve(names.size());
+ for (const auto & name : names)
+ {
+ auto sampler_item = sampler_canonical_name_map.find(name);
+ if (sampler_item != sampler_canonical_name_map.end())
+ {
+ sampler_types.push_back(sampler_item->second);
+ }
+ else
+ {
+ if (allow_alt_names)
+ {
+ sampler_item = sampler_alt_name_map.find(name);
+ if (sampler_item != sampler_alt_name_map.end())
+ {
+ sampler_types.push_back(sampler_item->second);
+ }
+ }
+ }
+ }
+ return sampler_types;
+}
+
+std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::string & names_string) {
+ std::unordered_map<char, llama_sampler_type> sampler_name_map {
+ {'k', llama_sampler_type::TOP_K},
+ {'p', llama_sampler_type::TOP_P},
+ {'y', llama_sampler_type::TYPICAL_P},
+ {'m', llama_sampler_type::MIN_P},
+ {'f', llama_sampler_type::TFS_Z},
+ {'t', llama_sampler_type::TEMPERATURE}
+ };
+
+ std::vector<llama_sampler_type> sampler_types;
+ sampler_types.reserve(names_string.size());
+ for (const auto & c : names_string) {
+ const auto sampler_item = sampler_name_map.find(c);
+ if (sampler_item != sampler_name_map.end()) {
+ sampler_types.push_back(sampler_item->second);
+ }
+ }
+ return sampler_types;
+}
+
// no reasons to expose this function in header
static void sampler_queue(
struct llama_context * ctx_main,