summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorOlivier Chafik <ochafik@users.noreply.github.com>2024-05-21 20:40:00 +0100
committerGitHub <noreply@github.com>2024-05-21 20:40:00 +0100
commite402de364b643cb89ea9f43057733b5d36298670 (patch)
tree0c3b1d54bc5def33eb553182955260eee37908f6
parentfcf6538ba6702c55eaec70da9a75c81d04900a72 (diff)
`grammars`: fix resampling logic regression (#7424)
-rw-r--r--common/sampling.cpp13
-rw-r--r--examples/main/main.cpp4
2 files changed, 9 insertions, 8 deletions
diff --git a/common/sampling.cpp b/common/sampling.cpp
index f0f1b92d..7fc2e215 100644
--- a/common/sampling.cpp
+++ b/common/sampling.cpp
@@ -179,7 +179,7 @@ static llama_token llama_sampling_sample_impl(
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
const int idx,
- bool is_resampling) { // Add a parameter to indicate if we are resampling
+ bool is_resampling) {
const llama_sampling_params & params = ctx_sampling->params;
const float temp = params.temp;
@@ -188,8 +188,8 @@ static llama_token llama_sampling_sample_impl(
const float mirostat_eta = params.mirostat_eta;
std::vector<float> original_logits;
- auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, !is_resampling, &original_logits);
- if (!is_resampling) {
+ auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits);
+ if (ctx_sampling->grammar != NULL && !is_resampling) {
GGML_ASSERT(!original_logits.empty());
}
llama_token id = 0;
@@ -252,7 +252,7 @@ static llama_token llama_sampling_sample_impl(
// Restore logits from the copy
std::copy(original_logits.begin(), original_logits.end(), logits);
- return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, true); // Pass true for is_resampling
+ return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, /* is_resampling= */ true);
}
}
@@ -285,7 +285,8 @@ static llama_token_data_array llama_sampling_prepare_impl(
// Get a pointer to the logits
float * logits = llama_get_logits_ith(ctx_main, idx);
- if (apply_grammar && original_logits != NULL) {
+ if (ctx_sampling->grammar != NULL && !apply_grammar) {
+ GGML_ASSERT(original_logits != NULL);
// Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this.
*original_logits = {logits, logits + llama_n_vocab(llama_get_model(ctx_main))};
}
@@ -342,7 +343,7 @@ llama_token llama_sampling_sample(
struct llama_context * ctx_cfg,
const int idx) {
// Call the implementation function with is_resampling set to false by default
- return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false);
+ return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, /* is_resampling= */ false);
}
llama_token_data_array llama_sampling_prepare(
diff --git a/examples/main/main.cpp b/examples/main/main.cpp
index 9dee4100..832b51ee 100644
--- a/examples/main/main.cpp
+++ b/examples/main/main.cpp
@@ -707,7 +707,7 @@ int main(int argc, char ** argv) {
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
- llama_sampling_accept(ctx_sampling, ctx, id, true);
+ llama_sampling_accept(ctx_sampling, ctx, id, /* apply_grammar= */ true);
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
@@ -728,7 +728,7 @@ int main(int argc, char ** argv) {
// push the prompt in the sampling context in order to apply repetition penalties later
// for the prompt, we don't apply grammar rules
- llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], false);
+ llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], /* apply_grammar= */ false);
++n_consumed;
if ((int) embd.size() >= params.n_batch) {