Added tool manager for airflow tools

This commit is contained in:
2025-02-12 19:06:10 +00:00
parent c152852767
commit 18962618bc
8 changed files with 535 additions and 54 deletions

View File

@@ -11,6 +11,7 @@ dependencies = [
"mcp>=1.2.0",
"openapi-core>=0.19.4",
"pydantic>=2.10.5",
"pyyaml>=6.0.0",
]
[project.scripts]

View File

@@ -2,7 +2,7 @@ import logging
import re
from pathlib import Path
from types import SimpleNamespace
from typing import Any
from typing import Any, BinaryIO, TextIO
import aiohttp
import yaml
@@ -30,20 +30,39 @@ class AirflowClient:
def __init__(
self,
spec_path: Path | str | object,
spec_path: Path | str | dict | bytes | BinaryIO | TextIO,
base_url: str,
auth_token: str,
) -> None:
"""Initialize Airflow client."""
"""Initialize Airflow client.
Args:
spec_path: OpenAPI spec as file path, dict, bytes, or file object
base_url: Base URL for API
auth_token: Authentication token
Raises:
ValueError: If spec_path is invalid or spec cannot be loaded
"""
try:
# Load and parse OpenAPI spec
if isinstance(spec_path, (str | Path)):
if isinstance(spec_path, dict):
self.raw_spec = spec_path
elif isinstance(spec_path, bytes):
self.raw_spec = yaml.safe_load(spec_path)
elif isinstance(spec_path, str | Path):
with open(spec_path) as f:
self.raw_spec = yaml.safe_load(f)
elif hasattr(spec_path, "read"):
content = spec_path.read()
if isinstance(content, bytes):
self.raw_spec = yaml.safe_load(content)
else:
self.raw_spec = yaml.safe_load(spec_path)
self.raw_spec = yaml.safe_load(content)
else:
raise ValueError("Invalid spec_path type. Expected Path, str, dict, bytes or file-like object")
# Initialize OpenAPI spec
try:
self.spec = OpenAPI.from_dict(self.raw_spec)
logger.debug("OpenAPI spec loaded successfully")
@@ -51,15 +70,10 @@ class AirflowClient:
logger.debug("Raw spec keys: %s", self.raw_spec.keys())
# Get paths from raw spec
if "paths" in self.raw_spec:
if "paths" not in self.raw_spec:
raise ValueError("OpenAPI spec does not contain paths information")
self._paths = self.raw_spec["paths"]
logger.debug("Using raw spec paths")
else:
raise ValueError("OpenAPI spec does not contain paths information")
except Exception as e:
logger.error("Failed to initialize OpenAPI spec: %s", e)
raise
# API configuration
self.base_url = base_url.rstrip("/")
@@ -69,8 +83,9 @@ class AirflowClient:
"Accept": "application/json",
}
# Session management
self._session: aiohttp.ClientSession | None = None
except Exception as e:
logger.error("Failed to initialize AirflowClient: %s", e)
raise
async def __aenter__(self) -> "AirflowClient":
"""Enter async context, creating session."""
@@ -81,7 +96,7 @@ class AirflowClient:
"""Exit async context, closing session."""
if self._session:
await self._session.close()
self._session = None
delattr(self, "_session")
def _get_operation(self, operation_id: str) -> tuple[str, str, SimpleNamespace]:
"""Get operation details from OpenAPI spec.
@@ -140,9 +155,10 @@ class AirflowClient:
Raises:
ValueError: If operation not found
RuntimeError: If used outside async context
aiohttp.ClientError: For HTTP/network errors
"""
if not self._session:
if not hasattr(self, "_session") or not self._session:
raise RuntimeError("Client not in async context")
try:

View File

