summaryrefslogtreecommitdiff
path: root/common/sampling.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'common/sampling.cpp')
-rw-r--r--common/sampling.cpp79
1 files changed, 79 insertions, 0 deletions
diff --git a/common/sampling.cpp b/common/sampling.cpp
index e67096be..823031fe 100644
--- a/common/sampling.cpp
+++ b/common/sampling.cpp
@@ -295,6 +295,77 @@ static llama_token llama_sampling_sample_impl(
return id;
}
+static llama_token_data_array llama_sample_probability_distribution_impl(
+ struct llama_sampling_context * ctx_sampling,
+ struct llama_context * ctx_main,
+ struct llama_context * ctx_cfg,
+ const int idx) {
+ const llama_sampling_params & params = ctx_sampling->params;
+
+ const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
+
+ const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
+ const float penalty_repeat = params.penalty_repeat;
+ const float penalty_freq = params.penalty_freq;
+ const float penalty_present = params.penalty_present;
+ const bool penalize_nl = params.penalize_nl;
+
+ auto & prev = ctx_sampling->prev;
+ auto & cur = ctx_sampling->cur;
+
+ // Get a pointer to the logits
+ float * logits = llama_get_logits_ith(ctx_main, idx);
+
+ // Declare original_logits at the beginning of the function scope
+ std::vector<float> original_logits;
+
+ // apply params.logit_bias map
+ for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
+ logits[it->first] += it->second;
+ }
+
+ if (ctx_cfg) {
+ float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
+ llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
+ }
+
+ cur.clear();
+
+ for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
+ cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
+ }
+
+ llama_token_data_array cur_p = { cur.data(), cur.size(), false };
+
+ // apply penalties
+ const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
+ const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
+ if (penalty_tokens_used_size) {
+ const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
+
+ llama_sample_repetition_penalties(ctx_main, &cur_p,
+ penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
+ penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
+
+ if (!penalize_nl) {
+ for (size_t idx = 0; idx < cur_p.size; idx++) {
+ if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {
+ cur_p.data[idx].logit = nl_logit;
+ break;
+ }
+ }
+ }
+ }
+
+ // apply grammar checks
+ if (ctx_sampling->grammar != NULL) {
+ llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
+ }
+
+ llama_sample_softmax(ctx_main, &cur_p);
+ return cur_p;
+}
+
llama_token llama_sampling_sample(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
@@ -304,6 +375,14 @@ llama_token llama_sampling_sample(
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false);
}
+llama_token_data_array llama_sampling_probability_distribution(
+ struct llama_sampling_context * ctx_sampling,
+ struct llama_context * ctx_main,
+ struct llama_context * ctx_cfg,
+ const int idx) {
+ return llama_sample_probability_distribution_impl(ctx_sampling,ctx_main, ctx_cfg, idx);
+}
+
void llama_sampling_accept(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,