diff options
author | drbh <david.richard.holtz@gmail.com> | 2023-08-13 10:00:48 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-13 17:00:48 +0300 |
commit | ee77efea2a1e3f7d153976b0934522b6bbaa62e6 (patch) | |
tree | 86a7db004cec2cc262d404d044408c153feeb843 | |
parent | f64d44a9b9581cd58f7ec40f4fa1c3ca5ca18e1e (diff) |
test : add simple grammar parsing tests (#2594)
* adds simple grammar parsing tests
* adds cassert header
-rw-r--r-- | .gitignore | 1 | ||||
-rw-r--r-- | Makefile | 5 | ||||
-rw-r--r-- | tests/CMakeLists.txt | 1 | ||||
-rw-r--r-- | tests/test-grammar-parser.cpp | 249 |
4 files changed, 255 insertions, 1 deletions
@@ -70,6 +70,7 @@ poetry.lock poetry.toml # Test binaries +tests/test-grammar-parser tests/test-double-float tests/test-grad0 tests/test-opt @@ -2,7 +2,7 @@ BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch convert-llama2c-to-ggml simple server embd-input-test # Binaries only useful for tests -TEST_TARGETS = tests/test-double-float tests/test-grad0 tests/test-opt tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0 +TEST_TARGETS = tests/test-grammar-parser tests/test-double-float tests/test-grad0 tests/test-opt tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0 default: $(BUILD_TARGETS) @@ -412,6 +412,9 @@ benchmark-matmult: examples/benchmark/benchmark-matmult.cpp build-info.h ggml.o vdot: pocs/vdot/vdot.cpp ggml.o $(OBJS) $(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS) +tests/test-grammar-parser: tests/test-grammar-parser.cpp examples/grammar-parser.cpp build-info.h ggml.o llama.o common.o $(OBJS) + $(CXX) $(CXXFLAGS) $(filter-out %.txt,$^) -o $@ $(LDFLAGS) + tests/test-double-float: tests/test-double-float.cpp build-info.h ggml.o llama.o common.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.txt,$^) -o $@ $(LDFLAGS) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 1a40edbe..689fb6f2 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -11,5 +11,6 @@ llama_add_test(test-quantize-fns.cpp) llama_add_test(test-quantize-perf.cpp) llama_add_test(test-sampling.cpp) llama_add_test(test-tokenizer-0.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab.bin) +llama_add_test(test-grammar-parser.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../examples/grammar-parser.cpp) llama_add_test(test-grad0.cpp) # SLOW # llama_add_test(test-opt.cpp) # SLOW diff --git a/tests/test-grammar-parser.cpp b/tests/test-grammar-parser.cpp new file mode 100644 index 00000000..7022988b --- /dev/null +++ b/tests/test-grammar-parser.cpp @@ -0,0 +1,249 @@ +#ifdef NDEBUG +#undef NDEBUG +#endif + +#include "llama.h" +#include "examples/grammar-parser.cpp" +#include <cassert> + +int main() +{ + grammar_parser::parse_state parsed_grammar; + + const char *grammar_bytes = R"""(root ::= (expr "=" term "\n")+ +expr ::= term ([-+*/] term)* +term ::= [0-9]+)"""; + + parsed_grammar = grammar_parser::parse(grammar_bytes); + + std::vector<std::pair<std::string, uint32_t>> expected = { + {"expr", 2}, + {"expr_5", 5}, + {"expr_6", 6}, + {"root", 0}, + {"root_1", 1}, + {"root_4", 4}, + {"term", 3}, + {"term_7", 7}, + }; + + uint32_t index = 0; + for (auto it = parsed_grammar.symbol_ids.begin(); it != parsed_grammar.symbol_ids.end(); ++it) + { + std::string key = it->first; + uint32_t value = it->second; + std::pair<std::string, uint32_t> expected_pair = expected[index]; + + // pretty print error message before asserting + if (expected_pair.first != key || expected_pair.second != value) + { + fprintf(stderr, "expected_pair: %s, %d\n", expected_pair.first.c_str(), expected_pair.second); + fprintf(stderr, "actual_pair: %s, %d\n", key.c_str(), value); + fprintf(stderr, "expected_pair != actual_pair\n"); + } + + assert(expected_pair.first == key && expected_pair.second == value); + + index++; + } + std::vector<llama_grammar_element> expected_rules = { + {LLAMA_GRETYPE_RULE_REF, 4}, + {LLAMA_GRETYPE_END, 0}, + {LLAMA_GRETYPE_RULE_REF, 2}, + {LLAMA_GRETYPE_CHAR, 61}, + {LLAMA_GRETYPE_RULE_REF, 3}, + {LLAMA_GRETYPE_CHAR, 10}, + {LLAMA_GRETYPE_END, 0}, + {LLAMA_GRETYPE_RULE_REF, 3}, + {LLAMA_GRETYPE_RULE_REF, 6}, + {LLAMA_GRETYPE_END, 0}, + {LLAMA_GRETYPE_RULE_REF, 7}, + {LLAMA_GRETYPE_END, 0}, + {LLAMA_GRETYPE_RULE_REF, 1}, + {LLAMA_GRETYPE_RULE_REF, 4}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_RULE_REF, 1}, + {LLAMA_GRETYPE_END, 0}, + {LLAMA_GRETYPE_CHAR, 45}, + {LLAMA_GRETYPE_CHAR_ALT, 43}, + {LLAMA_GRETYPE_CHAR_ALT, 42}, + {LLAMA_GRETYPE_CHAR_ALT, 47}, + {LLAMA_GRETYPE_RULE_REF, 3}, + {LLAMA_GRETYPE_END, 0}, + {LLAMA_GRETYPE_RULE_REF, 5}, + {LLAMA_GRETYPE_RULE_REF, 6}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_END, 0}, + {LLAMA_GRETYPE_CHAR, 48}, + {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57}, + {LLAMA_GRETYPE_RULE_REF, 7}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_CHAR, 48}, + {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57}, + {LLAMA_GRETYPE_END, 0}, + }; + + index = 0; + for (auto rule : parsed_grammar.rules) + { + // compare rule to expected rule + for (uint32_t i = 0; i < rule.size(); i++) + { + llama_grammar_element element = rule[i]; + llama_grammar_element expected_element = expected_rules[index]; + + // pretty print error message before asserting + if (expected_element.type != element.type || expected_element.value != element.value) + { + fprintf(stderr, "index: %d\n", index); + fprintf(stderr, "expected_element: %d, %d\n", expected_element.type, expected_element.value); + fprintf(stderr, "actual_element: %d, %d\n", element.type, element.value); + fprintf(stderr, "expected_element != actual_element\n"); + } + + assert(expected_element.type == element.type && expected_element.value == element.value); + index++; + } + } + + const char *longer_grammar_bytes = R"""( + root ::= (expr "=" ws term "\n")+ + expr ::= term ([-+*/] term)* + term ::= ident | num | "(" ws expr ")" ws + ident ::= [a-z] [a-z0-9_]* ws + num ::= [0-9]+ ws + ws ::= [ \t\n]* + )"""; + + parsed_grammar = grammar_parser::parse(longer_grammar_bytes); + + expected = { + {"expr", 2}, + {"expr_6", 6}, + {"expr_7", 7}, + {"ident", 8}, + {"ident_10", 10}, + {"num", 9}, + {"num_11", 11}, + {"root", 0}, + {"root_1", 1}, + {"root_5", 5}, + {"term", 4}, + {"ws", 3}, + {"ws_12", 12}, + }; + + index = 0; + for (auto it = parsed_grammar.symbol_ids.begin(); it != parsed_grammar.symbol_ids.end(); ++it) + { + std::string key = it->first; + uint32_t value = it->second; + std::pair<std::string, uint32_t> expected_pair = expected[index]; + + // pretty print error message before asserting + if (expected_pair.first != key || expected_pair.second != value) + { + fprintf(stderr, "expected_pair: %s, %d\n", expected_pair.first.c_str(), expected_pair.second); + fprintf(stderr, "actual_pair: %s, %d\n", key.c_str(), value); + fprintf(stderr, "expected_pair != actual_pair\n"); + } + + assert(expected_pair.first == key && expected_pair.second == value); + + index++; + } + expected_rules = { + {LLAMA_GRETYPE_RULE_REF, 5}, + {LLAMA_GRETYPE_END, 0}, + {LLAMA_GRETYPE_RULE_REF, 2}, + {LLAMA_GRETYPE_CHAR, 61}, + {LLAMA_GRETYPE_RULE_REF, 3}, + {LLAMA_GRETYPE_RULE_REF, 4}, + {LLAMA_GRETYPE_CHAR, 10}, + {LLAMA_GRETYPE_END, 0}, + {LLAMA_GRETYPE_RULE_REF, 4}, + {LLAMA_GRETYPE_RULE_REF, 7}, + {LLAMA_GRETYPE_END, 0}, + {LLAMA_GRETYPE_RULE_REF, 12}, + {LLAMA_GRETYPE_END, 0}, + {LLAMA_GRETYPE_RULE_REF, 8}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_RULE_REF, 9}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_CHAR, 40}, + {LLAMA_GRETYPE_RULE_REF, 3}, + {LLAMA_GRETYPE_RULE_REF, 2}, + {LLAMA_GRETYPE_CHAR, 41}, + {LLAMA_GRETYPE_RULE_REF, 3}, + {LLAMA_GRETYPE_END, 0}, + {LLAMA_GRETYPE_RULE_REF, 1}, + {LLAMA_GRETYPE_RULE_REF, 5}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_RULE_REF, 1}, + {LLAMA_GRETYPE_END, 0}, + {LLAMA_GRETYPE_CHAR, 45}, + {LLAMA_GRETYPE_CHAR_ALT, 43}, + {LLAMA_GRETYPE_CHAR_ALT, 42}, + {LLAMA_GRETYPE_CHAR_ALT, 47}, + {LLAMA_GRETYPE_RULE_REF, 4}, + {LLAMA_GRETYPE_END, 0}, + {LLAMA_GRETYPE_RULE_REF, 6}, + {LLAMA_GRETYPE_RULE_REF, 7}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_END, 0}, + {LLAMA_GRETYPE_CHAR, 97}, + {LLAMA_GRETYPE_CHAR_RNG_UPPER, 122}, + {LLAMA_GRETYPE_RULE_REF, 10}, + {LLAMA_GRETYPE_RULE_REF, 3}, + {LLAMA_GRETYPE_END, 0}, + {LLAMA_GRETYPE_RULE_REF, 11}, + {LLAMA_GRETYPE_RULE_REF, 3}, + {LLAMA_GRETYPE_END, 0}, + {LLAMA_GRETYPE_CHAR, 97}, + {LLAMA_GRETYPE_CHAR_RNG_UPPER, 122}, + {LLAMA_GRETYPE_CHAR_ALT, 48}, + {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57}, + {LLAMA_GRETYPE_CHAR_ALT, 95}, + {LLAMA_GRETYPE_RULE_REF, 10}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_END, 0}, + {LLAMA_GRETYPE_CHAR, 48}, + {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57}, + {LLAMA_GRETYPE_RULE_REF, 11}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_CHAR, 48}, + {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57}, + {LLAMA_GRETYPE_END, 0}, + {LLAMA_GRETYPE_CHAR, 32}, + {LLAMA_GRETYPE_CHAR_ALT, 9}, + {LLAMA_GRETYPE_CHAR_ALT, 10}, + {LLAMA_GRETYPE_RULE_REF, 12}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_END, 0}, + }; + + index = 0; + for (auto rule : parsed_grammar.rules) + { + // compare rule to expected rule + for (uint32_t i = 0; i < rule.size(); i++) + { + llama_grammar_element element = rule[i]; + llama_grammar_element expected_element = expected_rules[index]; + + // pretty print error message before asserting + if (expected_element.type != element.type || expected_element.value != element.value) + { + fprintf(stderr, "index: %d\n", index); + fprintf(stderr, "expected_element: %d, %d\n", expected_element.type, expected_element.value); + fprintf(stderr, "actual_element: %d, %d\n", element.type, element.value); + fprintf(stderr, "expected_element != actual_element\n"); + } + + assert(expected_element.type == element.type && expected_element.value == element.value); + index++; + } + } + + return 0; +} |