summaryrefslogtreecommitdiff
path: root/llama.h
diff options
context:
space:
mode:
authorMatt Pulver <matt.pulver@heavy.ai>2023-08-25 11:18:48 -0400
committerGitHub <noreply@github.com>2023-08-25 18:18:48 +0300
commitc82742ac9cd96fd34aa961978805c1d8a361d589 (patch)
treeee377f2559d967955ce1dde65b698504a33e2928 /llama.h
parent28b2c996ca0ab90a5669946084f13443ec98e241 (diff)
llama : add llama_beam_search() (#2267)
* Add llama_beam_search(). * Add '// Beam search' heading to llama.{h,cpp} after llama_grammar_accept_token(). * Add space around * pointers and & references. * Add spaces around comparison and assignment operators. * Prefer west const. * Use llama_ prefix for structs in global namespace. * Delete obsolete comment from an earlier revision. * Change eos to eob in llama_beam and llama_beam_view structs.
Diffstat (limited to 'llama.h')
-rw-r--r--llama.h37
1 files changed, 37 insertions, 0 deletions
diff --git a/llama.h b/llama.h
index d4746817..86737200 100644
--- a/llama.h
+++ b/llama.h
@@ -469,6 +469,43 @@ extern "C" {
/// @details Accepts the sampled token into the grammar
LLAMA_API void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token);
+ //
+ // Beam search
+ //
+
+ struct llama_beam_view {
+ const llama_token * tokens;
+ size_t n_tokens;
+ float p; // Cumulative beam probability (renormalized relative to all beams)
+ bool eob; // Callback should set this to true when a beam is at end-of-beam.
+ };
+
+ // Passed to beam_search_callback function.
+ // Whenever 0 < common_prefix_length, this number of tokens should be copied from any of the beams
+ // (e.g. beams[0]) as they will be removed (shifted) from all beams in all subsequent callbacks.
+ // These pointers are valid only during the synchronous callback, so should not be saved.
+ struct llama_beams_state {
+ llama_beam_view * beam_views;
+ size_t n_beams; // Number of elements in beam_views[].
+ size_t common_prefix_length; // Current max length of prefix tokens shared by all beams.
+ bool last_call; // True iff this is the last callback invocation.
+ };
+
+ // Type of pointer to the beam_search_callback function.
+ // void* callback_data is any custom data passed to llama_beam_search, that is subsequently
+ // passed back to beam_search_callback. This avoids having to use global variables in the callback.
+ typedef void (*llama_beam_search_callback_fn_t)(void * callback_data, llama_beams_state);
+
+ /// @details Deterministically returns entire sentence constructed by a beam search.
+ /// @param ctx Pointer to the llama_context.
+ /// @param callback Invoked for each iteration of the beam_search loop, passing in beams_state.
+ /// @param callback_data A pointer that is simply passed back to callback.
+ /// @param n_beams Number of beams to use.
+ /// @param n_past Number of tokens already evaluated.
+ /// @param n_predict Maximum number of tokens to predict. EOS may occur earlier.
+ /// @param n_threads Number of threads as passed to llama_eval().
+ LLAMA_API void llama_beam_search(struct llama_context * ctx, llama_beam_search_callback_fn_t callback, void * callback_data, size_t n_beams, int n_past, int n_predict, int n_threads);
+
// Performance information
LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);
LLAMA_API void llama_print_timings(struct llama_context * ctx);