From 4483396751c79dea540808b9cb9238245d06da2b Mon Sep 17 00:00:00 2001 From: David Friehs Date: Mon, 15 Jan 2024 14:06:52 +0100 Subject: llama : apply classifier-free guidance to logits directly (#4951) --- common/sampling.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) (limited to 'common/sampling.cpp') diff --git a/common/sampling.cpp b/common/sampling.cpp index 8e45909f..dd1ffeb1 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -190,6 +190,11 @@ static llama_token llama_sampling_sample_impl( logits[it->first] += it->second; } + if (ctx_cfg) { + float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx); + llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale); + } + cur.clear(); for (llama_token token_id = 0; token_id < n_vocab; token_id++) { @@ -198,10 +203,6 @@ static llama_token llama_sampling_sample_impl( llama_token_data_array cur_p = { cur.data(), cur.size(), false }; - if (ctx_cfg) { - llama_sample_classifier_free_guidance(ctx_main, &cur_p, ctx_cfg, params.cfg_scale); - } - // apply penalties const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev; const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n); -- cgit v1.2.3