28 Commits
0.3.0 ... 0.6.0

Author SHA1 Message Date
9837afe13a fix: update README to reflect changes in environment variable usage and authentication requirements 2025-05-04 09:25:48 +00:00
f4206ea73c Merge pull request #19 from abhishekbhakat/airflow-3
Airflow 3
2025-05-04 14:49:16 +05:30
950dc06901 fix: update error message for unsupported Airflow version to include equality check 2025-05-04 09:18:08 +00:00
2031650535 Refactor code structure for improved readability and maintainability 2025-05-04 09:15:15 +00:00
4263175351 feat: add version check for Airflow tools and remove client initialization function 2025-05-04 09:12:15 +00:00
b5cf563b8f feat: set default empty dict for API execution parameters in AirflowClient 2025-05-04 08:59:45 +00:00
d2464ea891 fix: update README to reflect Airflow API version change and mark tasks as complete 2025-05-04 08:25:46 +00:00
c5565e6a00 feat: implement async operation execution and validation in AirflowClient; enhance tool initialization 2025-05-04 04:19:39 +00:00
bba42eea00 refactor: update AirflowClient to use httpx for async requests and enhance tests for concurrency 2025-05-04 04:00:49 +00:00
5a864b27c5 Refactor operation parser tests to use updated OpenAPI spec
- Replace YAML spec file with JSON spec file for parser initialization.
- Update expected operation paths and descriptions to reflect API versioning changes.
- Adjust test cases to align with new operation IDs and request structures.
2025-05-04 03:51:16 +00:00
66cd068b33 Airflow 3 readiness initial commit 2025-04-23 06:17:27 +00:00
4734005ae4 Bump version from 0.4.0 to 0.5.0 2025-03-19 16:48:15 +00:00
63ff02fa4b Merge pull request #17 from bhavaniravi/bhavani/2-add-tool-description
chore: #2 add description to tools
2025-03-19 21:58:27 +05:30
Bhavani Ravi
407eb00c1b chore: #2 add description to tools 2025-03-19 20:45:51 +05:30
707f3747d7 version bump and dependencies sorting 2025-02-25 11:28:49 +00:00
a8638d27a5 fix tests CI 2025-02-25 11:23:25 +00:00
dbbc3ef5e8 Merge pull request #15 from abhishekbhakat/14-support-session-with-cookies
14 support session with cookies
2025-02-25 11:21:03 +00:00
d8887d3a2b Task update 2025-02-25 11:20:09 +00:00
679523a7c6 Using older env variables 2025-02-25 11:18:45 +00:00
492e79ef2a uv lock update 2025-02-25 11:12:45 +00:00
ea60acd54a Add AirflowConfig class for centralized configuration management 2025-02-25 08:53:37 +00:00
420b6fc68f safe and unsafe are mutually exclusive 2025-02-25 06:10:28 +00:00
355fb55bdb precedence implementation 2025-02-25 06:10:04 +00:00
8b38a26e8a Updates for using Cookies 2025-02-25 02:33:48 +00:00
5663f56621 Tests for PRs 2025-02-25 02:29:45 +00:00
2b652c5926 support cookies 2025-02-25 02:29:16 +00:00
3fd605b111 Merge pull request #13 from abhishekbhakat/fix-parser-tests
Use default resources yaml
2025-02-25 02:22:24 +00:00
c5106f10a8 Use default resources yaml 2025-02-25 02:21:44 +00:00
20 changed files with 18418 additions and 7277 deletions

29
.github/workflows/pytest.yml vendored Normal file
View File

@@ -0,0 +1,29 @@
name: Run Tests
on:
pull_request:
branches: [ main ]
jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.11", "3.12"]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[dev]
- name: Run pytest
run: |
pytest tests/ -v

3
.gitignore vendored
View File

@@ -179,3 +179,6 @@ project_resources/
# Ruff # Ruff
.ruff_cache/ .ruff_cache/
# Airflow
AIRFLOW_HOME/

View File

@@ -1,13 +1,13 @@
repos: repos:
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.11 rev: v0.11.8
hooks: hooks:
- id: ruff - id: ruff
args: [--fix] args: [--fix]
- id: ruff-format - id: ruff-format
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0 rev: v5.0.0
hooks: hooks:
- id: trailing-whitespace - id: trailing-whitespace
- id: end-of-file-fixer - id: end-of-file-fixer

View File

