summaryrefslogtreecommitdiff
path: root/common/sampling.h
diff options
context:
space:
mode:
Diffstat (limited to 'common/sampling.h')
-rw-r--r--common/sampling.h36
1 files changed, 20 insertions, 16 deletions
diff --git a/common/sampling.h b/common/sampling.h
index 7c9b8dcf..fdfa9eed 100644
--- a/common/sampling.h
+++ b/common/sampling.h
@@ -10,22 +10,23 @@
// sampling parameters
typedef struct llama_sampling_params {
- int32_t n_prev = 64; // number of previous tokens to remember
- int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
- int32_t top_k = 40; // <= 0 to use vocab size
- float top_p = 0.95f; // 1.0 = disabled
- float min_p = 0.05f; // 0.0 = disabled
- float tfs_z = 1.00f; // 1.0 = disabled
- float typical_p = 1.00f; // 1.0 = disabled
- float temp = 0.80f; // 1.0 = disabled
- int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
- float penalty_repeat = 1.10f; // 1.0 = disabled
- float penalty_freq = 0.00f; // 0.0 = disabled
- float penalty_present = 0.00f; // 0.0 = disabled
- int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
- float mirostat_tau = 5.00f; // target entropy
- float mirostat_eta = 0.10f; // learning rate
- bool penalize_nl = true; // consider newlines as a repeatable token
+ int32_t n_prev = 64; // number of previous tokens to remember
+ int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
+ int32_t top_k = 40; // <= 0 to use vocab size
+ float top_p = 0.95f; // 1.0 = disabled
+ float min_p = 0.05f; // 0.0 = disabled
+ float tfs_z = 1.00f; // 1.0 = disabled
+ float typical_p = 1.00f; // 1.0 = disabled
+ float temp = 0.80f; // 1.0 = disabled
+ int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
+ float penalty_repeat = 1.10f; // 1.0 = disabled
+ float penalty_freq = 0.00f; // 0.0 = disabled
+ float penalty_present = 0.00f; // 0.0 = disabled
+ int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
+ float mirostat_tau = 5.00f; // target entropy
+ float mirostat_eta = 0.10f; // learning rate
+ bool penalize_nl = true; // consider newlines as a repeatable token
+ std::string samplers_sequence = "kfypmt"; // top_k, tail_free, typical_p, top_p, min_p, temp
std::string grammar; // optional BNF-like grammar to constrain sampling
@@ -80,6 +81,9 @@ std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama
// Print sampling parameters into a string
std::string llama_sampling_print(const llama_sampling_params & params);
+// Print sampling order into a string
+std::string llama_sampling_order_print(const llama_sampling_params & params);
+
// this is a common sampling function used across the examples for convenience
// it can serve as a starting point for implementing your own sampling function
// Note: When using multiple sequences, it is the caller's responsibility to call