summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
Diffstat (limited to 'examples')
-rw-r--r--examples/CMakeLists.txt1
-rw-r--r--examples/beam_search/CMakeLists.txt8
-rw-r--r--examples/beam_search/beam_search.cpp188
-rw-r--r--examples/server/server.cpp90
4 files changed, 274 insertions, 13 deletions
diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt
index d2176c91..94b78522 100644
--- a/examples/CMakeLists.txt
+++ b/examples/CMakeLists.txt
@@ -25,6 +25,7 @@ else()
add_subdirectory(simple)
add_subdirectory(embd-input)
add_subdirectory(llama-bench)
+ add_subdirectory(beam_search)
if (LLAMA_METAL)
add_subdirectory(metal)
endif()
diff --git a/examples/beam_search/CMakeLists.txt b/examples/beam_search/CMakeLists.txt
new file mode 100644
index 00000000..b29e0109
--- /dev/null
+++ b/examples/beam_search/CMakeLists.txt
@@ -0,0 +1,8 @@
+set(TARGET beam_search)
+add_executable(${TARGET} beam_search.cpp)
+install(TARGETS ${TARGET} RUNTIME)
+target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
+target_compile_features(${TARGET} PRIVATE cxx_std_11)
+if(TARGET BUILD_INFO)
+ add_dependencies(${TARGET} BUILD_INFO)
+endif()
diff --git a/examples/beam_search/beam_search.cpp b/examples/beam_search/beam_search.cpp
new file mode 100644
index 00000000..1c04fabc
--- /dev/null
+++ b/examples/beam_search/beam_search.cpp
@@ -0,0 +1,188 @@
+#ifndef _GNU_SOURCE
+#define _GNU_SOURCE
+#endif
+
+#include "common.h"
+#include "llama.h"
+#include "build-info.h"
+
+#include <cassert>
+#include <cinttypes>
+#include <cmath>
+#include <cstdio>
+#include <cstring>
+#include <ctime>
+#include <fstream>
+#include <iostream>
+#include <string>
+#include <vector>
+
+#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
+#include <signal.h>
+#include <unistd.h>
+#elif defined (_WIN32)
+#define WIN32_LEAN_AND_MEAN
+#define NOMINMAX
+#include <windows.h>
+#include <signal.h>
+#endif
+
+// Used for debugging to print out beam tokens.
+struct ostream_beam_view {
+ llama_context * ctx;
+ llama_beam_view beam_view;
+};
+std::ostream& operator<<(std::ostream& os, const ostream_beam_view & obv) {
+ os << "p(" << obv.beam_view.p << ") eob(" << std::boolalpha << obv.beam_view.eob << ") tokens(";
+ for (size_t i = 0 ; i < obv.beam_view.n_tokens ; ++i) {
+ os << llama_token_to_str(obv.ctx, obv.beam_view.tokens[i]);
+ }
+ return os << ')';
+}
+
+// Put here anything you want back in beam_search_callback().
+struct beam_search_callback_data {
+ llama_context * ctx;
+ std::vector<llama_token> response;
+};
+
+// In this case, end-of-beam (eob) is equivalent to end-of-sentence (eos) but this need not always be the same.
+// For example, eob can be flagged due to maximum token length, stop words, etc.
+bool is_at_eob(const beam_search_callback_data & callback_data, const llama_token * tokens, const size_t n_tokens) {
+ return n_tokens && tokens[n_tokens-1] == llama_token_eos(callback_data.ctx);
+}
+
+// Function matching type llama_beam_search_callback_fn_t.
+// Custom callback example is called each time the beams lengths increase:
+// * Show progress by printing ',' following by number of convergent beam tokens if any.
+// * When all beams converge to a common prefix, they are made available in beams_state.beams[0].
+// This is also called when the stop condition is met.
+// Collect tokens into std::vector<llama_token> response which is pointed to by callback_data.
+void beam_search_callback(void * callback_data_ptr, llama_beams_state beams_state) {
+ auto& callback_data = *static_cast<beam_search_callback_data*>(callback_data_ptr);
+ // Mark beams as EOS as needed.
+ for (size_t i = 0 ; i < beams_state.n_beams ; ++i) {
+ llama_beam_view& beam_view = beams_state.beam_views[i];
+ if (!beam_view.eob && is_at_eob(callback_data, beam_view.tokens, beam_view.n_tokens)) {
+ beam_view.eob = true;
+ }
+ }
+ printf(","); // Show progress
+ if (const size_t n = beams_state.common_prefix_length) {
+ callback_data.response.resize(callback_data.response.size() + n);
+ assert(0u < beams_state.n_beams);
+ const llama_token * tokens = beams_state.beam_views[0].tokens;
+ std::copy(tokens, tokens + n, callback_data.response.end() - n);
+ printf("%lu", n);
+ }
+ fflush(stdout);
+#if 1 // DEBUG: print current beams for this iteration
+ std::cout << "\n\nCurrent beams (last_call=" << beams_state.last_call << "):\n";
+ for (size_t i = 0 ; i < beams_state.n_beams ; ++i) {
+ std::cout << "beams["<<i<<"]: " << ostream_beam_view{callback_data.ctx,beams_state.beam_views[i]} << std::endl;
+ }
+#endif
+}
+
+int main(int argc, char ** argv)
+{
+ gpt_params params;
+ //params.n_gpu_layers = 200;
+
+ //---------------------------------
+ // Print help :
+ //---------------------------------
+
+ if ( argc < 2 || argv[1][0] == '-' )
+ {
+ printf( "Usage: %s MODEL_PATH [BEAM_WIDTH=2] [PROMPT]\n" , argv[0] );
+ return 1 ;
+ }
+
+ //---------------------------------
+ // Load parameters :
+ //---------------------------------
+
+ params.model = argv[1];
+
+ params.n_beams = 2 < argc ? std::stoi(argv[2]) : 2;
+
+ if ( argc > 3 )
+ {
+ params.prompt = argv[3];
+ }
+
+ if ( params.prompt.empty() )
+ {
+ params.prompt = "### Request:\nHow many countries are there?\n\n### Response:\n";
+ }
+
+ //---------------------------------
+ // Init LLM :
+ //---------------------------------
+
+ llama_backend_init(params.numa);
+
+ llama_model * model;
+ llama_context * ctx;
+
+ std::tie(model, ctx) = llama_init_from_gpt_params( params );
+
+ if ( model == NULL )
+ {
+ fprintf( stderr , "%s: error: unable to load model\n" , __func__ );
+ return 1;
+ }
+
+ //---------------------------------
+ // Tokenize the prompt :
+ //---------------------------------
+
+ std::vector<llama_token> tokens_list = llama_tokenize(ctx, params.prompt, true);
+
+ const size_t max_context_size = llama_n_ctx( ctx );
+ const size_t max_tokens_list_size = max_context_size - 4 ;
+
+ if (tokens_list.size() > max_tokens_list_size)
+ {
+ fprintf( stderr , "%s: error: prompt too long (%lu tokens, max %lu)\n" ,
+ __func__ , tokens_list.size() , max_tokens_list_size );
+ return 1;
+ }
+
+ fprintf( stderr, "\n\n" );
+
+ // Print the tokens from the prompt :
+
+ for( auto id : tokens_list )
+ {
+ std::cout << llama_token_to_str(ctx, id);
+ }
+ std::cout << std::flush;
+
+ int n_past = llama_get_kv_cache_token_count(ctx);
+ if (llama_eval(ctx, tokens_list.data(), tokens_list.size(), n_past, params.n_threads))
+ {
+ fprintf(stderr, "%s : failed to eval prompt.\n" , __func__ );
+ return 1;
+ }
+ n_past += tokens_list.size();
+
+ beam_search_callback_data callback_data{ctx, {}};
+ size_t const beam_width = static_cast<size_t>(params.n_beams);
+ int const n_predict = 256;
+ llama_beam_search(ctx, beam_search_callback, &callback_data, beam_width, n_past, n_predict, params.n_threads);
+
+ std::cout << "\n\n";
+ for (llama_token const token_id : callback_data.response) {
+ std::cout << llama_token_to_str(ctx,token_id);
+ }
+ std::cout << std::endl;
+
+ llama_free( ctx );
+ llama_free_model( model );
+
+ llama_backend_free();
+
+ return 0;
+}
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index 025b385c..3300553f 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -1209,6 +1209,62 @@ static void log_server_request(const Request &req, const Response &res)
});
}
+bool is_at_eob(llama_server_context & server_context, const llama_token * tokens, const size_t n_tokens) {
+ return n_tokens && tokens[n_tokens-1] == llama_token_eos(server_context.ctx);
+}
+
+// Function matching type llama_beam_search_callback_fn_t.
+// Custom callback example is called each time the beams lengths increase:
+// * Show progress by printing ',' following by number of convergent beam tokens if any.
+// * When all beams converge to a common prefix, they are made available in beams_state.beams[0].
+// This is also called when the stop condition is met.
+// Collect tokens into std::vector<llama_token> response which is pointed to by callback_data.
+void beam_search_callback(void * callback_data, llama_beams_state beams_state) {
+ auto & llama = *static_cast<llama_server_context*>(callback_data);
+ // Mark beams as EOS as needed.
+ for (size_t i = 0 ; i < beams_state.n_beams ; ++i) {
+ llama_beam_view& beam_view = beams_state.beam_views[i];
+ if (!beam_view.eob && is_at_eob(llama, beam_view.tokens, beam_view.n_tokens)) {
+ beam_view.eob = true;
+ }
+ }
+ printf(","); // Show progress
+ if (const size_t n = beams_state.common_prefix_length) {
+ llama.generated_token_probs.resize(llama.generated_token_probs.size() + n);
+ assert(0u < beams_state.n_beams);
+ const llama_token * tokens = beams_state.beam_views[0].tokens;
+ const auto map = [](llama_token tok) { return completion_token_output{{},tok}; };
+ std::transform(tokens, tokens + n, llama.generated_token_probs.end() - n, map);
+ printf("%lu", n);
+ }
+ fflush(stdout);
+#if 0 // DEBUG: print current beams for this iteration
+ std::cout << "\n\nCurrent beams:\n";
+ for (size_t i=0 ; i < beams_state.n_beams ; ++i) {
+ std::cout << "beams["<<i<<"]: " << ostream_beam_view{state.ctx,beams_state.beam_views[i]} << std::endl;
+ }
+#endif
+}
+
+struct token_translator {
+ llama_context * ctx;
+ std::string operator()(llama_token tok) const { return llama_token_to_str(ctx, tok); }
+ std::string operator()(completion_token_output cto) const { return (*this)(cto.tok); }
+};
+
+void append_to_generated_text_from_generated_token_probs(llama_server_context & llama) {
+ auto & gtps = llama.generated_token_probs;
+ auto translator = token_translator{llama.ctx};
+ auto add_strlen = [=](size_t sum, const completion_token_output & cto) { return sum + translator(cto).size(); };
+ const size_t len = std::accumulate(gtps.begin(), gtps.end(), size_t(0), add_strlen);
+ if (llama.generated_text.capacity() < llama.generated_text.size() + len) {
+ llama.generated_text.reserve(llama.generated_text.size() + len);
+ }
+ for (const completion_token_output & cto : gtps) {
+ llama.generated_text += translator(cto);
+ }
+}
+
int main(int argc, char **argv)
{
// own arguments required by this example
@@ -1291,22 +1347,30 @@ int main(int argc, char **argv)
llama.beginCompletion();
if (!llama.stream) {
- size_t stop_pos = std::string::npos;
+ if (llama.params.n_beams) {
+ // Fill llama.generated_token_probs vector with final beam.
+ llama_beam_search(llama.ctx, beam_search_callback, &llama, llama.params.n_beams,
+ llama.n_past, llama.n_remain, llama.params.n_threads);
+ // Translate llama.generated_token_probs to llama.generated_text.
+ append_to_generated_text_from_generated_token_probs(llama);
+ } else {
+ size_t stop_pos = std::string::npos;
- while (llama.has_next_token) {
- const completion_token_output token_with_probs = llama.doCompletion();
- const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_str(llama.ctx, token_with_probs.tok);
+ while (llama.has_next_token) {
+ const completion_token_output token_with_probs = llama.doCompletion();
+ const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_str(llama.ctx, token_with_probs.tok);
- stop_pos = llama.findStoppingStrings(llama.generated_text,
- token_text.size(), STOP_FULL);
- }
+ stop_pos = llama.findStoppingStrings(llama.generated_text,
+ token_text.size(), STOP_FULL);
+ }
- if (stop_pos == std::string::npos) {
- stop_pos = llama.findStoppingStrings(llama.generated_text, 0, STOP_PARTIAL);
- }
- if (stop_pos != std::string::npos) {
- llama.generated_text.erase(llama.generated_text.begin() + stop_pos,
- llama.generated_text.end());
+ if (stop_pos == std::string::npos) {
+ stop_pos = llama.findStoppingStrings(llama.generated_text, 0, STOP_PARTIAL);
+ }
+ if (stop_pos != std::string::npos) {
+ llama.generated_text.erase(llama.generated_text.begin() + stop_pos,
+ llama.generated_text.end());
+ }
}
const json data = format_final_response(llama, llama.generated_text, llama.generated_token_probs);