@@ -6,7 +6,6 @@
<img width="380" height="200" src="https://glama.ai/mcp/servers/6gjq9w80xr/badge" /> <img width="380" height="200" src="https://glama.ai/mcp/servers/6gjq9w80xr/badge" />
</a> </a>
## Overview ## Overview
A [Model Context Protocol](https://modelcontextprotocol.io/) server for controlling Airflow via Airflow APIs. A [Model Context Protocol](https://modelcontextprotocol.io/) server for controlling Airflow via Airflow APIs.
@@ -14,28 +13,32 @@ A [Model Context Protocol](https://modelcontextprotocol.io/) server for controll
https://github.com/user-attachments/assets/f3e60fff-8680-4dd9-b08e-fa7db655a705 https://github.com/user-attachments/assets/f3e60fff-8680-4dd9-b08e-fa7db655a705
## Setup ## Setup
### Usage with Claude Desktop ### Usage with Claude Desktop
```json ```json
{ {
"mcpServers": { "mcpServers": {
"airflow-mcp-server": { "airflow-mcp-server": {
"command": "uvx", "command": "uvx",
"args": [ "args": [
"airflow-mcp-server" "airflow-mcp-server",
], "--base-url",
"env": { "http://localhost:8080",
"AIRFLOW_BASE_URL": "http://<host:port>/api/v1", "--auth-token",
"AUTH_TOKEN": "<base64_encoded_username_password>" "<jwt_token>",
} ]
}
} }
}
} }
``` ```
> **Note:**
> - Set `base_url` to the root Airflow URL (e.g., `http://localhost:8080`).
> - Do **not** include `/api/v2` in the base URL. The server will automatically fetch the OpenAPI spec from `${base_url}/openapi.json`.
> - Only JWT token is required for authentication. Cookie and basic auth are no longer supported in Airflow 3.0.
### Operation Modes ### Operation Modes
The server supports two operation modes: The server supports two operation modes:
@@ -55,12 +58,9 @@ airflow-mcp-server --unsafe
### Considerations ### Considerations
The MCP Server expects environment variables to be set: **Authentication**
- `AIRFLOW_BASE_URL`: The base URL of the Airflow API
- `AUTH_TOKEN`: The token to use for authorization (_This should be base64 encoded username:password_)
- `OPENAPI_SPEC`: The path to the OpenAPI spec file (_Optional_) (_defaults to latest stable release_)
*Currently, only Basic Auth is supported.* - Only JWT authentication is supported in Airflow 3.0. You must provide a valid `AUTH_TOKEN`.
**Page Limit** **Page Limit**
@@ -68,9 +68,9 @@ The default is 100 items, but you can change it using `maximum_page_limit` optio
## Tasks ## Tasks
- [x] First API - [x] Airflow 3 readiness
- [x] Parse OpenAPI Spec - [x] Parse OpenAPI Spec
- [x] Safe/Unsafe mode implementation - [x] Safe/Unsafe mode implementation
- [ ] Parse proper description with list_tools. - [x] Parse proper description with list_tools.
- [ ] Airflow config fetch (_specifically for page limit_) - [x] Airflow config fetch (_specifically for page limit_)
- [ ] Env variables optional (_env variables might not be ideal for airflow plugins_) - [ ] Env variables optional (_env variables might not be ideal for airflow plugins_)

View File

@@ -1,6 +1,6 @@
[project] [project]
name = "airflow-mcp-server" name = "airflow-mcp-server"
version = "0.3.0" version = "0.6.0"
description = "MCP Server for Airflow" description = "MCP Server for Airflow"
readme = "README.md" readme = "README.md"
requires-python = ">=3.11" requires-python = ">=3.11"
@@ -12,10 +12,11 @@ dependencies = [
"aiohttp>=3.11.11", "aiohttp>=3.11.11",
"aioresponses>=0.7.7", "aioresponses>=0.7.7",
"importlib-resources>=6.5.0", "importlib-resources>=6.5.0",
"mcp>=1.2.0", "mcp>=1.7.1",
"openapi-core>=0.19.4", "openapi-core>=0.19.4",
"pydantic>=2.10.5", "pydantic>=2.11.4",
"pyyaml>=6.0.0", "pyyaml>=6.0.0",
"packaging>=25.0",
] ]
classifiers = [ classifiers = [
"Development Status :: 3 - Alpha", "Development Status :: 3 - Alpha",
@@ -52,12 +53,12 @@ build-backend = "hatchling.build"
exclude = [ exclude = [
"*", "*",
"!src/**", "!src/**",
"!pyproject.toml" "!pyproject.toml",
"!assets/**"
] ]
[tool.hatch.build.targets.wheel] [tool.hatch.build.targets.wheel]
packages = ["src/airflow_mcp_server"] packages = ["src/airflow_mcp_server"]
package-data = {"airflow_mcp_server"= ["*.yaml"]}
[tool.hatch.build.targets.wheel.sources] [tool.hatch.build.targets.wheel.sources]
"src/airflow_mcp_server" = "airflow_mcp_server" "src/airflow_mcp_server" = "airflow_mcp_server"

View File

@@ -1,9 +1,11 @@
import asyncio import asyncio
import logging import logging
import os
import sys import sys
import click import click
from airflow_mcp_server.config import AirflowConfig
from airflow_mcp_server.server_safe import serve as serve_safe from airflow_mcp_server.server_safe import serve as serve_safe
from airflow_mcp_server.server_unsafe import serve as serve_unsafe from airflow_mcp_server.server_unsafe import serve as serve_unsafe
@@ -12,7 +14,9 @@ from airflow_mcp_server.server_unsafe import serve as serve_unsafe
@click.option("-v", "--verbose", count=True, help="Increase verbosity") @click.option("-v", "--verbose", count=True, help="Increase verbosity")
@click.option("--safe", "-s", is_flag=True, help="Use only read-only tools") @click.option("--safe", "-s", is_flag=True, help="Use only read-only tools")
@click.option("--unsafe", "-u", is_flag=True, help="Use all tools (default)") @click.option("--unsafe", "-u", is_flag=True, help="Use all tools (default)")
def main(verbose: int, safe: bool, unsafe: bool) -> None: @click.option("--base-url", help="Airflow API base URL")
@click.option("--auth-token", help="Authentication token (JWT)")
def main(verbose: int, safe: bool, unsafe: bool, base_url: str = None, auth_token: str = None) -> None:
"""MCP server for Airflow""" """MCP server for Airflow"""
logging_level = logging.WARN logging_level = logging.WARN
if verbose == 1: if verbose == 1:
@@ -22,13 +26,26 @@ def main(verbose: int, safe: bool, unsafe: bool) -> None:
logging.basicConfig(level=logging_level, stream=sys.stderr) logging.basicConfig(level=logging_level, stream=sys.stderr)
# Read environment variables with proper precedence
config_base_url = os.environ.get("AIRFLOW_BASE_URL") or base_url
config_auth_token = os.environ.get("AUTH_TOKEN") or auth_token
# Initialize configuration
try:
config = AirflowConfig(base_url=config_base_url, auth_token=config_auth_token)
except ValueError as e:
click.echo(f"Configuration error: {e}", err=True)
sys.exit(1)
# Determine server mode with proper precedence
if safe and unsafe: if safe and unsafe:
raise click.UsageError("Options --safe and --unsafe are mutually exclusive") raise click.UsageError("Options --safe and --unsafe are mutually exclusive")
elif safe:
if safe: # CLI argument for safe mode
asyncio.run(serve_safe()) asyncio.run(serve_safe(config))
else: # Default to unsafe mode else:
asyncio.run(serve_unsafe()) # Default to unsafe mode
asyncio.run(serve_unsafe(config))
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -1,11 +1,7 @@
import logging import logging
import re import re
from pathlib import Path
from types import SimpleNamespace
from typing import Any, BinaryIO, TextIO
import aiohttp import httpx
import yaml
from jsonschema_path import SchemaPath from jsonschema_path import SchemaPath
from openapi_core import OpenAPI from openapi_core import OpenAPI
from openapi_core.validation.request.validators import V31RequestValidator from openapi_core.validation.request.validators import V31RequestValidator
@@ -29,141 +25,95 @@ def convert_dict_keys(d: dict) -> dict:
class AirflowClient: class AirflowClient:
"""Client for interacting with Airflow API.""" """Async client for interacting with Airflow API."""
def __init__( def __init__(
self, self,
spec_path: Path | str | dict | bytes | BinaryIO | TextIO,
base_url: str, base_url: str,
auth_token: str, auth_token: str,
) -> None: ) -> None:
"""Initialize Airflow client. """Initialize Airflow client.
Args: Args:
spec_path: OpenAPI spec as file path, dict, bytes, or file object
base_url: Base URL for API base_url: Base URL for API
auth_token: Authentication token auth_token: Authentication token (JWT)
Raises: Raises:
ValueError: If spec_path is invalid or spec cannot be loaded ValueError: If required configuration is missing or OpenAPI spec cannot be loaded
""" """
try: if not base_url:
# Load and parse OpenAPI spec raise ValueError("Missing required configuration: base_url")
if isinstance(spec_path, dict): if not auth_token:
self.raw_spec = spec_path raise ValueError("Missing required configuration: auth_token (JWT)")
elif isinstance(spec_path, bytes): self.base_url = base_url
self.raw_spec = yaml.safe_load(spec_path) self.auth_token = auth_token
elif isinstance(spec_path, str | Path): self.headers = {"Authorization": f"Bearer {self.auth_token}"}
with open(spec_path) as f: self._client: httpx.AsyncClient | None = None
self.raw_spec = yaml.safe_load(f) self.raw_spec = None
elif hasattr(spec_path, "read"): self.spec = None
content = spec_path.read() self._paths = None
if isinstance(content, bytes): self._validator = None
self.raw_spec = yaml.safe_load(content)
else:
self.raw_spec = yaml.safe_load(content)
else:
raise ValueError("Invalid spec_path type. Expected Path, str, dict, bytes or file-like object")
# Validate spec has required fields async def __aenter__(self):
if not isinstance(self.raw_spec, dict): self._client = httpx.AsyncClient(headers=self.headers)
raise ValueError("OpenAPI spec must be a dictionary") await self._initialize_spec()
required_fields = ["openapi", "info", "paths"]
for field in required_fields:
if field not in self.raw_spec:
raise ValueError(f"OpenAPI spec missing required field: {field}")
# Validate OpenAPI spec format
validate(self.raw_spec)
# Initialize OpenAPI spec
self.spec = OpenAPI.from_dict(self.raw_spec)
logger.debug("OpenAPI spec loaded successfully")
# Debug raw spec
logger.debug("Raw spec keys: %s", self.raw_spec.keys())
# Get paths from raw spec
if "paths" not in self.raw_spec:
raise ValueError("OpenAPI spec does not contain paths information")
self._paths = self.raw_spec["paths"]
logger.debug("Using raw spec paths")
# Initialize request validator with schema path
schema_path = SchemaPath.from_dict(self.raw_spec)
self._validator = V31RequestValidator(schema_path)
# API configuration
self.base_url = base_url.rstrip("/")
self.headers = {
"Authorization": f"Basic {auth_token}",
"Accept": "application/json",
}
except Exception as e:
logger.error("Failed to initialize AirflowClient: %s", e)
raise ValueError(f"Failed to initialize client: {e}")
async def __aenter__(self) -> "AirflowClient":
self._session = aiohttp.ClientSession(headers=self.headers)
return self return self
async def __aexit__(self, *exc) -> None: async def __aexit__(self, exc_type, exc, tb):
if hasattr(self, "_session"): if self._client:
await self._session.close() await self._client.aclose()
delattr(self, "_session") self._client = None
def _get_operation(self, operation_id: str) -> tuple[str, str, SimpleNamespace]: async def _initialize_spec(self):
"""Get operation details from OpenAPI spec. openapi_url = f"{self.base_url.rstrip('/')}/openapi.json"
self.raw_spec = await self._fetch_openapi_spec(openapi_url)
if not isinstance(self.raw_spec, dict):
raise ValueError("OpenAPI spec must be a dictionary")
required_fields = ["openapi", "info", "paths"]
for field in required_fields:
if field not in self.raw_spec:
raise ValueError(f"OpenAPI spec missing required field: {field}")
validate(self.raw_spec)
self.spec = OpenAPI.from_dict(self.raw_spec)
logger.debug("OpenAPI spec loaded successfully")
if "paths" not in self.raw_spec:
raise ValueError("OpenAPI spec does not contain paths information")
self._paths = self.raw_spec["paths"]
logger.debug("Using raw spec paths")
schema_path = SchemaPath.from_dict(self.raw_spec)
self._validator = V31RequestValidator(schema_path)
Args: async def _fetch_openapi_spec(self, url: str) -> dict:
operation_id: The operation ID to look up if not self._client:
self._client = httpx.AsyncClient(headers=self.headers)
Returns:
Tuple of (path, method, operation) where operation is a SimpleNamespace object
Raises:
ValueError: If operation not found
"""
try: try:
# Debug the paths structure response = await self._client.get(url)
logger.debug("Looking for operation %s in paths", operation_id) response.raise_for_status()
except httpx.RequestError as e:
raise ValueError(f"Failed to fetch OpenAPI spec from {url}: {e}")
return response.json()
for path, path_item in self._paths.items(): def _get_operation(self, operation_id: str):
for method, operation_data in path_item.items(): """Get operation details from OpenAPI spec."""
# Skip non-operation fields for path, path_item in self._paths.items():
if method.startswith("x-") or method == "parameters": for method, operation_data in path_item.items():
continue if method.startswith("x-") or method == "parameters":
continue
if operation_data.get("operationId") == operation_id:
converted_data = convert_dict_keys(operation_data)
from types import SimpleNamespace
# Debug each operation operation_obj = SimpleNamespace(**converted_data)
logger.debug("Checking %s %s: %s", method, path, operation_data.get("operationId")) return path, method, operation_obj
raise ValueError(f"Operation {operation_id} not found in spec")
if operation_data.get("operationId") == operation_id: def _validate_path_params(self, path: str, params: dict | None) -> None:
logger.debug("Found operation %s at %s %s", operation_id, method, path)
# Convert keys to snake_case and create object
converted_data = convert_dict_keys(operation_data)
operation_obj = SimpleNamespace(**converted_data)
return path, method, operation_obj
raise ValueError(f"Operation {operation_id} not found in spec")
except Exception as e:
logger.error("Error getting operation %s: %s", operation_id, e)
raise
def _validate_path_params(self, path: str, params: dict[str, Any] | None) -> None:
if not params: if not params:
params = {} params = {}
# Extract path parameter names from the path
path_params = set(re.findall(r"{([^}]+)}", path)) path_params = set(re.findall(r"{([^}]+)}", path))
# Check for missing required parameters
missing_params = path_params - set(params.keys()) missing_params = path_params - set(params.keys())
if missing_params: if missing_params:
raise ValueError(f"Missing required path parameters: {missing_params}") raise ValueError(f"Missing required path parameters: {missing_params}")
# Check for invalid parameters
invalid_params = set(params.keys()) - path_params invalid_params = set(params.keys()) - path_params
if invalid_params: if invalid_params:
raise ValueError(f"Invalid path parameters: {invalid_params}") raise ValueError(f"Invalid path parameters: {invalid_params}")
@@ -171,77 +121,42 @@ class AirflowClient:
async def execute( async def execute(
self, self,
operation_id: str, operation_id: str,
path_params: dict[str, Any] | None = None, path_params: dict = None,
query_params: dict[str, Any] | None = None, query_params: dict = None,
body: dict[str, Any] | None = None, body: dict = None,
) -> Any: ) -> dict:
"""Execute an API operation. """Execute an API operation."""
if not self._client:
Args:
operation_id: Operation ID from OpenAPI spec
path_params: URL path parameters
query_params: URL query parameters
body: Request body data
Returns:
API response data
Raises:
ValueError: If operation not found
RuntimeError: If used outside async context
aiohttp.ClientError: For HTTP/network errors
"""
if not hasattr(self, "_session") or not self._session:
raise RuntimeError("Client not in async context") raise RuntimeError("Client not in async context")
# Default all params to empty dict if None
path_params = path_params or {}
query_params = query_params or {}
body = body or {}
path, method, _ = self._get_operation(operation_id)
self._validate_path_params(path, path_params)
if path_params:
path = path.format(**path_params)
url = f"{self.base_url.rstrip('/')}{path}"
request_headers = self.headers.copy()
if body:
request_headers["Content-Type"] = "application/json"
try: try:
# Get operation details response = await self._client.request(
path, method, _ = self._get_operation(operation_id) method=method.upper(),
# Validate path parameters
self._validate_path_params(path, path_params)
# Format URL
if path_params:
path = path.format(**path_params)
url = f"{self.base_url}{path}"
logger.debug("Executing %s %s", method, url)
logger.debug("Request body: %s", body)
logger.debug("Request query params: %s", query_params)
# Dynamically set headers based on presence of body
request_headers = self.headers.copy()
if body is not None:
request_headers["Content-Type"] = "application/json"
# Make request
async with self._session.request(
method=method,
url=url, url=url,
params=query_params, params=query_params,
json=body, json=body,
) as response: headers=request_headers,
response.raise_for_status() )
content_type = response.headers.get("Content-Type", "").lower() response.raise_for_status()
# Status codes that typically have no body content_type = response.headers.get("content-type", "").lower()
no_body_statuses = {204} if response.status_code == 204:
if response.status in no_body_statuses: return response.status_code
if content_type and "application/json" in content_type: if "application/json" in content_type:
logger.warning("Unexpected JSON body with status %s", response.status) return response.json()
return await response.json() # Parse if present, though rare return {"content": await response.aread()}
logger.debug("Received %s response with no body", response.status) except httpx.HTTPStatusError as e:
return response.status logger.error("HTTP error executing operation %s: %s", operation_id, e)
# For statuses expecting a body, check mimetype
if "application/json" in content_type:
logger.debug("Response: %s", await response.text())
return await response.json()
# Unexpected mimetype with body
response_text = await response.text()
logger.error("Unexpected mimetype %s for status %s: %s", content_type, response.status, response_text)
raise ValueError(f"Cannot parse response with mimetype {content_type} as JSON")
except aiohttp.ClientError as e:
logger.error("Error executing operation %s: %s", operation_id, e)
raise raise
except Exception as e: except Exception as e:
logger.error("Error executing operation %s: %s", operation_id, e) logger.error("Error executing operation %s: %s", operation_id, e)

View File

@@ -0,0 +1,20 @@
class AirflowConfig:
"""Centralized configuration for Airflow MCP server."""
def __init__(self, base_url: str | None = None, auth_token: str | None = None) -> None:
"""Initialize configuration with provided values.
Args:
base_url: Airflow API base URL
auth_token: Authentication token (JWT)
Raises:
ValueError: If required configuration is missing
"""
self.base_url = base_url
if not self.base_url:
raise ValueError("Missing required configuration: base_url")
self.auth_token = auth_token
if not self.auth_token:
raise ValueError("Missing required configuration: auth_token (JWT)")

View File

@@ -20,6 +20,7 @@ class OperationDetails:
method: str method: str
parameters: dict[str, Any] parameters: dict[str, Any]
input_model: type[BaseModel] input_model: type[BaseModel]
description: str
class OperationParser: class OperationParser:
@@ -104,6 +105,7 @@ class OperationParser:
operation["path"] = path operation["path"] = path
operation["path_item"] = path_item operation["path_item"] = path_item
description = operation.get("description") or operation.get("summary") or operation_id
parameters = self.extract_parameters(operation) parameters = self.extract_parameters(operation)
@@ -119,7 +121,7 @@ class OperationParser:
# Create unified input model # Create unified input model
input_model = self._create_input_model(operation_id, parameters, body_schema) input_model = self._create_input_model(operation_id, parameters, body_schema)
return OperationDetails(operation_id=operation_id, path=str(path), method=method, parameters=parameters, input_model=input_model) return OperationDetails(operation_id=operation_id, path=str(path), method=method, parameters=parameters, description=description, input_model=input_model)
raise ValueError(f"Operation {operation_id} not found in spec") raise ValueError(f"Operation {operation_id} not found in spec")

File diff suppressed because it is too large Load Diff

View File

@@ -1,11 +1,11 @@
import logging import logging
import os
from typing import Any from typing import Any
from mcp.server import Server from mcp.server import Server
from mcp.server.stdio import stdio_server from mcp.server.stdio import stdio_server
from mcp.types import TextContent, Tool from mcp.types import TextContent, Tool
from airflow_mcp_server.config import AirflowConfig
from airflow_mcp_server.tools.tool_manager import get_airflow_tools, get_tool from airflow_mcp_server.tools.tool_manager import get_airflow_tools, get_tool
# ===========THIS IS FOR DEBUGGING WITH MCP INSPECTOR=================== # ===========THIS IS FOR DEBUGGING WITH MCP INSPECTOR===================
@@ -20,18 +20,18 @@ from airflow_mcp_server.tools.tool_manager import get_airflow_tools, get_tool
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
async def serve() -> None: async def serve(config: AirflowConfig) -> None:
"""Start MCP server.""" """Start MCP server.
required_vars = ["AIRFLOW_BASE_URL", "AUTH_TOKEN"]
if not all(var in os.environ for var in required_vars):
raise ValueError(f"Missing required environment variables: {required_vars}")
Args:
config: Configuration object with auth and URL settings
"""
server = Server("airflow-mcp-server") server = Server("airflow-mcp-server")
@server.list_tools() @server.list_tools()
async def list_tools() -> list[Tool]: async def list_tools() -> list[Tool]:
try: try:
return await get_airflow_tools() return await get_airflow_tools(config)
except Exception as e: except Exception as e:
logger.error("Failed to list tools: %s", e) logger.error("Failed to list tools: %s", e)
raise raise
@@ -39,7 +39,7 @@ async def serve() -> None:
@server.call_tool() @server.call_tool()
async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]: async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]:
try: try:
tool = await get_tool(name) tool = await get_tool(config, name)
async with tool.client: async with tool.client:
result = await tool.run(body=arguments) result = await tool.run(body=arguments)
return [TextContent(type="text", text=str(result))] return [TextContent(type="text", text=str(result))]

