Clean up for only MCP Server

This commit is contained in:
2025-02-24 16:50:08 +00:00
parent 5d199ba154
commit 16cd3f48fe
52 changed files with 66 additions and 1317 deletions

View File

@@ -0,0 +1,35 @@
import asyncio
import logging
import sys
import click
from airflow_mcp_server.server_safe import serve as serve_safe
from airflow_mcp_server.server_unsafe import serve as serve_unsafe
@click.command()
@click.option("-v", "--verbose", count=True, help="Increase verbosity")
@click.option("--safe", "-s", is_flag=True, help="Use only read-only tools")
@click.option("--unsafe", "-u", is_flag=True, help="Use all tools (default)")
def main(verbose: int, safe: bool, unsafe: bool) -> None:
"""MCP server for Airflow"""
logging_level = logging.WARN
if verbose == 1:
logging_level = logging.INFO
elif verbose >= 2:
logging_level = logging.DEBUG
logging.basicConfig(level=logging_level, stream=sys.stderr)
if safe and unsafe:
raise click.UsageError("Options --safe and --unsafe are mutually exclusive")
if safe:
asyncio.run(serve_safe())
else: # Default to unsafe mode
asyncio.run(serve_unsafe())
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,3 @@
from airflow_mcp_server import main
main()

View File

@@ -0,0 +1,3 @@
from airflow_mcp_server.client.airflow_client import AirflowClient
__all__ = ["AirflowClient"]

View File

