diff options
author | David Friehs <david@friehs.info> | 2024-01-15 14:06:52 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-15 15:06:52 +0200 |
commit | 4483396751c79dea540808b9cb9238245d06da2b (patch) | |
tree | 81819684d572a750d80c740a8cc0103ae1ab0c2d /llama.h | |
parent | d9aa4ffa6e0296d42f1f676dd85de97c8491eb73 (diff) |
llama : apply classifier-free guidance to logits directly (#4951)
Diffstat (limited to 'llama.h')
-rw-r--r-- | llama.h | 17 |
1 files changed, 12 insertions, 5 deletions
@@ -714,14 +714,21 @@ extern "C" { float penalty_present); /// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806 - /// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted. - /// @params guidance_ctx A separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context. - /// @params scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance. - LLAMA_API void llama_sample_classifier_free_guidance( + /// @param logits Logits extracted from the original generation context. + /// @param logits_guidance Logits extracted from a separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context. + /// @param scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance. + LLAMA_API void llama_sample_apply_guidance( + struct llama_context * ctx, + float * logits, + float * logits_guidance, + float scale); + + LLAMA_API DEPRECATED(void llama_sample_classifier_free_guidance( struct llama_context * ctx, llama_token_data_array * candidates, struct llama_context * guidance_ctx, - float scale); + float scale), + "use llama_sample_apply_guidance() instead"); /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. LLAMA_API void llama_sample_softmax( |