summaryrefslogtreecommitdiff
path: root/common/sampling.h
diff options
context:
space:
mode:
Diffstat (limited to 'common/sampling.h')
-rw-r--r--common/sampling.h32
1 files changed, 21 insertions, 11 deletions
diff --git a/common/sampling.h b/common/sampling.h
index 50afcbc1..62ea6d4c 100644
--- a/common/sampling.h
+++ b/common/sampling.h
@@ -10,30 +10,30 @@
// sampling parameters
typedef struct llama_sampling_params {
+ int32_t n_prev = 64; // number of previous tokens to remember
+ int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
int32_t top_k = 40; // <= 0 to use vocab size
float top_p = 0.95f; // 1.0 = disabled
float tfs_z = 1.00f; // 1.0 = disabled
float typical_p = 1.00f; // 1.0 = disabled
float temp = 0.80f; // 1.0 = disabled
- float repeat_penalty = 1.10f; // 1.0 = disabled
- int32_t repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
- float frequency_penalty = 0.00f; // 0.0 = disabled
- float presence_penalty = 0.00f; // 0.0 = disabled
+ int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
+ float penalty_repeat = 1.10f; // 1.0 = disabled
+ float penalty_freq = 0.00f; // 0.0 = disabled
+ float penalty_present = 0.00f; // 0.0 = disabled
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate
-
bool penalize_nl = true; // consider newlines as a repeatable token
- int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
+ std::string grammar; // optional BNF-like grammar to constrain sampling
// Classifier-Free Guidance
// https://arxiv.org/abs/2306.17806
- std::string cfg_negative_prompt; // string to help guidance
- float cfg_scale = 1.f; // How strong is guidance
+ std::string cfg_negative_prompt; // string to help guidance
+ float cfg_scale = 1.f; // how strong is guidance
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
-
} llama_sampling_params;
// general sampler context
@@ -58,7 +58,7 @@ struct llama_sampling_context {
#include "common.h"
// Create a new sampling context instance.
-struct llama_sampling_context * llama_sampling_init(const struct gpt_params & params);
+struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params);
void llama_sampling_free(struct llama_sampling_context * ctx);
@@ -70,6 +70,15 @@ void llama_sampling_reset(llama_sampling_context * ctx);
// Copy the sampler context
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst);
+// Get the last sampled token
+llama_token llama_sampling_last(llama_sampling_context * ctx);
+
+// Get a string representation of the last sampled tokens
+std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n);
+
+// Print sampling parameters into a string
+std::string llama_sampling_print(const llama_sampling_params & params);
+
// this is a common sampling function used across the examples for convenience
// it can serve as a starting point for implementing your own sampling function
// Note: When using multiple sequences, it is the caller's responsibility to call
@@ -96,4 +105,5 @@ llama_token llama_sampling_sample(
void llama_sampling_accept(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
- llama_token id);
+ llama_token id,
+ bool apply_grammar);