diff options
Diffstat (limited to 'tests/test-llama-grammar.cpp')
-rw-r--r-- | tests/test-llama-grammar.cpp | 24 |
1 files changed, 15 insertions, 9 deletions
diff --git a/tests/test-llama-grammar.cpp b/tests/test-llama-grammar.cpp index 27ca4d26..1f3a267b 100644 --- a/tests/test-llama-grammar.cpp +++ b/tests/test-llama-grammar.cpp @@ -2,10 +2,12 @@ #undef NDEBUG #endif -#include "llama.cpp" // TODO: not great +#define LLAMA_API_INTERNAL +#include "llama.h" #include "grammar-parser.h" #include <cassert> +#include <stdexcept> int main() { @@ -112,10 +114,14 @@ int main() } } - llama_grammar *grammar = NULL; + llama_grammar * grammar = NULL; std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules()); - grammar = llama_grammar_init( - grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); + + grammar = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); + if (grammar == nullptr) + { + throw std::runtime_error("Failed to initialize llama_grammar"); + } std::vector<std::vector<llama_grammar_element>> expected_stacks = { { @@ -168,7 +174,7 @@ int main() }}; auto index = 0; - for (auto stack : grammar->stacks) + for (auto stack : llama_grammar_get_stacks(grammar)) { // compare stack to expected_stack for (uint32_t i = 0; i < stack.size(); i++) @@ -370,13 +376,13 @@ int main() }, }; - std::vector<llama_grammar_candidate> rejects = llama_grammar_reject_candidates_for_stack(grammar->rules, grammar->stacks[0], next_candidates); + std::vector<llama_grammar_candidate> rejects = llama_grammar_reject_candidates_for_stack(llama_grammar_get_rules(grammar), llama_grammar_get_stacks(grammar)[0], next_candidates); std::vector<std::vector<llama_grammar_candidate>> all_rejects; - for (std::size_t count = 0; count < grammar->stacks.size(); ++count) + for (std::size_t count = 0; count < llama_grammar_get_stacks(grammar).size(); ++count) { - rejects = llama_grammar_reject_candidates_for_stack(grammar->rules, grammar->stacks[count], next_candidates); + rejects = llama_grammar_reject_candidates_for_stack(llama_grammar_get_rules(grammar), llama_grammar_get_stacks(grammar)[count], next_candidates); all_rejects.push_back(rejects); } @@ -397,6 +403,6 @@ int main() delete[] candidate.code_points; candidate.code_points = nullptr; } - delete grammar; + llama_grammar_free(grammar); return 0; } |