@@ -0,0 +1,248 @@
import logging
import re
from pathlib import Path
from types import SimpleNamespace
from typing import Any, BinaryIO, TextIO
import aiohttp
import yaml
from jsonschema_path import SchemaPath
from openapi_core import OpenAPI
from openapi_core.validation.request.validators import V31RequestValidator
from openapi_spec_validator import validate
logger = logging.getLogger(__name__)
def camel_to_snake(name: str) -> str:
"""Convert camelCase to snake_case."""
name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower()
def convert_dict_keys(d: dict) -> dict:
"""Recursively convert dictionary keys from camelCase to snake_case."""
if not isinstance(d, dict):
return d
return {camel_to_snake(k): convert_dict_keys(v) if isinstance(v, dict) else v for k, v in d.items()}
class AirflowClient:
"""Client for interacting with Airflow API."""
def __init__(
self,
spec_path: Path | str | dict | bytes | BinaryIO | TextIO,
base_url: str,
auth_token: str,
) -> None:
"""Initialize Airflow client.
Args:
spec_path: OpenAPI spec as file path, dict, bytes, or file object
base_url: Base URL for API
auth_token: Authentication token
Raises:
ValueError: If spec_path is invalid or spec cannot be loaded
"""
try:
# Load and parse OpenAPI spec
if isinstance(spec_path, dict):
self.raw_spec = spec_path
elif isinstance(spec_path, bytes):
self.raw_spec = yaml.safe_load(spec_path)
elif isinstance(spec_path, str | Path):
with open(spec_path) as f:
self.raw_spec = yaml.safe_load(f)
elif hasattr(spec_path, "read"):
content = spec_path.read()
if isinstance(content, bytes):
self.raw_spec = yaml.safe_load(content)
else:
self.raw_spec = yaml.safe_load(content)
else:
raise ValueError("Invalid spec_path type. Expected Path, str, dict, bytes or file-like object")
# Validate spec has required fields
if not isinstance(self.raw_spec, dict):
raise ValueError("OpenAPI spec must be a dictionary")
required_fields = ["openapi", "info", "paths"]
for field in required_fields:
if field not in self.raw_spec:
raise ValueError(f"OpenAPI spec missing required field: {field}")
# Validate OpenAPI spec format
validate(self.raw_spec)
# Initialize OpenAPI spec
self.spec = OpenAPI.from_dict(self.raw_spec)
logger.debug("OpenAPI spec loaded successfully")
# Debug raw spec
logger.debug("Raw spec keys: %s", self.raw_spec.keys())
# Get paths from raw spec
if "paths" not in self.raw_spec:
raise ValueError("OpenAPI spec does not contain paths information")
self._paths = self.raw_spec["paths"]
logger.debug("Using raw spec paths")
# Initialize request validator with schema path
schema_path = SchemaPath.from_dict(self.raw_spec)
self._validator = V31RequestValidator(schema_path)
# API configuration
self.base_url = base_url.rstrip("/")
self.headers = {
"Authorization": f"Basic {auth_token}",
"Accept": "application/json",
}
except Exception as e:
logger.error("Failed to initialize AirflowClient: %s", e)
raise ValueError(f"Failed to initialize client: {e}")
async def __aenter__(self) -> "AirflowClient":
self._session = aiohttp.ClientSession(headers=self.headers)
return self
async def __aexit__(self, *exc) -> None:
if hasattr(self, "_session"):
await self._session.close()
delattr(self, "_session")
def _get_operation(self, operation_id: str) -> tuple[str, str, SimpleNamespace]:
"""Get operation details from OpenAPI spec.
Args:
operation_id: The operation ID to look up
Returns:
Tuple of (path, method, operation) where operation is a SimpleNamespace object
Raises:
ValueError: If operation not found
"""
try:
# Debug the paths structure
logger.debug("Looking for operation %s in paths", operation_id)
for path, path_item in self._paths.items():
for method, operation_data in path_item.items():
# Skip non-operation fields
if method.startswith("x-") or method == "parameters":
continue
# Debug each operation
logger.debug("Checking %s %s: %s", method, path, operation_data.get("operationId"))
if operation_data.get("operationId") == operation_id:
logger.debug("Found operation %s at %s %s", operation_id, method, path)
# Convert keys to snake_case and create object
converted_data = convert_dict_keys(operation_data)
operation_obj = SimpleNamespace(**converted_data)
return path, method, operation_obj
raise ValueError(f"Operation {operation_id} not found in spec")
except Exception as e:
logger.error("Error getting operation %s: %s", operation_id, e)
raise
def _validate_path_params(self, path: str, params: dict[str, Any] | None) -> None:
if not params:
params = {}
# Extract path parameter names from the path
path_params = set(re.findall(r"{([^}]+)}", path))
# Check for missing required parameters
missing_params = path_params - set(params.keys())
if missing_params:
raise ValueError(f"Missing required path parameters: {missing_params}")
# Check for invalid parameters
invalid_params = set(params.keys()) - path_params
if invalid_params:
raise ValueError(f"Invalid path parameters: {invalid_params}")
async def execute(
self,
operation_id: str,
path_params: dict[str, Any] | None = None,
query_params: dict[str, Any] | None = None,
body: dict[str, Any] | None = None,
) -> Any:
"""Execute an API operation.
Args:
operation_id: Operation ID from OpenAPI spec
path_params: URL path parameters
query_params: URL query parameters
body: Request body data
Returns:
API response data
Raises:
ValueError: If operation not found
RuntimeError: If used outside async context
aiohttp.ClientError: For HTTP/network errors
"""
if not hasattr(self, "_session") or not self._session:
raise RuntimeError("Client not in async context")
try:
# Get operation details
path, method, _ = self._get_operation(operation_id)
# Validate path parameters
self._validate_path_params(path, path_params)
# Format URL
if path_params:
path = path.format(**path_params)
url = f"{self.base_url}{path}"
logger.debug("Executing %s %s", method, url)
logger.debug("Request body: %s", body)
logger.debug("Request query params: %s", query_params)
# Dynamically set headers based on presence of body
request_headers = self.headers.copy()
if body is not None:
request_headers["Content-Type"] = "application/json"
# Make request
async with self._session.request(
method=method,
url=url,
params=query_params,
json=body,
) as response:
response.raise_for_status()
content_type = response.headers.get("Content-Type", "").lower()
# Status codes that typically have no body
no_body_statuses = {204}
if response.status in no_body_statuses:
if content_type and "application/json" in content_type:
logger.warning("Unexpected JSON body with status %s", response.status)
return await response.json() # Parse if present, though rare
logger.debug("Received %s response with no body", response.status)
return response.status
# For statuses expecting a body, check mimetype
if "application/json" in content_type:
logger.debug("Response: %s", await response.text())
return await response.json()
# Unexpected mimetype with body
response_text = await response.text()
logger.error("Unexpected mimetype %s for status %s: %s", content_type, response.status, response_text)
raise ValueError(f"Cannot parse response with mimetype {content_type} as JSON")
except aiohttp.ClientError as e:
logger.error("Error executing operation %s: %s", operation_id, e)
raise
except Exception as e:
logger.error("Error executing operation %s: %s", operation_id, e)
raise ValueError(f"Failed to execute operation: {e}")

