diff options
author | Matt Pulver <matt.pulver@heavy.ai> | 2023-08-25 11:18:48 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-25 18:18:48 +0300 |
commit | c82742ac9cd96fd34aa961978805c1d8a361d589 (patch) | |
tree | ee377f2559d967955ce1dde65b698504a33e2928 /examples/server/server.cpp | |
parent | 28b2c996ca0ab90a5669946084f13443ec98e241 (diff) |
llama : add llama_beam_search() (#2267)
* Add llama_beam_search().
* Add '// Beam search' heading to llama.{h,cpp} after llama_grammar_accept_token().
* Add space around * pointers and & references.
* Add spaces around comparison and assignment operators.
* Prefer west const.
* Use llama_ prefix for structs in global namespace.
* Delete obsolete comment from an earlier revision.
* Change eos to eob in llama_beam and llama_beam_view structs.
Diffstat (limited to 'examples/server/server.cpp')
-rw-r--r-- | examples/server/server.cpp | 90 |
1 files changed, 77 insertions, 13 deletions
diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 025b385c..3300553f 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1209,6 +1209,62 @@ static void log_server_request(const Request &req, const Response &res) }); } +bool is_at_eob(llama_server_context & server_context, const llama_token * tokens, const size_t n_tokens) { + return n_tokens && tokens[n_tokens-1] == llama_token_eos(server_context.ctx); +} + +// Function matching type llama_beam_search_callback_fn_t. +// Custom callback example is called each time the beams lengths increase: +// * Show progress by printing ',' following by number of convergent beam tokens if any. +// * When all beams converge to a common prefix, they are made available in beams_state.beams[0]. +// This is also called when the stop condition is met. +// Collect tokens into std::vector<llama_token> response which is pointed to by callback_data. +void beam_search_callback(void * callback_data, llama_beams_state beams_state) { + auto & llama = *static_cast<llama_server_context*>(callback_data); + // Mark beams as EOS as needed. + for (size_t i = 0 ; i < beams_state.n_beams ; ++i) { + llama_beam_view& beam_view = beams_state.beam_views[i]; + if (!beam_view.eob && is_at_eob(llama, beam_view.tokens, beam_view.n_tokens)) { + beam_view.eob = true; + } + } + printf(","); // Show progress + if (const size_t n = beams_state.common_prefix_length) { + llama.generated_token_probs.resize(llama.generated_token_probs.size() + n); + assert(0u < beams_state.n_beams); + const llama_token * tokens = beams_state.beam_views[0].tokens; + const auto map = [](llama_token tok) { return completion_token_output{{},tok}; }; + std::transform(tokens, tokens + n, llama.generated_token_probs.end() - n, map); + printf("%lu", n); + } + fflush(stdout); +#if 0 // DEBUG: print current beams for this iteration + std::cout << "\n\nCurrent beams:\n"; + for (size_t i=0 ; i < beams_state.n_beams ; ++i) { + std::cout << "beams["<<i<<"]: " << ostream_beam_view{state.ctx,beams_state.beam_views[i]} << std::endl; + } +#endif +} + +struct token_translator { + llama_context * ctx; + std::string operator()(llama_token tok) const { return llama_token_to_str(ctx, tok); } + std::string operator()(completion_token_output cto) const { return (*this)(cto.tok); } +}; + +void append_to_generated_text_from_generated_token_probs(llama_server_context & llama) { + auto & gtps = llama.generated_token_probs; + auto translator = token_translator{llama.ctx}; + auto add_strlen = [=](size_t sum, const completion_token_output & cto) { return sum + translator(cto).size(); }; + const size_t len = std::accumulate(gtps.begin(), gtps.end(), size_t(0), add_strlen); + if (llama.generated_text.capacity() < llama.generated_text.size() + len) { + llama.generated_text.reserve(llama.generated_text.size() + len); + } + for (const completion_token_output & cto : gtps) { + llama.generated_text += translator(cto); + } +} + int main(int argc, char **argv) { // own arguments required by this example @@ -1291,22 +1347,30 @@ int main(int argc, char **argv) llama.beginCompletion(); if (!llama.stream) { - size_t stop_pos = std::string::npos; + if (llama.params.n_beams) { + // Fill llama.generated_token_probs vector with final beam. + llama_beam_search(llama.ctx, beam_search_callback, &llama, llama.params.n_beams, + llama.n_past, llama.n_remain, llama.params.n_threads); + // Translate llama.generated_token_probs to llama.generated_text. + append_to_generated_text_from_generated_token_probs(llama); + } else { + size_t stop_pos = std::string::npos; - while (llama.has_next_token) { - const completion_token_output token_with_probs = llama.doCompletion(); - const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_str(llama.ctx, token_with_probs.tok); + while (llama.has_next_token) { + const completion_token_output token_with_probs = llama.doCompletion(); + const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_str(llama.ctx, token_with_probs.tok); - stop_pos = llama.findStoppingStrings(llama.generated_text, - token_text.size(), STOP_FULL); - } + stop_pos = llama.findStoppingStrings(llama.generated_text, + token_text.size(), STOP_FULL); + } - if (stop_pos == std::string::npos) { - stop_pos = llama.findStoppingStrings(llama.generated_text, 0, STOP_PARTIAL); - } - if (stop_pos != std::string::npos) { - llama.generated_text.erase(llama.generated_text.begin() + stop_pos, - llama.generated_text.end()); + if (stop_pos == std::string::npos) { + stop_pos = llama.findStoppingStrings(llama.generated_text, 0, STOP_PARTIAL); + } + if (stop_pos != std::string::npos) { + llama.generated_text.erase(llama.generated_text.begin() + stop_pos, + llama.generated_text.end()); + } } const json data = format_final_response(llama, llama.generated_text, llama.generated_token_probs); |