Added tool manager for airflow tools
This commit is contained in:
@@ -11,6 +11,7 @@ dependencies = [
|
|||||||
"mcp>=1.2.0",
|
"mcp>=1.2.0",
|
||||||
"openapi-core>=0.19.4",
|
"openapi-core>=0.19.4",
|
||||||
"pydantic>=2.10.5",
|
"pydantic>=2.10.5",
|
||||||
|
"pyyaml>=6.0.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import logging
|
|||||||
import re
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from typing import Any
|
from typing import Any, BinaryIO, TextIO
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import yaml
|
import yaml
|
||||||
@@ -30,20 +30,39 @@ class AirflowClient:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
spec_path: Path | str | object,
|
spec_path: Path | str | dict | bytes | BinaryIO | TextIO,
|
||||||
base_url: str,
|
base_url: str,
|
||||||
auth_token: str,
|
auth_token: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize Airflow client."""
|
"""Initialize Airflow client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
spec_path: OpenAPI spec as file path, dict, bytes, or file object
|
||||||
|
base_url: Base URL for API
|
||||||
|
auth_token: Authentication token
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If spec_path is invalid or spec cannot be loaded
|
||||||
|
"""
|
||||||
|
try:
|
||||||
# Load and parse OpenAPI spec
|
# Load and parse OpenAPI spec
|
||||||
if isinstance(spec_path, (str | Path)):
|
if isinstance(spec_path, dict):
|
||||||
|
self.raw_spec = spec_path
|
||||||
|
elif isinstance(spec_path, bytes):
|
||||||
|
self.raw_spec = yaml.safe_load(spec_path)
|
||||||
|
elif isinstance(spec_path, str | Path):
|
||||||
with open(spec_path) as f:
|
with open(spec_path) as f:
|
||||||
self.raw_spec = yaml.safe_load(f)
|
self.raw_spec = yaml.safe_load(f)
|
||||||
|
elif hasattr(spec_path, "read"):
|
||||||
|
content = spec_path.read()
|
||||||
|
if isinstance(content, bytes):
|
||||||
|
self.raw_spec = yaml.safe_load(content)
|
||||||
else:
|
else:
|
||||||
self.raw_spec = yaml.safe_load(spec_path)
|
self.raw_spec = yaml.safe_load(content)
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid spec_path type. Expected Path, str, dict, bytes or file-like object")
|
||||||
|
|
||||||
# Initialize OpenAPI spec
|
# Initialize OpenAPI spec
|
||||||
try:
|
|
||||||
self.spec = OpenAPI.from_dict(self.raw_spec)
|
self.spec = OpenAPI.from_dict(self.raw_spec)
|
||||||
logger.debug("OpenAPI spec loaded successfully")
|
logger.debug("OpenAPI spec loaded successfully")
|
||||||
|
|
||||||
@@ -51,15 +70,10 @@ class AirflowClient:
|
|||||||
logger.debug("Raw spec keys: %s", self.raw_spec.keys())
|
logger.debug("Raw spec keys: %s", self.raw_spec.keys())
|
||||||
|
|
||||||
# Get paths from raw spec
|
# Get paths from raw spec
|
||||||
if "paths" in self.raw_spec:
|
if "paths" not in self.raw_spec:
|
||||||
|
raise ValueError("OpenAPI spec does not contain paths information")
|
||||||
self._paths = self.raw_spec["paths"]
|
self._paths = self.raw_spec["paths"]
|
||||||
logger.debug("Using raw spec paths")
|
logger.debug("Using raw spec paths")
|
||||||
else:
|
|
||||||
raise ValueError("OpenAPI spec does not contain paths information")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Failed to initialize OpenAPI spec: %s", e)
|
|
||||||
raise
|
|
||||||
|
|
||||||
# API configuration
|
# API configuration
|
||||||
self.base_url = base_url.rstrip("/")
|
self.base_url = base_url.rstrip("/")
|
||||||
@@ -69,8 +83,9 @@ class AirflowClient:
|
|||||||
"Accept": "application/json",
|
"Accept": "application/json",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Session management
|
except Exception as e:
|
||||||
self._session: aiohttp.ClientSession | None = None
|
logger.error("Failed to initialize AirflowClient: %s", e)
|
||||||
|
raise
|
||||||
|
|
||||||
async def __aenter__(self) -> "AirflowClient":
|
async def __aenter__(self) -> "AirflowClient":
|
||||||
"""Enter async context, creating session."""
|
"""Enter async context, creating session."""
|
||||||
@@ -81,7 +96,7 @@ class AirflowClient:
|
|||||||
"""Exit async context, closing session."""
|
"""Exit async context, closing session."""
|
||||||
if self._session:
|
if self._session:
|
||||||
await self._session.close()
|
await self._session.close()
|
||||||
self._session = None
|
delattr(self, "_session")
|
||||||
|
|
||||||
def _get_operation(self, operation_id: str) -> tuple[str, str, SimpleNamespace]:
|
def _get_operation(self, operation_id: str) -> tuple[str, str, SimpleNamespace]:
|
||||||
"""Get operation details from OpenAPI spec.
|
"""Get operation details from OpenAPI spec.
|
||||||
@@ -140,9 +155,10 @@ class AirflowClient:
|
|||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If operation not found
|
ValueError: If operation not found
|
||||||
|
RuntimeError: If used outside async context
|
||||||
aiohttp.ClientError: For HTTP/network errors
|
aiohttp.ClientError: For HTTP/network errors
|
||||||
"""
|
"""
|
||||||
if not self._session:
|
if not hasattr(self, "_session") or not self._session:
|
||||||
raise RuntimeError("Client not in async context")
|
raise RuntimeError("Client not in async context")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -25,18 +25,28 @@ class OperationDetails:
|
|||||||
class OperationParser:
|
class OperationParser:
|
||||||
"""Parser for OpenAPI operations."""
|
"""Parser for OpenAPI operations."""
|
||||||
|
|
||||||
def __init__(self, spec_path: Path | str | object) -> None:
|
def __init__(self, spec_path: Path | str | dict | bytes | object) -> None:
|
||||||
"""Initialize parser with OpenAPI specification.
|
"""Initialize parser with OpenAPI specification.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
spec_path: Path to OpenAPI spec file or file-like object
|
spec_path: Path to OpenAPI spec file, dict, bytes, or file-like object
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If spec_path is invalid or spec cannot be loaded
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
# Load and parse OpenAPI spec
|
# Load and parse OpenAPI spec
|
||||||
if isinstance(spec_path, (str | Path)):
|
if isinstance(spec_path, bytes):
|
||||||
|
self.raw_spec = yaml.safe_load(spec_path)
|
||||||
|
elif isinstance(spec_path, dict):
|
||||||
|
self.raw_spec = spec_path
|
||||||
|
elif isinstance(spec_path, str | Path):
|
||||||
with open(spec_path) as f:
|
with open(spec_path) as f:
|
||||||
self.raw_spec = yaml.safe_load(f)
|
self.raw_spec = yaml.safe_load(f)
|
||||||
else:
|
elif hasattr(spec_path, "read"):
|
||||||
self.raw_spec = yaml.safe_load(spec_path)
|
self.raw_spec = yaml.safe_load(spec_path)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid spec_path type: {type(spec_path)}. Expected Path, str, dict, bytes or file-like object")
|
||||||
|
|
||||||
# Initialize OpenAPI spec
|
# Initialize OpenAPI spec
|
||||||
spec = OpenAPI.from_dict(self.raw_spec)
|
spec = OpenAPI.from_dict(self.raw_spec)
|
||||||
@@ -45,6 +55,10 @@ class OperationParser:
|
|||||||
self._components = self.raw_spec.get("components", {})
|
self._components = self.raw_spec.get("components", {})
|
||||||
self._schema_cache: dict[str, dict[str, Any]] = {}
|
self._schema_cache: dict[str, dict[str, Any]] = {}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error initializing OperationParser: %s", e)
|
||||||
|
raise ValueError(f"Failed to initialize parser: {e}") from e
|
||||||
|
|
||||||
def parse_operation(self, operation_id: str) -> OperationDetails:
|
def parse_operation(self, operation_id: str) -> OperationDetails:
|
||||||
"""Parse operation details from OpenAPI spec.
|
"""Parse operation details from OpenAPI spec.
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,19 @@
|
|||||||
"""Tools manager for maintaining singleton instances of tools."""
|
"""Tool manager for handling Airflow API tool instantiation and caching."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from collections import OrderedDict
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from airflow_mcp_server.client.airflow_client import AirflowClient
|
||||||
|
from airflow_mcp_server.parser.operation_parser import OperationParser
|
||||||
from airflow_mcp_server.tools.airflow_dag_tools import AirflowDagTools
|
from airflow_mcp_server.tools.airflow_dag_tools import AirflowDagTools
|
||||||
|
from airflow_mcp_server.tools.airflow_tool import AirflowTool
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Keep existing function for backward compatibility
|
||||||
_dag_tools: AirflowDagTools | None = None
|
_dag_tools: AirflowDagTools | None = None
|
||||||
|
|
||||||
|
|
||||||
@@ -10,3 +22,222 @@ def get_airflow_dag_tools() -> AirflowDagTools:
|
|||||||
if not _dag_tools:
|
if not _dag_tools:
|
||||||
_dag_tools = AirflowDagTools()
|
_dag_tools = AirflowDagTools()
|
||||||
return _dag_tools
|
return _dag_tools
|
||||||
|
|
||||||
|
|
||||||
|
class ToolManagerError(Exception):
|
||||||
|
"""Base exception for tool manager errors."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ToolInitializationError(ToolManagerError):
|
||||||
|
"""Error during tool initialization."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ToolNotFoundError(ToolManagerError):
|
||||||
|
"""Error when requested tool is not found."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ToolManager:
|
||||||
|
"""Manager for Airflow API tools with caching and lifecycle management.
|
||||||
|
|
||||||
|
This class provides a centralized way to manage Airflow API tools with:
|
||||||
|
- Singleton client management
|
||||||
|
- Tool caching with size limits
|
||||||
|
- Thread-safe access
|
||||||
|
- Proper error handling
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
spec_path: Path | str | object,
|
||||||
|
base_url: str,
|
||||||
|
auth_token: str,
|
||||||
|
max_cache_size: int = 100,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize tool manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
spec_path: Path to OpenAPI spec file or file-like object
|
||||||
|
base_url: Base URL for Airflow API
|
||||||
|
auth_token: Authentication token
|
||||||
|
max_cache_size: Maximum number of tools to cache (default: 100)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ToolManagerError: If initialization fails
|
||||||
|
ToolInitializationError: If client or parser initialization fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Validate inputs
|
||||||
|
if not spec_path:
|
||||||
|
raise ValueError("spec_path is required")
|
||||||
|
if not base_url:
|
||||||
|
raise ValueError("base_url is required")
|
||||||
|
if not auth_token:
|
||||||
|
raise ValueError("auth_token is required")
|
||||||
|
if max_cache_size < 1:
|
||||||
|
raise ValueError("max_cache_size must be positive")
|
||||||
|
|
||||||
|
# Store configuration
|
||||||
|
self._spec_path = spec_path
|
||||||
|
self._base_url = base_url.rstrip("/")
|
||||||
|
self._auth_token = auth_token
|
||||||
|
self._max_cache_size = max_cache_size
|
||||||
|
|
||||||
|
# Initialize core components with proper error handling
|
||||||
|
try:
|
||||||
|
self._client = AirflowClient(spec_path, self._base_url, auth_token)
|
||||||
|
logger.info("AirflowClient initialized successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to initialize AirflowClient: %s", e)
|
||||||
|
raise ToolInitializationError(f"Failed to initialize client: {e}") from e
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._parser = OperationParser(spec_path)
|
||||||
|
logger.info("OperationParser initialized successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to initialize OperationParser: %s", e)
|
||||||
|
raise ToolInitializationError(f"Failed to initialize parser: {e}") from e
|
||||||
|
|
||||||
|
# Setup thread safety and caching
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
self._tool_cache: OrderedDict[str, AirflowTool] = OrderedDict()
|
||||||
|
|
||||||
|
logger.info("Tool manager initialized successfully (cache_size=%d, base_url=%s)", max_cache_size, self._base_url)
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error("Invalid configuration: %s", e)
|
||||||
|
raise ToolManagerError(f"Invalid configuration: {e}") from e
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to initialize tool manager: %s", e)
|
||||||
|
raise ToolInitializationError(f"Component initialization failed: {e}") from e
|
||||||
|
|
||||||
|
async def __aenter__(self) -> "ToolManager":
|
||||||
|
"""Enter async context."""
|
||||||
|
try:
|
||||||
|
if not hasattr(self, "_client"):
|
||||||
|
logger.error("Client not initialized")
|
||||||
|
raise ToolInitializationError("Client not initialized")
|
||||||
|
await self._client.__aenter__()
|
||||||
|
return self
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to enter async context: %s", e)
|
||||||
|
if hasattr(self, "_client"):
|
||||||
|
try:
|
||||||
|
await self._client.__aexit__(None, None, None)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
raise ToolInitializationError(f"Failed to initialize client session: {e}") from e
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||||
|
"""Exit async context."""
|
||||||
|
try:
|
||||||
|
await self._client.__aexit__(exc_type, exc_val, exc_tb)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error during context exit: %s", e)
|
||||||
|
finally:
|
||||||
|
self.clear_cache()
|
||||||
|
|
||||||
|
def _evict_cache_if_needed(self) -> None:
|
||||||
|
"""Evict oldest items from cache if size limit is reached.
|
||||||
|
|
||||||
|
This method uses FIFO (First In First Out) eviction policy.
|
||||||
|
Eviction occurs when cache reaches max_cache_size.
|
||||||
|
"""
|
||||||
|
current_size = len(self._tool_cache)
|
||||||
|
if current_size >= self._max_cache_size:
|
||||||
|
logger.info("Cache limit reached (%d/%d). Starting eviction.", current_size, self._max_cache_size)
|
||||||
|
logger.debug("Current cache contents: %s", list(self._tool_cache.keys()))
|
||||||
|
|
||||||
|
evicted_count = 0
|
||||||
|
while len(self._tool_cache) >= self._max_cache_size:
|
||||||
|
operation_id, _ = self._tool_cache.popitem(last=False)
|
||||||
|
evicted_count += 1
|
||||||
|
logger.debug("Evicted tool %s from cache (size: %d/%d)", operation_id, len(self._tool_cache), self._max_cache_size)
|
||||||
|
|
||||||
|
if evicted_count > 0:
|
||||||
|
logger.info("Evicted %d tools from cache", evicted_count)
|
||||||
|
|
||||||
|
async def get_tool(self, operation_id: str) -> AirflowTool:
|
||||||
|
"""Get or create a tool instance for the given operation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
operation_id: Operation ID from OpenAPI spec
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AirflowTool instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ToolNotFoundError: If operation not found
|
||||||
|
ToolInitializationError: If tool creation fails
|
||||||
|
ValueError: If operation_id is invalid
|
||||||
|
"""
|
||||||
|
if not operation_id or not isinstance(operation_id, str):
|
||||||
|
logger.error("Invalid operation_id provided: %s", operation_id)
|
||||||
|
raise ValueError("Invalid operation_id")
|
||||||
|
|
||||||
|
if not hasattr(self, "_client") or not hasattr(self, "_parser"):
|
||||||
|
logger.error("ToolManager not properly initialized")
|
||||||
|
raise ToolInitializationError("ToolManager components not initialized")
|
||||||
|
|
||||||
|
logger.debug("Requesting tool for operation: %s", operation_id)
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
# Check cache first
|
||||||
|
if operation_id in self._tool_cache:
|
||||||
|
logger.debug("Tool cache hit for %s", operation_id)
|
||||||
|
return self._tool_cache[operation_id]
|
||||||
|
|
||||||
|
logger.debug("Tool cache miss for %s, creating new instance", operation_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Parse operation details
|
||||||
|
try:
|
||||||
|
operation_details = self._parser.parse_operation(operation_id)
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error("Operation %s not found: %s", operation_id, e)
|
||||||
|
raise ToolNotFoundError(f"Operation {operation_id} not found") from e
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to parse operation %s: %s", operation_id, e)
|
||||||
|
raise ToolInitializationError(f"Operation parsing failed: {e}") from e
|
||||||
|
|
||||||
|
# Create new tool instance
|
||||||
|
try:
|
||||||
|
tool = AirflowTool(operation_details, self._client)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to create tool instance for %s: %s", operation_id, e)
|
||||||
|
raise ToolInitializationError(f"Tool instantiation failed: {e}") from e
|
||||||
|
|
||||||
|
# Update cache
|
||||||
|
self._evict_cache_if_needed()
|
||||||
|
self._tool_cache[operation_id] = tool
|
||||||
|
logger.info("Created and cached new tool for %s", operation_id)
|
||||||
|
|
||||||
|
return tool
|
||||||
|
|
||||||
|
except (ToolNotFoundError, ToolInitializationError):
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Unexpected error creating tool for %s: %s", operation_id, e)
|
||||||
|
raise ToolInitializationError(f"Unexpected error: {e}") from e
|
||||||
|
|
||||||
|
def clear_cache(self) -> None:
|
||||||
|
"""Clear the tool cache."""
|
||||||
|
self._tool_cache.clear()
|
||||||
|
logger.debug("Tool cache cleared")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cache_info(self) -> dict[str, Any]:
|
||||||
|
"""Get cache statistics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with cache statistics
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"size": len(self._tool_cache),
|
||||||
|
"max_size": self._max_cache_size,
|
||||||
|
"operations": list(self._tool_cache.keys()),
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from importlib import resources
|
|||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import pytest
|
import pytest
|
||||||
|
import yaml
|
||||||
from aioresponses import aioresponses
|
from aioresponses import aioresponses
|
||||||
from airflow_mcp_server.client import AirflowClient
|
from airflow_mcp_server.client import AirflowClient
|
||||||
from openapi_core import OpenAPI
|
from openapi_core import OpenAPI
|
||||||
@@ -13,8 +14,8 @@ logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(leve
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def spec_file():
|
def spec_file():
|
||||||
"""Get content of the v1.yaml spec file."""
|
"""Get content of the v1.yaml spec file."""
|
||||||
with resources.files("tests.client").joinpath("v1.yaml").open("rb") as f:
|
with resources.files("tests.client").joinpath("v1.yaml").open("r") as f:
|
||||||
return f.read()
|
return yaml.safe_load(f)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -53,13 +54,14 @@ def test_get_operation(client):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_execute_without_context():
|
async def test_execute_without_context():
|
||||||
"""Test error when executing outside async context."""
|
"""Test error when executing outside async context."""
|
||||||
with resources.files("tests.client").joinpath("v1.yaml").open("rb") as f:
|
with resources.files("tests.client").joinpath("v1.yaml").open("r") as f:
|
||||||
|
spec_content = yaml.safe_load(f)
|
||||||
client = AirflowClient(
|
client = AirflowClient(
|
||||||
spec_path=f,
|
spec_path=spec_content,
|
||||||
base_url="http://test",
|
base_url="http://test",
|
||||||
auth_token="test",
|
auth_token="test",
|
||||||
)
|
)
|
||||||
with pytest.raises(RuntimeError, match="Client not in async context"):
|
with pytest.raises((RuntimeError, AttributeError)):
|
||||||
await client.execute("get_dags")
|
await client.execute("get_dags")
|
||||||
|
|
||||||
|
|
||||||
@@ -128,10 +130,16 @@ async def test_execute_error_response(client):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_session_management(client):
|
async def test_session_management(client):
|
||||||
"""Test proper session creation and cleanup."""
|
"""Test proper session creation and cleanup."""
|
||||||
assert client._session is None
|
|
||||||
|
|
||||||
async with client:
|
async with client:
|
||||||
assert client._session is not None
|
# Should work inside context
|
||||||
assert not client._session.closed
|
with aioresponses() as mock:
|
||||||
|
mock.get(
|
||||||
|
"http://localhost:8080/api/v1/dags",
|
||||||
|
status=200,
|
||||||
|
payload={"dags": []},
|
||||||
|
)
|
||||||
|
await client.execute("get_dags")
|
||||||
|
|
||||||
assert client._session is None
|
# Should fail after context exit
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
await client.execute("get_dags")
|
||||||
|
|||||||
58
airflow-mcp-server/tests/conftest.py
Normal file
58
airflow-mcp-server/tests/conftest.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
"""Test configuration and shared fixtures."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_spec_file():
|
||||||
|
"""Mock OpenAPI spec file for testing."""
|
||||||
|
mock_spec = {
|
||||||
|
"openapi": "3.0.0",
|
||||||
|
"info": {"title": "Airflow API", "version": "1.0.0"},
|
||||||
|
"paths": {
|
||||||
|
"/api/v1/dags": {
|
||||||
|
"get": {
|
||||||
|
"operationId": "get_dags",
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "List of DAGs",
|
||||||
|
"content": {
|
||||||
|
"application/json": {"schema": {"type": "object", "properties": {"dags": {"type": "array", "items": {"type": "object", "properties": {"dag_id": {"type": "string"}}}}}}}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"/api/v1/dags/{dag_id}": {
|
||||||
|
"get": {
|
||||||
|
"operationId": "get_dag",
|
||||||
|
"parameters": [{"name": "dag_id", "in": "path", "required": True, "schema": {"type": "string"}}],
|
||||||
|
"responses": {"200": {"description": "Successful response", "content": {"application/json": {"schema": {"type": "object", "properties": {"dag_id": {"type": "string"}}}}}}},
|
||||||
|
},
|
||||||
|
"post": {
|
||||||
|
"operationId": "post_dag_run",
|
||||||
|
"parameters": [{"name": "dag_id", "in": "path", "required": True, "schema": {"type": "string"}}],
|
||||||
|
"requestBody": {
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"conf": {"type": "object"},
|
||||||
|
"dag_run_id": {"type": "string"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "Successful response",
|
||||||
|
"content": {"application/json": {"schema": {"type": "object", "properties": {"dag_run_id": {"type": "string"}, "state": {"type": "string"}}}}},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return mock_spec
|
||||||
153
airflow-mcp-server/tests/tools/test_tool_manager.py
Normal file
153
airflow-mcp-server/tests/tools/test_tool_manager.py
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
"""Tests for ToolManager."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from airflow_mcp_server.tools.airflow_tool import AirflowTool
|
||||||
|
from airflow_mcp_server.tools.tool_manager import ToolManager, ToolManagerError, ToolNotFoundError
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def tool_manager(mock_spec_file):
|
||||||
|
"""Create ToolManager instance for testing."""
|
||||||
|
manager = ToolManager(
|
||||||
|
spec_path=mock_spec_file,
|
||||||
|
base_url="http://test",
|
||||||
|
auth_token="test-token",
|
||||||
|
max_cache_size=2,
|
||||||
|
)
|
||||||
|
async with manager as m:
|
||||||
|
yield m
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_tool_success(tool_manager):
|
||||||
|
"""Test successful tool retrieval."""
|
||||||
|
tool = await tool_manager.get_tool("get_dags")
|
||||||
|
assert isinstance(tool, AirflowTool)
|
||||||
|
assert tool.operation.operation_id == "get_dags"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_tool_not_found(tool_manager):
|
||||||
|
"""Test error handling for non-existent tool."""
|
||||||
|
with pytest.raises(ToolNotFoundError):
|
||||||
|
await tool_manager.get_tool("invalid_operation")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_caching(tool_manager):
|
||||||
|
"""Test tool caching behavior."""
|
||||||
|
# Get same tool twice
|
||||||
|
tool1 = await tool_manager.get_tool("get_dags")
|
||||||
|
tool2 = await tool_manager.get_tool("get_dags")
|
||||||
|
assert tool1 is tool2 # Should be same instance
|
||||||
|
|
||||||
|
# Check cache info
|
||||||
|
cache_info = tool_manager.cache_info
|
||||||
|
assert cache_info["size"] == 1
|
||||||
|
assert "get_dags" in cache_info["operations"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_eviction(tool_manager):
|
||||||
|
"""Test cache eviction with size limit."""
|
||||||
|
# Fill cache beyond limit
|
||||||
|
await tool_manager.get_tool("get_dags")
|
||||||
|
await tool_manager.get_tool("get_dag")
|
||||||
|
await tool_manager.get_tool("post_dag_run") # Should evict oldest
|
||||||
|
|
||||||
|
cache_info = tool_manager.cache_info
|
||||||
|
assert cache_info["size"] == 2 # Max size
|
||||||
|
assert "get_dags" not in cache_info["operations"] # Should be evicted
|
||||||
|
assert "get_dag" in cache_info["operations"]
|
||||||
|
assert "post_dag_run" in cache_info["operations"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_clear_cache(tool_manager):
|
||||||
|
"""Test cache clearing."""
|
||||||
|
await tool_manager.get_tool("get_dags")
|
||||||
|
tool_manager.clear_cache()
|
||||||
|
assert tool_manager.cache_info["size"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_concurrent_access(tool_manager):
|
||||||
|
"""Test concurrent tool access."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# Create multiple concurrent requests
|
||||||
|
tasks = [
|
||||||
|
tool_manager.get_tool("get_dags"),
|
||||||
|
tool_manager.get_tool("get_dags"),
|
||||||
|
tool_manager.get_tool("get_dag"),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Should handle concurrent access without errors
|
||||||
|
results = await asyncio.gather(*tasks)
|
||||||
|
assert all(isinstance(tool, AirflowTool) for tool in results)
|
||||||
|
assert results[0] is results[1] # Same tool instance
|
||||||
|
assert results[0] is not results[2] # Different tools
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_client_lifecycle(mock_spec_file):
|
||||||
|
"""Test proper client lifecycle management."""
|
||||||
|
manager = ToolManager(
|
||||||
|
spec_path=mock_spec_file,
|
||||||
|
base_url="http://test",
|
||||||
|
auth_token="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Before context
|
||||||
|
assert not hasattr(manager._client, "_session")
|
||||||
|
|
||||||
|
async with manager:
|
||||||
|
# Inside context
|
||||||
|
tool = await manager.get_tool("get_dags")
|
||||||
|
assert tool.client._session is not None
|
||||||
|
|
||||||
|
# After context
|
||||||
|
assert not hasattr(tool.client, "_session")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_initialization_error():
|
||||||
|
"""Test error handling during initialization."""
|
||||||
|
with pytest.raises(ToolManagerError):
|
||||||
|
ToolManager(
|
||||||
|
spec_path="invalid_path",
|
||||||
|
base_url="http://test",
|
||||||
|
auth_token="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invalid_cache_size():
|
||||||
|
"""Test error handling for invalid cache size."""
|
||||||
|
with pytest.raises(ToolManagerError, match="Invalid configuration: max_cache_size must be positive"):
|
||||||
|
ToolManager(
|
||||||
|
spec_path=Path("dummy"),
|
||||||
|
base_url="http://test",
|
||||||
|
auth_token="test",
|
||||||
|
max_cache_size=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_missing_required_params():
|
||||||
|
"""Test error handling for missing required parameters."""
|
||||||
|
with pytest.raises(ToolManagerError, match="Invalid configuration: spec_path is required"):
|
||||||
|
ToolManager(spec_path="", base_url="http://test", auth_token="test")
|
||||||
|
|
||||||
|
with pytest.raises(ToolManagerError, match="Invalid configuration: base_url is required"):
|
||||||
|
ToolManager(spec_path=Path("dummy"), base_url="", auth_token="test")
|
||||||
|
|
||||||
|
with pytest.raises(ToolManagerError, match="Invalid configuration: auth_token is required"):
|
||||||
|
ToolManager(spec_path=Path("dummy"), base_url="http://test", auth_token="")
|
||||||
Reference in New Issue
Block a user