diff options
author | Leng Yue <lengyue@lengyue.me> | 2023-09-14 09:14:44 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-14 19:14:44 +0300 |
commit | 35f73049af6c676a106a5a990a819ae0bc3fcd7d (patch) | |
tree | 55807c47e621aca6ffe3cb8936ade0f3f80e2921 | |
parent | 71ca2fad7d6c0ef95ef9944fb3a1a843e481f314 (diff) |
speculative : add heuristic algorithm (#3006)
* Add heuristic algo for speculative
* Constrain minimum n_draft to 2
* speculative : improve heuristic impl
* speculative : be more rewarding upon guessing max drafted tokens
* speculative : fix typos
---------
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
-rw-r--r-- | examples/speculative/speculative.cpp | 24 |
1 files changed, 23 insertions, 1 deletions
diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 2cd153f9..aa904183 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -82,7 +82,7 @@ int main(int argc, char ** argv) { //GGML_ASSERT(n_vocab == llama_n_vocab(ctx_dft)); // how many tokens to draft each time - const int n_draft = params.n_draft; + int n_draft = params.n_draft; int n_predict = 0; int n_drafted = 0; @@ -131,6 +131,7 @@ int main(int argc, char ** argv) { LOG("drafted: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_dft, drafted)); int i_dft = 0; + while (true) { // sample from the target model const llama_token id = llama_sample_token(ctx_tgt, NULL, grammar_tgt, params, last_tokens, candidates, i_dft); @@ -174,6 +175,27 @@ int main(int argc, char ** argv) { llama_eval(ctx_dft, &id, 1, n_past_dft, params.n_threads); ++n_past_dft; + // heuristic for n_draft + { + const int n_draft_cur = (int) drafted.size(); + const bool all_accepted = i_dft == n_draft_cur; + + LOG("n_draft = %d\n", n_draft); + LOG("n_draft_cur = %d\n", n_draft_cur); + LOG("i_dft = %d\n", i_dft); + LOG("all_accepted = %d\n", all_accepted); + + if (all_accepted && n_draft == n_draft_cur) { + LOG(" - max drafted tokens accepted - n_draft += 8\n"); + n_draft = std::min(30, n_draft + 8); + } else if (all_accepted) { + LOG(" - partially drafted tokens accepted - no change\n"); + } else { + LOG(" - drafted token rejected - n_draft -= 1\n"); + n_draft = std::max(2, n_draft - 1); + } + } + drafted.clear(); drafted.push_back(id); |