feat: implement async operation execution and validation in AirflowClient; enhance tool initialization
This commit is contained in:
@@ -92,3 +92,68 @@ class AirflowClient:
|
|||||||
except httpx.RequestError as e:
|
except httpx.RequestError as e:
|
||||||
raise ValueError(f"Failed to fetch OpenAPI spec from {url}: {e}")
|
raise ValueError(f"Failed to fetch OpenAPI spec from {url}: {e}")
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
|
def _get_operation(self, operation_id: str):
|
||||||
|
"""Get operation details from OpenAPI spec."""
|
||||||
|
for path, path_item in self._paths.items():
|
||||||
|
for method, operation_data in path_item.items():
|
||||||
|
if method.startswith("x-") or method == "parameters":
|
||||||
|
continue
|
||||||
|
if operation_data.get("operationId") == operation_id:
|
||||||
|
converted_data = convert_dict_keys(operation_data)
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
operation_obj = SimpleNamespace(**converted_data)
|
||||||
|
return path, method, operation_obj
|
||||||
|
raise ValueError(f"Operation {operation_id} not found in spec")
|
||||||
|
|
||||||
|
def _validate_path_params(self, path: str, params: dict | None) -> None:
|
||||||
|
if not params:
|
||||||
|
params = {}
|
||||||
|
path_params = set(re.findall(r"{([^}]+)}", path))
|
||||||
|
missing_params = path_params - set(params.keys())
|
||||||
|
if missing_params:
|
||||||
|
raise ValueError(f"Missing required path parameters: {missing_params}")
|
||||||
|
invalid_params = set(params.keys()) - path_params
|
||||||
|
if invalid_params:
|
||||||
|
raise ValueError(f"Invalid path parameters: {invalid_params}")
|
||||||
|
|
||||||
|
async def execute(
|
||||||
|
self,
|
||||||
|
operation_id: str,
|
||||||
|
path_params: dict = None,
|
||||||
|
query_params: dict = None,
|
||||||
|
body: dict = None,
|
||||||
|
) -> dict:
|
||||||
|
"""Execute an API operation."""
|
||||||
|
if not self._client:
|
||||||
|
raise RuntimeError("Client not in async context")
|
||||||
|
path, method, _ = self._get_operation(operation_id)
|
||||||
|
self._validate_path_params(path, path_params)
|
||||||
|
if path_params:
|
||||||
|
path = path.format(**path_params)
|
||||||
|
url = f"{self.base_url.rstrip('/')}{path}"
|
||||||
|
request_headers = self.headers.copy()
|
||||||
|
if body is not None:
|
||||||
|
request_headers["Content-Type"] = "application/json"
|
||||||
|
try:
|
||||||
|
response = await self._client.request(
|
||||||
|
method=method.upper(),
|
||||||
|
url=url,
|
||||||
|
params=query_params,
|
||||||
|
json=body,
|
||||||
|
headers=request_headers,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
content_type = response.headers.get("content-type", "").lower()
|
||||||
|
if response.status_code == 204:
|
||||||
|
return response.status_code
|
||||||
|
if "application/json" in content_type:
|
||||||
|
return response.json()
|
||||||
|
return {"content": await response.aread()}
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
logger.error("HTTP error executing operation %s: %s", operation_id, e)
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error executing operation %s: %s", operation_id, e)
|
||||||
|
raise ValueError(f"Failed to execute operation: {e}")
|
||||||
|
|||||||
@@ -29,26 +29,15 @@ def _initialize_client(config: AirflowConfig) -> AirflowClient:
|
|||||||
|
|
||||||
|
|
||||||
async def _initialize_tools(config: AirflowConfig) -> None:
|
async def _initialize_tools(config: AirflowConfig) -> None:
|
||||||
"""Initialize tools cache with Airflow operations.
|
"""Initialize tools cache with Airflow operations (async)."""
|
||||||
|
|
||||||
Args:
|
|
||||||
config: Configuration object with auth and URL settings
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If initialization fails
|
|
||||||
"""
|
|
||||||
global _tools_cache
|
global _tools_cache
|
||||||
|
|
||||||
try:
|
try:
|
||||||
client = _initialize_client(config)
|
async with AirflowClient(base_url=config.base_url, auth_token=config.auth_token) as client:
|
||||||
# Use the OpenAPI spec dict from the client
|
parser = OperationParser(client.raw_spec)
|
||||||
parser = OperationParser(client.raw_spec)
|
for operation_id in parser.get_operations():
|
||||||
|
operation_details = parser.parse_operation(operation_id)
|
||||||
# Generate tools for each operation
|
tool = AirflowTool(operation_details, client)
|
||||||
for operation_id in parser.get_operations():
|
_tools_cache[operation_id] = tool
|
||||||
operation_details = parser.parse_operation(operation_id)
|
|
||||||
tool = AirflowTool(operation_details, client)
|
|
||||||
_tools_cache[operation_id] = tool
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to initialize tools: %s", e)
|
logger.error("Failed to initialize tools: %s", e)
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
"""Tests for AirflowTool."""
|
"""Tests for AirflowTool."""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from airflow_mcp_server.client.airflow_client import AirflowClient
|
from airflow_mcp_server.client.airflow_client import AirflowClient
|
||||||
from airflow_mcp_server.parser.operation_parser import OperationDetails
|
from airflow_mcp_server.parser.operation_parser import OperationDetails
|
||||||
from airflow_mcp_server.tools.airflow_tool import AirflowTool
|
from airflow_mcp_server.tools.airflow_tool import AirflowTool
|
||||||
from pydantic import ValidationError
|
|
||||||
|
|
||||||
from tests.tools.test_models import TestRequestModel
|
from tests.tools.test_models import TestRequestModel
|
||||||
|
|
||||||
|
|
||||||
@@ -41,6 +41,7 @@ def operation_details():
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
input_model=model,
|
input_model=model,
|
||||||
|
description="Test operation for AirflowTool",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user