refactor for dynamic Models

This commit is contained in:
2025-02-13 10:03:23 +00:00
parent 18962618bc
commit b87fe176bd
6 changed files with 90 additions and 385 deletions

View File

@@ -259,6 +259,23 @@ class OperationParser:
return self._create_model("Response", schema) return self._create_model("Response", schema)
def get_operations(self) -> list[str]:
"""Get list of all operation IDs from spec.
Returns:
List of operation IDs
"""
operations = []
for path in self._paths.values():
for method, operation in path.items():
if method.startswith("x-") or method == "parameters":
continue
if "operationId" in operation:
operations.append(operation["operationId"])
return operations
def _create_model(self, name: str, schema: dict[str, Any]) -> type[BaseModel]: def _create_model(self, name: str, schema: dict[str, Any]) -> type[BaseModel]:
"""Create Pydantic model from schema. """Create Pydantic model from schema.

View File

@@ -1,61 +1,37 @@
import logging
import os import os
from enum import Enum
from typing import Any from typing import Any
from mcp.server import Server from mcp.server import Server
from mcp.server.stdio import stdio_server from mcp.server.stdio import stdio_server
from mcp.types import TextContent, Tool from mcp.types import TextContent, Tool
from airflow_mcp_server.tools.models import ListDags from airflow_mcp_server.tools.tool_manager import get_airflow_tools, get_tool
from airflow_mcp_server.tools.tool_manager import get_airflow_dag_tools
logger = logging.getLogger(__name__)
class AirflowAPITools(str, Enum):
# DAG Operations
LIST_DAGS = "list_dags"
async def process_instruction(instruction: dict[str, Any]) -> dict[str, Any]:
dag_tools = get_airflow_dag_tools()
try:
match instruction["type"]:
case "list_dags":
return {"dags": await dag_tools.list_dags()}
case _:
return {"message": "Invalid instruction type"}
except Exception as e:
return {"error": str(e)}
async def serve() -> None: async def serve() -> None:
"""Start MCP server."""
required_vars = ["OPENAPI_SPEC", "AIRFLOW_BASE_URL", "AUTH_TOKEN"]
if not all(var in os.environ for var in required_vars):
raise ValueError(f"Missing required environment variables: {required_vars}")
server = Server("airflow-mcp-server") server = Server("airflow-mcp-server")
@server.list_tools() @server.list_tools()
async def list_tools() -> list[Tool]: async def list_tools() -> list[Tool]:
tools = [ return get_airflow_tools()
# DAG Operations
Tool(
name=AirflowAPITools.LIST_DAGS,
description="Lists all DAGs in Airflow",
inputSchema=ListDags.model_json_schema(),
),
]
if "AIRFLOW_BASE_URL" in os.environ and "AUTH_TOKEN" in os.environ:
return tools
else:
return []
@server.call_tool() @server.call_tool()
async def call_tool(name: str, arguments: dict) -> list[TextContent]: async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]:
dag_tools = get_airflow_dag_tools() try:
tool = get_tool(name)
match name: result = await tool.run(**arguments)
case AirflowAPITools.LIST_DAGS: return [TextContent(type="text", text=str(result))]
result = await dag_tools.list_dags() except Exception as e:
return [TextContent(type="text", text=result)] logger.error("Tool execution failed: %s", e)
case _: raise
raise ValueError(f"Unknown tool: {name}")
options = server.create_initialization_options() options = server.create_initialization_options()
async with stdio_server() as (read_stream, write_stream): async with stdio_server() as (read_stream, write_stream):

View File

