refactor for dynamic Models

This commit is contained in:
2025-02-13 10:03:23 +00:00
parent 18962618bc
commit b87fe176bd
6 changed files with 90 additions and 385 deletions

View File

@@ -259,6 +259,23 @@ class OperationParser:
return self._create_model("Response", schema) return self._create_model("Response", schema)
def get_operations(self) -> list[str]:
"""Get list of all operation IDs from spec.
Returns:
List of operation IDs
"""
operations = []
for path in self._paths.values():
for method, operation in path.items():
if method.startswith("x-") or method == "parameters":
continue
if "operationId" in operation:
operations.append(operation["operationId"])
return operations
def _create_model(self, name: str, schema: dict[str, Any]) -> type[BaseModel]: def _create_model(self, name: str, schema: dict[str, Any]) -> type[BaseModel]:
"""Create Pydantic model from schema. """Create Pydantic model from schema.

View File

@@ -1,61 +1,37 @@
import logging
import os import os
from enum import Enum
from typing import Any from typing import Any
from mcp.server import Server from mcp.server import Server
from mcp.server.stdio import stdio_server from mcp.server.stdio import stdio_server
from mcp.types import TextContent, Tool from mcp.types import TextContent, Tool
from airflow_mcp_server.tools.models import ListDags from airflow_mcp_server.tools.tool_manager import get_airflow_tools, get_tool
from airflow_mcp_server.tools.tool_manager import get_airflow_dag_tools
logger = logging.getLogger(__name__)
class AirflowAPITools(str, Enum):
# DAG Operations
LIST_DAGS = "list_dags"
async def process_instruction(instruction: dict[str, Any]) -> dict[str, Any]:
dag_tools = get_airflow_dag_tools()
try:
match instruction["type"]:
case "list_dags":
return {"dags": await dag_tools.list_dags()}
case _:
return {"message": "Invalid instruction type"}
except Exception as e:
return {"error": str(e)}
async def serve() -> None: async def serve() -> None:
"""Start MCP server."""
required_vars = ["OPENAPI_SPEC", "AIRFLOW_BASE_URL", "AUTH_TOKEN"]
if not all(var in os.environ for var in required_vars):
raise ValueError(f"Missing required environment variables: {required_vars}")
server = Server("airflow-mcp-server") server = Server("airflow-mcp-server")
@server.list_tools() @server.list_tools()
async def list_tools() -> list[Tool]: async def list_tools() -> list[Tool]:
tools = [ return get_airflow_tools()
# DAG Operations
Tool(
name=AirflowAPITools.LIST_DAGS,
description="Lists all DAGs in Airflow",
inputSchema=ListDags.model_json_schema(),
),
]
if "AIRFLOW_BASE_URL" in os.environ and "AUTH_TOKEN" in os.environ:
return tools
else:
return []
@server.call_tool() @server.call_tool()
async def call_tool(name: str, arguments: dict) -> list[TextContent]: async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]:
dag_tools = get_airflow_dag_tools() try:
tool = get_tool(name)
match name: result = await tool.run(**arguments)
case AirflowAPITools.LIST_DAGS: return [TextContent(type="text", text=str(result))]
result = await dag_tools.list_dags() except Exception as e:
return [TextContent(type="text", text=result)] logger.error("Tool execution failed: %s", e)
case _: raise
raise ValueError(f"Unknown tool: {name}")
options = server.create_initialization_options() options = server.create_initialization_options()
async with stdio_server() as (read_stream, write_stream): async with stdio_server() as (read_stream, write_stream):

View File

