summaryrefslogtreecommitdiff
path: root/examples/server/server.cpp
diff options
context:
space:
mode:
authorMatt Pulver <matt.pulver@heavy.ai>2023-08-25 11:18:48 -0400
committerGitHub <noreply@github.com>2023-08-25 18:18:48 +0300
commitc82742ac9cd96fd34aa961978805c1d8a361d589 (patch)
treeee377f2559d967955ce1dde65b698504a33e2928 /examples/server/server.cpp
parent28b2c996ca0ab90a5669946084f13443ec98e241 (diff)
llama : add llama_beam_search() (#2267)
* Add llama_beam_search(). * Add '// Beam search' heading to llama.{h,cpp} after llama_grammar_accept_token(). * Add space around * pointers and & references. * Add spaces around comparison and assignment operators. * Prefer west const. * Use llama_ prefix for structs in global namespace. * Delete obsolete comment from an earlier revision. * Change eos to eob in llama_beam and llama_beam_view structs.
Diffstat (limited to 'examples/server/server.cpp')
-rw-r--r--examples/server/server.cpp90
1 files changed, 77 insertions, 13 deletions
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index 025b385c..3300553f 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -1209,6 +1209,62 @@ static void log_server_request(const Request &req, const Response &res)
});
}
+bool is_at_eob(llama_server_context & server_context, const llama_token * tokens, const size_t n_tokens) {
+ return n_tokens && tokens[n_tokens-1] == llama_token_eos(server_context.ctx);
+}
+
+// Function matching type llama_beam_search_callback_fn_t.
+// Custom callback example is called each time the beams lengths increase:
+// * Show progress by printing ',' following by number of convergent beam tokens if any.
+// * When all beams converge to a common prefix, they are made available in beams_state.beams[0].
+// This is also called when the stop condition is met.
+// Collect tokens into std::vector<llama_token> response which is pointed to by callback_data.
+void beam_search_callback(void * callback_data, llama_beams_state beams_state) {
+ auto & llama = *static_cast<llama_server_context*>(callback_data);
+ // Mark beams as EOS as needed.
+ for (size_t i = 0 ; i < beams_state.n_beams ; ++i) {
+ llama_beam_view& beam_view = beams_state.beam_views[i];
+ if (!beam_view.eob && is_at_eob(llama, beam_view.tokens, beam_view.n_tokens)) {
+ beam_view.eob = true;
+ }
+ }
+ printf(","); // Show progress
+ if (const size_t n = beams_state.common_prefix_length) {
+ llama.generated_token_probs.resize(llama.generated_token_probs.size() + n);
+ assert(0u < beams_state.n_beams);
+ const llama_token * tokens = beams_state.beam_views[0].tokens;
+ const auto map = [](llama_token tok) { return completion_token_output{{},tok}; };
+ std::transform(tokens, tokens + n, llama.generated_token_probs.end() - n, map);
+ printf("%lu", n);
+ }
+ fflush(stdout);
+#if 0 // DEBUG: print current beams for this iteration
+ std::cout << "\n\nCurrent beams:\n";
+ for (size_t i=0 ; i < beams_state.n_beams ; ++i) {
+ std::cout << "beams["<<i<<"]: " << ostream_beam_view{state.ctx,beams_state.beam_views[i]} << std::endl;
+ }
+#endif
+}
+
+struct token_translator {
+ llama_context * ctx;
+ std::string operator()(llama_token tok) const { return llama_token_to_str(ctx, tok); }
+ std::string operator()(completion_token_output cto) const { return (*this)(cto.tok); }
+};
+
+void append_to_generated_text_from_generated_token_probs(llama_server_context & llama) {
+ auto & gtps = llama.generated_token_probs;
+ auto translator = token_translator{llama.ctx};
+ auto add_strlen = [=](size_t sum, const completion_token_output & cto) { return sum + translator(cto).size(); };
+ const size_t len = std::accumulate(gtps.begin(), gtps.end(), size_t(0), add_strlen);
+ if (llama.generated_text.capacity() < llama.generated_text.size() + len) {
+ llama.generated_text.reserve(llama.generated_text.size() + len);
+ }
+ for (const completion_token_output & cto : gtps) {
+ llama.generated_text += translator(cto);
+ }
+}
+
int main(int argc, char **argv)
{
// own arguments required by this example
@@ -1291,22 +1347,30 @@ int main(int argc, char **argv)
llama.beginCompletion();
if (!llama.stream) {
- size_t stop_pos = std::string::npos;
+ if (llama.params.n_beams) {
+ // Fill llama.generated_token_probs vector with final beam.
+ llama_beam_search(llama.ctx, beam_search_callback, &llama, llama.params.n_beams,
+ llama.n_past, llama.n_remain, llama.params.n_threads);
+ // Translate llama.generated_token_probs to llama.generated_text.
+ append_to_generated_text_from_generated_token_probs(llama);
+ } else {
+ size_t stop_pos = std::string::npos;
- while (llama.has_next_token) {
- const completion_token_output token_with_probs = llama.doCompletion();
- const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_str(llama.ctx, token_with_probs.tok);
+ while (llama.has_next_token) {
+ const completion_token_output token_with_probs = llama.doCompletion();
+ const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_str(llama.ctx, token_with_probs.tok);
- stop_pos = llama.findStoppingStrings(llama.generated_text,
- token_text.size(), STOP_FULL);
- }
+ stop_pos = llama.findStoppingStrings(llama.generated_text,
+ token_text.size(), STOP_FULL);
+ }
- if (stop_pos == std::string::npos) {
- stop_pos = llama.findStoppingStrings(llama.generated_text, 0, STOP_PARTIAL);
- }
- if (stop_pos != std::string::npos) {
- llama.generated_text.erase(llama.generated_text.begin() + stop_pos,
- llama.generated_text.end());
+ if (stop_pos == std::string::npos) {
+ stop_pos = llama.findStoppingStrings(llama.generated_text, 0, STOP_PARTIAL);
+ }
+ if (stop_pos != std::string::npos) {
+ llama.generated_text.erase(llama.generated_text.begin() + stop_pos,
+ llama.generated_text.end());
+ }
}
const json data = format_final_response(llama, llama.generated_text, llama.generated_token_probs);