summaryrefslogtreecommitdiff
path: root/examples/speculative
diff options
context:
space:
mode:
Diffstat (limited to 'examples/speculative')
-rw-r--r--examples/speculative/speculative.cpp17
1 files changed, 15 insertions, 2 deletions
diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp
index 3a8e2781..ace755c5 100644
--- a/examples/speculative/speculative.cpp
+++ b/examples/speculative/speculative.cpp
@@ -94,9 +94,22 @@ int main(int argc, char ** argv) {
}
}
- // tokenize the prompt
+
+ // Tokenize the prompt
+ const bool add_bos_tgt = llama_should_add_bos_token(model_tgt);
+ LOG("add_bos tgt: %d\n", add_bos_tgt);
+
+ const bool add_bos_dft = llama_should_add_bos_token(model_dft);
+ LOG("add_bos dft: %d\n", add_bos_dft);
+
+ if (add_bos_tgt != add_bos_dft) {
+ fprintf(stderr, "%s: error: draft model add_bos must match target model to use speculation but ", __func__);
+ fprintf(stderr, "add_bos_dft = %d while add_bos_tgt = %d\n", add_bos_dft, add_bos_tgt);
+ return 1;
+ }
+
std::vector<llama_token> inp;
- inp = ::llama_tokenize(ctx_tgt, params.prompt, true);
+ inp = ::llama_tokenize(ctx_tgt, params.prompt, add_bos_tgt, true);
const int max_context_size = llama_n_ctx(ctx_tgt);
const int max_tokens_list_size = max_context_size - 4;