summaryrefslogtreecommitdiff
path: root/examples/gbnf-validator/gbnf-validator.cpp
diff options
context:
space:
mode:
authorKawrakow <48489457+ikawrakow@users.noreply.github.com>2024-07-27 07:55:01 +0200
committerGitHub <noreply@github.com>2024-07-27 07:55:01 +0200
commit154e0d75fccf1784fe9ff6fd76a630b66563da3d (patch)
tree81ce6dbb5b1900c1aa78a879f0593c694cab9d27 /examples/gbnf-validator/gbnf-validator.cpp
parent0684c3e9c70d49323b4fc517128cbe222cab7f96 (diff)
Merge mainline llama.cpp (#3)
* Merging mainline - WIP * Merging mainline - WIP AVX2 and CUDA appear to work. CUDA performance seems slightly (~1-2%) lower as it is so often the case with llama.cpp/ggml after some "improvements" have been made. * Merging mainline - fix Metal * Remove check --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'examples/gbnf-validator/gbnf-validator.cpp')
-rw-r--r--examples/gbnf-validator/gbnf-validator.cpp19
1 files changed, 13 insertions, 6 deletions
diff --git a/examples/gbnf-validator/gbnf-validator.cpp b/examples/gbnf-validator/gbnf-validator.cpp
index 0406dc33..48a705e1 100644
--- a/examples/gbnf-validator/gbnf-validator.cpp
+++ b/examples/gbnf-validator/gbnf-validator.cpp
@@ -16,20 +16,25 @@ static bool llama_sample_grammar_string(struct llama_grammar * grammar, const st
auto decoded = decode_utf8(input_str, {});
const auto & code_points = decoded.first;
+ const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
+ llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar);
+
size_t pos = 0;
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
- auto prev_stacks = grammar->stacks;
- llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
- if (grammar->stacks.empty()) {
+ const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy
+
+ llama_grammar_accept(rules, prev_stacks, *it, cur_stacks);
+
+ if (cur_stacks.empty()) {
error_pos = pos;
error_msg = "Unexpected character '" + unicode_cpt_to_utf8(*it) + "'";
- grammar->stacks = prev_stacks;
+ cur_stacks = prev_stacks;
return false;
}
++pos;
}
- for (const auto & stack : grammar->stacks) {
+ for (const auto & stack : cur_stacks) {
if (stack.empty()) {
return true;
}
@@ -101,7 +106,9 @@ int main(int argc, char** argv) {
auto grammar = llama_grammar_init(
grammar_rules.data(),
grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
-
+ if (grammar == nullptr) {
+ throw std::runtime_error("Failed to initialize llama_grammar");
+ }
// Read the input file
std::string input_str;
{