diff options
Diffstat (limited to 'common/sampling.h')
-rw-r--r-- | common/sampling.h | 20 |
1 files changed, 19 insertions, 1 deletions
diff --git a/common/sampling.h b/common/sampling.h index 88899c09..2bd6a75d 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -8,6 +8,16 @@ #include <vector> #include <unordered_map> +// sampler types +enum class llama_sampler_type : char { + TOP_K = 'k', + TOP_P = 'p', + MIN_P = 'm', + TFS_Z = 'f', + TYPICAL_P = 'y', + TEMP = 't' +}; + // sampling parameters typedef struct llama_sampling_params { int32_t n_prev = 64; // number of previous tokens to remember @@ -28,7 +38,15 @@ typedef struct llama_sampling_params { float mirostat_tau = 5.00f; // target entropy float mirostat_eta = 0.10f; // learning rate bool penalize_nl = true; // consider newlines as a repeatable token - std::string samplers_sequence = "kfypmt"; // top_k, tail_free, typical_p, top_p, min_p, temp + + std::vector<llama_sampler_type> samplers_sequence = { + llama_sampler_type::TOP_K, + llama_sampler_type::TFS_Z, + llama_sampler_type::TYPICAL_P, + llama_sampler_type::TOP_P, + llama_sampler_type::MIN_P, + llama_sampler_type::TEMP + }; std::string grammar; // optional BNF-like grammar to constrain sampling |