Clean up for only MCP Server
This commit is contained in:
0
src/airflow_mcp_server/tools/__init__.py
Normal file
0
src/airflow_mcp_server/tools/__init__.py
Normal file
81
src/airflow_mcp_server/tools/airflow_tool.py
Normal file
81
src/airflow_mcp_server/tools/airflow_tool.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from airflow_mcp_server.client.airflow_client import AirflowClient
|
||||
from airflow_mcp_server.parser.operation_parser import OperationDetails
|
||||
from airflow_mcp_server.tools.base_tools import BaseTools
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_validation_error(field: str, message: str) -> ValidationError:
|
||||
"""Create a properly formatted validation error.
|
||||
|
||||
Args:
|
||||
field: The field that failed validation
|
||||
message: The error message
|
||||
|
||||
Returns:
|
||||
ValidationError: A properly formatted validation error
|
||||
"""
|
||||
errors = [
|
||||
{
|
||||
"loc": (field,),
|
||||
"msg": message,
|
||||
"type": "value_error",
|
||||
"input": None,
|
||||
"ctx": {"error": message},
|
||||
}
|
||||
]
|
||||
return ValidationError.from_exception_data("validation_error", errors)
|
||||
|
||||
|
||||
class AirflowTool(BaseTools):
|
||||
"""
|
||||
Tool for executing Airflow API operations.
|
||||
AirflowTool is supposed to have objects per operation.
|
||||
"""
|
||||
|
||||
def __init__(self, operation_details: OperationDetails, client: AirflowClient) -> None:
|
||||
"""Initialize tool with operation details and client.
|
||||
|
||||
Args:
|
||||
operation_details: Operation details
|
||||
client: AirflowClient instance
|
||||
"""
|
||||
super().__init__()
|
||||
self.operation = operation_details
|
||||
self.client = client
|
||||
|
||||
async def run(
|
||||
self,
|
||||
body: dict[str, Any] | None = None,
|
||||
) -> Any:
|
||||
"""Execute the operation with provided parameters."""
|
||||
try:
|
||||
# Validate input
|
||||
validated_input = self.operation.input_model(**(body or {}))
|
||||
validated_body = validated_input.model_dump(exclude_none=True) # Only include non-None values
|
||||
|
||||
mapping = self.operation.input_model.model_config["parameter_mapping"]
|
||||
path_params = {k: validated_body[k] for k in mapping.get("path", []) if k in validated_body}
|
||||
query_params = {k: validated_body[k] for k in mapping.get("query", []) if k in validated_body}
|
||||
body_params = {k: validated_body[k] for k in mapping.get("body", []) if k in validated_body}
|
||||
|
||||
# Execute operation and return raw response
|
||||
response = await self.client.execute(
|
||||
operation_id=self.operation.operation_id,
|
||||
path_params=path_params or None,
|
||||
query_params=query_params or None,
|
||||
body=body_params or None,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except ValidationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Operation execution failed: %s", e)
|
||||
raise
|
||||
19
src/airflow_mcp_server/tools/base_tools.py
Normal file
19
src/airflow_mcp_server/tools/base_tools.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
class BaseTools(ABC):
|
||||
"""Abstract base class for tools."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the tool."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def run(self) -> Any:
|
||||
"""Execute the tool's main functionality.
|
||||
|
||||
Returns:
|
||||
Any: The result of the tool execution
|
||||
"""
|
||||
raise NotImplementedError
|
||||
126
src/airflow_mcp_server/tools/tool_manager.py
Normal file
126
src/airflow_mcp_server/tools/tool_manager.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import logging
|
||||
import os
|
||||
from importlib import resources
|
||||
|
||||
from mcp.types import Tool
|
||||
|
||||
from airflow_mcp_server.client.airflow_client import AirflowClient
|
||||
from airflow_mcp_server.parser.operation_parser import OperationParser
|
||||
from airflow_mcp_server.tools.airflow_tool import AirflowTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_tools_cache: dict[str, AirflowTool] = {}
|
||||
|
||||
|
||||
def _initialize_client() -> AirflowClient:
|
||||
"""Initialize Airflow client with environment variables or embedded spec.
|
||||
|
||||
Returns:
|
||||
AirflowClient instance
|
||||
|
||||
Raises:
|
||||
ValueError: If required environment variables are missing or default spec is not found
|
||||
"""
|
||||
spec_path = os.environ.get("OPENAPI_SPEC")
|
||||
if not spec_path:
|
||||
# Fallback to embedded v1.yaml
|
||||
try:
|
||||
with resources.files("airflow_mcp_server.resources").joinpath("v1.yaml").open("rb") as f:
|
||||
spec_path = f.name
|
||||
logger.info("OPENAPI_SPEC not set; using embedded v1.yaml from %s", spec_path)
|
||||
except Exception as e:
|
||||
raise ValueError("Default OpenAPI spec not found in package resources") from e
|
||||
|
||||
required_vars = ["AIRFLOW_BASE_URL", "AUTH_TOKEN"]
|
||||
missing_vars = [var for var in required_vars if var not in os.environ]
|
||||
if missing_vars:
|
||||
raise ValueError(f"Missing required environment variables: {missing_vars}")
|
||||
|
||||
return AirflowClient(spec_path=spec_path, base_url=os.environ["AIRFLOW_BASE_URL"], auth_token=os.environ["AUTH_TOKEN"])
|
||||
|
||||
|
||||
async def _initialize_tools() -> None:
|
||||
"""Initialize tools cache with Airflow operations.
|
||||
|
||||
Raises:
|
||||
ValueError: If initialization fails
|
||||
"""
|
||||
global _tools_cache
|
||||
|
||||
try:
|
||||
client = _initialize_client()
|
||||
spec_path = os.environ.get("OPENAPI_SPEC")
|
||||
if not spec_path:
|
||||
with resources.files("airflow_mcp_server.resources").joinpath("v1.yaml").open("rb") as f:
|
||||
spec_path = f.name
|
||||
parser = OperationParser(spec_path)
|
||||
|
||||
# Generate tools for each operation
|
||||
for operation_id in parser.get_operations():
|
||||
operation_details = parser.parse_operation(operation_id)
|
||||
tool = AirflowTool(operation_details, client)
|
||||
_tools_cache[operation_id] = tool
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to initialize tools: %s", e)
|
||||
_tools_cache.clear()
|
||||
raise ValueError(f"Failed to initialize tools: {e}") from e
|
||||
|
||||
|
||||
async def get_airflow_tools(mode: str = "unsafe") -> list[Tool]:
|
||||
"""Get list of available Airflow tools based on mode.
|
||||
|
||||
Args:
|
||||
mode: "safe" for GET operations only, "unsafe" for all operations (default)
|
||||
|
||||
Returns:
|
||||
List of MCP Tool objects representing available operations
|
||||
|
||||
Raises:
|
||||
ValueError: If initialization fails
|
||||
"""
|
||||
if not _tools_cache:
|
||||
await _initialize_tools()
|
||||
|
||||
tools = []
|
||||
for operation_id, tool in _tools_cache.items():
|
||||
try:
|
||||
# Skip non-GET operations in safe mode
|
||||
if mode == "safe" and not tool.operation.method.lower() == "get":
|
||||
continue
|
||||
schema = tool.operation.input_model.model_json_schema()
|
||||
tools.append(
|
||||
Tool(
|
||||
name=operation_id,
|
||||
description=tool.operation.operation_id,
|
||||
inputSchema=schema,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to create tool schema for %s: %s", operation_id, e)
|
||||
continue
|
||||
|
||||
return tools
|
||||
|
||||
|
||||
async def get_tool(name: str) -> AirflowTool:
|
||||
"""Get specific tool by name.
|
||||
|
||||
Args:
|
||||
name: Tool/operation name
|
||||
|
||||
Returns:
|
||||
AirflowTool instance
|
||||
|
||||
Raises:
|
||||
KeyError: If tool not found
|
||||
ValueError: If tool initialization fails
|
||||
"""
|
||||
if not _tools_cache:
|
||||
await _initialize_tools()
|
||||
|
||||
if name not in _tools_cache:
|
||||
raise KeyError(f"Tool {name} not found")
|
||||
|
||||
return _tools_cache[name]
|
||||
Reference in New Issue
Block a user