Compare commits
14 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
9837afe13a
|
|||
| f4206ea73c | |||
|
950dc06901
|
|||
|
2031650535
|
|||
|
4263175351
|
|||
|
b5cf563b8f
|
|||
|
d2464ea891
|
|||
|
c5565e6a00
|
|||
|
bba42eea00
|
|||
|
5a864b27c5
|
|||
|
66cd068b33
|
|||
|
4734005ae4
|
|||
| 63ff02fa4b | |||
|
|
407eb00c1b |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -179,3 +179,6 @@ project_resources/
|
|||||||
|
|
||||||
# Ruff
|
# Ruff
|
||||||
.ruff_cache/
|
.ruff_cache/
|
||||||
|
|
||||||
|
# Airflow
|
||||||
|
AIRFLOW_HOME/
|
||||||
|
|||||||
@@ -1,13 +1,13 @@
|
|||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
rev: v0.1.11
|
rev: v0.11.8
|
||||||
hooks:
|
hooks:
|
||||||
- id: ruff
|
- id: ruff
|
||||||
args: [--fix]
|
args: [--fix]
|
||||||
- id: ruff-format
|
- id: ruff-format
|
||||||
|
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
rev: v4.5.0
|
rev: v5.0.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: trailing-whitespace
|
- id: trailing-whitespace
|
||||||
- id: end-of-file-fixer
|
- id: end-of-file-fixer
|
||||||
|
|||||||
51
README.md
51
README.md
@@ -6,7 +6,6 @@
|
|||||||
<img width="380" height="200" src="https://glama.ai/mcp/servers/6gjq9w80xr/badge" />
|
<img width="380" height="200" src="https://glama.ai/mcp/servers/6gjq9w80xr/badge" />
|
||||||
</a>
|
</a>
|
||||||
|
|
||||||
|
|
||||||
## Overview
|
## Overview
|
||||||
A [Model Context Protocol](https://modelcontextprotocol.io/) server for controlling Airflow via Airflow APIs.
|
A [Model Context Protocol](https://modelcontextprotocol.io/) server for controlling Airflow via Airflow APIs.
|
||||||
|
|
||||||
@@ -14,31 +13,32 @@ A [Model Context Protocol](https://modelcontextprotocol.io/) server for controll
|
|||||||
|
|
||||||
https://github.com/user-attachments/assets/f3e60fff-8680-4dd9-b08e-fa7db655a705
|
https://github.com/user-attachments/assets/f3e60fff-8680-4dd9-b08e-fa7db655a705
|
||||||
|
|
||||||
|
|
||||||
## Setup
|
## Setup
|
||||||
|
|
||||||
### Usage with Claude Desktop
|
### Usage with Claude Desktop
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"mcpServers": {
|
"mcpServers": {
|
||||||
"airflow-mcp-server": {
|
"airflow-mcp-server": {
|
||||||
"command": "uvx",
|
"command": "uvx",
|
||||||
"args": [
|
"args": [
|
||||||
"airflow-mcp-server"
|
"airflow-mcp-server",
|
||||||
],
|
"--base-url",
|
||||||
"env": {
|
"http://localhost:8080",
|
||||||
"AIRFLOW_BASE_URL": "http://<host:port>/api/v1",
|
"--auth-token",
|
||||||
// Either use AUTH_TOKEN for basic auth
|
"<jwt_token>",
|
||||||
"AUTH_TOKEN": "<base64_encoded_username_password>",
|
]
|
||||||
// Or use COOKIE for cookie-based auth
|
}
|
||||||
"COOKIE": "<session_cookie>"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> **Note:**
|
||||||
|
> - Set `base_url` to the root Airflow URL (e.g., `http://localhost:8080`).
|
||||||
|
> - Do **not** include `/api/v2` in the base URL. The server will automatically fetch the OpenAPI spec from `${base_url}/openapi.json`.
|
||||||
|
> - Only JWT token is required for authentication. Cookie and basic auth are no longer supported in Airflow 3.0.
|
||||||
|
|
||||||
### Operation Modes
|
### Operation Modes
|
||||||
|
|
||||||
The server supports two operation modes:
|
The server supports two operation modes:
|
||||||
@@ -58,19 +58,9 @@ airflow-mcp-server --unsafe
|
|||||||
|
|
||||||
### Considerations
|
### Considerations
|
||||||
|
|
||||||
The MCP Server expects environment variables to be set:
|
|
||||||
- `AIRFLOW_BASE_URL`: The base URL of the Airflow API
|
|
||||||
- `AUTH_TOKEN`: The token to use for basic auth (_This should be base64 encoded username:password_) (_Optional if COOKIE is provided_)
|
|
||||||
- `COOKIE`: The session cookie to use for authentication (_Optional if AUTH_TOKEN is provided_)
|
|
||||||
- `OPENAPI_SPEC`: The path to the OpenAPI spec file (_Optional_) (_defaults to latest stable release_)
|
|
||||||
|
|
||||||
**Authentication**
|
**Authentication**
|
||||||
|
|
||||||
The server supports two authentication methods:
|
- Only JWT authentication is supported in Airflow 3.0. You must provide a valid `AUTH_TOKEN`.
|
||||||
- **Basic Auth**: Using base64 encoded username:password via `AUTH_TOKEN` environment variable
|
|
||||||
- **Cookie**: Using session cookie via `COOKIE` environment variable
|
|
||||||
|
|
||||||
At least one of these authentication methods must be provided.
|
|
||||||
|
|
||||||
**Page Limit**
|
**Page Limit**
|
||||||
|
|
||||||
@@ -78,10 +68,9 @@ The default is 100 items, but you can change it using `maximum_page_limit` optio
|
|||||||
|
|
||||||
## Tasks
|
## Tasks
|
||||||
|
|
||||||
- [x] First API
|
- [x] Airflow 3 readiness
|
||||||
- [x] Parse OpenAPI Spec
|
- [x] Parse OpenAPI Spec
|
||||||
- [x] Safe/Unsafe mode implementation
|
- [x] Safe/Unsafe mode implementation
|
||||||
- [x] Allow session auth
|
- [x] Parse proper description with list_tools.
|
||||||
- [ ] Parse proper description with list_tools.
|
- [x] Airflow config fetch (_specifically for page limit_)
|
||||||
- [ ] Airflow config fetch (_specifically for page limit_)
|
|
||||||
- [ ] Env variables optional (_env variables might not be ideal for airflow plugins_)
|
- [ ] Env variables optional (_env variables might not be ideal for airflow plugins_)
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "airflow-mcp-server"
|
name = "airflow-mcp-server"
|
||||||
version = "0.4.0"
|
version = "0.6.0"
|
||||||
description = "MCP Server for Airflow"
|
description = "MCP Server for Airflow"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.11"
|
requires-python = ">=3.11"
|
||||||
@@ -12,10 +12,11 @@ dependencies = [
|
|||||||
"aiohttp>=3.11.11",
|
"aiohttp>=3.11.11",
|
||||||
"aioresponses>=0.7.7",
|
"aioresponses>=0.7.7",
|
||||||
"importlib-resources>=6.5.0",
|
"importlib-resources>=6.5.0",
|
||||||
"mcp>=1.2.0",
|
"mcp>=1.7.1",
|
||||||
"openapi-core>=0.19.4",
|
"openapi-core>=0.19.4",
|
||||||
"pydantic>=2.10.5",
|
"pydantic>=2.11.4",
|
||||||
"pyyaml>=6.0.0",
|
"pyyaml>=6.0.0",
|
||||||
|
"packaging>=25.0",
|
||||||
]
|
]
|
||||||
classifiers = [
|
classifiers = [
|
||||||
"Development Status :: 3 - Alpha",
|
"Development Status :: 3 - Alpha",
|
||||||
@@ -58,7 +59,6 @@ exclude = [
|
|||||||
|
|
||||||
[tool.hatch.build.targets.wheel]
|
[tool.hatch.build.targets.wheel]
|
||||||
packages = ["src/airflow_mcp_server"]
|
packages = ["src/airflow_mcp_server"]
|
||||||
package-data = {"airflow_mcp_server"= ["*.yaml"]}
|
|
||||||
|
|
||||||
[tool.hatch.build.targets.wheel.sources]
|
[tool.hatch.build.targets.wheel.sources]
|
||||||
"src/airflow_mcp_server" = "airflow_mcp_server"
|
"src/airflow_mcp_server" = "airflow_mcp_server"
|
||||||
|
|||||||
@@ -15,10 +15,8 @@ from airflow_mcp_server.server_unsafe import serve as serve_unsafe
|
|||||||
@click.option("--safe", "-s", is_flag=True, help="Use only read-only tools")
|
@click.option("--safe", "-s", is_flag=True, help="Use only read-only tools")
|
||||||
@click.option("--unsafe", "-u", is_flag=True, help="Use all tools (default)")
|
@click.option("--unsafe", "-u", is_flag=True, help="Use all tools (default)")
|
||||||
@click.option("--base-url", help="Airflow API base URL")
|
@click.option("--base-url", help="Airflow API base URL")
|
||||||
@click.option("--spec-path", help="Path to OpenAPI spec file")
|
@click.option("--auth-token", help="Authentication token (JWT)")
|
||||||
@click.option("--auth-token", help="Authentication token")
|
def main(verbose: int, safe: bool, unsafe: bool, base_url: str = None, auth_token: str = None) -> None:
|
||||||
@click.option("--cookie", help="Session cookie")
|
|
||||||
def main(verbose: int, safe: bool, unsafe: bool, base_url: str = None, spec_path: str = None, auth_token: str = None, cookie: str = None) -> None:
|
|
||||||
"""MCP server for Airflow"""
|
"""MCP server for Airflow"""
|
||||||
logging_level = logging.WARN
|
logging_level = logging.WARN
|
||||||
if verbose == 1:
|
if verbose == 1:
|
||||||
@@ -29,22 +27,18 @@ def main(verbose: int, safe: bool, unsafe: bool, base_url: str = None, spec_path
|
|||||||
logging.basicConfig(level=logging_level, stream=sys.stderr)
|
logging.basicConfig(level=logging_level, stream=sys.stderr)
|
||||||
|
|
||||||
# Read environment variables with proper precedence
|
# Read environment variables with proper precedence
|
||||||
# Environment variables take precedence over CLI arguments
|
|
||||||
config_base_url = os.environ.get("AIRFLOW_BASE_URL") or base_url
|
config_base_url = os.environ.get("AIRFLOW_BASE_URL") or base_url
|
||||||
config_spec_path = os.environ.get("OPENAPI_SPEC") or spec_path
|
|
||||||
config_auth_token = os.environ.get("AUTH_TOKEN") or auth_token
|
config_auth_token = os.environ.get("AUTH_TOKEN") or auth_token
|
||||||
config_cookie = os.environ.get("COOKIE") or cookie
|
|
||||||
|
|
||||||
# Initialize configuration
|
# Initialize configuration
|
||||||
try:
|
try:
|
||||||
config = AirflowConfig(base_url=config_base_url, spec_path=config_spec_path, auth_token=config_auth_token, cookie=config_cookie)
|
config = AirflowConfig(base_url=config_base_url, auth_token=config_auth_token)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
click.echo(f"Configuration error: {e}", err=True)
|
click.echo(f"Configuration error: {e}", err=True)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
# Determine server mode with proper precedence
|
# Determine server mode with proper precedence
|
||||||
if safe and unsafe:
|
if safe and unsafe:
|
||||||
# CLI argument validation
|
|
||||||
raise click.UsageError("Options --safe and --unsafe are mutually exclusive")
|
raise click.UsageError("Options --safe and --unsafe are mutually exclusive")
|
||||||
elif safe:
|
elif safe:
|
||||||
# CLI argument for safe mode
|
# CLI argument for safe mode
|
||||||
|
|||||||
@@ -1,11 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from pathlib import Path
|
|
||||||
from types import SimpleNamespace
|
|
||||||
from typing import Any, BinaryIO, TextIO
|
|
||||||
|
|
||||||
import aiohttp
|
import httpx
|
||||||
import yaml
|
|
||||||
from jsonschema_path import SchemaPath
|
from jsonschema_path import SchemaPath
|
||||||
from openapi_core import OpenAPI
|
from openapi_core import OpenAPI
|
||||||
from openapi_core.validation.request.validators import V31RequestValidator
|
from openapi_core.validation.request.validators import V31RequestValidator
|
||||||
@@ -29,148 +25,95 @@ def convert_dict_keys(d: dict) -> dict:
|
|||||||
|
|
||||||
|
|
||||||
class AirflowClient:
|
class AirflowClient:
|
||||||
"""Client for interacting with Airflow API."""
|
"""Async client for interacting with Airflow API."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
spec_path: Path | str | dict | bytes | BinaryIO | TextIO,
|
|
||||||
base_url: str,
|
base_url: str,
|
||||||
auth_token: str | None = None,
|
auth_token: str,
|
||||||
cookie: str | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize Airflow client.
|
"""Initialize Airflow client.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
spec_path: OpenAPI spec as file path, dict, bytes, or file object
|
|
||||||
base_url: Base URL for API
|
base_url: Base URL for API
|
||||||
auth_token: Authentication token (optional if cookie is provided)
|
auth_token: Authentication token (JWT)
|
||||||
cookie: Session cookie (optional if auth_token is provided)
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If spec_path is invalid or spec cannot be loaded or if neither auth_token nor cookie is provided
|
ValueError: If required configuration is missing or OpenAPI spec cannot be loaded
|
||||||
"""
|
"""
|
||||||
if not auth_token and not cookie:
|
if not base_url:
|
||||||
raise ValueError("Either auth_token or cookie must be provided")
|
raise ValueError("Missing required configuration: base_url")
|
||||||
try:
|
if not auth_token:
|
||||||
# Load and parse OpenAPI spec
|
raise ValueError("Missing required configuration: auth_token (JWT)")
|
||||||
if isinstance(spec_path, dict):
|
self.base_url = base_url
|
||||||
self.raw_spec = spec_path
|
self.auth_token = auth_token
|
||||||
elif isinstance(spec_path, bytes):
|
self.headers = {"Authorization": f"Bearer {self.auth_token}"}
|
||||||
self.raw_spec = yaml.safe_load(spec_path)
|
self._client: httpx.AsyncClient | None = None
|
||||||
elif isinstance(spec_path, str | Path):
|
self.raw_spec = None
|
||||||
with open(spec_path) as f:
|
self.spec = None
|
||||||
self.raw_spec = yaml.safe_load(f)
|
self._paths = None
|
||||||
elif hasattr(spec_path, "read"):
|
self._validator = None
|
||||||
content = spec_path.read()
|
|
||||||
if isinstance(content, bytes):
|
|
||||||
self.raw_spec = yaml.safe_load(content)
|
|
||||||
else:
|
|
||||||
self.raw_spec = yaml.safe_load(content)
|
|
||||||
else:
|
|
||||||
raise ValueError("Invalid spec_path type. Expected Path, str, dict, bytes or file-like object")
|
|
||||||
|
|
||||||
# Validate spec has required fields
|
async def __aenter__(self):
|
||||||
if not isinstance(self.raw_spec, dict):
|
self._client = httpx.AsyncClient(headers=self.headers)
|
||||||
raise ValueError("OpenAPI spec must be a dictionary")
|
await self._initialize_spec()
|
||||||
|
|
||||||
required_fields = ["openapi", "info", "paths"]
|
|
||||||
for field in required_fields:
|
|
||||||
if field not in self.raw_spec:
|
|
||||||
raise ValueError(f"OpenAPI spec missing required field: {field}")
|
|
||||||
|
|
||||||
# Validate OpenAPI spec format
|
|
||||||
validate(self.raw_spec)
|
|
||||||
|
|
||||||
# Initialize OpenAPI spec
|
|
||||||
self.spec = OpenAPI.from_dict(self.raw_spec)
|
|
||||||
logger.debug("OpenAPI spec loaded successfully")
|
|
||||||
|
|
||||||
# Debug raw spec
|
|
||||||
logger.debug("Raw spec keys: %s", self.raw_spec.keys())
|
|
||||||
|
|
||||||
# Get paths from raw spec
|
|
||||||
if "paths" not in self.raw_spec:
|
|
||||||
raise ValueError("OpenAPI spec does not contain paths information")
|
|
||||||
self._paths = self.raw_spec["paths"]
|
|
||||||
logger.debug("Using raw spec paths")
|
|
||||||
|
|
||||||
# Initialize request validator with schema path
|
|
||||||
schema_path = SchemaPath.from_dict(self.raw_spec)
|
|
||||||
self._validator = V31RequestValidator(schema_path)
|
|
||||||
|
|
||||||
# API configuration
|
|
||||||
self.base_url = base_url.rstrip("/")
|
|
||||||
self.headers = {"Accept": "application/json"}
|
|
||||||
|
|
||||||
# Set authentication header based on precedence (cookie > auth_token)
|
|
||||||
if cookie:
|
|
||||||
self.headers["Cookie"] = cookie
|
|
||||||
elif auth_token:
|
|
||||||
self.headers["Authorization"] = f"Basic {auth_token}"
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Failed to initialize AirflowClient: %s", e)
|
|
||||||
raise ValueError(f"Failed to initialize client: {e}")
|
|
||||||
|
|
||||||
async def __aenter__(self) -> "AirflowClient":
|
|
||||||
self._session = aiohttp.ClientSession(headers=self.headers)
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def __aexit__(self, *exc) -> None:
|
async def __aexit__(self, exc_type, exc, tb):
|
||||||
if hasattr(self, "_session"):
|
if self._client:
|
||||||
await self._session.close()
|
await self._client.aclose()
|
||||||
delattr(self, "_session")
|
self._client = None
|
||||||
|
|
||||||
def _get_operation(self, operation_id: str) -> tuple[str, str, SimpleNamespace]:
|
async def _initialize_spec(self):
|
||||||
"""Get operation details from OpenAPI spec.
|
openapi_url = f"{self.base_url.rstrip('/')}/openapi.json"
|
||||||
|
self.raw_spec = await self._fetch_openapi_spec(openapi_url)
|
||||||
|
if not isinstance(self.raw_spec, dict):
|
||||||
|
raise ValueError("OpenAPI spec must be a dictionary")
|
||||||
|
required_fields = ["openapi", "info", "paths"]
|
||||||
|
for field in required_fields:
|
||||||
|
if field not in self.raw_spec:
|
||||||
|
raise ValueError(f"OpenAPI spec missing required field: {field}")
|
||||||
|
validate(self.raw_spec)
|
||||||
|
self.spec = OpenAPI.from_dict(self.raw_spec)
|
||||||
|
logger.debug("OpenAPI spec loaded successfully")
|
||||||
|
if "paths" not in self.raw_spec:
|
||||||
|
raise ValueError("OpenAPI spec does not contain paths information")
|
||||||
|
self._paths = self.raw_spec["paths"]
|
||||||
|
logger.debug("Using raw spec paths")
|
||||||
|
schema_path = SchemaPath.from_dict(self.raw_spec)
|
||||||
|
self._validator = V31RequestValidator(schema_path)
|
||||||
|
|
||||||
Args:
|
async def _fetch_openapi_spec(self, url: str) -> dict:
|
||||||
operation_id: The operation ID to look up
|
if not self._client:
|
||||||
|
self._client = httpx.AsyncClient(headers=self.headers)
|
||||||
Returns:
|
|
||||||
Tuple of (path, method, operation) where operation is a SimpleNamespace object
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If operation not found
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
# Debug the paths structure
|
response = await self._client.get(url)
|
||||||
logger.debug("Looking for operation %s in paths", operation_id)
|
response.raise_for_status()
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
raise ValueError(f"Failed to fetch OpenAPI spec from {url}: {e}")
|
||||||
|
return response.json()
|
||||||
|
|
||||||
for path, path_item in self._paths.items():
|
def _get_operation(self, operation_id: str):
|
||||||
for method, operation_data in path_item.items():
|
"""Get operation details from OpenAPI spec."""
|
||||||
# Skip non-operation fields
|
for path, path_item in self._paths.items():
|
||||||
if method.startswith("x-") or method == "parameters":
|
for method, operation_data in path_item.items():
|
||||||
continue
|
if method.startswith("x-") or method == "parameters":
|
||||||
|
continue
|
||||||
|
if operation_data.get("operationId") == operation_id:
|
||||||
|
converted_data = convert_dict_keys(operation_data)
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
# Debug each operation
|
operation_obj = SimpleNamespace(**converted_data)
|
||||||
logger.debug("Checking %s %s: %s", method, path, operation_data.get("operationId"))
|
return path, method, operation_obj
|
||||||
|
raise ValueError(f"Operation {operation_id} not found in spec")
|
||||||
|
|
||||||
if operation_data.get("operationId") == operation_id:
|
def _validate_path_params(self, path: str, params: dict | None) -> None:
|
||||||
logger.debug("Found operation %s at %s %s", operation_id, method, path)
|
|
||||||
# Convert keys to snake_case and create object
|
|
||||||
converted_data = convert_dict_keys(operation_data)
|
|
||||||
operation_obj = SimpleNamespace(**converted_data)
|
|
||||||
return path, method, operation_obj
|
|
||||||
|
|
||||||
raise ValueError(f"Operation {operation_id} not found in spec")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Error getting operation %s: %s", operation_id, e)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _validate_path_params(self, path: str, params: dict[str, Any] | None) -> None:
|
|
||||||
if not params:
|
if not params:
|
||||||
params = {}
|
params = {}
|
||||||
|
|
||||||
# Extract path parameter names from the path
|
|
||||||
path_params = set(re.findall(r"{([^}]+)}", path))
|
path_params = set(re.findall(r"{([^}]+)}", path))
|
||||||
|
|
||||||
# Check for missing required parameters
|
|
||||||
missing_params = path_params - set(params.keys())
|
missing_params = path_params - set(params.keys())
|
||||||
if missing_params:
|
if missing_params:
|
||||||
raise ValueError(f"Missing required path parameters: {missing_params}")
|
raise ValueError(f"Missing required path parameters: {missing_params}")
|
||||||
|
|
||||||
# Check for invalid parameters
|
|
||||||
invalid_params = set(params.keys()) - path_params
|
invalid_params = set(params.keys()) - path_params
|
||||||
if invalid_params:
|
if invalid_params:
|
||||||
raise ValueError(f"Invalid path parameters: {invalid_params}")
|
raise ValueError(f"Invalid path parameters: {invalid_params}")
|
||||||
@@ -178,77 +121,42 @@ class AirflowClient:
|
|||||||
async def execute(
|
async def execute(
|
||||||
self,
|
self,
|
||||||
operation_id: str,
|
operation_id: str,
|
||||||
path_params: dict[str, Any] | None = None,
|
path_params: dict = None,
|
||||||
query_params: dict[str, Any] | None = None,
|
query_params: dict = None,
|
||||||
body: dict[str, Any] | None = None,
|
body: dict = None,
|
||||||
) -> Any:
|
) -> dict:
|
||||||
"""Execute an API operation.
|
"""Execute an API operation."""
|
||||||
|
if not self._client:
|
||||||
Args:
|
|
||||||
operation_id: Operation ID from OpenAPI spec
|
|
||||||
path_params: URL path parameters
|
|
||||||
query_params: URL query parameters
|
|
||||||
body: Request body data
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
API response data
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If operation not found
|
|
||||||
RuntimeError: If used outside async context
|
|
||||||
aiohttp.ClientError: For HTTP/network errors
|
|
||||||
"""
|
|
||||||
if not hasattr(self, "_session") or not self._session:
|
|
||||||
raise RuntimeError("Client not in async context")
|
raise RuntimeError("Client not in async context")
|
||||||
|
# Default all params to empty dict if None
|
||||||
|
path_params = path_params or {}
|
||||||
|
query_params = query_params or {}
|
||||||
|
body = body or {}
|
||||||
|
path, method, _ = self._get_operation(operation_id)
|
||||||
|
self._validate_path_params(path, path_params)
|
||||||
|
if path_params:
|
||||||
|
path = path.format(**path_params)
|
||||||
|
url = f"{self.base_url.rstrip('/')}{path}"
|
||||||
|
request_headers = self.headers.copy()
|
||||||
|
if body:
|
||||||
|
request_headers["Content-Type"] = "application/json"
|
||||||
try:
|
try:
|
||||||
# Get operation details
|
response = await self._client.request(
|
||||||
path, method, _ = self._get_operation(operation_id)
|
method=method.upper(),
|
||||||
|
|
||||||
# Validate path parameters
|
|
||||||
self._validate_path_params(path, path_params)
|
|
||||||
|
|
||||||
# Format URL
|
|
||||||
if path_params:
|
|
||||||
path = path.format(**path_params)
|
|
||||||
url = f"{self.base_url}{path}"
|
|
||||||
|
|
||||||
logger.debug("Executing %s %s", method, url)
|
|
||||||
logger.debug("Request body: %s", body)
|
|
||||||
logger.debug("Request query params: %s", query_params)
|
|
||||||
|
|
||||||
# Dynamically set headers based on presence of body
|
|
||||||
request_headers = self.headers.copy()
|
|
||||||
if body is not None:
|
|
||||||
request_headers["Content-Type"] = "application/json"
|
|
||||||
# Make request
|
|
||||||
async with self._session.request(
|
|
||||||
method=method,
|
|
||||||
url=url,
|
url=url,
|
||||||
params=query_params,
|
params=query_params,
|
||||||
json=body,
|
json=body,
|
||||||
) as response:
|
headers=request_headers,
|
||||||
response.raise_for_status()
|
)
|
||||||
content_type = response.headers.get("Content-Type", "").lower()
|
response.raise_for_status()
|
||||||
# Status codes that typically have no body
|
content_type = response.headers.get("content-type", "").lower()
|
||||||
no_body_statuses = {204}
|
if response.status_code == 204:
|
||||||
if response.status in no_body_statuses:
|
return response.status_code
|
||||||
if content_type and "application/json" in content_type:
|
if "application/json" in content_type:
|
||||||
logger.warning("Unexpected JSON body with status %s", response.status)
|
return response.json()
|
||||||
return await response.json() # Parse if present, though rare
|
return {"content": await response.aread()}
|
||||||
logger.debug("Received %s response with no body", response.status)
|
except httpx.HTTPStatusError as e:
|
||||||
return response.status
|
logger.error("HTTP error executing operation %s: %s", operation_id, e)
|
||||||
# For statuses expecting a body, check mimetype
|
|
||||||
if "application/json" in content_type:
|
|
||||||
logger.debug("Response: %s", await response.text())
|
|
||||||
return await response.json()
|
|
||||||
# Unexpected mimetype with body
|
|
||||||
response_text = await response.text()
|
|
||||||
logger.error("Unexpected mimetype %s for status %s: %s", content_type, response.status, response_text)
|
|
||||||
raise ValueError(f"Cannot parse response with mimetype {content_type} as JSON")
|
|
||||||
|
|
||||||
except aiohttp.ClientError as e:
|
|
||||||
logger.error("Error executing operation %s: %s", operation_id, e)
|
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error executing operation %s: %s", operation_id, e)
|
logger.error("Error executing operation %s: %s", operation_id, e)
|
||||||
|
|||||||
@@ -1,14 +1,12 @@
|
|||||||
class AirflowConfig:
|
class AirflowConfig:
|
||||||
"""Centralized configuration for Airflow MCP server."""
|
"""Centralized configuration for Airflow MCP server."""
|
||||||
|
|
||||||
def __init__(self, base_url: str | None = None, spec_path: str | None = None, auth_token: str | None = None, cookie: str | None = None) -> None:
|
def __init__(self, base_url: str | None = None, auth_token: str | None = None) -> None:
|
||||||
"""Initialize configuration with provided values.
|
"""Initialize configuration with provided values.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
base_url: Airflow API base URL
|
base_url: Airflow API base URL
|
||||||
spec_path: Path to OpenAPI spec file
|
auth_token: Authentication token (JWT)
|
||||||
auth_token: Authentication token
|
|
||||||
cookie: Session cookie
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If required configuration is missing
|
ValueError: If required configuration is missing
|
||||||
@@ -17,9 +15,6 @@ class AirflowConfig:
|
|||||||
if not self.base_url:
|
if not self.base_url:
|
||||||
raise ValueError("Missing required configuration: base_url")
|
raise ValueError("Missing required configuration: base_url")
|
||||||
|
|
||||||
self.spec_path = spec_path
|
|
||||||
self.auth_token = auth_token
|
self.auth_token = auth_token
|
||||||
self.cookie = cookie
|
if not self.auth_token:
|
||||||
|
raise ValueError("Missing required configuration: auth_token (JWT)")
|
||||||
if not self.auth_token and not self.cookie:
|
|
||||||
raise ValueError("Either auth_token or cookie must be provided")
|
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ class OperationDetails:
|
|||||||
method: str
|
method: str
|
||||||
parameters: dict[str, Any]
|
parameters: dict[str, Any]
|
||||||
input_model: type[BaseModel]
|
input_model: type[BaseModel]
|
||||||
|
description: str
|
||||||
|
|
||||||
|
|
||||||
class OperationParser:
|
class OperationParser:
|
||||||
@@ -104,6 +105,7 @@ class OperationParser:
|
|||||||
|
|
||||||
operation["path"] = path
|
operation["path"] = path
|
||||||
operation["path_item"] = path_item
|
operation["path_item"] = path_item
|
||||||
|
description = operation.get("description") or operation.get("summary") or operation_id
|
||||||
|
|
||||||
parameters = self.extract_parameters(operation)
|
parameters = self.extract_parameters(operation)
|
||||||
|
|
||||||
@@ -119,7 +121,7 @@ class OperationParser:
|
|||||||
# Create unified input model
|
# Create unified input model
|
||||||
input_model = self._create_input_model(operation_id, parameters, body_schema)
|
input_model = self._create_input_model(operation_id, parameters, body_schema)
|
||||||
|
|
||||||
return OperationDetails(operation_id=operation_id, path=str(path), method=method, parameters=parameters, input_model=input_model)
|
return OperationDetails(operation_id=operation_id, path=str(path), method=method, parameters=parameters, description=description, input_model=input_model)
|
||||||
|
|
||||||
raise ValueError(f"Operation {operation_id} not found in spec")
|
raise ValueError(f"Operation {operation_id} not found in spec")
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,7 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
class BaseTools(ABC):
|
class BaseTools(ABC):
|
||||||
"""Abstract base class for tools."""
|
"""Abstract base class for tools."""
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from importlib import resources
|
|
||||||
|
|
||||||
from mcp.types import Tool
|
from mcp.types import Tool
|
||||||
|
from packaging.version import parse as parse_version
|
||||||
|
|
||||||
from airflow_mcp_server.client.airflow_client import AirflowClient
|
from airflow_mcp_server.client.airflow_client import AirflowClient
|
||||||
from airflow_mcp_server.config import AirflowConfig
|
from airflow_mcp_server.config import AirflowConfig
|
||||||
@@ -13,65 +13,16 @@ logger = logging.getLogger(__name__)
|
|||||||
_tools_cache: dict[str, AirflowTool] = {}
|
_tools_cache: dict[str, AirflowTool] = {}
|
||||||
|
|
||||||
|
|
||||||
def _initialize_client(config: AirflowConfig) -> AirflowClient:
|
|
||||||
"""Initialize Airflow client with configuration.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config: Configuration object with auth and URL settings
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
AirflowClient instance
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If default spec is not found
|
|
||||||
"""
|
|
||||||
spec_path = config.spec_path
|
|
||||||
if not spec_path:
|
|
||||||
# Fallback to embedded v1.yaml
|
|
||||||
try:
|
|
||||||
with resources.files("airflow_mcp_server.resources").joinpath("v1.yaml").open("rb") as f:
|
|
||||||
spec_path = f.name
|
|
||||||
logger.info("OPENAPI_SPEC not set; using embedded v1.yaml from %s", spec_path)
|
|
||||||
except Exception as e:
|
|
||||||
raise ValueError("Default OpenAPI spec not found in package resources") from e
|
|
||||||
|
|
||||||
# Initialize client with appropriate authentication method
|
|
||||||
client_args = {"spec_path": spec_path, "base_url": config.base_url}
|
|
||||||
|
|
||||||
# Apply cookie auth first if available (highest precedence)
|
|
||||||
if config.cookie:
|
|
||||||
client_args["cookie"] = config.cookie
|
|
||||||
# Otherwise use auth token if available
|
|
||||||
elif config.auth_token:
|
|
||||||
client_args["auth_token"] = config.auth_token
|
|
||||||
|
|
||||||
return AirflowClient(**client_args)
|
|
||||||
|
|
||||||
|
|
||||||
async def _initialize_tools(config: AirflowConfig) -> None:
|
async def _initialize_tools(config: AirflowConfig) -> None:
|
||||||
"""Initialize tools cache with Airflow operations.
|
"""Initialize tools cache with Airflow operations (async)."""
|
||||||
|
|
||||||
Args:
|
|
||||||
config: Configuration object with auth and URL settings
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If initialization fails
|
|
||||||
"""
|
|
||||||
global _tools_cache
|
global _tools_cache
|
||||||
|
|
||||||
try:
|
try:
|
||||||
client = _initialize_client(config)
|
async with AirflowClient(base_url=config.base_url, auth_token=config.auth_token) as client:
|
||||||
spec_path = config.spec_path
|
parser = OperationParser(client.raw_spec)
|
||||||
if not spec_path:
|
for operation_id in parser.get_operations():
|
||||||
with resources.files("airflow_mcp_server.resources").joinpath("v1.yaml").open("rb") as f:
|
operation_details = parser.parse_operation(operation_id)
|
||||||
spec_path = f.name
|
tool = AirflowTool(operation_details, client)
|
||||||
parser = OperationParser(spec_path)
|
_tools_cache[operation_id] = tool
|
||||||
|
|
||||||
# Generate tools for each operation
|
|
||||||
for operation_id in parser.get_operations():
|
|
||||||
operation_details = parser.parse_operation(operation_id)
|
|
||||||
tool = AirflowTool(operation_details, client)
|
|
||||||
_tools_cache[operation_id] = tool
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to initialize tools: %s", e)
|
logger.error("Failed to initialize tools: %s", e)
|
||||||
@@ -92,9 +43,22 @@ async def get_airflow_tools(config: AirflowConfig, mode: str = "unsafe") -> list
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If initialization fails
|
ValueError: If initialization fails
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Version check before returning tools
|
||||||
if not _tools_cache:
|
if not _tools_cache:
|
||||||
await _initialize_tools(config)
|
await _initialize_tools(config)
|
||||||
|
|
||||||
|
# Only check version if get_version tool is present
|
||||||
|
if "get_version" in _tools_cache:
|
||||||
|
version_tool = _tools_cache["get_version"]
|
||||||
|
async with version_tool.client:
|
||||||
|
version_result = await version_tool.run()
|
||||||
|
airflow_version = version_result.get("version")
|
||||||
|
if airflow_version is None:
|
||||||
|
raise RuntimeError("Could not determine Airflow version from get_version tool.")
|
||||||
|
if parse_version(airflow_version) <= parse_version("3.1.0"):
|
||||||
|
raise RuntimeError(f"Airflow version {airflow_version} is not supported. Requires >= 3.1.0.")
|
||||||
|
|
||||||
tools = []
|
tools = []
|
||||||
for operation_id, tool in _tools_cache.items():
|
for operation_id, tool in _tools_cache.items():
|
||||||
try:
|
try:
|
||||||
@@ -105,7 +69,7 @@ async def get_airflow_tools(config: AirflowConfig, mode: str = "unsafe") -> list
|
|||||||
tools.append(
|
tools.append(
|
||||||
Tool(
|
Tool(
|
||||||
name=operation_id,
|
name=operation_id,
|
||||||
description=tool.operation.operation_id,
|
description=tool.operation.description,
|
||||||
inputSchema=schema,
|
inputSchema=schema,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,211 +1,69 @@
|
|||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from importlib import resources
|
from unittest.mock import patch
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import aiohttp
|
|
||||||
import pytest
|
import pytest
|
||||||
import yaml
|
|
||||||
from aioresponses import aioresponses
|
|
||||||
from airflow_mcp_server.client.airflow_client import AirflowClient
|
|
||||||
from openapi_core import OpenAPI
|
from openapi_core import OpenAPI
|
||||||
|
|
||||||
|
from airflow_mcp_server.client.airflow_client import AirflowClient
|
||||||
|
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
|
||||||
|
|
||||||
def create_valid_spec(paths: dict[str, Any] | None = None) -> dict[str, Any]:
|
@pytest.mark.asyncio
|
||||||
return {"openapi": "3.0.0", "info": {"title": "Airflow API", "version": "1.0.0"}, "paths": paths or {}}
|
async def test_async_multiple_clients_concurrent():
|
||||||
|
"""Test initializing two AirflowClients concurrently to verify async power."""
|
||||||
|
|
||||||
|
async def mock_get(self, url, *args, **kwargs):
|
||||||
|
class MockResponse:
|
||||||
|
def __init__(self):
|
||||||
|
self.status_code = 200
|
||||||
|
|
||||||
|
def raise_for_status(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def json(self):
|
||||||
|
return {"openapi": "3.1.0", "info": {"title": "Airflow API", "version": "2.0.0"}, "paths": {}}
|
||||||
|
|
||||||
|
return MockResponse()
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient.get", new=mock_get):
|
||||||
|
|
||||||
|
async def create_and_check():
|
||||||
|
async with AirflowClient(base_url="http://localhost:8080", auth_token="token") as client:
|
||||||
|
assert client.base_url == "http://localhost:8080"
|
||||||
|
assert client.headers["Authorization"] == "Bearer token"
|
||||||
|
assert isinstance(client.spec, OpenAPI)
|
||||||
|
|
||||||
|
# Run two clients concurrently
|
||||||
|
await asyncio.gather(create_and_check(), create_and_check())
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.mark.asyncio
|
||||||
def client() -> AirflowClient:
|
async def test_async_client_initialization():
|
||||||
with resources.files("airflow_mcp_server.resources").joinpath("v1.yaml").open("rb") as f:
|
async def mock_get(self, url, *args, **kwargs):
|
||||||
spec = yaml.safe_load(f)
|
class MockResponse:
|
||||||
return AirflowClient(
|
def __init__(self):
|
||||||
spec_path=spec,
|
self.status_code = 200
|
||||||
base_url="http://localhost:8080/api/v1",
|
|
||||||
auth_token="test-token",
|
def raise_for_status(self):
|
||||||
)
|
pass
|
||||||
|
|
||||||
|
def json(self):
|
||||||
|
return {"openapi": "3.1.0", "info": {"title": "Airflow API", "version": "2.0.0"}, "paths": {}}
|
||||||
|
|
||||||
|
return MockResponse()
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient.get", new=mock_get):
|
||||||
|
async with AirflowClient(base_url="http://localhost:8080", auth_token="test-token") as client:
|
||||||
|
assert client.base_url == "http://localhost:8080"
|
||||||
|
assert client.headers["Authorization"] == "Bearer test-token"
|
||||||
|
assert isinstance(client.spec, OpenAPI)
|
||||||
|
|
||||||
|
|
||||||
def test_init_client_initialization(client: AirflowClient) -> None:
|
def test_init_client_missing_auth():
|
||||||
assert isinstance(client.spec, OpenAPI)
|
with pytest.raises(ValueError, match="auth_token"):
|
||||||
assert client.base_url == "http://localhost:8080/api/v1"
|
|
||||||
assert client.headers["Authorization"] == "Basic test-token"
|
|
||||||
assert "Cookie" not in client.headers
|
|
||||||
|
|
||||||
|
|
||||||
def test_init_client_with_cookie() -> None:
|
|
||||||
with resources.files("airflow_mcp_server.resources").joinpath("v1.yaml").open("rb") as f:
|
|
||||||
spec = yaml.safe_load(f)
|
|
||||||
client = AirflowClient(
|
|
||||||
spec_path=spec,
|
|
||||||
base_url="http://localhost:8080/api/v1",
|
|
||||||
cookie="session=b18e8c5e-92f5-4d1e-a8f2-7c1b62110cae.vmX5kqDq5TdvT9BzTlypMVclAwM",
|
|
||||||
)
|
|
||||||
assert isinstance(client.spec, OpenAPI)
|
|
||||||
assert client.base_url == "http://localhost:8080/api/v1"
|
|
||||||
assert "Authorization" not in client.headers
|
|
||||||
assert client.headers["Cookie"] == "session=b18e8c5e-92f5-4d1e-a8f2-7c1b62110cae.vmX5kqDq5TdvT9BzTlypMVclAwM"
|
|
||||||
|
|
||||||
|
|
||||||
def test_init_client_missing_auth() -> None:
|
|
||||||
with resources.files("airflow_mcp_server.resources").joinpath("v1.yaml").open("rb") as f:
|
|
||||||
spec = yaml.safe_load(f)
|
|
||||||
with pytest.raises(ValueError, match="Either auth_token or cookie must be provided"):
|
|
||||||
AirflowClient(
|
AirflowClient(
|
||||||
spec_path=spec,
|
base_url="http://localhost:8080",
|
||||||
base_url="http://localhost:8080/api/v1",
|
auth_token=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_init_load_spec_from_bytes() -> None:
|
|
||||||
spec_bytes = yaml.dump(create_valid_spec()).encode()
|
|
||||||
client = AirflowClient(spec_path=spec_bytes, base_url="http://test", auth_token="test")
|
|
||||||
assert client.raw_spec is not None
|
|
||||||
|
|
||||||
|
|
||||||
def test_init_load_spec_from_path(tmp_path: Path) -> None:
|
|
||||||
spec_file = tmp_path / "test_spec.yaml"
|
|
||||||
spec_file.write_text(yaml.dump(create_valid_spec()))
|
|
||||||
client = AirflowClient(spec_path=spec_file, base_url="http://test", auth_token="test")
|
|
||||||
assert client.raw_spec is not None
|
|
||||||
|
|
||||||
|
|
||||||
def test_init_invalid_spec() -> None:
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
AirflowClient(spec_path={"invalid": "spec"}, base_url="http://test", auth_token="test")
|
|
||||||
|
|
||||||
|
|
||||||
def test_init_missing_paths_in_spec() -> None:
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
AirflowClient(spec_path={"openapi": "3.0.0"}, base_url="http://test", auth_token="test")
|
|
||||||
|
|
||||||
|
|
||||||
def test_ops_get_operation(client: AirflowClient) -> None:
|
|
||||||
path, method, operation = client._get_operation("get_dags")
|
|
||||||
assert path == "/dags"
|
|
||||||
assert method == "get"
|
|
||||||
assert operation.operation_id == "get_dags"
|
|
||||||
|
|
||||||
path, method, operation = client._get_operation("get_dag")
|
|
||||||
assert path == "/dags/{dag_id}"
|
|
||||||
assert method == "get"
|
|
||||||
assert operation.operation_id == "get_dag"
|
|
||||||
|
|
||||||
|
|
||||||
def test_ops_nonexistent_operation(client: AirflowClient) -> None:
|
|
||||||
with pytest.raises(ValueError, match="Operation nonexistent not found in spec"):
|
|
||||||
client._get_operation("nonexistent")
|
|
||||||
|
|
||||||
|
|
||||||
def test_ops_case_sensitive_operation(client: AirflowClient) -> None:
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
client._get_operation("GET_DAGS")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_exec_without_context() -> None:
|
|
||||||
client = AirflowClient(
|
|
||||||
spec_path=create_valid_spec(),
|
|
||||||
base_url="http://test",
|
|
||||||
auth_token="test",
|
|
||||||
)
|
|
||||||
with pytest.raises(RuntimeError, match="Client not in async context"):
|
|
||||||
await client.execute("get_dags")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_exec_get_dags(client: AirflowClient) -> None:
|
|
||||||
expected_response = {
|
|
||||||
"dags": [
|
|
||||||
{
|
|
||||||
"dag_id": "test_dag",
|
|
||||||
"is_active": True,
|
|
||||||
"is_paused": False,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"total_entries": 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
with aioresponses() as mock:
|
|
||||||
async with client:
|
|
||||||
mock.get(
|
|
||||||
"http://localhost:8080/api/v1/dags?limit=100",
|
|
||||||
status=200,
|
|
||||||
payload=expected_response,
|
|
||||||
)
|
|
||||||
response = await client.execute("get_dags", query_params={"limit": 100})
|
|
||||||
assert response == expected_response
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_exec_get_dag(client: AirflowClient) -> None:
|
|
||||||
expected_response = {
|
|
||||||
"dag_id": "test_dag",
|
|
||||||
"is_active": True,
|
|
||||||
"is_paused": False,
|
|
||||||
}
|
|
||||||
|
|
||||||
with aioresponses() as mock:
|
|
||||||
async with client:
|
|
||||||
mock.get(
|
|
||||||
"http://localhost:8080/api/v1/dags/test_dag",
|
|
||||||
status=200,
|
|
||||||
payload=expected_response,
|
|
||||||
)
|
|
||||||
response = await client.execute(
|
|
||||||
"get_dag",
|
|
||||||
path_params={"dag_id": "test_dag"},
|
|
||||||
)
|
|
||||||
assert response == expected_response
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_exec_invalid_params(client: AirflowClient) -> None:
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
async with client:
|
|
||||||
# Test with missing required parameter
|
|
||||||
await client.execute("get_dag", path_params={})
|
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
async with client:
|
|
||||||
# Test with invalid parameter name
|
|
||||||
await client.execute("get_dag", path_params={"invalid": "value"})
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_exec_timeout(client: AirflowClient) -> None:
|
|
||||||
with aioresponses() as mock:
|
|
||||||
mock.get("http://localhost:8080/api/v1/dags", exception=aiohttp.ClientError("Timeout"))
|
|
||||||
async with client:
|
|
||||||
with pytest.raises(aiohttp.ClientError):
|
|
||||||
await client.execute("get_dags")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_exec_error_response(client: AirflowClient) -> None:
|
|
||||||
with aioresponses() as mock:
|
|
||||||
async with client:
|
|
||||||
mock.get(
|
|
||||||
"http://localhost:8080/api/v1/dags",
|
|
||||||
status=403,
|
|
||||||
body="Forbidden",
|
|
||||||
)
|
|
||||||
with pytest.raises(aiohttp.ClientResponseError):
|
|
||||||
await client.execute("get_dags")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_exec_session_management(client: AirflowClient) -> None:
|
|
||||||
async with client:
|
|
||||||
with aioresponses() as mock:
|
|
||||||
mock.get(
|
|
||||||
"http://localhost:8080/api/v1/dags",
|
|
||||||
status=200,
|
|
||||||
payload={"dags": []},
|
|
||||||
)
|
|
||||||
await client.execute("get_dags")
|
|
||||||
|
|
||||||
with pytest.raises(RuntimeError):
|
|
||||||
await client.execute("get_dags")
|
|
||||||
|
|||||||
17319
tests/parser/openapi.json
Normal file
17319
tests/parser/openapi.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,26 +1,16 @@
|
|||||||
import logging
|
import json
|
||||||
from importlib import resources
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from airflow_mcp_server.parser.operation_parser import OperationDetails, OperationParser
|
from typing import Any
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from airflow_mcp_server.parser.operation_parser import OperationDetails, OperationParser
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def spec_file():
|
def parser() -> OperationParser:
|
||||||
"""Get content of the v1.yaml spec file."""
|
"""Create OperationParser instance from tests/parser/openapi.json."""
|
||||||
with resources.files("airflow_mcp_server.resources").joinpath("v1.yaml").open("rb") as f:
|
with open("tests/parser/openapi.json") as f:
|
||||||
return f.read()
|
spec_dict = json.load(f)
|
||||||
|
return OperationParser(spec_dict)
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def parser(spec_file) -> OperationParser:
|
|
||||||
"""Create OperationParser instance."""
|
|
||||||
return OperationParser(spec_path=spec_file)
|
|
||||||
|
|
||||||
|
|
||||||
def test_parse_operation_basic(parser: OperationParser) -> None:
|
def test_parse_operation_basic(parser: OperationParser) -> None:
|
||||||
@@ -29,8 +19,21 @@ def test_parse_operation_basic(parser: OperationParser) -> None:
|
|||||||
|
|
||||||
assert isinstance(operation, OperationDetails)
|
assert isinstance(operation, OperationDetails)
|
||||||
assert operation.operation_id == "get_dags"
|
assert operation.operation_id == "get_dags"
|
||||||
assert operation.path == "/dags"
|
assert operation.path == "/api/v2/dags"
|
||||||
assert operation.method == "get"
|
assert operation.method == "get"
|
||||||
|
assert operation.description == "Get all DAGs."
|
||||||
|
assert isinstance(operation.parameters, dict)
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_operation_with_no_description_but_summary(parser: OperationParser) -> None:
|
||||||
|
"""Test parsing operation with no description but summary."""
|
||||||
|
operation = parser.parse_operation("get_connections")
|
||||||
|
|
||||||
|
assert isinstance(operation, OperationDetails)
|
||||||
|
assert operation.operation_id == "get_connections"
|
||||||
|
assert operation.path == "/api/v2/connections"
|
||||||
|
assert operation.method == "get"
|
||||||
|
assert operation.description == "Get all connection entries."
|
||||||
assert isinstance(operation.parameters, dict)
|
assert isinstance(operation.parameters, dict)
|
||||||
|
|
||||||
|
|
||||||
@@ -38,7 +41,7 @@ def test_parse_operation_with_path_params(parser: OperationParser) -> None:
|
|||||||
"""Test parsing operation with path parameters."""
|
"""Test parsing operation with path parameters."""
|
||||||
operation = parser.parse_operation("get_dag")
|
operation = parser.parse_operation("get_dag")
|
||||||
|
|
||||||
assert operation.path == "/dags/{dag_id}"
|
assert operation.path == "/api/v2/dags/{dag_id}"
|
||||||
assert isinstance(operation.input_model, type(BaseModel))
|
assert isinstance(operation.input_model, type(BaseModel))
|
||||||
|
|
||||||
# Verify path parameter field exists
|
# Verify path parameter field exists
|
||||||
@@ -65,7 +68,10 @@ def test_parse_operation_with_query_params(parser: OperationParser) -> None:
|
|||||||
|
|
||||||
def test_parse_operation_with_body_params(parser: OperationParser) -> None:
|
def test_parse_operation_with_body_params(parser: OperationParser) -> None:
|
||||||
"""Test parsing operation with request body."""
|
"""Test parsing operation with request body."""
|
||||||
operation = parser.parse_operation("post_dag_run")
|
# Find the correct operationId for posting a dag run in the OpenAPI spec
|
||||||
|
# From the spec, the likely operation is under /api/v2/dags/{dag_id}/dagRuns
|
||||||
|
# Let's use "post_dag_run" if it exists, otherwise use the actual operationId
|
||||||
|
operation = parser.parse_operation("trigger_dag_run")
|
||||||
|
|
||||||
# Verify body fields exist
|
# Verify body fields exist
|
||||||
fields = operation.input_model.__annotations__
|
fields = operation.input_model.__annotations__
|
||||||
@@ -149,7 +155,7 @@ def test_parse_operation_with_allof_body(parser: OperationParser) -> None:
|
|||||||
|
|
||||||
assert isinstance(operation, OperationDetails)
|
assert isinstance(operation, OperationDetails)
|
||||||
assert operation.operation_id == "test_connection"
|
assert operation.operation_id == "test_connection"
|
||||||
assert operation.path == "/connections/test"
|
assert operation.path == "/api/v2/connections/test"
|
||||||
assert operation.method == "post"
|
assert operation.method == "post"
|
||||||
|
|
||||||
# Verify input model includes fields from allOf schema
|
# Verify input model includes fields from allOf schema
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
"""Tests for AirflowTool."""
|
"""Tests for AirflowTool."""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from airflow_mcp_server.client.airflow_client import AirflowClient
|
from airflow_mcp_server.client.airflow_client import AirflowClient
|
||||||
from airflow_mcp_server.parser.operation_parser import OperationDetails
|
from airflow_mcp_server.parser.operation_parser import OperationDetails
|
||||||
from airflow_mcp_server.tools.airflow_tool import AirflowTool
|
from airflow_mcp_server.tools.airflow_tool import AirflowTool
|
||||||
from pydantic import ValidationError
|
|
||||||
|
|
||||||
from tests.tools.test_models import TestRequestModel
|
from tests.tools.test_models import TestRequestModel
|
||||||
|
|
||||||
|
|
||||||
@@ -41,6 +41,7 @@ def operation_details():
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
input_model=model,
|
input_model=model,
|
||||||
|
description="Test operation for AirflowTool",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user