Better handling of Schemas from Yaml

This commit is contained in:
2025-02-14 10:13:32 +00:00
parent 7148335192
commit 289a20650d

View File

@@ -1,11 +1,12 @@
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any, Literal, Union
import yaml import yaml
from openapi_core import OpenAPI from openapi_core import OpenAPI
from pydantic import BaseModel, create_model from pydantic import BaseModel, Field, create_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -138,7 +139,7 @@ class OperationParser:
# Add body fields if present # Add body fields if present
if body_schema and body_schema.get("type") == "object": if body_schema and body_schema.get("type") == "object":
for prop_name, prop_schema in body_schema.get("properties", {}).items(): for prop_name, prop_schema in body_schema.get("properties", {}).items():
field_type = self._map_type(prop_schema.get("type", "string")) field_type = self._map_type(prop_schema.get("type", "string"), prop_schema.get("format"))
required = prop_name in body_schema.get("required", []) required = prop_name in body_schema.get("required", [])
fields[prop_name] = (field_type, ... if required else None) fields[prop_name] = (field_type, ... if required else None)
parameter_mapping["body"].append(prop_name) parameter_mapping["body"].append(prop_name)
@@ -230,45 +231,66 @@ class OperationParser:
"description": param.get("description"), "description": param.get("description"),
} }
def _map_type(self, openapi_type: str) -> type: def _map_type(self, openapi_type: str, format_type: str | None = None, schema: dict[str, Any] | None = None) -> type:
"""Map OpenAPI type to Python type.""" """Map OpenAPI type to Python type."""
type_map = { # Handle enums
if schema and "enum" in schema:
return Literal[tuple(schema["enum"])] # type: ignore
# Handle date-time format
if openapi_type == "string" and format_type == "date-time":
return datetime
# Handle nullable fields
base_type = {
"string": str, "string": str,
"integer": int, "integer": int,
"number": float, "number": float,
"boolean": bool, "boolean": bool,
"array": list, "array": list,
"object": dict, "object": dict,
} }.get(openapi_type, Any)
return type_map.get(openapi_type, Any)
if schema and schema.get("nullable"):
return Union[base_type, None] # noqa: UP007
return base_type
def _parse_response_model(self, operation: dict[str, Any]) -> type[BaseModel] | None: def _parse_response_model(self, operation: dict[str, Any]) -> type[BaseModel] | None:
"""Parse response schema into Pydantic model. """Parse response schema into Pydantic model."""
try:
responses = operation.get("responses", {})
if "200" not in responses:
return None
Args: response = responses["200"]
operation: Operation object from OpenAPI spec if "$ref" in response:
response = self._resolve_ref(response["$ref"])
Returns: content = response.get("content", {})
Pydantic model for response or None if "application/json" not in content:
""" return None
responses = operation.get("responses", {})
if "200" not in responses: schema = content["application/json"].get("schema", {})
if "$ref" in schema:
schema = self._resolve_ref(schema["$ref"])
logger.debug("Response schema before model creation: %s", schema)
# Ensure schema has properties
if "properties" not in schema:
logger.error("Response schema missing properties")
return None
# Create model with schema properties
model = self._create_model("Response", schema)
logger.debug("Created response model with schema: %s", model.model_json_schema())
return model
except Exception as e:
logger.error("Failed to create response model: %s", e)
return None return None
response = responses["200"]
if "$ref" in response:
response = self._resolve_ref(response["$ref"])
content = response.get("content", {})
if "application/json" not in content:
return None
schema = content["application/json"].get("schema", {})
if "$ref" in schema:
schema = self._resolve_ref(schema["$ref"])
return self._create_model("Response", schema)
def _create_model(self, name: str, schema: dict[str, Any]) -> type[BaseModel]: def _create_model(self, name: str, schema: dict[str, Any]) -> type[BaseModel]:
"""Create Pydantic model from schema. """Create Pydantic model from schema.
@@ -285,6 +307,23 @@ class OperationParser:
if "$ref" in schema: if "$ref" in schema:
schema = self._resolve_ref(schema["$ref"]) schema = self._resolve_ref(schema["$ref"])
# Handle discriminated unions
if "anyOf" in schema:
discriminator = schema.get("discriminator", {}).get("propertyName")
if discriminator:
union_types = []
for subschema in schema["anyOf"]:
if "$ref" in subschema:
resolved = self._resolve_ref(subschema["$ref"])
sub_model = self._create_model(f"{name}_type", resolved)
union_types.append(sub_model)
fields = {
"type_name": (str, ...), # Use type_name instead of __type
**{k: (Union[tuple(union_types)], ...) for k in schema.get("properties", {}) if k != discriminator}, # noqa: UP007
}
return create_model(name, **fields)
if schema.get("type", "object") != "object": if schema.get("type", "object") != "object":
raise ValueError("Schema must be an object type") raise ValueError("Schema must be an object type")
@@ -293,6 +332,8 @@ class OperationParser:
if "$ref" in prop_schema: if "$ref" in prop_schema:
prop_schema = self._resolve_ref(prop_schema["$ref"]) prop_schema = self._resolve_ref(prop_schema["$ref"])
required = prop_name in schema.get("required", [])
if prop_schema.get("type") == "object": if prop_schema.get("type") == "object":
nested_model = self._create_model(f"{name}_{prop_name}", prop_schema) nested_model = self._create_model(f"{name}_{prop_name}", prop_schema)
field_type = nested_model field_type = nested_model
@@ -307,9 +348,27 @@ class OperationParser:
item_type = self._map_type(items.get("type", "string")) item_type = self._map_type(items.get("type", "string"))
field_type = list[item_type] field_type = list[item_type]
else: else:
field_type = self._map_type(prop_schema.get("type", "string")) field_type = self._map_type(prop_schema.get("type", "string"), prop_schema.get("format"), prop_schema)
# Add validation parameters
field_config = {}
if "pattern" in prop_schema:
field_config["pattern"] = prop_schema["pattern"]
if "minLength" in prop_schema:
field_config["min_length"] = prop_schema["minLength"]
if "maxLength" in prop_schema:
field_config["max_length"] = prop_schema["maxLength"]
if "minimum" in prop_schema:
field_config["ge"] = prop_schema["minimum"]
if "maximum" in prop_schema:
field_config["le"] = prop_schema["maximum"]
# Create field with validation
if field_config:
field = Field(..., **field_config) if required else Field(None, **field_config)
fields[prop_name] = (field_type, field)
continue
required = prop_name in schema.get("required", [])
fields[prop_name] = (field_type, ... if required else None) fields[prop_name] = (field_type, ... if required else None)
logger.debug("Creating model %s with fields: %s", name, fields) logger.debug("Creating model %s with fields: %s", name, fields)