View File

@@ -0,0 +1,3 @@
from airflow_mcp_server.parser.operation_parser import OperationParser
__all__ = ["OperationParser"]

View File

@@ -0,0 +1,415 @@
import logging
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Any, Literal, Union
import yaml
from openapi_core import OpenAPI
from pydantic import BaseModel, Field, create_model
logger = logging.getLogger(__name__)
@dataclass
class OperationDetails:
"""Details of an OpenAPI operation."""
operation_id: str
path: str
method: str
parameters: dict[str, Any]
input_model: type[BaseModel]
class OperationParser:
"""Parser for OpenAPI operations."""
def __init__(self, spec_path: Path | str | dict | bytes | object) -> None:
"""Initialize parser with OpenAPI specification.
Args:
spec_path: Path to OpenAPI spec file, dict, bytes, or file-like object
Raises:
ValueError: If spec_path is invalid or spec cannot be loaded
"""
try:
if isinstance(spec_path, bytes):
self.raw_spec = yaml.safe_load(spec_path)
elif isinstance(spec_path, dict):
self.raw_spec = spec_path
elif isinstance(spec_path, str | Path):
with open(spec_path) as f:
self.raw_spec = yaml.safe_load(f)
elif hasattr(spec_path, "read"):
self.raw_spec = yaml.safe_load(spec_path)
else:
raise ValueError(f"Invalid spec_path type: {type(spec_path)}. Expected Path, str, dict, bytes or file-like object")
spec = OpenAPI.from_dict(self.raw_spec)
self.spec = spec
self._paths = self.raw_spec["paths"]
self._components = self.raw_spec.get("components", {})
self._schema_cache: dict[str, dict[str, Any]] = {}
except Exception as e:
logger.error("Error initializing OperationParser: %s", e)
raise ValueError(f"Failed to initialize parser: {e}") from e
def _merge_allof_schema(self, schema: dict[str, Any]) -> dict[str, Any]:
"""Merge an allOf schema into a single effective schema.
Args:
schema: The schema potentially containing allOf
Returns:
A merged schema with unified properties and required fields
"""
if "allOf" not in schema:
return schema
merged = {"type": "object", "properties": {}, "required": []}
for subschema in schema["allOf"]:
resolved = subschema
if "$ref" in subschema:
resolved = self._resolve_ref(subschema["$ref"])
if "properties" in resolved:
merged["properties"].update(resolved["properties"])
if "required" in resolved:
merged["required"].extend(resolved.get("required", []))
merged["required"] = list(set(merged["required"])) # Remove duplicates
logger.debug("Merged allOf schema: %s", merged)
return merged
def parse_operation(self, operation_id: str) -> OperationDetails:
"""Parse operation details from OpenAPI spec.
Args:
operation_id: Operation ID to parse
Returns:
OperationDetails object containing parsed information
Raises:
ValueError: If operation not found or invalid
"""
try:
for path, path_item in self._paths.items():
for method, operation in path_item.items():
if method.startswith("x-") or method == "parameters":
continue
if operation.get("operationId") == operation_id:
logger.debug("Found operation %s at %s %s", operation_id, method, path)
operation["path"] = path
operation["path_item"] = path_item
parameters = self.extract_parameters(operation)
# Get request body schema if present
body_schema = None
if "requestBody" in operation:
content = operation["requestBody"].get("content", {})
if "application/json" in content:
body_schema = content["application/json"].get("schema", {})
if "$ref" in body_schema:
body_schema = self._resolve_ref(body_schema["$ref"])
# Create unified input model
input_model = self._create_input_model(operation_id, parameters, body_schema)
return OperationDetails(operation_id=operation_id, path=str(path), method=method, parameters=parameters, input_model=input_model)
raise ValueError(f"Operation {operation_id} not found in spec")
except Exception as e:
logger.error("Error parsing operation %s: %s", operation_id, e)
raise
def _create_input_model(
self,
operation_id: str,
parameters: dict[str, Any],
body_schema: dict[str, Any] | None = None,
) -> type[BaseModel]:
"""Create unified input model for all parameters."""
fields: dict[str, tuple[type, Any]] = {}
parameter_mapping = {"path": [], "query": [], "body": []}
# Add path parameters
for name, schema in parameters.get("path", {}).items():
field_type = schema["type"] # Use the mapped type from parameter schema
fields[name] = (field_type | None, None) # Make all optional
parameter_mapping["path"].append(name)
# Add query parameters
for name, schema in parameters.get("query", {}).items():
field_type = schema["type"] # Use the mapped type from parameter schema
fields[name] = (field_type | None, None) # Make all optional
parameter_mapping["query"].append(name)
# Handle body schema with allOf support
if body_schema:
effective_schema = self._merge_allof_schema(body_schema)
if "properties" in effective_schema or effective_schema.get("type") == "object":
for prop_name, prop_schema in effective_schema.get("properties", {}).items():
field_type = self._map_type(prop_schema.get("type", "string"), prop_schema.get("format"), prop_schema)
default = None if prop_name in effective_schema.get("required", []) else None
if prop_name == "schema": # Avoid shadowing BaseModel.schema
fields["connection_schema"] = (field_type | None, Field(default, alias="schema"))
parameter_mapping["body"].append("connection_schema")
else:
fields[prop_name] = (field_type | None, default)
parameter_mapping["body"].append(prop_name)
logger.debug("Creating input model for %s with fields: %s", operation_id, fields)
model = create_model(f"{operation_id}_input", **fields)
model.model_config["parameter_mapping"] = parameter_mapping
return model
def extract_parameters(self, operation: dict[str, Any]) -> dict[str, Any]:
"""Extract and categorize operation parameters.
Args:
operation: Operation object from OpenAPI spec
Returns:
Dictionary of parameters by category (path, query, header)
"""
parameters: dict[str, dict[str, Any]] = {
"path": {},
"query": {},
"header": {},
}
path_item = operation.get("path_item", {})
if path_item and "parameters" in path_item:
self._process_parameters(path_item["parameters"], parameters)
self._process_parameters(operation.get("parameters", []), parameters)
return parameters
def _process_parameters(self, params: list[dict[str, Any]], target: dict[str, dict[str, Any]]) -> None:
"""Process a list of parameters and add them to the target dict.
Args:
params: List of parameter objects
target: Target dictionary to store processed parameters
"""
for param in params:
if "$ref" in param:
param = self._resolve_ref(param["$ref"])
if not isinstance(param, dict) or "in" not in param:
logger.warning("Invalid parameter format: %s", param)
continue
param_in = param["in"]
if param_in in target:
target[param_in][param["name"]] = self._map_parameter_schema(param)
def _resolve_ref(self, ref: str) -> dict[str, Any]:
"""Resolve OpenAPI reference.
Args:
ref: Reference string (e.g. '#/components/schemas/Model')
Returns:
Resolved object
"""
if ref in self._schema_cache:
return self._schema_cache[ref]
parts = ref.split("/")
current = self.raw_spec
for part in parts[1:]:
current = current[part]
self._schema_cache[ref] = current
return current
def _map_parameter_schema(self, param: dict[str, Any]) -> dict[str, Any]:
"""Map parameter schema to Python type information.
Args:
param: Parameter object from OpenAPI spec
Returns:
Dictionary with Python type information
"""
schema = param.get("schema", {})
if "$ref" in schema:
schema = self._resolve_ref(schema["$ref"])
# Get the type and format from schema
openapi_type = schema.get("type", "string")
format_type = schema.get("format")
return {
"type": self._map_type(openapi_type, format_type, schema), # Pass format_type and full schema
"required": param.get("required", False),
"default": schema.get("default"),
"description": param.get("description"),
}
def _map_type(self, openapi_type: str, format_type: str | None = None, schema: dict[str, Any] | None = None) -> type:
"""Map OpenAPI type to Python type."""
# Handle enums
if schema and "enum" in schema:
return Literal[tuple(schema["enum"])] # type: ignore
# Handle date-time format
if openapi_type == "string" and format_type == "date-time":
return datetime
# Handle nullable fields
base_type = {
"string": str,
"integer": int,
"number": float,
"boolean": bool,
"array": list,
"object": dict,
}.get(openapi_type, Any)
if schema and schema.get("nullable"):
return Union[base_type, None] # noqa: UP007
return base_type
def _parse_response_model(self, operation: dict[str, Any]) -> type[BaseModel] | None:
"""Parse response schema into Pydantic model."""
try:
responses = operation.get("responses", {})
if "200" not in responses:
return None
response = responses["200"]
if "$ref" in response:
response = self._resolve_ref(response["$ref"])
content = response.get("content", {})
if "application/json" not in content:
return None
schema = content["application/json"].get("schema", {})
if "$ref" in schema:
schema = self._resolve_ref(schema["$ref"])
logger.debug("Response schema before model creation: %s", schema)
# Ensure schema has properties
if "properties" not in schema:
logger.error("Response schema missing properties")
return None
# Create model with schema properties
model = self._create_model("Response", schema)
logger.debug("Created response model with schema: %s", model.model_json_schema())
return model
except Exception as e:
logger.error("Failed to create response model: %s", e)
return None
def _create_model(self, name: str, schema: dict[str, Any]) -> type[BaseModel]: # noqa C901
"""Create Pydantic model from schema.
Args:
name: Model name
schema: OpenAPI schema
Returns:
Generated Pydantic model
Raises:
ValueError: If schema is invalid
"""
if "$ref" in schema:
schema = self._resolve_ref(schema["$ref"])
# Handle discriminated unions
if "anyOf" in schema:
discriminator = schema.get("discriminator", {}).get("propertyName")
if discriminator:
union_types = []
for subschema in schema["anyOf"]:
if "$ref" in subschema:
resolved = self._resolve_ref(subschema["$ref"])
sub_model = self._create_model(f"{name}_type", resolved)
union_types.append(sub_model)
fields = {
"type_name": (str, ...), # Use type_name instead of __type
**{k: (Union[tuple(union_types)], ...) for k in schema.get("properties", {}) if k != discriminator}, # noqa: UP007
}
return create_model(name, **fields)
if schema.get("type", "object") != "object":
raise ValueError("Schema must be an object type")
fields = {}
for prop_name, prop_schema in schema.get("properties", {}).items():
if "$ref" in prop_schema:
prop_schema = self._resolve_ref(prop_schema["$ref"])
required = prop_name in schema.get("required", [])
if prop_schema.get("type") == "object":
nested_model = self._create_model(f"{name}_{prop_name}", prop_schema)
field_type = nested_model
elif prop_schema.get("type") == "array":
items = prop_schema.get("items", {})
if "$ref" in items:
items = self._resolve_ref(items["$ref"])
if items.get("type") == "object":
item_model = self._create_model(f"{name}_{prop_name}_item", items)
field_type = list[item_model]
else:
item_type = self._map_type(items.get("type", "string"))
field_type = list[item_type]
else:
field_type = self._map_type(prop_schema.get("type", "string"), prop_schema.get("format"), prop_schema)
# Add validation parameters
field_config = {}
if "pattern" in prop_schema:
field_config["pattern"] = prop_schema["pattern"]
if "minLength" in prop_schema:
field_config["min_length"] = prop_schema["minLength"]
if "maxLength" in prop_schema:
field_config["max_length"] = prop_schema["maxLength"]
if "minimum" in prop_schema:
field_config["ge"] = prop_schema["minimum"]
if "maximum" in prop_schema:
field_config["le"] = prop_schema["maximum"]
# Create field with validation
if field_config:
field = Field(..., **field_config) if required else Field(None, **field_config)
fields[prop_name] = (field_type, field)
continue
fields[prop_name] = (field_type, ... if required else None)
logger.debug("Creating model %s with fields: %s", name, fields)
try:
return create_model(name, **fields)
except Exception as e:
logger.error("Error creating model %s: %s", name, e)
raise ValueError(f"Failed to create model {name}: {e}")
def get_operations(self) -> list[str]:
"""Get list of all operation IDs from spec."""
operations = []
for path in self._paths.values():
for method, operation in path.items():
if method.startswith("x-") or method == "parameters":
continue
if "operationId" in operation:
operations.append(operation["operationId"])
return operations

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,52 @@
import logging
import os
from typing import Any
from mcp.server import Server
from mcp.server.stdio import stdio_server
from mcp.types import TextContent, Tool
from airflow_mcp_server.tools.tool_manager import get_airflow_tools, get_tool
# ===========THIS IS FOR DEBUGGING WITH MCP INSPECTOR===================
# import sys
# Configure root logger to stderr
# logging.basicConfig(level=logging.DEBUG, format="%(asctime)s [%(levelname)s] %(message)s", handlers=[logging.StreamHandler(sys.stderr)])
# Disable Uvicorn's default handlers
# logging.getLogger("uvicorn.error").handlers = []
# logging.getLogger("uvicorn.access").handlers = []
# ======================================================================
logger = logging.getLogger(__name__)
async def serve() -> None:
"""Start MCP server."""
required_vars = ["AIRFLOW_BASE_URL", "AUTH_TOKEN"]
if not all(var in os.environ for var in required_vars):
raise ValueError(f"Missing required environment variables: {required_vars}")
server = Server("airflow-mcp-server")
@server.list_tools()
async def list_tools() -> list[Tool]:
try:
return await get_airflow_tools()
except Exception as e:
logger.error("Failed to list tools: %s", e)
raise
@server.call_tool()
async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]:
try:
tool = await get_tool(name)
async with tool.client:
result = await tool.run(body=arguments)
return [TextContent(type="text", text=str(result))]
except Exception as e:
logger.error("Tool execution failed: %s", e)
raise
options = server.create_initialization_options()
async with stdio_server() as (read_stream, write_stream):
await server.run(read_stream, write_stream, options, raise_exceptions=True)

