From b87fe176bd29e2839cd73e6fabc218a5257e113c Mon Sep 17 00:00:00 2001 From: abhishekbhakat Date: Thu, 13 Feb 2025 10:03:23 +0000 Subject: [PATCH] refactor for dynamic Models --- .../parser/operation_parser.py | 17 ++ .../src/airflow_mcp_server/server.py | 58 ++-- .../tools/airflow_dag_tools.py | 103 ------- .../airflow_mcp_server/tools/airflow_tool.py | 9 +- .../src/airflow_mcp_server/tools/models.py | 22 -- .../airflow_mcp_server/tools/tool_manager.py | 266 ++++-------------- 6 files changed, 90 insertions(+), 385 deletions(-) delete mode 100644 airflow-mcp-server/src/airflow_mcp_server/tools/airflow_dag_tools.py delete mode 100644 airflow-mcp-server/src/airflow_mcp_server/tools/models.py diff --git a/airflow-mcp-server/src/airflow_mcp_server/parser/operation_parser.py b/airflow-mcp-server/src/airflow_mcp_server/parser/operation_parser.py index 4cee9a7..9b5a7c1 100644 --- a/airflow-mcp-server/src/airflow_mcp_server/parser/operation_parser.py +++ b/airflow-mcp-server/src/airflow_mcp_server/parser/operation_parser.py @@ -259,6 +259,23 @@ class OperationParser: return self._create_model("Response", schema) + def get_operations(self) -> list[str]: + """Get list of all operation IDs from spec. + + Returns: + List of operation IDs + """ + operations = [] + + for path in self._paths.values(): + for method, operation in path.items(): + if method.startswith("x-") or method == "parameters": + continue + if "operationId" in operation: + operations.append(operation["operationId"]) + + return operations + def _create_model(self, name: str, schema: dict[str, Any]) -> type[BaseModel]: """Create Pydantic model from schema. diff --git a/airflow-mcp-server/src/airflow_mcp_server/server.py b/airflow-mcp-server/src/airflow_mcp_server/server.py index fe2a72b..0332da1 100644 --- a/airflow-mcp-server/src/airflow_mcp_server/server.py +++ b/airflow-mcp-server/src/airflow_mcp_server/server.py @@ -1,61 +1,37 @@ +import logging import os -from enum import Enum from typing import Any from mcp.server import Server from mcp.server.stdio import stdio_server from mcp.types import TextContent, Tool -from airflow_mcp_server.tools.models import ListDags -from airflow_mcp_server.tools.tool_manager import get_airflow_dag_tools +from airflow_mcp_server.tools.tool_manager import get_airflow_tools, get_tool - -class AirflowAPITools(str, Enum): - # DAG Operations - LIST_DAGS = "list_dags" - - -async def process_instruction(instruction: dict[str, Any]) -> dict[str, Any]: - dag_tools = get_airflow_dag_tools() - - try: - match instruction["type"]: - case "list_dags": - return {"dags": await dag_tools.list_dags()} - case _: - return {"message": "Invalid instruction type"} - except Exception as e: - return {"error": str(e)} +logger = logging.getLogger(__name__) async def serve() -> None: + """Start MCP server.""" + required_vars = ["OPENAPI_SPEC", "AIRFLOW_BASE_URL", "AUTH_TOKEN"] + if not all(var in os.environ for var in required_vars): + raise ValueError(f"Missing required environment variables: {required_vars}") + server = Server("airflow-mcp-server") @server.list_tools() async def list_tools() -> list[Tool]: - tools = [ - # DAG Operations - Tool( - name=AirflowAPITools.LIST_DAGS, - description="Lists all DAGs in Airflow", - inputSchema=ListDags.model_json_schema(), - ), - ] - if "AIRFLOW_BASE_URL" in os.environ and "AUTH_TOKEN" in os.environ: - return tools - else: - return [] + return get_airflow_tools() @server.call_tool() - async def call_tool(name: str, arguments: dict) -> list[TextContent]: - dag_tools = get_airflow_dag_tools() - - match name: - case AirflowAPITools.LIST_DAGS: - result = await dag_tools.list_dags() - return [TextContent(type="text", text=result)] - case _: - raise ValueError(f"Unknown tool: {name}") + async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]: + try: + tool = get_tool(name) + result = await tool.run(**arguments) + return [TextContent(type="text", text=str(result))] + except Exception as e: + logger.error("Tool execution failed: %s", e) + raise options = server.create_initialization_options() async with stdio_server() as (read_stream, write_stream): diff --git a/airflow-mcp-server/src/airflow_mcp_server/tools/airflow_dag_tools.py b/airflow-mcp-server/src/airflow_mcp_server/tools/airflow_dag_tools.py deleted file mode 100644 index 75bf53f..0000000 --- a/airflow-mcp-server/src/airflow_mcp_server/tools/airflow_dag_tools.py +++ /dev/null @@ -1,103 +0,0 @@ -import os - -import aiohttp - - -class AirflowDagTools: - def __init__(self): - self.airflow_base_url = os.getenv("AIRFLOW_BASE_URL") - self.auth_token = os.getenv("AUTH_TOKEN") - self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.auth_token}"} - - async def list_dags( - self, - limit: int | None = 100, - offset: int | None = None, - order_by: str | None = None, - tags: list[str] | None = None, - only_active: bool = True, - paused: bool | None = None, - fields: list[str] | None = None, - dag_id_pattern: str | None = None, - ) -> list[str]: - """ - List all DAGs in Airflow. - - Sample response: - { - "dags": [ - { - "dag_id": "string", - "dag_display_name": "string", - "root_dag_id": "string", - "is_paused": true, - "is_active": true, - "is_subdag": true, - "last_parsed_time": "2019-08-24T14:15:22Z", - "last_pickled": "2019-08-24T14:15:22Z", - "last_expired": "2019-08-24T14:15:22Z", - "scheduler_lock": true, - "pickle_id": "string", - "default_view": "string", - "fileloc": "string", - "file_token": "string", - "owners": [ - "string" - ], - "description": "string", - "schedule_interval": { - "__type": "string", - "days": 0, - "seconds": 0, - "microseconds": 0 - }, - "timetable_description": "string", - "tags": [ - { - "name": "string" - } - ], - "max_active_tasks": 0, - "max_active_runs": 0, - "has_task_concurrency_limits": true, - "has_import_errors": true, - "next_dagrun": "2019-08-24T14:15:22Z", - "next_dagrun_data_interval_start": "2019-08-24T14:15:22Z", - "next_dagrun_data_interval_end": "2019-08-24T14:15:22Z", - "next_dagrun_create_after": "2019-08-24T14:15:22Z", - "max_consecutive_failed_dag_runs": 0 - } - ], - "total_entries": 0 - } - - Args: - limit (int, optional): The numbers of items to return. - offset (int, optional): The number of items to skip before starting to collect the result set. - order_by (str, optional): The name of the field to order the results by. Prefix a field name with - to reverse the sort order. New in version 2.1.0 - tags (list[str], optional): List of tags to filter results. New in version 2.2.0 - only_active (bool, optional): Only filter active DAGs. New in version 2.1.1 - paused (bool, optional): Only filter paused/unpaused DAGs. If absent or null, it returns paused and unpaused DAGs. New in version 2.6.0 - fields (list[str], optional): List of field for return. - dag_id_pattern (str, optional): If set, only return DAGs with dag_ids matching this pattern. - - Returns: - list[str]: A list of DAG names. - """ - dags = [] - async with aiohttp.ClientSession() as session: - params = { - "limit": limit, - "offset": offset, - "order_by": order_by, - "tags": tags, - "only_active": only_active, - "paused": paused, - "fields": fields, - "dag_id_pattern": dag_id_pattern, - } - async with session.get(f"{self.airflow_base_url}/api/v1/dags", headers=self.headers, params=params) as response: - if response.status == 200: - dags = await response.json() - - return dags diff --git a/airflow-mcp-server/src/airflow_mcp_server/tools/airflow_tool.py b/airflow-mcp-server/src/airflow_mcp_server/tools/airflow_tool.py index 096cc88..96d1920 100644 --- a/airflow-mcp-server/src/airflow_mcp_server/tools/airflow_tool.py +++ b/airflow-mcp-server/src/airflow_mcp_server/tools/airflow_tool.py @@ -32,14 +32,17 @@ def create_validation_error(field: str, message: str) -> ValidationError: class AirflowTool(BaseTools): - """Tool for executing Airflow API operations.""" + """ + Tool for executing Airflow API operations. + AirflowTool is supposed to have objects per operation. + """ def __init__(self, operation_details: OperationDetails, client: AirflowClient) -> None: """Initialize tool with operation details and client. Args: - operation_details: Parsed operation details - client: Configured Airflow API client + operation_details: Operation details + client: AirflowClient instance """ super().__init__() self.operation = operation_details diff --git a/airflow-mcp-server/src/airflow_mcp_server/tools/models.py b/airflow-mcp-server/src/airflow_mcp_server/tools/models.py deleted file mode 100644 index 774dfa7..0000000 --- a/airflow-mcp-server/src/airflow_mcp_server/tools/models.py +++ /dev/null @@ -1,22 +0,0 @@ -from pydantic import BaseModel, model_validator - - -# DAG operations -# ==================================================================== -class ListDags(BaseModel): - """Parameters for listing DAGs.""" - - limit: int | None - offset: int | None - order_by: str | None - tags: list[str] | None - only_active: bool - paused: bool | None - fields: list[str] | None - dag_id_pattern: str | None - - @model_validator(mode="after") - def validate_offset(self) -> "ListDags": - if self.offset is not None and self.offset < 0: - raise ValueError("offset must be non-negative") - return self diff --git a/airflow-mcp-server/src/airflow_mcp_server/tools/tool_manager.py b/airflow-mcp-server/src/airflow_mcp_server/tools/tool_manager.py index 6c86b0e..9306729 100644 --- a/airflow-mcp-server/src/airflow_mcp_server/tools/tool_manager.py +++ b/airflow-mcp-server/src/airflow_mcp_server/tools/tool_manager.py @@ -1,243 +1,77 @@ -"""Tool manager for handling Airflow API tool instantiation and caching.""" - -import asyncio import logging -from collections import OrderedDict -from pathlib import Path -from typing import Any +import os + +from mcp.types import Tool from airflow_mcp_server.client.airflow_client import AirflowClient from airflow_mcp_server.parser.operation_parser import OperationParser -from airflow_mcp_server.tools.airflow_dag_tools import AirflowDagTools from airflow_mcp_server.tools.airflow_tool import AirflowTool logger = logging.getLogger(__name__) -# Keep existing function for backward compatibility -_dag_tools: AirflowDagTools | None = None +_tools_cache: dict[str, AirflowTool] = {} +_client: AirflowClient | None = None -def get_airflow_dag_tools() -> AirflowDagTools: - global _dag_tools - if not _dag_tools: - _dag_tools = AirflowDagTools() - return _dag_tools +def get_airflow_tools() -> list[Tool]: + """Get list of all available Airflow tools. + Returns: + List of MCP Tool objects representing available operations -class ToolManagerError(Exception): - """Base exception for tool manager errors.""" - - pass - - -class ToolInitializationError(ToolManagerError): - """Error during tool initialization.""" - - pass - - -class ToolNotFoundError(ToolManagerError): - """Error when requested tool is not found.""" - - pass - - -class ToolManager: - """Manager for Airflow API tools with caching and lifecycle management. - - This class provides a centralized way to manage Airflow API tools with: - - Singleton client management - - Tool caching with size limits - - Thread-safe access - - Proper error handling + Raises: + ValueError: If required environment variables are missing """ + global _tools_cache, _client - def __init__( - self, - spec_path: Path | str | object, - base_url: str, - auth_token: str, - max_cache_size: int = 100, - ) -> None: - """Initialize tool manager. + if not _tools_cache: + required_vars = ["OPENAPI_SPEC", "AIRFLOW_BASE_URL", "AUTH_TOKEN"] + if not all(var in os.environ for var in required_vars): + raise ValueError(f"Missing required environment variables: {required_vars}") - Args: - spec_path: Path to OpenAPI spec file or file-like object - base_url: Base URL for Airflow API - auth_token: Authentication token - max_cache_size: Maximum number of tools to cache (default: 100) + # Initialize client if not exists + if not _client: + _client = AirflowClient(spec_path=os.environ["OPENAPI_SPEC"], base_url=os.environ["AIRFLOW_BASE_URL"], auth_token=os.environ["AUTH_TOKEN"]) - Raises: - ToolManagerError: If initialization fails - ToolInitializationError: If client or parser initialization fails - """ try: - # Validate inputs - if not spec_path: - raise ValueError("spec_path is required") - if not base_url: - raise ValueError("base_url is required") - if not auth_token: - raise ValueError("auth_token is required") - if max_cache_size < 1: - raise ValueError("max_cache_size must be positive") + # Create parser + parser = OperationParser(os.environ["OPENAPI_SPEC"]) - # Store configuration - self._spec_path = spec_path - self._base_url = base_url.rstrip("/") - self._auth_token = auth_token - self._max_cache_size = max_cache_size + # Generate tools for each operation + for operation_id in parser.get_operations(): + operation_details = parser.parse_operation(operation_id) + tool = AirflowTool(operation_details, _client) + _tools_cache[operation_id] = tool - # Initialize core components with proper error handling - try: - self._client = AirflowClient(spec_path, self._base_url, auth_token) - logger.info("AirflowClient initialized successfully") - except Exception as e: - logger.error("Failed to initialize AirflowClient: %s", e) - raise ToolInitializationError(f"Failed to initialize client: {e}") from e - - try: - self._parser = OperationParser(spec_path) - logger.info("OperationParser initialized successfully") - except Exception as e: - logger.error("Failed to initialize OperationParser: %s", e) - raise ToolInitializationError(f"Failed to initialize parser: {e}") from e - - # Setup thread safety and caching - self._lock = asyncio.Lock() - self._tool_cache: OrderedDict[str, AirflowTool] = OrderedDict() - - logger.info("Tool manager initialized successfully (cache_size=%d, base_url=%s)", max_cache_size, self._base_url) - except ValueError as e: - logger.error("Invalid configuration: %s", e) - raise ToolManagerError(f"Invalid configuration: {e}") from e except Exception as e: - logger.error("Failed to initialize tool manager: %s", e) - raise ToolInitializationError(f"Component initialization failed: {e}") from e + logger.error("Failed to initialize tools: %s", e) + raise - async def __aenter__(self) -> "ToolManager": - """Enter async context.""" - try: - if not hasattr(self, "_client"): - logger.error("Client not initialized") - raise ToolInitializationError("Client not initialized") - await self._client.__aenter__() - return self - except Exception as e: - logger.error("Failed to enter async context: %s", e) - if hasattr(self, "_client"): - try: - await self._client.__aexit__(None, None, None) - except Exception: - pass - raise ToolInitializationError(f"Failed to initialize client session: {e}") from e + # Convert to MCP Tool format + return [ + Tool( + name=operation_id, + description=tool.operation.operation_id, + inputSchema=tool.operation.request_body.model_json_schema() if tool.operation.request_body else None, + ) + for operation_id, tool in _tools_cache.items() + ] - async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: - """Exit async context.""" - try: - await self._client.__aexit__(exc_type, exc_val, exc_tb) - except Exception as e: - logger.error("Error during context exit: %s", e) - finally: - self.clear_cache() - def _evict_cache_if_needed(self) -> None: - """Evict oldest items from cache if size limit is reached. +def get_tool(name: str) -> AirflowTool: + """Get specific tool by name. - This method uses FIFO (First In First Out) eviction policy. - Eviction occurs when cache reaches max_cache_size. - """ - current_size = len(self._tool_cache) - if current_size >= self._max_cache_size: - logger.info("Cache limit reached (%d/%d). Starting eviction.", current_size, self._max_cache_size) - logger.debug("Current cache contents: %s", list(self._tool_cache.keys())) + Args: + name: Tool/operation name - evicted_count = 0 - while len(self._tool_cache) >= self._max_cache_size: - operation_id, _ = self._tool_cache.popitem(last=False) - evicted_count += 1 - logger.debug("Evicted tool %s from cache (size: %d/%d)", operation_id, len(self._tool_cache), self._max_cache_size) + Returns: + AirflowTool instance - if evicted_count > 0: - logger.info("Evicted %d tools from cache", evicted_count) + Raises: + KeyError: If tool not found + """ + if name not in _tools_cache: + # Ensure cache is populated + get_airflow_tools() - async def get_tool(self, operation_id: str) -> AirflowTool: - """Get or create a tool instance for the given operation. - - Args: - operation_id: Operation ID from OpenAPI spec - - Returns: - AirflowTool instance - - Raises: - ToolNotFoundError: If operation not found - ToolInitializationError: If tool creation fails - ValueError: If operation_id is invalid - """ - if not operation_id or not isinstance(operation_id, str): - logger.error("Invalid operation_id provided: %s", operation_id) - raise ValueError("Invalid operation_id") - - if not hasattr(self, "_client") or not hasattr(self, "_parser"): - logger.error("ToolManager not properly initialized") - raise ToolInitializationError("ToolManager components not initialized") - - logger.debug("Requesting tool for operation: %s", operation_id) - - async with self._lock: - # Check cache first - if operation_id in self._tool_cache: - logger.debug("Tool cache hit for %s", operation_id) - return self._tool_cache[operation_id] - - logger.debug("Tool cache miss for %s, creating new instance", operation_id) - - try: - # Parse operation details - try: - operation_details = self._parser.parse_operation(operation_id) - except ValueError as e: - logger.error("Operation %s not found: %s", operation_id, e) - raise ToolNotFoundError(f"Operation {operation_id} not found") from e - except Exception as e: - logger.error("Failed to parse operation %s: %s", operation_id, e) - raise ToolInitializationError(f"Operation parsing failed: {e}") from e - - # Create new tool instance - try: - tool = AirflowTool(operation_details, self._client) - except Exception as e: - logger.error("Failed to create tool instance for %s: %s", operation_id, e) - raise ToolInitializationError(f"Tool instantiation failed: {e}") from e - - # Update cache - self._evict_cache_if_needed() - self._tool_cache[operation_id] = tool - logger.info("Created and cached new tool for %s", operation_id) - - return tool - - except (ToolNotFoundError, ToolInitializationError): - raise - except Exception as e: - logger.error("Unexpected error creating tool for %s: %s", operation_id, e) - raise ToolInitializationError(f"Unexpected error: {e}") from e - - def clear_cache(self) -> None: - """Clear the tool cache.""" - self._tool_cache.clear() - logger.debug("Tool cache cleared") - - @property - def cache_info(self) -> dict[str, Any]: - """Get cache statistics. - - Returns: - Dictionary with cache statistics - """ - return { - "size": len(self._tool_cache), - "max_size": self._max_cache_size, - "operations": list(self._tool_cache.keys()), - } + return _tools_cache[name]