summaryrefslogtreecommitdiff
path: root/common/sampling.h
diff options
context:
space:
mode:
authorAlexey Parfenov <zxed@alkatrazstudio.net>2024-02-11 13:43:31 +0000
committerGitHub <noreply@github.com>2024-02-11 15:43:31 +0200
commita803333a4e6fc534c93afe90d741bc2388bdec87 (patch)
tree0f4781f58f4391691823cdabaaf604afae333155 /common/sampling.h
parent684780141a08200ec98eba3e982dbafd1d0b5000 (diff)
common : use enums for sampler types (#5418)
* common: use enums for sampler types * Apply suggestions from code review Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * minor : spaces --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'common/sampling.h')
-rw-r--r--common/sampling.h20
1 files changed, 19 insertions, 1 deletions
diff --git a/common/sampling.h b/common/sampling.h
index 88899c09..2bd6a75d 100644
--- a/common/sampling.h
+++ b/common/sampling.h
@@ -8,6 +8,16 @@
#include <vector>
#include <unordered_map>
+// sampler types
+enum class llama_sampler_type : char {
+ TOP_K = 'k',
+ TOP_P = 'p',
+ MIN_P = 'm',
+ TFS_Z = 'f',
+ TYPICAL_P = 'y',
+ TEMP = 't'
+};
+
// sampling parameters
typedef struct llama_sampling_params {
int32_t n_prev = 64; // number of previous tokens to remember
@@ -28,7 +38,15 @@ typedef struct llama_sampling_params {
float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate
bool penalize_nl = true; // consider newlines as a repeatable token
- std::string samplers_sequence = "kfypmt"; // top_k, tail_free, typical_p, top_p, min_p, temp
+
+ std::vector<llama_sampler_type> samplers_sequence = {
+ llama_sampler_type::TOP_K,
+ llama_sampler_type::TFS_Z,
+ llama_sampler_type::TYPICAL_P,
+ llama_sampler_type::TOP_P,
+ llama_sampler_type::MIN_P,
+ llama_sampler_type::TEMP
+ };
std::string grammar; // optional BNF-like grammar to constrain sampling