View File

@@ -1,28 +1,28 @@
import logging import logging
import os
from typing import Any from typing import Any
from mcp.server import Server from mcp.server import Server
from mcp.server.stdio import stdio_server from mcp.server.stdio import stdio_server
from mcp.types import TextContent, Tool from mcp.types import TextContent, Tool
from airflow_mcp_server.config import AirflowConfig
from airflow_mcp_server.tools.tool_manager import get_airflow_tools, get_tool from airflow_mcp_server.tools.tool_manager import get_airflow_tools, get_tool
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
async def serve() -> None: async def serve(config: AirflowConfig) -> None:
"""Start MCP server in safe mode (read-only operations).""" """Start MCP server in safe mode (read-only operations).
required_vars = ["AIRFLOW_BASE_URL", "AUTH_TOKEN"]
if not all(var in os.environ for var in required_vars):
raise ValueError(f"Missing required environment variables: {required_vars}")
Args:
config: Configuration object with auth and URL settings
"""
server = Server("airflow-mcp-server-safe") server = Server("airflow-mcp-server-safe")
@server.list_tools() @server.list_tools()
async def list_tools() -> list[Tool]: async def list_tools() -> list[Tool]:
try: try:
return await get_airflow_tools(mode="safe") return await get_airflow_tools(config, mode="safe")
except Exception as e: except Exception as e:
logger.error("Failed to list tools: %s", e) logger.error("Failed to list tools: %s", e)
raise raise
@@ -32,7 +32,7 @@ async def serve() -> None:
try: try:
if not name.startswith("get_"): if not name.startswith("get_"):
raise ValueError("Only GET operations allowed in safe mode") raise ValueError("Only GET operations allowed in safe mode")
tool = await get_tool(name) tool = await get_tool(config, name)
async with tool.client: async with tool.client:
result = await tool.run(body=arguments) result = await tool.run(body=arguments)
return [TextContent(type="text", text=str(result))] return [TextContent(type="text", text=str(result))]

