16 Commits
0.5.0 ... main

Author SHA1 Message Date
8c0d3a7e18 fix: add error handling for server.run in safe and unsafe modes 2025-05-04 10:37:04 +00:00
7e3025a0b5 fix: update version to 0.6.2 in pyproject.toml and uv.lock 2025-05-04 10:31:17 +00:00
452652a190 fix: add resource and prompt listing endpoints to MCP server 2025-05-04 10:30:40 +00:00
9ec1cd2020 fix: update version check for Airflow to use strict inequality 2025-05-04 09:41:21 +00:00
c20539c39f fix: correct formatting of JSON example in README for auth token 2025-05-04 09:37:25 +00:00
9837afe13a fix: update README to reflect changes in environment variable usage and authentication requirements 2025-05-04 09:25:48 +00:00
f4206ea73c Merge pull request #19 from abhishekbhakat/airflow-3
Airflow 3
2025-05-04 14:49:16 +05:30
950dc06901 fix: update error message for unsupported Airflow version to include equality check 2025-05-04 09:18:08 +00:00
2031650535 Refactor code structure for improved readability and maintainability 2025-05-04 09:15:15 +00:00
4263175351 feat: add version check for Airflow tools and remove client initialization function 2025-05-04 09:12:15 +00:00
b5cf563b8f feat: set default empty dict for API execution parameters in AirflowClient 2025-05-04 08:59:45 +00:00
d2464ea891 fix: update README to reflect Airflow API version change and mark tasks as complete 2025-05-04 08:25:46 +00:00
c5565e6a00 feat: implement async operation execution and validation in AirflowClient; enhance tool initialization 2025-05-04 04:19:39 +00:00
bba42eea00 refactor: update AirflowClient to use httpx for async requests and enhance tests for concurrency 2025-05-04 04:00:49 +00:00
5a864b27c5 Refactor operation parser tests to use updated OpenAPI spec
- Replace YAML spec file with JSON spec file for parser initialization.
- Update expected operation paths and descriptions to reflect API versioning changes.
- Adjust test cases to align with new operation IDs and request structures.
2025-05-04 03:51:16 +00:00
66cd068b33 Airflow 3 readiness initial commit 2025-04-23 06:17:27 +00:00
17 changed files with 18371 additions and 7378 deletions

3
.gitignore vendored
View File

@@ -179,3 +179,6 @@ project_resources/
# Ruff
.ruff_cache/
# Airflow
AIRFLOW_HOME/

View File

@@ -1,13 +1,13 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.11
rev: v0.11.8
hooks:
- id: ruff
args: [--fix]
- id: ruff-format
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
rev: v5.0.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer

View File

@@ -6,7 +6,6 @@
<img width="380" height="200" src="https://glama.ai/mcp/servers/6gjq9w80xr/badge" />
</a>
## Overview
A [Model Context Protocol](https://modelcontextprotocol.io/) server for controlling Airflow via Airflow APIs.
@@ -14,7 +13,6 @@ A [Model Context Protocol](https://modelcontextprotocol.io/) server for controll
https://github.com/user-attachments/assets/f3e60fff-8680-4dd9-b08e-fa7db655a705
## Setup
### Usage with Claude Desktop
@@ -25,20 +23,22 @@ https://github.com/user-attachments/assets/f3e60fff-8680-4dd9-b08e-fa7db655a705
"airflow-mcp-server": {
"command": "uvx",
"args": [
"airflow-mcp-server"
],
"env": {
"AIRFLOW_BASE_URL": "http://<host:port>/api/v1",
// Either use AUTH_TOKEN for basic auth
"AUTH_TOKEN": "<base64_encoded_username_password>",
// Or use COOKIE for cookie-based auth
"COOKIE": "<session_cookie>"
}
"airflow-mcp-server",
"--base-url",
"http://localhost:8080",
"--auth-token",
"<jwt_token>"
]
}
}
}
```
> **Note:**
> - Set `base_url` to the root Airflow URL (e.g., `http://localhost:8080`).
> - Do **not** include `/api/v2` in the base URL. The server will automatically fetch the OpenAPI spec from `${base_url}/openapi.json`.
> - Only JWT token is required for authentication. Cookie and basic auth are no longer supported in Airflow 3.0.
### Operation Modes
The server supports two operation modes:
@@ -58,19 +58,9 @@ airflow-mcp-server --unsafe
### Considerations
The MCP Server expects environment variables to be set:
- `AIRFLOW_BASE_URL`: The base URL of the Airflow API
- `AUTH_TOKEN`: The token to use for basic auth (_This should be base64 encoded username:password_) (_Optional if COOKIE is provided_)
- `COOKIE`: The session cookie to use for authentication (_Optional if AUTH_TOKEN is provided_)
- `OPENAPI_SPEC`: The path to the OpenAPI spec file (_Optional_) (_defaults to latest stable release_)
**Authentication**
The server supports two authentication methods:
- **Basic Auth**: Using base64 encoded username:password via `AUTH_TOKEN` environment variable
- **Cookie**: Using session cookie via `COOKIE` environment variable
At least one of these authentication methods must be provided.
- Only JWT authentication is supported in Airflow 3.0. You must provide a valid `AUTH_TOKEN`.
**Page Limit**
@@ -78,10 +68,9 @@ The default is 100 items, but you can change it using `maximum_page_limit` optio
## Tasks
- [x] First API
- [x] Airflow 3 readiness
- [x] Parse OpenAPI Spec
- [x] Safe/Unsafe mode implementation
- [x] Allow session auth
- [ ] Parse proper description with list_tools.
- [ ] Airflow config fetch (_specifically for page limit_)
- [x] Parse proper description with list_tools.
- [x] Airflow config fetch (_specifically for page limit_)
- [ ] Env variables optional (_env variables might not be ideal for airflow plugins_)

