diff --git a/airflow-mcp-server/README.md b/airflow-mcp-server/README.md index b49b9bb..857a861 100644 --- a/airflow-mcp-server/README.md +++ b/airflow-mcp-server/README.md @@ -3,3 +3,22 @@ ## Overview A [Model Context Protocol](https://modelcontextprotocol.io/) server for controlling Airflow via Airflow APIs. + + +### Considerations + +The MCP Server expects environment variables to be set: +- `AIRFLOW_BASE_URL`: The base URL of the Airflow API +- `AUTH_TOKEN`: The token to use for authorization + +*Currently, only Session mode is supported.* + +**Page Limit** + +The default is 100 items, but you can change it using `maximum_page_limit` option in [api] section in the `airflow.cfg` file. + +## Tasks + +- [x] First API +- [ ] Airflow config fetch (_specifically for page limit_) +- [ ] Env variables optional (_env variables might not be ideal for airflow plugins_) diff --git a/airflow-mcp-server/src/airflow_mcp_server/server.py b/airflow-mcp-server/src/airflow_mcp_server/server.py index e69de29..fe2a72b 100644 --- a/airflow-mcp-server/src/airflow_mcp_server/server.py +++ b/airflow-mcp-server/src/airflow_mcp_server/server.py @@ -0,0 +1,62 @@ +import os +from enum import Enum +from typing import Any + +from mcp.server import Server +from mcp.server.stdio import stdio_server +from mcp.types import TextContent, Tool + +from airflow_mcp_server.tools.models import ListDags +from airflow_mcp_server.tools.tool_manager import get_airflow_dag_tools + + +class AirflowAPITools(str, Enum): + # DAG Operations + LIST_DAGS = "list_dags" + + +async def process_instruction(instruction: dict[str, Any]) -> dict[str, Any]: + dag_tools = get_airflow_dag_tools() + + try: + match instruction["type"]: + case "list_dags": + return {"dags": await dag_tools.list_dags()} + case _: + return {"message": "Invalid instruction type"} + except Exception as e: + return {"error": str(e)} + + +async def serve() -> None: + server = Server("airflow-mcp-server") + + @server.list_tools() + async def list_tools() -> list[Tool]: + tools = [ + # DAG Operations + Tool( + name=AirflowAPITools.LIST_DAGS, + description="Lists all DAGs in Airflow", + inputSchema=ListDags.model_json_schema(), + ), + ] + if "AIRFLOW_BASE_URL" in os.environ and "AUTH_TOKEN" in os.environ: + return tools + else: + return [] + + @server.call_tool() + async def call_tool(name: str, arguments: dict) -> list[TextContent]: + dag_tools = get_airflow_dag_tools() + + match name: + case AirflowAPITools.LIST_DAGS: + result = await dag_tools.list_dags() + return [TextContent(type="text", text=result)] + case _: + raise ValueError(f"Unknown tool: {name}") + + options = server.create_initialization_options() + async with stdio_server() as (read_stream, write_stream): + server.run(read_stream, write_stream, options, raise_exceptions=True) diff --git a/airflow-mcp-server/src/airflow_mcp_server/tools/airflow_dag_tools.py b/airflow-mcp-server/src/airflow_mcp_server/tools/airflow_dag_tools.py new file mode 100644 index 0000000..75bf53f --- /dev/null +++ b/airflow-mcp-server/src/airflow_mcp_server/tools/airflow_dag_tools.py @@ -0,0 +1,103 @@ +import os + +import aiohttp + + +class AirflowDagTools: + def __init__(self): + self.airflow_base_url = os.getenv("AIRFLOW_BASE_URL") + self.auth_token = os.getenv("AUTH_TOKEN") + self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.auth_token}"} + + async def list_dags( + self, + limit: int | None = 100, + offset: int | None = None, + order_by: str | None = None, + tags: list[str] | None = None, + only_active: bool = True, + paused: bool | None = None, + fields: list[str] | None = None, + dag_id_pattern: str | None = None, + ) -> list[str]: + """ + List all DAGs in Airflow. + + Sample response: + { + "dags": [ + { + "dag_id": "string", + "dag_display_name": "string", + "root_dag_id": "string", + "is_paused": true, + "is_active": true, + "is_subdag": true, + "last_parsed_time": "2019-08-24T14:15:22Z", + "last_pickled": "2019-08-24T14:15:22Z", + "last_expired": "2019-08-24T14:15:22Z", + "scheduler_lock": true, + "pickle_id": "string", + "default_view": "string", + "fileloc": "string", + "file_token": "string", + "owners": [ + "string" + ], + "description": "string", + "schedule_interval": { + "__type": "string", + "days": 0, + "seconds": 0, + "microseconds": 0 + }, + "timetable_description": "string", + "tags": [ + { + "name": "string" + } + ], + "max_active_tasks": 0, + "max_active_runs": 0, + "has_task_concurrency_limits": true, + "has_import_errors": true, + "next_dagrun": "2019-08-24T14:15:22Z", + "next_dagrun_data_interval_start": "2019-08-24T14:15:22Z", + "next_dagrun_data_interval_end": "2019-08-24T14:15:22Z", + "next_dagrun_create_after": "2019-08-24T14:15:22Z", + "max_consecutive_failed_dag_runs": 0 + } + ], + "total_entries": 0 + } + + Args: + limit (int, optional): The numbers of items to return. + offset (int, optional): The number of items to skip before starting to collect the result set. + order_by (str, optional): The name of the field to order the results by. Prefix a field name with - to reverse the sort order. New in version 2.1.0 + tags (list[str], optional): List of tags to filter results. New in version 2.2.0 + only_active (bool, optional): Only filter active DAGs. New in version 2.1.1 + paused (bool, optional): Only filter paused/unpaused DAGs. If absent or null, it returns paused and unpaused DAGs. New in version 2.6.0 + fields (list[str], optional): List of field for return. + dag_id_pattern (str, optional): If set, only return DAGs with dag_ids matching this pattern. + + Returns: + list[str]: A list of DAG names. + """ + dags = [] + async with aiohttp.ClientSession() as session: + params = { + "limit": limit, + "offset": offset, + "order_by": order_by, + "tags": tags, + "only_active": only_active, + "paused": paused, + "fields": fields, + "dag_id_pattern": dag_id_pattern, + } + async with session.get(f"{self.airflow_base_url}/api/v1/dags", headers=self.headers, params=params) as response: + if response.status == 200: + dags = await response.json() + + return dags diff --git a/airflow-mcp-server/src/airflow_mcp_server/tools/models.py b/airflow-mcp-server/src/airflow_mcp_server/tools/models.py new file mode 100644 index 0000000..774dfa7 --- /dev/null +++ b/airflow-mcp-server/src/airflow_mcp_server/tools/models.py @@ -0,0 +1,22 @@ +from pydantic import BaseModel, model_validator + + +# DAG operations +# ==================================================================== +class ListDags(BaseModel): + """Parameters for listing DAGs.""" + + limit: int | None + offset: int | None + order_by: str | None + tags: list[str] | None + only_active: bool + paused: bool | None + fields: list[str] | None + dag_id_pattern: str | None + + @model_validator(mode="after") + def validate_offset(self) -> "ListDags": + if self.offset is not None and self.offset < 0: + raise ValueError("offset must be non-negative") + return self diff --git a/airflow-mcp-server/src/airflow_mcp_server/tools/tool_manager.py b/airflow-mcp-server/src/airflow_mcp_server/tools/tool_manager.py new file mode 100644 index 0000000..3a828db --- /dev/null +++ b/airflow-mcp-server/src/airflow_mcp_server/tools/tool_manager.py @@ -0,0 +1,12 @@ +"""Tools manager for maintaining singleton instances of tools.""" + +from airflow_mcp_server.tools.airflow_dag_tools import AirflowDagTools + +_dag_tools: AirflowDagTools | None = None + + +def get_airflow_dag_tools() -> AirflowDagTools: + global _dag_tools + if not _dag_tools: + _dag_tools = AirflowDagTools() + return _dag_tools