@@ -1,103 +0,0 @@
import os
import aiohttp
class AirflowDagTools:
def __init__(self):
self.airflow_base_url = os.getenv("AIRFLOW_BASE_URL")
self.auth_token = os.getenv("AUTH_TOKEN")
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.auth_token}"}
async def list_dags(
self,
limit: int | None = 100,
offset: int | None = None,
order_by: str | None = None,
tags: list[str] | None = None,
only_active: bool = True,
paused: bool | None = None,
fields: list[str] | None = None,
dag_id_pattern: str | None = None,
) -> list[str]:
"""
List all DAGs in Airflow.
Sample response:
{
"dags": [
{
"dag_id": "string",
"dag_display_name": "string",
"root_dag_id": "string",
"is_paused": true,
"is_active": true,
"is_subdag": true,
"last_parsed_time": "2019-08-24T14:15:22Z",
"last_pickled": "2019-08-24T14:15:22Z",
"last_expired": "2019-08-24T14:15:22Z",
"scheduler_lock": true,
"pickle_id": "string",
"default_view": "string",
"fileloc": "string",
"file_token": "string",
"owners": [
"string"
],
"description": "string",
"schedule_interval": {
"__type": "string",
"days": 0,
"seconds": 0,
"microseconds": 0
},
"timetable_description": "string",
"tags": [
{
"name": "string"
}
],
"max_active_tasks": 0,
"max_active_runs": 0,
"has_task_concurrency_limits": true,
"has_import_errors": true,
"next_dagrun": "2019-08-24T14:15:22Z",
"next_dagrun_data_interval_start": "2019-08-24T14:15:22Z",
"next_dagrun_data_interval_end": "2019-08-24T14:15:22Z",
"next_dagrun_create_after": "2019-08-24T14:15:22Z",
"max_consecutive_failed_dag_runs": 0
}
],
"total_entries": 0
}
Args:
limit (int, optional): The numbers of items to return.
offset (int, optional): The number of items to skip before starting to collect the result set.
order_by (str, optional): The name of the field to order the results by. Prefix a field name with - to reverse the sort order. New in version 2.1.0
tags (list[str], optional): List of tags to filter results. New in version 2.2.0
only_active (bool, optional): Only filter active DAGs. New in version 2.1.1
paused (bool, optional): Only filter paused/unpaused DAGs. If absent or null, it returns paused and unpaused DAGs. New in version 2.6.0
fields (list[str], optional): List of field for return.
dag_id_pattern (str, optional): If set, only return DAGs with dag_ids matching this pattern.
Returns:
list[str]: A list of DAG names.
"""
dags = []
async with aiohttp.ClientSession() as session:
params = {
"limit": limit,
"offset": offset,
"order_by": order_by,
"tags": tags,
"only_active": only_active,
"paused": paused,
"fields": fields,
"dag_id_pattern": dag_id_pattern,
}
async with session.get(f"{self.airflow_base_url}/api/v1/dags", headers=self.headers, params=params) as response:
if response.status == 200:
dags = await response.json()
return dags

View File

@@ -32,14 +32,17 @@ def create_validation_error(field: str, message: str) -> ValidationError:
class AirflowTool(BaseTools): class AirflowTool(BaseTools):
"""Tool for executing Airflow API operations.""" """
Tool for executing Airflow API operations.
AirflowTool is supposed to have objects per operation.
"""
def __init__(self, operation_details: OperationDetails, client: AirflowClient) -> None: def __init__(self, operation_details: OperationDetails, client: AirflowClient) -> None:
"""Initialize tool with operation details and client. """Initialize tool with operation details and client.
Args: Args:
operation_details: Parsed operation details operation_details: Operation details
client: Configured Airflow API client client: AirflowClient instance
""" """
super().__init__() super().__init__()
self.operation = operation_details self.operation = operation_details

View File

@@ -1,22 +0,0 @@
from pydantic import BaseModel, model_validator
# DAG operations
# ====================================================================
class ListDags(BaseModel):
"""Parameters for listing DAGs."""
limit: int | None
offset: int | None
order_by: str | None
tags: list[str] | None
only_active: bool
paused: bool | None
fields: list[str] | None
dag_id_pattern: str | None
@model_validator(mode="after")
def validate_offset(self) -> "ListDags":
if self.offset is not None and self.offset < 0:
raise ValueError("offset must be non-negative")
return self

View File