View File

@@ -0,0 +1,45 @@
import logging
import os
from typing import Any
from mcp.server import Server
from mcp.server.stdio import stdio_server
from mcp.types import TextContent, Tool
from airflow_mcp_server.tools.tool_manager import get_airflow_tools, get_tool
logger = logging.getLogger(__name__)
async def serve() -> None:
"""Start MCP server in safe mode (read-only operations)."""
required_vars = ["AIRFLOW_BASE_URL", "AUTH_TOKEN"]
if not all(var in os.environ for var in required_vars):
raise ValueError(f"Missing required environment variables: {required_vars}")
server = Server("airflow-mcp-server-safe")
@server.list_tools()
async def list_tools() -> list[Tool]:
try:
return await get_airflow_tools(mode="safe")
except Exception as e:
logger.error("Failed to list tools: %s", e)
raise
@server.call_tool()
async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]:
try:
if not name.startswith("get_"):
raise ValueError("Only GET operations allowed in safe mode")
tool = await get_tool(name)
async with tool.client:
result = await tool.run(body=arguments)
return [TextContent(type="text", text=str(result))]
except Exception as e:
logger.error("Tool execution failed: %s", e)
raise
options = server.create_initialization_options()
async with stdio_server() as (read_stream, write_stream):
await server.run(read_stream, write_stream, options, raise_exceptions=True)

