refactor for dynamic Models

This commit is contained in:
2025-02-13 10:03:23 +00:00
parent 18962618bc
commit b87fe176bd
6 changed files with 90 additions and 385 deletions

View File

@@ -259,6 +259,23 @@ class OperationParser:
return self._create_model("Response", schema)
def get_operations(self) -> list[str]:
"""Get list of all operation IDs from spec.
Returns:
List of operation IDs
"""
operations = []
for path in self._paths.values():
for method, operation in path.items():
if method.startswith("x-") or method == "parameters":
continue
if "operationId" in operation:
operations.append(operation["operationId"])
return operations
def _create_model(self, name: str, schema: dict[str, Any]) -> type[BaseModel]:
"""Create Pydantic model from schema.

View File

@@ -1,61 +1,37 @@
import logging
import os
from enum import Enum
from typing import Any
from mcp.server import Server
from mcp.server.stdio import stdio_server
from mcp.types import TextContent, Tool
from airflow_mcp_server.tools.models import ListDags
from airflow_mcp_server.tools.tool_manager import get_airflow_dag_tools
from airflow_mcp_server.tools.tool_manager import get_airflow_tools, get_tool
class AirflowAPITools(str, Enum):
# DAG Operations
LIST_DAGS = "list_dags"
async def process_instruction(instruction: dict[str, Any]) -> dict[str, Any]:
dag_tools = get_airflow_dag_tools()
try:
match instruction["type"]:
case "list_dags":
return {"dags": await dag_tools.list_dags()}
case _:
return {"message": "Invalid instruction type"}
except Exception as e:
return {"error": str(e)}
logger = logging.getLogger(__name__)
async def serve() -> None:
"""Start MCP server."""
required_vars = ["OPENAPI_SPEC", "AIRFLOW_BASE_URL", "AUTH_TOKEN"]
if not all(var in os.environ for var in required_vars):
raise ValueError(f"Missing required environment variables: {required_vars}")
server = Server("airflow-mcp-server")
@server.list_tools()
async def list_tools() -> list[Tool]:
tools = [
# DAG Operations
Tool(
name=AirflowAPITools.LIST_DAGS,
description="Lists all DAGs in Airflow",
inputSchema=ListDags.model_json_schema(),
),
]
if "AIRFLOW_BASE_URL" in os.environ and "AUTH_TOKEN" in os.environ:
return tools
else:
return []
return get_airflow_tools()
@server.call_tool()
async def call_tool(name: str, arguments: dict) -> list[TextContent]:
dag_tools = get_airflow_dag_tools()
match name:
case AirflowAPITools.LIST_DAGS:
result = await dag_tools.list_dags()
return [TextContent(type="text", text=result)]
case _:
raise ValueError(f"Unknown tool: {name}")
async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]:
try:
tool = get_tool(name)
result = await tool.run(**arguments)
return [TextContent(type="text", text=str(result))]
except Exception as e:
logger.error("Tool execution failed: %s", e)
raise
options = server.create_initialization_options()
async with stdio_server() as (read_stream, write_stream):

View File

@@ -1,103 +0,0 @@
import os
import aiohttp
class AirflowDagTools:
def __init__(self):
self.airflow_base_url = os.getenv("AIRFLOW_BASE_URL")
self.auth_token = os.getenv("AUTH_TOKEN")
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.auth_token}"}
async def list_dags(
self,
limit: int | None = 100,
offset: int | None = None,
order_by: str | None = None,
tags: list[str] | None = None,
only_active: bool = True,
paused: bool | None = None,
fields: list[str] | None = None,
dag_id_pattern: str | None = None,
) -> list[str]:
"""
List all DAGs in Airflow.
Sample response:
{
"dags": [
{
"dag_id": "string",
"dag_display_name": "string",
"root_dag_id": "string",
"is_paused": true,
"is_active": true,
"is_subdag": true,
"last_parsed_time": "2019-08-24T14:15:22Z",
"last_pickled": "2019-08-24T14:15:22Z",
"last_expired": "2019-08-24T14:15:22Z",
"scheduler_lock": true,
"pickle_id": "string",
"default_view": "string",
"fileloc": "string",
"file_token": "string",
"owners": [
"string"
],
"description": "string",
"schedule_interval": {
"__type": "string",
"days": 0,
"seconds": 0,
"microseconds": 0
},
"timetable_description": "string",
"tags": [
{
"name": "string"
}
],
"max_active_tasks": 0,
"max_active_runs": 0,
"has_task_concurrency_limits": true,
"has_import_errors": true,
"next_dagrun": "2019-08-24T14:15:22Z",
"next_dagrun_data_interval_start": "2019-08-24T14:15:22Z",
"next_dagrun_data_interval_end": "2019-08-24T14:15:22Z",
"next_dagrun_create_after": "2019-08-24T14:15:22Z",
"max_consecutive_failed_dag_runs": 0
}
],
"total_entries": 0
}
Args:
limit (int, optional): The numbers of items to return.
offset (int, optional): The number of items to skip before starting to collect the result set.
order_by (str, optional): The name of the field to order the results by. Prefix a field name with - to reverse the sort order. New in version 2.1.0
tags (list[str], optional): List of tags to filter results. New in version 2.2.0
only_active (bool, optional): Only filter active DAGs. New in version 2.1.1
paused (bool, optional): Only filter paused/unpaused DAGs. If absent or null, it returns paused and unpaused DAGs. New in version 2.6.0
fields (list[str], optional): List of field for return.
dag_id_pattern (str, optional): If set, only return DAGs with dag_ids matching this pattern.
Returns:
list[str]: A list of DAG names.
"""
dags = []
async with aiohttp.ClientSession() as session:
params = {
"limit": limit,
"offset": offset,
"order_by": order_by,
"tags": tags,
"only_active": only_active,
"paused": paused,
"fields": fields,
"dag_id_pattern": dag_id_pattern,
}
async with session.get(f"{self.airflow_base_url}/api/v1/dags", headers=self.headers, params=params) as response:
if response.status == 200:
dags = await response.json()
return dags

View File

@@ -32,14 +32,17 @@ def create_validation_error(field: str, message: str) -> ValidationError:
class AirflowTool(BaseTools):
"""Tool for executing Airflow API operations."""
"""
Tool for executing Airflow API operations.
AirflowTool is supposed to have objects per operation.
"""
def __init__(self, operation_details: OperationDetails, client: AirflowClient) -> None:
"""Initialize tool with operation details and client.
Args:
operation_details: Parsed operation details
client: Configured Airflow API client
operation_details: Operation details
client: AirflowClient instance
"""
super().__init__()
self.operation = operation_details

View File

@@ -1,22 +0,0 @@
from pydantic import BaseModel, model_validator
# DAG operations
# ====================================================================
class ListDags(BaseModel):
"""Parameters for listing DAGs."""
limit: int | None
offset: int | None
order_by: str | None
tags: list[str] | None
only_active: bool
paused: bool | None
fields: list[str] | None
dag_id_pattern: str | None
@model_validator(mode="after")
def validate_offset(self) -> "ListDags":
if self.offset is not None and self.offset < 0:
raise ValueError("offset must be non-negative")
return self

View File

@@ -1,243 +1,77 @@
"""Tool manager for handling Airflow API tool instantiation and caching."""
import asyncio
import logging
from collections import OrderedDict
from pathlib import Path
from typing import Any
import os
from mcp.types import Tool
from airflow_mcp_server.client.airflow_client import AirflowClient
from airflow_mcp_server.parser.operation_parser import OperationParser
from airflow_mcp_server.tools.airflow_dag_tools import AirflowDagTools
from airflow_mcp_server.tools.airflow_tool import AirflowTool
logger = logging.getLogger(__name__)
# Keep existing function for backward compatibility
_dag_tools: AirflowDagTools | None = None
_tools_cache: dict[str, AirflowTool] = {}
_client: AirflowClient | None = None
def get_airflow_dag_tools() -> AirflowDagTools:
global _dag_tools
if not _dag_tools:
_dag_tools = AirflowDagTools()
return _dag_tools
def get_airflow_tools() -> list[Tool]:
"""Get list of all available Airflow tools.
Returns:
List of MCP Tool objects representing available operations
class ToolManagerError(Exception):
"""Base exception for tool manager errors."""
pass
class ToolInitializationError(ToolManagerError):
"""Error during tool initialization."""
pass
class ToolNotFoundError(ToolManagerError):
"""Error when requested tool is not found."""
pass
class ToolManager:
"""Manager for Airflow API tools with caching and lifecycle management.
This class provides a centralized way to manage Airflow API tools with:
- Singleton client management
- Tool caching with size limits
- Thread-safe access
- Proper error handling
Raises:
ValueError: If required environment variables are missing
"""
global _tools_cache, _client
def __init__(
self,
spec_path: Path | str | object,
base_url: str,
auth_token: str,
max_cache_size: int = 100,
) -> None:
"""Initialize tool manager.
if not _tools_cache:
required_vars = ["OPENAPI_SPEC", "AIRFLOW_BASE_URL", "AUTH_TOKEN"]
if not all(var in os.environ for var in required_vars):
raise ValueError(f"Missing required environment variables: {required_vars}")
Args:
spec_path: Path to OpenAPI spec file or file-like object
base_url: Base URL for Airflow API
auth_token: Authentication token
max_cache_size: Maximum number of tools to cache (default: 100)
# Initialize client if not exists
if not _client:
_client = AirflowClient(spec_path=os.environ["OPENAPI_SPEC"], base_url=os.environ["AIRFLOW_BASE_URL"], auth_token=os.environ["AUTH_TOKEN"])
Raises:
ToolManagerError: If initialization fails
ToolInitializationError: If client or parser initialization fails
"""
try:
# Validate inputs
if not spec_path:
raise ValueError("spec_path is required")
if not base_url:
raise ValueError("base_url is required")
if not auth_token:
raise ValueError("auth_token is required")
if max_cache_size < 1:
raise ValueError("max_cache_size must be positive")
# Create parser
parser = OperationParser(os.environ["OPENAPI_SPEC"])
# Store configuration
self._spec_path = spec_path
self._base_url = base_url.rstrip("/")
self._auth_token = auth_token
self._max_cache_size = max_cache_size
# Generate tools for each operation
for operation_id in parser.get_operations():
operation_details = parser.parse_operation(operation_id)
tool = AirflowTool(operation_details, _client)
_tools_cache[operation_id] = tool
# Initialize core components with proper error handling
try:
self._client = AirflowClient(spec_path, self._base_url, auth_token)
logger.info("AirflowClient initialized successfully")
except Exception as e:
logger.error("Failed to initialize AirflowClient: %s", e)
raise ToolInitializationError(f"Failed to initialize client: {e}") from e
try:
self._parser = OperationParser(spec_path)
logger.info("OperationParser initialized successfully")
except Exception as e:
logger.error("Failed to initialize OperationParser: %s", e)
raise ToolInitializationError(f"Failed to initialize parser: {e}") from e
# Setup thread safety and caching
self._lock = asyncio.Lock()
self._tool_cache: OrderedDict[str, AirflowTool] = OrderedDict()
logger.info("Tool manager initialized successfully (cache_size=%d, base_url=%s)", max_cache_size, self._base_url)
except ValueError as e:
logger.error("Invalid configuration: %s", e)
raise ToolManagerError(f"Invalid configuration: {e}") from e
except Exception as e:
logger.error("Failed to initialize tool manager: %s", e)
raise ToolInitializationError(f"Component initialization failed: {e}") from e
logger.error("Failed to initialize tools: %s", e)
raise
async def __aenter__(self) -> "ToolManager":
"""Enter async context."""
try:
if not hasattr(self, "_client"):
logger.error("Client not initialized")
raise ToolInitializationError("Client not initialized")
await self._client.__aenter__()
return self
except Exception as e:
logger.error("Failed to enter async context: %s", e)
if hasattr(self, "_client"):
try:
await self._client.__aexit__(None, None, None)
except Exception:
pass
raise ToolInitializationError(f"Failed to initialize client session: {e}") from e
# Convert to MCP Tool format
return [
Tool(
name=operation_id,
description=tool.operation.operation_id,
inputSchema=tool.operation.request_body.model_json_schema() if tool.operation.request_body else None,
)
for operation_id, tool in _tools_cache.items()
]
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
"""Exit async context."""
try:
await self._client.__aexit__(exc_type, exc_val, exc_tb)
except Exception as e:
logger.error("Error during context exit: %s", e)
finally:
self.clear_cache()
def _evict_cache_if_needed(self) -> None:
"""Evict oldest items from cache if size limit is reached.
def get_tool(name: str) -> AirflowTool:
"""Get specific tool by name.
This method uses FIFO (First In First Out) eviction policy.
Eviction occurs when cache reaches max_cache_size.
"""
current_size = len(self._tool_cache)
if current_size >= self._max_cache_size:
logger.info("Cache limit reached (%d/%d). Starting eviction.", current_size, self._max_cache_size)
logger.debug("Current cache contents: %s", list(self._tool_cache.keys()))
Args:
name: Tool/operation name
evicted_count = 0
while len(self._tool_cache) >= self._max_cache_size:
operation_id, _ = self._tool_cache.popitem(last=False)
evicted_count += 1
logger.debug("Evicted tool %s from cache (size: %d/%d)", operation_id, len(self._tool_cache), self._max_cache_size)
Returns:
AirflowTool instance
if evicted_count > 0:
logger.info("Evicted %d tools from cache", evicted_count)
Raises:
KeyError: If tool not found
"""
if name not in _tools_cache:
# Ensure cache is populated
get_airflow_tools()
async def get_tool(self, operation_id: str) -> AirflowTool:
"""Get or create a tool instance for the given operation.
Args:
operation_id: Operation ID from OpenAPI spec
Returns:
AirflowTool instance
Raises:
ToolNotFoundError: If operation not found
ToolInitializationError: If tool creation fails
ValueError: If operation_id is invalid
"""
if not operation_id or not isinstance(operation_id, str):
logger.error("Invalid operation_id provided: %s", operation_id)
raise ValueError("Invalid operation_id")
if not hasattr(self, "_client") or not hasattr(self, "_parser"):
logger.error("ToolManager not properly initialized")
raise ToolInitializationError("ToolManager components not initialized")
logger.debug("Requesting tool for operation: %s", operation_id)
async with self._lock:
# Check cache first
if operation_id in self._tool_cache:
logger.debug("Tool cache hit for %s", operation_id)
return self._tool_cache[operation_id]
logger.debug("Tool cache miss for %s, creating new instance", operation_id)
try:
# Parse operation details
try:
operation_details = self._parser.parse_operation(operation_id)
except ValueError as e:
logger.error("Operation %s not found: %s", operation_id, e)
raise ToolNotFoundError(f"Operation {operation_id} not found") from e
except Exception as e:
logger.error("Failed to parse operation %s: %s", operation_id, e)
raise ToolInitializationError(f"Operation parsing failed: {e}") from e
# Create new tool instance
try:
tool = AirflowTool(operation_details, self._client)
except Exception as e:
logger.error("Failed to create tool instance for %s: %s", operation_id, e)
raise ToolInitializationError(f"Tool instantiation failed: {e}") from e
# Update cache
self._evict_cache_if_needed()
self._tool_cache[operation_id] = tool
logger.info("Created and cached new tool for %s", operation_id)
return tool
except (ToolNotFoundError, ToolInitializationError):
raise
except Exception as e:
logger.error("Unexpected error creating tool for %s: %s", operation_id, e)
raise ToolInitializationError(f"Unexpected error: {e}") from e
def clear_cache(self) -> None:
"""Clear the tool cache."""
self._tool_cache.clear()
logger.debug("Tool cache cleared")
@property
def cache_info(self) -> dict[str, Any]:
"""Get cache statistics.
Returns:
Dictionary with cache statistics
"""
return {
"size": len(self._tool_cache),
"max_size": self._max_cache_size,
"operations": list(self._tool_cache.keys()),
}
return _tools_cache[name]