@@ -25,18 +25,28 @@ class OperationDetails:
class OperationParser:
"""Parser for OpenAPI operations."""
def __init__(self, spec_path: Path | str | object) -> None:
def __init__(self, spec_path: Path | str | dict | bytes | object) -> None:
"""Initialize parser with OpenAPI specification.
Args:
spec_path: Path to OpenAPI spec file or file-like object
spec_path: Path to OpenAPI spec file, dict, bytes, or file-like object
Raises:
ValueError: If spec_path is invalid or spec cannot be loaded
"""
try:
# Load and parse OpenAPI spec
if isinstance(spec_path, (str | Path)):
if isinstance(spec_path, bytes):
self.raw_spec = yaml.safe_load(spec_path)
elif isinstance(spec_path, dict):
self.raw_spec = spec_path
elif isinstance(spec_path, str | Path):
with open(spec_path) as f:
self.raw_spec = yaml.safe_load(f)
else:
elif hasattr(spec_path, "read"):
self.raw_spec = yaml.safe_load(spec_path)
else:
raise ValueError(f"Invalid spec_path type: {type(spec_path)}. Expected Path, str, dict, bytes or file-like object")
# Initialize OpenAPI spec
spec = OpenAPI.from_dict(self.raw_spec)
@@ -45,6 +55,10 @@ class OperationParser:
self._components = self.raw_spec.get("components", {})
self._schema_cache: dict[str, dict[str, Any]] = {}
except Exception as e:
logger.error("Error initializing OperationParser: %s", e)
raise ValueError(f"Failed to initialize parser: {e}") from e
def parse_operation(self, operation_id: str) -> OperationDetails:
"""Parse operation details from OpenAPI spec.

View File

@@ -1,7 +1,19 @@
"""Tools manager for maintaining singleton instances of tools."""
"""Tool manager for handling Airflow API tool instantiation and caching."""
import asyncio
import logging
from collections import OrderedDict
from pathlib import Path
from typing import Any
from airflow_mcp_server.client.airflow_client import AirflowClient
from airflow_mcp_server.parser.operation_parser import OperationParser
from airflow_mcp_server.tools.airflow_dag_tools import AirflowDagTools
from airflow_mcp_server.tools.airflow_tool import AirflowTool
logger = logging.getLogger(__name__)
# Keep existing function for backward compatibility
_dag_tools: AirflowDagTools | None = None
@@ -10,3 +22,222 @@ def get_airflow_dag_tools() -> AirflowDagTools:
if not _dag_tools:
_dag_tools = AirflowDagTools()
return _dag_tools
class ToolManagerError(Exception):
"""Base exception for tool manager errors."""
pass
class ToolInitializationError(ToolManagerError):
"""Error during tool initialization."""
pass
class ToolNotFoundError(ToolManagerError):
"""Error when requested tool is not found."""
pass
class ToolManager:
"""Manager for Airflow API tools with caching and lifecycle management.
This class provides a centralized way to manage Airflow API tools with:
- Singleton client management
- Tool caching with size limits
- Thread-safe access
- Proper error handling
"""
def __init__(
self,
spec_path: Path | str | object,
base_url: str,
auth_token: str,
max_cache_size: int = 100,
) -> None:
"""Initialize tool manager.
Args:
spec_path: Path to OpenAPI spec file or file-like object
base_url: Base URL for Airflow API
auth_token: Authentication token
max_cache_size: Maximum number of tools to cache (default: 100)
Raises:
ToolManagerError: If initialization fails
ToolInitializationError: If client or parser initialization fails
"""
try:
# Validate inputs
if not spec_path:
raise ValueError("spec_path is required")
if not base_url:
raise ValueError("base_url is required")
if not auth_token:
raise ValueError("auth_token is required")
if max_cache_size < 1:
raise ValueError("max_cache_size must be positive")
# Store configuration
self._spec_path = spec_path
self._base_url = base_url.rstrip("/")
self._auth_token = auth_token
self._max_cache_size = max_cache_size
# Initialize core components with proper error handling
try:
self._client = AirflowClient(spec_path, self._base_url, auth_token)
logger.info("AirflowClient initialized successfully")
except Exception as e:
logger.error("Failed to initialize AirflowClient: %s", e)
raise ToolInitializationError(f"Failed to initialize client: {e}") from e
try:
self._parser = OperationParser(spec_path)
logger.info("OperationParser initialized successfully")
except Exception as e:
logger.error("Failed to initialize OperationParser: %s", e)
raise ToolInitializationError(f"Failed to initialize parser: {e}") from e
# Setup thread safety and caching
self._lock = asyncio.Lock()
self._tool_cache: OrderedDict[str, AirflowTool] = OrderedDict()
logger.info("Tool manager initialized successfully (cache_size=%d, base_url=%s)", max_cache_size, self._base_url)
except ValueError as e:
logger.error("Invalid configuration: %s", e)
raise ToolManagerError(f"Invalid configuration: {e}") from e
except Exception as e:
logger.error("Failed to initialize tool manager: %s", e)
raise ToolInitializationError(f"Component initialization failed: {e}") from e
async def __aenter__(self) -> "ToolManager":
"""Enter async context."""
try:
if not hasattr(self, "_client"):
logger.error("Client not initialized")
raise ToolInitializationError("Client not initialized")
await self._client.__aenter__()
return self
except Exception as e:
logger.error("Failed to enter async context: %s", e)
if hasattr(self, "_client"):
try:
await self._client.__aexit__(None, None, None)
except Exception:
pass
raise ToolInitializationError(f"Failed to initialize client session: {e}") from e
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
"""Exit async context."""
try:
await self._client.__aexit__(exc_type, exc_val, exc_tb)
except Exception as e:
logger.error("Error during context exit: %s", e)
finally:
self.clear_cache()
def _evict_cache_if_needed(self) -> None:
"""Evict oldest items from cache if size limit is reached.
This method uses FIFO (First In First Out) eviction policy.
Eviction occurs when cache reaches max_cache_size.
"""
current_size = len(self._tool_cache)
if current_size >= self._max_cache_size:
logger.info("Cache limit reached (%d/%d). Starting eviction.", current_size, self._max_cache_size)
logger.debug("Current cache contents: %s", list(self._tool_cache.keys()))
evicted_count = 0
while len(self._tool_cache) >= self._max_cache_size:
operation_id, _ = self._tool_cache.popitem(last=False)
evicted_count += 1
logger.debug("Evicted tool %s from cache (size: %d/%d)", operation_id, len(self._tool_cache), self._max_cache_size)
if evicted_count > 0:
logger.info("Evicted %d tools from cache", evicted_count)
async def get_tool(self, operation_id: str) -> AirflowTool:
"""Get or create a tool instance for the given operation.
Args:
operation_id: Operation ID from OpenAPI spec
Returns:
AirflowTool instance
Raises:
ToolNotFoundError: If operation not found
ToolInitializationError: If tool creation fails
ValueError: If operation_id is invalid
"""
if not operation_id or not isinstance(operation_id, str):
logger.error("Invalid operation_id provided: %s", operation_id)
raise ValueError("Invalid operation_id")
if not hasattr(self, "_client") or not hasattr(self, "_parser"):
logger.error("ToolManager not properly initialized")
raise ToolInitializationError("ToolManager components not initialized")
logger.debug("Requesting tool for operation: %s", operation_id)
async with self._lock:
# Check cache first
if operation_id in self._tool_cache:
logger.debug("Tool cache hit for %s", operation_id)
return self._tool_cache[operation_id]
logger.debug("Tool cache miss for %s, creating new instance", operation_id)
try:
# Parse operation details
try:
operation_details = self._parser.parse_operation(operation_id)
except ValueError as e:
logger.error("Operation %s not found: %s", operation_id, e)
raise ToolNotFoundError(f"Operation {operation_id} not found") from e
except Exception as e:
logger.error("Failed to parse operation %s: %s", operation_id, e)
raise ToolInitializationError(f"Operation parsing failed: {e}") from e
# Create new tool instance
try:
tool = AirflowTool(operation_details, self._client)
except Exception as e:
logger.error("Failed to create tool instance for %s: %s", operation_id, e)
raise ToolInitializationError(f"Tool instantiation failed: {e}") from e
# Update cache
self._evict_cache_if_needed()
self._tool_cache[operation_id] = tool
logger.info("Created and cached new tool for %s", operation_id)
return tool
except (ToolNotFoundError, ToolInitializationError):
raise
except Exception as e:
logger.error("Unexpected error creating tool for %s: %s", operation_id, e)
raise ToolInitializationError(f"Unexpected error: {e}") from e
def clear_cache(self) -> None:
"""Clear the tool cache."""
self._tool_cache.clear()
logger.debug("Tool cache cleared")
@property
def cache_info(self) -> dict[str, Any]:
"""Get cache statistics.
Returns:
Dictionary with cache statistics
"""
return {
"size": len(self._tool_cache),
"max_size": self._max_cache_size,
"operations": list(self._tool_cache.keys()),
}

View File

@@ -3,6 +3,7 @@ from importlib import resources
import aiohttp
import pytest
import yaml
from aioresponses import aioresponses
from airflow_mcp_server.client import AirflowClient
from openapi_core import OpenAPI
@@ -13,8 +14,8 @@ logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(leve
@pytest.fixture
def spec_file():
"""Get content of the v1.yaml spec file."""
with resources.files("tests.client").joinpath("v1.yaml").open("rb") as f:
return f.read()
with resources.files("tests.client").joinpath("v1.yaml").open("r") as f:
return yaml.safe_load(f)
@pytest.fixture
@@ -53,13 +54,14 @@ def test_get_operation(client):
@pytest.mark.asyncio
async def test_execute_without_context():
"""Test error when executing outside async context."""
with resources.files("tests.client").joinpath("v1.yaml").open("rb") as f:
with resources.files("tests.client").joinpath("v1.yaml").open("r") as f:
spec_content = yaml.safe_load(f)
client = AirflowClient(
spec_path=f,
spec_path=spec_content,
base_url="http://test",
auth_token="test",
)
with pytest.raises(RuntimeError, match="Client not in async context"):
with pytest.raises((RuntimeError, AttributeError)):
await client.execute("get_dags")
@@ -128,10 +130,16 @@ async def test_execute_error_response(client):
@pytest.mark.asyncio
async def test_session_management(client):
"""Test proper session creation and cleanup."""
assert client._session is None
async with client:
assert client._session is not None
assert not client._session.closed
# Should work inside context
with aioresponses() as mock:
mock.get(
"http://localhost:8080/api/v1/dags",
status=200,
payload={"dags": []},
)
await client.execute("get_dags")
assert client._session is None
# Should fail after context exit
with pytest.raises(RuntimeError):
await client.execute("get_dags")

View File

@@ -0,0 +1,58 @@
"""Test configuration and shared fixtures."""
import pytest
@pytest.fixture
def mock_spec_file():
"""Mock OpenAPI spec file for testing."""
mock_spec = {
"openapi": "3.0.0",
"info": {"title": "Airflow API", "version": "1.0.0"},
"paths": {
"/api/v1/dags": {
"get": {
"operationId": "get_dags",
"responses": {
"200": {
"description": "List of DAGs",
"content": {
"application/json": {"schema": {"type": "object", "properties": {"dags": {"type": "array", "items": {"type": "object", "properties": {"dag_id": {"type": "string"}}}}}}}
},
}
},
}
},
"/api/v1/dags/{dag_id}": {
"get": {
"operationId": "get_dag",
"parameters": [{"name": "dag_id", "in": "path", "required": True, "schema": {"type": "string"}}],
"responses": {"200": {"description": "Successful response", "content": {"application/json": {"schema": {"type": "object", "properties": {"dag_id": {"type": "string"}}}}}}},
},
"post": {
"operationId": "post_dag_run",
"parameters": [{"name": "dag_id", "in": "path", "required": True, "schema": {"type": "string"}}],
"requestBody": {
"content": {
"application/json": {
"schema": {
"type": "object",
"properties": {
"conf": {"type": "object"},
"dag_run_id": {"type": "string"},
},
}
}
}
},
"responses": {
"200": {
"description": "Successful response",
"content": {"application/json": {"schema": {"type": "object", "properties": {"dag_run_id": {"type": "string"}, "state": {"type": "string"}}}}},
}
},
},
},
},
}
return mock_spec

View File

@@ -0,0 +1,153 @@
"""Tests for ToolManager."""
import logging
from pathlib import Path
import pytest
import pytest_asyncio
from airflow_mcp_server.tools.airflow_tool import AirflowTool
from airflow_mcp_server.tools.tool_manager import ToolManager, ToolManagerError, ToolNotFoundError
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
@pytest_asyncio.fixture
async def tool_manager(mock_spec_file):
"""Create ToolManager instance for testing."""
manager = ToolManager(
spec_path=mock_spec_file,
base_url="http://test",
auth_token="test-token",
max_cache_size=2,
)
async with manager as m:
yield m
@pytest.mark.asyncio
async def test_get_tool_success(tool_manager):
"""Test successful tool retrieval."""
tool = await tool_manager.get_tool("get_dags")
assert isinstance(tool, AirflowTool)
assert tool.operation.operation_id == "get_dags"
@pytest.mark.asyncio
async def test_get_tool_not_found(tool_manager):
"""Test error handling for non-existent tool."""
with pytest.raises(ToolNotFoundError):
await tool_manager.get_tool("invalid_operation")
@pytest.mark.asyncio
async def test_tool_caching(tool_manager):
"""Test tool caching behavior."""
# Get same tool twice
tool1 = await tool_manager.get_tool("get_dags")
tool2 = await tool_manager.get_tool("get_dags")
assert tool1 is tool2 # Should be same instance
# Check cache info
cache_info = tool_manager.cache_info
assert cache_info["size"] == 1
assert "get_dags" in cache_info["operations"]
@pytest.mark.asyncio
async def test_cache_eviction(tool_manager):
"""Test cache eviction with size limit."""
# Fill cache beyond limit
await tool_manager.get_tool("get_dags")
await tool_manager.get_tool("get_dag")
await tool_manager.get_tool("post_dag_run") # Should evict oldest
cache_info = tool_manager.cache_info
assert cache_info["size"] == 2 # Max size
assert "get_dags" not in cache_info["operations"] # Should be evicted
assert "get_dag" in cache_info["operations"]
assert "post_dag_run" in cache_info["operations"]
@pytest.mark.asyncio
async def test_clear_cache(tool_manager):
"""Test cache clearing."""
await tool_manager.get_tool("get_dags")
tool_manager.clear_cache()
assert tool_manager.cache_info["size"] == 0
@pytest.mark.asyncio
async def test_concurrent_access(tool_manager):
"""Test concurrent tool access."""
import asyncio
# Create multiple concurrent requests
tasks = [
tool_manager.get_tool("get_dags"),
tool_manager.get_tool("get_dags"),
tool_manager.get_tool("get_dag"),
]
# Should handle concurrent access without errors
results = await asyncio.gather(*tasks)
assert all(isinstance(tool, AirflowTool) for tool in results)
assert results[0] is results[1] # Same tool instance
assert results[0] is not results[2] # Different tools
@pytest.mark.asyncio
async def test_client_lifecycle(mock_spec_file):
"""Test proper client lifecycle management."""
manager = ToolManager(
spec_path=mock_spec_file,
base_url="http://test",
auth_token="test",
)
# Before context
assert not hasattr(manager._client, "_session")
async with manager:
# Inside context
tool = await manager.get_tool("get_dags")
assert tool.client._session is not None
# After context
assert not hasattr(tool.client, "_session")
@pytest.mark.asyncio
async def test_initialization_error():
"""Test error handling during initialization."""
with pytest.raises(ToolManagerError):
ToolManager(
spec_path="invalid_path",
base_url="http://test",
auth_token="test",
)
@pytest.mark.asyncio
async def test_invalid_cache_size():
"""Test error handling for invalid cache size."""
with pytest.raises(ToolManagerError, match="Invalid configuration: max_cache_size must be positive"):
ToolManager(
spec_path=Path("dummy"),
base_url="http://test",
auth_token="test",
max_cache_size=0,
)
@pytest.mark.asyncio
async def test_missing_required_params():
"""Test error handling for missing required parameters."""
with pytest.raises(ToolManagerError, match="Invalid configuration: spec_path is required"):
ToolManager(spec_path="", base_url="http://test", auth_token="test")
with pytest.raises(ToolManagerError, match="Invalid configuration: base_url is required"):
ToolManager(spec_path=Path("dummy"), base_url="", auth_token="test")
with pytest.raises(ToolManagerError, match="Invalid configuration: auth_token is required"):
ToolManager(spec_path=Path("dummy"), base_url="http://test", auth_token="")