View File

@@ -1,6 +1,6 @@
[project]
name = "airflow-mcp-server"
version = "0.5.0"
version = "0.6.2"
description = "MCP Server for Airflow"
readme = "README.md"
requires-python = ">=3.11"
@@ -12,10 +12,11 @@ dependencies = [
"aiohttp>=3.11.11",
"aioresponses>=0.7.7",
"importlib-resources>=6.5.0",
"mcp>=1.2.0",
"mcp>=1.7.1",
"openapi-core>=0.19.4",
"pydantic>=2.10.5",
"pydantic>=2.11.4",
"pyyaml>=6.0.0",
"packaging>=25.0",
]
classifiers = [
"Development Status :: 3 - Alpha",
@@ -58,7 +59,6 @@ exclude = [
[tool.hatch.build.targets.wheel]
packages = ["src/airflow_mcp_server"]
package-data = {"airflow_mcp_server"= ["*.yaml"]}
[tool.hatch.build.targets.wheel.sources]
"src/airflow_mcp_server" = "airflow_mcp_server"

View File

@@ -15,10 +15,8 @@ from airflow_mcp_server.server_unsafe import serve as serve_unsafe
@click.option("--safe", "-s", is_flag=True, help="Use only read-only tools")
@click.option("--unsafe", "-u", is_flag=True, help="Use all tools (default)")
@click.option("--base-url", help="Airflow API base URL")
@click.option("--spec-path", help="Path to OpenAPI spec file")
@click.option("--auth-token", help="Authentication token")
@click.option("--cookie", help="Session cookie")
def main(verbose: int, safe: bool, unsafe: bool, base_url: str = None, spec_path: str = None, auth_token: str = None, cookie: str = None) -> None:
@click.option("--auth-token", help="Authentication token (JWT)")
def main(verbose: int, safe: bool, unsafe: bool, base_url: str = None, auth_token: str = None) -> None:
"""MCP server for Airflow"""
logging_level = logging.WARN
if verbose == 1:
@@ -29,22 +27,18 @@ def main(verbose: int, safe: bool, unsafe: bool, base_url: str = None, spec_path
logging.basicConfig(level=logging_level, stream=sys.stderr)
# Read environment variables with proper precedence
# Environment variables take precedence over CLI arguments
config_base_url = os.environ.get("AIRFLOW_BASE_URL") or base_url
config_spec_path = os.environ.get("OPENAPI_SPEC") or spec_path
config_auth_token = os.environ.get("AUTH_TOKEN") or auth_token
config_cookie = os.environ.get("COOKIE") or cookie
# Initialize configuration
try:
config = AirflowConfig(base_url=config_base_url, spec_path=config_spec_path, auth_token=config_auth_token, cookie=config_cookie)
config = AirflowConfig(base_url=config_base_url, auth_token=config_auth_token)
except ValueError as e:
click.echo(f"Configuration error: {e}", err=True)
sys.exit(1)
# Determine server mode with proper precedence
if safe and unsafe:
# CLI argument validation
raise click.UsageError("Options --safe and --unsafe are mutually exclusive")
elif safe:
# CLI argument for safe mode

View File

@@ -1,11 +1,7 @@
import logging
import re
from pathlib import Path
from types import SimpleNamespace
from typing import Any, BinaryIO, TextIO
import aiohttp
import yaml
import httpx
from jsonschema_path import SchemaPath
from openapi_core import OpenAPI
from openapi_core.validation.request.validators import V31RequestValidator
@@ -29,148 +25,95 @@ def convert_dict_keys(d: dict) -> dict:
class AirflowClient:
"""Client for interacting with Airflow API."""
"""Async client for interacting with Airflow API."""
def __init__(
self,
spec_path: Path | str | dict | bytes | BinaryIO | TextIO,
base_url: str,
auth_token: str | None = None,
cookie: str | None = None,
auth_token: str,
) -> None:
"""Initialize Airflow client.
Args:
spec_path: OpenAPI spec as file path, dict, bytes, or file object
base_url: Base URL for API
auth_token: Authentication token (optional if cookie is provided)
cookie: Session cookie (optional if auth_token is provided)
auth_token: Authentication token (JWT)
Raises:
ValueError: If spec_path is invalid or spec cannot be loaded or if neither auth_token nor cookie is provided
ValueError: If required configuration is missing or OpenAPI spec cannot be loaded
"""
if not auth_token and not cookie:
raise ValueError("Either auth_token or cookie must be provided")
try:
# Load and parse OpenAPI spec
if isinstance(spec_path, dict):
self.raw_spec = spec_path
elif isinstance(spec_path, bytes):
self.raw_spec = yaml.safe_load(spec_path)
elif isinstance(spec_path, str | Path):
with open(spec_path) as f:
self.raw_spec = yaml.safe_load(f)
elif hasattr(spec_path, "read"):
content = spec_path.read()
if isinstance(content, bytes):
self.raw_spec = yaml.safe_load(content)
else:
self.raw_spec = yaml.safe_load(content)
else:
raise ValueError("Invalid spec_path type. Expected Path, str, dict, bytes or file-like object")
if not base_url:
raise ValueError("Missing required configuration: base_url")
if not auth_token:
raise ValueError("Missing required configuration: auth_token (JWT)")
self.base_url = base_url
self.auth_token = auth_token
self.headers = {"Authorization": f"Bearer {self.auth_token}"}
self._client: httpx.AsyncClient | None = None
self.raw_spec = None
self.spec = None
self._paths = None
self._validator = None
# Validate spec has required fields
async def __aenter__(self):
self._client = httpx.AsyncClient(headers=self.headers)
await self._initialize_spec()
return self
async def __aexit__(self, exc_type, exc, tb):
if self._client:
await self._client.aclose()
self._client = None
async def _initialize_spec(self):
openapi_url = f"{self.base_url.rstrip('/')}/openapi.json"
self.raw_spec = await self._fetch_openapi_spec(openapi_url)
if not isinstance(self.raw_spec, dict):
raise ValueError("OpenAPI spec must be a dictionary")
required_fields = ["openapi", "info", "paths"]
for field in required_fields:
if field not in self.raw_spec:
raise ValueError(f"OpenAPI spec missing required field: {field}")
# Validate OpenAPI spec format
validate(self.raw_spec)
# Initialize OpenAPI spec
self.spec = OpenAPI.from_dict(self.raw_spec)
logger.debug("OpenAPI spec loaded successfully")
# Debug raw spec
logger.debug("Raw spec keys: %s", self.raw_spec.keys())
# Get paths from raw spec
if "paths" not in self.raw_spec:
raise ValueError("OpenAPI spec does not contain paths information")
self._paths = self.raw_spec["paths"]
logger.debug("Using raw spec paths")
# Initialize request validator with schema path
schema_path = SchemaPath.from_dict(self.raw_spec)
self._validator = V31RequestValidator(schema_path)
# API configuration
self.base_url = base_url.rstrip("/")
self.headers = {"Accept": "application/json"}
# Set authentication header based on precedence (cookie > auth_token)
if cookie:
self.headers["Cookie"] = cookie
elif auth_token:
self.headers["Authorization"] = f"Basic {auth_token}"
except Exception as e:
logger.error("Failed to initialize AirflowClient: %s", e)
raise ValueError(f"Failed to initialize client: {e}")
async def __aenter__(self) -> "AirflowClient":
self._session = aiohttp.ClientSession(headers=self.headers)
return self
async def __aexit__(self, *exc) -> None:
if hasattr(self, "_session"):
await self._session.close()
delattr(self, "_session")
def _get_operation(self, operation_id: str) -> tuple[str, str, SimpleNamespace]:
"""Get operation details from OpenAPI spec.
Args:
operation_id: The operation ID to look up
Returns:
Tuple of (path, method, operation) where operation is a SimpleNamespace object
Raises:
ValueError: If operation not found
"""
async def _fetch_openapi_spec(self, url: str) -> dict:
if not self._client:
self._client = httpx.AsyncClient(headers=self.headers)
try:
# Debug the paths structure
logger.debug("Looking for operation %s in paths", operation_id)
response = await self._client.get(url)
response.raise_for_status()
except httpx.RequestError as e:
raise ValueError(f"Failed to fetch OpenAPI spec from {url}: {e}")
return response.json()
def _get_operation(self, operation_id: str):
"""Get operation details from OpenAPI spec."""
for path, path_item in self._paths.items():
for method, operation_data in path_item.items():
# Skip non-operation fields
if method.startswith("x-") or method == "parameters":
continue
# Debug each operation
logger.debug("Checking %s %s: %s", method, path, operation_data.get("operationId"))
if operation_data.get("operationId") == operation_id:
logger.debug("Found operation %s at %s %s", operation_id, method, path)
# Convert keys to snake_case and create object
converted_data = convert_dict_keys(operation_data)
from types import SimpleNamespace
operation_obj = SimpleNamespace(**converted_data)
return path, method, operation_obj
raise ValueError(f"Operation {operation_id} not found in spec")
except Exception as e:
logger.error("Error getting operation %s: %s", operation_id, e)
raise
def _validate_path_params(self, path: str, params: dict[str, Any] | None) -> None:
def _validate_path_params(self, path: str, params: dict | None) -> None:
if not params:
params = {}
# Extract path parameter names from the path
path_params = set(re.findall(r"{([^}]+)}", path))
# Check for missing required parameters
missing_params = path_params - set(params.keys())
if missing_params:
raise ValueError(f"Missing required path parameters: {missing_params}")
# Check for invalid parameters
invalid_params = set(params.keys()) - path_params
if invalid_params:
raise ValueError(f"Invalid path parameters: {invalid_params}")
@@ -178,77 +121,42 @@ class AirflowClient:
async def execute(
self,
operation_id: str,
path_params: dict[str, Any] | None = None,
query_params: dict[str, Any] | None = None,
body: dict[str, Any] | None = None,
) -> Any:
"""Execute an API operation.
Args:
operation_id: Operation ID from OpenAPI spec
path_params: URL path parameters
query_params: URL query parameters
body: Request body data
Returns:
API response data
Raises:
ValueError: If operation not found
RuntimeError: If used outside async context
aiohttp.ClientError: For HTTP/network errors
"""
if not hasattr(self, "_session") or not self._session:
path_params: dict = None,
query_params: dict = None,
body: dict = None,
) -> dict:
"""Execute an API operation."""
if not self._client:
raise RuntimeError("Client not in async context")
try:
# Get operation details
# Default all params to empty dict if None
path_params = path_params or {}
query_params = query_params or {}
body = body or {}
path, method, _ = self._get_operation(operation_id)
# Validate path parameters
self._validate_path_params(path, path_params)
# Format URL
if path_params:
path = path.format(**path_params)
url = f"{self.base_url}{path}"
logger.debug("Executing %s %s", method, url)
logger.debug("Request body: %s", body)
logger.debug("Request query params: %s", query_params)
# Dynamically set headers based on presence of body
url = f"{self.base_url.rstrip('/')}{path}"
request_headers = self.headers.copy()
if body is not None:
if body:
request_headers["Content-Type"] = "application/json"
# Make request
async with self._session.request(
method=method,
try:
response = await self._client.request(
method=method.upper(),
url=url,
params=query_params,
json=body,
) as response:
headers=request_headers,
)
response.raise_for_status()
content_type = response.headers.get("Content-Type", "").lower()
# Status codes that typically have no body
no_body_statuses = {204}
if response.status in no_body_statuses:
if content_type and "application/json" in content_type:
logger.warning("Unexpected JSON body with status %s", response.status)
return await response.json() # Parse if present, though rare
logger.debug("Received %s response with no body", response.status)
return response.status
# For statuses expecting a body, check mimetype
content_type = response.headers.get("content-type", "").lower()
if response.status_code == 204:
return response.status_code
if "application/json" in content_type:
logger.debug("Response: %s", await response.text())
return await response.json()
# Unexpected mimetype with body
response_text = await response.text()
logger.error("Unexpected mimetype %s for status %s: %s", content_type, response.status, response_text)
raise ValueError(f"Cannot parse response with mimetype {content_type} as JSON")
except aiohttp.ClientError as e:
logger.error("Error executing operation %s: %s", operation_id, e)
return response.json()
return {"content": await response.aread()}
except httpx.HTTPStatusError as e:
logger.error("HTTP error executing operation %s: %s", operation_id, e)
raise
except Exception as e:
logger.error("Error executing operation %s: %s", operation_id, e)

View File

@@ -1,14 +1,12 @@
class AirflowConfig:
"""Centralized configuration for Airflow MCP server."""
def __init__(self, base_url: str | None = None, spec_path: str | None = None, auth_token: str | None = None, cookie: str | None = None) -> None:
def __init__(self, base_url: str | None = None, auth_token: str | None = None) -> None:
"""Initialize configuration with provided values.
Args:
base_url: Airflow API base URL
spec_path: Path to OpenAPI spec file
auth_token: Authentication token
cookie: Session cookie
auth_token: Authentication token (JWT)
Raises:
ValueError: If required configuration is missing
@@ -17,9 +15,6 @@ class AirflowConfig:
if not self.base_url:
raise ValueError("Missing required configuration: base_url")
self.spec_path = spec_path
self.auth_token = auth_token
self.cookie = cookie
if not self.auth_token and not self.cookie:
raise ValueError("Either auth_token or cookie must be provided")
if not self.auth_token:
raise ValueError("Missing required configuration: auth_token (JWT)")

File diff suppressed because it is too large Load Diff

View File

@@ -1,52 +0,0 @@
import logging
from typing import Any
from mcp.server import Server
from mcp.server.stdio import stdio_server
from mcp.types import TextContent, Tool
from airflow_mcp_server.config import AirflowConfig
from airflow_mcp_server.tools.tool_manager import get_airflow_tools, get_tool
# ===========THIS IS FOR DEBUGGING WITH MCP INSPECTOR===================
# import sys
# Configure root logger to stderr
# logging.basicConfig(level=logging.DEBUG, format="%(asctime)s [%(levelname)s] %(message)s", handlers=[logging.StreamHandler(sys.stderr)])
# Disable Uvicorn's default handlers
# logging.getLogger("uvicorn.error").handlers = []
# logging.getLogger("uvicorn.access").handlers = []
# ======================================================================
logger = logging.getLogger(__name__)
async def serve(config: AirflowConfig) -> None:
"""Start MCP server.
Args:
config: Configuration object with auth and URL settings
"""
server = Server("airflow-mcp-server")
@server.list_tools()
async def list_tools() -> list[Tool]:
try:
return await get_airflow_tools(config)
except Exception as e:
logger.error("Failed to list tools: %s", e)
raise
@server.call_tool()
async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]:
try:
tool = await get_tool(config, name)
async with tool.client:
result = await tool.run(body=arguments)
return [TextContent(type="text", text=str(result))]
except Exception as e:
logger.error("Tool execution failed: %s", e)
raise
options = server.create_initialization_options()
async with stdio_server() as (read_stream, write_stream):
await server.run(read_stream, write_stream, options, raise_exceptions=True)

View File

@@ -1,9 +1,10 @@
import logging
from typing import Any
import anyio
from mcp.server import Server
from mcp.server.stdio import stdio_server
from mcp.types import TextContent, Tool
from mcp.types import Prompt, Resource, ResourceTemplate, TextContent, Tool
from airflow_mcp_server.config import AirflowConfig
from airflow_mcp_server.tools.tool_manager import get_airflow_tools, get_tool
@@ -27,6 +28,24 @@ async def serve(config: AirflowConfig) -> None:
logger.error("Failed to list tools: %s", e)
raise
@server.list_resources()
async def list_resources() -> list[Resource]:
"""List available resources (returns empty list)."""
logger.info("Resources list requested - returning empty list")
return []
@server.list_resource_templates()
async def list_resource_templates() -> list[ResourceTemplate]:
"""List available resource templates (returns empty list)."""
logger.info("Resource templates list requested - returning empty list")
return []
@server.list_prompts()
async def list_prompts() -> list[Prompt]:
"""List available prompts (returns empty list)."""
logger.info("Prompts list requested - returning empty list")
return []
@server.call_tool()
async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]:
try:
@@ -42,4 +61,10 @@ async def serve(config: AirflowConfig) -> None:
options = server.create_initialization_options()
async with stdio_server() as (read_stream, write_stream):
try:
await server.run(read_stream, write_stream, options, raise_exceptions=True)
except anyio.BrokenResourceError:
logger.error("BrokenResourceError: Stream was closed unexpectedly. Exiting gracefully.")
except Exception as e:
logger.error(f"Unexpected error in server.run: {e}")
raise

View File

@@ -1,13 +1,23 @@
import logging
from typing import Any
import anyio
from mcp.server import Server
from mcp.server.stdio import stdio_server
from mcp.types import TextContent, Tool
from mcp.types import Prompt, Resource, ResourceTemplate, TextContent, Tool
from airflow_mcp_server.config import AirflowConfig
from airflow_mcp_server.tools.tool_manager import get_airflow_tools, get_tool
# ===========THIS IS FOR DEBUGGING WITH MCP INSPECTOR===================
# import sys
# Configure root logger to stderr
# logging.basicConfig(level=logging.DEBUG, format="%(asctime)s [%(levelname)s] %(message)s", handlers=[logging.StreamHandler(sys.stderr)])
# Disable Uvicorn's default handlers
# logging.getLogger("uvicorn.error").handlers = []
# logging.getLogger("uvicorn.access").handlers = []
# ======================================================================
logger = logging.getLogger(__name__)
@@ -27,6 +37,24 @@ async def serve(config: AirflowConfig) -> None:
logger.error("Failed to list tools: %s", e)
raise
@server.list_resources()
async def list_resources() -> list[Resource]:
"""List available resources (returns empty list)."""
logger.info("Resources list requested - returning empty list")
return []
@server.list_resource_templates()
async def list_resource_templates() -> list[ResourceTemplate]:
"""List available resource templates (returns empty list)."""
logger.info("Resource templates list requested - returning empty list")
return []
@server.list_prompts()
async def list_prompts() -> list[Prompt]:
"""List available prompts (returns empty list)."""
logger.info("Prompts list requested - returning empty list")
return []
@server.call_tool()
async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]:
try:
@@ -40,4 +68,10 @@ async def serve(config: AirflowConfig) -> None:
options = server.create_initialization_options()
async with stdio_server() as (read_stream, write_stream):
try:
await server.run(read_stream, write_stream, options, raise_exceptions=True)
except anyio.BrokenResourceError:
logger.error("BrokenResourceError: Stream was closed unexpectedly. Exiting gracefully.")
except Exception as e:
logger.error(f"Unexpected error in server.run: {e}")
raise

View File

@@ -1,7 +1,7 @@
import logging
from importlib import resources
from mcp.types import Tool
from packaging.version import parse as parse_version
from airflow_mcp_server.client.airflow_client import AirflowClient
from airflow_mcp_server.config import AirflowConfig
@@ -13,61 +13,12 @@ logger = logging.getLogger(__name__)
_tools_cache: dict[str, AirflowTool] = {}
def _initialize_client(config: AirflowConfig) -> AirflowClient:
"""Initialize Airflow client with configuration.
Args:
config: Configuration object with auth and URL settings
Returns:
AirflowClient instance
Raises:
ValueError: If default spec is not found
"""
spec_path = config.spec_path
if not spec_path:
# Fallback to embedded v1.yaml
try:
with resources.files("airflow_mcp_server.resources").joinpath("v1.yaml").open("rb") as f:
spec_path = f.name
logger.info("OPENAPI_SPEC not set; using embedded v1.yaml from %s", spec_path)
except Exception as e:
raise ValueError("Default OpenAPI spec not found in package resources") from e
# Initialize client with appropriate authentication method
client_args = {"spec_path": spec_path, "base_url": config.base_url}
# Apply cookie auth first if available (highest precedence)
if config.cookie:
client_args["cookie"] = config.cookie
# Otherwise use auth token if available
elif config.auth_token:
client_args["auth_token"] = config.auth_token
return AirflowClient(**client_args)
async def _initialize_tools(config: AirflowConfig) -> None:
"""Initialize tools cache with Airflow operations.
Args:
config: Configuration object with auth and URL settings
Raises:
ValueError: If initialization fails
"""
"""Initialize tools cache with Airflow operations (async)."""
global _tools_cache
try:
client = _initialize_client(config)
spec_path = config.spec_path
if not spec_path:
with resources.files("airflow_mcp_server.resources").joinpath("v1.yaml").open("rb") as f:
spec_path = f.name
parser = OperationParser(spec_path)
# Generate tools for each operation
async with AirflowClient(base_url=config.base_url, auth_token=config.auth_token) as client:
parser = OperationParser(client.raw_spec)
for operation_id in parser.get_operations():
operation_details = parser.parse_operation(operation_id)
tool = AirflowTool(operation_details, client)
@@ -92,9 +43,22 @@ async def get_airflow_tools(config: AirflowConfig, mode: str = "unsafe") -> list
Raises:
ValueError: If initialization fails
"""
# Version check before returning tools
if not _tools_cache:
await _initialize_tools(config)
# Only check version if get_version tool is present
if "get_version" in _tools_cache:
version_tool = _tools_cache["get_version"]
async with version_tool.client:
version_result = await version_tool.run()
airflow_version = version_result.get("version")
if airflow_version is None:
raise RuntimeError("Could not determine Airflow version from get_version tool.")
if parse_version(airflow_version) < parse_version("3.1.0"):
raise RuntimeError(f"Airflow version {airflow_version} is not supported. Requires >= 3.1.0.")
tools = []
for operation_id, tool in _tools_cache.items():
try:

View File

@@ -1,211 +1,69 @@
import asyncio
import logging
from importlib import resources
from pathlib import Path
from typing import Any
from unittest.mock import patch
import aiohttp
import pytest
import yaml
from aioresponses import aioresponses
from airflow_mcp_server.client.airflow_client import AirflowClient
from openapi_core import OpenAPI
from airflow_mcp_server.client.airflow_client import AirflowClient
logging.basicConfig(level=logging.DEBUG)
def create_valid_spec(paths: dict[str, Any] | None = None) -> dict[str, Any]:
return {"openapi": "3.0.0", "info": {"title": "Airflow API", "version": "1.0.0"}, "paths": paths or {}}
@pytest.mark.asyncio
async def test_async_multiple_clients_concurrent():
"""Test initializing two AirflowClients concurrently to verify async power."""
async def mock_get(self, url, *args, **kwargs):
class MockResponse:
def __init__(self):
self.status_code = 200
@pytest.fixture
def client() -> AirflowClient:
with resources.files("airflow_mcp_server.resources").joinpath("v1.yaml").open("rb") as f:
spec = yaml.safe_load(f)
return AirflowClient(
spec_path=spec,
base_url="http://localhost:8080/api/v1",
auth_token="test-token",
)
def raise_for_status(self):
pass
def json(self):
return {"openapi": "3.1.0", "info": {"title": "Airflow API", "version": "2.0.0"}, "paths": {}}
def test_init_client_initialization(client: AirflowClient) -> None:
return MockResponse()
with patch("httpx.AsyncClient.get", new=mock_get):
async def create_and_check():
async with AirflowClient(base_url="http://localhost:8080", auth_token="token") as client:
assert client.base_url == "http://localhost:8080"
assert client.headers["Authorization"] == "Bearer token"
assert isinstance(client.spec, OpenAPI)
assert client.base_url == "http://localhost:8080/api/v1"
assert client.headers["Authorization"] == "Basic test-token"
assert "Cookie" not in client.headers
# Run two clients concurrently
await asyncio.gather(create_and_check(), create_and_check())
def test_init_client_with_cookie() -> None:
with resources.files("airflow_mcp_server.resources").joinpath("v1.yaml").open("rb") as f:
spec = yaml.safe_load(f)
client = AirflowClient(
spec_path=spec,
base_url="http://localhost:8080/api/v1",
cookie="session=b18e8c5e-92f5-4d1e-a8f2-7c1b62110cae.vmX5kqDq5TdvT9BzTlypMVclAwM",
)
@pytest.mark.asyncio
async def test_async_client_initialization():
async def mock_get(self, url, *args, **kwargs):
class MockResponse:
def __init__(self):
self.status_code = 200
def raise_for_status(self):
pass
def json(self):
return {"openapi": "3.1.0", "info": {"title": "Airflow API", "version": "2.0.0"}, "paths": {}}
return MockResponse()
with patch("httpx.AsyncClient.get", new=mock_get):
async with AirflowClient(base_url="http://localhost:8080", auth_token="test-token") as client:
assert client.base_url == "http://localhost:8080"
assert client.headers["Authorization"] == "Bearer test-token"
assert isinstance(client.spec, OpenAPI)
assert client.base_url == "http://localhost:8080/api/v1"
assert "Authorization" not in client.headers
assert client.headers["Cookie"] == "session=b18e8c5e-92f5-4d1e-a8f2-7c1b62110cae.vmX5kqDq5TdvT9BzTlypMVclAwM"
def test_init_client_missing_auth() -> None:
with resources.files("airflow_mcp_server.resources").joinpath("v1.yaml").open("rb") as f:
spec = yaml.safe_load(f)
with pytest.raises(ValueError, match="Either auth_token or cookie must be provided"):
def test_init_client_missing_auth():
with pytest.raises(ValueError, match="auth_token"):
AirflowClient(
spec_path=spec,
base_url="http://localhost:8080/api/v1",
base_url="http://localhost:8080",
auth_token=None,
)
def test_init_load_spec_from_bytes() -> None:
spec_bytes = yaml.dump(create_valid_spec()).encode()
client = AirflowClient(spec_path=spec_bytes, base_url="http://test", auth_token="test")
assert client.raw_spec is not None
def test_init_load_spec_from_path(tmp_path: Path) -> None:
spec_file = tmp_path / "test_spec.yaml"
spec_file.write_text(yaml.dump(create_valid_spec()))
client = AirflowClient(spec_path=spec_file, base_url="http://test", auth_token="test")
assert client.raw_spec is not None
def test_init_invalid_spec() -> None:
with pytest.raises(ValueError):
AirflowClient(spec_path={"invalid": "spec"}, base_url="http://test", auth_token="test")
def test_init_missing_paths_in_spec() -> None:
with pytest.raises(ValueError):
AirflowClient(spec_path={"openapi": "3.0.0"}, base_url="http://test", auth_token="test")
def test_ops_get_operation(client: AirflowClient) -> None:
path, method, operation = client._get_operation("get_dags")
assert path == "/dags"
assert method == "get"
assert operation.operation_id == "get_dags"
path, method, operation = client._get_operation("get_dag")
assert path == "/dags/{dag_id}"
assert method == "get"
assert operation.operation_id == "get_dag"
def test_ops_nonexistent_operation(client: AirflowClient) -> None:
with pytest.raises(ValueError, match="Operation nonexistent not found in spec"):
client._get_operation("nonexistent")
def test_ops_case_sensitive_operation(client: AirflowClient) -> None:
with pytest.raises(ValueError):
client._get_operation("GET_DAGS")
@pytest.mark.asyncio
async def test_exec_without_context() -> None:
client = AirflowClient(
spec_path=create_valid_spec(),
base_url="http://test",
auth_token="test",
)
with pytest.raises(RuntimeError, match="Client not in async context"):
await client.execute("get_dags")
@pytest.mark.asyncio
async def test_exec_get_dags(client: AirflowClient) -> None:
expected_response = {
"dags": [
{
"dag_id": "test_dag",
"is_active": True,
"is_paused": False,
}
],
"total_entries": 1,
}
with aioresponses() as mock:
async with client:
mock.get(
"http://localhost:8080/api/v1/dags?limit=100",
status=200,
payload=expected_response,
)
response = await client.execute("get_dags", query_params={"limit": 100})
assert response == expected_response
@pytest.mark.asyncio
async def test_exec_get_dag(client: AirflowClient) -> None:
expected_response = {
"dag_id": "test_dag",
"is_active": True,
"is_paused": False,
}
with aioresponses() as mock:
async with client:
mock.get(
"http://localhost:8080/api/v1/dags/test_dag",
status=200,
payload=expected_response,
)
response = await client.execute(
"get_dag",
path_params={"dag_id": "test_dag"},
)
assert response == expected_response
@pytest.mark.asyncio
async def test_exec_invalid_params(client: AirflowClient) -> None:
with pytest.raises(ValueError):
async with client:
# Test with missing required parameter
await client.execute("get_dag", path_params={})
with pytest.raises(ValueError):
async with client:
# Test with invalid parameter name
await client.execute("get_dag", path_params={"invalid": "value"})
@pytest.mark.asyncio
async def test_exec_timeout(client: AirflowClient) -> None:
with aioresponses() as mock:
mock.get("http://localhost:8080/api/v1/dags", exception=aiohttp.ClientError("Timeout"))
async with client:
with pytest.raises(aiohttp.ClientError):
await client.execute("get_dags")
@pytest.mark.asyncio
async def test_exec_error_response(client: AirflowClient) -> None:
with aioresponses() as mock:
async with client:
mock.get(
"http://localhost:8080/api/v1/dags",
status=403,
body="Forbidden",
)
with pytest.raises(aiohttp.ClientResponseError):
await client.execute("get_dags")
@pytest.mark.asyncio
async def test_exec_session_management(client: AirflowClient) -> None:
async with client:
with aioresponses() as mock:
mock.get(
"http://localhost:8080/api/v1/dags",
status=200,
payload={"dags": []},
)
await client.execute("get_dags")
with pytest.raises(RuntimeError):
await client.execute("get_dags")

17319
tests/parser/openapi.json Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -1,26 +1,16 @@
import logging
from importlib import resources
from typing import Any
import json
import pytest
from airflow_mcp_server.parser.operation_parser import OperationDetails, OperationParser
from typing import Any
from pydantic import BaseModel
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
from airflow_mcp_server.parser.operation_parser import OperationDetails, OperationParser
@pytest.fixture
def spec_file():
"""Get content of the v1.yaml spec file."""
with resources.files("airflow_mcp_server.resources").joinpath("v1.yaml").open("rb") as f:
return f.read()
@pytest.fixture
def parser(spec_file) -> OperationParser:
"""Create OperationParser instance."""
return OperationParser(spec_path=spec_file)
def parser() -> OperationParser:
"""Create OperationParser instance from tests/parser/openapi.json."""
with open("tests/parser/openapi.json") as f:
spec_dict = json.load(f)
return OperationParser(spec_dict)
def test_parse_operation_basic(parser: OperationParser) -> None:
@@ -29,14 +19,9 @@ def test_parse_operation_basic(parser: OperationParser) -> None:
assert isinstance(operation, OperationDetails)
assert operation.operation_id == "get_dags"
assert operation.path == "/dags"
assert operation.path == "/api/v2/dags"
assert operation.method == "get"
assert (
operation.description
== """List DAGs in the database.
`dag_id_pattern` can be set to match dags of a specific pattern
"""
)
assert operation.description == "Get all DAGs."
assert isinstance(operation.parameters, dict)
@@ -46,9 +31,9 @@ def test_parse_operation_with_no_description_but_summary(parser: OperationParser
assert isinstance(operation, OperationDetails)
assert operation.operation_id == "get_connections"
assert operation.path == "/connections"
assert operation.path == "/api/v2/connections"
assert operation.method == "get"
assert operation.description == "List connections"
assert operation.description == "Get all connection entries."
assert isinstance(operation.parameters, dict)
@@ -56,7 +41,7 @@ def test_parse_operation_with_path_params(parser: OperationParser) -> None:
"""Test parsing operation with path parameters."""
operation = parser.parse_operation("get_dag")
assert operation.path == "/dags/{dag_id}"
assert operation.path == "/api/v2/dags/{dag_id}"
assert isinstance(operation.input_model, type(BaseModel))
# Verify path parameter field exists
@@ -83,7 +68,10 @@ def test_parse_operation_with_query_params(parser: OperationParser) -> None:
def test_parse_operation_with_body_params(parser: OperationParser) -> None:
"""Test parsing operation with request body."""
operation = parser.parse_operation("post_dag_run")
# Find the correct operationId for posting a dag run in the OpenAPI spec
# From the spec, the likely operation is under /api/v2/dags/{dag_id}/dagRuns
# Let's use "post_dag_run" if it exists, otherwise use the actual operationId
operation = parser.parse_operation("trigger_dag_run")
# Verify body fields exist
fields = operation.input_model.__annotations__
@@ -167,7 +155,7 @@ def test_parse_operation_with_allof_body(parser: OperationParser) -> None:
assert isinstance(operation, OperationDetails)
assert operation.operation_id == "test_connection"
assert operation.path == "/connections/test"
assert operation.path == "/api/v2/connections/test"
assert operation.method == "post"
# Verify input model includes fields from allOf schema

View File

@@ -1,11 +1,11 @@
"""Tests for AirflowTool."""
import pytest
from pydantic import ValidationError
from airflow_mcp_server.client.airflow_client import AirflowClient
from airflow_mcp_server.parser.operation_parser import OperationDetails
from airflow_mcp_server.tools.airflow_tool import AirflowTool
from pydantic import ValidationError
from tests.tools.test_models import TestRequestModel
@@ -41,6 +41,7 @@ def operation_details():
},
},
input_model=model,
description="Test operation for AirflowTool",
)

1396
uv.lock generated

File diff suppressed because it is too large Load Diff