summaryrefslogtreecommitdiff
path: root/common/sampling.cpp
diff options
context:
space:
mode:
authorMinsoo Cheong <54794500+mscheong01@users.noreply.github.com>2024-03-05 03:24:00 +0900
committerGitHub <noreply@github.com>2024-03-04 20:24:00 +0200
commit6d341ab6c53cd51f2921d986d0090cc8b049b39a (patch)
treef212b497e210c8c73fe52369f6bc81297c7b1dab /common/sampling.cpp
parent4ffcdce2ff877ebb683cd217ea38faf20faa5ffe (diff)
speculative : implement stochastic speculative sampling (#5625)
* (WIP) Implement stochastic speculative decoding * sample from residual distribution on draft accept failure * fix #5657: force greedy sampling with probs when temp is 0 * remove p_accept parameter * fix style * remove unused variables * add srand() in speculative.cpp * replace use of rand() with mt19937 sampling * fixes based on review (@JohannesGaessler) * fix r random generation * randomly select next sequence to verify + fix bug in memory freeing * fix bug in active_seqs sync * fix uniform int distribution initialization * remove warnings from comparison between int and size_t * check grammar in `llama_sample_probability_distribution_impl` * remove malloc code by utilizing vectors * add PR link to README
Diffstat (limited to 'common/sampling.cpp')
-rw-r--r--common/sampling.cpp79
1 files changed, 79 insertions, 0 deletions
diff --git a/common/sampling.cpp b/common/sampling.cpp
index e67096be..823031fe 100644
--- a/common/sampling.cpp
+++ b/common/sampling.cpp
@@ -295,6 +295,77 @@ static llama_token llama_sampling_sample_impl(
return id;
}
+static llama_token_data_array llama_sample_probability_distribution_impl(
+ struct llama_sampling_context * ctx_sampling,
+ struct llama_context * ctx_main,
+ struct llama_context * ctx_cfg,
+ const int idx) {
+ const llama_sampling_params & params = ctx_sampling->params;
+
+ const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
+
+ const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
+ const float penalty_repeat = params.penalty_repeat;
+ const float penalty_freq = params.penalty_freq;
+ const float penalty_present = params.penalty_present;
+ const bool penalize_nl = params.penalize_nl;
+
+ auto & prev = ctx_sampling->prev;
+ auto & cur = ctx_sampling->cur;
+
+ // Get a pointer to the logits
+ float * logits = llama_get_logits_ith(ctx_main, idx);
+
+ // Declare original_logits at the beginning of the function scope
+ std::vector<float> original_logits;
+
+ // apply params.logit_bias map
+ for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
+ logits[it->first] += it->second;
+ }
+
+ if (ctx_cfg) {
+ float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
+ llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
+ }
+
+ cur.clear();
+
+ for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
+ cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
+ }
+
+ llama_token_data_array cur_p = { cur.data(), cur.size(), false };
+
+ // apply penalties
+ const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
+ const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
+ if (penalty_tokens_used_size) {
+ const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
+
+ llama_sample_repetition_penalties(ctx_main, &cur_p,
+ penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
+ penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
+
+ if (!penalize_nl) {
+ for (size_t idx = 0; idx < cur_p.size; idx++) {
+ if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {
+ cur_p.data[idx].logit = nl_logit;
+ break;
+ }
+ }
+ }
+ }
+
+ // apply grammar checks
+ if (ctx_sampling->grammar != NULL) {
+ llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
+ }
+
+ llama_sample_softmax(ctx_main, &cur_p);
+ return cur_p;
+}
+
llama_token llama_sampling_sample(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
@@ -304,6 +375,14 @@ llama_token llama_sampling_sample(
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false);
}
+llama_token_data_array llama_sampling_probability_distribution(
+ struct llama_sampling_context * ctx_sampling,
+ struct llama_context * ctx_main,
+ struct llama_context * ctx_cfg,
+ const int idx) {
+ return llama_sample_probability_distribution_impl(ctx_sampling,ctx_main, ctx_cfg, idx);
+}
+
void llama_sampling_accept(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,