diff options
Diffstat (limited to 'tests/test-grammar-integration.cpp')
-rw-r--r-- | tests/test-grammar-integration.cpp | 243 |
1 files changed, 243 insertions, 0 deletions
diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp new file mode 100644 index 00000000..0a9c3b6f --- /dev/null +++ b/tests/test-grammar-integration.cpp @@ -0,0 +1,243 @@ +#ifdef NDEBUG +#undef NDEBUG +#endif + +#define LLAMA_API_INTERNAL + +#include "ggml.h" +#include "llama.h" +#include "grammar-parser.h" +#include "unicode.h" +#include <cassert> +#include <string> + +static void test_simple_grammar() { + // Test case for a simple grammar + const std::string grammar_str = R"""(root ::= expr +expr ::= term ("+" term)* +term ::= number +number ::= [0-9]+)"""; + + grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str()); + + // Ensure we parsed correctly + assert(!parsed_grammar.rules.empty()); + + // Ensure we have a root node + assert(!(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end())); + + std::vector<const llama_grammar_element*> grammar_rules(parsed_grammar.c_rules()); + llama_grammar* grammar = llama_grammar_init( + grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); + + std::string input = "123+456"; + + auto decoded = decode_utf8(input, {}); + + const auto & code_points = decoded.first; + + for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { + auto prev_stacks = grammar->stacks; + grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it); + assert(!grammar->stacks.empty()); + } + + bool completed_grammar = false; + + for (const auto & stack : grammar->stacks) { + if (stack.empty()) { + completed_grammar = true; + break; + } + } + + assert(completed_grammar); + + // Clean up allocated memory + llama_grammar_free(grammar); +} + +static void test_complex_grammar() { + // Test case for a more complex grammar, with both failure strings and success strings + const std::string grammar_str = R"""(root ::= expression +expression ::= term ws (("+"|"-") ws term)* +term ::= factor ws (("*"|"/") ws factor)* +factor ::= number | variable | "(" expression ")" | function-call +number ::= [0-9]+ +variable ::= [a-zA-Z_][a-zA-Z0-9_]* +function-call ::= variable ws "(" (expression ("," ws expression)*)? ")" +ws ::= [ \t\n\r]?)"""; + + grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str()); + + // Ensure we parsed correctly + assert(!parsed_grammar.rules.empty()); + + // Ensure we have a root node + assert(!(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end())); + + std::vector<const llama_grammar_element*> grammar_rules(parsed_grammar.c_rules()); + llama_grammar* grammar = llama_grammar_init( + grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); + + // Save the original grammar stacks so that we can reset after every new string we want to test + auto original_stacks = grammar->stacks; + + // Test a few strings + std::vector<std::string> test_strings_pass = { + "42", + "1*2*3*4*5", + "x", + "x+10", + "x1+y2", + "(a+b)*(c-d)", + "func()", + "func(x,y+2)", + "a*(b+c)-d/e", + "f(g(x),h(y,z))", + "x + 10", + "x1 + y2", + "(a + b) * (c - d)", + "func()", + "func(x, y + 2)", + "a * (b + c) - d / e", + "f(g(x), h(y, z))", + "123+456", + "123*456*789-123/456+789*123", + "123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456" + }; + + std::vector<std::string> test_strings_fail = { + "+", + "/ 3x", + "x + + y", + "a * / b", + "func(,)", + "func(x y)", + "(a + b", + "x + y)", + "a + b * (c - d", + "42 +", + "x +", + "x + 10 +", + "(a + b) * (c - d", + "func(", + "func(x, y + 2", + "a * (b + c) - d /", + "f(g(x), h(y, z)", + "123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456/", + }; + + // Passing strings + for (const auto & test_string : test_strings_pass) { + auto decoded = decode_utf8(test_string, {}); + + const auto & code_points = decoded.first; + + int pos = 0; + for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { + ++pos; + auto prev_stacks = grammar->stacks; + grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it); + + // Expect that each code point will not cause the grammar to fail + if (grammar->stacks.empty()) { + fprintf(stdout, "Error at position %d\n", pos); + fprintf(stderr, "Unexpected character '%s'\n", unicode_cpt_to_utf8(*it).c_str()); + fprintf(stderr, "Input string is %s:\n", test_string.c_str()); + } + assert(!grammar->stacks.empty()); + } + + bool completed_grammar = false; + + for (const auto & stack : grammar->stacks) { + if (stack.empty()) { + completed_grammar = true; + break; + } + } + + assert(completed_grammar); + + // Reset the grammar stacks + grammar->stacks = original_stacks; + } + + // Failing strings + for (const auto & test_string : test_strings_fail) { + auto decoded = decode_utf8(test_string, {}); + + const auto & code_points = decoded.first; + bool parse_failed = false; + + for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { + auto prev_stacks = grammar->stacks; + grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it); + if (grammar->stacks.empty()) { + parse_failed = true; + break; + } + assert(!grammar->stacks.empty()); + } + + bool completed_grammar = false; + + for (const auto & stack : grammar->stacks) { + if (stack.empty()) { + completed_grammar = true; + break; + } + } + + // Ensure that the grammar is not completed, or that each string failed to match as-expected + assert((!completed_grammar) || parse_failed); + + // Reset the grammar stacks + grammar->stacks = original_stacks; + } + + // Clean up allocated memory + llama_grammar_free(grammar); +} + +static void test_failure_missing_root() { + // Test case for a grammar that is missing a root rule + const std::string grammar_str = R"""(rot ::= expr +expr ::= term ("+" term)* +term ::= number +number ::= [0-9]+)"""; + + grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str()); + + // Ensure we parsed correctly + assert(!parsed_grammar.rules.empty()); + + // Ensure we do NOT have a root node + assert(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()); +} + +static void test_failure_missing_reference() { + // Test case for a grammar that is missing a referenced rule + const std::string grammar_str = R"""(root ::= expr +expr ::= term ("+" term)* +term ::= numero +number ::= [0-9]+)"""; + + fprintf(stderr, "Expected error: "); + + grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str()); + + // Ensure we did NOT parsed correctly + assert(parsed_grammar.rules.empty()); + + fprintf(stderr, "End of expected error. Test successful.\n"); +} + +int main() { + test_simple_grammar(); + test_complex_grammar(); + test_failure_missing_root(); + test_failure_missing_reference(); + return 0; +} |