Better handling of Schemas from Yaml

This commit is contained in:
2025-02-14 10:13:32 +00:00
parent 7148335192
commit 289a20650d

View File

@@ -1,11 +1,12 @@
import logging
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Any
from typing import Any, Literal, Union
import yaml
from openapi_core import OpenAPI
from pydantic import BaseModel, create_model
from pydantic import BaseModel, Field, create_model
logger = logging.getLogger(__name__)
@@ -138,7 +139,7 @@ class OperationParser:
# Add body fields if present
if body_schema and body_schema.get("type") == "object":
for prop_name, prop_schema in body_schema.get("properties", {}).items():
field_type = self._map_type(prop_schema.get("type", "string"))
field_type = self._map_type(prop_schema.get("type", "string"), prop_schema.get("format"))
required = prop_name in body_schema.get("required", [])
fields[prop_name] = (field_type, ... if required else None)
parameter_mapping["body"].append(prop_name)
@@ -230,27 +231,34 @@ class OperationParser:
"description": param.get("description"),
}
def _map_type(self, openapi_type: str) -> type:
def _map_type(self, openapi_type: str, format_type: str | None = None, schema: dict[str, Any] | None = None) -> type:
"""Map OpenAPI type to Python type."""
type_map = {
# Handle enums
if schema and "enum" in schema:
return Literal[tuple(schema["enum"])] # type: ignore
# Handle date-time format
if openapi_type == "string" and format_type == "date-time":
return datetime
# Handle nullable fields
base_type = {
"string": str,
"integer": int,
"number": float,
"boolean": bool,
"array": list,
"object": dict,
}
return type_map.get(openapi_type, Any)
}.get(openapi_type, Any)
if schema and schema.get("nullable"):
return Union[base_type, None] # noqa: UP007
return base_type
def _parse_response_model(self, operation: dict[str, Any]) -> type[BaseModel] | None:
"""Parse response schema into Pydantic model.
Args:
operation: Operation object from OpenAPI spec
Returns:
Pydantic model for response or None
"""
"""Parse response schema into Pydantic model."""
try:
responses = operation.get("responses", {})
if "200" not in responses:
return None
@@ -267,7 +275,21 @@ class OperationParser:
if "$ref" in schema:
schema = self._resolve_ref(schema["$ref"])
return self._create_model("Response", schema)
logger.debug("Response schema before model creation: %s", schema)
# Ensure schema has properties
if "properties" not in schema:
logger.error("Response schema missing properties")
return None
# Create model with schema properties
model = self._create_model("Response", schema)
logger.debug("Created response model with schema: %s", model.model_json_schema())
return model
except Exception as e:
logger.error("Failed to create response model: %s", e)
return None
def _create_model(self, name: str, schema: dict[str, Any]) -> type[BaseModel]:
"""Create Pydantic model from schema.
@@ -285,6 +307,23 @@ class OperationParser:
if "$ref" in schema:
schema = self._resolve_ref(schema["$ref"])
# Handle discriminated unions
if "anyOf" in schema:
discriminator = schema.get("discriminator", {}).get("propertyName")
if discriminator:
union_types = []
for subschema in schema["anyOf"]:
if "$ref" in subschema:
resolved = self._resolve_ref(subschema["$ref"])
sub_model = self._create_model(f"{name}_type", resolved)
union_types.append(sub_model)
fields = {
"type_name": (str, ...), # Use type_name instead of __type
**{k: (Union[tuple(union_types)], ...) for k in schema.get("properties", {}) if k != discriminator}, # noqa: UP007
}
return create_model(name, **fields)
if schema.get("type", "object") != "object":
raise ValueError("Schema must be an object type")
@@ -293,6 +332,8 @@ class OperationParser:
if "$ref" in prop_schema:
prop_schema = self._resolve_ref(prop_schema["$ref"])
required = prop_name in schema.get("required", [])
if prop_schema.get("type") == "object":
nested_model = self._create_model(f"{name}_{prop_name}", prop_schema)
field_type = nested_model
@@ -307,9 +348,27 @@ class OperationParser:
item_type = self._map_type(items.get("type", "string"))
field_type = list[item_type]
else:
field_type = self._map_type(prop_schema.get("type", "string"))
field_type = self._map_type(prop_schema.get("type", "string"), prop_schema.get("format"), prop_schema)
# Add validation parameters
field_config = {}
if "pattern" in prop_schema:
field_config["pattern"] = prop_schema["pattern"]
if "minLength" in prop_schema:
field_config["min_length"] = prop_schema["minLength"]
if "maxLength" in prop_schema:
field_config["max_length"] = prop_schema["maxLength"]
if "minimum" in prop_schema:
field_config["ge"] = prop_schema["minimum"]
if "maximum" in prop_schema:
field_config["le"] = prop_schema["maximum"]
# Create field with validation
if field_config:
field = Field(..., **field_config) if required else Field(None, **field_config)
fields[prop_name] = (field_type, field)
continue
required = prop_name in schema.get("required", [])
fields[prop_name] = (field_type, ... if required else None)
logger.debug("Creating model %s with fields: %s", name, fields)