summaryrefslogtreecommitdiff
path: root/examples/speculative/speculative.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'examples/speculative/speculative.cpp')
-rw-r--r--examples/speculative/speculative.cpp19
1 files changed, 11 insertions, 8 deletions
diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp
index aa904183..2445d78d 100644
--- a/examples/speculative/speculative.cpp
+++ b/examples/speculative/speculative.cpp
@@ -37,7 +37,7 @@ int main(int argc, char ** argv) {
llama_context * ctx_dft = NULL;
// load the target model
- params.perplexity = true; // HACK: enable logits_all = true
+ params.logits_all = true;
std::tie(model_tgt, ctx_tgt) = llama_init_from_gpt_params(params);
// load the draft model
@@ -70,9 +70,9 @@ int main(int argc, char ** argv) {
const auto t_enc_start = ggml_time_us();
// eval the prompt with both models
- llama_eval(ctx_tgt, inp.data(), int(inp.size() - 1), 0, params.n_threads);
- llama_eval(ctx_tgt, &inp.back(), 1, inp.size() - 1, params.n_threads);
- llama_eval(ctx_dft, inp.data(), int(inp.size()), 0, params.n_threads);
+ llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1, 0, 0), params.n_threads);
+ llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0), params.n_threads);
+ llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input, 0, 0), params.n_threads);
const auto t_enc_end = ggml_time_us();
@@ -134,7 +134,7 @@ int main(int argc, char ** argv) {
while (true) {
// sample from the target model
- const llama_token id = llama_sample_token(ctx_tgt, NULL, grammar_tgt, params, last_tokens, candidates, i_dft);
+ llama_token id = llama_sample_token(ctx_tgt, NULL, grammar_tgt, params, last_tokens, candidates, i_dft);
// remember which tokens were sampled - used for repetition penalties during sampling
last_tokens.erase(last_tokens.begin());
@@ -172,7 +172,8 @@ int main(int argc, char ** argv) {
LOG("out of drafted tokens\n");
}
- llama_eval(ctx_dft, &id, 1, n_past_dft, params.n_threads);
+ llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, n_ctx);
+ llama_decode(ctx_dft, llama_batch_get_one(&id, 1, n_past_dft, 0), params.n_threads);
++n_past_dft;
// heuristic for n_draft
@@ -256,7 +257,8 @@ int main(int argc, char ** argv) {
}
// evaluate the drafted token on the draft model
- llama_eval(ctx_dft, &drafted.back(), 1, n_past_cur, params.n_threads);
+ llama_kv_cache_seq_rm(ctx_dft, 0, n_past_cur, n_ctx);
+ llama_decode(ctx_dft, llama_batch_get_one(&drafted.back(), 1, n_past_cur, 0), params.n_threads);
++n_past_cur;
if (grammar_dft != NULL) {
@@ -265,7 +267,8 @@ int main(int argc, char ** argv) {
}
// evaluate the target model on the drafted tokens
- llama_eval(ctx_tgt, drafted.data(), drafted.size(), n_past_tgt, params.n_threads);
+ llama_kv_cache_seq_rm(ctx_tgt, 0, n_past_tgt, n_ctx);
+ llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0), params.n_threads);
++n_past_tgt;
// the first token is always proposed by the traget model before the speculation loop