From c5565e6a00f59b7e915a07db73a3207133c53959 Mon Sep 17 00:00:00 2001 From: abhishekbhakat Date: Sun, 4 May 2025 04:19:39 +0000 Subject: [PATCH] feat: implement async operation execution and validation in AirflowClient; enhance tool initialization --- .../client/airflow_client.py | 65 +++++++++++++++++++ src/airflow_mcp_server/tools/tool_manager.py | 25 ++----- tests/tools/test_airflow_tool.py | 5 +- 3 files changed, 75 insertions(+), 20 deletions(-) diff --git a/src/airflow_mcp_server/client/airflow_client.py b/src/airflow_mcp_server/client/airflow_client.py index 7ec9afe..ba081c3 100644 --- a/src/airflow_mcp_server/client/airflow_client.py +++ b/src/airflow_mcp_server/client/airflow_client.py @@ -92,3 +92,68 @@ class AirflowClient: except httpx.RequestError as e: raise ValueError(f"Failed to fetch OpenAPI spec from {url}: {e}") return response.json() + + def _get_operation(self, operation_id: str): + """Get operation details from OpenAPI spec.""" + for path, path_item in self._paths.items(): + for method, operation_data in path_item.items(): + if method.startswith("x-") or method == "parameters": + continue + if operation_data.get("operationId") == operation_id: + converted_data = convert_dict_keys(operation_data) + from types import SimpleNamespace + + operation_obj = SimpleNamespace(**converted_data) + return path, method, operation_obj + raise ValueError(f"Operation {operation_id} not found in spec") + + def _validate_path_params(self, path: str, params: dict | None) -> None: + if not params: + params = {} + path_params = set(re.findall(r"{([^}]+)}", path)) + missing_params = path_params - set(params.keys()) + if missing_params: + raise ValueError(f"Missing required path parameters: {missing_params}") + invalid_params = set(params.keys()) - path_params + if invalid_params: + raise ValueError(f"Invalid path parameters: {invalid_params}") + + async def execute( + self, + operation_id: str, + path_params: dict = None, + query_params: dict = None, + body: dict = None, + ) -> dict: + """Execute an API operation.""" + if not self._client: + raise RuntimeError("Client not in async context") + path, method, _ = self._get_operation(operation_id) + self._validate_path_params(path, path_params) + if path_params: + path = path.format(**path_params) + url = f"{self.base_url.rstrip('/')}{path}" + request_headers = self.headers.copy() + if body is not None: + request_headers["Content-Type"] = "application/json" + try: + response = await self._client.request( + method=method.upper(), + url=url, + params=query_params, + json=body, + headers=request_headers, + ) + response.raise_for_status() + content_type = response.headers.get("content-type", "").lower() + if response.status_code == 204: + return response.status_code + if "application/json" in content_type: + return response.json() + return {"content": await response.aread()} + except httpx.HTTPStatusError as e: + logger.error("HTTP error executing operation %s: %s", operation_id, e) + raise + except Exception as e: + logger.error("Error executing operation %s: %s", operation_id, e) + raise ValueError(f"Failed to execute operation: {e}") diff --git a/src/airflow_mcp_server/tools/tool_manager.py b/src/airflow_mcp_server/tools/tool_manager.py index 59173b2..d837a61 100644 --- a/src/airflow_mcp_server/tools/tool_manager.py +++ b/src/airflow_mcp_server/tools/tool_manager.py @@ -29,26 +29,15 @@ def _initialize_client(config: AirflowConfig) -> AirflowClient: async def _initialize_tools(config: AirflowConfig) -> None: - """Initialize tools cache with Airflow operations. - - Args: - config: Configuration object with auth and URL settings - - Raises: - ValueError: If initialization fails - """ + """Initialize tools cache with Airflow operations (async).""" global _tools_cache - try: - client = _initialize_client(config) - # Use the OpenAPI spec dict from the client - parser = OperationParser(client.raw_spec) - - # Generate tools for each operation - for operation_id in parser.get_operations(): - operation_details = parser.parse_operation(operation_id) - tool = AirflowTool(operation_details, client) - _tools_cache[operation_id] = tool + async with AirflowClient(base_url=config.base_url, auth_token=config.auth_token) as client: + parser = OperationParser(client.raw_spec) + for operation_id in parser.get_operations(): + operation_details = parser.parse_operation(operation_id) + tool = AirflowTool(operation_details, client) + _tools_cache[operation_id] = tool except Exception as e: logger.error("Failed to initialize tools: %s", e) diff --git a/tests/tools/test_airflow_tool.py b/tests/tools/test_airflow_tool.py index 7981edb..4f62eec 100644 --- a/tests/tools/test_airflow_tool.py +++ b/tests/tools/test_airflow_tool.py @@ -1,11 +1,11 @@ """Tests for AirflowTool.""" import pytest +from pydantic import ValidationError + from airflow_mcp_server.client.airflow_client import AirflowClient from airflow_mcp_server.parser.operation_parser import OperationDetails from airflow_mcp_server.tools.airflow_tool import AirflowTool -from pydantic import ValidationError - from tests.tools.test_models import TestRequestModel @@ -41,6 +41,7 @@ def operation_details(): }, }, input_model=model, + description="Test operation for AirflowTool", )