From c152852767803107fd5959930042359da183d31f Mon Sep 17 00:00:00 2001 From: abhishekbhakat Date: Wed, 12 Feb 2025 13:30:20 +0000 Subject: [PATCH] Airflow Tools Generation --- airflow-mcp-server/pyproject.toml | 6 + .../airflow_mcp_server/tools/airflow_tool.py | 166 ++++++++++++++++++ airflow-mcp-server/tests/tools/__init__.py | 0 .../tests/tools/test_airflow_tool.py | 112 ++++++++++++ airflow-mcp-server/tests/tools/test_models.py | 17 ++ 5 files changed, 301 insertions(+) create mode 100644 airflow-mcp-server/src/airflow_mcp_server/tools/airflow_tool.py create mode 100644 airflow-mcp-server/tests/tools/__init__.py create mode 100644 airflow-mcp-server/tests/tools/test_airflow_tool.py create mode 100644 airflow-mcp-server/tests/tools/test_models.py diff --git a/airflow-mcp-server/pyproject.toml b/airflow-mcp-server/pyproject.toml index bc634ba..70f1d4a 100644 --- a/airflow-mcp-server/pyproject.toml +++ b/airflow-mcp-server/pyproject.toml @@ -22,6 +22,7 @@ dev = [ "pre-commit>=4.0.1", "pytest>=8.3.4", "pytest-asyncio>=0.25.0", + "pytest-mock>=3.14.0", "ruff>=0.9.2" ] @@ -33,6 +34,8 @@ build-backend = "hatchling.build" pythonpath = ["src"] asyncio_mode = "strict" testpaths = ["tests"] +python_classes = "!TestRequestModel,!TestResponseModel" +asyncio_default_fixture_loop_scope = "function" [tool.ruff] line-length = 200 @@ -69,3 +72,6 @@ skip-magic-trailing-comma = false [tool.ruff.lint.isort] combine-as-imports = true + +[tool.ruff.lint.mccabe] +max-complexity = 12 diff --git a/airflow-mcp-server/src/airflow_mcp_server/tools/airflow_tool.py b/airflow-mcp-server/src/airflow_mcp_server/tools/airflow_tool.py new file mode 100644 index 0000000..096cc88 --- /dev/null +++ b/airflow-mcp-server/src/airflow_mcp_server/tools/airflow_tool.py @@ -0,0 +1,166 @@ +import logging +from typing import Any + +from airflow_mcp_server.client.airflow_client import AirflowClient +from airflow_mcp_server.parser.operation_parser import OperationDetails +from airflow_mcp_server.tools.base_tools import BaseTools +from pydantic import BaseModel, ValidationError + +logger = logging.getLogger(__name__) + + +def create_validation_error(field: str, message: str) -> ValidationError: + """Create a properly formatted validation error. + + Args: + field: The field that failed validation + message: The error message + + Returns: + ValidationError: A properly formatted validation error + """ + errors = [ + { + "loc": (field,), + "msg": message, + "type": "value_error", + "input": None, + "ctx": {"error": message}, + } + ] + return ValidationError.from_exception_data("validation_error", errors) + + +class AirflowTool(BaseTools): + """Tool for executing Airflow API operations.""" + + def __init__(self, operation_details: OperationDetails, client: AirflowClient) -> None: + """Initialize tool with operation details and client. + + Args: + operation_details: Parsed operation details + client: Configured Airflow API client + """ + super().__init__() + self.operation = operation_details + self.client = client + + def _validate_parameters( + self, + path_params: dict[str, Any] | None = None, + query_params: dict[str, Any] | None = None, + body: dict[str, Any] | None = None, + ) -> tuple[dict[str, Any] | None, dict[str, Any] | None, dict[str, Any] | None]: + """Validate input parameters against operation schemas. + + Args: + path_params: URL path parameters + query_params: URL query parameters + body: Request body data + + Returns: + Tuple of validated (path_params, query_params, body) + + Raises: + ValidationError: If parameters fail validation + """ + validated_params: dict[str, dict[str, Any] | None] = { + "path": None, + "query": None, + "body": None, + } + + try: + # Validate path parameters + if path_params and "path" in self.operation.parameters: + path_schema = self.operation.parameters["path"] + for name, value in path_params.items(): + if name in path_schema: + param_type = path_schema[name]["type"] + if not isinstance(value, param_type): + raise create_validation_error( + field=name, + message=f"Path parameter {name} must be of type {param_type.__name__}", + ) + validated_params["path"] = path_params + + # Validate query parameters + if query_params and "query" in self.operation.parameters: + query_schema = self.operation.parameters["query"] + for name, value in query_params.items(): + if name in query_schema: + param_type = query_schema[name]["type"] + if not isinstance(value, param_type): + raise create_validation_error( + field=name, + message=f"Query parameter {name} must be of type {param_type.__name__}", + ) + validated_params["query"] = query_params + + # Validate request body + if body and self.operation.request_body: + try: + model: type[BaseModel] = self.operation.request_body + validated_body = model(**body) + validated_params["body"] = validated_body.model_dump() + except ValidationError as e: + # Re-raise Pydantic validation errors directly + raise e + + return ( + validated_params["path"], + validated_params["query"], + validated_params["body"], + ) + + except Exception as e: + logger.error("Parameter validation failed: %s", e) + raise + + async def run( + self, + path_params: dict[str, Any] | None = None, + query_params: dict[str, Any] | None = None, + body: dict[str, Any] | None = None, + ) -> Any: + """Execute the operation with provided parameters. + + Args: + path_params: URL path parameters + query_params: URL query parameters + body: Request body data + + Returns: + API response data + + Raises: + ValidationError: If parameters fail validation + RuntimeError: If client execution fails + """ + try: + # Validate parameters + validated_path_params, validated_query_params, validated_body = self._validate_parameters(path_params, query_params, body) + + # Execute operation + response = await self.client.execute( + operation_id=self.operation.operation_id, + path_params=validated_path_params, + query_params=validated_query_params, + body=validated_body, + ) + + # Validate response if model exists + if self.operation.response_model and isinstance(response, dict): + try: + model: type[BaseModel] = self.operation.response_model + validated_response = model(**response) + return validated_response.model_dump() + except ValidationError as e: + logger.error("Response validation failed: %s", e) + raise RuntimeError(f"Invalid response format: {e}") + + return response + + except Exception as e: + logger.error("Operation execution failed: %s", e) + raise diff --git a/airflow-mcp-server/tests/tools/__init__.py b/airflow-mcp-server/tests/tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/airflow-mcp-server/tests/tools/test_airflow_tool.py b/airflow-mcp-server/tests/tools/test_airflow_tool.py new file mode 100644 index 0000000..f5537cf --- /dev/null +++ b/airflow-mcp-server/tests/tools/test_airflow_tool.py @@ -0,0 +1,112 @@ +"""Tests for AirflowTool.""" + +import pytest +from airflow_mcp_server.client.airflow_client import AirflowClient +from airflow_mcp_server.parser.operation_parser import OperationDetails +from airflow_mcp_server.tools.airflow_tool import AirflowTool +from pydantic import ValidationError + +from tests.tools.test_models import TestRequestModel, TestResponseModel + + +@pytest.fixture +def mock_client(mocker): + """Create mock Airflow client.""" + client = mocker.Mock(spec=AirflowClient) + client.execute = mocker.AsyncMock() + return client + + +@pytest.fixture +def operation_details(): + """Create test operation details.""" + return OperationDetails( + operation_id="test_operation", + path="/test/{id}", + method="POST", + parameters={ + "path": { + "id": {"type": int, "required": True}, + }, + "query": { + "filter": {"type": str, "required": False}, + }, + }, + request_body=TestRequestModel, + response_model=TestResponseModel, + ) + + +@pytest.fixture +def airflow_tool(mock_client, operation_details): + """Create AirflowTool instance for testing.""" + return AirflowTool(operation_details, mock_client) + + +@pytest.mark.asyncio +async def test_successful_execution(airflow_tool, mock_client): + """Test successful operation execution with valid parameters.""" + # Setup mock response + mock_client.execute.return_value = {"item_id": 1, "result": "success"} + + # Execute operation + result = await airflow_tool.run( + path_params={"id": 123}, + query_params={"filter": "test"}, + body={"name": "test", "value": 42}, + ) + + # Verify response + assert result == {"item_id": 1, "result": "success"} + mock_client.execute.assert_called_once_with( + operation_id="test_operation", + path_params={"id": 123}, + query_params={"filter": "test"}, + body={"name": "test", "value": 42}, + ) + + +@pytest.mark.asyncio +async def test_invalid_path_parameter(airflow_tool): + """Test validation error for invalid path parameter type.""" + with pytest.raises(ValidationError): + await airflow_tool.run( + path_params={"id": "not_an_integer"}, + body={"name": "test", "value": 42}, + ) + + +@pytest.mark.asyncio +async def test_invalid_request_body(airflow_tool): + """Test validation error for invalid request body.""" + with pytest.raises(ValidationError): + await airflow_tool.run( + path_params={"id": 123}, + body={"name": "test", "value": "not_an_integer"}, + ) + + +@pytest.mark.asyncio +async def test_invalid_response_format(airflow_tool, mock_client): + """Test error handling for invalid response format.""" + # Setup mock response with invalid format + mock_client.execute.return_value = {"invalid": "response"} + + with pytest.raises(RuntimeError): + await airflow_tool.run( + path_params={"id": 123}, + body={"name": "test", "value": 42}, + ) + + +@pytest.mark.asyncio +async def test_client_error(airflow_tool, mock_client): + """Test error handling for client execution failure.""" + # Setup mock to raise exception + mock_client.execute.side_effect = RuntimeError("API Error") + + with pytest.raises(RuntimeError): + await airflow_tool.run( + path_params={"id": 123}, + body={"name": "test", "value": 42}, + ) diff --git a/airflow-mcp-server/tests/tools/test_models.py b/airflow-mcp-server/tests/tools/test_models.py new file mode 100644 index 0000000..d7f9ca2 --- /dev/null +++ b/airflow-mcp-server/tests/tools/test_models.py @@ -0,0 +1,17 @@ +"""Test models for Airflow tool tests.""" + +from pydantic import BaseModel + + +class TestRequestModel(BaseModel): + """Test request model.""" + + name: str + value: int + + +class TestResponseModel(BaseModel): + """Test response model.""" + + item_id: int + result: str