summaryrefslogtreecommitdiff
path: root/common/sampling.cpp
diff options
context:
space:
mode:
authorAlexey Parfenov <zxed@alkatrazstudio.net>2024-02-11 13:43:31 +0000
committerGitHub <noreply@github.com>2024-02-11 15:43:31 +0200
commita803333a4e6fc534c93afe90d741bc2388bdec87 (patch)
tree0f4781f58f4391691823cdabaaf604afae333155 /common/sampling.cpp
parent684780141a08200ec98eba3e982dbafd1d0b5000 (diff)
common : use enums for sampler types (#5418)
* common: use enums for sampler types * Apply suggestions from code review Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * minor : spaces --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'common/sampling.cpp')
-rw-r--r--common/sampling.cpp33
1 files changed, 14 insertions, 19 deletions
diff --git a/common/sampling.cpp b/common/sampling.cpp
index 82cbdece..a001750d 100644
--- a/common/sampling.cpp
+++ b/common/sampling.cpp
@@ -103,15 +103,10 @@ std::string llama_sampling_print(const llama_sampling_params & params) {
std::string llama_sampling_order_print(const llama_sampling_params & params) {
std::string result = "CFG -> Penalties ";
if (params.mirostat == 0) {
- for (auto s : params.samplers_sequence) {
- switch (s) {
- case 'k': result += "-> top_k "; break;
- case 'f': result += "-> tfs_z "; break;
- case 'y': result += "-> typical_p "; break;
- case 'p': result += "-> top_p "; break;
- case 'm': result += "-> min_p "; break;
- case 't': result += "-> temp "; break;
- default : break;
+ for (auto sampler_type : params.samplers_sequence) {
+ const auto sampler_type_name = sampler_type_to_name_string(sampler_type);
+ if (!sampler_type_name.empty()) {
+ result += "-> " + sampler_type_name + " ";
}
}
} else {
@@ -135,16 +130,16 @@ static void sampler_queue(
const float min_p = params.min_p;
const float tfs_z = params.tfs_z;
const float typical_p = params.typical_p;
- const std::string & samplers_sequence = params.samplers_sequence;
-
- for (auto s : samplers_sequence) {
- switch (s){
- case 'k': llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); break;
- case 'f': llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); break;
- case 'y': llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break;
- case 'p': llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break;
- case 'm': llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break;
- case 't':
+ const std::vector<llama_sampler_type> & samplers_sequence = params.samplers_sequence;
+
+ for (auto sampler_type : samplers_sequence) {
+ switch (sampler_type) {
+ case llama_sampler_type::TOP_K : llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); break;
+ case llama_sampler_type::TFS_Z : llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); break;
+ case llama_sampler_type::TYPICAL_P: llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break;
+ case llama_sampler_type::TOP_P : llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break;
+ case llama_sampler_type::MIN_P : llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break;
+ case llama_sampler_type::TEMP:
if (dynatemp_range > 0) {
float dynatemp_min = std::max(0.0f, temp - dynatemp_range);
float dynatemp_max = std::max(0.0f, temp + dynatemp_range);