diff options
Diffstat (limited to 'examples/speculative/speculative.cpp')
-rw-r--r-- | examples/speculative/speculative.cpp | 216 |
1 files changed, 168 insertions, 48 deletions
diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 3848791d..85bc0a76 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -5,6 +5,7 @@ #include <cstdio> #include <string> #include <vector> +#include <set> #define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 100 #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 @@ -18,6 +19,7 @@ struct seq_draft { std::vector<int> i_batch_tgt; std::vector<llama_token> tokens; + std::vector<std::vector<llama_token_data>> dists; struct llama_sampling_context * ctx_sampling; }; @@ -37,12 +39,15 @@ int main(int argc, char ** argv) { // max number of parallel drafting sequences (i.e. tree branches) const int n_seq_dft = params.n_parallel; - // probability threshold for accepting a token from the draft model - const float p_accept = params.p_accept; - // probability threshold for splitting a draft branch (only for n_seq_dft > 1) const float p_split = params.p_split; + if (params.seed == LLAMA_DEFAULT_SEED) { + params.seed = time(NULL); + } + std::default_random_engine rng(params.seed); + std::uniform_real_distribution<> u_dist; + #ifndef LOG_DISABLE_LOGS log_set_target(log_filename_generator("speculative", "log")); LOG_TEE("Log start\n"); @@ -166,7 +171,9 @@ int main(int argc, char ** argv) { std::vector<seq_draft> drafts(n_seq_dft); params.sparams.grammar.clear(); // the draft samplers will copy the target sampler's grammar - params.sparams.temp = -1.0f; // force greedy sampling with probs for the draft model + if (params.sparams.temp == 0) { + params.sparams.temp = -1.0f; // force greedy sampling with probs for the draft model + } for (int s = 0; s < n_seq_dft; ++s) { drafts[s].ctx_sampling = llama_sampling_init(params.sparams); @@ -182,12 +189,15 @@ int main(int argc, char ** argv) { drafts[0].i_batch_tgt[0] = 0; while (true) { + std::set<int> active_seqs = {}; + // print current draft sequences for (int s = 0; s < n_seq_dft; ++s) { if (!drafts[s].active) { continue; } + active_seqs.insert(s); const auto & tokens = drafts[s].tokens; LOG("draft %d: %s\n", s, LOG_TOKENS_TOSTR_PRETTY(ctx_dft, tokens).c_str()); @@ -196,48 +206,156 @@ int main(int argc, char ** argv) { int i_dft = 0; int s_keep = 0; + llama_token token_id; + std::string token_str; + + // loop until we fail to accept a drafted token or we run out of drafted tokens while (true) { - LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]); - // sample from the target model - llama_token id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]); + // check if the target token matches any of the drafts + // for stochastic sampling, attempt to match the token with the drafted tokens + { + bool accept = false; + if (params.sparams.temp > 0) { + // stochastic verification + + llama_token_data_array dist_tgt = llama_sampling_probability_distribution(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]); + float p_tgt = 0, p_dft = 0; + + // GGML_ASSERT(dist_tgt.size() == dist_dft.size()); + + while (active_seqs.size() > 0) { + // randomly select a sequence to verify from active sequences + std::uniform_int_distribution<u_int> u_int_dist(0, active_seqs.size() - 1); + int s = *std::next(active_seqs.begin(), u_int_dist(rng)); + if (i_dft >= (int) drafts[s].tokens.size()) { + drafts[s].active = false; + active_seqs.erase(s); + continue; + } + if (accept) { + // if we already accepted a token, we can skip the rest + if (drafts[s].tokens[i_dft] != drafts[s_keep].tokens[i_dft]) { + drafts[s].active = false; + active_seqs.erase(s); + } + continue; + } + LOG("verifying sequence #%d at pos #%d from %d active sequence(s)\n", s, i_dft, (int) active_seqs.size()); + float r = u_dist(rng); + llama_token_data_array dist_dft = { drafts[s].dists[i_dft].data() , drafts[s].dists[i_dft].size(), true }; + // acquire the token probabilities assigned by the draft and target models + for (size_t i = 0; i < dist_tgt.size; i++) { + if (dist_tgt.data[i].id == drafts[s].tokens[i_dft]) { + p_tgt = dist_tgt.data[i].p; + } + if (dist_dft.data[i].id == drafts[s].tokens[i_dft]) { + p_dft = dist_dft.data[i].p; + } + if (p_tgt && p_dft) { + break; + } + } + LOG("r = %f, p_dft = %f, p_tgt = %f\n", r, p_dft, p_tgt); + if (r <= p_tgt / p_dft) { + s_keep = s; + accept = true; + token_id = drafts[s].tokens[i_dft]; + token_str = llama_token_to_piece(ctx_tgt, token_id); + llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true); + + LOG("draft token %d of sequence %d (%d, '%s') accepted\n", i_dft, s, token_id, token_str.c_str()); + break; + } else { + LOG("draft token %d of sequence %d (%d, '%s') rejected\n", i_dft, s, drafts[s].tokens[i_dft], llama_token_to_piece(ctx_tgt, drafts[s].tokens[i_dft]).c_str()); + drafts[s].active = false; + + // calculate residual probability + GGML_ASSERT(dist_tgt.sorted); + GGML_ASSERT(dist_dft.sorted); + float sum_probs = 0.0f; + + // sort dist by id + std::sort(dist_tgt.data, dist_tgt.data + dist_tgt.size, [](const llama_token_data &a, const llama_token_data &b) { + return a.id < b.id; + }); + std::sort(dist_dft.data, dist_dft.data + dist_dft.size, [](const llama_token_data &a, const llama_token_data &b) { + return a.id < b.id; + }); + + for (size_t i = 0; i < dist_tgt.size; i++) { + dist_tgt.data[i].p = std::max(0.0f, dist_tgt.data[i].p - dist_dft.data[i].p); + sum_probs += dist_tgt.data[i].p; + } + for (size_t i = 0; i < dist_tgt.size; i++) { + dist_tgt.data[i].p /= sum_probs; + } - llama_sampling_accept(ctx_sampling, ctx_tgt, id, true); + // sort dist_tgt by p desc + std::sort(dist_tgt.data, dist_tgt.data + dist_tgt.size, [](const llama_token_data &a, const llama_token_data &b) { + return a.p > b.p; + }); + } - //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str()); + active_seqs.erase(s); + for(int i = 0; i < n_seq_dft; i++) { + if (i == s) { + continue; + } + if (drafts[i].tokens[i_dft] == drafts[s].tokens[i_dft]) { + // synchronize active status for sequences with the same drafted token + drafts[i].active = drafts[i].active && accept; + if (!drafts[i].active) { + active_seqs.erase(s); + } + } + } + } - const std::string token_str = llama_token_to_piece(ctx_tgt, id); + if (!accept) { + // all drafted tokens were rejected + // sample from the target model + LOG("all drafted tokens were rejected, sampling from residual distribution\n"); + token_id = llama_sample_token(ctx_tgt, &dist_tgt); + llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true); + token_str = llama_token_to_piece(ctx_tgt, token_id); + } - if (!params.use_color) { - printf("%s", token_str.c_str()); - } + } else { + // greedy verification - if (id == llama_token_eos(model_tgt)) { - has_eos = true; - } + // sample from the target model + LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]); + token_id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]); - ++n_predict; + llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true); - // check if the target token matches any of the drafts - { - bool matches = false; + //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str()); - for (int s = 0; s < n_seq_dft; ++s) { - if (!drafts[s].active) { - continue; - } + token_str = llama_token_to_piece(ctx_tgt, token_id); - if (i_dft < (int) drafts[s].tokens.size() && id == drafts[s].tokens[i_dft]) { - LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, s, id, token_str.c_str()); + for (int s = 0; s < n_seq_dft; ++s) { + if (!drafts[s].active) { + continue; + } - s_keep = s; - matches = true; - } else { - drafts[s].active = false; + if (i_dft < (int) drafts[s].tokens.size() && token_id == drafts[s].tokens[i_dft]) { + LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, s, token_id, token_str.c_str()); + + s_keep = s; + accept = true; + } else { + drafts[s].active = false; + } } } - if (matches) { + if (token_id == llama_token_eos(model_tgt)) { + has_eos = true; + } + ++n_predict; + + if (accept) { ++n_accept; ++n_past_tgt; ++n_past_dft; @@ -245,17 +363,21 @@ int main(int argc, char ** argv) { if (params.use_color) { // Color token according to its origin sequence printf("\u001b[%dm%s\u001b[37m", (36 - s_keep % 6), token_str.c_str()); - fflush(stdout); + } else { + printf("%s", token_str.c_str()); } + fflush(stdout); continue; + } else { + printf("%s", token_str.c_str()); + fflush(stdout); + break; } } - if (params.use_color) { - printf("%s", token_str.c_str()); - } - fflush(stdout); + } - LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", id, token_str.c_str()); + { + LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", token_id, token_str.c_str()); // TODO: simplify { @@ -275,21 +397,21 @@ int main(int argc, char ** argv) { drafts[s].active = false; drafts[s].tokens.clear(); drafts[s].i_batch_tgt.clear(); + drafts[s].dists.clear(); } // note: will be erased after the speculation phase - drafts[0].tokens.push_back(id); + drafts[0].tokens.push_back(token_id); + drafts[0].dists.push_back(std::vector<llama_token_data>()); drafts[0].i_batch_tgt.push_back(0); llama_batch_clear(batch_dft); - llama_batch_add (batch_dft, id, n_past_dft, { 0 }, true); + llama_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true); llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1); // LOG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str()); - llama_decode (ctx_dft, batch_dft); + llama_decode(ctx_dft, batch_dft); ++n_past_dft; - - break; } if (n_predict > params.n_predict || has_eos) { @@ -334,12 +456,6 @@ int main(int argc, char ** argv) { k, s, i, cur_p[k].id, cur_p[k].p, llama_token_to_piece(ctx_dft, cur_p[k].id).c_str()); } - if (cur_p[0].p < p_accept) { - LOG("stopping drafting for seq %3d, probability too low: %.3f < %.3f\n", s, cur_p[0].p, p_accept); - drafts[s].drafting = false; - continue; - } - std::vector<int> sa(1, s); // attempt to split the branch if the probability is high enough @@ -367,6 +483,7 @@ int main(int argc, char ** argv) { drafts[n_seq_cur].skip = true; drafts[n_seq_cur].tokens = drafts[s].tokens; + drafts[n_seq_cur].dists = drafts[s].dists; drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft; drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt; @@ -389,6 +506,8 @@ int main(int argc, char ** argv) { llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id, true); drafts[s].tokens.push_back(id); + // save cur_p.data into drafts[s].dists + drafts[s].dists.push_back(cur_p); // add unique drafted tokens to the target batch drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens); @@ -440,6 +559,7 @@ int main(int argc, char ** argv) { } drafts[s].tokens.erase(drafts[s].tokens.begin()); + drafts[s].dists.erase(drafts[s].dists.begin()); } } |