summaryrefslogtreecommitdiff
path: root/examples/pydantic_models_to_grammar.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/pydantic_models_to_grammar.py')
-rw-r--r--examples/pydantic_models_to_grammar.py106
1 files changed, 59 insertions, 47 deletions
diff --git a/examples/pydantic_models_to_grammar.py b/examples/pydantic_models_to_grammar.py
index f029c73a..93e5dcb6 100644
--- a/examples/pydantic_models_to_grammar.py
+++ b/examples/pydantic_models_to_grammar.py
@@ -9,7 +9,7 @@ from inspect import getdoc, isclass
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union, get_args, get_origin, get_type_hints
from docstring_parser import parse
-from pydantic import BaseModel, Field, create_model
+from pydantic import BaseModel, create_model
if TYPE_CHECKING:
from types import GenericAlias
@@ -17,6 +17,9 @@ else:
# python 3.8 compat
from typing import _GenericAlias as GenericAlias
+# TODO: fix this
+# pyright: reportAttributeAccessIssue=information
+
class PydanticDataType(Enum):
"""
@@ -50,35 +53,38 @@ class PydanticDataType(Enum):
def map_pydantic_type_to_gbnf(pydantic_type: type[Any]) -> str:
- if isclass(pydantic_type) and issubclass(pydantic_type, str):
+ origin_type = get_origin(pydantic_type)
+ origin_type = pydantic_type if origin_type is None else origin_type
+
+ if isclass(origin_type) and issubclass(origin_type, str):
return PydanticDataType.STRING.value
- elif isclass(pydantic_type) and issubclass(pydantic_type, bool):
+ elif isclass(origin_type) and issubclass(origin_type, bool):
return PydanticDataType.BOOLEAN.value
- elif isclass(pydantic_type) and issubclass(pydantic_type, int):
+ elif isclass(origin_type) and issubclass(origin_type, int):
return PydanticDataType.INTEGER.value
- elif isclass(pydantic_type) and issubclass(pydantic_type, float):
+ elif isclass(origin_type) and issubclass(origin_type, float):
return PydanticDataType.FLOAT.value
- elif isclass(pydantic_type) and issubclass(pydantic_type, Enum):
+ elif isclass(origin_type) and issubclass(origin_type, Enum):
return PydanticDataType.ENUM.value
- elif isclass(pydantic_type) and issubclass(pydantic_type, BaseModel):
- return format_model_and_field_name(pydantic_type.__name__)
- elif get_origin(pydantic_type) is list:
+ elif isclass(origin_type) and issubclass(origin_type, BaseModel):
+ return format_model_and_field_name(origin_type.__name__)
+ elif origin_type is list:
element_type = get_args(pydantic_type)[0]
return f"{map_pydantic_type_to_gbnf(element_type)}-list"
- elif get_origin(pydantic_type) is set:
+ elif origin_type is set:
element_type = get_args(pydantic_type)[0]
return f"{map_pydantic_type_to_gbnf(element_type)}-set"
- elif get_origin(pydantic_type) is Union:
+ elif origin_type is Union:
union_types = get_args(pydantic_type)
union_rules = [map_pydantic_type_to_gbnf(ut) for ut in union_types]
return f"union-{'-or-'.join(union_rules)}"
- elif get_origin(pydantic_type) is Optional:
+ elif origin_type is Optional:
element_type = get_args(pydantic_type)[0]
return f"optional-{map_pydantic_type_to_gbnf(element_type)}"
- elif isclass(pydantic_type):
- return f"{PydanticDataType.CUSTOM_CLASS.value}-{format_model_and_field_name(pydantic_type.__name__)}"
- elif get_origin(pydantic_type) is dict:
+ elif isclass(origin_type):
+ return f"{PydanticDataType.CUSTOM_CLASS.value}-{format_model_and_field_name(origin_type.__name__)}"
+ elif origin_type is dict:
key_type, value_type = get_args(pydantic_type)
return f"custom-dict-key-type-{format_model_and_field_name(map_pydantic_type_to_gbnf(key_type))}-value-type-{format_model_and_field_name(map_pydantic_type_to_gbnf(value_type))}"
else:
@@ -115,7 +121,7 @@ def get_members_structure(cls, rule_name):
# Modify this comprehension
members = [
f' "\\"{name}\\"" ":" {map_pydantic_type_to_gbnf(param_type)}'
- for name, param_type in cls.__annotations__.items()
+ for name, param_type in get_type_hints(cls).items()
if name != "self"
]
@@ -234,8 +240,9 @@ def generate_gbnf_float_rules(max_digit=None, min_digit=None, max_precision=None
# Define the integer part rule
integer_part_rule = (
- "integer-part" + (f"-max{max_digit}" if max_digit is not None else "") + (
- f"-min{min_digit}" if min_digit is not None else "")
+ "integer-part"
+ + (f"-max{max_digit}" if max_digit is not None else "")
+ + (f"-min{min_digit}" if min_digit is not None else "")
)
# Define the fractional part rule based on precision constraints
@@ -293,17 +300,20 @@ def generate_gbnf_rule_for_type(
field_name = format_model_and_field_name(field_name)
gbnf_type = map_pydantic_type_to_gbnf(field_type)
- if isclass(field_type) and issubclass(field_type, BaseModel):
+ origin_type = get_origin(field_type)
+ origin_type = field_type if origin_type is None else origin_type
+
+ if isclass(origin_type) and issubclass(origin_type, BaseModel):
nested_model_name = format_model_and_field_name(field_type.__name__)
nested_model_rules, _ = generate_gbnf_grammar(field_type, processed_models, created_rules)
rules.extend(nested_model_rules)
gbnf_type, rules = nested_model_name, rules
- elif isclass(field_type) and issubclass(field_type, Enum):
+ elif isclass(origin_type) and issubclass(origin_type, Enum):
enum_values = [f'"\\"{e.value}\\""' for e in field_type] # Adding escaped quotes
enum_rule = f"{model_name}-{field_name} ::= {' | '.join(enum_values)}"
rules.append(enum_rule)
gbnf_type, rules = model_name + "-" + field_name, rules
- elif get_origin(field_type) == list: # Array
+ elif origin_type is list: # Array
element_type = get_args(field_type)[0]
element_rule_name, additional_rules = generate_gbnf_rule_for_type(
model_name, f"{field_name}-element", element_type, is_optional, processed_models, created_rules
@@ -313,7 +323,7 @@ def generate_gbnf_rule_for_type(
rules.append(array_rule)
gbnf_type, rules = model_name + "-" + field_name, rules
- elif get_origin(field_type) == set or field_type == set: # Array
+ elif origin_type is set: # Array
element_type = get_args(field_type)[0]
element_rule_name, additional_rules = generate_gbnf_rule_for_type(
model_name, f"{field_name}-element", element_type, is_optional, processed_models, created_rules
@@ -367,7 +377,7 @@ def generate_gbnf_rule_for_type(
gbnf_type = f"{model_name}-{field_name}-optional"
else:
gbnf_type = f"{model_name}-{field_name}-union"
- elif isclass(field_type) and issubclass(field_type, str):
+ elif isclass(origin_type) and issubclass(origin_type, str):
if field_info and hasattr(field_info, "json_schema_extra") and field_info.json_schema_extra is not None:
triple_quoted_string = field_info.json_schema_extra.get("triple_quoted_string", False)
markdown_string = field_info.json_schema_extra.get("markdown_code_block", False)
@@ -383,8 +393,8 @@ def generate_gbnf_rule_for_type(
gbnf_type = PydanticDataType.STRING.value
elif (
- isclass(field_type)
- and issubclass(field_type, float)
+ isclass(origin_type)
+ and issubclass(origin_type, float)
and field_info
and hasattr(field_info, "json_schema_extra")
and field_info.json_schema_extra is not None
@@ -409,8 +419,8 @@ def generate_gbnf_rule_for_type(
)
elif (
- isclass(field_type)
- and issubclass(field_type, int)
+ isclass(origin_type)
+ and issubclass(origin_type, int)
and field_info
and hasattr(field_info, "json_schema_extra")
and field_info.json_schema_extra is not None
@@ -458,7 +468,7 @@ def generate_gbnf_grammar(model: type[BaseModel], processed_models: set[type[Bas
if not issubclass(model, BaseModel):
# For non-Pydantic classes, generate model_fields from __annotations__ or __init__
if hasattr(model, "__annotations__") and model.__annotations__:
- model_fields = {name: (typ, ...) for name, typ in model.__annotations__.items()}
+ model_fields = {name: (typ, ...) for name, typ in get_type_hints(model).items()}
else:
init_signature = inspect.signature(model.__init__)
parameters = init_signature.parameters
@@ -466,7 +476,7 @@ def generate_gbnf_grammar(model: type[BaseModel], processed_models: set[type[Bas
name != "self"}
else:
# For Pydantic models, use model_fields and check for ellipsis (required fields)
- model_fields = model.__annotations__
+ model_fields = get_type_hints(model)
model_rule_parts = []
nested_rules = []
@@ -680,7 +690,7 @@ def generate_markdown_documentation(
str: Generated text documentation.
"""
documentation = ""
- pyd_models = [(model, True) for model in pydantic_models]
+ pyd_models: list[tuple[type[BaseModel], bool]] = [(model, True) for model in pydantic_models]
for model, add_prefix in pyd_models:
if add_prefix:
documentation += f"{model_prefix}: {model.__name__}\n"
@@ -700,9 +710,9 @@ def generate_markdown_documentation(
# Indenting the fields section
documentation += f" {fields_prefix}:\n"
else:
- documentation += f" Fields:\n"
+ documentation += f" Fields:\n" # noqa: F541
if isclass(model) and issubclass(model, BaseModel):
- for name, field_type in model.__annotations__.items():
+ for name, field_type in get_type_hints(model).items():
# if name == "markdown_code_block":
# continue
if get_origin(field_type) == list:
@@ -750,14 +760,17 @@ def generate_field_markdown(
field_info = model.model_fields.get(field_name)
field_description = field_info.description if field_info and field_info.description else ""
- if get_origin(field_type) == list:
+ origin_type = get_origin(field_type)
+ origin_type = field_type if origin_type is None else origin_type
+
+ if origin_type == list:
element_type = get_args(field_type)[0]
field_text = f"{indent}{field_name} ({format_model_and_field_name(field_type.__name__)} of {format_model_and_field_name(element_type.__name__)})"
if field_description != "":
field_text += ":\n"
else:
field_text += "\n"
- elif get_origin(field_type) == Union:
+ elif origin_type == Union:
element_types = get_args(field_type)
types = []
for element_type in element_types:
@@ -778,7 +791,7 @@ def generate_field_markdown(
return field_text
if field_description != "":
- field_text += f" Description: " + field_description + "\n"
+ field_text += f" Description: {field_description}\n"
# Check for and include field-specific examples if available
if hasattr(model, "Config") and hasattr(model.Config,
@@ -788,9 +801,9 @@ def generate_field_markdown(
example_text = f"'{field_example}'" if isinstance(field_example, str) else field_example
field_text += f"{indent} Example: {example_text}\n"
- if isclass(field_type) and issubclass(field_type, BaseModel):
+ if isclass(origin_type) and issubclass(origin_type, BaseModel):
field_text += f"{indent} Details:\n"
- for name, type_ in field_type.__annotations__.items():
+ for name, type_ in get_type_hints(field_type).items():
field_text += generate_field_markdown(name, type_, field_type, depth + 2)
return field_text
@@ -833,7 +846,7 @@ def generate_text_documentation(
str: Generated text documentation.
"""
documentation = ""
- pyd_models = [(model, True) for model in pydantic_models]
+ pyd_models: list[tuple[type[BaseModel], bool]] = [(model, True) for model in pydantic_models]
for model, add_prefix in pyd_models:
if add_prefix:
documentation += f"{model_prefix}: {model.__name__}\n"
@@ -851,7 +864,7 @@ def generate_text_documentation(
if isclass(model) and issubclass(model, BaseModel):
documentation_fields = ""
- for name, field_type in model.__annotations__.items():
+ for name, field_type in get_type_hints(model).items():
# if name == "markdown_code_block":
# continue
if get_origin(field_type) == list:
@@ -944,7 +957,7 @@ def generate_field_text(
if isclass(field_type) and issubclass(field_type, BaseModel):
field_text += f"{indent} Details:\n"
- for name, type_ in field_type.__annotations__.items():
+ for name, type_ in get_type_hints(field_type).items():
field_text += generate_field_text(name, type_, field_type, depth + 2)
return field_text
@@ -1164,7 +1177,7 @@ def create_dynamic_model_from_function(func: Callable[..., Any]):
dynamic_fields[param.name] = (
param.annotation if param.annotation != inspect.Parameter.empty else str, default_value)
# Creating the dynamic model
- dynamic_model = create_model(f"{func.__name__}", **dynamic_fields) # type: ignore[call-overload]
+ dynamic_model = create_model(f"{func.__name__}", **dynamic_fields)
for name, param_doc in param_docs:
dynamic_model.model_fields[name].description = param_doc.description
@@ -1228,9 +1241,6 @@ def map_grammar_names_to_pydantic_model_class(pydantic_model_list):
return output
-from enum import Enum
-
-
def json_schema_to_python_types(schema):
type_map = {
"any": Any,
@@ -1275,7 +1285,7 @@ def convert_dictionary_to_pydantic_model(dictionary: dict[str, Any], model_name:
if items != {}:
array = {"properties": items}
array_type = convert_dictionary_to_pydantic_model(array, f"{model_name}_{field_name}_items")
- fields[field_name] = (List[array_type], ...) # type: ignore[valid-type]
+ fields[field_name] = (List[array_type], ...)
else:
fields[field_name] = (list, ...)
elif field_type == "object":
@@ -1285,7 +1295,8 @@ def convert_dictionary_to_pydantic_model(dictionary: dict[str, Any], model_name:
required = field_data.get("enum", [])
for key, field in fields.items():
if key not in required:
- fields[key] = (Optional[fields[key][0]], ...)
+ optional_type = fields[key][0]
+ fields[key] = (Optional[optional_type], ...)
else:
field_type = json_schema_to_python_types(field_type)
fields[field_name] = (field_type, ...)
@@ -1305,6 +1316,7 @@ def convert_dictionary_to_pydantic_model(dictionary: dict[str, Any], model_name:
required = dictionary.get("required", [])
for key, field in fields.items():
if key not in required:
- fields[key] = (Optional[fields[key][0]], ...)
+ optional_type = fields[key][0]
+ fields[key] = (Optional[optional_type], ...)
custom_model = create_model(model_name, **fields)
return custom_model