diff options
author | David Friehs <david@friehs.info> | 2024-01-15 14:06:52 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-15 15:06:52 +0200 |
commit | 4483396751c79dea540808b9cb9238245d06da2b (patch) | |
tree | 81819684d572a750d80c740a8cc0103ae1ab0c2d /common/sampling.cpp | |
parent | d9aa4ffa6e0296d42f1f676dd85de97c8491eb73 (diff) |
llama : apply classifier-free guidance to logits directly (#4951)
Diffstat (limited to 'common/sampling.cpp')
-rw-r--r-- | common/sampling.cpp | 9 |
1 files changed, 5 insertions, 4 deletions
diff --git a/common/sampling.cpp b/common/sampling.cpp index 8e45909f..dd1ffeb1 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -190,6 +190,11 @@ static llama_token llama_sampling_sample_impl( logits[it->first] += it->second; } + if (ctx_cfg) { + float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx); + llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale); + } + cur.clear(); for (llama_token token_id = 0; token_id < n_vocab; token_id++) { @@ -198,10 +203,6 @@ static llama_token llama_sampling_sample_impl( llama_token_data_array cur_p = { cur.data(), cur.size(), false }; - if (ctx_cfg) { - llama_sample_classifier_free_guidance(ctx_main, &cur_p, ctx_cfg, params.cfg_scale); - } - // apply penalties const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev; const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n); |