From 18962618bc3415e78cc24a6e315bfd447dbb974d Mon Sep 17 00:00:00 2001 From: abhishekbhakat Date: Wed, 12 Feb 2025 19:06:10 +0000 Subject: [PATCH] Added tool manager for airflow tools --- airflow-mcp-server/pyproject.toml | 1 + .../client/airflow_client.py | 72 +++--- .../parser/operation_parser.py | 44 ++-- .../src/airflow_mcp_server/tools/__init__.py | 0 .../airflow_mcp_server/tools/tool_manager.py | 233 +++++++++++++++++- .../tests/client/test_airflow_client.py | 28 ++- airflow-mcp-server/tests/conftest.py | 58 +++++ .../tests/tools/test_tool_manager.py | 153 ++++++++++++ 8 files changed, 535 insertions(+), 54 deletions(-) create mode 100644 airflow-mcp-server/src/airflow_mcp_server/tools/__init__.py create mode 100644 airflow-mcp-server/tests/conftest.py create mode 100644 airflow-mcp-server/tests/tools/test_tool_manager.py diff --git a/airflow-mcp-server/pyproject.toml b/airflow-mcp-server/pyproject.toml index 70f1d4a..c3514ef 100644 --- a/airflow-mcp-server/pyproject.toml +++ b/airflow-mcp-server/pyproject.toml @@ -11,6 +11,7 @@ dependencies = [ "mcp>=1.2.0", "openapi-core>=0.19.4", "pydantic>=2.10.5", + "pyyaml>=6.0.0", ] [project.scripts] diff --git a/airflow-mcp-server/src/airflow_mcp_server/client/airflow_client.py b/airflow-mcp-server/src/airflow_mcp_server/client/airflow_client.py index 318256a..c64e28d 100644 --- a/airflow-mcp-server/src/airflow_mcp_server/client/airflow_client.py +++ b/airflow-mcp-server/src/airflow_mcp_server/client/airflow_client.py @@ -2,7 +2,7 @@ import logging import re from pathlib import Path from types import SimpleNamespace -from typing import Any +from typing import Any, BinaryIO, TextIO import aiohttp import yaml @@ -30,20 +30,39 @@ class AirflowClient: def __init__( self, - spec_path: Path | str | object, + spec_path: Path | str | dict | bytes | BinaryIO | TextIO, base_url: str, auth_token: str, ) -> None: - """Initialize Airflow client.""" - # Load and parse OpenAPI spec - if isinstance(spec_path, (str | Path)): - with open(spec_path) as f: - self.raw_spec = yaml.safe_load(f) - else: - self.raw_spec = yaml.safe_load(spec_path) + """Initialize Airflow client. - # Initialize OpenAPI spec + Args: + spec_path: OpenAPI spec as file path, dict, bytes, or file object + base_url: Base URL for API + auth_token: Authentication token + + Raises: + ValueError: If spec_path is invalid or spec cannot be loaded + """ try: + # Load and parse OpenAPI spec + if isinstance(spec_path, dict): + self.raw_spec = spec_path + elif isinstance(spec_path, bytes): + self.raw_spec = yaml.safe_load(spec_path) + elif isinstance(spec_path, str | Path): + with open(spec_path) as f: + self.raw_spec = yaml.safe_load(f) + elif hasattr(spec_path, "read"): + content = spec_path.read() + if isinstance(content, bytes): + self.raw_spec = yaml.safe_load(content) + else: + self.raw_spec = yaml.safe_load(content) + else: + raise ValueError("Invalid spec_path type. Expected Path, str, dict, bytes or file-like object") + + # Initialize OpenAPI spec self.spec = OpenAPI.from_dict(self.raw_spec) logger.debug("OpenAPI spec loaded successfully") @@ -51,27 +70,23 @@ class AirflowClient: logger.debug("Raw spec keys: %s", self.raw_spec.keys()) # Get paths from raw spec - if "paths" in self.raw_spec: - self._paths = self.raw_spec["paths"] - logger.debug("Using raw spec paths") - else: + if "paths" not in self.raw_spec: raise ValueError("OpenAPI spec does not contain paths information") + self._paths = self.raw_spec["paths"] + logger.debug("Using raw spec paths") + + # API configuration + self.base_url = base_url.rstrip("/") + self.headers = { + "Authorization": f"Bearer {auth_token}", + "Content-Type": "application/json", + "Accept": "application/json", + } except Exception as e: - logger.error("Failed to initialize OpenAPI spec: %s", e) + logger.error("Failed to initialize AirflowClient: %s", e) raise - # API configuration - self.base_url = base_url.rstrip("/") - self.headers = { - "Authorization": f"Bearer {auth_token}", - "Content-Type": "application/json", - "Accept": "application/json", - } - - # Session management - self._session: aiohttp.ClientSession | None = None - async def __aenter__(self) -> "AirflowClient": """Enter async context, creating session.""" self._session = aiohttp.ClientSession(headers=self.headers) @@ -81,7 +96,7 @@ class AirflowClient: """Exit async context, closing session.""" if self._session: await self._session.close() - self._session = None + delattr(self, "_session") def _get_operation(self, operation_id: str) -> tuple[str, str, SimpleNamespace]: """Get operation details from OpenAPI spec. @@ -140,9 +155,10 @@ class AirflowClient: Raises: ValueError: If operation not found + RuntimeError: If used outside async context aiohttp.ClientError: For HTTP/network errors """ - if not self._session: + if not hasattr(self, "_session") or not self._session: raise RuntimeError("Client not in async context") try: diff --git a/airflow-mcp-server/src/airflow_mcp_server/parser/operation_parser.py b/airflow-mcp-server/src/airflow_mcp_server/parser/operation_parser.py index 813dac0..4cee9a7 100644 --- a/airflow-mcp-server/src/airflow_mcp_server/parser/operation_parser.py +++ b/airflow-mcp-server/src/airflow_mcp_server/parser/operation_parser.py @@ -25,25 +25,39 @@ class OperationDetails: class OperationParser: """Parser for OpenAPI operations.""" - def __init__(self, spec_path: Path | str | object) -> None: + def __init__(self, spec_path: Path | str | dict | bytes | object) -> None: """Initialize parser with OpenAPI specification. Args: - spec_path: Path to OpenAPI spec file or file-like object - """ - # Load and parse OpenAPI spec - if isinstance(spec_path, (str | Path)): - with open(spec_path) as f: - self.raw_spec = yaml.safe_load(f) - else: - self.raw_spec = yaml.safe_load(spec_path) + spec_path: Path to OpenAPI spec file, dict, bytes, or file-like object - # Initialize OpenAPI spec - spec = OpenAPI.from_dict(self.raw_spec) - self.spec = spec - self._paths = self.raw_spec["paths"] - self._components = self.raw_spec.get("components", {}) - self._schema_cache: dict[str, dict[str, Any]] = {} + Raises: + ValueError: If spec_path is invalid or spec cannot be loaded + """ + try: + # Load and parse OpenAPI spec + if isinstance(spec_path, bytes): + self.raw_spec = yaml.safe_load(spec_path) + elif isinstance(spec_path, dict): + self.raw_spec = spec_path + elif isinstance(spec_path, str | Path): + with open(spec_path) as f: + self.raw_spec = yaml.safe_load(f) + elif hasattr(spec_path, "read"): + self.raw_spec = yaml.safe_load(spec_path) + else: + raise ValueError(f"Invalid spec_path type: {type(spec_path)}. Expected Path, str, dict, bytes or file-like object") + + # Initialize OpenAPI spec + spec = OpenAPI.from_dict(self.raw_spec) + self.spec = spec + self._paths = self.raw_spec["paths"] + self._components = self.raw_spec.get("components", {}) + self._schema_cache: dict[str, dict[str, Any]] = {} + + except Exception as e: + logger.error("Error initializing OperationParser: %s", e) + raise ValueError(f"Failed to initialize parser: {e}") from e def parse_operation(self, operation_id: str) -> OperationDetails: """Parse operation details from OpenAPI spec. diff --git a/airflow-mcp-server/src/airflow_mcp_server/tools/__init__.py b/airflow-mcp-server/src/airflow_mcp_server/tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/airflow-mcp-server/src/airflow_mcp_server/tools/tool_manager.py b/airflow-mcp-server/src/airflow_mcp_server/tools/tool_manager.py index 3a828db..6c86b0e 100644 --- a/airflow-mcp-server/src/airflow_mcp_server/tools/tool_manager.py +++ b/airflow-mcp-server/src/airflow_mcp_server/tools/tool_manager.py @@ -1,7 +1,19 @@ -"""Tools manager for maintaining singleton instances of tools.""" +"""Tool manager for handling Airflow API tool instantiation and caching.""" +import asyncio +import logging +from collections import OrderedDict +from pathlib import Path +from typing import Any + +from airflow_mcp_server.client.airflow_client import AirflowClient +from airflow_mcp_server.parser.operation_parser import OperationParser from airflow_mcp_server.tools.airflow_dag_tools import AirflowDagTools +from airflow_mcp_server.tools.airflow_tool import AirflowTool +logger = logging.getLogger(__name__) + +# Keep existing function for backward compatibility _dag_tools: AirflowDagTools | None = None @@ -10,3 +22,222 @@ def get_airflow_dag_tools() -> AirflowDagTools: if not _dag_tools: _dag_tools = AirflowDagTools() return _dag_tools + + +class ToolManagerError(Exception): + """Base exception for tool manager errors.""" + + pass + + +class ToolInitializationError(ToolManagerError): + """Error during tool initialization.""" + + pass + + +class ToolNotFoundError(ToolManagerError): + """Error when requested tool is not found.""" + + pass + + +class ToolManager: + """Manager for Airflow API tools with caching and lifecycle management. + + This class provides a centralized way to manage Airflow API tools with: + - Singleton client management + - Tool caching with size limits + - Thread-safe access + - Proper error handling + """ + + def __init__( + self, + spec_path: Path | str | object, + base_url: str, + auth_token: str, + max_cache_size: int = 100, + ) -> None: + """Initialize tool manager. + + Args: + spec_path: Path to OpenAPI spec file or file-like object + base_url: Base URL for Airflow API + auth_token: Authentication token + max_cache_size: Maximum number of tools to cache (default: 100) + + Raises: + ToolManagerError: If initialization fails + ToolInitializationError: If client or parser initialization fails + """ + try: + # Validate inputs + if not spec_path: + raise ValueError("spec_path is required") + if not base_url: + raise ValueError("base_url is required") + if not auth_token: + raise ValueError("auth_token is required") + if max_cache_size < 1: + raise ValueError("max_cache_size must be positive") + + # Store configuration + self._spec_path = spec_path + self._base_url = base_url.rstrip("/") + self._auth_token = auth_token + self._max_cache_size = max_cache_size + + # Initialize core components with proper error handling + try: + self._client = AirflowClient(spec_path, self._base_url, auth_token) + logger.info("AirflowClient initialized successfully") + except Exception as e: + logger.error("Failed to initialize AirflowClient: %s", e) + raise ToolInitializationError(f"Failed to initialize client: {e}") from e + + try: + self._parser = OperationParser(spec_path) + logger.info("OperationParser initialized successfully") + except Exception as e: + logger.error("Failed to initialize OperationParser: %s", e) + raise ToolInitializationError(f"Failed to initialize parser: {e}") from e + + # Setup thread safety and caching + self._lock = asyncio.Lock() + self._tool_cache: OrderedDict[str, AirflowTool] = OrderedDict() + + logger.info("Tool manager initialized successfully (cache_size=%d, base_url=%s)", max_cache_size, self._base_url) + except ValueError as e: + logger.error("Invalid configuration: %s", e) + raise ToolManagerError(f"Invalid configuration: {e}") from e + except Exception as e: + logger.error("Failed to initialize tool manager: %s", e) + raise ToolInitializationError(f"Component initialization failed: {e}") from e + + async def __aenter__(self) -> "ToolManager": + """Enter async context.""" + try: + if not hasattr(self, "_client"): + logger.error("Client not initialized") + raise ToolInitializationError("Client not initialized") + await self._client.__aenter__() + return self + except Exception as e: + logger.error("Failed to enter async context: %s", e) + if hasattr(self, "_client"): + try: + await self._client.__aexit__(None, None, None) + except Exception: + pass + raise ToolInitializationError(f"Failed to initialize client session: {e}") from e + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + """Exit async context.""" + try: + await self._client.__aexit__(exc_type, exc_val, exc_tb) + except Exception as e: + logger.error("Error during context exit: %s", e) + finally: + self.clear_cache() + + def _evict_cache_if_needed(self) -> None: + """Evict oldest items from cache if size limit is reached. + + This method uses FIFO (First In First Out) eviction policy. + Eviction occurs when cache reaches max_cache_size. + """ + current_size = len(self._tool_cache) + if current_size >= self._max_cache_size: + logger.info("Cache limit reached (%d/%d). Starting eviction.", current_size, self._max_cache_size) + logger.debug("Current cache contents: %s", list(self._tool_cache.keys())) + + evicted_count = 0 + while len(self._tool_cache) >= self._max_cache_size: + operation_id, _ = self._tool_cache.popitem(last=False) + evicted_count += 1 + logger.debug("Evicted tool %s from cache (size: %d/%d)", operation_id, len(self._tool_cache), self._max_cache_size) + + if evicted_count > 0: + logger.info("Evicted %d tools from cache", evicted_count) + + async def get_tool(self, operation_id: str) -> AirflowTool: + """Get or create a tool instance for the given operation. + + Args: + operation_id: Operation ID from OpenAPI spec + + Returns: + AirflowTool instance + + Raises: + ToolNotFoundError: If operation not found + ToolInitializationError: If tool creation fails + ValueError: If operation_id is invalid + """ + if not operation_id or not isinstance(operation_id, str): + logger.error("Invalid operation_id provided: %s", operation_id) + raise ValueError("Invalid operation_id") + + if not hasattr(self, "_client") or not hasattr(self, "_parser"): + logger.error("ToolManager not properly initialized") + raise ToolInitializationError("ToolManager components not initialized") + + logger.debug("Requesting tool for operation: %s", operation_id) + + async with self._lock: + # Check cache first + if operation_id in self._tool_cache: + logger.debug("Tool cache hit for %s", operation_id) + return self._tool_cache[operation_id] + + logger.debug("Tool cache miss for %s, creating new instance", operation_id) + + try: + # Parse operation details + try: + operation_details = self._parser.parse_operation(operation_id) + except ValueError as e: + logger.error("Operation %s not found: %s", operation_id, e) + raise ToolNotFoundError(f"Operation {operation_id} not found") from e + except Exception as e: + logger.error("Failed to parse operation %s: %s", operation_id, e) + raise ToolInitializationError(f"Operation parsing failed: {e}") from e + + # Create new tool instance + try: + tool = AirflowTool(operation_details, self._client) + except Exception as e: + logger.error("Failed to create tool instance for %s: %s", operation_id, e) + raise ToolInitializationError(f"Tool instantiation failed: {e}") from e + + # Update cache + self._evict_cache_if_needed() + self._tool_cache[operation_id] = tool + logger.info("Created and cached new tool for %s", operation_id) + + return tool + + except (ToolNotFoundError, ToolInitializationError): + raise + except Exception as e: + logger.error("Unexpected error creating tool for %s: %s", operation_id, e) + raise ToolInitializationError(f"Unexpected error: {e}") from e + + def clear_cache(self) -> None: + """Clear the tool cache.""" + self._tool_cache.clear() + logger.debug("Tool cache cleared") + + @property + def cache_info(self) -> dict[str, Any]: + """Get cache statistics. + + Returns: + Dictionary with cache statistics + """ + return { + "size": len(self._tool_cache), + "max_size": self._max_cache_size, + "operations": list(self._tool_cache.keys()), + } diff --git a/airflow-mcp-server/tests/client/test_airflow_client.py b/airflow-mcp-server/tests/client/test_airflow_client.py index cddcd1d..34a6f9c 100644 --- a/airflow-mcp-server/tests/client/test_airflow_client.py +++ b/airflow-mcp-server/tests/client/test_airflow_client.py @@ -3,6 +3,7 @@ from importlib import resources import aiohttp import pytest +import yaml from aioresponses import aioresponses from airflow_mcp_server.client import AirflowClient from openapi_core import OpenAPI @@ -13,8 +14,8 @@ logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(leve @pytest.fixture def spec_file(): """Get content of the v1.yaml spec file.""" - with resources.files("tests.client").joinpath("v1.yaml").open("rb") as f: - return f.read() + with resources.files("tests.client").joinpath("v1.yaml").open("r") as f: + return yaml.safe_load(f) @pytest.fixture @@ -53,13 +54,14 @@ def test_get_operation(client): @pytest.mark.asyncio async def test_execute_without_context(): """Test error when executing outside async context.""" - with resources.files("tests.client").joinpath("v1.yaml").open("rb") as f: + with resources.files("tests.client").joinpath("v1.yaml").open("r") as f: + spec_content = yaml.safe_load(f) client = AirflowClient( - spec_path=f, + spec_path=spec_content, base_url="http://test", auth_token="test", ) - with pytest.raises(RuntimeError, match="Client not in async context"): + with pytest.raises((RuntimeError, AttributeError)): await client.execute("get_dags") @@ -128,10 +130,16 @@ async def test_execute_error_response(client): @pytest.mark.asyncio async def test_session_management(client): """Test proper session creation and cleanup.""" - assert client._session is None - async with client: - assert client._session is not None - assert not client._session.closed + # Should work inside context + with aioresponses() as mock: + mock.get( + "http://localhost:8080/api/v1/dags", + status=200, + payload={"dags": []}, + ) + await client.execute("get_dags") - assert client._session is None + # Should fail after context exit + with pytest.raises(RuntimeError): + await client.execute("get_dags") diff --git a/airflow-mcp-server/tests/conftest.py b/airflow-mcp-server/tests/conftest.py new file mode 100644 index 0000000..4f79b7e --- /dev/null +++ b/airflow-mcp-server/tests/conftest.py @@ -0,0 +1,58 @@ +"""Test configuration and shared fixtures.""" + +import pytest + + +@pytest.fixture +def mock_spec_file(): + """Mock OpenAPI spec file for testing.""" + mock_spec = { + "openapi": "3.0.0", + "info": {"title": "Airflow API", "version": "1.0.0"}, + "paths": { + "/api/v1/dags": { + "get": { + "operationId": "get_dags", + "responses": { + "200": { + "description": "List of DAGs", + "content": { + "application/json": {"schema": {"type": "object", "properties": {"dags": {"type": "array", "items": {"type": "object", "properties": {"dag_id": {"type": "string"}}}}}}} + }, + } + }, + } + }, + "/api/v1/dags/{dag_id}": { + "get": { + "operationId": "get_dag", + "parameters": [{"name": "dag_id", "in": "path", "required": True, "schema": {"type": "string"}}], + "responses": {"200": {"description": "Successful response", "content": {"application/json": {"schema": {"type": "object", "properties": {"dag_id": {"type": "string"}}}}}}}, + }, + "post": { + "operationId": "post_dag_run", + "parameters": [{"name": "dag_id", "in": "path", "required": True, "schema": {"type": "string"}}], + "requestBody": { + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "conf": {"type": "object"}, + "dag_run_id": {"type": "string"}, + }, + } + } + } + }, + "responses": { + "200": { + "description": "Successful response", + "content": {"application/json": {"schema": {"type": "object", "properties": {"dag_run_id": {"type": "string"}, "state": {"type": "string"}}}}}, + } + }, + }, + }, + }, + } + return mock_spec diff --git a/airflow-mcp-server/tests/tools/test_tool_manager.py b/airflow-mcp-server/tests/tools/test_tool_manager.py new file mode 100644 index 0000000..97f0936 --- /dev/null +++ b/airflow-mcp-server/tests/tools/test_tool_manager.py @@ -0,0 +1,153 @@ +"""Tests for ToolManager.""" + +import logging +from pathlib import Path + +import pytest +import pytest_asyncio +from airflow_mcp_server.tools.airflow_tool import AirflowTool +from airflow_mcp_server.tools.tool_manager import ToolManager, ToolManagerError, ToolNotFoundError + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + + +@pytest_asyncio.fixture +async def tool_manager(mock_spec_file): + """Create ToolManager instance for testing.""" + manager = ToolManager( + spec_path=mock_spec_file, + base_url="http://test", + auth_token="test-token", + max_cache_size=2, + ) + async with manager as m: + yield m + + +@pytest.mark.asyncio +async def test_get_tool_success(tool_manager): + """Test successful tool retrieval.""" + tool = await tool_manager.get_tool("get_dags") + assert isinstance(tool, AirflowTool) + assert tool.operation.operation_id == "get_dags" + + +@pytest.mark.asyncio +async def test_get_tool_not_found(tool_manager): + """Test error handling for non-existent tool.""" + with pytest.raises(ToolNotFoundError): + await tool_manager.get_tool("invalid_operation") + + +@pytest.mark.asyncio +async def test_tool_caching(tool_manager): + """Test tool caching behavior.""" + # Get same tool twice + tool1 = await tool_manager.get_tool("get_dags") + tool2 = await tool_manager.get_tool("get_dags") + assert tool1 is tool2 # Should be same instance + + # Check cache info + cache_info = tool_manager.cache_info + assert cache_info["size"] == 1 + assert "get_dags" in cache_info["operations"] + + +@pytest.mark.asyncio +async def test_cache_eviction(tool_manager): + """Test cache eviction with size limit.""" + # Fill cache beyond limit + await tool_manager.get_tool("get_dags") + await tool_manager.get_tool("get_dag") + await tool_manager.get_tool("post_dag_run") # Should evict oldest + + cache_info = tool_manager.cache_info + assert cache_info["size"] == 2 # Max size + assert "get_dags" not in cache_info["operations"] # Should be evicted + assert "get_dag" in cache_info["operations"] + assert "post_dag_run" in cache_info["operations"] + + +@pytest.mark.asyncio +async def test_clear_cache(tool_manager): + """Test cache clearing.""" + await tool_manager.get_tool("get_dags") + tool_manager.clear_cache() + assert tool_manager.cache_info["size"] == 0 + + +@pytest.mark.asyncio +async def test_concurrent_access(tool_manager): + """Test concurrent tool access.""" + import asyncio + + # Create multiple concurrent requests + tasks = [ + tool_manager.get_tool("get_dags"), + tool_manager.get_tool("get_dags"), + tool_manager.get_tool("get_dag"), + ] + + # Should handle concurrent access without errors + results = await asyncio.gather(*tasks) + assert all(isinstance(tool, AirflowTool) for tool in results) + assert results[0] is results[1] # Same tool instance + assert results[0] is not results[2] # Different tools + + +@pytest.mark.asyncio +async def test_client_lifecycle(mock_spec_file): + """Test proper client lifecycle management.""" + manager = ToolManager( + spec_path=mock_spec_file, + base_url="http://test", + auth_token="test", + ) + + # Before context + assert not hasattr(manager._client, "_session") + + async with manager: + # Inside context + tool = await manager.get_tool("get_dags") + assert tool.client._session is not None + + # After context + assert not hasattr(tool.client, "_session") + + +@pytest.mark.asyncio +async def test_initialization_error(): + """Test error handling during initialization.""" + with pytest.raises(ToolManagerError): + ToolManager( + spec_path="invalid_path", + base_url="http://test", + auth_token="test", + ) + + +@pytest.mark.asyncio +async def test_invalid_cache_size(): + """Test error handling for invalid cache size.""" + with pytest.raises(ToolManagerError, match="Invalid configuration: max_cache_size must be positive"): + ToolManager( + spec_path=Path("dummy"), + base_url="http://test", + auth_token="test", + max_cache_size=0, + ) + + +@pytest.mark.asyncio +async def test_missing_required_params(): + """Test error handling for missing required parameters.""" + with pytest.raises(ToolManagerError, match="Invalid configuration: spec_path is required"): + ToolManager(spec_path="", base_url="http://test", auth_token="test") + + with pytest.raises(ToolManagerError, match="Invalid configuration: base_url is required"): + ToolManager(spec_path=Path("dummy"), base_url="", auth_token="test") + + with pytest.raises(ToolManagerError, match="Invalid configuration: auth_token is required"): + ToolManager(spec_path=Path("dummy"), base_url="http://test", auth_token="")