summaryrefslogtreecommitdiff
path: root/examples/speculative
diff options
context:
space:
mode:
Diffstat (limited to 'examples/speculative')
-rw-r--r--examples/speculative/speculative.cpp24
1 files changed, 23 insertions, 1 deletions
diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp
index 2cd153f9..aa904183 100644
--- a/examples/speculative/speculative.cpp
+++ b/examples/speculative/speculative.cpp
@@ -82,7 +82,7 @@ int main(int argc, char ** argv) {
//GGML_ASSERT(n_vocab == llama_n_vocab(ctx_dft));
// how many tokens to draft each time
- const int n_draft = params.n_draft;
+ int n_draft = params.n_draft;
int n_predict = 0;
int n_drafted = 0;
@@ -131,6 +131,7 @@ int main(int argc, char ** argv) {
LOG("drafted: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_dft, drafted));
int i_dft = 0;
+
while (true) {
// sample from the target model
const llama_token id = llama_sample_token(ctx_tgt, NULL, grammar_tgt, params, last_tokens, candidates, i_dft);
@@ -174,6 +175,27 @@ int main(int argc, char ** argv) {
llama_eval(ctx_dft, &id, 1, n_past_dft, params.n_threads);
++n_past_dft;
+ // heuristic for n_draft
+ {
+ const int n_draft_cur = (int) drafted.size();
+ const bool all_accepted = i_dft == n_draft_cur;
+
+ LOG("n_draft = %d\n", n_draft);
+ LOG("n_draft_cur = %d\n", n_draft_cur);
+ LOG("i_dft = %d\n", i_dft);
+ LOG("all_accepted = %d\n", all_accepted);
+
+ if (all_accepted && n_draft == n_draft_cur) {
+ LOG(" - max drafted tokens accepted - n_draft += 8\n");
+ n_draft = std::min(30, n_draft + 8);
+ } else if (all_accepted) {
+ LOG(" - partially drafted tokens accepted - no change\n");
+ } else {
+ LOG(" - drafted token rejected - n_draft -= 1\n");
+ n_draft = std::max(2, n_draft - 1);
+ }
+ }
+
drafted.clear();
drafted.push_back(id);