diff options
author | Minsoo Cheong <54794500+mscheong01@users.noreply.github.com> | 2024-03-05 03:24:00 +0900 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-04 20:24:00 +0200 |
commit | 6d341ab6c53cd51f2921d986d0090cc8b049b39a (patch) | |
tree | f212b497e210c8c73fe52369f6bc81297c7b1dab /common | |
parent | 4ffcdce2ff877ebb683cd217ea38faf20faa5ffe (diff) |
speculative : implement stochastic speculative sampling (#5625)
* (WIP) Implement stochastic speculative decoding
* sample from residual distribution on draft accept failure
* fix #5657: force greedy sampling with probs when temp is 0
* remove p_accept parameter
* fix style
* remove unused variables
* add srand() in speculative.cpp
* replace use of rand() with mt19937 sampling
* fixes based on review (@JohannesGaessler)
* fix r random generation
* randomly select next sequence to verify + fix bug in memory freeing
* fix bug in active_seqs sync
* fix uniform int distribution initialization
* remove warnings from comparison between int and size_t
* check grammar in `llama_sample_probability_distribution_impl`
* remove malloc code by utilizing vectors
* add PR link to README
Diffstat (limited to 'common')
-rw-r--r-- | common/common.cpp | 7 | ||||
-rw-r--r-- | common/common.h | 3 | ||||
-rw-r--r-- | common/sampling.cpp | 79 | ||||
-rw-r--r-- | common/sampling.h | 7 |
4 files changed, 87 insertions, 9 deletions
diff --git a/common/common.cpp b/common/common.cpp index dbe7e922..036a9813 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -513,12 +513,6 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { break; } params.n_sequences = std::stoi(argv[i]); - } else if (arg == "--p-accept" || arg == "-pa") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.p_accept = std::stof(argv[i]); } else if (arg == "--p-split" || arg == "-ps") { if (++i >= argc) { invalid_param = true; @@ -1044,7 +1038,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks); printf(" -np N, --parallel N number of parallel sequences to decode (default: %d)\n", params.n_parallel); printf(" -ns N, --sequences N number of sequences to decode (default: %d)\n", params.n_sequences); - printf(" -pa N, --p-accept N speculative decoding accept probability (default: %.1f)\n", (double)params.p_accept); printf(" -ps N, --p-split N speculative decoding split probability (default: %.1f)\n", (double)params.p_split); printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n"); printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA. see examples/llava/README.md\n"); diff --git a/common/common.h b/common/common.h index b2868833..977ce419 100644 --- a/common/common.h +++ b/common/common.h @@ -53,11 +53,10 @@ struct gpt_params { int32_t n_ctx = 512; // context size int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS) int32_t n_keep = 0; // number of tokens to keep from initial prompt - int32_t n_draft = 8; // number of tokens to draft during speculative decoding + int32_t n_draft = 5; // number of tokens to draft during speculative decoding int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited) int32_t n_parallel = 1; // number of parallel sequences to decode int32_t n_sequences = 1; // number of sequences to decode - float p_accept = 0.5f; // speculative decoding accept probability float p_split = 0.1f; // speculative decoding split probability int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default) int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default) diff --git a/common/sampling.cpp b/common/sampling.cpp index e67096be..823031fe 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -295,6 +295,77 @@ static llama_token llama_sampling_sample_impl( return id; } +static llama_token_data_array llama_sample_probability_distribution_impl( + struct llama_sampling_context * ctx_sampling, + struct llama_context * ctx_main, + struct llama_context * ctx_cfg, + const int idx) { + const llama_sampling_params & params = ctx_sampling->params; + + const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); + + const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n; + const float penalty_repeat = params.penalty_repeat; + const float penalty_freq = params.penalty_freq; + const float penalty_present = params.penalty_present; + const bool penalize_nl = params.penalize_nl; + + auto & prev = ctx_sampling->prev; + auto & cur = ctx_sampling->cur; + + // Get a pointer to the logits + float * logits = llama_get_logits_ith(ctx_main, idx); + + // Declare original_logits at the beginning of the function scope + std::vector<float> original_logits; + + // apply params.logit_bias map + for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { + logits[it->first] += it->second; + } + + if (ctx_cfg) { + float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx); + llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale); + } + + cur.clear(); + + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); + } + + llama_token_data_array cur_p = { cur.data(), cur.size(), false }; + + // apply penalties + const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev; + const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n); + if (penalty_tokens_used_size) { + const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))]; + + llama_sample_repetition_penalties(ctx_main, &cur_p, + penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size, + penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present); + + if (!penalize_nl) { + for (size_t idx = 0; idx < cur_p.size; idx++) { + if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) { + cur_p.data[idx].logit = nl_logit; + break; + } + } + } + } + + // apply grammar checks + if (ctx_sampling->grammar != NULL) { + llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar); + } + + llama_sample_softmax(ctx_main, &cur_p); + return cur_p; +} + llama_token llama_sampling_sample( struct llama_sampling_context * ctx_sampling, struct llama_context * ctx_main, @@ -304,6 +375,14 @@ llama_token llama_sampling_sample( return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false); } +llama_token_data_array llama_sampling_probability_distribution( + struct llama_sampling_context * ctx_sampling, + struct llama_context * ctx_main, + struct llama_context * ctx_cfg, + const int idx) { + return llama_sample_probability_distribution_impl(ctx_sampling,ctx_main, ctx_cfg, idx); +} + void llama_sampling_accept( struct llama_sampling_context * ctx_sampling, struct llama_context * ctx_main, diff --git a/common/sampling.h b/common/sampling.h index 95d87539..48b2459d 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -131,6 +131,13 @@ llama_token llama_sampling_sample( struct llama_context * ctx_cfg, int idx = 0); +// returns the probability that token of given id will be sampled +llama_token_data_array llama_sampling_probability_distribution( + struct llama_sampling_context * ctx_sampling, + struct llama_context * ctx_main, + struct llama_context * ctx_cfg, + int idx = 0); + void llama_sampling_accept( struct llama_sampling_context * ctx_sampling, struct llama_context * ctx_main, |