summaryrefslogtreecommitdiff
path: root/common/sampling.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'common/sampling.cpp')
-rw-r--r--common/sampling.cpp9
1 files changed, 5 insertions, 4 deletions
diff --git a/common/sampling.cpp b/common/sampling.cpp
index 8e45909f..dd1ffeb1 100644
--- a/common/sampling.cpp
+++ b/common/sampling.cpp
@@ -190,6 +190,11 @@ static llama_token llama_sampling_sample_impl(
logits[it->first] += it->second;
}
+ if (ctx_cfg) {
+ float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
+ llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
+ }
+
cur.clear();
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
@@ -198,10 +203,6 @@ static llama_token llama_sampling_sample_impl(
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
- if (ctx_cfg) {
- llama_sample_classifier_free_guidance(ctx_main, &cur_p, ctx_cfg, params.cfg_scale);
- }
-
// apply penalties
const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);