Compare commits
16 Commits
0.3.0
...
63ff02fa4b
| Author | SHA1 | Date | |
|---|---|---|---|
| 63ff02fa4b | |||
|
|
407eb00c1b | ||
|
707f3747d7
|
|||
|
a8638d27a5
|
|||
| dbbc3ef5e8 | |||
|
d8887d3a2b
|
|||
|
679523a7c6
|
|||
|
492e79ef2a
|
|||
|
ea60acd54a
|
|||
|
420b6fc68f
|
|||
|
355fb55bdb
|
|||
|
8b38a26e8a
|
|||
|
5663f56621
|
|||
|
2b652c5926
|
|||
| 3fd605b111 | |||
|
c5106f10a8
|
29
.github/workflows/pytest.yml
vendored
Normal file
29
.github/workflows/pytest.yml
vendored
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
name: Run Tests
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
branches: [ main ]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: ["3.11", "3.12"]
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v3
|
||||||
|
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install .[dev]
|
||||||
|
|
||||||
|
- name: Run pytest
|
||||||
|
run: |
|
||||||
|
pytest tests/ -v
|
||||||
17
README.md
17
README.md
@@ -29,7 +29,10 @@ https://github.com/user-attachments/assets/f3e60fff-8680-4dd9-b08e-fa7db655a705
|
|||||||
],
|
],
|
||||||
"env": {
|
"env": {
|
||||||
"AIRFLOW_BASE_URL": "http://<host:port>/api/v1",
|
"AIRFLOW_BASE_URL": "http://<host:port>/api/v1",
|
||||||
"AUTH_TOKEN": "<base64_encoded_username_password>"
|
// Either use AUTH_TOKEN for basic auth
|
||||||
|
"AUTH_TOKEN": "<base64_encoded_username_password>",
|
||||||
|
// Or use COOKIE for cookie-based auth
|
||||||
|
"COOKIE": "<session_cookie>"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -57,10 +60,17 @@ airflow-mcp-server --unsafe
|
|||||||
|
|
||||||
The MCP Server expects environment variables to be set:
|
The MCP Server expects environment variables to be set:
|
||||||
- `AIRFLOW_BASE_URL`: The base URL of the Airflow API
|
- `AIRFLOW_BASE_URL`: The base URL of the Airflow API
|
||||||
- `AUTH_TOKEN`: The token to use for authorization (_This should be base64 encoded username:password_)
|
- `AUTH_TOKEN`: The token to use for basic auth (_This should be base64 encoded username:password_) (_Optional if COOKIE is provided_)
|
||||||
|
- `COOKIE`: The session cookie to use for authentication (_Optional if AUTH_TOKEN is provided_)
|
||||||
- `OPENAPI_SPEC`: The path to the OpenAPI spec file (_Optional_) (_defaults to latest stable release_)
|
- `OPENAPI_SPEC`: The path to the OpenAPI spec file (_Optional_) (_defaults to latest stable release_)
|
||||||
|
|
||||||
*Currently, only Basic Auth is supported.*
|
**Authentication**
|
||||||
|
|
||||||
|
The server supports two authentication methods:
|
||||||
|
- **Basic Auth**: Using base64 encoded username:password via `AUTH_TOKEN` environment variable
|
||||||
|
- **Cookie**: Using session cookie via `COOKIE` environment variable
|
||||||
|
|
||||||
|
At least one of these authentication methods must be provided.
|
||||||
|
|
||||||
**Page Limit**
|
**Page Limit**
|
||||||
|
|
||||||
@@ -71,6 +81,7 @@ The default is 100 items, but you can change it using `maximum_page_limit` optio
|
|||||||
- [x] First API
|
- [x] First API
|
||||||
- [x] Parse OpenAPI Spec
|
- [x] Parse OpenAPI Spec
|
||||||
- [x] Safe/Unsafe mode implementation
|
- [x] Safe/Unsafe mode implementation
|
||||||
|
- [x] Allow session auth
|
||||||
- [ ] Parse proper description with list_tools.
|
- [ ] Parse proper description with list_tools.
|
||||||
- [ ] Airflow config fetch (_specifically for page limit_)
|
- [ ] Airflow config fetch (_specifically for page limit_)
|
||||||
- [ ] Env variables optional (_env variables might not be ideal for airflow plugins_)
|
- [ ] Env variables optional (_env variables might not be ideal for airflow plugins_)
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "airflow-mcp-server"
|
name = "airflow-mcp-server"
|
||||||
version = "0.3.0"
|
version = "0.4.0"
|
||||||
description = "MCP Server for Airflow"
|
description = "MCP Server for Airflow"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.11"
|
requires-python = ">=3.11"
|
||||||
@@ -52,7 +52,8 @@ build-backend = "hatchling.build"
|
|||||||
exclude = [
|
exclude = [
|
||||||
"*",
|
"*",
|
||||||
"!src/**",
|
"!src/**",
|
||||||
"!pyproject.toml"
|
"!pyproject.toml",
|
||||||
|
"!assets/**"
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.hatch.build.targets.wheel]
|
[tool.hatch.build.targets.wheel]
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
|
from airflow_mcp_server.config import AirflowConfig
|
||||||
from airflow_mcp_server.server_safe import serve as serve_safe
|
from airflow_mcp_server.server_safe import serve as serve_safe
|
||||||
from airflow_mcp_server.server_unsafe import serve as serve_unsafe
|
from airflow_mcp_server.server_unsafe import serve as serve_unsafe
|
||||||
|
|
||||||
@@ -12,7 +14,11 @@ from airflow_mcp_server.server_unsafe import serve as serve_unsafe
|
|||||||
@click.option("-v", "--verbose", count=True, help="Increase verbosity")
|
@click.option("-v", "--verbose", count=True, help="Increase verbosity")
|
||||||
@click.option("--safe", "-s", is_flag=True, help="Use only read-only tools")
|
@click.option("--safe", "-s", is_flag=True, help="Use only read-only tools")
|
||||||
@click.option("--unsafe", "-u", is_flag=True, help="Use all tools (default)")
|
@click.option("--unsafe", "-u", is_flag=True, help="Use all tools (default)")
|
||||||
def main(verbose: int, safe: bool, unsafe: bool) -> None:
|
@click.option("--base-url", help="Airflow API base URL")
|
||||||
|
@click.option("--spec-path", help="Path to OpenAPI spec file")
|
||||||
|
@click.option("--auth-token", help="Authentication token")
|
||||||
|
@click.option("--cookie", help="Session cookie")
|
||||||
|
def main(verbose: int, safe: bool, unsafe: bool, base_url: str = None, spec_path: str = None, auth_token: str = None, cookie: str = None) -> None:
|
||||||
"""MCP server for Airflow"""
|
"""MCP server for Airflow"""
|
||||||
logging_level = logging.WARN
|
logging_level = logging.WARN
|
||||||
if verbose == 1:
|
if verbose == 1:
|
||||||
@@ -22,13 +28,30 @@ def main(verbose: int, safe: bool, unsafe: bool) -> None:
|
|||||||
|
|
||||||
logging.basicConfig(level=logging_level, stream=sys.stderr)
|
logging.basicConfig(level=logging_level, stream=sys.stderr)
|
||||||
|
|
||||||
if safe and unsafe:
|
# Read environment variables with proper precedence
|
||||||
raise click.UsageError("Options --safe and --unsafe are mutually exclusive")
|
# Environment variables take precedence over CLI arguments
|
||||||
|
config_base_url = os.environ.get("AIRFLOW_BASE_URL") or base_url
|
||||||
|
config_spec_path = os.environ.get("OPENAPI_SPEC") or spec_path
|
||||||
|
config_auth_token = os.environ.get("AUTH_TOKEN") or auth_token
|
||||||
|
config_cookie = os.environ.get("COOKIE") or cookie
|
||||||
|
|
||||||
if safe:
|
# Initialize configuration
|
||||||
asyncio.run(serve_safe())
|
try:
|
||||||
else: # Default to unsafe mode
|
config = AirflowConfig(base_url=config_base_url, spec_path=config_spec_path, auth_token=config_auth_token, cookie=config_cookie)
|
||||||
asyncio.run(serve_unsafe())
|
except ValueError as e:
|
||||||
|
click.echo(f"Configuration error: {e}", err=True)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Determine server mode with proper precedence
|
||||||
|
if safe and unsafe:
|
||||||
|
# CLI argument validation
|
||||||
|
raise click.UsageError("Options --safe and --unsafe are mutually exclusive")
|
||||||
|
elif safe:
|
||||||
|
# CLI argument for safe mode
|
||||||
|
asyncio.run(serve_safe(config))
|
||||||
|
else:
|
||||||
|
# Default to unsafe mode
|
||||||
|
asyncio.run(serve_unsafe(config))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -35,18 +35,22 @@ class AirflowClient:
|
|||||||
self,
|
self,
|
||||||
spec_path: Path | str | dict | bytes | BinaryIO | TextIO,
|
spec_path: Path | str | dict | bytes | BinaryIO | TextIO,
|
||||||
base_url: str,
|
base_url: str,
|
||||||
auth_token: str,
|
auth_token: str | None = None,
|
||||||
|
cookie: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize Airflow client.
|
"""Initialize Airflow client.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
spec_path: OpenAPI spec as file path, dict, bytes, or file object
|
spec_path: OpenAPI spec as file path, dict, bytes, or file object
|
||||||
base_url: Base URL for API
|
base_url: Base URL for API
|
||||||
auth_token: Authentication token
|
auth_token: Authentication token (optional if cookie is provided)
|
||||||
|
cookie: Session cookie (optional if auth_token is provided)
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If spec_path is invalid or spec cannot be loaded
|
ValueError: If spec_path is invalid or spec cannot be loaded or if neither auth_token nor cookie is provided
|
||||||
"""
|
"""
|
||||||
|
if not auth_token and not cookie:
|
||||||
|
raise ValueError("Either auth_token or cookie must be provided")
|
||||||
try:
|
try:
|
||||||
# Load and parse OpenAPI spec
|
# Load and parse OpenAPI spec
|
||||||
if isinstance(spec_path, dict):
|
if isinstance(spec_path, dict):
|
||||||
@@ -96,10 +100,13 @@ class AirflowClient:
|
|||||||
|
|
||||||
# API configuration
|
# API configuration
|
||||||
self.base_url = base_url.rstrip("/")
|
self.base_url = base_url.rstrip("/")
|
||||||
self.headers = {
|
self.headers = {"Accept": "application/json"}
|
||||||
"Authorization": f"Basic {auth_token}",
|
|
||||||
"Accept": "application/json",
|
# Set authentication header based on precedence (cookie > auth_token)
|
||||||
}
|
if cookie:
|
||||||
|
self.headers["Cookie"] = cookie
|
||||||
|
elif auth_token:
|
||||||
|
self.headers["Authorization"] = f"Basic {auth_token}"
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to initialize AirflowClient: %s", e)
|
logger.error("Failed to initialize AirflowClient: %s", e)
|
||||||
|
|||||||
25
src/airflow_mcp_server/config.py
Normal file
25
src/airflow_mcp_server/config.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
class AirflowConfig:
|
||||||
|
"""Centralized configuration for Airflow MCP server."""
|
||||||
|
|
||||||
|
def __init__(self, base_url: str | None = None, spec_path: str | None = None, auth_token: str | None = None, cookie: str | None = None) -> None:
|
||||||
|
"""Initialize configuration with provided values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_url: Airflow API base URL
|
||||||
|
spec_path: Path to OpenAPI spec file
|
||||||
|
auth_token: Authentication token
|
||||||
|
cookie: Session cookie
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If required configuration is missing
|
||||||
|
"""
|
||||||
|
self.base_url = base_url
|
||||||
|
if not self.base_url:
|
||||||
|
raise ValueError("Missing required configuration: base_url")
|
||||||
|
|
||||||
|
self.spec_path = spec_path
|
||||||
|
self.auth_token = auth_token
|
||||||
|
self.cookie = cookie
|
||||||
|
|
||||||
|
if not self.auth_token and not self.cookie:
|
||||||
|
raise ValueError("Either auth_token or cookie must be provided")
|
||||||
@@ -20,6 +20,7 @@ class OperationDetails:
|
|||||||
method: str
|
method: str
|
||||||
parameters: dict[str, Any]
|
parameters: dict[str, Any]
|
||||||
input_model: type[BaseModel]
|
input_model: type[BaseModel]
|
||||||
|
description: str
|
||||||
|
|
||||||
|
|
||||||
class OperationParser:
|
class OperationParser:
|
||||||
@@ -104,6 +105,7 @@ class OperationParser:
|
|||||||
|
|
||||||
operation["path"] = path
|
operation["path"] = path
|
||||||
operation["path_item"] = path_item
|
operation["path_item"] = path_item
|
||||||
|
description = operation.get("description") or operation.get("summary") or operation_id
|
||||||
|
|
||||||
parameters = self.extract_parameters(operation)
|
parameters = self.extract_parameters(operation)
|
||||||
|
|
||||||
@@ -119,7 +121,7 @@ class OperationParser:
|
|||||||
# Create unified input model
|
# Create unified input model
|
||||||
input_model = self._create_input_model(operation_id, parameters, body_schema)
|
input_model = self._create_input_model(operation_id, parameters, body_schema)
|
||||||
|
|
||||||
return OperationDetails(operation_id=operation_id, path=str(path), method=method, parameters=parameters, input_model=input_model)
|
return OperationDetails(operation_id=operation_id, path=str(path), method=method, parameters=parameters, description=description, input_model=input_model)
|
||||||
|
|
||||||
raise ValueError(f"Operation {operation_id} not found in spec")
|
raise ValueError(f"Operation {operation_id} not found in spec")
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from mcp.server import Server
|
from mcp.server import Server
|
||||||
from mcp.server.stdio import stdio_server
|
from mcp.server.stdio import stdio_server
|
||||||
from mcp.types import TextContent, Tool
|
from mcp.types import TextContent, Tool
|
||||||
|
|
||||||
|
from airflow_mcp_server.config import AirflowConfig
|
||||||
from airflow_mcp_server.tools.tool_manager import get_airflow_tools, get_tool
|
from airflow_mcp_server.tools.tool_manager import get_airflow_tools, get_tool
|
||||||
|
|
||||||
# ===========THIS IS FOR DEBUGGING WITH MCP INSPECTOR===================
|
# ===========THIS IS FOR DEBUGGING WITH MCP INSPECTOR===================
|
||||||
@@ -20,18 +20,18 @@ from airflow_mcp_server.tools.tool_manager import get_airflow_tools, get_tool
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def serve() -> None:
|
async def serve(config: AirflowConfig) -> None:
|
||||||
"""Start MCP server."""
|
"""Start MCP server.
|
||||||
required_vars = ["AIRFLOW_BASE_URL", "AUTH_TOKEN"]
|
|
||||||
if not all(var in os.environ for var in required_vars):
|
|
||||||
raise ValueError(f"Missing required environment variables: {required_vars}")
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Configuration object with auth and URL settings
|
||||||
|
"""
|
||||||
server = Server("airflow-mcp-server")
|
server = Server("airflow-mcp-server")
|
||||||
|
|
||||||
@server.list_tools()
|
@server.list_tools()
|
||||||
async def list_tools() -> list[Tool]:
|
async def list_tools() -> list[Tool]:
|
||||||
try:
|
try:
|
||||||
return await get_airflow_tools()
|
return await get_airflow_tools(config)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to list tools: %s", e)
|
logger.error("Failed to list tools: %s", e)
|
||||||
raise
|
raise
|
||||||
@@ -39,7 +39,7 @@ async def serve() -> None:
|
|||||||
@server.call_tool()
|
@server.call_tool()
|
||||||
async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]:
|
async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]:
|
||||||
try:
|
try:
|
||||||
tool = await get_tool(name)
|
tool = await get_tool(config, name)
|
||||||
async with tool.client:
|
async with tool.client:
|
||||||
result = await tool.run(body=arguments)
|
result = await tool.run(body=arguments)
|
||||||
return [TextContent(type="text", text=str(result))]
|
return [TextContent(type="text", text=str(result))]
|
||||||
|
|||||||
@@ -1,28 +1,28 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from mcp.server import Server
|
from mcp.server import Server
|
||||||
from mcp.server.stdio import stdio_server
|
from mcp.server.stdio import stdio_server
|
||||||
from mcp.types import TextContent, Tool
|
from mcp.types import TextContent, Tool
|
||||||
|
|
||||||
|
from airflow_mcp_server.config import AirflowConfig
|
||||||
from airflow_mcp_server.tools.tool_manager import get_airflow_tools, get_tool
|
from airflow_mcp_server.tools.tool_manager import get_airflow_tools, get_tool
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def serve() -> None:
|
async def serve(config: AirflowConfig) -> None:
|
||||||
"""Start MCP server in safe mode (read-only operations)."""
|
"""Start MCP server in safe mode (read-only operations).
|
||||||
required_vars = ["AIRFLOW_BASE_URL", "AUTH_TOKEN"]
|
|
||||||
if not all(var in os.environ for var in required_vars):
|
|
||||||
raise ValueError(f"Missing required environment variables: {required_vars}")
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Configuration object with auth and URL settings
|
||||||
|
"""
|
||||||
server = Server("airflow-mcp-server-safe")
|
server = Server("airflow-mcp-server-safe")
|
||||||
|
|
||||||
@server.list_tools()
|
@server.list_tools()
|
||||||
async def list_tools() -> list[Tool]:
|
async def list_tools() -> list[Tool]:
|
||||||
try:
|
try:
|
||||||
return await get_airflow_tools(mode="safe")
|
return await get_airflow_tools(config, mode="safe")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to list tools: %s", e)
|
logger.error("Failed to list tools: %s", e)
|
||||||
raise
|
raise
|
||||||
@@ -32,7 +32,7 @@ async def serve() -> None:
|
|||||||
try:
|
try:
|
||||||
if not name.startswith("get_"):
|
if not name.startswith("get_"):
|
||||||
raise ValueError("Only GET operations allowed in safe mode")
|
raise ValueError("Only GET operations allowed in safe mode")
|
||||||
tool = await get_tool(name)
|
tool = await get_tool(config, name)
|
||||||
async with tool.client:
|
async with tool.client:
|
||||||
result = await tool.run(body=arguments)
|
result = await tool.run(body=arguments)
|
||||||
return [TextContent(type="text", text=str(result))]
|
return [TextContent(type="text", text=str(result))]
|
||||||
|
|||||||
@@ -1,28 +1,28 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from mcp.server import Server
|
from mcp.server import Server
|
||||||
from mcp.server.stdio import stdio_server
|
from mcp.server.stdio import stdio_server
|
||||||
from mcp.types import TextContent, Tool
|
from mcp.types import TextContent, Tool
|
||||||
|
|
||||||
|
from airflow_mcp_server.config import AirflowConfig
|
||||||
from airflow_mcp_server.tools.tool_manager import get_airflow_tools, get_tool
|
from airflow_mcp_server.tools.tool_manager import get_airflow_tools, get_tool
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def serve() -> None:
|
async def serve(config: AirflowConfig) -> None:
|
||||||
"""Start MCP server in unsafe mode (all operations)."""
|
"""Start MCP server in unsafe mode (all operations).
|
||||||
required_vars = ["AIRFLOW_BASE_URL", "AUTH_TOKEN"]
|
|
||||||
if not all(var in os.environ for var in required_vars):
|
|
||||||
raise ValueError(f"Missing required environment variables: {required_vars}")
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Configuration object with auth and URL settings
|
||||||
|
"""
|
||||||
server = Server("airflow-mcp-server-unsafe")
|
server = Server("airflow-mcp-server-unsafe")
|
||||||
|
|
||||||
@server.list_tools()
|
@server.list_tools()
|
||||||
async def list_tools() -> list[Tool]:
|
async def list_tools() -> list[Tool]:
|
||||||
try:
|
try:
|
||||||
return await get_airflow_tools(mode="unsafe")
|
return await get_airflow_tools(config, mode="unsafe")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to list tools: %s", e)
|
logger.error("Failed to list tools: %s", e)
|
||||||
raise
|
raise
|
||||||
@@ -30,7 +30,7 @@ async def serve() -> None:
|
|||||||
@server.call_tool()
|
@server.call_tool()
|
||||||
async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]:
|
async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]:
|
||||||
try:
|
try:
|
||||||
tool = await get_tool(name)
|
tool = await get_tool(config, name)
|
||||||
async with tool.client:
|
async with tool.client:
|
||||||
result = await tool.run(body=arguments)
|
result = await tool.run(body=arguments)
|
||||||
return [TextContent(type="text", text=str(result))]
|
return [TextContent(type="text", text=str(result))]
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
class BaseTools(ABC):
|
class BaseTools(ABC):
|
||||||
"""Abstract base class for tools."""
|
"""Abstract base class for tools."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
"""Initialize the tool."""
|
"""Initialize the tool."""
|
||||||
@@ -12,7 +13,7 @@ class BaseTools(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def run(self) -> Any:
|
def run(self) -> Any:
|
||||||
"""Execute the tool's main functionality.
|
"""Execute the tool's main functionality.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Any: The result of the tool execution
|
Any: The result of the tool execution
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
from importlib import resources
|
from importlib import resources
|
||||||
|
|
||||||
from mcp.types import Tool
|
from mcp.types import Tool
|
||||||
|
|
||||||
from airflow_mcp_server.client.airflow_client import AirflowClient
|
from airflow_mcp_server.client.airflow_client import AirflowClient
|
||||||
|
from airflow_mcp_server.config import AirflowConfig
|
||||||
from airflow_mcp_server.parser.operation_parser import OperationParser
|
from airflow_mcp_server.parser.operation_parser import OperationParser
|
||||||
from airflow_mcp_server.tools.airflow_tool import AirflowTool
|
from airflow_mcp_server.tools.airflow_tool import AirflowTool
|
||||||
|
|
||||||
@@ -13,16 +13,19 @@ logger = logging.getLogger(__name__)
|
|||||||
_tools_cache: dict[str, AirflowTool] = {}
|
_tools_cache: dict[str, AirflowTool] = {}
|
||||||
|
|
||||||
|
|
||||||
def _initialize_client() -> AirflowClient:
|
def _initialize_client(config: AirflowConfig) -> AirflowClient:
|
||||||
"""Initialize Airflow client with environment variables or embedded spec.
|
"""Initialize Airflow client with configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Configuration object with auth and URL settings
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
AirflowClient instance
|
AirflowClient instance
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If required environment variables are missing or default spec is not found
|
ValueError: If default spec is not found
|
||||||
"""
|
"""
|
||||||
spec_path = os.environ.get("OPENAPI_SPEC")
|
spec_path = config.spec_path
|
||||||
if not spec_path:
|
if not spec_path:
|
||||||
# Fallback to embedded v1.yaml
|
# Fallback to embedded v1.yaml
|
||||||
try:
|
try:
|
||||||
@@ -32,25 +35,33 @@ def _initialize_client() -> AirflowClient:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError("Default OpenAPI spec not found in package resources") from e
|
raise ValueError("Default OpenAPI spec not found in package resources") from e
|
||||||
|
|
||||||
required_vars = ["AIRFLOW_BASE_URL", "AUTH_TOKEN"]
|
# Initialize client with appropriate authentication method
|
||||||
missing_vars = [var for var in required_vars if var not in os.environ]
|
client_args = {"spec_path": spec_path, "base_url": config.base_url}
|
||||||
if missing_vars:
|
|
||||||
raise ValueError(f"Missing required environment variables: {missing_vars}")
|
|
||||||
|
|
||||||
return AirflowClient(spec_path=spec_path, base_url=os.environ["AIRFLOW_BASE_URL"], auth_token=os.environ["AUTH_TOKEN"])
|
# Apply cookie auth first if available (highest precedence)
|
||||||
|
if config.cookie:
|
||||||
|
client_args["cookie"] = config.cookie
|
||||||
|
# Otherwise use auth token if available
|
||||||
|
elif config.auth_token:
|
||||||
|
client_args["auth_token"] = config.auth_token
|
||||||
|
|
||||||
|
return AirflowClient(**client_args)
|
||||||
|
|
||||||
|
|
||||||
async def _initialize_tools() -> None:
|
async def _initialize_tools(config: AirflowConfig) -> None:
|
||||||
"""Initialize tools cache with Airflow operations.
|
"""Initialize tools cache with Airflow operations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Configuration object with auth and URL settings
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If initialization fails
|
ValueError: If initialization fails
|
||||||
"""
|
"""
|
||||||
global _tools_cache
|
global _tools_cache
|
||||||
|
|
||||||
try:
|
try:
|
||||||
client = _initialize_client()
|
client = _initialize_client(config)
|
||||||
spec_path = os.environ.get("OPENAPI_SPEC")
|
spec_path = config.spec_path
|
||||||
if not spec_path:
|
if not spec_path:
|
||||||
with resources.files("airflow_mcp_server.resources").joinpath("v1.yaml").open("rb") as f:
|
with resources.files("airflow_mcp_server.resources").joinpath("v1.yaml").open("rb") as f:
|
||||||
spec_path = f.name
|
spec_path = f.name
|
||||||
@@ -68,10 +79,11 @@ async def _initialize_tools() -> None:
|
|||||||
raise ValueError(f"Failed to initialize tools: {e}") from e
|
raise ValueError(f"Failed to initialize tools: {e}") from e
|
||||||
|
|
||||||
|
|
||||||
async def get_airflow_tools(mode: str = "unsafe") -> list[Tool]:
|
async def get_airflow_tools(config: AirflowConfig, mode: str = "unsafe") -> list[Tool]:
|
||||||
"""Get list of available Airflow tools based on mode.
|
"""Get list of available Airflow tools based on mode.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
config: Configuration object with auth and URL settings
|
||||||
mode: "safe" for GET operations only, "unsafe" for all operations (default)
|
mode: "safe" for GET operations only, "unsafe" for all operations (default)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -81,7 +93,7 @@ async def get_airflow_tools(mode: str = "unsafe") -> list[Tool]:
|
|||||||
ValueError: If initialization fails
|
ValueError: If initialization fails
|
||||||
"""
|
"""
|
||||||
if not _tools_cache:
|
if not _tools_cache:
|
||||||
await _initialize_tools()
|
await _initialize_tools(config)
|
||||||
|
|
||||||
tools = []
|
tools = []
|
||||||
for operation_id, tool in _tools_cache.items():
|
for operation_id, tool in _tools_cache.items():
|
||||||
@@ -93,7 +105,7 @@ async def get_airflow_tools(mode: str = "unsafe") -> list[Tool]:
|
|||||||
tools.append(
|
tools.append(
|
||||||
Tool(
|
Tool(
|
||||||
name=operation_id,
|
name=operation_id,
|
||||||
description=tool.operation.operation_id,
|
description=tool.operation.description,
|
||||||
inputSchema=schema,
|
inputSchema=schema,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -104,10 +116,11 @@ async def get_airflow_tools(mode: str = "unsafe") -> list[Tool]:
|
|||||||
return tools
|
return tools
|
||||||
|
|
||||||
|
|
||||||
async def get_tool(name: str) -> AirflowTool:
|
async def get_tool(config: AirflowConfig, name: str) -> AirflowTool:
|
||||||
"""Get specific tool by name.
|
"""Get specific tool by name.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
config: Configuration object with auth and URL settings
|
||||||
name: Tool/operation name
|
name: Tool/operation name
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -118,7 +131,7 @@ async def get_tool(name: str) -> AirflowTool:
|
|||||||
ValueError: If tool initialization fails
|
ValueError: If tool initialization fails
|
||||||
"""
|
"""
|
||||||
if not _tools_cache:
|
if not _tools_cache:
|
||||||
await _initialize_tools()
|
await _initialize_tools(config)
|
||||||
|
|
||||||
if name not in _tools_cache:
|
if name not in _tools_cache:
|
||||||
raise KeyError(f"Tool {name} not found")
|
raise KeyError(f"Tool {name} not found")
|
||||||
|
|||||||
@@ -32,6 +32,31 @@ def test_init_client_initialization(client: AirflowClient) -> None:
|
|||||||
assert isinstance(client.spec, OpenAPI)
|
assert isinstance(client.spec, OpenAPI)
|
||||||
assert client.base_url == "http://localhost:8080/api/v1"
|
assert client.base_url == "http://localhost:8080/api/v1"
|
||||||
assert client.headers["Authorization"] == "Basic test-token"
|
assert client.headers["Authorization"] == "Basic test-token"
|
||||||
|
assert "Cookie" not in client.headers
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_client_with_cookie() -> None:
|
||||||
|
with resources.files("airflow_mcp_server.resources").joinpath("v1.yaml").open("rb") as f:
|
||||||
|
spec = yaml.safe_load(f)
|
||||||
|
client = AirflowClient(
|
||||||
|
spec_path=spec,
|
||||||
|
base_url="http://localhost:8080/api/v1",
|
||||||
|
cookie="session=b18e8c5e-92f5-4d1e-a8f2-7c1b62110cae.vmX5kqDq5TdvT9BzTlypMVclAwM",
|
||||||
|
)
|
||||||
|
assert isinstance(client.spec, OpenAPI)
|
||||||
|
assert client.base_url == "http://localhost:8080/api/v1"
|
||||||
|
assert "Authorization" not in client.headers
|
||||||
|
assert client.headers["Cookie"] == "session=b18e8c5e-92f5-4d1e-a8f2-7c1b62110cae.vmX5kqDq5TdvT9BzTlypMVclAwM"
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_client_missing_auth() -> None:
|
||||||
|
with resources.files("airflow_mcp_server.resources").joinpath("v1.yaml").open("rb") as f:
|
||||||
|
spec = yaml.safe_load(f)
|
||||||
|
with pytest.raises(ValueError, match="Either auth_token or cookie must be provided"):
|
||||||
|
AirflowClient(
|
||||||
|
spec_path=spec,
|
||||||
|
base_url="http://localhost:8080/api/v1",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_init_load_spec_from_bytes() -> None:
|
def test_init_load_spec_from_bytes() -> None:
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ logger = logging.getLogger(__name__)
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def spec_file():
|
def spec_file():
|
||||||
"""Get content of the v1.yaml spec file."""
|
"""Get content of the v1.yaml spec file."""
|
||||||
with resources.files("tests.client").joinpath("v1.yaml").open("rb") as f:
|
with resources.files("airflow_mcp_server.resources").joinpath("v1.yaml").open("rb") as f:
|
||||||
return f.read()
|
return f.read()
|
||||||
|
|
||||||
|
|
||||||
@@ -31,6 +31,24 @@ def test_parse_operation_basic(parser: OperationParser) -> None:
|
|||||||
assert operation.operation_id == "get_dags"
|
assert operation.operation_id == "get_dags"
|
||||||
assert operation.path == "/dags"
|
assert operation.path == "/dags"
|
||||||
assert operation.method == "get"
|
assert operation.method == "get"
|
||||||
|
assert (
|
||||||
|
operation.description
|
||||||
|
== """List DAGs in the database.
|
||||||
|
`dag_id_pattern` can be set to match dags of a specific pattern
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
assert isinstance(operation.parameters, dict)
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_operation_with_no_description_but_summary(parser: OperationParser) -> None:
|
||||||
|
"""Test parsing operation with no description but summary."""
|
||||||
|
operation = parser.parse_operation("get_connections")
|
||||||
|
|
||||||
|
assert isinstance(operation, OperationDetails)
|
||||||
|
assert operation.operation_id == "get_connections"
|
||||||
|
assert operation.path == "/connections"
|
||||||
|
assert operation.method == "get"
|
||||||
|
assert operation.description == "List connections"
|
||||||
assert isinstance(operation.parameters, dict)
|
assert isinstance(operation.parameters, dict)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
2
uv.lock
generated
2
uv.lock
generated
@@ -111,7 +111,7 @@ wheels = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "airflow-mcp-server"
|
name = "airflow-mcp-server"
|
||||||
version = "0.2.0"
|
version = "0.4.0"
|
||||||
source = { editable = "." }
|
source = { editable = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "aiofiles" },
|
{ name = "aiofiles" },
|
||||||
|
|||||||
Reference in New Issue
Block a user