View File

@@ -0,0 +1,43 @@
import logging
import os
from typing import Any
from mcp.server import Server
from mcp.server.stdio import stdio_server
from mcp.types import TextContent, Tool
from airflow_mcp_server.tools.tool_manager import get_airflow_tools, get_tool
logger = logging.getLogger(__name__)
async def serve() -> None:
"""Start MCP server in unsafe mode (all operations)."""
required_vars = ["AIRFLOW_BASE_URL", "AUTH_TOKEN"]
if not all(var in os.environ for var in required_vars):
raise ValueError(f"Missing required environment variables: {required_vars}")
server = Server("airflow-mcp-server-unsafe")
@server.list_tools()
async def list_tools() -> list[Tool]:
try:
return await get_airflow_tools(mode="unsafe")
except Exception as e:
logger.error("Failed to list tools: %s", e)
raise
@server.call_tool()
async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]:
try:
tool = await get_tool(name)
async with tool.client:
result = await tool.run(body=arguments)
return [TextContent(type="text", text=str(result))]
except Exception as e:
logger.error("Tool execution failed: %s", e)
raise
options = server.create_initialization_options()
async with stdio_server() as (read_stream, write_stream):
await server.run(read_stream, write_stream, options, raise_exceptions=True)

View File

View File

@@ -0,0 +1,81 @@
import logging
from typing import Any
from pydantic import ValidationError
from airflow_mcp_server.client.airflow_client import AirflowClient
from airflow_mcp_server.parser.operation_parser import OperationDetails
from airflow_mcp_server.tools.base_tools import BaseTools
logger = logging.getLogger(__name__)
def create_validation_error(field: str, message: str) -> ValidationError:
"""Create a properly formatted validation error.
Args:
field: The field that failed validation
message: The error message
Returns:
ValidationError: A properly formatted validation error
"""
errors = [
{
"loc": (field,),
"msg": message,
"type": "value_error",
"input": None,
"ctx": {"error": message},
}
]
return ValidationError.from_exception_data("validation_error", errors)
class AirflowTool(BaseTools):
"""
Tool for executing Airflow API operations.
AirflowTool is supposed to have objects per operation.
"""
def __init__(self, operation_details: OperationDetails, client: AirflowClient) -> None:
"""Initialize tool with operation details and client.
Args:
operation_details: Operation details
client: AirflowClient instance
"""
super().__init__()
self.operation = operation_details
self.client = client
async def run(
self,
body: dict[str, Any] | None = None,
) -> Any:
"""Execute the operation with provided parameters."""
try:
# Validate input
validated_input = self.operation.input_model(**(body or {}))
validated_body = validated_input.model_dump(exclude_none=True) # Only include non-None values
mapping = self.operation.input_model.model_config["parameter_mapping"]
path_params = {k: validated_body[k] for k in mapping.get("path", []) if k in validated_body}
query_params = {k: validated_body[k] for k in mapping.get("query", []) if k in validated_body}
body_params = {k: validated_body[k] for k in mapping.get("body", []) if k in validated_body}
# Execute operation and return raw response
response = await self.client.execute(
operation_id=self.operation.operation_id,
path_params=path_params or None,
query_params=query_params or None,
body=body_params or None,
)
return response
except ValidationError:
raise
except Exception as e:
logger.error("Operation execution failed: %s", e)
raise

