diff options
author | Olivier Chafik <ochafik@users.noreply.github.com> | 2024-06-06 10:07:06 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-06-06 10:07:06 +0100 |
commit | 55b2d0849d3ec9e45e4a4d9e480f5fa7977872a6 (patch) | |
tree | aef3a9e9a051670751912daf0d06a76783b44b45 /common/grammar-parser.cpp | |
parent | f5d7b268ec4bf8628aa6ccc9f6631d0230dde76f (diff) |
grammars: x{min,max} repetition operator (#6640)
* grammars: x{min,max} repetition operator + tweak +/*/? to avoid duplication of original over alternates
* grammars: handle `x{n}` and fix `x{n,n}`
* grammars: document new repetition operators
* grammars: uniform use of int for min & max
* grammars: refactor parser test
* grammar: parsing tests w/ natural pretty print of updated expectations
* grammars: much prettier print of expectations (+ TEST_GRAMMAR_PARSER_PRINT_ALL=1 to force all)
* grammars: improve test pretty print again
* grammars: pretty print rules and chars
* grammars: fix copy rule skipping
* grammars: disallow `a{,}` (not allowed in regexps)
* Update common/grammar-parser.cpp
Co-authored-by: Clint Herron <hanclinto@gmail.com>
* grammars: fix copy rule skipping (again) & display of expectations
* grammars: more test cases
* grammars: update reps parsing to bring ? / * / + closer to before
* json: use new GBNF repetitions{m,n} syntax
* grammars: update performance gotchas w/ repetition advice
* Update examples/json_schema_to_grammar.py
Co-authored-by: Clint Herron <hanclinto@gmail.com>
* Update examples/server/public/json-schema-to-grammar.mjs
Co-authored-by: Clint Herron <hanclinto@gmail.com>
* grammars: comment on rule repetitions
* grammars: ensure unambiguous number alternatives
* grammar: nit typo switched error msgs
* grammar: nit numbering in comment
* json: update numeric rule to be unambiguous
* Apply suggestions from code review
Co-authored-by: Clint Herron <hanclinto@gmail.com>
* Update examples/server/public/json-schema-to-grammar.mjs
Co-authored-by: Clint Herron <hanclinto@gmail.com>
* json: fix integral-part
* grammar: add repetition tests
---------
Co-authored-by: Clint Herron <hanclinto@gmail.com>
Diffstat (limited to 'common/grammar-parser.cpp')
-rw-r--r-- | common/grammar-parser.cpp | 138 |
1 files changed, 107 insertions, 31 deletions
diff --git a/common/grammar-parser.cpp b/common/grammar-parser.cpp index b5bc7d49..79d2b035 100644 --- a/common/grammar-parser.cpp +++ b/common/grammar-parser.cpp @@ -46,8 +46,12 @@ namespace grammar_parser { state.rules[rule_id] = rule; } + static bool is_digit_char(char c) { + return '0' <= c && c <= '9'; + } + static bool is_word_char(char c) { - return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9'); + return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || is_digit_char(c); } static std::pair<uint32_t, const char *> parse_hex(const char * src, int size) { @@ -99,6 +103,17 @@ namespace grammar_parser { return pos; } + static const char * parse_int(const char * src) { + const char * pos = src; + while (is_digit_char(*pos)) { + pos++; + } + if (pos == src) { + throw std::runtime_error(std::string("expecting integer at ") + src); + } + return pos; + } + static std::pair<uint32_t, const char *> parse_char(const char * src) { if (*src == '\\') { switch (src[1]) { @@ -137,6 +152,60 @@ namespace grammar_parser { bool is_nested) { size_t last_sym_start = out_elements.size(); const char * pos = src; + + auto handle_repetitions = [&](int min_times, int max_times) { + + if (last_sym_start == out_elements.size()) { + throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos); + } + + // apply transformation to previous symbol (last_sym_start to end) according to + // the following rewrite rules: + // S{m,n} --> S S S (m times) S'(n-m) + // S'(x) ::= S S'(x-1) | + // (... n-m definitions of these S' rules ...) + // S'(1) ::= S | + // S{m,} --> S S S (m times) S' + // S' ::= S S' | + // S* --> S{0,} + // --> S' ::= S S' | + // S+ --> S{1,} + // --> S S' + // S' ::= S S' | + // S? --> S{0,1} + // --> S' + // S' ::= S | + + std::vector<llama_grammar_element> previous_elements(out_elements.begin() + last_sym_start, out_elements.end()); + if (min_times == 0) { + out_elements.resize(last_sym_start); + } else { + // Repeat the previous elements (min_times - 1) times + for (int i = 1; i < min_times; i++) { + out_elements.insert(out_elements.end(), previous_elements.begin(), previous_elements.end()); + } + } + + uint32_t last_rec_rule_id = 0; + auto n_opt = max_times < 0 ? 1 : max_times - min_times; + + std::vector<llama_grammar_element> rec_rule(previous_elements); + for (int i = 0; i < n_opt; i++) { + rec_rule.resize(previous_elements.size()); + uint32_t rec_rule_id = generate_symbol_id(state, rule_name); + if (i > 0 || max_times < 0) { + rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id}); + } + rec_rule.push_back({LLAMA_GRETYPE_ALT, 0}); + rec_rule.push_back({LLAMA_GRETYPE_END, 0}); + add_rule(state, rec_rule_id, rec_rule); + last_rec_rule_id = rec_rule_id; + } + if (n_opt > 0) { + out_elements.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id}); + } + }; + while (*pos) { if (*pos == '"') { // literal string pos++; @@ -197,40 +266,47 @@ namespace grammar_parser { throw std::runtime_error(std::string("expecting ')' at ") + pos); } pos = parse_space(pos + 1, is_nested); - } else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator - if (last_sym_start == out_elements.size()) { - throw std::runtime_error(std::string("expecting preceding item to */+/? at ") + pos); - } + } else if (*pos == '*') { + pos = parse_space(pos + 1, is_nested); + handle_repetitions(0, -1); + } else if (*pos == '+') { + pos = parse_space(pos + 1, is_nested); + handle_repetitions(1, -1); + } else if (*pos == '?') { + pos = parse_space(pos + 1, is_nested); + handle_repetitions(0, 1); + } else if (*pos == '{') { + pos = parse_space(pos + 1, is_nested); - // apply transformation to previous symbol (last_sym_start to end) according to - // rewrite rules: - // S* --> S' ::= S S' | - // S+ --> S' ::= S S' | S - // S? --> S' ::= S | - uint32_t sub_rule_id = generate_symbol_id(state, rule_name); - std::vector<llama_grammar_element> sub_rule; - // add preceding symbol to generated rule - sub_rule.insert( - sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end()); - if (*pos == '*' || *pos == '+') { - // cause generated rule to recurse - sub_rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); - } - // mark start of alternate def - sub_rule.push_back({LLAMA_GRETYPE_ALT, 0}); - if (*pos == '+') { - // add preceding symbol as alternate only for '+' (otherwise empty) - sub_rule.insert( - sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end()); + if (!is_digit_char(*pos)) { + throw std::runtime_error(std::string("expecting an int at ") + pos); } - sub_rule.push_back({LLAMA_GRETYPE_END, 0}); - add_rule(state, sub_rule_id, sub_rule); + const char * int_end = parse_int(pos); + int min_times = std::stoul(std::string(pos, int_end - pos)); + pos = parse_space(int_end, is_nested); - // in original rule, replace previous symbol with reference to generated rule - out_elements.resize(last_sym_start); - out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); + int max_times = -1; - pos = parse_space(pos + 1, is_nested); + if (*pos == '}') { + max_times = min_times; + pos = parse_space(pos + 1, is_nested); + } else if (*pos == ',') { + pos = parse_space(pos + 1, is_nested); + + if (is_digit_char(*pos)) { + const char * int_end = parse_int(pos); + max_times = std::stoul(std::string(pos, int_end - pos)); + pos = parse_space(int_end, is_nested); + } + + if (*pos != '}') { + throw std::runtime_error(std::string("expecting '}' at ") + pos); + } + pos = parse_space(pos + 1, is_nested); + } else { + throw std::runtime_error(std::string("expecting ',' at ") + pos); + } + handle_repetitions(min_times, max_times); } else { break; } |