diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml new file mode 100644 index 0000000..7e53814 --- /dev/null +++ b/.github/workflows/pytest.yml @@ -0,0 +1,29 @@ +name: Run Tests + +on: + pull_request: + branches: [ main ] + +jobs: + test: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.11", "3.12"] + + steps: + - uses: actions/checkout@v3 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install .[test] + + - name: Run pytest + run: | + pytest tests/ -v diff --git a/README.md b/README.md index fa81cf0..bc7de51 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,10 @@ https://github.com/user-attachments/assets/f3e60fff-8680-4dd9-b08e-fa7db655a705 ], "env": { "AIRFLOW_BASE_URL": "http:///api/v1", - "AUTH_TOKEN": "" + // Either use AUTH_TOKEN for basic auth + "AUTH_TOKEN": "", + // Or use COOKIE for cookie-based auth + "COOKIE": "" } } } @@ -57,10 +60,17 @@ airflow-mcp-server --unsafe The MCP Server expects environment variables to be set: - `AIRFLOW_BASE_URL`: The base URL of the Airflow API -- `AUTH_TOKEN`: The token to use for authorization (_This should be base64 encoded username:password_) +- `AUTH_TOKEN`: The token to use for basic auth (_This should be base64 encoded username:password_) (_Optional if COOKIE is provided_) +- `COOKIE`: The session cookie to use for authentication (_Optional if AUTH_TOKEN is provided_) - `OPENAPI_SPEC`: The path to the OpenAPI spec file (_Optional_) (_defaults to latest stable release_) -*Currently, only Basic Auth is supported.* +**Authentication** + +The server supports two authentication methods: +- **Basic Auth**: Using base64 encoded username:password via `AUTH_TOKEN` environment variable +- **Cookie**: Using session cookie via `COOKIE` environment variable + +At least one of these authentication methods must be provided. **Page Limit** @@ -71,6 +81,7 @@ The default is 100 items, but you can change it using `maximum_page_limit` optio - [x] First API - [x] Parse OpenAPI Spec - [x] Safe/Unsafe mode implementation +- [x] Allow session auth - [ ] Parse proper description with list_tools. - [ ] Airflow config fetch (_specifically for page limit_) - [ ] Env variables optional (_env variables might not be ideal for airflow plugins_) diff --git a/pyproject.toml b/pyproject.toml index e3f292a..b8dc0ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,8 @@ build-backend = "hatchling.build" exclude = [ "*", "!src/**", - "!pyproject.toml" + "!pyproject.toml", + "!assets/**" ] [tool.hatch.build.targets.wheel] diff --git a/src/airflow_mcp_server/__init__.py b/src/airflow_mcp_server/__init__.py index 8decf7f..4f860f2 100644 --- a/src/airflow_mcp_server/__init__.py +++ b/src/airflow_mcp_server/__init__.py @@ -1,9 +1,11 @@ import asyncio import logging +import os import sys import click +from airflow_mcp_server.config import AirflowConfig from airflow_mcp_server.server_safe import serve as serve_safe from airflow_mcp_server.server_unsafe import serve as serve_unsafe @@ -12,7 +14,11 @@ from airflow_mcp_server.server_unsafe import serve as serve_unsafe @click.option("-v", "--verbose", count=True, help="Increase verbosity") @click.option("--safe", "-s", is_flag=True, help="Use only read-only tools") @click.option("--unsafe", "-u", is_flag=True, help="Use all tools (default)") -def main(verbose: int, safe: bool, unsafe: bool) -> None: +@click.option("--base-url", help="Airflow API base URL") +@click.option("--spec-path", help="Path to OpenAPI spec file") +@click.option("--auth-token", help="Authentication token") +@click.option("--cookie", help="Session cookie") +def main(verbose: int, safe: bool, unsafe: bool, base_url: str = None, spec_path: str = None, auth_token: str = None, cookie: str = None) -> None: """MCP server for Airflow""" logging_level = logging.WARN if verbose == 1: @@ -22,13 +28,30 @@ def main(verbose: int, safe: bool, unsafe: bool) -> None: logging.basicConfig(level=logging_level, stream=sys.stderr) - if safe and unsafe: - raise click.UsageError("Options --safe and --unsafe are mutually exclusive") + # Read environment variables with proper precedence + # Environment variables take precedence over CLI arguments + config_base_url = os.environ.get("AIRFLOW_BASE_URL") or base_url + config_spec_path = os.environ.get("OPENAPI_SPEC") or spec_path + config_auth_token = os.environ.get("AUTH_TOKEN") or auth_token + config_cookie = os.environ.get("COOKIE") or cookie - if safe: - asyncio.run(serve_safe()) - else: # Default to unsafe mode - asyncio.run(serve_unsafe()) + # Initialize configuration + try: + config = AirflowConfig(base_url=config_base_url, spec_path=config_spec_path, auth_token=config_auth_token, cookie=config_cookie) + except ValueError as e: + click.echo(f"Configuration error: {e}", err=True) + sys.exit(1) + + # Determine server mode with proper precedence + if safe and unsafe: + # CLI argument validation + raise click.UsageError("Options --safe and --unsafe are mutually exclusive") + elif safe: + # CLI argument for safe mode + asyncio.run(serve_safe(config)) + else: + # Default to unsafe mode + asyncio.run(serve_unsafe(config)) if __name__ == "__main__": diff --git a/src/airflow_mcp_server/client/airflow_client.py b/src/airflow_mcp_server/client/airflow_client.py index e9f247c..e1e5117 100644 --- a/src/airflow_mcp_server/client/airflow_client.py +++ b/src/airflow_mcp_server/client/airflow_client.py @@ -35,18 +35,22 @@ class AirflowClient: self, spec_path: Path | str | dict | bytes | BinaryIO | TextIO, base_url: str, - auth_token: str, + auth_token: str | None = None, + cookie: str | None = None, ) -> None: """Initialize Airflow client. Args: spec_path: OpenAPI spec as file path, dict, bytes, or file object base_url: Base URL for API - auth_token: Authentication token + auth_token: Authentication token (optional if cookie is provided) + cookie: Session cookie (optional if auth_token is provided) Raises: - ValueError: If spec_path is invalid or spec cannot be loaded + ValueError: If spec_path is invalid or spec cannot be loaded or if neither auth_token nor cookie is provided """ + if not auth_token and not cookie: + raise ValueError("Either auth_token or cookie must be provided") try: # Load and parse OpenAPI spec if isinstance(spec_path, dict): @@ -96,10 +100,13 @@ class AirflowClient: # API configuration self.base_url = base_url.rstrip("/") - self.headers = { - "Authorization": f"Basic {auth_token}", - "Accept": "application/json", - } + self.headers = {"Accept": "application/json"} + + # Set authentication header based on precedence (cookie > auth_token) + if cookie: + self.headers["Cookie"] = cookie + elif auth_token: + self.headers["Authorization"] = f"Basic {auth_token}" except Exception as e: logger.error("Failed to initialize AirflowClient: %s", e) diff --git a/src/airflow_mcp_server/config.py b/src/airflow_mcp_server/config.py new file mode 100644 index 0000000..f3bb9d2 --- /dev/null +++ b/src/airflow_mcp_server/config.py @@ -0,0 +1,25 @@ +class AirflowConfig: + """Centralized configuration for Airflow MCP server.""" + + def __init__(self, base_url: str | None = None, spec_path: str | None = None, auth_token: str | None = None, cookie: str | None = None) -> None: + """Initialize configuration with provided values. + + Args: + base_url: Airflow API base URL + spec_path: Path to OpenAPI spec file + auth_token: Authentication token + cookie: Session cookie + + Raises: + ValueError: If required configuration is missing + """ + self.base_url = base_url + if not self.base_url: + raise ValueError("Missing required configuration: base_url") + + self.spec_path = spec_path + self.auth_token = auth_token + self.cookie = cookie + + if not self.auth_token and not self.cookie: + raise ValueError("Either auth_token or cookie must be provided") diff --git a/src/airflow_mcp_server/server.py b/src/airflow_mcp_server/server.py index 09933d2..c9063b4 100644 --- a/src/airflow_mcp_server/server.py +++ b/src/airflow_mcp_server/server.py @@ -1,11 +1,11 @@ import logging -import os from typing import Any from mcp.server import Server from mcp.server.stdio import stdio_server from mcp.types import TextContent, Tool +from airflow_mcp_server.config import AirflowConfig from airflow_mcp_server.tools.tool_manager import get_airflow_tools, get_tool # ===========THIS IS FOR DEBUGGING WITH MCP INSPECTOR=================== @@ -20,18 +20,18 @@ from airflow_mcp_server.tools.tool_manager import get_airflow_tools, get_tool logger = logging.getLogger(__name__) -async def serve() -> None: - """Start MCP server.""" - required_vars = ["AIRFLOW_BASE_URL", "AUTH_TOKEN"] - if not all(var in os.environ for var in required_vars): - raise ValueError(f"Missing required environment variables: {required_vars}") +async def serve(config: AirflowConfig) -> None: + """Start MCP server. + Args: + config: Configuration object with auth and URL settings + """ server = Server("airflow-mcp-server") @server.list_tools() async def list_tools() -> list[Tool]: try: - return await get_airflow_tools() + return await get_airflow_tools(config) except Exception as e: logger.error("Failed to list tools: %s", e) raise @@ -39,7 +39,7 @@ async def serve() -> None: @server.call_tool() async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]: try: - tool = await get_tool(name) + tool = await get_tool(config, name) async with tool.client: result = await tool.run(body=arguments) return [TextContent(type="text", text=str(result))] diff --git a/src/airflow_mcp_server/server_safe.py b/src/airflow_mcp_server/server_safe.py index bf81b3e..543b3b5 100644 --- a/src/airflow_mcp_server/server_safe.py +++ b/src/airflow_mcp_server/server_safe.py @@ -1,28 +1,28 @@ import logging -import os from typing import Any from mcp.server import Server from mcp.server.stdio import stdio_server from mcp.types import TextContent, Tool +from airflow_mcp_server.config import AirflowConfig from airflow_mcp_server.tools.tool_manager import get_airflow_tools, get_tool logger = logging.getLogger(__name__) -async def serve() -> None: - """Start MCP server in safe mode (read-only operations).""" - required_vars = ["AIRFLOW_BASE_URL", "AUTH_TOKEN"] - if not all(var in os.environ for var in required_vars): - raise ValueError(f"Missing required environment variables: {required_vars}") +async def serve(config: AirflowConfig) -> None: + """Start MCP server in safe mode (read-only operations). + Args: + config: Configuration object with auth and URL settings + """ server = Server("airflow-mcp-server-safe") @server.list_tools() async def list_tools() -> list[Tool]: try: - return await get_airflow_tools(mode="safe") + return await get_airflow_tools(config, mode="safe") except Exception as e: logger.error("Failed to list tools: %s", e) raise @@ -32,7 +32,7 @@ async def serve() -> None: try: if not name.startswith("get_"): raise ValueError("Only GET operations allowed in safe mode") - tool = await get_tool(name) + tool = await get_tool(config, name) async with tool.client: result = await tool.run(body=arguments) return [TextContent(type="text", text=str(result))] diff --git a/src/airflow_mcp_server/server_unsafe.py b/src/airflow_mcp_server/server_unsafe.py index bcc7932..347db5c 100644 --- a/src/airflow_mcp_server/server_unsafe.py +++ b/src/airflow_mcp_server/server_unsafe.py @@ -1,28 +1,28 @@ import logging -import os from typing import Any from mcp.server import Server from mcp.server.stdio import stdio_server from mcp.types import TextContent, Tool +from airflow_mcp_server.config import AirflowConfig from airflow_mcp_server.tools.tool_manager import get_airflow_tools, get_tool logger = logging.getLogger(__name__) -async def serve() -> None: - """Start MCP server in unsafe mode (all operations).""" - required_vars = ["AIRFLOW_BASE_URL", "AUTH_TOKEN"] - if not all(var in os.environ for var in required_vars): - raise ValueError(f"Missing required environment variables: {required_vars}") +async def serve(config: AirflowConfig) -> None: + """Start MCP server in unsafe mode (all operations). + Args: + config: Configuration object with auth and URL settings + """ server = Server("airflow-mcp-server-unsafe") @server.list_tools() async def list_tools() -> list[Tool]: try: - return await get_airflow_tools(mode="unsafe") + return await get_airflow_tools(config, mode="unsafe") except Exception as e: logger.error("Failed to list tools: %s", e) raise @@ -30,7 +30,7 @@ async def serve() -> None: @server.call_tool() async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]: try: - tool = await get_tool(name) + tool = await get_tool(config, name) async with tool.client: result = await tool.run(body=arguments) return [TextContent(type="text", text=str(result))] diff --git a/src/airflow_mcp_server/tools/tool_manager.py b/src/airflow_mcp_server/tools/tool_manager.py index ee70e86..3588cc7 100644 --- a/src/airflow_mcp_server/tools/tool_manager.py +++ b/src/airflow_mcp_server/tools/tool_manager.py @@ -1,10 +1,10 @@ import logging -import os from importlib import resources from mcp.types import Tool from airflow_mcp_server.client.airflow_client import AirflowClient +from airflow_mcp_server.config import AirflowConfig from airflow_mcp_server.parser.operation_parser import OperationParser from airflow_mcp_server.tools.airflow_tool import AirflowTool @@ -13,16 +13,19 @@ logger = logging.getLogger(__name__) _tools_cache: dict[str, AirflowTool] = {} -def _initialize_client() -> AirflowClient: - """Initialize Airflow client with environment variables or embedded spec. +def _initialize_client(config: AirflowConfig) -> AirflowClient: + """Initialize Airflow client with configuration. + + Args: + config: Configuration object with auth and URL settings Returns: AirflowClient instance Raises: - ValueError: If required environment variables are missing or default spec is not found + ValueError: If default spec is not found """ - spec_path = os.environ.get("OPENAPI_SPEC") + spec_path = config.spec_path if not spec_path: # Fallback to embedded v1.yaml try: @@ -32,25 +35,33 @@ def _initialize_client() -> AirflowClient: except Exception as e: raise ValueError("Default OpenAPI spec not found in package resources") from e - required_vars = ["AIRFLOW_BASE_URL", "AUTH_TOKEN"] - missing_vars = [var for var in required_vars if var not in os.environ] - if missing_vars: - raise ValueError(f"Missing required environment variables: {missing_vars}") + # Initialize client with appropriate authentication method + client_args = {"spec_path": spec_path, "base_url": config.base_url} - return AirflowClient(spec_path=spec_path, base_url=os.environ["AIRFLOW_BASE_URL"], auth_token=os.environ["AUTH_TOKEN"]) + # Apply cookie auth first if available (highest precedence) + if config.cookie: + client_args["cookie"] = config.cookie + # Otherwise use auth token if available + elif config.auth_token: + client_args["auth_token"] = config.auth_token + + return AirflowClient(**client_args) -async def _initialize_tools() -> None: +async def _initialize_tools(config: AirflowConfig) -> None: """Initialize tools cache with Airflow operations. + Args: + config: Configuration object with auth and URL settings + Raises: ValueError: If initialization fails """ global _tools_cache try: - client = _initialize_client() - spec_path = os.environ.get("OPENAPI_SPEC") + client = _initialize_client(config) + spec_path = config.spec_path if not spec_path: with resources.files("airflow_mcp_server.resources").joinpath("v1.yaml").open("rb") as f: spec_path = f.name @@ -68,10 +79,11 @@ async def _initialize_tools() -> None: raise ValueError(f"Failed to initialize tools: {e}") from e -async def get_airflow_tools(mode: str = "unsafe") -> list[Tool]: +async def get_airflow_tools(config: AirflowConfig, mode: str = "unsafe") -> list[Tool]: """Get list of available Airflow tools based on mode. Args: + config: Configuration object with auth and URL settings mode: "safe" for GET operations only, "unsafe" for all operations (default) Returns: @@ -81,7 +93,7 @@ async def get_airflow_tools(mode: str = "unsafe") -> list[Tool]: ValueError: If initialization fails """ if not _tools_cache: - await _initialize_tools() + await _initialize_tools(config) tools = [] for operation_id, tool in _tools_cache.items(): @@ -104,10 +116,11 @@ async def get_airflow_tools(mode: str = "unsafe") -> list[Tool]: return tools -async def get_tool(name: str) -> AirflowTool: +async def get_tool(config: AirflowConfig, name: str) -> AirflowTool: """Get specific tool by name. Args: + config: Configuration object with auth and URL settings name: Tool/operation name Returns: @@ -118,7 +131,7 @@ async def get_tool(name: str) -> AirflowTool: ValueError: If tool initialization fails """ if not _tools_cache: - await _initialize_tools() + await _initialize_tools(config) if name not in _tools_cache: raise KeyError(f"Tool {name} not found") diff --git a/tests/client/test_airflow_client.py b/tests/client/test_airflow_client.py index 3fdc03d..0b4460c 100644 --- a/tests/client/test_airflow_client.py +++ b/tests/client/test_airflow_client.py @@ -32,6 +32,31 @@ def test_init_client_initialization(client: AirflowClient) -> None: assert isinstance(client.spec, OpenAPI) assert client.base_url == "http://localhost:8080/api/v1" assert client.headers["Authorization"] == "Basic test-token" + assert "Cookie" not in client.headers + + +def test_init_client_with_cookie() -> None: + with resources.files("airflow_mcp_server.resources").joinpath("v1.yaml").open("rb") as f: + spec = yaml.safe_load(f) + client = AirflowClient( + spec_path=spec, + base_url="http://localhost:8080/api/v1", + cookie="session=b18e8c5e-92f5-4d1e-a8f2-7c1b62110cae.vmX5kqDq5TdvT9BzTlypMVclAwM", + ) + assert isinstance(client.spec, OpenAPI) + assert client.base_url == "http://localhost:8080/api/v1" + assert "Authorization" not in client.headers + assert client.headers["Cookie"] == "session=b18e8c5e-92f5-4d1e-a8f2-7c1b62110cae.vmX5kqDq5TdvT9BzTlypMVclAwM" + + +def test_init_client_missing_auth() -> None: + with resources.files("airflow_mcp_server.resources").joinpath("v1.yaml").open("rb") as f: + spec = yaml.safe_load(f) + with pytest.raises(ValueError, match="Either auth_token or cookie must be provided"): + AirflowClient( + spec_path=spec, + base_url="http://localhost:8080/api/v1", + ) def test_init_load_spec_from_bytes() -> None: diff --git a/uv.lock b/uv.lock index 15e19b1..362c5e4 100644 --- a/uv.lock +++ b/uv.lock @@ -111,7 +111,7 @@ wheels = [ [[package]] name = "airflow-mcp-server" -version = "0.2.0" +version = "0.3.0" source = { editable = "." } dependencies = [ { name = "aiofiles" },