summaryrefslogtreecommitdiff
path: root/common/sampling.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'common/sampling.cpp')
-rw-r--r--common/sampling.cpp73
1 files changed, 51 insertions, 22 deletions
diff --git a/common/sampling.cpp b/common/sampling.cpp
index 0b246658..6f0af3c4 100644
--- a/common/sampling.cpp
+++ b/common/sampling.cpp
@@ -1,9 +1,9 @@
#include "sampling.h"
-struct llama_sampling_context * llama_sampling_init(const struct gpt_params & params) {
+struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
struct llama_sampling_context * result = new llama_sampling_context();
- result->params = params.sampling_params;
+ result->params = params;
result->grammar = nullptr;
// if there is a grammar, parse it
@@ -23,7 +23,7 @@ struct llama_sampling_context * llama_sampling_init(const struct gpt_params & pa
grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root"));
}
- result->prev.resize(params.n_ctx);
+ result->prev.resize(params.n_prev);
return result;
}
@@ -66,25 +66,56 @@ void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * ds
dst->prev = src->prev;
}
+llama_token llama_sampling_last(llama_sampling_context * ctx) {
+ return ctx->prev.back();
+}
+
+std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n) {
+ const int size = ctx_sampling->prev.size();
+
+ n = std::min(n, size);
+
+ std::string result;
+
+ for (int i = size - n; i < size; i++) {
+ result += llama_token_to_piece(ctx_main, ctx_sampling->prev[i]);
+ }
+
+ return result;
+}
+
+std::string llama_sampling_print(const llama_sampling_params & params) {
+ char result[1024];
+
+ snprintf(result, sizeof(result),
+ "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
+ "\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, typical_p = %.3f, temp = %.3f\n"
+ "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
+ params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present,
+ params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp,
+ params.mirostat, params.mirostat_eta, params.mirostat_tau);
+
+ return std::string(result);
+}
+
llama_token llama_sampling_sample(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
const int idx) {
- const int n_ctx = llama_n_ctx(ctx_main);
- const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
-
const llama_sampling_params & params = ctx_sampling->params;
+ const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
+
const float temp = params.temp;
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
const float top_p = params.top_p;
const float tfs_z = params.tfs_z;
const float typical_p = params.typical_p;
- const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
- const float repeat_penalty = params.repeat_penalty;
- const float alpha_presence = params.presence_penalty;
- const float alpha_frequency = params.frequency_penalty;
+ const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
+ const float penalty_repeat = params.penalty_repeat;
+ const float penalty_freq = params.penalty_freq;
+ const float penalty_present = params.penalty_present;
const int mirostat = params.mirostat;
const float mirostat_tau = params.mirostat_tau;
const float mirostat_eta = params.mirostat_eta;
@@ -97,7 +128,7 @@ llama_token llama_sampling_sample(
float * logits = llama_get_logits_ith(ctx_main, idx);
- // Apply params.logit_bias map
+ // apply params.logit_bias map
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
logits[it->first] += it->second;
}
@@ -117,14 +148,10 @@ llama_token llama_sampling_sample(
// apply penalties
if (!prev.empty()) {
const float nl_logit = logits[llama_token_nl(ctx_main)];
- const int last_n_repeat = std::min(std::min((int)prev.size(), repeat_last_n), n_ctx);
- llama_sample_repetition_penalty(ctx_main, &cur_p,
- prev.data() + prev.size() - last_n_repeat,
- last_n_repeat, repeat_penalty);
- llama_sample_frequency_and_presence_penalties(ctx_main, &cur_p,
- prev.data() + prev.size() - last_n_repeat,
- last_n_repeat, alpha_frequency, alpha_presence);
+ llama_sample_repetition_penalties(ctx_main, &cur_p,
+ prev.data() + prev.size() - penalty_last_n,
+ penalty_last_n, penalty_repeat, penalty_freq, penalty_present);
if (!penalize_nl) {
for (size_t idx = 0; idx < cur_p.size; idx++) {
@@ -141,7 +168,7 @@ llama_token llama_sampling_sample(
}
if (temp <= 0) {
- // Greedy sampling
+ // greedy sampling
id = llama_sample_token_greedy(ctx_main, &cur_p);
} else {
if (mirostat == 1) {
@@ -152,8 +179,9 @@ llama_token llama_sampling_sample(
llama_sample_temp(ctx_main, &cur_p, temp);
id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu);
} else {
- // Temperature sampling
+ // temperature sampling
size_t min_keep = std::max(1, params.n_probs);
+
llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep);
llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep);
llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep);
@@ -183,11 +211,12 @@ llama_token llama_sampling_sample(
void llama_sampling_accept(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
- llama_token id) {
+ llama_token id,
+ bool apply_grammar) {
ctx_sampling->prev.erase(ctx_sampling->prev.begin());
ctx_sampling->prev.push_back(id);
- if (ctx_sampling->grammar != NULL) {
+ if (ctx_sampling->grammar != NULL && apply_grammar) {
llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id);
}
}