View File

@@ -1,28 +1,28 @@
import logging import logging
import os
from typing import Any from typing import Any
from mcp.server import Server from mcp.server import Server
from mcp.server.stdio import stdio_server from mcp.server.stdio import stdio_server
from mcp.types import TextContent, Tool from mcp.types import TextContent, Tool
from airflow_mcp_server.config import AirflowConfig
from airflow_mcp_server.tools.tool_manager import get_airflow_tools, get_tool from airflow_mcp_server.tools.tool_manager import get_airflow_tools, get_tool
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
async def serve() -> None: async def serve(config: AirflowConfig) -> None:
"""Start MCP server in unsafe mode (all operations).""" """Start MCP server in unsafe mode (all operations).
required_vars = ["AIRFLOW_BASE_URL", "AUTH_TOKEN"]
if not all(var in os.environ for var in required_vars):
raise ValueError(f"Missing required environment variables: {required_vars}")
Args:
config: Configuration object with auth and URL settings
"""
server = Server("airflow-mcp-server-unsafe") server = Server("airflow-mcp-server-unsafe")
@server.list_tools() @server.list_tools()
async def list_tools() -> list[Tool]: async def list_tools() -> list[Tool]:
try: try:
return await get_airflow_tools(mode="unsafe") return await get_airflow_tools(config, mode="unsafe")
except Exception as e: except Exception as e:
logger.error("Failed to list tools: %s", e) logger.error("Failed to list tools: %s", e)
raise raise
@@ -30,7 +30,7 @@ async def serve() -> None:
@server.call_tool() @server.call_tool()
async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]: async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]:
try: try:
tool = await get_tool(name) tool = await get_tool(config, name)
async with tool.client: async with tool.client:
result = await tool.run(body=arguments) result = await tool.run(body=arguments)
return [TextContent(type="text", text=str(result))] return [TextContent(type="text", text=str(result))]

