summaryrefslogtreecommitdiff
path: root/common/sampling.cpp
diff options
context:
space:
mode:
authorDavid Friehs <david@friehs.info>2024-01-15 14:06:52 +0100
committerGitHub <noreply@github.com>2024-01-15 15:06:52 +0200
commit4483396751c79dea540808b9cb9238245d06da2b (patch)
tree81819684d572a750d80c740a8cc0103ae1ab0c2d /common/sampling.cpp
parentd9aa4ffa6e0296d42f1f676dd85de97c8491eb73 (diff)
llama : apply classifier-free guidance to logits directly (#4951)
Diffstat (limited to 'common/sampling.cpp')
-rw-r--r--common/sampling.cpp9
1 files changed, 5 insertions, 4 deletions
diff --git a/common/sampling.cpp b/common/sampling.cpp
index 8e45909f..dd1ffeb1 100644
--- a/common/sampling.cpp
+++ b/common/sampling.cpp
@@ -190,6 +190,11 @@ static llama_token llama_sampling_sample_impl(
logits[it->first] += it->second;
}
+ if (ctx_cfg) {
+ float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
+ llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
+ }
+
cur.clear();
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
@@ -198,10 +203,6 @@ static llama_token llama_sampling_sample_impl(
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
- if (ctx_cfg) {
- llama_sample_classifier_free_guidance(ctx_main, &cur_p, ctx_cfg, params.cfg_scale);
- }
-
// apply penalties
const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);