summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorJohannes Gäßler <johannesg@5d6.de>2024-03-23 01:24:36 +0100
committerGitHub <noreply@github.com>2024-03-23 01:24:36 +0100
commit50ccaf5eacb50a2ca378a4ef0dc7aeb45fead652 (patch)
tree3ebcfdadf96bb6f3aadd752a1bfe9771ac182d3b /examples
parent56a00f0a2f48a85376f48b5ce77699df781631ae (diff)
lookup: complement data from context with general text statistics (#5479)
* lookup: evaluation tools, use corpus/previous gens * fixup! lookup: evaluation tools, use corpus/previous gens * fixup! lookup: evaluation tools, use corpus/previous gens * fixup! lookup: evaluation tools, use corpus/previous gens * fixup! lookup: evaluation tools, use corpus/previous gens
Diffstat (limited to 'examples')
-rw-r--r--examples/lookup/CMakeLists.txt18
-rw-r--r--examples/lookup/lookup-create.cpp43
-rw-r--r--examples/lookup/lookup-merge.cpp47
-rw-r--r--examples/lookup/lookup-stats.cpp163
-rw-r--r--examples/lookup/lookup.cpp116
5 files changed, 339 insertions, 48 deletions
diff --git a/examples/lookup/CMakeLists.txt b/examples/lookup/CMakeLists.txt
index c060b8f5..b91633f6 100644
--- a/examples/lookup/CMakeLists.txt
+++ b/examples/lookup/CMakeLists.txt
@@ -3,3 +3,21 @@ add_executable(${TARGET} lookup.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)
+
+set(TARGET lookup-create)
+add_executable(${TARGET} lookup-create.cpp)
+install(TARGETS ${TARGET} RUNTIME)
+target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
+target_compile_features(${TARGET} PRIVATE cxx_std_11)
+
+set(TARGET lookup-merge)
+add_executable(${TARGET} lookup-merge.cpp)
+install(TARGETS ${TARGET} RUNTIME)
+target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
+target_compile_features(${TARGET} PRIVATE cxx_std_11)
+
+set(TARGET lookup-stats)
+add_executable(${TARGET} lookup-stats.cpp)
+install(TARGETS ${TARGET} RUNTIME)
+target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
+target_compile_features(${TARGET} PRIVATE cxx_std_11)
diff --git a/examples/lookup/lookup-create.cpp b/examples/lookup/lookup-create.cpp
new file mode 100644
index 00000000..46a6bed0
--- /dev/null
+++ b/examples/lookup/lookup-create.cpp
@@ -0,0 +1,43 @@
+#include "ggml.h"
+#include "llama.h"
+#include "common.h"
+#include "ngram-cache.h"
+
+#include <cstdint>
+#include <fstream>
+#include <iostream>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+int main(int argc, char ** argv){
+ gpt_params params;
+
+ if (!gpt_params_parse(argc, argv, params)) {
+ return 1;
+ }
+ // init llama.cpp
+ llama_backend_init();
+ llama_numa_init(params.numa);
+
+ llama_model * model = NULL;
+ llama_context * ctx = NULL;
+
+ // load the model
+ std::tie(model, ctx) = llama_init_from_gpt_params(params);
+ GGML_ASSERT(model != nullptr);
+
+ // tokenize the prompt
+ const bool add_bos = llama_should_add_bos_token(model);
+
+ std::vector<llama_token> inp;
+ inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
+ fprintf(stderr, "%s: tokenization done\n", __func__);
+
+
+ llama_ngram_cache ngram_cache;
+ llama_ngram_cache_update(ngram_cache, LLAMA_NGRAM_STATIC, LLAMA_NGRAM_STATIC, inp, inp.size(), true);
+ fprintf(stderr, "%s: hashing done, writing file to %s\n", __func__, params.lookup_cache_static.c_str());
+
+ llama_ngram_cache_save(ngram_cache, params.lookup_cache_static);
+}
diff --git a/examples/lookup/lookup-merge.cpp b/examples/lookup/lookup-merge.cpp
new file mode 100644
index 00000000..07c93eb8
--- /dev/null
+++ b/examples/lookup/lookup-merge.cpp
@@ -0,0 +1,47 @@
+#include "ggml.h"
+#include "llama.h"
+#include "common.h"
+#include "ngram-cache.h"
+
+#include <cstdint>
+#include <cstdio>
+#include <fstream>
+#include <iostream>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+static void print_usage() {
+ fprintf(stderr, "Merges multiple lookup cache files into a single one.\n");
+ fprintf(stderr, "Usage: lookup-merge [--help] lookup_part_1.bin lookup_part_2.bin ... lookup_merged.bin\n");
+}
+
+int main(int argc, char ** argv){
+ if (argc < 3) {
+ print_usage();
+ exit(1);
+ }
+
+ std::vector<std::string> args;
+ args.resize(argc-1);
+ for (int i = 0; i < argc-1; ++i) {
+ args[i] = argv[i+1];
+ if (args[i] == "-h" || args[i] == "--help") {
+ print_usage();
+ exit(0);
+ }
+ }
+
+ fprintf(stderr, "lookup-merge: loading file %s\n", args[0].c_str());
+ llama_ngram_cache ngram_cache_merged = llama_ngram_cache_load(args[0]);
+
+ for (size_t i = 1; i < args.size()-1; ++i) {
+ fprintf(stderr, "lookup-merge: loading file %s\n", args[i].c_str());
+ llama_ngram_cache ngram_cache = llama_ngram_cache_load(args[i]);
+
+ llama_ngram_cache_merge(ngram_cache_merged, ngram_cache);
+ }
+
+ fprintf(stderr, "lookup-merge: saving file %s\n", args.back().c_str());
+ llama_ngram_cache_save(ngram_cache_merged, args.back());
+}
diff --git a/examples/lookup/lookup-stats.cpp b/examples/lookup/lookup-stats.cpp
new file mode 100644
index 00000000..31f22777
--- /dev/null
+++ b/examples/lookup/lookup-stats.cpp
@@ -0,0 +1,163 @@
+#include "ggml.h"
+#include "common.h"
+#include "llama.h"
+#include "log.h"
+#include "ngram-cache.h"
+
+#include <cmath>
+#include <cstdint>
+#include <cstdio>
+#include <fstream>
+#include <string>
+#include <vector>
+#include <unordered_map>
+
+int main(int argc, char ** argv){
+ gpt_params params;
+
+ if (!gpt_params_parse(argc, argv, params)) {
+ return 1;
+ }
+
+ const int n_draft = params.n_draft;
+
+ // init llama.cpp
+ llama_backend_init();
+ llama_numa_init(params.numa);
+
+ llama_model * model = NULL;
+ llama_context * ctx = NULL;
+
+ // load the model
+ std::tie(model, ctx) = llama_init_from_gpt_params(params);
+ llama_set_rng_seed(ctx, params.seed);
+ GGML_ASSERT(llama_n_vocab(model) < (1 << 16));
+
+ // tokenize the prompt
+ const bool add_bos = llama_should_add_bos_token(model);
+ LOG("add_bos tgt: %d\n", add_bos);
+
+ std::vector<llama_token> inp;
+ inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
+
+ llama_ngram_cache ngram_cache_context;
+ llama_ngram_cache ngram_cache_dynamic;
+ llama_ngram_cache ngram_cache_static;
+ int64_t t_draft_flat_us = 0;
+ int64_t t_draft_us = 0;
+
+ {
+ const int64_t t_start_draft_us = ggml_time_us();
+
+ if (!params.lookup_cache_static.empty()) {
+ try {
+ ngram_cache_static = llama_ngram_cache_load(params.lookup_cache_static);
+ } catch (std::ifstream::failure const &) {
+ fprintf(stderr, "error: failed to open static lookup cache: %s", params.lookup_cache_static.c_str());
+ exit(1);
+ }
+ }
+
+ if (!params.lookup_cache_dynamic.empty()) {
+ try {
+ ngram_cache_dynamic = llama_ngram_cache_load(params.lookup_cache_dynamic);
+ } catch (std::ifstream::failure const &) {} // if the file does not exist it will simply be created at the end of the program
+ }
+
+ t_draft_flat_us += ggml_time_us() - t_start_draft_us;
+ }
+
+ const int n_input = inp.size();
+ const int n_ctx = params.n_ctx;
+
+ int n_drafted = 0;
+ int n_accept = 0;
+
+ const int64_t t_start_ms = ggml_time_ms();
+
+ // Iterate over input tokens in chunks of size n_ctx.
+ // Each chunk is treated as if a sequential generation but with pre-determined tokens to ensure reproducibility.
+ for (int i_start = 0; i_start + n_ctx < n_input; i_start += n_ctx) {
+ const std::vector<llama_token> inp_slice(inp.begin() + i_start, inp.begin() + i_start + n_ctx);
+ std::vector<llama_token> pseudo_output;
+ pseudo_output.push_back(inp_slice[0]);
+
+ while ((int) pseudo_output.size() < n_ctx) {
+ // Simulate drafting and decoding from draft:
+ std::vector<llama_token> draft;
+ draft.push_back(pseudo_output.back());
+
+ {
+ const int64_t t_start_draft_us = ggml_time_us();
+ llama_ngram_cache_draft(pseudo_output, draft, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static);
+ t_draft_us += ggml_time_us() - t_start_draft_us;
+ }
+
+ n_drafted += draft.size() - 1;
+
+ for (size_t j = 1; j < draft.size() && (int) pseudo_output.size() < n_ctx; ++j) {
+ const llama_token ground_truth = inp_slice[pseudo_output.size()];
+ const llama_token drafted = draft[j];
+
+ if (ground_truth != drafted) {
+ break;
+ }
+
+ ++n_accept;
+ pseudo_output.push_back(ground_truth);
+
+ {
+ const int64_t t_start_draft_us = ggml_time_us();
+ llama_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, pseudo_output, 1, false);
+ t_draft_us += ggml_time_us() - t_start_draft_us;
+ }
+ }
+
+ // After each simulated batch decoding simulate the sampling of a single token:
+ if ((int) pseudo_output.size() < n_ctx) {
+ pseudo_output.push_back(inp_slice[pseudo_output.size()]);
+ {
+ const int64_t t_start_draft_us = ggml_time_us();
+ llama_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, pseudo_output, 1, false);
+ t_draft_us += ggml_time_us() - t_start_draft_us;
+ }
+ }
+
+ draft.erase(draft.begin());
+
+ }
+ if (i_start > 0 && i_start / 100000 != (i_start - n_ctx) / 100000) {
+ const int64_t t_now_ms = ggml_time_ms();
+ const int64_t eta_ms = (n_input - i_start) * (t_now_ms - t_start_ms) / i_start;
+ const int64_t eta_min = eta_ms / (60*1000);
+ const int64_t eta_s = (eta_ms - 60*1000*eta_min) / 1000;
+
+ LOG_TEE("lookup-stats: %d/%d done, ETA: %02" PRId64 ":%02" PRId64 "\n", i_start, n_input, eta_min, eta_s);
+ }
+
+ // After each chunk, update the dynamic ngram cache with the context ngram cache:
+ llama_ngram_cache_merge(ngram_cache_dynamic, ngram_cache_context);
+ ngram_cache_context.clear();
+ }
+
+ LOG_TEE("\n");
+
+ LOG_TEE("\n");
+ LOG_TEE("n_draft = %d\n", n_draft);
+ LOG_TEE("n_predict = %d\n", n_input - n_input % n_ctx);
+ LOG_TEE("n_drafted = %d\n", n_drafted);
+ LOG_TEE("t_draft_flat = %.2f ms\n", t_draft_flat_us*1e-3);
+ LOG_TEE("t_draft = %.2f ms, %.2f us per token, %.2f tokens per second\n",
+ t_draft_us*1e-3, 1.0f*t_draft_us/n_drafted, n_drafted/(1e-6*t_draft_us));
+ LOG_TEE("n_accept = %d\n", n_accept);
+ LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted);
+
+ llama_free(ctx);
+ llama_free_model(model);
+
+ llama_backend_free();
+
+ fprintf(stderr, "\n\n");
+
+ return 0;
+}
diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp
index b53fae11..2e8c35de 100644
--- a/examples/lookup/lookup.cpp
+++ b/examples/lookup/lookup.cpp
@@ -1,12 +1,15 @@
-#include "common.h"
#include "ggml.h"
#include "llama.h"
+#include "common.h"
+#include "ngram-cache.h"
#include <cmath>
#include <cstdint>
#include <cstdio>
+#include <fstream>
#include <string>
#include <vector>
+#include <unordered_map>
int main(int argc, char ** argv){
gpt_params params;
@@ -15,11 +18,7 @@ int main(int argc, char ** argv){
return 1;
}
- // max/min n-grams size to search for in prompt
- const int ngram_max = 4;
- const int ngram_min = 1;
-
- // length of the candidate / draft sequence, if match is found
+ // max. number of additional tokens to draft if match is found
const int n_draft = params.n_draft;
const bool dump_kv_cache = params.dump_kv_cache;
@@ -39,6 +38,8 @@ int main(int argc, char ** argv){
// load the model
std::tie(model, ctx) = llama_init_from_gpt_params(params);
+ llama_set_rng_seed(ctx, params.seed);
+ GGML_ASSERT(llama_n_vocab(model) < (1 << 16));
// tokenize the prompt
const bool add_bos = llama_should_add_bos_token(model);
@@ -47,6 +48,35 @@ int main(int argc, char ** argv){
std::vector<llama_token> inp;
inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
+ llama_ngram_cache ngram_cache_context;
+ llama_ngram_cache ngram_cache_dynamic;
+ llama_ngram_cache ngram_cache_static;
+ int64_t t_draft_flat_us = 0;
+ int64_t t_draft_us = 0;
+
+ {
+ // Fill up context ngram cache with tokens from user input:
+ const int64_t t_start_draft_us = ggml_time_us();
+ llama_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp, inp.size(), false);
+
+ if (!params.lookup_cache_static.empty()) {
+ try {
+ ngram_cache_static = llama_ngram_cache_load(params.lookup_cache_static);
+ } catch (std::ifstream::failure const &) {
+ fprintf(stderr, "error: failed to open static lookup cache: %s", params.lookup_cache_static.c_str());
+ exit(1);
+ }
+ }
+
+ if (!params.lookup_cache_dynamic.empty()) {
+ try {
+ ngram_cache_dynamic = llama_ngram_cache_load(params.lookup_cache_dynamic);
+ } catch (std::ifstream::failure const &) {} // if the file does not exist it will simply be created at the end of the program
+ }
+
+ t_draft_flat_us += ggml_time_us() - t_start_draft_us;
+ }
+
const int max_context_size = llama_n_ctx(ctx);
const int max_tokens_list_size = max_context_size - 4;
@@ -76,8 +106,6 @@ int main(int argc, char ** argv){
int n_drafted = 0;
int n_accept = 0;
- int64_t t_draft_us = 0;
-
int n_past = inp.size();
bool has_eos = false;
@@ -129,6 +157,12 @@ int main(int argc, char ** argv){
++n_past;
++i_dft;
inp.push_back(id);
+ {
+ // Update context ngram cache with the newly accepted token:
+ const int64_t t_start_draft_us = ggml_time_us();
+ llama_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp, 1, false);
+ t_draft_us += ggml_time_us() - t_start_draft_us;
+ }
if (params.use_color) {
// color accepted draft token
@@ -149,6 +183,12 @@ int main(int argc, char ** argv){
draft.clear();
draft.push_back(id);
inp.push_back(id);
+ {
+ // Update context ngram cache with the newly accepted token:
+ const int64_t t_start_draft_us = ggml_time_us();
+ llama_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp, 1, false);
+ t_draft_us += ggml_time_us() - t_start_draft_us;
+ }
break;
}
@@ -163,44 +203,19 @@ int main(int argc, char ** argv){
llama_batch_clear(batch_tgt);
llama_batch_add(batch_tgt, draft[0], n_past, { 0 }, true);
- // generate n_pred tokens through prompt lookup
- auto prompt_lookup = [&]() -> void {
- const int inp_size = inp.size();
- for (int ngram_size = ngram_max ; ngram_size > ngram_min; --ngram_size){
- const llama_token * ngram = &inp[inp_size - ngram_size];
-
- for (int i = 0; i <= (int) inp_size - (ngram_size * 2); ++i) {
- bool match = true;
- for (int j = 0; j < ngram_size; ++j) {
- if (inp[i + j] != ngram[j]) {
- match = false;
- break;
- }
- }
-
- if (match) {
- const int startIdx = i + ngram_size;
- const int endIdx = startIdx + n_draft;
- if (endIdx < inp_size) {
- for (int j = startIdx; j < endIdx; ++j) {
- LOG(" - draft candidate %d: %d\n", j, inp[j]);
- draft.push_back(inp[j]);
- llama_batch_add(batch_tgt, inp[j], n_past + (j - startIdx) + 1, { 0 }, true);
- ++n_drafted;
- }
- return;
- }
- }
- }
- }
- return;
- };
-
+ // Draft already contains a single token sampled from the model:
+ GGML_ASSERT(draft.size() == 1);
+ GGML_ASSERT(draft[0] == inp.back());
const int64_t t_start_draft_us = ggml_time_us();
- prompt_lookup();
+ llama_ngram_cache_draft(inp, draft, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static);
+
+ for (size_t i = 1; i < draft.size(); ++i) {
+ llama_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true);
+ }
t_draft_us += ggml_time_us() - t_start_draft_us;
+ n_drafted += draft.size() - 1;
llama_decode(ctx, batch_tgt);
++n_past;
@@ -210,19 +225,24 @@ int main(int argc, char ** argv){
auto t_dec_end = ggml_time_us();
+ // Update dynamic ngram cache with context ngram cache and save it to disk:
+ llama_ngram_cache_merge(ngram_cache_dynamic, ngram_cache_context);
+ llama_ngram_cache_save(ngram_cache_dynamic, params.lookup_cache_dynamic);
+
LOG_TEE("\n\n");
LOG_TEE("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f));
LOG_TEE("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f));
LOG_TEE("\n");
- LOG_TEE("n_draft = %d\n", n_draft);
- LOG_TEE("n_predict = %d\n", n_predict);
- LOG_TEE("n_drafted = %d\n", n_drafted);
- LOG_TEE("t_draft = %.2f ms, %.2f us per token, %.2f tokens per second\n",
+ LOG_TEE("n_draft = %d\n", n_draft);
+ LOG_TEE("n_predict = %d\n", n_predict);
+ LOG_TEE("n_drafted = %d\n", n_drafted);
+ LOG_TEE("t_draft_flat = %.2f ms\n", t_draft_flat_us*1e-3);
+ LOG_TEE("t_draft = %.2f ms, %.2f us per token, %.2f tokens per second\n",
t_draft_us*1e-3, 1.0f*t_draft_us/n_drafted, n_drafted/(1e-6*t_draft_us));
- LOG_TEE("n_accept = %d\n", n_accept);
- LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted);
+ LOG_TEE("n_accept = %d\n", n_accept);
+ LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted);
LOG_TEE("\ntarget:\n");
llama_print_timings(ctx);