@@ -1,103 +0,0 @@
import os
import aiohttp
class AirflowDagTools:
def __init__(self):
self.airflow_base_url = os.getenv("AIRFLOW_BASE_URL")
self.auth_token = os.getenv("AUTH_TOKEN")
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.auth_token}"}
async def list_dags(
self,
limit: int | None = 100,
offset: int | None = None,
order_by: str | None = None,
tags: list[str] | None = None,
only_active: bool = True,
paused: bool | None = None,
fields: list[str] | None = None,
dag_id_pattern: str | None = None,
) -> list[str]:
"""
List all DAGs in Airflow.
Sample response:
{
"dags": [
{
"dag_id": "string",
"dag_display_name": "string",
"root_dag_id": "string",
"is_paused": true,
"is_active": true,
"is_subdag": true,
"last_parsed_time": "2019-08-24T14:15:22Z",
"last_pickled": "2019-08-24T14:15:22Z",
"last_expired": "2019-08-24T14:15:22Z",
"scheduler_lock": true,
"pickle_id": "string",
"default_view": "string",
"fileloc": "string",
"file_token": "string",
"owners": [
"string"
],
"description": "string",
"schedule_interval": {
"__type": "string",
"days": 0,
"seconds": 0,
"microseconds": 0
},
"timetable_description": "string",
"tags": [
{
"name": "string"
}
],
"max_active_tasks": 0,
"max_active_runs": 0,
"has_task_concurrency_limits": true,
"has_import_errors": true,
"next_dagrun": "2019-08-24T14:15:22Z",
"next_dagrun_data_interval_start": "2019-08-24T14:15:22Z",
"next_dagrun_data_interval_end": "2019-08-24T14:15:22Z",
"next_dagrun_create_after": "2019-08-24T14:15:22Z",
"max_consecutive_failed_dag_runs": 0
}
],
"total_entries": 0
}
Args:
limit (int, optional): The numbers of items to return.
offset (int, optional): The number of items to skip before starting to collect the result set.
order_by (str, optional): The name of the field to order the results by. Prefix a field name with - to reverse the sort order. New in version 2.1.0
tags (list[str], optional): List of tags to filter results. New in version 2.2.0
only_active (bool, optional): Only filter active DAGs. New in version 2.1.1
paused (bool, optional): Only filter paused/unpaused DAGs. If absent or null, it returns paused and unpaused DAGs. New in version 2.6.0
fields (list[str], optional): List of field for return.
dag_id_pattern (str, optional): If set, only return DAGs with dag_ids matching this pattern.
Returns:
list[str]: A list of DAG names.
"""
dags = []
async with aiohttp.ClientSession() as session:
params = {
"limit": limit,
"offset": offset,
"order_by": order_by,
"tags": tags,
"only_active": only_active,
"paused": paused,
"fields": fields,
"dag_id_pattern": dag_id_pattern,
}
async with session.get(f"{self.airflow_base_url}/api/v1/dags", headers=self.headers, params=params) as response:
if response.status == 200:
dags = await response.json()
return dags

View File

@@ -32,14 +32,17 @@ def create_validation_error(field: str, message: str) -> ValidationError:
class AirflowTool(BaseTools): class AirflowTool(BaseTools):
"""Tool for executing Airflow API operations.""" """
Tool for executing Airflow API operations.
AirflowTool is supposed to have objects per operation.
"""
def __init__(self, operation_details: OperationDetails, client: AirflowClient) -> None: def __init__(self, operation_details: OperationDetails, client: AirflowClient) -> None:
"""Initialize tool with operation details and client. """Initialize tool with operation details and client.
Args: Args:
operation_details: Parsed operation details operation_details: Operation details
client: Configured Airflow API client client: AirflowClient instance
""" """
super().__init__() super().__init__()
self.operation = operation_details self.operation = operation_details

View File

@@ -1,22 +0,0 @@
from pydantic import BaseModel, model_validator
# DAG operations
# ====================================================================
class ListDags(BaseModel):
"""Parameters for listing DAGs."""
limit: int | None
offset: int | None
order_by: str | None
tags: list[str] | None
only_active: bool
paused: bool | None
fields: list[str] | None
dag_id_pattern: str | None
@model_validator(mode="after")
def validate_offset(self) -> "ListDags":
if self.offset is not None and self.offset < 0:
raise ValueError("offset must be non-negative")
return self

View File

