summaryrefslogtreecommitdiff
path: root/tests/test-llama-grammar.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test-llama-grammar.cpp')
-rw-r--r--tests/test-llama-grammar.cpp24
1 files changed, 15 insertions, 9 deletions
diff --git a/tests/test-llama-grammar.cpp b/tests/test-llama-grammar.cpp
index 27ca4d26..1f3a267b 100644
--- a/tests/test-llama-grammar.cpp
+++ b/tests/test-llama-grammar.cpp
@@ -2,10 +2,12 @@
#undef NDEBUG
#endif
-#include "llama.cpp" // TODO: not great
+#define LLAMA_API_INTERNAL
+#include "llama.h"
#include "grammar-parser.h"
#include <cassert>
+#include <stdexcept>
int main()
{
@@ -112,10 +114,14 @@ int main()
}
}
- llama_grammar *grammar = NULL;
+ llama_grammar * grammar = NULL;
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
- grammar = llama_grammar_init(
- grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
+
+ grammar = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
+ if (grammar == nullptr)
+ {
+ throw std::runtime_error("Failed to initialize llama_grammar");
+ }
std::vector<std::vector<llama_grammar_element>> expected_stacks = {
{
@@ -168,7 +174,7 @@ int main()
}};
auto index = 0;
- for (auto stack : grammar->stacks)
+ for (auto stack : llama_grammar_get_stacks(grammar))
{
// compare stack to expected_stack
for (uint32_t i = 0; i < stack.size(); i++)
@@ -370,13 +376,13 @@ int main()
},
};
- std::vector<llama_grammar_candidate> rejects = llama_grammar_reject_candidates_for_stack(grammar->rules, grammar->stacks[0], next_candidates);
+ std::vector<llama_grammar_candidate> rejects = llama_grammar_reject_candidates_for_stack(llama_grammar_get_rules(grammar), llama_grammar_get_stacks(grammar)[0], next_candidates);
std::vector<std::vector<llama_grammar_candidate>> all_rejects;
- for (std::size_t count = 0; count < grammar->stacks.size(); ++count)
+ for (std::size_t count = 0; count < llama_grammar_get_stacks(grammar).size(); ++count)
{
- rejects = llama_grammar_reject_candidates_for_stack(grammar->rules, grammar->stacks[count], next_candidates);
+ rejects = llama_grammar_reject_candidates_for_stack(llama_grammar_get_rules(grammar), llama_grammar_get_stacks(grammar)[count], next_candidates);
all_rejects.push_back(rejects);
}
@@ -397,6 +403,6 @@ int main()
delete[] candidate.code_points;
candidate.code_points = nullptr;
}
- delete grammar;
+ llama_grammar_free(grammar);
return 0;
}