summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/gbnf-validator/gbnf-validator.cpp2
-rw-r--r--llama.cpp16
-rw-r--r--llama.h5
-rw-r--r--tests/test-grammar-integration.cpp6
4 files changed, 17 insertions, 12 deletions
diff --git a/examples/gbnf-validator/gbnf-validator.cpp b/examples/gbnf-validator/gbnf-validator.cpp
index e4c0c168..091069ff 100644
--- a/examples/gbnf-validator/gbnf-validator.cpp
+++ b/examples/gbnf-validator/gbnf-validator.cpp
@@ -17,7 +17,7 @@ static bool llama_sample_grammar_string(struct llama_grammar * grammar, const st
size_t pos = 0;
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
auto prev_stacks = grammar->stacks;
- grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
+ llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
if (grammar->stacks.empty()) {
error_pos = pos;
error_msg = "Unexpected character '" + unicode_cpt_to_utf8(*it) + "'";
diff --git a/llama.cpp b/llama.cpp
index b6e2ade9..ad07059c 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -11912,12 +11912,13 @@ static void llama_grammar_advance_stack(
// be positioned at a character range (see `llama_grammar_advance_stack`), and
// produces the N possible stacks if the given char is accepted at those
// positions
-std::vector<std::vector<const llama_grammar_element *>> llama_grammar_accept(
+void llama_grammar_accept(
const std::vector<std::vector<llama_grammar_element>> & rules,
const std::vector<std::vector<const llama_grammar_element *>> & stacks,
- const uint32_t chr) {
+ const uint32_t chr,
+ std::vector<std::vector<const llama_grammar_element *>> & new_stacks) {
- std::vector<std::vector<const llama_grammar_element *>> new_stacks;
+ new_stacks.clear();
for (const auto & stack : stacks) {
if (stack.empty()) {
@@ -11936,8 +11937,6 @@ std::vector<std::vector<const llama_grammar_element *>> llama_grammar_accept(
llama_grammar_advance_stack(rules, new_stack, new_stacks);
}
}
-
- return new_stacks;
}
static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates(
@@ -11951,6 +11950,7 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_
const std::vector<llama_grammar_candidate> & candidates) {
std::vector<llama_grammar_candidate> rejects;
+ rejects.reserve(candidates.size());
if (stack.empty()) {
for (const auto & tok : candidates) {
@@ -11964,6 +11964,8 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_
const llama_grammar_element * stack_pos = stack.back();
std::vector<llama_grammar_candidate> next_candidates;
+ next_candidates.reserve(candidates.size());
+
for (const auto & tok : candidates) {
if (*tok.code_points == 0) {
// reached end of full codepoints in token, reject iff it ended in a partial sequence
@@ -12771,8 +12773,10 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
// Note terminating 0 in decoded string
const auto decoded = decode_utf8(piece, grammar->partial_utf8);
const auto & code_points = decoded.first;
+ std::vector<std::vector<const llama_grammar_element *>> tmp_new_stacks;
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
- grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
+ llama_grammar_accept(grammar->rules, grammar->stacks, *it, tmp_new_stacks);
+ grammar->stacks = tmp_new_stacks;
}
grammar->partial_utf8 = decoded.second;
GGML_ASSERT(!grammar->stacks.empty());
diff --git a/llama.h b/llama.h
index b770a275..b5da686f 100644
--- a/llama.h
+++ b/llama.h
@@ -1097,10 +1097,11 @@ const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal
struct llama_context * ctx
);
-std::vector<std::vector<const llama_grammar_element *>> llama_grammar_accept(
+void llama_grammar_accept(
const std::vector<std::vector<llama_grammar_element>> & rules,
const std::vector<std::vector<const llama_grammar_element *>> & stacks,
- const uint32_t chr);
+ const uint32_t chr,
+ std::vector<std::vector<const llama_grammar_element *>> & new_stacks);
std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
const std::string & src,
diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp
index 0a9c3b6f..2d8f228e 100644
--- a/tests/test-grammar-integration.cpp
+++ b/tests/test-grammar-integration.cpp
@@ -38,7 +38,7 @@ number ::= [0-9]+)""";
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
auto prev_stacks = grammar->stacks;
- grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
+ llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
assert(!grammar->stacks.empty());
}
@@ -138,7 +138,7 @@ ws ::= [ \t\n\r]?)""";
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
++pos;
auto prev_stacks = grammar->stacks;
- grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
+ llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
// Expect that each code point will not cause the grammar to fail
if (grammar->stacks.empty()) {
@@ -173,7 +173,7 @@ ws ::= [ \t\n\r]?)""";
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
auto prev_stacks = grammar->stacks;
- grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
+ llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
if (grammar->stacks.empty()) {
parse_failed = true;
break;