Compare commits
30 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
9ec1cd2020
|
|||
|
c20539c39f
|
|||
|
9837afe13a
|
|||
| f4206ea73c | |||
|
950dc06901
|
|||
|
2031650535
|
|||
|
4263175351
|
|||
|
b5cf563b8f
|
|||
|
d2464ea891
|
|||
|
c5565e6a00
|
|||
|
bba42eea00
|
|||
|
5a864b27c5
|
|||
|
66cd068b33
|
|||
|
4734005ae4
|
|||
| 63ff02fa4b | |||
|
|
407eb00c1b | ||
|
707f3747d7
|
|||
|
a8638d27a5
|
|||
| dbbc3ef5e8 | |||
|
d8887d3a2b
|
|||
|
679523a7c6
|
|||
|
492e79ef2a
|
|||
|
ea60acd54a
|
|||
|
420b6fc68f
|
|||
|
355fb55bdb
|
|||
|
8b38a26e8a
|
|||
|
5663f56621
|
|||
|
2b652c5926
|
|||
| 3fd605b111 | |||
|
c5106f10a8
|
29
.github/workflows/pytest.yml
vendored
Normal file
29
.github/workflows/pytest.yml
vendored
Normal file
@@ -0,0 +1,29 @@
|
||||
name: Run Tests
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.11", "3.12"]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install .[dev]
|
||||
|
||||
- name: Run pytest
|
||||
run: |
|
||||
pytest tests/ -v
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -179,3 +179,6 @@ project_resources/
|
||||
|
||||
# Ruff
|
||||
.ruff_cache/
|
||||
|
||||
# Airflow
|
||||
AIRFLOW_HOME/
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.1.11
|
||||
rev: v0.11.8
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
- id: ruff-format
|
||||
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.5.0
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: end-of-file-fixer
|
||||
|
||||
42
README.md
42
README.md
@@ -6,7 +6,6 @@
|
||||
<img width="380" height="200" src="https://glama.ai/mcp/servers/6gjq9w80xr/badge" />
|
||||
</a>
|
||||
|
||||
|
||||
## Overview
|
||||
A [Model Context Protocol](https://modelcontextprotocol.io/) server for controlling Airflow via Airflow APIs.
|
||||
|
||||
@@ -14,28 +13,32 @@ A [Model Context Protocol](https://modelcontextprotocol.io/) server for controll
|
||||
|
||||
https://github.com/user-attachments/assets/f3e60fff-8680-4dd9-b08e-fa7db655a705
|
||||
|
||||
|
||||
## Setup
|
||||
|
||||
### Usage with Claude Desktop
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"airflow-mcp-server": {
|
||||
"command": "uvx",
|
||||
"args": [
|
||||
"airflow-mcp-server"
|
||||
],
|
||||
"env": {
|
||||
"AIRFLOW_BASE_URL": "http://<host:port>/api/v1",
|
||||
"AUTH_TOKEN": "<base64_encoded_username_password>"
|
||||
}
|
||||
"mcpServers": {
|
||||
"airflow-mcp-server": {
|
||||
"command": "uvx",
|
||||
"args": [
|
||||
"airflow-mcp-server",
|
||||
"--base-url",
|
||||
"http://localhost:8080",
|
||||
"--auth-token",
|
||||
"<jwt_token>"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
> **Note:**
|
||||
> - Set `base_url` to the root Airflow URL (e.g., `http://localhost:8080`).
|
||||
> - Do **not** include `/api/v2` in the base URL. The server will automatically fetch the OpenAPI spec from `${base_url}/openapi.json`.
|
||||
> - Only JWT token is required for authentication. Cookie and basic auth are no longer supported in Airflow 3.0.
|
||||
|
||||
### Operation Modes
|
||||
|
||||
The server supports two operation modes:
|
||||
@@ -55,12 +58,9 @@ airflow-mcp-server --unsafe
|
||||
|
||||
### Considerations
|
||||
|
||||
The MCP Server expects environment variables to be set:
|
||||
- `AIRFLOW_BASE_URL`: The base URL of the Airflow API
|
||||
- `AUTH_TOKEN`: The token to use for authorization (_This should be base64 encoded username:password_)
|
||||
- `OPENAPI_SPEC`: The path to the OpenAPI spec file (_Optional_) (_defaults to latest stable release_)
|
||||
**Authentication**
|
||||
|
||||
*Currently, only Basic Auth is supported.*
|
||||
- Only JWT authentication is supported in Airflow 3.0. You must provide a valid `AUTH_TOKEN`.
|
||||
|
||||
**Page Limit**
|
||||
|
||||
@@ -68,9 +68,9 @@ The default is 100 items, but you can change it using `maximum_page_limit` optio
|
||||
|
||||
## Tasks
|
||||
|
||||
- [x] First API
|
||||
- [x] Airflow 3 readiness
|
||||
- [x] Parse OpenAPI Spec
|
||||
- [x] Safe/Unsafe mode implementation
|
||||
- [ ] Parse proper description with list_tools.
|
||||
- [ ] Airflow config fetch (_specifically for page limit_)
|
||||
- [x] Parse proper description with list_tools.
|
||||
- [x] Airflow config fetch (_specifically for page limit_)
|
||||
- [ ] Env variables optional (_env variables might not be ideal for airflow plugins_)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "airflow-mcp-server"
|
||||
version = "0.3.0"
|
||||
version = "0.6.1"
|
||||
description = "MCP Server for Airflow"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.11"
|
||||
@@ -12,10 +12,11 @@ dependencies = [
|
||||
"aiohttp>=3.11.11",
|
||||
"aioresponses>=0.7.7",
|
||||
"importlib-resources>=6.5.0",
|
||||
"mcp>=1.2.0",
|
||||
"mcp>=1.7.1",
|
||||
"openapi-core>=0.19.4",
|
||||
"pydantic>=2.10.5",
|
||||
"pydantic>=2.11.4",
|
||||
"pyyaml>=6.0.0",
|
||||
"packaging>=25.0",
|
||||
]
|
||||
classifiers = [
|
||||
"Development Status :: 3 - Alpha",
|
||||
@@ -52,12 +53,12 @@ build-backend = "hatchling.build"
|
||||
exclude = [
|
||||
"*",
|
||||
"!src/**",
|
||||
"!pyproject.toml"
|
||||
"!pyproject.toml",
|
||||
"!assets/**"
|
||||
]
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/airflow_mcp_server"]
|
||||
package-data = {"airflow_mcp_server"= ["*.yaml"]}
|
||||
|
||||
[tool.hatch.build.targets.wheel.sources]
|
||||
"src/airflow_mcp_server" = "airflow_mcp_server"
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import click
|
||||
|
||||
from airflow_mcp_server.config import AirflowConfig
|
||||
from airflow_mcp_server.server_safe import serve as serve_safe
|
||||
from airflow_mcp_server.server_unsafe import serve as serve_unsafe
|
||||
|
||||
@@ -12,7 +14,9 @@ from airflow_mcp_server.server_unsafe import serve as serve_unsafe
|
||||
@click.option("-v", "--verbose", count=True, help="Increase verbosity")
|
||||
@click.option("--safe", "-s", is_flag=True, help="Use only read-only tools")
|
||||
@click.option("--unsafe", "-u", is_flag=True, help="Use all tools (default)")
|
||||
def main(verbose: int, safe: bool, unsafe: bool) -> None:
|
||||
@click.option("--base-url", help="Airflow API base URL")
|
||||
@click.option("--auth-token", help="Authentication token (JWT)")
|
||||
def main(verbose: int, safe: bool, unsafe: bool, base_url: str = None, auth_token: str = None) -> None:
|
||||
"""MCP server for Airflow"""
|
||||
logging_level = logging.WARN
|
||||
if verbose == 1:
|
||||
@@ -22,13 +26,26 @@ def main(verbose: int, safe: bool, unsafe: bool) -> None:
|
||||
|
||||
logging.basicConfig(level=logging_level, stream=sys.stderr)
|
||||
|
||||
# Read environment variables with proper precedence
|
||||
config_base_url = os.environ.get("AIRFLOW_BASE_URL") or base_url
|
||||
config_auth_token = os.environ.get("AUTH_TOKEN") or auth_token
|
||||
|
||||
# Initialize configuration
|
||||
try:
|
||||
config = AirflowConfig(base_url=config_base_url, auth_token=config_auth_token)
|
||||
except ValueError as e:
|
||||
click.echo(f"Configuration error: {e}", err=True)
|
||||
sys.exit(1)
|
||||
|
||||
# Determine server mode with proper precedence
|
||||
if safe and unsafe:
|
||||
raise click.UsageError("Options --safe and --unsafe are mutually exclusive")
|
||||
|
||||
if safe:
|
||||
asyncio.run(serve_safe())
|
||||
else: # Default to unsafe mode
|
||||
asyncio.run(serve_unsafe())
|
||||
elif safe:
|
||||
# CLI argument for safe mode
|
||||
asyncio.run(serve_safe(config))
|
||||
else:
|
||||
# Default to unsafe mode
|
||||
asyncio.run(serve_unsafe(config))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,11 +1,7 @@
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, BinaryIO, TextIO
|
||||
|
||||
import aiohttp
|
||||
import yaml
|
||||
import httpx
|
||||
from jsonschema_path import SchemaPath
|
||||
from openapi_core import OpenAPI
|
||||
from openapi_core.validation.request.validators import V31RequestValidator
|
||||
@@ -29,141 +25,95 @@ def convert_dict_keys(d: dict) -> dict:
|
||||
|
||||
|
||||
class AirflowClient:
|
||||
"""Client for interacting with Airflow API."""
|
||||
"""Async client for interacting with Airflow API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
spec_path: Path | str | dict | bytes | BinaryIO | TextIO,
|
||||
base_url: str,
|
||||
auth_token: str,
|
||||
) -> None:
|
||||
"""Initialize Airflow client.
|
||||
|
||||
Args:
|
||||
spec_path: OpenAPI spec as file path, dict, bytes, or file object
|
||||
base_url: Base URL for API
|
||||
auth_token: Authentication token
|
||||
auth_token: Authentication token (JWT)
|
||||
|
||||
Raises:
|
||||
ValueError: If spec_path is invalid or spec cannot be loaded
|
||||
ValueError: If required configuration is missing or OpenAPI spec cannot be loaded
|
||||
"""
|
||||
try:
|
||||
# Load and parse OpenAPI spec
|
||||
if isinstance(spec_path, dict):
|
||||
self.raw_spec = spec_path
|
||||
elif isinstance(spec_path, bytes):
|
||||
self.raw_spec = yaml.safe_load(spec_path)
|
||||
elif isinstance(spec_path, str | Path):
|
||||
with open(spec_path) as f:
|
||||
self.raw_spec = yaml.safe_load(f)
|
||||
elif hasattr(spec_path, "read"):
|
||||
content = spec_path.read()
|
||||
if isinstance(content, bytes):
|
||||
self.raw_spec = yaml.safe_load(content)
|
||||
else:
|
||||
self.raw_spec = yaml.safe_load(content)
|
||||
else:
|
||||
raise ValueError("Invalid spec_path type. Expected Path, str, dict, bytes or file-like object")
|
||||
if not base_url:
|
||||
raise ValueError("Missing required configuration: base_url")
|
||||
if not auth_token:
|
||||
raise ValueError("Missing required configuration: auth_token (JWT)")
|
||||
self.base_url = base_url
|
||||
self.auth_token = auth_token
|
||||
self.headers = {"Authorization": f"Bearer {self.auth_token}"}
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
self.raw_spec = None
|
||||
self.spec = None
|
||||
self._paths = None
|
||||
self._validator = None
|
||||
|
||||
# Validate spec has required fields
|
||||
if not isinstance(self.raw_spec, dict):
|
||||
raise ValueError("OpenAPI spec must be a dictionary")
|
||||
|
||||
required_fields = ["openapi", "info", "paths"]
|
||||
for field in required_fields:
|
||||
if field not in self.raw_spec:
|
||||
raise ValueError(f"OpenAPI spec missing required field: {field}")
|
||||
|
||||
# Validate OpenAPI spec format
|
||||
validate(self.raw_spec)
|
||||
|
||||
# Initialize OpenAPI spec
|
||||
self.spec = OpenAPI.from_dict(self.raw_spec)
|
||||
logger.debug("OpenAPI spec loaded successfully")
|
||||
|
||||
# Debug raw spec
|
||||
logger.debug("Raw spec keys: %s", self.raw_spec.keys())
|
||||
|
||||
# Get paths from raw spec
|
||||
if "paths" not in self.raw_spec:
|
||||
raise ValueError("OpenAPI spec does not contain paths information")
|
||||
self._paths = self.raw_spec["paths"]
|
||||
logger.debug("Using raw spec paths")
|
||||
|
||||
# Initialize request validator with schema path
|
||||
schema_path = SchemaPath.from_dict(self.raw_spec)
|
||||
self._validator = V31RequestValidator(schema_path)
|
||||
|
||||
# API configuration
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.headers = {
|
||||
"Authorization": f"Basic {auth_token}",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to initialize AirflowClient: %s", e)
|
||||
raise ValueError(f"Failed to initialize client: {e}")
|
||||
|
||||
async def __aenter__(self) -> "AirflowClient":
|
||||
self._session = aiohttp.ClientSession(headers=self.headers)
|
||||
async def __aenter__(self):
|
||||
self._client = httpx.AsyncClient(headers=self.headers)
|
||||
await self._initialize_spec()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *exc) -> None:
|
||||
if hasattr(self, "_session"):
|
||||
await self._session.close()
|
||||
delattr(self, "_session")
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
def _get_operation(self, operation_id: str) -> tuple[str, str, SimpleNamespace]:
|
||||
"""Get operation details from OpenAPI spec.
|
||||
async def _initialize_spec(self):
|
||||
openapi_url = f"{self.base_url.rstrip('/')}/openapi.json"
|
||||
self.raw_spec = await self._fetch_openapi_spec(openapi_url)
|
||||
if not isinstance(self.raw_spec, dict):
|
||||
raise ValueError("OpenAPI spec must be a dictionary")
|
||||
required_fields = ["openapi", "info", "paths"]
|
||||
for field in required_fields:
|
||||
if field not in self.raw_spec:
|
||||
raise ValueError(f"OpenAPI spec missing required field: {field}")
|
||||
validate(self.raw_spec)
|
||||
self.spec = OpenAPI.from_dict(self.raw_spec)
|
||||
logger.debug("OpenAPI spec loaded successfully")
|
||||
if "paths" not in self.raw_spec:
|
||||
raise ValueError("OpenAPI spec does not contain paths information")
|
||||
self._paths = self.raw_spec["paths"]
|
||||
logger.debug("Using raw spec paths")
|
||||
schema_path = SchemaPath.from_dict(self.raw_spec)
|
||||
self._validator = V31RequestValidator(schema_path)
|
||||
|
||||
Args:
|
||||
operation_id: The operation ID to look up
|
||||
|
||||
Returns:
|
||||
Tuple of (path, method, operation) where operation is a SimpleNamespace object
|
||||
|
||||
Raises:
|
||||
ValueError: If operation not found
|
||||
"""
|
||||
async def _fetch_openapi_spec(self, url: str) -> dict:
|
||||
if not self._client:
|
||||
self._client = httpx.AsyncClient(headers=self.headers)
|
||||
try:
|
||||
# Debug the paths structure
|
||||
logger.debug("Looking for operation %s in paths", operation_id)
|
||||
response = await self._client.get(url)
|
||||
response.raise_for_status()
|
||||
except httpx.RequestError as e:
|
||||
raise ValueError(f"Failed to fetch OpenAPI spec from {url}: {e}")
|
||||
return response.json()
|
||||
|
||||
for path, path_item in self._paths.items():
|
||||
for method, operation_data in path_item.items():
|
||||
# Skip non-operation fields
|
||||
if method.startswith("x-") or method == "parameters":
|
||||
continue
|
||||
def _get_operation(self, operation_id: str):
|
||||
"""Get operation details from OpenAPI spec."""
|
||||
for path, path_item in self._paths.items():
|
||||
for method, operation_data in path_item.items():
|
||||
if method.startswith("x-") or method == "parameters":
|
||||
continue
|
||||
if operation_data.get("operationId") == operation_id:
|
||||
converted_data = convert_dict_keys(operation_data)
|
||||
from types import SimpleNamespace
|
||||
|
||||
# Debug each operation
|
||||
logger.debug("Checking %s %s: %s", method, path, operation_data.get("operationId"))
|
||||
operation_obj = SimpleNamespace(**converted_data)
|
||||
return path, method, operation_obj
|
||||
raise ValueError(f"Operation {operation_id} not found in spec")
|
||||
|
||||
if operation_data.get("operationId") == operation_id:
|
||||
logger.debug("Found operation %s at %s %s", operation_id, method, path)
|
||||
# Convert keys to snake_case and create object
|
||||
converted_data = convert_dict_keys(operation_data)
|
||||
operation_obj = SimpleNamespace(**converted_data)
|
||||
return path, method, operation_obj
|
||||
|
||||
raise ValueError(f"Operation {operation_id} not found in spec")
|
||||
except Exception as e:
|
||||
logger.error("Error getting operation %s: %s", operation_id, e)
|
||||
raise
|
||||
|
||||
def _validate_path_params(self, path: str, params: dict[str, Any] | None) -> None:
|
||||
def _validate_path_params(self, path: str, params: dict | None) -> None:
|
||||
if not params:
|
||||
params = {}
|
||||
|
||||
# Extract path parameter names from the path
|
||||
path_params = set(re.findall(r"{([^}]+)}", path))
|
||||
|
||||
# Check for missing required parameters
|
||||
missing_params = path_params - set(params.keys())
|
||||
if missing_params:
|
||||
raise ValueError(f"Missing required path parameters: {missing_params}")
|
||||
|
||||
# Check for invalid parameters
|
||||
invalid_params = set(params.keys()) - path_params
|
||||
if invalid_params:
|
||||
raise ValueError(f"Invalid path parameters: {invalid_params}")
|
||||
@@ -171,77 +121,42 @@ class AirflowClient:
|
||||
async def execute(
|
||||
self,
|
||||
operation_id: str,
|
||||
path_params: dict[str, Any] | None = None,
|
||||
query_params: dict[str, Any] | None = None,
|
||||
body: dict[str, Any] | None = None,
|
||||
) -> Any:
|
||||
"""Execute an API operation.
|
||||
|
||||
Args:
|
||||
operation_id: Operation ID from OpenAPI spec
|
||||
path_params: URL path parameters
|
||||
query_params: URL query parameters
|
||||
body: Request body data
|
||||
|
||||
Returns:
|
||||
API response data
|
||||
|
||||
Raises:
|
||||
ValueError: If operation not found
|
||||
RuntimeError: If used outside async context
|
||||
aiohttp.ClientError: For HTTP/network errors
|
||||
"""
|
||||
if not hasattr(self, "_session") or not self._session:
|
||||
path_params: dict = None,
|
||||
query_params: dict = None,
|
||||
body: dict = None,
|
||||
) -> dict:
|
||||
"""Execute an API operation."""
|
||||
if not self._client:
|
||||
raise RuntimeError("Client not in async context")
|
||||
|
||||
# Default all params to empty dict if None
|
||||
path_params = path_params or {}
|
||||
query_params = query_params or {}
|
||||
body = body or {}
|
||||
path, method, _ = self._get_operation(operation_id)
|
||||
self._validate_path_params(path, path_params)
|
||||
if path_params:
|
||||
path = path.format(**path_params)
|
||||
url = f"{self.base_url.rstrip('/')}{path}"
|
||||
request_headers = self.headers.copy()
|
||||
if body:
|
||||
request_headers["Content-Type"] = "application/json"
|
||||
try:
|
||||
# Get operation details
|
||||
path, method, _ = self._get_operation(operation_id)
|
||||
|
||||
# Validate path parameters
|
||||
self._validate_path_params(path, path_params)
|
||||
|
||||
# Format URL
|
||||
if path_params:
|
||||
path = path.format(**path_params)
|
||||
url = f"{self.base_url}{path}"
|
||||
|
||||
logger.debug("Executing %s %s", method, url)
|
||||
logger.debug("Request body: %s", body)
|
||||
logger.debug("Request query params: %s", query_params)
|
||||
|
||||
# Dynamically set headers based on presence of body
|
||||
request_headers = self.headers.copy()
|
||||
if body is not None:
|
||||
request_headers["Content-Type"] = "application/json"
|
||||
# Make request
|
||||
async with self._session.request(
|
||||
method=method,
|
||||
response = await self._client.request(
|
||||
method=method.upper(),
|
||||
url=url,
|
||||
params=query_params,
|
||||
json=body,
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
content_type = response.headers.get("Content-Type", "").lower()
|
||||
# Status codes that typically have no body
|
||||
no_body_statuses = {204}
|
||||
if response.status in no_body_statuses:
|
||||
if content_type and "application/json" in content_type:
|
||||
logger.warning("Unexpected JSON body with status %s", response.status)
|
||||
return await response.json() # Parse if present, though rare
|
||||
logger.debug("Received %s response with no body", response.status)
|
||||
return response.status
|
||||
# For statuses expecting a body, check mimetype
|
||||
if "application/json" in content_type:
|
||||
logger.debug("Response: %s", await response.text())
|
||||
return await response.json()
|
||||
# Unexpected mimetype with body
|
||||
response_text = await response.text()
|
||||
logger.error("Unexpected mimetype %s for status %s: %s", content_type, response.status, response_text)
|
||||
raise ValueError(f"Cannot parse response with mimetype {content_type} as JSON")
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error("Error executing operation %s: %s", operation_id, e)
|
||||
headers=request_headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
content_type = response.headers.get("content-type", "").lower()
|
||||
if response.status_code == 204:
|
||||
return response.status_code
|
||||
if "application/json" in content_type:
|
||||
return response.json()
|
||||
return {"content": await response.aread()}
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error("HTTP error executing operation %s: %s", operation_id, e)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Error executing operation %s: %s", operation_id, e)
|
||||
|
||||
20
src/airflow_mcp_server/config.py
Normal file
20
src/airflow_mcp_server/config.py
Normal file
@@ -0,0 +1,20 @@
|
||||
class AirflowConfig:
|
||||
"""Centralized configuration for Airflow MCP server."""
|
||||
|
||||
def __init__(self, base_url: str | None = None, auth_token: str | None = None) -> None:
|
||||
"""Initialize configuration with provided values.
|
||||
|
||||
Args:
|
||||
base_url: Airflow API base URL
|
||||
auth_token: Authentication token (JWT)
|
||||
|
||||
Raises:
|
||||
ValueError: If required configuration is missing
|
||||
"""
|
||||
self.base_url = base_url
|
||||
if not self.base_url:
|
||||
raise ValueError("Missing required configuration: base_url")
|
||||
|
||||
self.auth_token = auth_token
|
||||
if not self.auth_token:
|
||||
raise ValueError("Missing required configuration: auth_token (JWT)")
|
||||
@@ -20,6 +20,7 @@ class OperationDetails:
|
||||
method: str
|
||||
parameters: dict[str, Any]
|
||||
input_model: type[BaseModel]
|
||||
description: str
|
||||
|
||||
|
||||
class OperationParser:
|
||||
@@ -104,6 +105,7 @@ class OperationParser:
|
||||
|
||||
operation["path"] = path
|
||||
operation["path_item"] = path_item
|
||||
description = operation.get("description") or operation.get("summary") or operation_id
|
||||
|
||||
parameters = self.extract_parameters(operation)
|
||||
|
||||
@@ -119,7 +121,7 @@ class OperationParser:
|
||||
# Create unified input model
|
||||
input_model = self._create_input_model(operation_id, parameters, body_schema)
|
||||
|
||||
return OperationDetails(operation_id=operation_id, path=str(path), method=method, parameters=parameters, input_model=input_model)
|
||||
return OperationDetails(operation_id=operation_id, path=str(path), method=method, parameters=parameters, description=description, input_model=input_model)
|
||||
|
||||
raise ValueError(f"Operation {operation_id} not found in spec")
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,11 +1,11 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from mcp.server import Server
|
||||
from mcp.server.stdio import stdio_server
|
||||
from mcp.types import TextContent, Tool
|
||||
|
||||
from airflow_mcp_server.config import AirflowConfig
|
||||
from airflow_mcp_server.tools.tool_manager import get_airflow_tools, get_tool
|
||||
|
||||
# ===========THIS IS FOR DEBUGGING WITH MCP INSPECTOR===================
|
||||
@@ -20,18 +20,18 @@ from airflow_mcp_server.tools.tool_manager import get_airflow_tools, get_tool
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def serve() -> None:
|
||||
"""Start MCP server."""
|
||||
required_vars = ["AIRFLOW_BASE_URL", "AUTH_TOKEN"]
|
||||
if not all(var in os.environ for var in required_vars):
|
||||
raise ValueError(f"Missing required environment variables: {required_vars}")
|
||||
async def serve(config: AirflowConfig) -> None:
|
||||
"""Start MCP server.
|
||||
|
||||
Args:
|
||||
config: Configuration object with auth and URL settings
|
||||
"""
|
||||
server = Server("airflow-mcp-server")
|
||||
|
||||
@server.list_tools()
|
||||
async def list_tools() -> list[Tool]:
|
||||
try:
|
||||
return await get_airflow_tools()
|
||||
return await get_airflow_tools(config)
|
||||
except Exception as e:
|
||||
logger.error("Failed to list tools: %s", e)
|
||||
raise
|
||||
@@ -39,7 +39,7 @@ async def serve() -> None:
|
||||
@server.call_tool()
|
||||
async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]:
|
||||
try:
|
||||
tool = await get_tool(name)
|
||||
tool = await get_tool(config, name)
|
||||
async with tool.client:
|
||||
result = await tool.run(body=arguments)
|
||||
return [TextContent(type="text", text=str(result))]
|
||||
|
||||
@@ -1,28 +1,28 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from mcp.server import Server
|
||||
from mcp.server.stdio import stdio_server
|
||||
from mcp.types import TextContent, Tool
|
||||
|
||||
from airflow_mcp_server.config import AirflowConfig
|
||||
from airflow_mcp_server.tools.tool_manager import get_airflow_tools, get_tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def serve() -> None:
|
||||
"""Start MCP server in safe mode (read-only operations)."""
|
||||
required_vars = ["AIRFLOW_BASE_URL", "AUTH_TOKEN"]
|
||||
if not all(var in os.environ for var in required_vars):
|
||||
raise ValueError(f"Missing required environment variables: {required_vars}")
|
||||
async def serve(config: AirflowConfig) -> None:
|
||||
"""Start MCP server in safe mode (read-only operations).
|
||||
|
||||
Args:
|
||||
config: Configuration object with auth and URL settings
|
||||
"""
|
||||
server = Server("airflow-mcp-server-safe")
|
||||
|
||||
@server.list_tools()
|
||||
async def list_tools() -> list[Tool]:
|
||||
try:
|
||||
return await get_airflow_tools(mode="safe")
|
||||
return await get_airflow_tools(config, mode="safe")
|
||||
except Exception as e:
|
||||
logger.error("Failed to list tools: %s", e)
|
||||
raise
|
||||
@@ -32,7 +32,7 @@ async def serve() -> None:
|
||||
try:
|
||||
if not name.startswith("get_"):
|
||||
raise ValueError("Only GET operations allowed in safe mode")
|
||||
tool = await get_tool(name)
|
||||
tool = await get_tool(config, name)
|
||||
async with tool.client:
|
||||
result = await tool.run(body=arguments)
|
||||
return [TextContent(type="text", text=str(result))]
|
||||
|
||||
@@ -1,28 +1,28 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from mcp.server import Server
|
||||
from mcp.server.stdio import stdio_server
|
||||
from mcp.types import TextContent, Tool
|
||||
|
||||
from airflow_mcp_server.config import AirflowConfig
|
||||
from airflow_mcp_server.tools.tool_manager import get_airflow_tools, get_tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def serve() -> None:
|
||||
"""Start MCP server in unsafe mode (all operations)."""
|
||||
required_vars = ["AIRFLOW_BASE_URL", "AUTH_TOKEN"]
|
||||
if not all(var in os.environ for var in required_vars):
|
||||
raise ValueError(f"Missing required environment variables: {required_vars}")
|
||||
async def serve(config: AirflowConfig) -> None:
|
||||
"""Start MCP server in unsafe mode (all operations).
|
||||
|
||||
Args:
|
||||
config: Configuration object with auth and URL settings
|
||||
"""
|
||||
server = Server("airflow-mcp-server-unsafe")
|
||||
|
||||
@server.list_tools()
|
||||
async def list_tools() -> list[Tool]:
|
||||
try:
|
||||
return await get_airflow_tools(mode="unsafe")
|
||||
return await get_airflow_tools(config, mode="unsafe")
|
||||
except Exception as e:
|
||||
logger.error("Failed to list tools: %s", e)
|
||||
raise
|
||||
@@ -30,7 +30,7 @@ async def serve() -> None:
|
||||
@server.call_tool()
|
||||
async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]:
|
||||
try:
|
||||
tool = await get_tool(name)
|
||||
tool = await get_tool(config, name)
|
||||
async with tool.client:
|
||||
result = await tool.run(body=arguments)
|
||||
return [TextContent(type="text", text=str(result))]
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
|
||||
class BaseTools(ABC):
|
||||
"""Abstract base class for tools."""
|
||||
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import logging
|
||||
import os
|
||||
from importlib import resources
|
||||
|
||||
from mcp.types import Tool
|
||||
from packaging.version import parse as parse_version
|
||||
|
||||
from airflow_mcp_server.client.airflow_client import AirflowClient
|
||||
from airflow_mcp_server.config import AirflowConfig
|
||||
from airflow_mcp_server.parser.operation_parser import OperationParser
|
||||
from airflow_mcp_server.tools.airflow_tool import AirflowTool
|
||||
|
||||
@@ -13,54 +13,16 @@ logger = logging.getLogger(__name__)
|
||||
_tools_cache: dict[str, AirflowTool] = {}
|
||||
|
||||
|
||||
def _initialize_client() -> AirflowClient:
|
||||
"""Initialize Airflow client with environment variables or embedded spec.
|
||||
|
||||
Returns:
|
||||
AirflowClient instance
|
||||
|
||||
Raises:
|
||||
ValueError: If required environment variables are missing or default spec is not found
|
||||
"""
|
||||
spec_path = os.environ.get("OPENAPI_SPEC")
|
||||
if not spec_path:
|
||||
# Fallback to embedded v1.yaml
|
||||
try:
|
||||
with resources.files("airflow_mcp_server.resources").joinpath("v1.yaml").open("rb") as f:
|
||||
spec_path = f.name
|
||||
logger.info("OPENAPI_SPEC not set; using embedded v1.yaml from %s", spec_path)
|
||||
except Exception as e:
|
||||
raise ValueError("Default OpenAPI spec not found in package resources") from e
|
||||
|
||||
required_vars = ["AIRFLOW_BASE_URL", "AUTH_TOKEN"]
|
||||
missing_vars = [var for var in required_vars if var not in os.environ]
|
||||
if missing_vars:
|
||||
raise ValueError(f"Missing required environment variables: {missing_vars}")
|
||||
|
||||
return AirflowClient(spec_path=spec_path, base_url=os.environ["AIRFLOW_BASE_URL"], auth_token=os.environ["AUTH_TOKEN"])
|
||||
|
||||
|
||||
async def _initialize_tools() -> None:
|
||||
"""Initialize tools cache with Airflow operations.
|
||||
|
||||
Raises:
|
||||
ValueError: If initialization fails
|
||||
"""
|
||||
async def _initialize_tools(config: AirflowConfig) -> None:
|
||||
"""Initialize tools cache with Airflow operations (async)."""
|
||||
global _tools_cache
|
||||
|
||||
try:
|
||||
client = _initialize_client()
|
||||
spec_path = os.environ.get("OPENAPI_SPEC")
|
||||
if not spec_path:
|
||||
with resources.files("airflow_mcp_server.resources").joinpath("v1.yaml").open("rb") as f:
|
||||
spec_path = f.name
|
||||
parser = OperationParser(spec_path)
|
||||
|
||||
# Generate tools for each operation
|
||||
for operation_id in parser.get_operations():
|
||||
operation_details = parser.parse_operation(operation_id)
|
||||
tool = AirflowTool(operation_details, client)
|
||||
_tools_cache[operation_id] = tool
|
||||
async with AirflowClient(base_url=config.base_url, auth_token=config.auth_token) as client:
|
||||
parser = OperationParser(client.raw_spec)
|
||||
for operation_id in parser.get_operations():
|
||||
operation_details = parser.parse_operation(operation_id)
|
||||
tool = AirflowTool(operation_details, client)
|
||||
_tools_cache[operation_id] = tool
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to initialize tools: %s", e)
|
||||
@@ -68,10 +30,11 @@ async def _initialize_tools() -> None:
|
||||
raise ValueError(f"Failed to initialize tools: {e}") from e
|
||||
|
||||
|
||||
async def get_airflow_tools(mode: str = "unsafe") -> list[Tool]:
|
||||
async def get_airflow_tools(config: AirflowConfig, mode: str = "unsafe") -> list[Tool]:
|
||||
"""Get list of available Airflow tools based on mode.
|
||||
|
||||
Args:
|
||||
config: Configuration object with auth and URL settings
|
||||
mode: "safe" for GET operations only, "unsafe" for all operations (default)
|
||||
|
||||
Returns:
|
||||
@@ -80,8 +43,21 @@ async def get_airflow_tools(mode: str = "unsafe") -> list[Tool]:
|
||||
Raises:
|
||||
ValueError: If initialization fails
|
||||
"""
|
||||
|
||||
# Version check before returning tools
|
||||
if not _tools_cache:
|
||||
await _initialize_tools()
|
||||
await _initialize_tools(config)
|
||||
|
||||
# Only check version if get_version tool is present
|
||||
if "get_version" in _tools_cache:
|
||||
version_tool = _tools_cache["get_version"]
|
||||
async with version_tool.client:
|
||||
version_result = await version_tool.run()
|
||||
airflow_version = version_result.get("version")
|
||||
if airflow_version is None:
|
||||
raise RuntimeError("Could not determine Airflow version from get_version tool.")
|
||||
if parse_version(airflow_version) < parse_version("3.1.0"):
|
||||
raise RuntimeError(f"Airflow version {airflow_version} is not supported. Requires >= 3.1.0.")
|
||||
|
||||
tools = []
|
||||
for operation_id, tool in _tools_cache.items():
|
||||
@@ -93,7 +69,7 @@ async def get_airflow_tools(mode: str = "unsafe") -> list[Tool]:
|
||||
tools.append(
|
||||
Tool(
|
||||
name=operation_id,
|
||||
description=tool.operation.operation_id,
|
||||
description=tool.operation.description,
|
||||
inputSchema=schema,
|
||||
)
|
||||
)
|
||||
@@ -104,10 +80,11 @@ async def get_airflow_tools(mode: str = "unsafe") -> list[Tool]:
|
||||
return tools
|
||||
|
||||
|
||||
async def get_tool(name: str) -> AirflowTool:
|
||||
async def get_tool(config: AirflowConfig, name: str) -> AirflowTool:
|
||||
"""Get specific tool by name.
|
||||
|
||||
Args:
|
||||
config: Configuration object with auth and URL settings
|
||||
name: Tool/operation name
|
||||
|
||||
Returns:
|
||||
@@ -118,7 +95,7 @@ async def get_tool(name: str) -> AirflowTool:
|
||||
ValueError: If tool initialization fails
|
||||
"""
|
||||
if not _tools_cache:
|
||||
await _initialize_tools()
|
||||
await _initialize_tools(config)
|
||||
|
||||
if name not in _tools_cache:
|
||||
raise KeyError(f"Tool {name} not found")
|
||||
|
||||
@@ -1,186 +1,69 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from importlib import resources
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
import yaml
|
||||
from aioresponses import aioresponses
|
||||
from airflow_mcp_server.client.airflow_client import AirflowClient
|
||||
from openapi_core import OpenAPI
|
||||
|
||||
from airflow_mcp_server.client.airflow_client import AirflowClient
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
|
||||
def create_valid_spec(paths: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||
return {"openapi": "3.0.0", "info": {"title": "Airflow API", "version": "1.0.0"}, "paths": paths or {}}
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_multiple_clients_concurrent():
|
||||
"""Test initializing two AirflowClients concurrently to verify async power."""
|
||||
|
||||
async def mock_get(self, url, *args, **kwargs):
|
||||
class MockResponse:
|
||||
def __init__(self):
|
||||
self.status_code = 200
|
||||
|
||||
@pytest.fixture
|
||||
def client() -> AirflowClient:
|
||||
with resources.files("airflow_mcp_server.resources").joinpath("v1.yaml").open("rb") as f:
|
||||
spec = yaml.safe_load(f)
|
||||
return AirflowClient(
|
||||
spec_path=spec,
|
||||
base_url="http://localhost:8080/api/v1",
|
||||
auth_token="test-token",
|
||||
)
|
||||
def raise_for_status(self):
|
||||
pass
|
||||
|
||||
def json(self):
|
||||
return {"openapi": "3.1.0", "info": {"title": "Airflow API", "version": "2.0.0"}, "paths": {}}
|
||||
|
||||
def test_init_client_initialization(client: AirflowClient) -> None:
|
||||
assert isinstance(client.spec, OpenAPI)
|
||||
assert client.base_url == "http://localhost:8080/api/v1"
|
||||
assert client.headers["Authorization"] == "Basic test-token"
|
||||
return MockResponse()
|
||||
|
||||
with patch("httpx.AsyncClient.get", new=mock_get):
|
||||
|
||||
def test_init_load_spec_from_bytes() -> None:
|
||||
spec_bytes = yaml.dump(create_valid_spec()).encode()
|
||||
client = AirflowClient(spec_path=spec_bytes, base_url="http://test", auth_token="test")
|
||||
assert client.raw_spec is not None
|
||||
async def create_and_check():
|
||||
async with AirflowClient(base_url="http://localhost:8080", auth_token="token") as client:
|
||||
assert client.base_url == "http://localhost:8080"
|
||||
assert client.headers["Authorization"] == "Bearer token"
|
||||
assert isinstance(client.spec, OpenAPI)
|
||||
|
||||
|
||||
def test_init_load_spec_from_path(tmp_path: Path) -> None:
|
||||
spec_file = tmp_path / "test_spec.yaml"
|
||||
spec_file.write_text(yaml.dump(create_valid_spec()))
|
||||
client = AirflowClient(spec_path=spec_file, base_url="http://test", auth_token="test")
|
||||
assert client.raw_spec is not None
|
||||
|
||||
|
||||
def test_init_invalid_spec() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
AirflowClient(spec_path={"invalid": "spec"}, base_url="http://test", auth_token="test")
|
||||
|
||||
|
||||
def test_init_missing_paths_in_spec() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
AirflowClient(spec_path={"openapi": "3.0.0"}, base_url="http://test", auth_token="test")
|
||||
|
||||
|
||||
def test_ops_get_operation(client: AirflowClient) -> None:
|
||||
path, method, operation = client._get_operation("get_dags")
|
||||
assert path == "/dags"
|
||||
assert method == "get"
|
||||
assert operation.operation_id == "get_dags"
|
||||
|
||||
path, method, operation = client._get_operation("get_dag")
|
||||
assert path == "/dags/{dag_id}"
|
||||
assert method == "get"
|
||||
assert operation.operation_id == "get_dag"
|
||||
|
||||
|
||||
def test_ops_nonexistent_operation(client: AirflowClient) -> None:
|
||||
with pytest.raises(ValueError, match="Operation nonexistent not found in spec"):
|
||||
client._get_operation("nonexistent")
|
||||
|
||||
|
||||
def test_ops_case_sensitive_operation(client: AirflowClient) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
client._get_operation("GET_DAGS")
|
||||
# Run two clients concurrently
|
||||
await asyncio.gather(create_and_check(), create_and_check())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_without_context() -> None:
|
||||
client = AirflowClient(
|
||||
spec_path=create_valid_spec(),
|
||||
base_url="http://test",
|
||||
auth_token="test",
|
||||
)
|
||||
with pytest.raises(RuntimeError, match="Client not in async context"):
|
||||
await client.execute("get_dags")
|
||||
async def test_async_client_initialization():
|
||||
async def mock_get(self, url, *args, **kwargs):
|
||||
class MockResponse:
|
||||
def __init__(self):
|
||||
self.status_code = 200
|
||||
|
||||
def raise_for_status(self):
|
||||
pass
|
||||
|
||||
def json(self):
|
||||
return {"openapi": "3.1.0", "info": {"title": "Airflow API", "version": "2.0.0"}, "paths": {}}
|
||||
|
||||
return MockResponse()
|
||||
|
||||
with patch("httpx.AsyncClient.get", new=mock_get):
|
||||
async with AirflowClient(base_url="http://localhost:8080", auth_token="test-token") as client:
|
||||
assert client.base_url == "http://localhost:8080"
|
||||
assert client.headers["Authorization"] == "Bearer test-token"
|
||||
assert isinstance(client.spec, OpenAPI)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_get_dags(client: AirflowClient) -> None:
|
||||
expected_response = {
|
||||
"dags": [
|
||||
{
|
||||
"dag_id": "test_dag",
|
||||
"is_active": True,
|
||||
"is_paused": False,
|
||||
}
|
||||
],
|
||||
"total_entries": 1,
|
||||
}
|
||||
|
||||
with aioresponses() as mock:
|
||||
async with client:
|
||||
mock.get(
|
||||
"http://localhost:8080/api/v1/dags?limit=100",
|
||||
status=200,
|
||||
payload=expected_response,
|
||||
)
|
||||
response = await client.execute("get_dags", query_params={"limit": 100})
|
||||
assert response == expected_response
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_get_dag(client: AirflowClient) -> None:
|
||||
expected_response = {
|
||||
"dag_id": "test_dag",
|
||||
"is_active": True,
|
||||
"is_paused": False,
|
||||
}
|
||||
|
||||
with aioresponses() as mock:
|
||||
async with client:
|
||||
mock.get(
|
||||
"http://localhost:8080/api/v1/dags/test_dag",
|
||||
status=200,
|
||||
payload=expected_response,
|
||||
)
|
||||
response = await client.execute(
|
||||
"get_dag",
|
||||
path_params={"dag_id": "test_dag"},
|
||||
)
|
||||
assert response == expected_response
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_invalid_params(client: AirflowClient) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
async with client:
|
||||
# Test with missing required parameter
|
||||
await client.execute("get_dag", path_params={})
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
async with client:
|
||||
# Test with invalid parameter name
|
||||
await client.execute("get_dag", path_params={"invalid": "value"})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_timeout(client: AirflowClient) -> None:
|
||||
with aioresponses() as mock:
|
||||
mock.get("http://localhost:8080/api/v1/dags", exception=aiohttp.ClientError("Timeout"))
|
||||
async with client:
|
||||
with pytest.raises(aiohttp.ClientError):
|
||||
await client.execute("get_dags")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_error_response(client: AirflowClient) -> None:
|
||||
with aioresponses() as mock:
|
||||
async with client:
|
||||
mock.get(
|
||||
"http://localhost:8080/api/v1/dags",
|
||||
status=403,
|
||||
body="Forbidden",
|
||||
)
|
||||
with pytest.raises(aiohttp.ClientResponseError):
|
||||
await client.execute("get_dags")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_session_management(client: AirflowClient) -> None:
|
||||
async with client:
|
||||
with aioresponses() as mock:
|
||||
mock.get(
|
||||
"http://localhost:8080/api/v1/dags",
|
||||
status=200,
|
||||
payload={"dags": []},
|
||||
)
|
||||
await client.execute("get_dags")
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
await client.execute("get_dags")
|
||||
def test_init_client_missing_auth():
|
||||
with pytest.raises(ValueError, match="auth_token"):
|
||||
AirflowClient(
|
||||
base_url="http://localhost:8080",
|
||||
auth_token=None,
|
||||
)
|
||||
|
||||
17319
tests/parser/openapi.json
Normal file
17319
tests/parser/openapi.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,26 +1,16 @@
|
||||
import logging
|
||||
from importlib import resources
|
||||
from typing import Any
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from airflow_mcp_server.parser.operation_parser import OperationDetails, OperationParser
|
||||
from typing import Any
|
||||
from pydantic import BaseModel
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger(__name__)
|
||||
from airflow_mcp_server.parser.operation_parser import OperationDetails, OperationParser
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def spec_file():
|
||||
"""Get content of the v1.yaml spec file."""
|
||||
with resources.files("tests.client").joinpath("v1.yaml").open("rb") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def parser(spec_file) -> OperationParser:
|
||||
"""Create OperationParser instance."""
|
||||
return OperationParser(spec_path=spec_file)
|
||||
def parser() -> OperationParser:
|
||||
"""Create OperationParser instance from tests/parser/openapi.json."""
|
||||
with open("tests/parser/openapi.json") as f:
|
||||
spec_dict = json.load(f)
|
||||
return OperationParser(spec_dict)
|
||||
|
||||
|
||||
def test_parse_operation_basic(parser: OperationParser) -> None:
|
||||
@@ -29,8 +19,21 @@ def test_parse_operation_basic(parser: OperationParser) -> None:
|
||||
|
||||
assert isinstance(operation, OperationDetails)
|
||||
assert operation.operation_id == "get_dags"
|
||||
assert operation.path == "/dags"
|
||||
assert operation.path == "/api/v2/dags"
|
||||
assert operation.method == "get"
|
||||
assert operation.description == "Get all DAGs."
|
||||
assert isinstance(operation.parameters, dict)
|
||||
|
||||
|
||||
def test_parse_operation_with_no_description_but_summary(parser: OperationParser) -> None:
|
||||
"""Test parsing operation with no description but summary."""
|
||||
operation = parser.parse_operation("get_connections")
|
||||
|
||||
assert isinstance(operation, OperationDetails)
|
||||
assert operation.operation_id == "get_connections"
|
||||
assert operation.path == "/api/v2/connections"
|
||||
assert operation.method == "get"
|
||||
assert operation.description == "Get all connection entries."
|
||||
assert isinstance(operation.parameters, dict)
|
||||
|
||||
|
||||
@@ -38,7 +41,7 @@ def test_parse_operation_with_path_params(parser: OperationParser) -> None:
|
||||
"""Test parsing operation with path parameters."""
|
||||
operation = parser.parse_operation("get_dag")
|
||||
|
||||
assert operation.path == "/dags/{dag_id}"
|
||||
assert operation.path == "/api/v2/dags/{dag_id}"
|
||||
assert isinstance(operation.input_model, type(BaseModel))
|
||||
|
||||
# Verify path parameter field exists
|
||||
@@ -65,7 +68,10 @@ def test_parse_operation_with_query_params(parser: OperationParser) -> None:
|
||||
|
||||
def test_parse_operation_with_body_params(parser: OperationParser) -> None:
|
||||
"""Test parsing operation with request body."""
|
||||
operation = parser.parse_operation("post_dag_run")
|
||||
# Find the correct operationId for posting a dag run in the OpenAPI spec
|
||||
# From the spec, the likely operation is under /api/v2/dags/{dag_id}/dagRuns
|
||||
# Let's use "post_dag_run" if it exists, otherwise use the actual operationId
|
||||
operation = parser.parse_operation("trigger_dag_run")
|
||||
|
||||
# Verify body fields exist
|
||||
fields = operation.input_model.__annotations__
|
||||
@@ -149,7 +155,7 @@ def test_parse_operation_with_allof_body(parser: OperationParser) -> None:
|
||||
|
||||
assert isinstance(operation, OperationDetails)
|
||||
assert operation.operation_id == "test_connection"
|
||||
assert operation.path == "/connections/test"
|
||||
assert operation.path == "/api/v2/connections/test"
|
||||
assert operation.method == "post"
|
||||
|
||||
# Verify input model includes fields from allOf schema
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
"""Tests for AirflowTool."""
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from airflow_mcp_server.client.airflow_client import AirflowClient
|
||||
from airflow_mcp_server.parser.operation_parser import OperationDetails
|
||||
from airflow_mcp_server.tools.airflow_tool import AirflowTool
|
||||
from pydantic import ValidationError
|
||||
|
||||
from tests.tools.test_models import TestRequestModel
|
||||
|
||||
|
||||
@@ -41,6 +41,7 @@ def operation_details():
|
||||
},
|
||||
},
|
||||
input_model=model,
|
||||
description="Test operation for AirflowTool",
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user