diff options
author | Olivier Chafik <ochafik@users.noreply.github.com> | 2024-04-12 19:43:38 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-12 19:43:38 +0100 |
commit | ab9a3240a9da941fdef5cd4a25f2b97c2f5a67aa (patch) | |
tree | aa2efe58bc95a650827db07c83eb8bc0e026162c /examples/json-schema-to-grammar.py | |
parent | fbbc030ba93561fac842af994c5c6c4c1147f13b (diff) |
JSON schema conversion: ⚡️ faster repetitions, min/maxLength for strings, cap number length (#6555)
* json: rename python schema converter to make import easier
* server: skip null json_schema / grammar fields
* json: deps management for primitive rules (+ allow null values)
* json: optimize repetitions for minItems/maxItems and regexps: `a{,3}` goes from `"a"? "a"? "a"?` (explosive combos) to `(a (a (a)?)?)?`
* grammars: add troubleshooting section to readme
* json: cap length of numbers to 15 digits before/after decimal point
(avoids infinite gen, e.g. "one third" -> `0.333333333333...`)
* json: unify all repetition code (w/ or w/o sep)
* json: support string minLength/maxLength
* server+json: update server/README w/ result_format
* nits
* json: fix type error w/ python 3.8
* json: fix server/README (json_schema in /completion vs. result_format in /v1/chat/completions)
* json: simplify DOT `{"type": "string", "pattern": "^.$"}`
* json: remove recursion in opt_repetitions (avoids Python stack overflow)
* json: rm dead code
* json: rm useless assert & ggml.h import
Diffstat (limited to 'examples/json-schema-to-grammar.py')
-rwxr-xr-x | examples/json-schema-to-grammar.py | 550 |
1 files changed, 0 insertions, 550 deletions
diff --git a/examples/json-schema-to-grammar.py b/examples/json-schema-to-grammar.py deleted file mode 100755 index 91dd734c..00000000 --- a/examples/json-schema-to-grammar.py +++ /dev/null @@ -1,550 +0,0 @@ -#!/usr/bin/env python3 -import argparse -import itertools -import json -import re -import sys -from typing import Any, Dict, List, Set, Tuple, Union - -# whitespace is constrained to a single space char to prevent model "running away" in -# whitespace. Also maybe improves generation quality? -SPACE_RULE = '" "?' - -PRIMITIVE_RULES = { - 'boolean': '("true" | "false") space', - 'number': '("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space', - 'integer': '("-"? ([0-9] | [1-9] [0-9]*)) space', - 'value' : 'object | array | string | number | boolean', - 'object' : '"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', - 'array' : '"[" space ( value ("," space value)* )? "]" space', - 'uuid' : '"\\"" ' + ' "-" '.join('[0-9a-fA-F]' * n for n in [8, 4, 4, 4, 12]) + ' "\\"" space', - 'string': r''' "\"" ( - [^"\\] | - "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) - )* "\"" space''', - 'null': '"null" space', -} -OBJECT_RULE_NAMES = ['object', 'array', 'string', 'number', 'boolean', 'null', 'value'] - -# TODO: support "uri", "email" string formats -DATE_RULES = { - 'date' : '[0-9] [0-9] [0-9] [0-9] "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )', - 'time' : '([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9] [0-9] [0-9] )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', - 'date-time': 'date "T" time', - 'date-string': '"\\"" date "\\"" space', - 'time-string': '"\\"" time "\\"" space', - 'date-time-string': '"\\"" date-time "\\"" space', -} - -RESERVED_NAMES = set(["root", *PRIMITIVE_RULES.keys(), *DATE_RULES.keys()]) - -INVALID_RULE_CHARS_RE = re.compile(r'[^a-zA-Z0-9-]+') -GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]') -GRAMMAR_RANGE_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"\]\-\\]') -GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]'} - -NON_LITERAL_SET = set('|.()[]{}*+?') -ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set('[]()|{}*+?') - -DATE_PATTERN = '[0-9]{4}-(0[1-9]|1[0-2])-([0-2][0-9]|3[0-1])' -TIME_PATTERN = '([01][0-9]|2[0-3])(:[0-5][0-9]){2}(\\.[0-9]{1,3})?(Z|[+-](([01][0-9]|2[0-3]):[0-5][0-9]))' # Cap millisecond precision w/ 3 digits - -class SchemaConverter: - def __init__(self, *, prop_order, allow_fetch, dotall, raw_pattern): - self._prop_order = prop_order - self._allow_fetch = allow_fetch - self._dotall = dotall - self._raw_pattern = raw_pattern - self._rules = {'space': SPACE_RULE} - self._refs = {} - self._refs_being_resolved = set() - - def _format_literal(self, literal): - escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub( - lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), literal - ) - return f'"{escaped}"' - - def _add_rule(self, name, rule): - esc_name = INVALID_RULE_CHARS_RE.sub('-', name) - if esc_name not in self._rules or self._rules[esc_name] == rule: - key = esc_name - else: - i = 0 - while f'{esc_name}{i}' in self._rules and self._rules[f'{esc_name}{i}'] != rule: - i += 1 - key = f'{esc_name}{i}' - self._rules[key] = rule - return key - - def resolve_refs(self, schema: dict, url: str): - ''' - Resolves all $ref fields in the given schema, fetching any remote schemas, - replacing $ref with absolute reference URL and populating self._refs with the - respective referenced (sub)schema dictionaries. - ''' - def visit(n: dict): - if isinstance(n, list): - return [visit(x) for x in n] - elif isinstance(n, dict): - ref = n.get('$ref') - if ref is not None and ref not in self._refs: - if ref.startswith('https://'): - assert self._allow_fetch, 'Fetching remote schemas is not allowed (use --allow-fetch for force)' - import requests - - frag_split = ref.split('#') - base_url = frag_split[0] - - target = self._refs.get(base_url) - if target is None: - target = self.resolve_refs(requests.get(ref).json(), base_url) - self._refs[base_url] = target - - if len(frag_split) == 1 or frag_split[-1] == '': - return target - elif ref.startswith('#/'): - target = schema - ref = f'{url}{ref}' - n['$ref'] = ref - else: - raise ValueError(f'Unsupported ref {ref}') - - for sel in ref.split('#')[-1].split('/')[1:]: - assert target is not None and sel in target, f'Error resolving ref {ref}: {sel} not in {target}' - target = target[sel] - - self._refs[ref] = target - else: - for v in n.values(): - visit(v) - - return n - return visit(schema) - - def _generate_union_rule(self, name, alt_schemas): - return ' | '.join(( - self.visit(alt_schema, f'{name}{"-" if name else "alternative-"}{i}') - for i, alt_schema in enumerate(alt_schemas) - )) - - def _visit_pattern(self, pattern, name): - ''' - Transforms a regular expression pattern into a GBNF rule. - - Input: https://json-schema.org/understanding-json-schema/reference/regular_expressions - Output: https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md - - Unsupported features: negative/positive lookaheads, greedy/non-greedy modifiers. - - Mostly a 1:1 translation, except for {x} / {x,} / {x,y} quantifiers for which - we define sub-rules to keep the output lean. - ''' - - assert pattern.startswith('^') and pattern.endswith('$'), 'Pattern must start with "^" and end with "$"' - pattern = pattern[1:-1] - sub_rule_ids = {} - - i = 0 - length = len(pattern) - - def to_rule(s: Tuple[str, bool]) -> str: - (txt, is_literal) = s - return "\"" + txt + "\"" if is_literal else txt - - def transform() -> Tuple[str, bool]: - ''' - Parse a unit at index i (advancing it), and return its string representation + whether it's a literal. - ''' - nonlocal i - nonlocal pattern - nonlocal sub_rule_ids - - start = i - # For each component of this sequence, store its string representation and whether it's a literal. - # We only need a flat structure here to apply repetition operators to the last item, and - # to merge literals at the and (we're parsing grouped ( sequences ) recursively and don't treat '|' specially - # (GBNF's syntax is luckily very close to regular expressions!) - seq: list[Tuple[str, bool]] = [] - - def get_dot(): - if self._dotall: - rule = '[\\U00000000-\\U0010FFFF]' - else: - # Accept any character... except \n and \r line break chars (\x0A and \xOD) - rule = '[\\U00000000-\\x09\\x0B\\x0C\\x0E-\\U0010FFFF]' - return self._add_rule(f'dot', rule) - - def join_seq(): - nonlocal seq - ret = [] - for is_literal, g in itertools.groupby(seq, lambda x: x[1]): - if is_literal: - ret.append((''.join(x[0] for x in g), True)) - else: - ret.extend(g) - if len(ret) == 1: - return ret[0] - return (' '.join(to_rule(x) for x in seq), False) - - while i < length: - c = pattern[i] - if c == '.': - seq.append((get_dot(), False)) - i += 1 - elif c == '(': - i += 1 - if i < length: - assert pattern[i] != '?', f'Unsupported pattern syntax "{pattern[i]}" at index {i} of /{pattern}/' - seq.append((f'({to_rule(transform())})', False)) - elif c == ')': - i += 1 - assert start > 0 and pattern[start-1] == '(', f'Unbalanced parentheses; start = {start}, i = {i}, pattern = {pattern}' - return join_seq() - elif c == '[': - square_brackets = c - i += 1 - while i < length and pattern[i] != ']': - if pattern[i] == '\\': - square_brackets += pattern[i:i+2] - i += 2 - else: - square_brackets += pattern[i] - i += 1 - assert i < length, f'Unbalanced square brackets; start = {start}, i = {i}, pattern = {pattern}' - square_brackets += ']' - i += 1 - seq.append((square_brackets, False)) - elif c == '|': - seq.append(('|', False)) - i += 1 - elif c in ('*', '+', '?'): - seq[-1] = (to_rule(seq[-1]) + c, False) - i += 1 - elif c == '{': - curly_brackets = c - i += 1 - while i < length and pattern[i] != '}': - curly_brackets += pattern[i] - i += 1 - assert i < length, f'Unbalanced curly brackets; start = {start}, i = {i}, pattern = {pattern}' - curly_brackets += '}' - i += 1 - nums = [s.strip() for s in curly_brackets[1:-1].split(',')] - min_times = 0 - max_times = None - try: - if len(nums) == 1: - min_times = int(nums[0]) - max_times = min_times - else: - assert len(nums) == 2 - min_times = int(nums[0]) if nums[0] else 0 - max_times = int(nums[1]) if nums[1] else None - except ValueError: - raise ValueError(f'Invalid quantifier {curly_brackets} in /{pattern}/') - - (sub, sub_is_literal) = seq[-1] - - if min_times == 0 and max_times is None: - seq[-1] = (f'{sub}*', False) - elif min_times == 0 and max_times == 1: - seq[-1] = (f'{sub}?', False) - elif min_times == 1 and max_times is None: - seq[-1] = (f'{sub}+', False) - else: - if not sub_is_literal: - id = sub_rule_ids.get(sub) - if id is None: - id = self._add_rule(f'{name}-{len(sub_rule_ids) + 1}', sub) - sub_rule_ids[sub] = id - sub = id - - seq[-1] = ( - ' '.join( - ([f'"{sub[1:-1] * min_times}"'] if sub_is_literal else [sub] * min_times) + - ([f'{sub}?'] * (max_times - min_times) if max_times is not None else [f'{sub}*'])), - False - ) - else: - literal = '' - while i < length: - if pattern[i] == '\\' and i < length - 1: - next = pattern[i + 1] - if next in ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS: - i += 1 - literal += pattern[i] - i += 1 - else: - literal += pattern[i:i+2] - i += 2 - elif pattern[i] == '"' and not self._raw_pattern: - literal += '\\"' - i += 1 - elif pattern[i] not in NON_LITERAL_SET and \ - (i == length - 1 or literal == '' or pattern[i+1] == '.' or pattern[i+1] not in NON_LITERAL_SET): - literal += pattern[i] - i += 1 - else: - break - if literal: - seq.append((literal, True)) - - return join_seq() - - return self._add_rule( - name, - to_rule(transform()) if self._raw_pattern \ - else "\"\\\"\" " + to_rule(transform()) + " \"\\\"\" space") - - - def _resolve_ref(self, ref): - ref_name = ref.split('/')[-1] - if ref_name not in self._rules and ref not in self._refs_being_resolved: - self._refs_being_resolved.add(ref) - resolved = self._refs[ref] - ref_name = self.visit(resolved, ref_name) - self._refs_being_resolved.remove(ref) - return ref_name - - def _generate_constant_rule(self, value): - return self._format_literal(json.dumps(value)) - - def visit(self, schema, name): - schema_type = schema.get('type') - schema_format = schema.get('format') - rule_name = name + '-' if name in RESERVED_NAMES else name or 'root' - - if (ref := schema.get('$ref')) is not None: - return self._add_rule(rule_name, self._resolve_ref(ref)) - - elif 'oneOf' in schema or 'anyOf' in schema: - return self._add_rule(rule_name, self._generate_union_rule(name, schema.get('oneOf') or schema['anyOf'])) - - elif isinstance(schema_type, list): - return self._add_rule(rule_name, self._generate_union_rule(name, [{'type': t} for t in schema_type])) - - elif 'const' in schema: - return self._add_rule(rule_name, self._generate_constant_rule(schema['const'])) - - elif 'enum' in schema: - rule = ' | '.join((self._generate_constant_rule(v) for v in schema['enum'])) - return self._add_rule(rule_name, rule) - - elif schema_type in (None, 'object') and \ - ('properties' in schema or \ - ('additionalProperties' in schema and schema['additionalProperties'] is not True)): - required = set(schema.get('required', [])) - properties = list(schema.get('properties', {}).items()) - return self._add_rule(rule_name, self._build_object_rule(properties, required, name, schema.get('additionalProperties'))) - - elif schema_type in (None, 'object') and 'allOf' in schema: - required = set() - properties = [] - hybrid_name = name - def add_component(comp_schema, is_required): - if (ref := comp_schema.get('$ref')) is not None: - comp_schema = self._refs[ref] - - if 'properties' in comp_schema: - for prop_name, prop_schema in comp_schema['properties'].items(): - properties.append((prop_name, prop_schema)) - if is_required: - required.add(prop_name) - - for t in schema['allOf']: - if 'anyOf' in t: - for tt in t['anyOf']: - add_component(tt, is_required=False) - else: - add_component(t, is_required=True) - - return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=[])) - - elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema): - items = schema.get('items') or schema['prefixItems'] - if isinstance(items, list): - return self._add_rule( - rule_name, - '"[" space ' + - ' "," space '.join( - self.visit(item, f'{name}{"-" if name else ""}tuple-{i}') - for i, item in enumerate(items)) + - ' "]" space') - else: - item_rule_name = self.visit(items, f'{name}{"-" if name else ""}item') - list_item_operator = f'( "," space {item_rule_name} )' - successive_items = "" - min_items = schema.get("minItems", 0) - max_items = schema.get("maxItems") - if min_items > 0: - successive_items = list_item_operator * (min_items - 1) - min_items -= 1 - if max_items is not None and max_items > min_items: - successive_items += (list_item_operator + "?") * (max_items - min_items - 1) - else: - successive_items += list_item_operator + "*" - if min_items == 0: - rule = f'"[" space ( {item_rule_name} {successive_items} )? "]" space' - else: - rule = f'"[" space {item_rule_name} {successive_items} "]" space' - return self._add_rule(rule_name, rule) - - elif schema_type in (None, 'string') and 'pattern' in schema: - return self._visit_pattern(schema['pattern'], rule_name) - - elif schema_type in (None, 'string') and re.match(r'^uuid[1-5]?$', schema_format or ''): - return self._add_rule( - 'root' if rule_name == 'root' else schema_format, - PRIMITIVE_RULES['uuid'] - ) - - elif schema_type in (None, 'string') and schema_format in DATE_RULES: - for t, r in DATE_RULES.items(): - self._add_rule(t, r) - return schema_format + '-string' - - elif (schema_type == 'object') or (len(schema) == 0): - for n in OBJECT_RULE_NAMES: - self._add_rule(n, PRIMITIVE_RULES[n]) - return self._add_rule(rule_name, 'object') - - else: - assert schema_type in PRIMITIVE_RULES, f'Unrecognized schema: {schema}' - # TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero - return self._add_rule( - 'root' if rule_name == 'root' else schema_type, - PRIMITIVE_RULES[schema_type] - ) - - def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Union[bool, Any]): - prop_order = self._prop_order - # sort by position in prop_order (if specified) then by original order - sorted_props = [kv[0] for _, kv in sorted(enumerate(properties), key=lambda ikv: (prop_order.get(ikv[1][0], len(prop_order)), ikv[0]))] - - prop_kv_rule_names = {} - for prop_name, prop_schema in properties: - prop_rule_name = self.visit(prop_schema, f'{name}{"-" if name else ""}{prop_name}') - prop_kv_rule_names[prop_name] = self._add_rule( - f'{name}{"-" if name else ""}{prop_name}-kv', - fr'{self._format_literal(json.dumps(prop_name))} space ":" space {prop_rule_name}' - ) - required_props = [k for k in sorted_props if k in required] - optional_props = [k for k in sorted_props if k not in required] - - if additional_properties == True or isinstance(additional_properties, dict): - sub_name = f'{name}{"-" if name else ""}additional' - value_rule = self.visit({} if additional_properties == True else additional_properties, f'{sub_name}-value') - prop_kv_rule_names["*"] = self._add_rule( - f'{sub_name}-kv', - self._add_rule('string', PRIMITIVE_RULES['string']) + f' ":" space {value_rule}' - ) - optional_props.append("*") - - rule = '"{" space ' - rule += ' "," space '.join(prop_kv_rule_names[k] for k in required_props) - - if optional_props: - rule += ' (' - if required_props: - rule += ' "," space ( ' - - def get_recursive_refs(ks, first_is_optional): - [k, *rest] = ks - kv_rule_name = prop_kv_rule_names[k] - if k == '*': - res = self._add_rule( - f'{name}{"-" if name else ""}additional-kvs', - f'{kv_rule_name} ( "," space ' + kv_rule_name + ' )*' - ) - elif first_is_optional: - res = f'( "," space {kv_rule_name} )?' - else: - res = kv_rule_name - if len(rest) > 0: - res += ' ' + self._add_rule( - f'{name}{"-" if name else ""}{k}-rest', - get_recursive_refs(rest, first_is_optional=True) - ) - return res - - rule += ' | '.join( - get_recursive_refs(optional_props[i:], first_is_optional=False) - for i in range(len(optional_props)) - ) - if required_props: - rule += ' )' - rule += ' )?' - - rule += ' "}" space' - - return rule - - def format_grammar(self): - return '\n'.join( - f'{name} ::= {rule}' - for name, rule in sorted(self._rules.items(), key=lambda kv: kv[0]) - ) - - -def main(args_in = None): - parser = argparse.ArgumentParser( - description=''' - Generates a grammar (suitable for use in ./main) that produces JSON conforming to a - given JSON schema. Only a subset of JSON schema features are supported; more may be - added in the future. - ''', - ) - parser.add_argument( - '--prop-order', - default=[], - type=lambda s: s.split(','), - help=''' - comma-separated property names defining the order of precedence for object properties; - properties not specified here are given lower precedence than those that are, and - are kept in their original order from the schema. Required properties are always - given precedence over optional properties. - ''' - ) - parser.add_argument( - '--allow-fetch', - action='store_true', - default=False, - help='Whether to allow fetching referenced schemas over HTTPS') - parser.add_argument( - '--dotall', - action='store_true', - default=False, - help='Whether to treat dot (".") as matching all chars including line breaks in regular expression patterns') - parser.add_argument( - '--raw-pattern', - action='store_true', - default=False, - help='Treats string patterns as raw patterns w/o quotes (or quote escapes)') - - parser.add_argument('schema', help='file containing JSON schema ("-" for stdin)') - args = parser.parse_args(args_in) - - if args.schema.startswith('https://'): - url = args.schema - import requests - schema = requests.get(url).json() - elif args.schema == '-': - url = 'stdin' - schema = json.load(sys.stdin) - else: - url = f'file://{args.schema}' - with open(args.schema) as f: - schema = json.load(f) - converter = SchemaConverter( - prop_order={name: idx for idx, name in enumerate(args.prop_order)}, - allow_fetch=args.allow_fetch, - dotall=args.dotall, - raw_pattern=args.raw_pattern) - schema = converter.resolve_refs(schema, url) - converter.visit(schema, '') - print(converter.format_grammar()) - - -if __name__ == '__main__': - main() |