From 98c1486f56ec007131b509d2dcc2e435e0f100ea Mon Sep 17 00:00:00 2001 From: abhishekbhakat Date: Thu, 13 Feb 2025 16:38:18 +0000 Subject: [PATCH] Combined pydantic model with parameters and request body --- .../parser/operation_parser.py | 125 +++++++++-------- .../src/airflow_mcp_server/server.py | 2 +- .../airflow_mcp_server/tools/airflow_tool.py | 127 +++++++----------- .../airflow_mcp_server/tools/tool_manager.py | 2 +- .../tests/parser/test_operation_parser.py | 33 +++-- .../tests/tools/test_airflow_tool.py | 2 +- airflow-mcp-server/tests/tools/test_models.py | 6 +- 7 files changed, 134 insertions(+), 163 deletions(-) diff --git a/airflow-mcp-server/src/airflow_mcp_server/parser/operation_parser.py b/airflow-mcp-server/src/airflow_mcp_server/parser/operation_parser.py index 9b5a7c1..588cbf2 100644 --- a/airflow-mcp-server/src/airflow_mcp_server/parser/operation_parser.py +++ b/airflow-mcp-server/src/airflow_mcp_server/parser/operation_parser.py @@ -18,7 +18,7 @@ class OperationDetails: path: str method: str parameters: dict[str, Any] - request_body: type[BaseModel] | None = None + input_model: type[BaseModel] response_model: type[BaseModel] | None = None @@ -35,7 +35,6 @@ class OperationParser: ValueError: If spec_path is invalid or spec cannot be loaded """ try: - # Load and parse OpenAPI spec if isinstance(spec_path, bytes): self.raw_spec = yaml.safe_load(spec_path) elif isinstance(spec_path, dict): @@ -48,7 +47,6 @@ class OperationParser: else: raise ValueError(f"Invalid spec_path type: {type(spec_path)}. Expected Path, str, dict, bytes or file-like object") - # Initialize OpenAPI spec spec = OpenAPI.from_dict(self.raw_spec) self.spec = spec self._paths = self.raw_spec["paths"] @@ -72,7 +70,6 @@ class OperationParser: ValueError: If operation not found or invalid """ try: - # Find operation in spec for path, path_item in self._paths.items(): for method, operation in path_item.items(): if method.startswith("x-") or method == "parameters": @@ -81,21 +78,30 @@ class OperationParser: if operation.get("operationId") == operation_id: logger.debug("Found operation %s at %s %s", operation_id, method, path) - # Add path to operation for parameter context operation["path"] = path operation["path_item"] = path_item - # Extract operation details parameters = self.extract_parameters(operation) - request_body = self._parse_request_body(operation) response_model = self._parse_response_model(operation) + # Get request body schema if present + body_schema = None + if "requestBody" in operation: + content = operation["requestBody"].get("content", {}) + if "application/json" in content: + body_schema = content["application/json"].get("schema", {}) + if "$ref" in body_schema: + body_schema = self._resolve_ref(body_schema["$ref"]) + + # Create unified input model + input_model = self._create_input_model(operation_id, parameters, body_schema) + return OperationDetails( operation_id=operation_id, path=str(path), method=method, parameters=parameters, - request_body=request_body, + input_model=input_model, response_model=response_model, ) @@ -105,6 +111,37 @@ class OperationParser: logger.error("Error parsing operation %s: %s", operation_id, e) raise + def _create_input_model( + self, + operation_id: str, + parameters: dict[str, Any], + body_schema: dict[str, Any] | None = None, + ) -> type[BaseModel]: + """Create unified input model for all parameters.""" + fields: dict[str, tuple[type, Any]] = {} + + # Add path parameters + for name, schema in parameters.get("path", {}).items(): + field_type = schema["type"] + required = schema.get("required", True) # Path parameters are required by default + fields[f"path_{name}"] = (field_type, ... if required else None) + + # Add query parameters + for name, schema in parameters.get("query", {}).items(): + field_type = schema["type"] + required = schema.get("required", False) # Query parameters are optional by default + fields[f"query_{name}"] = (field_type, ... if required else None) + + # Add body fields if present + if body_schema and body_schema.get("type") == "object": + for prop_name, prop_schema in body_schema.get("properties", {}).items(): + field_type = self._map_type(prop_schema.get("type", "string")) + required = prop_name in body_schema.get("required", []) + fields[f"body_{prop_name}"] = (field_type, ... if required else None) + + logger.debug("Creating input model for %s with fields: %s", operation_id, fields) + return create_model(f"{operation_id}_input", **fields) + def extract_parameters(self, operation: dict[str, Any]) -> dict[str, Any]: """Extract and categorize operation parameters. @@ -120,12 +157,10 @@ class OperationParser: "header": {}, } - # Handle path-level parameters path_item = operation.get("path_item", {}) if path_item and "parameters" in path_item: self._process_parameters(path_item["parameters"], parameters) - # Handle operation-level parameters self._process_parameters(operation.get("parameters", []), parameters) return parameters @@ -138,11 +173,9 @@ class OperationParser: target: Target dictionary to store processed parameters """ for param in params: - # Resolve parameter reference if needed if "$ref" in param: param = self._resolve_ref(param["$ref"]) - # Validate parameter structure if not isinstance(param, dict) or "in" not in param: logger.warning("Invalid parameter format: %s", param) continue @@ -165,7 +198,7 @@ class OperationParser: parts = ref.split("/") current = self.raw_spec - for part in parts[1:]: # Skip first '#' + for part in parts[1:]: current = current[part] self._schema_cache[ref] = current @@ -192,14 +225,7 @@ class OperationParser: } def _map_type(self, openapi_type: str) -> type: - """Map OpenAPI type to Python type. - - Args: - openapi_type: OpenAPI type string - - Returns: - Corresponding Python type - """ + """Map OpenAPI type to Python type.""" type_map = { "string": str, "integer": int, @@ -210,28 +236,6 @@ class OperationParser: } return type_map.get(openapi_type, Any) - def _parse_request_body(self, operation: dict[str, Any]) -> type[BaseModel] | None: - """Parse request body schema into Pydantic model. - - Args: - operation: Operation object from OpenAPI spec - - Returns: - Pydantic model for request body or None - """ - if "requestBody" not in operation: - return None - - content = operation["requestBody"].get("content", {}) - if "application/json" not in content: - return None - - schema = content["application/json"].get("schema", {}) - if "$ref" in schema: - schema = self._resolve_ref(schema["$ref"]) - - return self._create_model("RequestBody", schema) - def _parse_response_model(self, operation: dict[str, Any]) -> type[BaseModel] | None: """Parse response schema into Pydantic model. @@ -259,23 +263,6 @@ class OperationParser: return self._create_model("Response", schema) - def get_operations(self) -> list[str]: - """Get list of all operation IDs from spec. - - Returns: - List of operation IDs - """ - operations = [] - - for path in self._paths.values(): - for method, operation in path.items(): - if method.startswith("x-") or method == "parameters": - continue - if "operationId" in operation: - operations.append(operation["operationId"]) - - return operations - def _create_model(self, name: str, schema: dict[str, Any]) -> type[BaseModel]: """Create Pydantic model from schema. @@ -292,21 +279,18 @@ class OperationParser: if "$ref" in schema: schema = self._resolve_ref(schema["$ref"]) - if schema.get("type") != "object": + if schema.get("type", "object") != "object": raise ValueError("Schema must be an object type") fields = {} for prop_name, prop_schema in schema.get("properties", {}).items(): - # Resolve property schema reference if needed if "$ref" in prop_schema: prop_schema = self._resolve_ref(prop_schema["$ref"]) if prop_schema.get("type") == "object": - # Create nested model nested_model = self._create_model(f"{name}_{prop_name}", prop_schema) field_type = nested_model elif prop_schema.get("type") == "array": - # Handle array types items = prop_schema.get("items", {}) if "$ref" in items: items = self._resolve_ref(items["$ref"]) @@ -328,3 +312,16 @@ class OperationParser: except Exception as e: logger.error("Error creating model %s: %s", name, e) raise ValueError(f"Failed to create model {name}: {e}") + + def get_operations(self) -> list[str]: + """Get list of all operation IDs from spec.""" + operations = [] + + for path in self._paths.values(): + for method, operation in path.items(): + if method.startswith("x-") or method == "parameters": + continue + if "operationId" in operation: + operations.append(operation["operationId"]) + + return operations diff --git a/airflow-mcp-server/src/airflow_mcp_server/server.py b/airflow-mcp-server/src/airflow_mcp_server/server.py index 0332da1..8691efe 100644 --- a/airflow-mcp-server/src/airflow_mcp_server/server.py +++ b/airflow-mcp-server/src/airflow_mcp_server/server.py @@ -35,4 +35,4 @@ async def serve() -> None: options = server.create_initialization_options() async with stdio_server() as (read_stream, write_stream): - server.run(read_stream, write_stream, options, raise_exceptions=True) + await server.run(read_stream, write_stream, options, raise_exceptions=True) diff --git a/airflow-mcp-server/src/airflow_mcp_server/tools/airflow_tool.py b/airflow-mcp-server/src/airflow_mcp_server/tools/airflow_tool.py index 96d1920..75977df 100644 --- a/airflow-mcp-server/src/airflow_mcp_server/tools/airflow_tool.py +++ b/airflow-mcp-server/src/airflow_mcp_server/tools/airflow_tool.py @@ -1,10 +1,11 @@ import logging from typing import Any +from pydantic import BaseModel, ValidationError + from airflow_mcp_server.client.airflow_client import AirflowClient from airflow_mcp_server.parser.operation_parser import OperationDetails from airflow_mcp_server.tools.base_tools import BaseTools -from pydantic import BaseModel, ValidationError logger = logging.getLogger(__name__) @@ -48,108 +49,82 @@ class AirflowTool(BaseTools): self.operation = operation_details self.client = client - def _validate_parameters( + def _validate_input( self, path_params: dict[str, Any] | None = None, query_params: dict[str, Any] | None = None, body: dict[str, Any] | None = None, - ) -> tuple[dict[str, Any] | None, dict[str, Any] | None, dict[str, Any] | None]: - """Validate input parameters against operation schemas. + ) -> dict[str, Any]: + """Validate input parameters using unified input model. Args: - path_params: URL path parameters - query_params: URL query parameters - body: Request body data + path_params: Path parameters + query_params: Query parameters + body: Body parameters Returns: - Tuple of validated (path_params, query_params, body) - - Raises: - ValidationError: If parameters fail validation + dict[str, Any]: Validated input parameters """ - validated_params: dict[str, dict[str, Any] | None] = { - "path": None, - "query": None, - "body": None, - } - try: - # Validate path parameters - if path_params and "path" in self.operation.parameters: - path_schema = self.operation.parameters["path"] - for name, value in path_params.items(): - if name in path_schema: - param_type = path_schema[name]["type"] - if not isinstance(value, param_type): - raise create_validation_error( - field=name, - message=f"Path parameter {name} must be of type {param_type.__name__}", - ) - validated_params["path"] = path_params + input_data = {} - # Validate query parameters - if query_params and "query" in self.operation.parameters: - query_schema = self.operation.parameters["query"] - for name, value in query_params.items(): - if name in query_schema: - param_type = query_schema[name]["type"] - if not isinstance(value, param_type): - raise create_validation_error( - field=name, - message=f"Query parameter {name} must be of type {param_type.__name__}", - ) - validated_params["query"] = query_params + if path_params: + input_data.update({f"path_{k}": v for k, v in path_params.items()}) - # Validate request body - if body and self.operation.request_body: - try: - model: type[BaseModel] = self.operation.request_body - validated_body = model(**body) - validated_params["body"] = validated_body.model_dump() - except ValidationError as e: - # Re-raise Pydantic validation errors directly - raise e + if query_params: + input_data.update({f"query_{k}": v for k, v in query_params.items()}) - return ( - validated_params["path"], - validated_params["query"], - validated_params["body"], - ) + if body: + input_data.update({f"body_{k}": v for k, v in body.items()}) - except Exception as e: - logger.error("Parameter validation failed: %s", e) + validated = self.operation.input_model(**input_data) + return validated.model_dump() + + except ValidationError as e: + logger.error("Input validation failed: %s", e) raise + def _extract_parameters(self, validated_input: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]: + """Extract validated parameters by type.""" + path_params = {} + query_params = {} + body = {} + + # Extract parameters based on operation definition + for key, value in validated_input.items(): + # Remove prefix from key if present + param_key = key + if key.startswith(("path_", "query_", "body_")): + param_key = key.split("_", 1)[1] + + if key.startswith("path_"): + path_params[param_key] = value + elif key.startswith("query_"): + query_params[param_key] = value + elif key.startswith("body_"): + body[param_key] = value + else: + body[key] = value + + return path_params, query_params, body + async def run( self, path_params: dict[str, Any] | None = None, query_params: dict[str, Any] | None = None, body: dict[str, Any] | None = None, ) -> Any: - """Execute the operation with provided parameters. - - Args: - path_params: URL path parameters - query_params: URL query parameters - body: Request body data - - Returns: - API response data - - Raises: - ValidationError: If parameters fail validation - RuntimeError: If client execution fails - """ + """Execute the operation with provided parameters.""" try: - # Validate parameters - validated_path_params, validated_query_params, validated_body = self._validate_parameters(path_params, query_params, body) + validated_input = self._validate_input(path_params, query_params, body) + path_params, query_params, body = self._extract_parameters(validated_input) # Execute operation response = await self.client.execute( operation_id=self.operation.operation_id, - path_params=validated_path_params, - query_params=validated_query_params, - body=validated_body, + path_params=path_params, + query_params=query_params, + body=body, ) # Validate response if model exists diff --git a/airflow-mcp-server/src/airflow_mcp_server/tools/tool_manager.py b/airflow-mcp-server/src/airflow_mcp_server/tools/tool_manager.py index 9306729..fc7792a 100644 --- a/airflow-mcp-server/src/airflow_mcp_server/tools/tool_manager.py +++ b/airflow-mcp-server/src/airflow_mcp_server/tools/tool_manager.py @@ -52,7 +52,7 @@ def get_airflow_tools() -> list[Tool]: Tool( name=operation_id, description=tool.operation.operation_id, - inputSchema=tool.operation.request_body.model_json_schema() if tool.operation.request_body else None, + inputSchema=tool.operation.input_model.model_json_schema(), ) for operation_id, tool in _tools_cache.items() ] diff --git a/airflow-mcp-server/tests/parser/test_operation_parser.py b/airflow-mcp-server/tests/parser/test_operation_parser.py index afd778a..e112f91 100644 --- a/airflow-mcp-server/tests/parser/test_operation_parser.py +++ b/airflow-mcp-server/tests/parser/test_operation_parser.py @@ -39,33 +39,32 @@ def test_parse_operation_with_path_params(parser: OperationParser) -> None: operation = parser.parse_operation("get_dag") assert operation.path == "/dags/{dag_id}" - assert "dag_id" in operation.parameters["path"] - param = operation.parameters["path"]["dag_id"] - assert isinstance(param["type"], type(str)) - assert param["required"] is True + assert isinstance(operation.input_model, type(BaseModel)) + + # Verify path parameter field exists + fields = operation.input_model.__annotations__ + assert "path_dag_id" in fields + assert isinstance(fields["path_dag_id"], type(str)) def test_parse_operation_with_query_params(parser: OperationParser) -> None: """Test parsing operation with query parameters.""" operation = parser.parse_operation("get_dags") - assert "limit" in operation.parameters["query"] - param = operation.parameters["query"]["limit"] - assert isinstance(param["type"], type(int)) - assert param["required"] is False + # Verify query parameter field exists + fields = operation.input_model.__annotations__ + assert "query_limit" in fields + assert isinstance(fields["query_limit"], type(int)) -def test_parse_operation_with_request_body(parser: OperationParser) -> None: +def test_parse_operation_with_body_params(parser: OperationParser) -> None: """Test parsing operation with request body.""" operation = parser.parse_operation("post_dag_run") - assert operation.request_body is not None - assert issubclass(operation.request_body, BaseModel) - - # Test model fields - fields = operation.request_body.__annotations__ - assert "dag_run_id" in fields - assert isinstance(fields["dag_run_id"], type(str)) + # Verify body fields exist + fields = operation.input_model.__annotations__ + assert "body_dag_run_id" in fields + assert isinstance(fields["body_dag_run_id"], type(str)) def test_parse_operation_with_response_model(parser: OperationParser) -> None: @@ -140,9 +139,7 @@ def test_create_model_nested_objects(parser: OperationParser) -> None: assert issubclass(model, BaseModel) fields = model.__annotations__ assert "nested" in fields - # Check that nested field is a Pydantic model assert issubclass(fields["nested"], BaseModel) - # Verify nested model structure nested_fields = fields["nested"].__annotations__ assert "field" in nested_fields assert isinstance(nested_fields["field"], type(str)) diff --git a/airflow-mcp-server/tests/tools/test_airflow_tool.py b/airflow-mcp-server/tests/tools/test_airflow_tool.py index f5537cf..5705cfd 100644 --- a/airflow-mcp-server/tests/tools/test_airflow_tool.py +++ b/airflow-mcp-server/tests/tools/test_airflow_tool.py @@ -32,7 +32,7 @@ def operation_details(): "filter": {"type": str, "required": False}, }, }, - request_body=TestRequestModel, + input_model=TestRequestModel, response_model=TestResponseModel, ) diff --git a/airflow-mcp-server/tests/tools/test_models.py b/airflow-mcp-server/tests/tools/test_models.py index d7f9ca2..78a7459 100644 --- a/airflow-mcp-server/tests/tools/test_models.py +++ b/airflow-mcp-server/tests/tools/test_models.py @@ -6,8 +6,10 @@ from pydantic import BaseModel class TestRequestModel(BaseModel): """Test request model.""" - name: str - value: int + path_id: int + query_filter: str | None = None + body_name: str + body_value: int class TestResponseModel(BaseModel):