Better handling of Schemas from Yaml
This commit is contained in:
@@ -1,11 +1,12 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
import yaml
|
||||
from openapi_core import OpenAPI
|
||||
from pydantic import BaseModel, create_model
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -138,7 +139,7 @@ class OperationParser:
|
||||
# Add body fields if present
|
||||
if body_schema and body_schema.get("type") == "object":
|
||||
for prop_name, prop_schema in body_schema.get("properties", {}).items():
|
||||
field_type = self._map_type(prop_schema.get("type", "string"))
|
||||
field_type = self._map_type(prop_schema.get("type", "string"), prop_schema.get("format"))
|
||||
required = prop_name in body_schema.get("required", [])
|
||||
fields[prop_name] = (field_type, ... if required else None)
|
||||
parameter_mapping["body"].append(prop_name)
|
||||
@@ -230,45 +231,66 @@ class OperationParser:
|
||||
"description": param.get("description"),
|
||||
}
|
||||
|
||||
def _map_type(self, openapi_type: str) -> type:
|
||||
def _map_type(self, openapi_type: str, format_type: str | None = None, schema: dict[str, Any] | None = None) -> type:
|
||||
"""Map OpenAPI type to Python type."""
|
||||
type_map = {
|
||||
# Handle enums
|
||||
if schema and "enum" in schema:
|
||||
return Literal[tuple(schema["enum"])] # type: ignore
|
||||
|
||||
# Handle date-time format
|
||||
if openapi_type == "string" and format_type == "date-time":
|
||||
return datetime
|
||||
|
||||
# Handle nullable fields
|
||||
base_type = {
|
||||
"string": str,
|
||||
"integer": int,
|
||||
"number": float,
|
||||
"boolean": bool,
|
||||
"array": list,
|
||||
"object": dict,
|
||||
}
|
||||
return type_map.get(openapi_type, Any)
|
||||
}.get(openapi_type, Any)
|
||||
|
||||
if schema and schema.get("nullable"):
|
||||
return Union[base_type, None] # noqa: UP007
|
||||
|
||||
return base_type
|
||||
|
||||
def _parse_response_model(self, operation: dict[str, Any]) -> type[BaseModel] | None:
|
||||
"""Parse response schema into Pydantic model.
|
||||
"""Parse response schema into Pydantic model."""
|
||||
try:
|
||||
responses = operation.get("responses", {})
|
||||
if "200" not in responses:
|
||||
return None
|
||||
|
||||
Args:
|
||||
operation: Operation object from OpenAPI spec
|
||||
response = responses["200"]
|
||||
if "$ref" in response:
|
||||
response = self._resolve_ref(response["$ref"])
|
||||
|
||||
Returns:
|
||||
Pydantic model for response or None
|
||||
"""
|
||||
responses = operation.get("responses", {})
|
||||
if "200" not in responses:
|
||||
content = response.get("content", {})
|
||||
if "application/json" not in content:
|
||||
return None
|
||||
|
||||
schema = content["application/json"].get("schema", {})
|
||||
if "$ref" in schema:
|
||||
schema = self._resolve_ref(schema["$ref"])
|
||||
|
||||
logger.debug("Response schema before model creation: %s", schema)
|
||||
|
||||
# Ensure schema has properties
|
||||
if "properties" not in schema:
|
||||
logger.error("Response schema missing properties")
|
||||
return None
|
||||
|
||||
# Create model with schema properties
|
||||
model = self._create_model("Response", schema)
|
||||
logger.debug("Created response model with schema: %s", model.model_json_schema())
|
||||
return model
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to create response model: %s", e)
|
||||
return None
|
||||
|
||||
response = responses["200"]
|
||||
if "$ref" in response:
|
||||
response = self._resolve_ref(response["$ref"])
|
||||
|
||||
content = response.get("content", {})
|
||||
if "application/json" not in content:
|
||||
return None
|
||||
|
||||
schema = content["application/json"].get("schema", {})
|
||||
if "$ref" in schema:
|
||||
schema = self._resolve_ref(schema["$ref"])
|
||||
|
||||
return self._create_model("Response", schema)
|
||||
|
||||
def _create_model(self, name: str, schema: dict[str, Any]) -> type[BaseModel]:
|
||||
"""Create Pydantic model from schema.
|
||||
|
||||
@@ -285,6 +307,23 @@ class OperationParser:
|
||||
if "$ref" in schema:
|
||||
schema = self._resolve_ref(schema["$ref"])
|
||||
|
||||
# Handle discriminated unions
|
||||
if "anyOf" in schema:
|
||||
discriminator = schema.get("discriminator", {}).get("propertyName")
|
||||
if discriminator:
|
||||
union_types = []
|
||||
for subschema in schema["anyOf"]:
|
||||
if "$ref" in subschema:
|
||||
resolved = self._resolve_ref(subschema["$ref"])
|
||||
sub_model = self._create_model(f"{name}_type", resolved)
|
||||
union_types.append(sub_model)
|
||||
|
||||
fields = {
|
||||
"type_name": (str, ...), # Use type_name instead of __type
|
||||
**{k: (Union[tuple(union_types)], ...) for k in schema.get("properties", {}) if k != discriminator}, # noqa: UP007
|
||||
}
|
||||
return create_model(name, **fields)
|
||||
|
||||
if schema.get("type", "object") != "object":
|
||||
raise ValueError("Schema must be an object type")
|
||||
|
||||
@@ -293,6 +332,8 @@ class OperationParser:
|
||||
if "$ref" in prop_schema:
|
||||
prop_schema = self._resolve_ref(prop_schema["$ref"])
|
||||
|
||||
required = prop_name in schema.get("required", [])
|
||||
|
||||
if prop_schema.get("type") == "object":
|
||||
nested_model = self._create_model(f"{name}_{prop_name}", prop_schema)
|
||||
field_type = nested_model
|
||||
@@ -307,9 +348,27 @@ class OperationParser:
|
||||
item_type = self._map_type(items.get("type", "string"))
|
||||
field_type = list[item_type]
|
||||
else:
|
||||
field_type = self._map_type(prop_schema.get("type", "string"))
|
||||
field_type = self._map_type(prop_schema.get("type", "string"), prop_schema.get("format"), prop_schema)
|
||||
|
||||
# Add validation parameters
|
||||
field_config = {}
|
||||
if "pattern" in prop_schema:
|
||||
field_config["pattern"] = prop_schema["pattern"]
|
||||
if "minLength" in prop_schema:
|
||||
field_config["min_length"] = prop_schema["minLength"]
|
||||
if "maxLength" in prop_schema:
|
||||
field_config["max_length"] = prop_schema["maxLength"]
|
||||
if "minimum" in prop_schema:
|
||||
field_config["ge"] = prop_schema["minimum"]
|
||||
if "maximum" in prop_schema:
|
||||
field_config["le"] = prop_schema["maximum"]
|
||||
|
||||
# Create field with validation
|
||||
if field_config:
|
||||
field = Field(..., **field_config) if required else Field(None, **field_config)
|
||||
fields[prop_name] = (field_type, field)
|
||||
continue
|
||||
|
||||
required = prop_name in schema.get("required", [])
|
||||
fields[prop_name] = (field_type, ... if required else None)
|
||||
|
||||
logger.debug("Creating model %s with fields: %s", name, fields)
|
||||
|
||||
Reference in New Issue
Block a user