summaryrefslogtreecommitdiff
path: root/examples/speculative
diff options
context:
space:
mode:
Diffstat (limited to 'examples/speculative')
-rw-r--r--examples/speculative/speculative.cpp36
1 files changed, 23 insertions, 13 deletions
diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp
index 6e0815b3..6a7367b0 100644
--- a/examples/speculative/speculative.cpp
+++ b/examples/speculative/speculative.cpp
@@ -76,6 +76,28 @@ int main(int argc, char ** argv) {
params.n_threads_batch = params.n_threads_batch_draft;
std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params);
+ const bool vocab_type_tgt = llama_vocab_type(model_tgt);
+ LOG("vocab_type tgt: %d\n", vocab_type_tgt);
+
+ const bool vocab_type_dft = llama_vocab_type(model_dft);
+ LOG("vocab_type dft: %d\n", vocab_type_dft);
+
+ if (vocab_type_tgt != vocab_type_dft) {
+ fprintf(stderr, "%s: error: draft model vocab type must match target model to use speculation but ", __func__);
+ fprintf(stderr, "vocab_type_dft = %d while vocab_type_tgt = %d\n", vocab_type_dft, vocab_type_tgt);
+ return 1;
+ }
+
+ if (
+ llama_add_bos_token(model_tgt) != llama_add_bos_token(model_dft) ||
+ llama_add_eos_token(model_tgt) != llama_add_eos_token(model_dft) ||
+ llama_token_bos(model_tgt) != llama_token_bos(model_dft) ||
+ llama_token_eos(model_tgt) != llama_token_eos(model_dft)
+ ) {
+ fprintf(stderr, "%s: error: draft model special tokens must match target model to use speculation\n", __func__);
+ return 1;
+ }
+
{
const int n_vocab_tgt = llama_n_vocab(model_tgt);
const int n_vocab_dft = llama_n_vocab(model_dft);
@@ -105,20 +127,8 @@ int main(int argc, char ** argv) {
// Tokenize the prompt
- const bool add_bos_tgt = llama_should_add_bos_token(model_tgt);
- LOG("add_bos tgt: %d\n", add_bos_tgt);
-
- const bool add_bos_dft = llama_should_add_bos_token(model_dft);
- LOG("add_bos dft: %d\n", add_bos_dft);
-
- if (add_bos_tgt != add_bos_dft) {
- fprintf(stderr, "%s: error: draft model add_bos must match target model to use speculation but ", __func__);
- fprintf(stderr, "add_bos_dft = %d while add_bos_tgt = %d\n", add_bos_dft, add_bos_tgt);
- return 1;
- }
-
std::vector<llama_token> inp;
- inp = ::llama_tokenize(ctx_tgt, params.prompt, add_bos_tgt, true);
+ inp = ::llama_tokenize(ctx_tgt, params.prompt, true, true);
const int max_context_size = llama_n_ctx(ctx_tgt);
const int max_tokens_list_size = max_context_size - 4;