@@ -1,243 +1,77 @@
"""Tool manager for handling Airflow API tool instantiation and caching."""
import asyncio
import logging import logging
from collections import OrderedDict import os
from pathlib import Path
from typing import Any from mcp.types import Tool
from airflow_mcp_server.client.airflow_client import AirflowClient from airflow_mcp_server.client.airflow_client import AirflowClient
from airflow_mcp_server.parser.operation_parser import OperationParser from airflow_mcp_server.parser.operation_parser import OperationParser
from airflow_mcp_server.tools.airflow_dag_tools import AirflowDagTools
from airflow_mcp_server.tools.airflow_tool import AirflowTool from airflow_mcp_server.tools.airflow_tool import AirflowTool
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Keep existing function for backward compatibility _tools_cache: dict[str, AirflowTool] = {}
_dag_tools: AirflowDagTools | None = None _client: AirflowClient | None = None
def get_airflow_dag_tools() -> AirflowDagTools: def get_airflow_tools() -> list[Tool]:
global _dag_tools """Get list of all available Airflow tools.
if not _dag_tools:
_dag_tools = AirflowDagTools()
return _dag_tools
Returns:
class ToolManagerError(Exception): List of MCP Tool objects representing available operations
"""Base exception for tool manager errors."""
pass
class ToolInitializationError(ToolManagerError):
"""Error during tool initialization."""
pass
class ToolNotFoundError(ToolManagerError):
"""Error when requested tool is not found."""
pass
class ToolManager:
"""Manager for Airflow API tools with caching and lifecycle management.
This class provides a centralized way to manage Airflow API tools with:
- Singleton client management
- Tool caching with size limits
- Thread-safe access
- Proper error handling
"""
def __init__(
self,
spec_path: Path | str | object,
base_url: str,
auth_token: str,
max_cache_size: int = 100,
) -> None:
"""Initialize tool manager.
Args:
spec_path: Path to OpenAPI spec file or file-like object
base_url: Base URL for Airflow API
auth_token: Authentication token
max_cache_size: Maximum number of tools to cache (default: 100)
Raises: Raises:
ToolManagerError: If initialization fails ValueError: If required environment variables are missing
ToolInitializationError: If client or parser initialization fails
""" """
try: global _tools_cache, _client
# Validate inputs
if not spec_path:
raise ValueError("spec_path is required")
if not base_url:
raise ValueError("base_url is required")
if not auth_token:
raise ValueError("auth_token is required")
if max_cache_size < 1:
raise ValueError("max_cache_size must be positive")
# Store configuration if not _tools_cache:
self._spec_path = spec_path required_vars = ["OPENAPI_SPEC", "AIRFLOW_BASE_URL", "AUTH_TOKEN"]
self._base_url = base_url.rstrip("/") if not all(var in os.environ for var in required_vars):
self._auth_token = auth_token raise ValueError(f"Missing required environment variables: {required_vars}")
self._max_cache_size = max_cache_size
# Initialize client if not exists
if not _client:
_client = AirflowClient(spec_path=os.environ["OPENAPI_SPEC"], base_url=os.environ["AIRFLOW_BASE_URL"], auth_token=os.environ["AUTH_TOKEN"])
# Initialize core components with proper error handling
try: try:
self._client = AirflowClient(spec_path, self._base_url, auth_token) # Create parser
logger.info("AirflowClient initialized successfully") parser = OperationParser(os.environ["OPENAPI_SPEC"])
# Generate tools for each operation
for operation_id in parser.get_operations():
operation_details = parser.parse_operation(operation_id)
tool = AirflowTool(operation_details, _client)
_tools_cache[operation_id] = tool
except Exception as e: except Exception as e:
logger.error("Failed to initialize AirflowClient: %s", e) logger.error("Failed to initialize tools: %s", e)
raise ToolInitializationError(f"Failed to initialize client: {e}") from e raise
try: # Convert to MCP Tool format
self._parser = OperationParser(spec_path) return [
logger.info("OperationParser initialized successfully") Tool(
except Exception as e: name=operation_id,
logger.error("Failed to initialize OperationParser: %s", e) description=tool.operation.operation_id,
raise ToolInitializationError(f"Failed to initialize parser: {e}") from e inputSchema=tool.operation.request_body.model_json_schema() if tool.operation.request_body else None,
)
for operation_id, tool in _tools_cache.items()
]
# Setup thread safety and caching
self._lock = asyncio.Lock()
self._tool_cache: OrderedDict[str, AirflowTool] = OrderedDict()
logger.info("Tool manager initialized successfully (cache_size=%d, base_url=%s)", max_cache_size, self._base_url) def get_tool(name: str) -> AirflowTool:
except ValueError as e: """Get specific tool by name.
logger.error("Invalid configuration: %s", e)
raise ToolManagerError(f"Invalid configuration: {e}") from e
except Exception as e:
logger.error("Failed to initialize tool manager: %s", e)
raise ToolInitializationError(f"Component initialization failed: {e}") from e
async def __aenter__(self) -> "ToolManager":
"""Enter async context."""
try:
if not hasattr(self, "_client"):
logger.error("Client not initialized")
raise ToolInitializationError("Client not initialized")
await self._client.__aenter__()
return self
except Exception as e:
logger.error("Failed to enter async context: %s", e)
if hasattr(self, "_client"):
try:
await self._client.__aexit__(None, None, None)
except Exception:
pass
raise ToolInitializationError(f"Failed to initialize client session: {e}") from e
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
"""Exit async context."""
try:
await self._client.__aexit__(exc_type, exc_val, exc_tb)
except Exception as e:
logger.error("Error during context exit: %s", e)
finally:
self.clear_cache()
def _evict_cache_if_needed(self) -> None:
"""Evict oldest items from cache if size limit is reached.
This method uses FIFO (First In First Out) eviction policy.
Eviction occurs when cache reaches max_cache_size.
"""
current_size = len(self._tool_cache)
if current_size >= self._max_cache_size:
logger.info("Cache limit reached (%d/%d). Starting eviction.", current_size, self._max_cache_size)
logger.debug("Current cache contents: %s", list(self._tool_cache.keys()))
evicted_count = 0
while len(self._tool_cache) >= self._max_cache_size:
operation_id, _ = self._tool_cache.popitem(last=False)
evicted_count += 1
logger.debug("Evicted tool %s from cache (size: %d/%d)", operation_id, len(self._tool_cache), self._max_cache_size)
if evicted_count > 0:
logger.info("Evicted %d tools from cache", evicted_count)
async def get_tool(self, operation_id: str) -> AirflowTool:
"""Get or create a tool instance for the given operation.
Args: Args:
operation_id: Operation ID from OpenAPI spec name: Tool/operation name
Returns: Returns:
AirflowTool instance AirflowTool instance
Raises: Raises:
ToolNotFoundError: If operation not found KeyError: If tool not found
ToolInitializationError: If tool creation fails
ValueError: If operation_id is invalid
""" """
if not operation_id or not isinstance(operation_id, str): if name not in _tools_cache:
logger.error("Invalid operation_id provided: %s", operation_id) # Ensure cache is populated
raise ValueError("Invalid operation_id") get_airflow_tools()
if not hasattr(self, "_client") or not hasattr(self, "_parser"): return _tools_cache[name]
logger.error("ToolManager not properly initialized")
raise ToolInitializationError("ToolManager components not initialized")
logger.debug("Requesting tool for operation: %s", operation_id)
async with self._lock:
# Check cache first
if operation_id in self._tool_cache:
logger.debug("Tool cache hit for %s", operation_id)
return self._tool_cache[operation_id]
logger.debug("Tool cache miss for %s, creating new instance", operation_id)
try:
# Parse operation details
try:
operation_details = self._parser.parse_operation(operation_id)
except ValueError as e:
logger.error("Operation %s not found: %s", operation_id, e)
raise ToolNotFoundError(f"Operation {operation_id} not found") from e
except Exception as e:
logger.error("Failed to parse operation %s: %s", operation_id, e)
raise ToolInitializationError(f"Operation parsing failed: {e}") from e
# Create new tool instance
try:
tool = AirflowTool(operation_details, self._client)
except Exception as e:
logger.error("Failed to create tool instance for %s: %s", operation_id, e)
raise ToolInitializationError(f"Tool instantiation failed: {e}") from e
# Update cache
self._evict_cache_if_needed()
self._tool_cache[operation_id] = tool
logger.info("Created and cached new tool for %s", operation_id)
return tool
except (ToolNotFoundError, ToolInitializationError):
raise
except Exception as e:
logger.error("Unexpected error creating tool for %s: %s", operation_id, e)
raise ToolInitializationError(f"Unexpected error: {e}") from e
def clear_cache(self) -> None:
"""Clear the tool cache."""
self._tool_cache.clear()
logger.debug("Tool cache cleared")
@property
def cache_info(self) -> dict[str, Any]:
"""Get cache statistics.
Returns:
Dictionary with cache statistics
"""
return {
"size": len(self._tool_cache),
"max_size": self._max_cache_size,
"operations": list(self._tool_cache.keys()),
}