View File

@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any from typing import Any
class BaseTools(ABC): class BaseTools(ABC):
"""Abstract base class for tools.""" """Abstract base class for tools."""

View File

@@ -1,10 +1,10 @@
import logging import logging
import os
from importlib import resources
from mcp.types import Tool from mcp.types import Tool
from packaging.version import parse as parse_version
from airflow_mcp_server.client.airflow_client import AirflowClient from airflow_mcp_server.client.airflow_client import AirflowClient
from airflow_mcp_server.config import AirflowConfig
from airflow_mcp_server.parser.operation_parser import OperationParser from airflow_mcp_server.parser.operation_parser import OperationParser
from airflow_mcp_server.tools.airflow_tool import AirflowTool from airflow_mcp_server.tools.airflow_tool import AirflowTool
@@ -13,54 +13,16 @@ logger = logging.getLogger(__name__)
_tools_cache: dict[str, AirflowTool] = {} _tools_cache: dict[str, AirflowTool] = {}
def _initialize_client() -> AirflowClient: async def _initialize_tools(config: AirflowConfig) -> None:
"""Initialize Airflow client with environment variables or embedded spec. """Initialize tools cache with Airflow operations (async)."""
Returns:
AirflowClient instance
Raises:
ValueError: If required environment variables are missing or default spec is not found
"""
spec_path = os.environ.get("OPENAPI_SPEC")
if not spec_path:
# Fallback to embedded v1.yaml
try:
with resources.files("airflow_mcp_server.resources").joinpath("v1.yaml").open("rb") as f:
spec_path = f.name
logger.info("OPENAPI_SPEC not set; using embedded v1.yaml from %s", spec_path)
except Exception as e:
raise ValueError("Default OpenAPI spec not found in package resources") from e
required_vars = ["AIRFLOW_BASE_URL", "AUTH_TOKEN"]
missing_vars = [var for var in required_vars if var not in os.environ]
if missing_vars:
raise ValueError(f"Missing required environment variables: {missing_vars}")
return AirflowClient(spec_path=spec_path, base_url=os.environ["AIRFLOW_BASE_URL"], auth_token=os.environ["AUTH_TOKEN"])
async def _initialize_tools() -> None:
"""Initialize tools cache with Airflow operations.
Raises:
ValueError: If initialization fails
"""
global _tools_cache global _tools_cache
try: try:
client = _initialize_client() async with AirflowClient(base_url=config.base_url, auth_token=config.auth_token) as client:
spec_path = os.environ.get("OPENAPI_SPEC") parser = OperationParser(client.raw_spec)
if not spec_path: for operation_id in parser.get_operations():
with resources.files("airflow_mcp_server.resources").joinpath("v1.yaml").open("rb") as f: operation_details = parser.parse_operation(operation_id)
spec_path = f.name tool = AirflowTool(operation_details, client)
parser = OperationParser(spec_path) _tools_cache[operation_id] = tool
# Generate tools for each operation
for operation_id in parser.get_operations():
operation_details = parser.parse_operation(operation_id)
tool = AirflowTool(operation_details, client)
_tools_cache[operation_id] = tool
except Exception as e: except Exception as e:
logger.error("Failed to initialize tools: %s", e) logger.error("Failed to initialize tools: %s", e)
@@ -68,10 +30,11 @@ async def _initialize_tools() -> None:
raise ValueError(f"Failed to initialize tools: {e}") from e raise ValueError(f"Failed to initialize tools: {e}") from e
async def get_airflow_tools(mode: str = "unsafe") -> list[Tool]: async def get_airflow_tools(config: AirflowConfig, mode: str = "unsafe") -> list[Tool]:
"""Get list of available Airflow tools based on mode. """Get list of available Airflow tools based on mode.
Args: Args:
config: Configuration object with auth and URL settings
mode: "safe" for GET operations only, "unsafe" for all operations (default) mode: "safe" for GET operations only, "unsafe" for all operations (default)
Returns: Returns:
@@ -80,8 +43,21 @@ async def get_airflow_tools(mode: str = "unsafe") -> list[Tool]:
Raises: Raises:
ValueError: If initialization fails ValueError: If initialization fails
""" """
# Version check before returning tools
if not _tools_cache: if not _tools_cache:
await _initialize_tools() await _initialize_tools(config)
# Only check version if get_version tool is present
if "get_version" in _tools_cache:
version_tool = _tools_cache["get_version"]
async with version_tool.client:
version_result = await version_tool.run()
airflow_version = version_result.get("version")
if airflow_version is None:
raise RuntimeError("Could not determine Airflow version from get_version tool.")
if parse_version(airflow_version) <= parse_version("3.1.0"):
raise RuntimeError(f"Airflow version {airflow_version} is not supported. Requires >= 3.1.0.")
tools = [] tools = []
for operation_id, tool in _tools_cache.items(): for operation_id, tool in _tools_cache.items():
@@ -93,7 +69,7 @@ async def get_airflow_tools(mode: str = "unsafe") -> list[Tool]:
tools.append( tools.append(
Tool( Tool(
name=operation_id, name=operation_id,
description=tool.operation.operation_id, description=tool.operation.description,
inputSchema=schema, inputSchema=schema,
) )
) )
@@ -104,10 +80,11 @@ async def get_airflow_tools(mode: str = "unsafe") -> list[Tool]:
return tools return tools
async def get_tool(name: str) -> AirflowTool: async def get_tool(config: AirflowConfig, name: str) -> AirflowTool:
"""Get specific tool by name. """Get specific tool by name.
Args: Args:
config: Configuration object with auth and URL settings
name: Tool/operation name name: Tool/operation name
Returns: Returns:
@@ -118,7 +95,7 @@ async def get_tool(name: str) -> AirflowTool:
ValueError: If tool initialization fails ValueError: If tool initialization fails
""" """
if not _tools_cache: if not _tools_cache:
await _initialize_tools() await _initialize_tools(config)
if name not in _tools_cache: if name not in _tools_cache:
raise KeyError(f"Tool {name} not found") raise KeyError(f"Tool {name} not found")

