diff --git a/airflow-mcp-server/src/airflow_mcp_server/parser/operation_parser.py b/airflow-mcp-server/src/airflow_mcp_server/parser/operation_parser.py index 3f38a1d..fdddd5b 100644 --- a/airflow-mcp-server/src/airflow_mcp_server/parser/operation_parser.py +++ b/airflow-mcp-server/src/airflow_mcp_server/parser/operation_parser.py @@ -20,7 +20,6 @@ class OperationDetails: method: str parameters: dict[str, Any] input_model: type[BaseModel] - response_model: type[BaseModel] | None = None class OperationParser: @@ -83,7 +82,6 @@ class OperationParser: operation["path_item"] = path_item parameters = self.extract_parameters(operation) - response_model = self._parse_response_model(operation) # Get request body schema if present body_schema = None @@ -97,14 +95,7 @@ class OperationParser: # Create unified input model input_model = self._create_input_model(operation_id, parameters, body_schema) - return OperationDetails( - operation_id=operation_id, - path=str(path), - method=method, - parameters=parameters, - input_model=input_model, - response_model=response_model, - ) + return OperationDetails(operation_id=operation_id, path=str(path), method=method, parameters=parameters, input_model=input_model) raise ValueError(f"Operation {operation_id} not found in spec") @@ -124,24 +115,21 @@ class OperationParser: # Add path parameters for name, schema in parameters.get("path", {}).items(): - field_type = schema["type"] - required = schema.get("required", True) # Path parameters are required by default - fields[name] = (field_type, ... if required else None) + field_type = schema["type"] # Use the mapped type from parameter schema + fields[name] = (field_type | None, None) # Make all optional parameter_mapping["path"].append(name) # Add query parameters for name, schema in parameters.get("query", {}).items(): - field_type = schema["type"] - required = schema.get("required", False) # Query parameters are optional by default - fields[name] = (field_type, ... if required else None) + field_type = schema["type"] # Use the mapped type from parameter schema + fields[name] = (field_type | None, None) # Make all optional parameter_mapping["query"].append(name) # Add body fields if present if body_schema and body_schema.get("type") == "object": for prop_name, prop_schema in body_schema.get("properties", {}).items(): - field_type = self._map_type(prop_schema.get("type", "string"), prop_schema.get("format")) - required = prop_name in body_schema.get("required", []) - fields[prop_name] = (field_type, ... if required else None) + field_type = self._map_type(prop_schema.get("type", "string"), prop_schema.get("format"), prop_schema) + fields[prop_name] = (field_type | None, None) # Make all optional parameter_mapping["body"].append(prop_name) logger.debug("Creating input model for %s with fields: %s", operation_id, fields) @@ -224,8 +212,12 @@ class OperationParser: if "$ref" in schema: schema = self._resolve_ref(schema["$ref"]) + # Get the type and format from schema + openapi_type = schema.get("type", "string") + format_type = schema.get("format") + return { - "type": self._map_type(schema.get("type", "string")), + "type": self._map_type(openapi_type, format_type, schema), # Pass format_type and full schema "required": param.get("required", False), "default": schema.get("default"), "description": param.get("description"), @@ -291,7 +283,7 @@ class OperationParser: logger.error("Failed to create response model: %s", e) return None - def _create_model(self, name: str, schema: dict[str, Any]) -> type[BaseModel]: + def _create_model(self, name: str, schema: dict[str, Any]) -> type[BaseModel]: # noqa C901 """Create Pydantic model from schema. Args: diff --git a/airflow-mcp-server/src/airflow_mcp_server/tools/airflow_tool.py b/airflow-mcp-server/src/airflow_mcp_server/tools/airflow_tool.py index bd0b3cb..2c93d69 100644 --- a/airflow-mcp-server/src/airflow_mcp_server/tools/airflow_tool.py +++ b/airflow-mcp-server/src/airflow_mcp_server/tools/airflow_tool.py @@ -1,7 +1,7 @@ import logging from typing import Any -from pydantic import BaseModel, ValidationError +from pydantic import ValidationError from airflow_mcp_server.client.airflow_client import AirflowClient from airflow_mcp_server.parser.operation_parser import OperationDetails @@ -55,40 +55,27 @@ class AirflowTool(BaseTools): ) -> Any: """Execute the operation with provided parameters.""" try: - mapping = self.operation.input_model.model_config["parameter_mapping"] - body = body or {} - path_params = {k: body[k] for k in mapping.get("path", []) if k in body} - query_params = {k: body[k] for k in mapping.get("query", []) if k in body} - body_params = {k: body[k] for k in mapping.get("body", []) if k in body} + # Validate input + validated_input = self.operation.input_model(**(body or {})) + validated_body = validated_input.model_dump(exclude_none=True) # Only include non-None values - # Execute operation + mapping = self.operation.input_model.model_config["parameter_mapping"] + path_params = {k: validated_body[k] for k in mapping.get("path", []) if k in validated_body} + query_params = {k: validated_body[k] for k in mapping.get("query", []) if k in validated_body} + body_params = {k: validated_body[k] for k in mapping.get("body", []) if k in validated_body} + + # Execute operation and return raw response response = await self.client.execute( operation_id=self.operation.operation_id, - path_params=path_params, - query_params=query_params, - body=body_params, + path_params=path_params or None, + query_params=query_params or None, + body=body_params or None, ) - logger.debug("Raw response: %s", response) - - # Validate response if model exists - if self.operation.response_model and isinstance(response, dict): - try: - logger.debug("Response model schema: %s", self.operation.response_model.model_json_schema()) - model: type[BaseModel] = self.operation.response_model - logger.debug("Attempting to validate response with model: %s", model.__name__) - validated_response = model(**response) - logger.debug("Response validation successful") - result = validated_response.model_dump() - logger.debug("Final response after model_dump: %s", result) - return result - except ValidationError as e: - logger.error("Response validation failed: %s", e) - logger.error("Validation error details: %s", e.errors()) - raise RuntimeError(f"Invalid response format: {e}") - return response + except ValidationError: + raise except Exception as e: logger.error("Operation execution failed: %s", e) raise diff --git a/airflow-mcp-server/tests/client/test_airflow_client.py b/airflow-mcp-server/tests/client/test_airflow_client.py index 41cd497..9a2fc79 100644 --- a/airflow-mcp-server/tests/client/test_airflow_client.py +++ b/airflow-mcp-server/tests/client/test_airflow_client.py @@ -35,7 +35,7 @@ def client(spec_file: dict[str, Any]) -> AirflowClient: def test_init_client_initialization(client: AirflowClient) -> None: assert isinstance(client.spec, OpenAPI) assert client.base_url == "http://localhost:8080/api/v1" - assert client.headers["Authorization"] == "Bearer test-token" + assert client.headers["Authorization"] == "Basic test-token" def test_init_load_spec_from_bytes() -> None: diff --git a/airflow-mcp-server/tests/parser/test_operation_parser.py b/airflow-mcp-server/tests/parser/test_operation_parser.py index e112f91..a726a89 100644 --- a/airflow-mcp-server/tests/parser/test_operation_parser.py +++ b/airflow-mcp-server/tests/parser/test_operation_parser.py @@ -43,8 +43,11 @@ def test_parse_operation_with_path_params(parser: OperationParser) -> None: # Verify path parameter field exists fields = operation.input_model.__annotations__ - assert "path_dag_id" in fields - assert isinstance(fields["path_dag_id"], type(str)) + assert "dag_id" in fields + assert str in fields["dag_id"].__args__ # Check if str is in the Union types + + # Verify parameter is mapped correctly + assert "dag_id" in operation.input_model.model_config["parameter_mapping"]["path"] def test_parse_operation_with_query_params(parser: OperationParser) -> None: @@ -53,8 +56,11 @@ def test_parse_operation_with_query_params(parser: OperationParser) -> None: # Verify query parameter field exists fields = operation.input_model.__annotations__ - assert "query_limit" in fields - assert isinstance(fields["query_limit"], type(int)) + assert "limit" in fields + assert int in fields["limit"].__args__ # Check if int is in the Union types + + # Verify parameter is mapped correctly + assert "limit" in operation.input_model.model_config["parameter_mapping"]["query"] def test_parse_operation_with_body_params(parser: OperationParser) -> None: @@ -63,21 +69,11 @@ def test_parse_operation_with_body_params(parser: OperationParser) -> None: # Verify body fields exist fields = operation.input_model.__annotations__ - assert "body_dag_run_id" in fields - assert isinstance(fields["body_dag_run_id"], type(str)) + assert "dag_run_id" in fields + assert str in fields["dag_run_id"].__args__ # Check if str is in the Union types - -def test_parse_operation_with_response_model(parser: OperationParser) -> None: - """Test parsing operation with response model.""" - operation = parser.parse_operation("get_dag") - - assert operation.response_model is not None - assert issubclass(operation.response_model, BaseModel) - - # Test model fields - fields = operation.response_model.__annotations__ - assert "dag_id" in fields - assert "is_paused" in fields + # Verify parameter is mapped correctly + assert "dag_run_id" in operation.input_model.model_config["parameter_mapping"]["body"] def test_parse_operation_not_found(parser: OperationParser) -> None: @@ -118,7 +114,9 @@ def test_map_parameter_schema_nullable(parser: OperationParser) -> None: } result = parser._map_parameter_schema(param) - assert isinstance(result["type"], type(str)) + # Check that str is in the Union types + assert str in result["type"].__args__ + assert None.__class__ in result["type"].__args__ # Check for NoneType assert not result["required"] diff --git a/airflow-mcp-server/tests/tools/test_airflow_tool.py b/airflow-mcp-server/tests/tools/test_airflow_tool.py index 5705cfd..7981edb 100644 --- a/airflow-mcp-server/tests/tools/test_airflow_tool.py +++ b/airflow-mcp-server/tests/tools/test_airflow_tool.py @@ -6,7 +6,7 @@ from airflow_mcp_server.parser.operation_parser import OperationDetails from airflow_mcp_server.tools.airflow_tool import AirflowTool from pydantic import ValidationError -from tests.tools.test_models import TestRequestModel, TestResponseModel +from tests.tools.test_models import TestRequestModel @pytest.fixture @@ -20,20 +20,27 @@ def mock_client(mocker): @pytest.fixture def operation_details(): """Create test operation details.""" + model = TestRequestModel + # Add parameter mapping to model config + model.model_config["parameter_mapping"] = { + "path": ["path_id"], + "query": ["query_filter"], + "body": ["body_name", "body_value"], + } + return OperationDetails( operation_id="test_operation", - path="/test/{id}", + path="/test/{path_id}", method="POST", parameters={ "path": { - "id": {"type": int, "required": True}, + "path_id": {"type": int, "required": True}, }, "query": { - "filter": {"type": str, "required": False}, + "query_filter": {"type": str, "required": False}, }, }, - input_model=TestRequestModel, - response_model=TestResponseModel, + input_model=model, ) @@ -49,20 +56,23 @@ async def test_successful_execution(airflow_tool, mock_client): # Setup mock response mock_client.execute.return_value = {"item_id": 1, "result": "success"} - # Execute operation + # Execute operation with unified body result = await airflow_tool.run( - path_params={"id": 123}, - query_params={"filter": "test"}, - body={"name": "test", "value": 42}, + body={ + "path_id": 123, + "query_filter": "test", + "body_name": "test", + "body_value": 42, + } ) # Verify response assert result == {"item_id": 1, "result": "success"} mock_client.execute.assert_called_once_with( operation_id="test_operation", - path_params={"id": 123}, - query_params={"filter": "test"}, - body={"name": "test", "value": 42}, + path_params={"path_id": 123}, + query_params={"query_filter": "test"}, + body={"body_name": "test", "body_value": 42}, ) @@ -71,8 +81,11 @@ async def test_invalid_path_parameter(airflow_tool): """Test validation error for invalid path parameter type.""" with pytest.raises(ValidationError): await airflow_tool.run( - path_params={"id": "not_an_integer"}, - body={"name": "test", "value": 42}, + body={ + "path_id": "not_an_integer", # Invalid type + "body_name": "test", + "body_value": 42, + } ) @@ -81,22 +94,29 @@ async def test_invalid_request_body(airflow_tool): """Test validation error for invalid request body.""" with pytest.raises(ValidationError): await airflow_tool.run( - path_params={"id": 123}, - body={"name": "test", "value": "not_an_integer"}, + body={ + "path_id": 123, + "body_name": "test", + "body_value": "not_an_integer", # Invalid type + } ) @pytest.mark.asyncio async def test_invalid_response_format(airflow_tool, mock_client): """Test error handling for invalid response format.""" - # Setup mock response with invalid format + # Setup mock response mock_client.execute.return_value = {"invalid": "response"} - with pytest.raises(RuntimeError): - await airflow_tool.run( - path_params={"id": 123}, - body={"name": "test", "value": 42}, - ) + # Should not raise any validation error + result = await airflow_tool.run( + body={ + "path_id": 123, + "body_name": "test", + "body_value": 42, + } + ) + assert result == {"invalid": "response"} @pytest.mark.asyncio @@ -107,6 +127,9 @@ async def test_client_error(airflow_tool, mock_client): with pytest.raises(RuntimeError): await airflow_tool.run( - path_params={"id": 123}, - body={"name": "test", "value": 42}, + body={ + "path_id": 123, + "body_name": "test", + "body_value": 42, + } ) diff --git a/airflow-mcp-server/tests/tools/test_models.py b/airflow-mcp-server/tests/tools/test_models.py index 78a7459..4d2b170 100644 --- a/airflow-mcp-server/tests/tools/test_models.py +++ b/airflow-mcp-server/tests/tools/test_models.py @@ -10,10 +10,3 @@ class TestRequestModel(BaseModel): query_filter: str | None = None body_name: str body_value: int - - -class TestResponseModel(BaseModel): - """Test response model.""" - - item_id: int - result: str diff --git a/airflow-mcp-server/uv.lock b/airflow-mcp-server/uv.lock index 24c9800..15e19b1 100644 --- a/airflow-mcp-server/uv.lock +++ b/airflow-mcp-server/uv.lock @@ -111,7 +111,7 @@ wheels = [ [[package]] name = "airflow-mcp-server" -version = "0.1.0" +version = "0.2.0" source = { editable = "." } dependencies = [ { name = "aiofiles" },