diff --git a/airflow-mcp-server/src/airflow_mcp_server/parser/operation_parser.py b/airflow-mcp-server/src/airflow_mcp_server/parser/operation_parser.py index fca8ca5..3f38a1d 100644 --- a/airflow-mcp-server/src/airflow_mcp_server/parser/operation_parser.py +++ b/airflow-mcp-server/src/airflow_mcp_server/parser/operation_parser.py @@ -1,11 +1,12 @@ import logging from dataclasses import dataclass +from datetime import datetime from pathlib import Path -from typing import Any +from typing import Any, Literal, Union import yaml from openapi_core import OpenAPI -from pydantic import BaseModel, create_model +from pydantic import BaseModel, Field, create_model logger = logging.getLogger(__name__) @@ -138,7 +139,7 @@ class OperationParser: # Add body fields if present if body_schema and body_schema.get("type") == "object": for prop_name, prop_schema in body_schema.get("properties", {}).items(): - field_type = self._map_type(prop_schema.get("type", "string")) + field_type = self._map_type(prop_schema.get("type", "string"), prop_schema.get("format")) required = prop_name in body_schema.get("required", []) fields[prop_name] = (field_type, ... if required else None) parameter_mapping["body"].append(prop_name) @@ -230,45 +231,66 @@ class OperationParser: "description": param.get("description"), } - def _map_type(self, openapi_type: str) -> type: + def _map_type(self, openapi_type: str, format_type: str | None = None, schema: dict[str, Any] | None = None) -> type: """Map OpenAPI type to Python type.""" - type_map = { + # Handle enums + if schema and "enum" in schema: + return Literal[tuple(schema["enum"])] # type: ignore + + # Handle date-time format + if openapi_type == "string" and format_type == "date-time": + return datetime + + # Handle nullable fields + base_type = { "string": str, "integer": int, "number": float, "boolean": bool, "array": list, "object": dict, - } - return type_map.get(openapi_type, Any) + }.get(openapi_type, Any) + + if schema and schema.get("nullable"): + return Union[base_type, None] # noqa: UP007 + + return base_type def _parse_response_model(self, operation: dict[str, Any]) -> type[BaseModel] | None: - """Parse response schema into Pydantic model. + """Parse response schema into Pydantic model.""" + try: + responses = operation.get("responses", {}) + if "200" not in responses: + return None - Args: - operation: Operation object from OpenAPI spec + response = responses["200"] + if "$ref" in response: + response = self._resolve_ref(response["$ref"]) - Returns: - Pydantic model for response or None - """ - responses = operation.get("responses", {}) - if "200" not in responses: + content = response.get("content", {}) + if "application/json" not in content: + return None + + schema = content["application/json"].get("schema", {}) + if "$ref" in schema: + schema = self._resolve_ref(schema["$ref"]) + + logger.debug("Response schema before model creation: %s", schema) + + # Ensure schema has properties + if "properties" not in schema: + logger.error("Response schema missing properties") + return None + + # Create model with schema properties + model = self._create_model("Response", schema) + logger.debug("Created response model with schema: %s", model.model_json_schema()) + return model + + except Exception as e: + logger.error("Failed to create response model: %s", e) return None - response = responses["200"] - if "$ref" in response: - response = self._resolve_ref(response["$ref"]) - - content = response.get("content", {}) - if "application/json" not in content: - return None - - schema = content["application/json"].get("schema", {}) - if "$ref" in schema: - schema = self._resolve_ref(schema["$ref"]) - - return self._create_model("Response", schema) - def _create_model(self, name: str, schema: dict[str, Any]) -> type[BaseModel]: """Create Pydantic model from schema. @@ -285,6 +307,23 @@ class OperationParser: if "$ref" in schema: schema = self._resolve_ref(schema["$ref"]) + # Handle discriminated unions + if "anyOf" in schema: + discriminator = schema.get("discriminator", {}).get("propertyName") + if discriminator: + union_types = [] + for subschema in schema["anyOf"]: + if "$ref" in subschema: + resolved = self._resolve_ref(subschema["$ref"]) + sub_model = self._create_model(f"{name}_type", resolved) + union_types.append(sub_model) + + fields = { + "type_name": (str, ...), # Use type_name instead of __type + **{k: (Union[tuple(union_types)], ...) for k in schema.get("properties", {}) if k != discriminator}, # noqa: UP007 + } + return create_model(name, **fields) + if schema.get("type", "object") != "object": raise ValueError("Schema must be an object type") @@ -293,6 +332,8 @@ class OperationParser: if "$ref" in prop_schema: prop_schema = self._resolve_ref(prop_schema["$ref"]) + required = prop_name in schema.get("required", []) + if prop_schema.get("type") == "object": nested_model = self._create_model(f"{name}_{prop_name}", prop_schema) field_type = nested_model @@ -307,9 +348,27 @@ class OperationParser: item_type = self._map_type(items.get("type", "string")) field_type = list[item_type] else: - field_type = self._map_type(prop_schema.get("type", "string")) + field_type = self._map_type(prop_schema.get("type", "string"), prop_schema.get("format"), prop_schema) + + # Add validation parameters + field_config = {} + if "pattern" in prop_schema: + field_config["pattern"] = prop_schema["pattern"] + if "minLength" in prop_schema: + field_config["min_length"] = prop_schema["minLength"] + if "maxLength" in prop_schema: + field_config["max_length"] = prop_schema["maxLength"] + if "minimum" in prop_schema: + field_config["ge"] = prop_schema["minimum"] + if "maximum" in prop_schema: + field_config["le"] = prop_schema["maximum"] + + # Create field with validation + if field_config: + field = Field(..., **field_config) if required else Field(None, **field_config) + fields[prop_name] = (field_type, field) + continue - required = prop_name in schema.get("required", []) fields[prop_name] = (field_type, ... if required else None) logger.debug("Creating model %s with fields: %s", name, fields)