diff options
author | Matt Pulver <matt.pulver@heavy.ai> | 2023-08-25 11:18:48 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-25 18:18:48 +0300 |
commit | c82742ac9cd96fd34aa961978805c1d8a361d589 (patch) | |
tree | ee377f2559d967955ce1dde65b698504a33e2928 /llama.cpp | |
parent | 28b2c996ca0ab90a5669946084f13443ec98e241 (diff) |
llama : add llama_beam_search() (#2267)
* Add llama_beam_search().
* Add '// Beam search' heading to llama.{h,cpp} after llama_grammar_accept_token().
* Add space around * pointers and & references.
* Add spaces around comparison and assignment operators.
* Prefer west const.
* Use llama_ prefix for structs in global namespace.
* Delete obsolete comment from an earlier revision.
* Change eos to eob in llama_beam and llama_beam_view structs.
Diffstat (limited to 'llama.cpp')
-rw-r--r-- | llama.cpp | 251 |
1 files changed, 251 insertions, 0 deletions
@@ -4327,6 +4327,257 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar } // +// Beam search +// + +struct llama_beam { + std::vector<llama_token> tokens; + float p; // Cumulative beam probability (renormalized relative to all beams) + bool eob; // Initialize end-of-beam to false. Callback sets this to true. + // Sort beams by probability. In case of ties, prefer beams at eob. + bool operator<(const llama_beam & rhs) const { + return std::make_pair(p, eob) < std::make_pair(rhs.p, rhs.eob); + } + // Shift off first n tokens and discard them. + void shift_tokens(const size_t n) { + if (n) { + std::copy(tokens.begin() + n, tokens.end(), tokens.begin()); + tokens.resize(tokens.size() - n); + } + } + llama_beam_view view() const { return {tokens.data(), tokens.size(), p, eob}; } +}; + +// A struct for calculating logit-related info. +struct llama_logit_info { + const float * const logits; + const int n_vocab; + const float max_l; + const float normalizer; + struct sum_exp { + float max_l; + float operator()(float sum, float l) const { return sum + std::exp(l - max_l); } + }; + llama_logit_info(llama_context * ctx) + : logits(llama_get_logits(ctx)) + , n_vocab(llama_n_vocab(ctx)) + , max_l(*std::max_element(logits, logits + n_vocab)) + , normalizer(1.0f / std::accumulate(logits, logits + n_vocab, 0.0f, sum_exp{max_l})) + { } + llama_token_data get_token_data(const llama_token token_id) const { + constexpr auto p = std::numeric_limits<float>::quiet_NaN(); // never used + return {token_id, logits[token_id], p}; + } + // Return top k token_data by logit. + std::vector<llama_token_data> top_k(size_t k) { + std::vector<llama_token_data> min_heap; // min-heap by logit + const llama_token k_min = std::min(static_cast<llama_token>(k), n_vocab); + min_heap.reserve(k_min); + for (llama_token token_id = 0 ; token_id < k_min ; ++token_id) { + min_heap.push_back(get_token_data(token_id)); + } + auto comp = [](const llama_token_data & a, const llama_token_data & b) { return a.logit > b.logit; }; + std::make_heap(min_heap.begin(), min_heap.end(), comp); + for (llama_token token_id = k_min ; token_id < n_vocab ; ++token_id) { + if (min_heap.front().logit < logits[token_id]) { + std::pop_heap(min_heap.begin(), min_heap.end(), comp); + min_heap.back().id = token_id; + min_heap.back().logit = logits[token_id]; + std::push_heap(min_heap.begin(), min_heap.end(), comp); + } + } + return min_heap; + } + float probability_from_logit(float logit) { + return normalizer * std::exp(logit - max_l); + } +}; + +struct llama_beam_search_data { + llama_context * ctx; + size_t n_beams; + int n_past; + int n_predict; + int n_threads; + std::vector<llama_beam> beams; + std::vector<llama_beam> next_beams; + + // Re-calculated on each loop iteration + size_t common_prefix_length; + + // Used to communicate to/from callback on beams state. + std::vector<llama_beam_view> beam_views; + + llama_beam_search_data(llama_context * ctx, size_t n_beams, int n_past, int n_predict, int n_threads) + : ctx(ctx) + , n_beams(n_beams) + , n_past(n_past) + , n_predict(n_predict) + , n_threads(n_threads) + , beam_views(n_beams) { + beams.reserve(n_beams); + next_beams.reserve(n_beams); + } + + // Collapse beams to a single beam given by index. + void collapse_beams(const size_t beam_idx) { + if (0u < beam_idx) { + std::swap(beams[0], beams[beam_idx]); + } + beams.resize(1); + } + + // Min-heaps are used to efficiently collect the top-k elements (k=n_beams). + // The repetative patterns below reflect the 2 stages of heaps: + // * Gather elements until the vector is full, then call std::make_heap() on it. + // * If the heap is full and a new element is found that should be included, pop the + // least element to the back(), replace it with the new, then push it into the heap. + void fill_next_beams_by_top_probabilities(llama_beam & beam) { + // Min-heaps use a greater-than comparator. + const auto comp = [](const llama_beam & a, const llama_beam & b) { return a.p > b.p; }; + if (beam.eob) { + // beam is at end-of-sentence, so just copy it to next_beams if its probability is high enough. + if (next_beams.size() < n_beams) { + next_beams.push_back(std::move(beam)); + if (next_beams.size() == n_beams) { + std::make_heap(next_beams.begin(), next_beams.end(), comp); + } + } else if (next_beams.front().p < beam.p) { + std::pop_heap(next_beams.begin(), next_beams.end(), comp); + next_beams.back() = std::move(beam); + std::push_heap(next_beams.begin(), next_beams.end(), comp); + } + } else { + // beam is not at end-of-sentence, so branch with next top_k tokens. + if (!beam.tokens.empty()) { + llama_eval(ctx, beam.tokens.data(), beam.tokens.size(), n_past, n_threads); + } + llama_logit_info logit_info(ctx); + std::vector<llama_token_data> next_tokens = logit_info.top_k(n_beams); + size_t i=0; + if (next_beams.size() < n_beams) { + for (; next_beams.size() < n_beams ; ++i) { + llama_beam next_beam = beam; + next_beam.tokens.push_back(next_tokens[i].id); + next_beam.p *= logit_info.probability_from_logit(next_tokens[i].logit); + next_beams.push_back(std::move(next_beam)); + } + std::make_heap(next_beams.begin(), next_beams.end(), comp); + } else { + for (; next_beams.front().p == 0.0f ; ++i) { + std::pop_heap(next_beams.begin(), next_beams.end(), comp); + next_beams.back() = beam; + next_beams.back().tokens.push_back(next_tokens[i].id); + next_beams.back().p *= logit_info.probability_from_logit(next_tokens[i].logit); + std::push_heap(next_beams.begin(), next_beams.end(), comp); + } + } + for (; i < n_beams ; ++i) { + const float next_p = beam.p * logit_info.probability_from_logit(next_tokens[i].logit); + if (next_beams.front().p < next_p) { + std::pop_heap(next_beams.begin(), next_beams.end(), comp); + next_beams.back() = beam; + next_beams.back().tokens.push_back(next_tokens[i].id); + next_beams.back().p = next_p; + std::push_heap(next_beams.begin(), next_beams.end(), comp); + } + } + } + } + + // Find common_prefix_length based on beams. + // Requires beams is not empty. + size_t find_common_prefix_length() { + size_t common_prefix_length = beams[0].tokens.size(); + for (size_t i = 1 ; i < beams.size() ; ++i) { + common_prefix_length = std::min(common_prefix_length, beams[i].tokens.size()); + for (size_t j = 0 ; j < common_prefix_length ; ++j) { + if (beams[0].tokens[j] != beams[i].tokens[j]) { + common_prefix_length = j; + break; + } + } + } + return common_prefix_length; + } + + // Construct beams_state to send back to caller via the callback function. + // Side effect: set common_prefix_length = find_common_prefix_length(); + llama_beams_state get_beams_state(const bool last_call) { + for (size_t i = 0 ; i < beams.size() ; ++i) { + beam_views[i] = beams[i].view(); + } + common_prefix_length = find_common_prefix_length(); + return {beam_views.data(), beams.size(), common_prefix_length, last_call}; + } + + // Loop: + // * while i < n_predict, AND + // * any of the beams have not yet reached end-of-beam (eob), AND + // * the highest probability beam(s) (plural in case of ties) are not at end-of-sentence + // (since all other beam probabilities can only decrease) + void loop(const llama_beam_search_callback_fn_t callback, void * const callback_data) { + beams.push_back({{}, 1.0f, false}); // Start with one empty beam w/ probability = 1.0 and !eob. + const auto not_eob = [](const llama_beam & beam) { return !beam.eob; }; + for (int i = 0 ; i < n_predict && std::any_of(beams.begin(),beams.end(),not_eob) && + !beams[top_beam_index()].eob ; ++i) { + callback(callback_data, get_beams_state(false)); // Sets common_prefix_length + update_beams_from_beam_views(); // Update values (p,eob) that callback may have changed. + if (common_prefix_length) { + llama_eval(ctx, beams[0].tokens.data(), common_prefix_length, n_past, n_threads); + n_past += common_prefix_length; + } + // Zero-out next_beam probabilities to place them last in following min-heap. + std::for_each(next_beams.begin(), next_beams.end(), [](llama_beam & beam) { beam.p = 0.0f; }); + for (llama_beam & beam : beams) { + beam.shift_tokens(common_prefix_length); + fill_next_beams_by_top_probabilities(beam); + } + // next_beams become the beams of next/final iteration. Swap them to re-use memory. + beams.swap(next_beams); + renormalize_beam_probabilities(beams); + } + collapse_beams(top_beam_index()); + callback(callback_data, get_beams_state(true)); + } + + // As beams grow, the cumulative probabilities decrease. + // Renormalize them to avoid floating point underflow. + static void renormalize_beam_probabilities(std::vector<llama_beam> & beams) { + const auto sum_p = [](float sum, llama_beam & beam) { return sum + beam.p; }; + const float inv_sum = 1.0f / std::accumulate(beams.begin(), beams.end(), 0.0f, sum_p); + std::for_each(beams.begin(), beams.end(), [=](llama_beam & beam) { beam.p *= inv_sum; }); + } + + // Assumes beams is non-empty. Uses llama_beam::operator<() for ordering. + size_t top_beam_index() { + return std::max_element(beams.begin(), beams.end()) - beams.begin(); + } + + // Copy (p,eob) for each beam which may have been changed by the callback. + void update_beams_from_beam_views() { + for (size_t i = 0 ; i < beams.size() ; ++i) { + beams[i].p = beam_views[i].p; + beams[i].eob = beam_views[i].eob; + } + } +}; + +void llama_beam_search(llama_context * ctx, + llama_beam_search_callback_fn_t callback, void * callback_data, + size_t n_beams, int n_past, int n_predict, int n_threads) { + assert(ctx); + const int64_t t_start_sample_us = ggml_time_us(); + + llama_beam_search_data beam_search_data(ctx, n_beams, n_past, n_predict, n_threads); + + beam_search_data.loop(callback, callback_data); + + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + ctx->n_sample++; +} + +// // quantization // |