View File

@@ -0,0 +1,19 @@
from abc import ABC, abstractmethod
from typing import Any
class BaseTools(ABC):
"""Abstract base class for tools."""
@abstractmethod
def __init__(self) -> None:
"""Initialize the tool."""
pass
@abstractmethod
def run(self) -> Any:
"""Execute the tool's main functionality.
Returns:
Any: The result of the tool execution
"""
raise NotImplementedError

View File

@@ -0,0 +1,126 @@
import logging
import os
from importlib import resources
from mcp.types import Tool
from airflow_mcp_server.client.airflow_client import AirflowClient
from airflow_mcp_server.parser.operation_parser import OperationParser
from airflow_mcp_server.tools.airflow_tool import AirflowTool
logger = logging.getLogger(__name__)
_tools_cache: dict[str, AirflowTool] = {}
def _initialize_client() -> AirflowClient:
"""Initialize Airflow client with environment variables or embedded spec.
Returns:
AirflowClient instance
Raises:
ValueError: If required environment variables are missing or default spec is not found
"""
spec_path = os.environ.get("OPENAPI_SPEC")
if not spec_path:
# Fallback to embedded v1.yaml
try:
with resources.files("airflow_mcp_server.resources").joinpath("v1.yaml").open("rb") as f:
spec_path = f.name
logger.info("OPENAPI_SPEC not set; using embedded v1.yaml from %s", spec_path)
except Exception as e:
raise ValueError("Default OpenAPI spec not found in package resources") from e
required_vars = ["AIRFLOW_BASE_URL", "AUTH_TOKEN"]
missing_vars = [var for var in required_vars if var not in os.environ]
if missing_vars:
raise ValueError(f"Missing required environment variables: {missing_vars}")
return AirflowClient(spec_path=spec_path, base_url=os.environ["AIRFLOW_BASE_URL"], auth_token=os.environ["AUTH_TOKEN"])
async def _initialize_tools() -> None:
"""Initialize tools cache with Airflow operations.
Raises:
ValueError: If initialization fails
"""
global _tools_cache
try:
client = _initialize_client()
spec_path = os.environ.get("OPENAPI_SPEC")
if not spec_path:
with resources.files("airflow_mcp_server.resources").joinpath("v1.yaml").open("rb") as f:
spec_path = f.name
parser = OperationParser(spec_path)
# Generate tools for each operation
for operation_id in parser.get_operations():
operation_details = parser.parse_operation(operation_id)
tool = AirflowTool(operation_details, client)
_tools_cache[operation_id] = tool
except Exception as e:
logger.error("Failed to initialize tools: %s", e)
_tools_cache.clear()
raise ValueError(f"Failed to initialize tools: {e}") from e
async def get_airflow_tools(mode: str = "unsafe") -> list[Tool]:
"""Get list of available Airflow tools based on mode.
Args:
mode: "safe" for GET operations only, "unsafe" for all operations (default)
Returns:
List of MCP Tool objects representing available operations
Raises:
ValueError: If initialization fails
"""
if not _tools_cache:
await _initialize_tools()
tools = []
for operation_id, tool in _tools_cache.items():
try:
# Skip non-GET operations in safe mode
if mode == "safe" and not tool.operation.method.lower() == "get":
continue
schema = tool.operation.input_model.model_json_schema()
tools.append(
Tool(
name=operation_id,
description=tool.operation.operation_id,
inputSchema=schema,
)
)
except Exception as e:
logger.error("Failed to create tool schema for %s: %s", operation_id, e)
continue
return tools
async def get_tool(name: str) -> AirflowTool:
"""Get specific tool by name.
Args:
name: Tool/operation name
Returns:
AirflowTool instance
Raises:
KeyError: If tool not found
ValueError: If tool initialization fails
"""
if not _tools_cache:
await _initialize_tools()
if name not in _tools_cache:
raise KeyError(f"Tool {name} not found")
return _tools_cache[name]