View File

@@ -1,186 +1,69 @@
import asyncio
import logging import logging
from importlib import resources from unittest.mock import patch
from pathlib import Path
from typing import Any
import aiohttp
import pytest import pytest
import yaml
from aioresponses import aioresponses
from airflow_mcp_server.client.airflow_client import AirflowClient
from openapi_core import OpenAPI from openapi_core import OpenAPI
from airflow_mcp_server.client.airflow_client import AirflowClient
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
def create_valid_spec(paths: dict[str, Any] | None = None) -> dict[str, Any]: @pytest.mark.asyncio
return {"openapi": "3.0.0", "info": {"title": "Airflow API", "version": "1.0.0"}, "paths": paths or {}} async def test_async_multiple_clients_concurrent():
"""Test initializing two AirflowClients concurrently to verify async power."""
async def mock_get(self, url, *args, **kwargs):
class MockResponse:
def __init__(self):
self.status_code = 200
@pytest.fixture def raise_for_status(self):
def client() -> AirflowClient: pass
with resources.files("airflow_mcp_server.resources").joinpath("v1.yaml").open("rb") as f:
spec = yaml.safe_load(f)
return AirflowClient(
spec_path=spec,
base_url="http://localhost:8080/api/v1",
auth_token="test-token",
)
def json(self):
return {"openapi": "3.1.0", "info": {"title": "Airflow API", "version": "2.0.0"}, "paths": {}}
def test_init_client_initialization(client: AirflowClient) -> None: return MockResponse()
assert isinstance(client.spec, OpenAPI)
assert client.base_url == "http://localhost:8080/api/v1"
assert client.headers["Authorization"] == "Basic test-token"
with patch("httpx.AsyncClient.get", new=mock_get):
def test_init_load_spec_from_bytes() -> None: async def create_and_check():
spec_bytes = yaml.dump(create_valid_spec()).encode() async with AirflowClient(base_url="http://localhost:8080", auth_token="token") as client:
client = AirflowClient(spec_path=spec_bytes, base_url="http://test", auth_token="test") assert client.base_url == "http://localhost:8080"
assert client.raw_spec is not None assert client.headers["Authorization"] == "Bearer token"
assert isinstance(client.spec, OpenAPI)
# Run two clients concurrently
def test_init_load_spec_from_path(tmp_path: Path) -> None: await asyncio.gather(create_and_check(), create_and_check())
spec_file = tmp_path / "test_spec.yaml"
spec_file.write_text(yaml.dump(create_valid_spec()))
client = AirflowClient(spec_path=spec_file, base_url="http://test", auth_token="test")
assert client.raw_spec is not None
def test_init_invalid_spec() -> None:
with pytest.raises(ValueError):
AirflowClient(spec_path={"invalid": "spec"}, base_url="http://test", auth_token="test")
def test_init_missing_paths_in_spec() -> None:
with pytest.raises(ValueError):
AirflowClient(spec_path={"openapi": "3.0.0"}, base_url="http://test", auth_token="test")
def test_ops_get_operation(client: AirflowClient) -> None:
path, method, operation = client._get_operation("get_dags")
assert path == "/dags"
assert method == "get"
assert operation.operation_id == "get_dags"
path, method, operation = client._get_operation("get_dag")
assert path == "/dags/{dag_id}"
assert method == "get"
assert operation.operation_id == "get_dag"
def test_ops_nonexistent_operation(client: AirflowClient) -> None:
with pytest.raises(ValueError, match="Operation nonexistent not found in spec"):
client._get_operation("nonexistent")
def test_ops_case_sensitive_operation(client: AirflowClient) -> None:
with pytest.raises(ValueError):
client._get_operation("GET_DAGS")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_exec_without_context() -> None: async def test_async_client_initialization():
client = AirflowClient( async def mock_get(self, url, *args, **kwargs):
spec_path=create_valid_spec(), class MockResponse:
base_url="http://test", def __init__(self):
auth_token="test", self.status_code = 200
)
with pytest.raises(RuntimeError, match="Client not in async context"): def raise_for_status(self):
await client.execute("get_dags") pass
def json(self):
return {"openapi": "3.1.0", "info": {"title": "Airflow API", "version": "2.0.0"}, "paths": {}}
return MockResponse()
with patch("httpx.AsyncClient.get", new=mock_get):
async with AirflowClient(base_url="http://localhost:8080", auth_token="test-token") as client:
assert client.base_url == "http://localhost:8080"
assert client.headers["Authorization"] == "Bearer test-token"
assert isinstance(client.spec, OpenAPI)
@pytest.mark.asyncio def test_init_client_missing_auth():
async def test_exec_get_dags(client: AirflowClient) -> None: with pytest.raises(ValueError, match="auth_token"):
expected_response = { AirflowClient(
"dags": [ base_url="http://localhost:8080",
{ auth_token=None,
"dag_id": "test_dag", )
"is_active": True,
"is_paused": False,
}
],
"total_entries": 1,
}
with aioresponses() as mock:
async with client:
mock.get(
"http://localhost:8080/api/v1/dags?limit=100",
status=200,
payload=expected_response,
)
response = await client.execute("get_dags", query_params={"limit": 100})
assert response == expected_response
@pytest.mark.asyncio
async def test_exec_get_dag(client: AirflowClient) -> None:
expected_response = {
"dag_id": "test_dag",
"is_active": True,
"is_paused": False,
}
with aioresponses() as mock:
async with client:
mock.get(
"http://localhost:8080/api/v1/dags/test_dag",
status=200,
payload=expected_response,
)
response = await client.execute(
"get_dag",
path_params={"dag_id": "test_dag"},
)
assert response == expected_response
@pytest.mark.asyncio
async def test_exec_invalid_params(client: AirflowClient) -> None:
with pytest.raises(ValueError):
async with client:
# Test with missing required parameter
await client.execute("get_dag", path_params={})
with pytest.raises(ValueError):
async with client:
# Test with invalid parameter name
await client.execute("get_dag", path_params={"invalid": "value"})
@pytest.mark.asyncio
async def test_exec_timeout(client: AirflowClient) -> None:
with aioresponses() as mock:
mock.get("http://localhost:8080/api/v1/dags", exception=aiohttp.ClientError("Timeout"))
async with client:
with pytest.raises(aiohttp.ClientError):
await client.execute("get_dags")
@pytest.mark.asyncio
async def test_exec_error_response(client: AirflowClient) -> None:
with aioresponses() as mock:
async with client:
mock.get(
"http://localhost:8080/api/v1/dags",
status=403,
body="Forbidden",
)
with pytest.raises(aiohttp.ClientResponseError):
await client.execute("get_dags")
@pytest.mark.asyncio
async def test_exec_session_management(client: AirflowClient) -> None:
async with client:
with aioresponses() as mock:
mock.get(
"http://localhost:8080/api/v1/dags",
status=200,
payload={"dags": []},
)
await client.execute("get_dags")
with pytest.raises(RuntimeError):
await client.execute("get_dags")