@@ -1,243 +1,77 @@
"""Tool manager for handling Airflow API tool instantiation and caching."""
import asyncio
import logging import logging
from collections import OrderedDict import os
from pathlib import Path
from typing import Any from mcp.types import Tool
from airflow_mcp_server.client.airflow_client import AirflowClient from airflow_mcp_server.client.airflow_client import AirflowClient
from airflow_mcp_server.parser.operation_parser import OperationParser from airflow_mcp_server.parser.operation_parser import OperationParser
from airflow_mcp_server.tools.airflow_dag_tools import AirflowDagTools
from airflow_mcp_server.tools.airflow_tool import AirflowTool from airflow_mcp_server.tools.airflow_tool import AirflowTool
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Keep existing function for backward compatibility _tools_cache: dict[str, AirflowTool] = {}
_dag_tools: AirflowDagTools | None = None _client: AirflowClient | None = None
def get_airflow_dag_tools() -> AirflowDagTools: def get_airflow_tools() -> list[Tool]:
global _dag_tools """Get list of all available Airflow tools.
if not _dag_tools:
_dag_tools = AirflowDagTools()
return _dag_tools
Returns:
List of MCP Tool objects representing available operations
class ToolManagerError(Exception): Raises:
"""Base exception for tool manager errors.""" ValueError: If required environment variables are missing
pass
class ToolInitializationError(ToolManagerError):
"""Error during tool initialization."""
pass
class ToolNotFoundError(ToolManagerError):
"""Error when requested tool is not found."""
pass
class ToolManager:
"""Manager for Airflow API tools with caching and lifecycle management.
This class provides a centralized way to manage Airflow API tools with:
- Singleton client management
- Tool caching with size limits
- Thread-safe access
- Proper error handling
""" """
global _tools_cache, _client
def __init__( if not _tools_cache:
self, required_vars = ["OPENAPI_SPEC", "AIRFLOW_BASE_URL", "AUTH_TOKEN"]
spec_path: Path | str | object, if not all(var in os.environ for var in required_vars):
base_url: str, raise ValueError(f"Missing required environment variables: {required_vars}")
auth_token: str,
max_cache_size: int = 100,
) -> None:
"""Initialize tool manager.
Args: # Initialize client if not exists
spec_path: Path to OpenAPI spec file or file-like object if not _client:
base_url: Base URL for Airflow API _client = AirflowClient(spec_path=os.environ["OPENAPI_SPEC"], base_url=os.environ["AIRFLOW_BASE_URL"], auth_token=os.environ["AUTH_TOKEN"])
auth_token: Authentication token
max_cache_size: Maximum number of tools to cache (default: 100)
Raises:
ToolManagerError: If initialization fails
ToolInitializationError: If client or parser initialization fails
"""
try: try:
# Validate inputs # Create parser
if not spec_path: parser = OperationParser(os.environ["OPENAPI_SPEC"])
raise ValueError("spec_path is required")
if not base_url:
raise ValueError("base_url is required")
if not auth_token:
raise ValueError("auth_token is required")
if max_cache_size < 1:
raise ValueError("max_cache_size must be positive")
# Store configuration # Generate tools for each operation
self._spec_path = spec_path for operation_id in parser.get_operations():
self._base_url = base_url.rstrip("/") operation_details = parser.parse_operation(operation_id)
self._auth_token = auth_token tool = AirflowTool(operation_details, _client)
self._max_cache_size = max_cache_size _tools_cache[operation_id] = tool
# Initialize core components with proper error handling
try:
self._client = AirflowClient(spec_path, self._base_url, auth_token)
logger.info("AirflowClient initialized successfully")
except Exception as e:
logger.error("Failed to initialize AirflowClient: %s", e)
raise ToolInitializationError(f"Failed to initialize client: {e}") from e
try:
self._parser = OperationParser(spec_path)
logger.info("OperationParser initialized successfully")
except Exception as e:
logger.error("Failed to initialize OperationParser: %s", e)
raise ToolInitializationError(f"Failed to initialize parser: {e}") from e
# Setup thread safety and caching
self._lock = asyncio.Lock()
self._tool_cache: OrderedDict[str, AirflowTool] = OrderedDict()
logger.info("Tool manager initialized successfully (cache_size=%d, base_url=%s)", max_cache_size, self._base_url)
except ValueError as e:
logger.error("Invalid configuration: %s", e)
raise ToolManagerError(f"Invalid configuration: {e}") from e
except Exception as e: except Exception as e:
logger.error("Failed to initialize tool manager: %s", e) logger.error("Failed to initialize tools: %s", e)
raise ToolInitializationError(f"Component initialization failed: {e}") from e raise
async def __aenter__(self) -> "ToolManager": # Convert to MCP Tool format
"""Enter async context.""" return [
try: Tool(
if not hasattr(self, "_client"): name=operation_id,
logger.error("Client not initialized") description=tool.operation.operation_id,
raise ToolInitializationError("Client not initialized") inputSchema=tool.operation.request_body.model_json_schema() if tool.operation.request_body else None,
await self._client.__aenter__() )
return self for operation_id, tool in _tools_cache.items()
except Exception as e: ]
logger.error("Failed to enter async context: %s", e)
if hasattr(self, "_client"):
try:
await self._client.__aexit__(None, None, None)
except Exception:
pass
raise ToolInitializationError(f"Failed to initialize client session: {e}") from e
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
"""Exit async context."""
try:
await self._client.__aexit__(exc_type, exc_val, exc_tb)
except Exception as e:
logger.error("Error during context exit: %s", e)
finally:
self.clear_cache()
def _evict_cache_if_needed(self) -> None: def get_tool(name: str) -> AirflowTool:
"""Evict oldest items from cache if size limit is reached. """Get specific tool by name.
This method uses FIFO (First In First Out) eviction policy. Args:
Eviction occurs when cache reaches max_cache_size. name: Tool/operation name
"""
current_size = len(self._tool_cache)
if current_size >= self._max_cache_size:
logger.info("Cache limit reached (%d/%d). Starting eviction.", current_size, self._max_cache_size)
logger.debug("Current cache contents: %s", list(self._tool_cache.keys()))
evicted_count = 0 Returns:
while len(self._tool_cache) >= self._max_cache_size: AirflowTool instance
operation_id, _ = self._tool_cache.popitem(last=False)
evicted_count += 1
logger.debug("Evicted tool %s from cache (size: %d/%d)", operation_id, len(self._tool_cache), self._max_cache_size)
if evicted_count > 0: Raises:
logger.info("Evicted %d tools from cache", evicted_count) KeyError: If tool not found
"""
if name not in _tools_cache:
# Ensure cache is populated
get_airflow_tools()
async def get_tool(self, operation_id: str) -> AirflowTool: return _tools_cache[name]
"""Get or create a tool instance for the given operation.
Args:
operation_id: Operation ID from OpenAPI spec
Returns:
AirflowTool instance
Raises:
ToolNotFoundError: If operation not found
ToolInitializationError: If tool creation fails
ValueError: If operation_id is invalid
"""
if not operation_id or not isinstance(operation_id, str):
logger.error("Invalid operation_id provided: %s", operation_id)
raise ValueError("Invalid operation_id")
if not hasattr(self, "_client") or not hasattr(self, "_parser"):
logger.error("ToolManager not properly initialized")
raise ToolInitializationError("ToolManager components not initialized")
logger.debug("Requesting tool for operation: %s", operation_id)
async with self._lock:
# Check cache first
if operation_id in self._tool_cache:
logger.debug("Tool cache hit for %s", operation_id)
return self._tool_cache[operation_id]
logger.debug("Tool cache miss for %s, creating new instance", operation_id)
try:
# Parse operation details
try:
operation_details = self._parser.parse_operation(operation_id)
except ValueError as e:
logger.error("Operation %s not found: %s", operation_id, e)
raise ToolNotFoundError(f"Operation {operation_id} not found") from e
except Exception as e:
logger.error("Failed to parse operation %s: %s", operation_id, e)
raise ToolInitializationError(f"Operation parsing failed: {e}") from e
# Create new tool instance
try:
tool = AirflowTool(operation_details, self._client)
except Exception as e:
logger.error("Failed to create tool instance for %s: %s", operation_id, e)
raise ToolInitializationError(f"Tool instantiation failed: {e}") from e
# Update cache
self._evict_cache_if_needed()
self._tool_cache[operation_id] = tool
logger.info("Created and cached new tool for %s", operation_id)
return tool
except (ToolNotFoundError, ToolInitializationError):
raise
except Exception as e:
logger.error("Unexpected error creating tool for %s: %s", operation_id, e)
raise ToolInitializationError(f"Unexpected error: {e}") from e
def clear_cache(self) -> None:
"""Clear the tool cache."""
self._tool_cache.clear()
logger.debug("Tool cache cleared")
@property
def cache_info(self) -> dict[str, Any]:
"""Get cache statistics.
Returns:
Dictionary with cache statistics
"""
return {
"size": len(self._tool_cache),
"max_size": self._max_cache_size,
"operations": list(self._tool_cache.keys()),
}