summaryrefslogtreecommitdiff
path: root/examples/common.h
diff options
context:
space:
mode:
authorBach Le <bach@bullno1.com>2023-07-12 00:18:43 +0800
committerGitHub <noreply@github.com>2023-07-11 19:18:43 +0300
commitc9c74b4e3f9dcfab8b0032749ff8a579ab4e4d8d (patch)
tree651d6915218efa83cad8745310f7d1114ca21e2a /examples/common.h
parent3ec7e596b2ba3f43c22f441254ca2bcfa91102ba (diff)
llama : add classifier-free guidance (#2135)
* Initial implementation * Remove debug print * Restore signature of llama_init_from_gpt_params * Free guidance context * Make freeing of guidance_ctx conditional * Make Classifier-Free Guidance a sampling function * Correct typo. CFG already means context-free grammar. * Record sampling time in llama_sample_classifier_free_guidance * Shift all values by the max value before applying logsoftmax * Fix styling based on review
Diffstat (limited to 'examples/common.h')
-rw-r--r--examples/common.h7
1 files changed, 7 insertions, 0 deletions
diff --git a/examples/common.h b/examples/common.h
index 96f2228f..6315df96 100644
--- a/examples/common.h
+++ b/examples/common.h
@@ -48,6 +48,12 @@ struct gpt_params {
float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate
+ // Classifier-Free Guidance
+ // https://arxiv.org/abs/2306.17806
+ std::string cfg_negative_prompt; // string to help guidance
+ float cfg_scale = 1.f; // How strong is guidance
+ float cfg_smooth_factor = 1.f; // Smooth factor between old and new logits
+
std::string model = "models/7B/ggml-model.bin"; // model path
std::string model_alias = "unknown"; // model alias
std::string prompt = "";
@@ -99,6 +105,7 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
//
std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(const gpt_params & params);
+struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params);
//
// Console utils