diff options
Diffstat (limited to 'examples/json_schema_to_grammar.py')
-rwxr-xr-x | examples/json_schema_to_grammar.py | 276 |
1 files changed, 252 insertions, 24 deletions
diff --git a/examples/json_schema_to_grammar.py b/examples/json_schema_to_grammar.py index b588497b..a8779bf3 100755 --- a/examples/json_schema_to_grammar.py +++ b/examples/json_schema_to_grammar.py @@ -1,11 +1,12 @@ #!/usr/bin/env python3 +from __future__ import annotations + import argparse import itertools import json import re import sys -from typing import Any, Dict, List, Set, Tuple, Union - +from typing import Any, List, Optional, Set, Tuple, Union def _build_repetition(item_rule, min_items, max_items, separator_rule=None): @@ -23,9 +24,173 @@ def _build_repetition(item_rule, min_items, max_items, separator_rule=None): result = item_rule + ' ' + _build_repetition(f'({separator_rule} {item_rule})', min_items - 1 if min_items > 0 else 0, max_items - 1 if max_items is not None else None) return f'({result})?' if min_items == 0 else result +def _generate_min_max_int(min_value: Optional[int], max_value: Optional[int], out: list, decimals_left: int = 16, top_level: bool = True): + has_min = min_value != None + has_max = max_value != None + + def digit_range(from_char: str, to_char: str): + out.append("[") + if from_char == to_char: + out.append(from_char) + else: + out.append(from_char) + out.append("-") + out.append(to_char) + out.append("]") + + def more_digits(min_digits: int, max_digits: int): + out.append("[0-9]") + if min_digits == max_digits and min_digits == 1: + return + out.append("{") + out.append(str(min_digits)) + if max_digits != min_digits: + out.append(",") + if max_digits != sys.maxsize: + out.append(str(max_digits)) + out.append("}") + + def uniform_range(from_str: str, to_str: str): + i = 0 + while i < len(from_str) and from_str[i] == to_str[i]: + i += 1 + if i > 0: + out.append("\"") + out.append(from_str[:i]) + out.append("\"") + if i < len(from_str): + if i > 0: + out.append(" ") + sub_len = len(from_str) - i - 1 + if sub_len > 0: + from_sub = from_str[i+1:] + to_sub = to_str[i+1:] + sub_zeros = "0" * sub_len + sub_nines = "9" * sub_len + + to_reached = False + out.append("(") + if from_sub == sub_zeros: + digit_range(from_str[i], chr(ord(to_str[i]) - 1)) + out.append(" ") + more_digits(sub_len, sub_len) + else: + out.append("[") + out.append(from_str[i]) + out.append("] ") + out.append("(") + uniform_range(from_sub, sub_nines) + out.append(")") + if ord(from_str[i]) < ord(to_str[i]) - 1: + out.append(" | ") + if to_sub == sub_nines: + digit_range(chr(ord(from_str[i]) + 1), to_str[i]) + to_reached = True + else: + digit_range(chr(ord(from_str[i]) + 1), chr(ord(to_str[i]) - 1)) + out.append(" ") + more_digits(sub_len, sub_len) + if not to_reached: + out.append(" | ") + digit_range(to_str[i], to_str[i]) + out.append(" ") + uniform_range(sub_zeros, to_sub) + out.append(")") + else: + out.append("[") + out.append(from_str[i]) + out.append("-") + out.append(to_str[i]) + out.append("]") + + if has_min and has_max: + if min_value < 0 and max_value < 0: + out.append("\"-\" (") + _generate_min_max_int(-max_value, -min_value, out, decimals_left, top_level=True) + out.append(")") + return + + if min_value < 0: + out.append("\"-\" (") + _generate_min_max_int(0, -min_value, out, decimals_left, top_level=True) + out.append(") | ") + min_value = 0 + + min_s = str(min_value) + max_s = str(max_value) + min_digits = len(min_s) + max_digits = len(max_s) + + for digits in range(min_digits, max_digits): + uniform_range(min_s, "9" * digits) + min_s = "1" + "0" * digits + out.append(" | ") + uniform_range(min_s, max_s) + return + + less_decimals = max(decimals_left - 1, 1) + + if has_min: + if min_value < 0: + out.append("\"-\" (") + _generate_min_max_int(None, -min_value, out, decimals_left, top_level=False) + out.append(") | [0] | [1-9] ") + more_digits(0, decimals_left - 1) + elif min_value == 0: + if top_level: + out.append("[0] | [1-9] ") + more_digits(0, less_decimals) + else: + more_digits(1, decimals_left) + elif min_value <= 9: + c = str(min_value) + range_start = '1' if top_level else '0' + if c > range_start: + digit_range(range_start, chr(ord(c) - 1)) + out.append(" ") + more_digits(1, less_decimals) + out.append(" | ") + digit_range(c, "9") + out.append(" ") + more_digits(0, less_decimals) + else: + min_s = str(min_value) + length = len(min_s) + c = min_s[0] + + if c > "1": + digit_range("1" if top_level else "0", chr(ord(c) - 1)) + out.append(" ") + more_digits(length, less_decimals) + out.append(" | ") + digit_range(c, c) + out.append(" (") + _generate_min_max_int(int(min_s[1:]), None, out, less_decimals, top_level=False) + out.append(")") + if c < "9": + out.append(" | ") + digit_range(chr(ord(c) + 1), "9") + out.append(" ") + more_digits(length - 1, less_decimals) + return + + if has_max: + if max_value >= 0: + if top_level: + out.append("\"-\" [1-9] ") + more_digits(0, less_decimals) + out.append(" | ") + _generate_min_max_int(0, max_value, out, decimals_left, top_level=True) + else: + out.append("\"-\" (") + _generate_min_max_int(-max_value, None, out, decimals_left, top_level=False) + out.append(")") + return + + raise RuntimeError("At least one of min_value or max_value must be set") class BuiltinRule: - def __init__(self, content: str, deps: list = None): + def __init__(self, content: str, deps: list | None = None): self.content = content self.deps = deps or [] @@ -68,7 +233,7 @@ GRAMMAR_RANGE_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"\]\-\\]') GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]'} NON_LITERAL_SET = set('|.()[]{}*+?') -ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set('[]()|{}*+?') +ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set('^$.[]()|{}*+?') class SchemaConverter: @@ -85,7 +250,7 @@ class SchemaConverter: def _format_literal(self, literal): escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub( - lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), literal + lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)) or m.group(0), literal ) return f'"{escaped}"' @@ -112,6 +277,51 @@ class SchemaConverter: return ''.join(('(', *recurse(0), ')')) + def _not_strings(self, strings): + class TrieNode: + def __init__(self): + self.children = {} + self.is_end_of_string = False + + def insert(self, string): + node = self + for c in string: + node = node.children.setdefault(c, TrieNode()) + node.is_end_of_string = True + + trie = TrieNode() + for s in strings: + trie.insert(s) + + char_rule = self._add_primitive('char', PRIMITIVE_RULES['char']) + out = ['["] ( '] + + def visit(node): + rejects = [] + first = True + for c in sorted(node.children.keys()): + child = node.children[c] + rejects.append(c) + if first: + first = False + else: + out.append(' | ') + out.append(f'[{c}]') + if child.children: + out.append(f' (') + visit(child) + out.append(')') + elif child.is_end_of_string: + out.append(f' {char_rule}+') + if node.children: + if not first: + out.append(' | ') + out.append(f'[^"{"".join(rejects)}] {char_rule}*') + visit(trie) + + out.append(f' ){"" if trie.is_end_of_string else "?"} ["] space') + return ''.join(out) + def _add_rule(self, name, rule): esc_name = INVALID_RULE_CHARS_RE.sub('-', name) if esc_name not in self._rules or self._rules[esc_name] == rule: @@ -195,11 +405,11 @@ class SchemaConverter: i = 0 length = len(pattern) - def to_rule(s: Tuple[str, bool]) -> str: + def to_rule(s: tuple[str, bool]) -> str: (txt, is_literal) = s return "\"" + txt + "\"" if is_literal else txt - def transform() -> Tuple[str, bool]: + def transform() -> tuple[str, bool]: ''' Parse a unit at index i (advancing it), and return its string representation + whether it's a literal. ''' @@ -212,7 +422,7 @@ class SchemaConverter: # We only need a flat structure here to apply repetition operators to the last item, and # to merge literals at the and (we're parsing grouped ( sequences ) recursively and don't treat '|' specially # (GBNF's syntax is luckily very close to regular expressions!) - seq: list[Tuple[str, bool]] = [] + seq: list[tuple[str, bool]] = [] def get_dot(): if self._dotall: @@ -357,13 +567,13 @@ class SchemaConverter: return self._add_rule(rule_name, self._generate_union_rule(name, schema.get('oneOf') or schema['anyOf'])) elif isinstance(schema_type, list): - return self._add_rule(rule_name, self._generate_union_rule(name, [{'type': t} for t in schema_type])) + return self._add_rule(rule_name, self._generate_union_rule(name, [{**schema, 'type': t} for t in schema_type])) elif 'const' in schema: - return self._add_rule(rule_name, self._generate_constant_rule(schema['const'])) + return self._add_rule(rule_name, self._generate_constant_rule(schema['const']) + ' space') elif 'enum' in schema: - rule = ' | '.join((self._generate_constant_rule(v) for v in schema['enum'])) + rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in schema['enum'])) + ') space' return self._add_rule(rule_name, rule) elif schema_type in (None, 'object') and \ @@ -394,7 +604,7 @@ class SchemaConverter: else: add_component(t, is_required=True) - return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=[])) + return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=None)) elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema): items = schema.get('items') or schema['prefixItems'] @@ -432,6 +642,24 @@ class SchemaConverter: return self._add_rule(rule_name, r'"\"" ' + _build_repetition(char_rule, min_len, max_len) + r' "\"" space') + elif schema_type in (None, 'integer') and \ + ('minimum' in schema or 'exclusiveMinimum' in schema or 'maximum' in schema or 'exclusiveMaximum' in schema): + min_value = None + max_value = None + if 'minimum' in schema: + min_value = schema['minimum'] + elif 'exclusiveMinimum' in schema: + min_value = schema['exclusiveMinimum'] + 1 + if 'maximum' in schema: + max_value = schema['maximum'] + elif 'exclusiveMaximum' in schema: + max_value = schema['exclusiveMaximum'] - 1 + + out = ["("] + _generate_min_max_int(min_value, max_value, out) + out.append(") space") + return self._add_rule(rule_name, ''.join(out)) + elif (schema_type == 'object') or (len(schema) == 0): return self._add_rule(rule_name, self._add_primitive('object', PRIMITIVE_RULES['object'])) @@ -450,7 +678,7 @@ class SchemaConverter: self._add_primitive(dep, dep_rule) return n - def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Union[bool, Any]): + def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Optional[Union[bool, Any]]): prop_order = self._prop_order # sort by position in prop_order (if specified) then by original order sorted_props = [kv[0] for _, kv in sorted(enumerate(properties), key=lambda ikv: (prop_order.get(ikv[1][0], len(prop_order)), ikv[0]))] @@ -465,12 +693,16 @@ class SchemaConverter: required_props = [k for k in sorted_props if k in required] optional_props = [k for k in sorted_props if k not in required] - if additional_properties == True or isinstance(additional_properties, dict): + if additional_properties is not None and additional_properties != False: sub_name = f'{name}{"-" if name else ""}additional' - value_rule = self.visit({} if additional_properties == True else additional_properties, f'{sub_name}-value') + value_rule = self.visit(additional_properties, f'{sub_name}-value') if isinstance(additional_properties, dict) else \ + self._add_primitive('value', PRIMITIVE_RULES['value']) + key_rule = self._add_primitive('string', PRIMITIVE_RULES['string']) if not sorted_props \ + else self._add_rule(f'{sub_name}-k', self._not_strings(sorted_props)) + prop_kv_rule_names["*"] = self._add_rule( f'{sub_name}-kv', - self._add_primitive('string', PRIMITIVE_RULES['string']) + f' ":" space {value_rule}' + f'{key_rule} ":" space {value_rule}' ) optional_props.append("*") @@ -485,15 +717,11 @@ class SchemaConverter: def get_recursive_refs(ks, first_is_optional): [k, *rest] = ks kv_rule_name = prop_kv_rule_names[k] - if k == '*': - res = self._add_rule( - f'{name}{"-" if name else ""}additional-kvs', - f'{kv_rule_name} ( "," space ' + kv_rule_name + ' )*' - ) - elif first_is_optional: - res = f'( "," space {kv_rule_name} )?' + comma_ref = f'( "," space {kv_rule_name} )' + if first_is_optional: + res = comma_ref + ('*' if k == '*' else '?') else: - res = kv_rule_name + res = kv_rule_name + (' ' + comma_ref + "*" if k == '*' else '') if len(rest) > 0: res += ' ' + self._add_rule( f'{name}{"-" if name else ""}{k}-rest', |