diff options
Diffstat (limited to 'examples/speculative/speculative.cpp')
-rw-r--r-- | examples/speculative/speculative.cpp | 80 |
1 files changed, 69 insertions, 11 deletions
diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index f0400c13..c6211ac7 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -6,6 +6,7 @@ #include "common.h" #include "llama.h" +#include "grammar-parser.h" #include <cmath> #include <cstdio> @@ -109,16 +110,35 @@ int main(int argc, char ** argv) { // used to determine end of generation bool has_eos = false; + // grammar stuff + struct llama_grammar * grammar_dft = NULL; + struct llama_grammar * grammar_tgt = NULL; + + grammar_parser::parse_state parsed_grammar; + + // if requested - load the grammar, error checking is omitted for brevity + if (!params.grammar.empty()) { + parsed_grammar = grammar_parser::parse(params.grammar.c_str()); + // will be empty (default) if there are parse errors + if (parsed_grammar.rules.empty()) { + return 1; + } + + std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules()); + grammar_tgt = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); + } + const auto t_dec_start = ggml_time_us(); while (true) { LOG("drafted: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_dft, drafted)); - // sample from the drafted tokens if any int i_dft = 0; while (true) { - const llama_token id = llama_sample_token(ctx_tgt, NULL, NULL, params, last_tokens, candidates, i_dft); + // sample from the target model + const llama_token id = llama_sample_token(ctx_tgt, NULL, grammar_tgt, params, last_tokens, candidates, i_dft); + // remember which tokens were sampled - used for repetition penalties during sampling last_tokens.erase(last_tokens.begin()); last_tokens.push_back(id); @@ -134,8 +154,9 @@ int main(int argc, char ** argv) { ++n_predict; + // check if the draft matches the target if (i_dft < (int) drafted.size() && id == drafted[i_dft]) { - LOG("drafted token %d accepted\n", id); + LOG("the sampled target token matches the %dth drafted token (%d, '%s') - accepted\n", i_dft, id, token_str.c_str()); ++n_accept; ++n_past_tgt; ++n_past_dft; @@ -145,6 +166,14 @@ int main(int argc, char ** argv) { } // the drafted token was rejected or we are out of drafted tokens + + if (i_dft < (int) drafted.size()) { + LOG("the %dth drafted token (%d, '%s') does not match the sampled target token (%d, '%s') - rejected\n", + i_dft, drafted[i_dft], llama_token_to_piece(ctx_dft, drafted[i_dft]).c_str(), id, token_str.c_str()); + } else { + LOG("out of drafted tokens\n"); + } + llama_eval(ctx_dft, &id, 1, n_past_dft, params.n_threads); ++n_past_dft; @@ -158,7 +187,16 @@ int main(int argc, char ** argv) { break; } - // sample n_draft tokens from the draft model picking the best token + if (grammar_tgt) { + if (grammar_dft) { + llama_grammar_free(grammar_dft); + } + grammar_dft = llama_grammar_copy(grammar_tgt); + + LOG("copied target grammar to draft grammar\n"); + } + + // sample n_draft tokens from the draft model using greedy decoding int n_past_cur = n_past_dft; for (int i = 0; i < n_draft; ++i) { float * logits = llama_get_logits(ctx_dft); @@ -170,25 +208,40 @@ int main(int argc, char ** argv) { llama_token_data_array cur_p = { candidates.data(), candidates.size(), false }; + if (grammar_dft != NULL) { + llama_sample_grammar(ctx_dft, &cur_p, grammar_dft); + } + // computes softmax and sorts the candidates llama_sample_softmax(ctx_dft, &cur_p); for (int i = 0; i < 3; ++i) { - LOG(" - draft candidate %d: %d (%.3f)\n", i, cur_p.data[i].id, cur_p.data[i].p); + LOG(" - draft candidate %3d: %6d (%8.3f) '%s'\n", i, cur_p.data[i].id, cur_p.data[i].p, llama_token_to_piece(ctx_dft, cur_p.data[i].id).c_str()); } - // too low probability, stop drafting + // TODO: better logic? if (cur_p.data[0].p < 2*cur_p.data[1].p) { + LOG("stopping drafting, probability too low: %.3f < 2*%.3f\n", cur_p.data[0].p, cur_p.data[1].p); break; } - drafted.push_back(cur_p.data[0].id); + // drafted token + const llama_token id = cur_p.data[0].id; + + drafted.push_back(id); ++n_drafted; - if (i < n_draft - 1) { - // evaluate the drafted token on the draft model - llama_eval(ctx_dft, &drafted.back(), 1, n_past_cur, params.n_threads); - ++n_past_cur; + // no need to evaluate the last drafted token, since we won't use the result + if (i == n_draft - 1) { + break; + } + + // evaluate the drafted token on the draft model + llama_eval(ctx_dft, &drafted.back(), 1, n_past_cur, params.n_threads); + ++n_past_cur; + + if (grammar_dft != NULL) { + llama_grammar_accept_token(ctx_dft, grammar_dft, id); } } @@ -196,6 +249,7 @@ int main(int argc, char ** argv) { llama_eval(ctx_tgt, drafted.data(), drafted.size(), n_past_tgt, params.n_threads); ++n_past_tgt; + // the first token is always proposed by the traget model before the speculation loop drafted.erase(drafted.begin()); } @@ -226,6 +280,10 @@ int main(int argc, char ** argv) { llama_free(ctx_dft); llama_free_model(model_dft); + if (grammar_dft != NULL) { + llama_grammar_free(grammar_dft); + llama_grammar_free(grammar_tgt); + } llama_backend_free(); fprintf(stderr, "\n\n"); |