summaryrefslogtreecommitdiff
path: root/common/sampling.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'common/sampling.cpp')
-rw-r--r--common/sampling.cpp29
1 files changed, 19 insertions, 10 deletions
diff --git a/common/sampling.cpp b/common/sampling.cpp
index f1f80351..079e4051 100644
--- a/common/sampling.cpp
+++ b/common/sampling.cpp
@@ -28,9 +28,13 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_
std::vector<const llama_grammar_element *> grammar_rules(result->parsed_grammar.c_rules());
- result->grammar = llama_grammar_init(
+ struct llama_grammar * grammar = llama_grammar_init(
grammar_rules.data(),
grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root"));
+ if (grammar == nullptr) {
+ throw std::runtime_error("Failed to initialize llama_grammar");
+ }
+ result->grammar = grammar;
}
result->prev.resize(params.n_prev);
@@ -59,9 +63,13 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
if (!ctx->parsed_grammar.rules.empty()) {
std::vector<const llama_grammar_element *> grammar_rules(ctx->parsed_grammar.c_rules());
- ctx->grammar = llama_grammar_init(
+ struct llama_grammar * grammar = llama_grammar_init(
grammar_rules.data(),
grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root"));
+ if (grammar == nullptr) {
+ throw std::runtime_error("Failed to initialize llama_grammar");
+ }
+ ctx->grammar = grammar;
}
std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
@@ -274,8 +282,6 @@ static llama_token llama_sampling_sample_impl(
GGML_ASSERT(!original_logits.empty());
}
llama_token id = 0;
- // Get a pointer to the logits
- float * logits = llama_get_logits_ith(ctx_main, idx);
if (temp < 0.0) {
// greedy sampling, with probs
@@ -316,12 +322,15 @@ static llama_token llama_sampling_sample_impl(
}
if (ctx_sampling->grammar != NULL && !is_resampling) {
+ // Get a pointer to the logits
+ float * logits = llama_get_logits_ith(ctx_main, idx);
+
// Create an array with a single token data element for the sampled id
llama_token_data single_token_data = {id, logits[id], 0.0f};
llama_token_data_array single_token_data_array = { &single_token_data, 1, false };
// Apply grammar constraints to the single token
- llama_sample_grammar(ctx_main, &single_token_data_array, ctx_sampling->grammar);
+ llama_grammar_sample(ctx_sampling->grammar, ctx_main, &single_token_data_array);
// Check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
@@ -369,7 +378,7 @@ static llama_token_data_array llama_sampling_prepare_impl(
if (ctx_sampling->grammar != NULL && !apply_grammar) {
GGML_ASSERT(original_logits != NULL);
// Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this.
- *original_logits = {logits, logits + llama_n_vocab(llama_get_model(ctx_main))};
+ *original_logits = {logits, logits + n_vocab};
}
// apply params.logit_bias map
@@ -382,10 +391,10 @@ static llama_token_data_array llama_sampling_prepare_impl(
llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
}
- cur.clear();
+ cur.resize(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
- cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
+ cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
}
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
@@ -412,7 +421,7 @@ static llama_token_data_array llama_sampling_prepare_impl(
// apply grammar checks before sampling logic
if (apply_grammar && ctx_sampling->grammar != NULL) {
- llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
+ llama_grammar_sample(ctx_sampling->grammar, ctx_main, &cur_p);
}
return cur_p;
@@ -446,6 +455,6 @@ void llama_sampling_accept(
ctx_sampling->prev.push_back(id);
if (ctx_sampling->grammar != NULL && apply_grammar) {
- llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id);
+ llama_grammar_accept_token(ctx_sampling->grammar, ctx_main, id);
}
}