Airflow openapi client
This commit is contained in:
@@ -7,7 +7,9 @@ dependencies = [
|
|||||||
"aiofiles>=24.1.0",
|
"aiofiles>=24.1.0",
|
||||||
"aiohttp>=3.11.11",
|
"aiohttp>=3.11.11",
|
||||||
"aioresponses>=0.7.7",
|
"aioresponses>=0.7.7",
|
||||||
|
"importlib-resources>=6.5.0",
|
||||||
"mcp>=1.2.0",
|
"mcp>=1.2.0",
|
||||||
|
"openapi-core>=0.19.4",
|
||||||
"pydantic>=2.10.5",
|
"pydantic>=2.10.5",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,3 @@
|
|||||||
|
from airflow_mcp_server.client.airflow_client import AirflowClient
|
||||||
|
|
||||||
|
__all__ = ["AirflowClient"]
|
||||||
@@ -0,0 +1,171 @@
|
|||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import yaml
|
||||||
|
from openapi_core import OpenAPI
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def camel_to_snake(name: str) -> str:
|
||||||
|
"""Convert camelCase to snake_case."""
|
||||||
|
name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
|
||||||
|
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower()
|
||||||
|
|
||||||
|
|
||||||
|
def convert_dict_keys(d: dict) -> dict:
|
||||||
|
"""Recursively convert dictionary keys from camelCase to snake_case."""
|
||||||
|
if not isinstance(d, dict):
|
||||||
|
return d
|
||||||
|
|
||||||
|
return {camel_to_snake(k): convert_dict_keys(v) if isinstance(v, dict) else v for k, v in d.items()}
|
||||||
|
|
||||||
|
|
||||||
|
class AirflowClient:
|
||||||
|
"""Client for interacting with Airflow API."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
spec_path: Path | str | object,
|
||||||
|
base_url: str,
|
||||||
|
auth_token: str,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize Airflow client."""
|
||||||
|
# Load and parse OpenAPI spec
|
||||||
|
if isinstance(spec_path, (str | Path)):
|
||||||
|
with open(spec_path) as f:
|
||||||
|
self.raw_spec = yaml.safe_load(f)
|
||||||
|
else:
|
||||||
|
self.raw_spec = yaml.safe_load(spec_path)
|
||||||
|
|
||||||
|
# Initialize OpenAPI spec
|
||||||
|
try:
|
||||||
|
self.spec = OpenAPI.from_dict(self.raw_spec)
|
||||||
|
logger.debug("OpenAPI spec loaded successfully")
|
||||||
|
|
||||||
|
# Debug raw spec
|
||||||
|
logger.debug("Raw spec keys: %s", self.raw_spec.keys())
|
||||||
|
|
||||||
|
# Get paths from raw spec
|
||||||
|
if "paths" in self.raw_spec:
|
||||||
|
self._paths = self.raw_spec["paths"]
|
||||||
|
logger.debug("Using raw spec paths")
|
||||||
|
else:
|
||||||
|
raise ValueError("OpenAPI spec does not contain paths information")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to initialize OpenAPI spec: %s", e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
# API configuration
|
||||||
|
self.base_url = base_url.rstrip("/")
|
||||||
|
self.headers = {
|
||||||
|
"Authorization": f"Bearer {auth_token}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Accept": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Session management
|
||||||
|
self._session: aiohttp.ClientSession | None = None
|
||||||
|
|
||||||
|
async def __aenter__(self) -> "AirflowClient":
|
||||||
|
"""Enter async context, creating session."""
|
||||||
|
self._session = aiohttp.ClientSession(headers=self.headers)
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, *exc) -> None:
|
||||||
|
"""Exit async context, closing session."""
|
||||||
|
if self._session:
|
||||||
|
await self._session.close()
|
||||||
|
self._session = None
|
||||||
|
|
||||||
|
def _get_operation(self, operation_id: str) -> tuple[str, str, SimpleNamespace]:
|
||||||
|
"""Get operation details from OpenAPI spec.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
operation_id: The operation ID to look up
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (path, method, operation) where operation is a SimpleNamespace object
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If operation not found
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Debug the paths structure
|
||||||
|
logger.debug("Looking for operation %s in paths", operation_id)
|
||||||
|
|
||||||
|
for path, path_item in self._paths.items():
|
||||||
|
for method, operation_data in path_item.items():
|
||||||
|
# Skip non-operation fields
|
||||||
|
if method.startswith("x-") or method == "parameters":
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Debug each operation
|
||||||
|
logger.debug("Checking %s %s: %s", method, path, operation_data.get("operationId"))
|
||||||
|
|
||||||
|
if operation_data.get("operationId") == operation_id:
|
||||||
|
logger.debug("Found operation %s at %s %s", operation_id, method, path)
|
||||||
|
# Convert keys to snake_case and create object
|
||||||
|
converted_data = convert_dict_keys(operation_data)
|
||||||
|
operation_obj = SimpleNamespace(**converted_data)
|
||||||
|
return path, method, operation_obj
|
||||||
|
|
||||||
|
raise ValueError(f"Operation {operation_id} not found in spec")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error getting operation %s: %s", operation_id, e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def execute(
|
||||||
|
self,
|
||||||
|
operation_id: str,
|
||||||
|
path_params: dict[str, Any] | None = None,
|
||||||
|
query_params: dict[str, Any] | None = None,
|
||||||
|
body: dict[str, Any] | None = None,
|
||||||
|
) -> Any:
|
||||||
|
"""Execute an API operation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
operation_id: Operation ID from OpenAPI spec
|
||||||
|
path_params: URL path parameters
|
||||||
|
query_params: URL query parameters
|
||||||
|
body: Request body data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
API response data
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If operation not found
|
||||||
|
aiohttp.ClientError: For HTTP/network errors
|
||||||
|
"""
|
||||||
|
if not self._session:
|
||||||
|
raise RuntimeError("Client not in async context")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get operation details
|
||||||
|
path, method, _ = self._get_operation(operation_id)
|
||||||
|
|
||||||
|
# Format URL
|
||||||
|
if path_params:
|
||||||
|
path = path.format(**path_params)
|
||||||
|
url = f"{self.base_url}{path}"
|
||||||
|
|
||||||
|
logger.debug("Executing %s %s", method, url)
|
||||||
|
|
||||||
|
# Make request
|
||||||
|
async with self._session.request(
|
||||||
|
method=method,
|
||||||
|
url=url,
|
||||||
|
params=query_params,
|
||||||
|
json=body,
|
||||||
|
) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
return await response.json()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error executing operation %s: %s", operation_id, e)
|
||||||
|
raise
|
||||||
137
airflow-mcp-server/tests/client/test_airflow_client.py
Normal file
137
airflow-mcp-server/tests/client/test_airflow_client.py
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
import logging
|
||||||
|
from importlib import resources
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import pytest
|
||||||
|
from aioresponses import aioresponses
|
||||||
|
from airflow_mcp_server.client import AirflowClient
|
||||||
|
from openapi_core import OpenAPI
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def spec_file():
|
||||||
|
"""Get content of the v1.yaml spec file."""
|
||||||
|
with resources.files("tests.client").joinpath("v1.yaml").open("rb") as f:
|
||||||
|
return f.read()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client(spec_file):
|
||||||
|
"""Create test client with the actual spec."""
|
||||||
|
return AirflowClient(
|
||||||
|
spec_path=spec_file,
|
||||||
|
base_url="http://localhost:8080/api/v1",
|
||||||
|
auth_token="test-token",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_client_initialization(client):
|
||||||
|
"""Test client initialization and spec loading."""
|
||||||
|
assert isinstance(client.spec, OpenAPI)
|
||||||
|
assert client.base_url == "http://localhost:8080/api/v1"
|
||||||
|
assert client.headers["Authorization"] == "Bearer test-token"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_operation(client):
|
||||||
|
"""Test operation lookup from spec."""
|
||||||
|
# Test get_dags operation
|
||||||
|
path, method, operation = client._get_operation("get_dags")
|
||||||
|
assert path == "/dags"
|
||||||
|
assert method == "get"
|
||||||
|
assert operation.operation_id == "get_dags"
|
||||||
|
|
||||||
|
# Test get_dag operation
|
||||||
|
path, method, operation = client._get_operation("get_dag")
|
||||||
|
assert path == "/dags/{dag_id}"
|
||||||
|
assert method == "get"
|
||||||
|
assert operation.operation_id == "get_dag"
|
||||||
|
|
||||||
|
|
||||||
|
# Note: asyncio_mode is configured in pyproject.toml
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_without_context():
|
||||||
|
"""Test error when executing outside async context."""
|
||||||
|
with resources.files("tests.client").joinpath("v1.yaml").open("rb") as f:
|
||||||
|
client = AirflowClient(
|
||||||
|
spec_path=f,
|
||||||
|
base_url="http://test",
|
||||||
|
auth_token="test",
|
||||||
|
)
|
||||||
|
with pytest.raises(RuntimeError, match="Client not in async context"):
|
||||||
|
await client.execute("get_dags")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_get_dags(client):
|
||||||
|
"""Test DAG list retrieval."""
|
||||||
|
expected_response = {
|
||||||
|
"dags": [
|
||||||
|
{
|
||||||
|
"dag_id": "test_dag",
|
||||||
|
"is_active": True,
|
||||||
|
"is_paused": False,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"total_entries": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
with aioresponses() as mock:
|
||||||
|
async with client:
|
||||||
|
mock.get(
|
||||||
|
"http://localhost:8080/api/v1/dags?limit=100",
|
||||||
|
status=200,
|
||||||
|
payload=expected_response,
|
||||||
|
)
|
||||||
|
response = await client.execute("get_dags", query_params={"limit": 100})
|
||||||
|
assert response == expected_response
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_get_dag(client):
|
||||||
|
"""Test single DAG retrieval with path parameters."""
|
||||||
|
expected_response = {
|
||||||
|
"dag_id": "test_dag",
|
||||||
|
"is_active": True,
|
||||||
|
"is_paused": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
with aioresponses() as mock:
|
||||||
|
async with client:
|
||||||
|
mock.get(
|
||||||
|
"http://localhost:8080/api/v1/dags/test_dag",
|
||||||
|
status=200,
|
||||||
|
payload=expected_response,
|
||||||
|
)
|
||||||
|
response = await client.execute(
|
||||||
|
"get_dag",
|
||||||
|
path_params={"dag_id": "test_dag"},
|
||||||
|
)
|
||||||
|
assert response == expected_response
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_error_response(client):
|
||||||
|
"""Test error handling for failed requests."""
|
||||||
|
with aioresponses() as mock:
|
||||||
|
async with client:
|
||||||
|
mock.get(
|
||||||
|
"http://localhost:8080/api/v1/dags",
|
||||||
|
status=403,
|
||||||
|
body="Forbidden",
|
||||||
|
)
|
||||||
|
with pytest.raises(aiohttp.ClientError):
|
||||||
|
await client.execute("get_dags")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_session_management(client):
|
||||||
|
"""Test proper session creation and cleanup."""
|
||||||
|
assert client._session is None
|
||||||
|
|
||||||
|
async with client:
|
||||||
|
assert client._session is not None
|
||||||
|
assert not client._session.closed
|
||||||
|
|
||||||
|
assert client._session is None
|
||||||
6161
airflow-mcp-server/tests/client/v1.yaml
Normal file
6161
airflow-mcp-server/tests/client/v1.yaml
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user