feat: implement async operation execution and validation in AirflowClient; enhance tool initialization

This commit is contained in:
2025-05-04 04:19:39 +00:00
parent bba42eea00
commit c5565e6a00
3 changed files with 75 additions and 20 deletions

View File

@@ -92,3 +92,68 @@ class AirflowClient:
except httpx.RequestError as e: except httpx.RequestError as e:
raise ValueError(f"Failed to fetch OpenAPI spec from {url}: {e}") raise ValueError(f"Failed to fetch OpenAPI spec from {url}: {e}")
return response.json() return response.json()
def _get_operation(self, operation_id: str):
"""Get operation details from OpenAPI spec."""
for path, path_item in self._paths.items():
for method, operation_data in path_item.items():
if method.startswith("x-") or method == "parameters":
continue
if operation_data.get("operationId") == operation_id:
converted_data = convert_dict_keys(operation_data)
from types import SimpleNamespace
operation_obj = SimpleNamespace(**converted_data)
return path, method, operation_obj
raise ValueError(f"Operation {operation_id} not found in spec")
def _validate_path_params(self, path: str, params: dict | None) -> None:
if not params:
params = {}
path_params = set(re.findall(r"{([^}]+)}", path))
missing_params = path_params - set(params.keys())
if missing_params:
raise ValueError(f"Missing required path parameters: {missing_params}")
invalid_params = set(params.keys()) - path_params
if invalid_params:
raise ValueError(f"Invalid path parameters: {invalid_params}")
async def execute(
self,
operation_id: str,
path_params: dict = None,
query_params: dict = None,
body: dict = None,
) -> dict:
"""Execute an API operation."""
if not self._client:
raise RuntimeError("Client not in async context")
path, method, _ = self._get_operation(operation_id)
self._validate_path_params(path, path_params)
if path_params:
path = path.format(**path_params)
url = f"{self.base_url.rstrip('/')}{path}"
request_headers = self.headers.copy()
if body is not None:
request_headers["Content-Type"] = "application/json"
try:
response = await self._client.request(
method=method.upper(),
url=url,
params=query_params,
json=body,
headers=request_headers,
)
response.raise_for_status()
content_type = response.headers.get("content-type", "").lower()
if response.status_code == 204:
return response.status_code
if "application/json" in content_type:
return response.json()
return {"content": await response.aread()}
except httpx.HTTPStatusError as e:
logger.error("HTTP error executing operation %s: %s", operation_id, e)
raise
except Exception as e:
logger.error("Error executing operation %s: %s", operation_id, e)
raise ValueError(f"Failed to execute operation: {e}")

View File

@@ -29,22 +29,11 @@ def _initialize_client(config: AirflowConfig) -> AirflowClient:
async def _initialize_tools(config: AirflowConfig) -> None: async def _initialize_tools(config: AirflowConfig) -> None:
"""Initialize tools cache with Airflow operations. """Initialize tools cache with Airflow operations (async)."""
Args:
config: Configuration object with auth and URL settings
Raises:
ValueError: If initialization fails
"""
global _tools_cache global _tools_cache
try: try:
client = _initialize_client(config) async with AirflowClient(base_url=config.base_url, auth_token=config.auth_token) as client:
# Use the OpenAPI spec dict from the client
parser = OperationParser(client.raw_spec) parser = OperationParser(client.raw_spec)
# Generate tools for each operation
for operation_id in parser.get_operations(): for operation_id in parser.get_operations():
operation_details = parser.parse_operation(operation_id) operation_details = parser.parse_operation(operation_id)
tool = AirflowTool(operation_details, client) tool = AirflowTool(operation_details, client)

View File

@@ -1,11 +1,11 @@
"""Tests for AirflowTool.""" """Tests for AirflowTool."""
import pytest import pytest
from pydantic import ValidationError
from airflow_mcp_server.client.airflow_client import AirflowClient from airflow_mcp_server.client.airflow_client import AirflowClient
from airflow_mcp_server.parser.operation_parser import OperationDetails from airflow_mcp_server.parser.operation_parser import OperationDetails
from airflow_mcp_server.tools.airflow_tool import AirflowTool from airflow_mcp_server.tools.airflow_tool import AirflowTool
from pydantic import ValidationError
from tests.tools.test_models import TestRequestModel from tests.tools.test_models import TestRequestModel
@@ -41,6 +41,7 @@ def operation_details():
}, },
}, },
input_model=model, input_model=model,
description="Test operation for AirflowTool",
) )