17319
tests/parser/openapi.json Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -1,26 +1,16 @@
import logging import json
from importlib import resources
from typing import Any
import pytest import pytest
from airflow_mcp_server.parser.operation_parser import OperationDetails, OperationParser from typing import Any
from pydantic import BaseModel from pydantic import BaseModel
from airflow_mcp_server.parser.operation_parser import OperationDetails, OperationParser
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
@pytest.fixture @pytest.fixture
def spec_file(): def parser() -> OperationParser:
"""Get content of the v1.yaml spec file.""" """Create OperationParser instance from tests/parser/openapi.json."""
with resources.files("tests.client").joinpath("v1.yaml").open("rb") as f: with open("tests/parser/openapi.json") as f:
return f.read() spec_dict = json.load(f)
return OperationParser(spec_dict)
@pytest.fixture
def parser(spec_file) -> OperationParser:
"""Create OperationParser instance."""
return OperationParser(spec_path=spec_file)
def test_parse_operation_basic(parser: OperationParser) -> None: def test_parse_operation_basic(parser: OperationParser) -> None:
@@ -29,8 +19,21 @@ def test_parse_operation_basic(parser: OperationParser) -> None:
assert isinstance(operation, OperationDetails) assert isinstance(operation, OperationDetails)
assert operation.operation_id == "get_dags" assert operation.operation_id == "get_dags"
assert operation.path == "/dags" assert operation.path == "/api/v2/dags"
assert operation.method == "get" assert operation.method == "get"
assert operation.description == "Get all DAGs."
assert isinstance(operation.parameters, dict)
def test_parse_operation_with_no_description_but_summary(parser: OperationParser) -> None:
"""Test parsing operation with no description but summary."""
operation = parser.parse_operation("get_connections")
assert isinstance(operation, OperationDetails)
assert operation.operation_id == "get_connections"
assert operation.path == "/api/v2/connections"
assert operation.method == "get"
assert operation.description == "Get all connection entries."
assert isinstance(operation.parameters, dict) assert isinstance(operation.parameters, dict)
@@ -38,7 +41,7 @@ def test_parse_operation_with_path_params(parser: OperationParser) -> None:
"""Test parsing operation with path parameters.""" """Test parsing operation with path parameters."""
operation = parser.parse_operation("get_dag") operation = parser.parse_operation("get_dag")
assert operation.path == "/dags/{dag_id}" assert operation.path == "/api/v2/dags/{dag_id}"
assert isinstance(operation.input_model, type(BaseModel)) assert isinstance(operation.input_model, type(BaseModel))
# Verify path parameter field exists # Verify path parameter field exists
@@ -65,7 +68,10 @@ def test_parse_operation_with_query_params(parser: OperationParser) -> None:
def test_parse_operation_with_body_params(parser: OperationParser) -> None: def test_parse_operation_with_body_params(parser: OperationParser) -> None:
"""Test parsing operation with request body.""" """Test parsing operation with request body."""
operation = parser.parse_operation("post_dag_run") # Find the correct operationId for posting a dag run in the OpenAPI spec
# From the spec, the likely operation is under /api/v2/dags/{dag_id}/dagRuns
# Let's use "post_dag_run" if it exists, otherwise use the actual operationId
operation = parser.parse_operation("trigger_dag_run")
# Verify body fields exist # Verify body fields exist
fields = operation.input_model.__annotations__ fields = operation.input_model.__annotations__
@@ -149,7 +155,7 @@ def test_parse_operation_with_allof_body(parser: OperationParser) -> None:
assert isinstance(operation, OperationDetails) assert isinstance(operation, OperationDetails)
assert operation.operation_id == "test_connection" assert operation.operation_id == "test_connection"
assert operation.path == "/connections/test" assert operation.path == "/api/v2/connections/test"
assert operation.method == "post" assert operation.method == "post"
# Verify input model includes fields from allOf schema # Verify input model includes fields from allOf schema

View File

@@ -1,11 +1,11 @@
"""Tests for AirflowTool.""" """Tests for AirflowTool."""
import pytest import pytest
from pydantic import ValidationError
from airflow_mcp_server.client.airflow_client import AirflowClient from airflow_mcp_server.client.airflow_client import AirflowClient
from airflow_mcp_server.parser.operation_parser import OperationDetails from airflow_mcp_server.parser.operation_parser import OperationDetails
from airflow_mcp_server.tools.airflow_tool import AirflowTool from airflow_mcp_server.tools.airflow_tool import AirflowTool
from pydantic import ValidationError
from tests.tools.test_models import TestRequestModel from tests.tools.test_models import TestRequestModel
@@ -41,6 +41,7 @@ def operation_details():
}, },
}, },
input_model=model, input_model=model,
description="Test operation for AirflowTool",
) )

1396
uv.lock generated

File diff suppressed because it is too large Load Diff