Clean up for only MCP Server
This commit is contained in:
35
src/airflow_mcp_server/__init__.py
Normal file
35
src/airflow_mcp_server/__init__.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import click
|
||||
|
||||
from airflow_mcp_server.server_safe import serve as serve_safe
|
||||
from airflow_mcp_server.server_unsafe import serve as serve_unsafe
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("-v", "--verbose", count=True, help="Increase verbosity")
|
||||
@click.option("--safe", "-s", is_flag=True, help="Use only read-only tools")
|
||||
@click.option("--unsafe", "-u", is_flag=True, help="Use all tools (default)")
|
||||
def main(verbose: int, safe: bool, unsafe: bool) -> None:
|
||||
"""MCP server for Airflow"""
|
||||
logging_level = logging.WARN
|
||||
if verbose == 1:
|
||||
logging_level = logging.INFO
|
||||
elif verbose >= 2:
|
||||
logging_level = logging.DEBUG
|
||||
|
||||
logging.basicConfig(level=logging_level, stream=sys.stderr)
|
||||
|
||||
if safe and unsafe:
|
||||
raise click.UsageError("Options --safe and --unsafe are mutually exclusive")
|
||||
|
||||
if safe:
|
||||
asyncio.run(serve_safe())
|
||||
else: # Default to unsafe mode
|
||||
asyncio.run(serve_unsafe())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
3
src/airflow_mcp_server/__main__.py
Normal file
3
src/airflow_mcp_server/__main__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from airflow_mcp_server import main
|
||||
|
||||
main()
|
||||
3
src/airflow_mcp_server/client/__init__.py
Normal file
3
src/airflow_mcp_server/client/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from airflow_mcp_server.client.airflow_client import AirflowClient
|
||||
|
||||
__all__ = ["AirflowClient"]
|
||||
248
src/airflow_mcp_server/client/airflow_client.py
Normal file
248
src/airflow_mcp_server/client/airflow_client.py
Normal file
@@ -0,0 +1,248 @@
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, BinaryIO, TextIO
|
||||
|
||||
import aiohttp
|
||||
import yaml
|
||||
from jsonschema_path import SchemaPath
|
||||
from openapi_core import OpenAPI
|
||||
from openapi_core.validation.request.validators import V31RequestValidator
|
||||
from openapi_spec_validator import validate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def camel_to_snake(name: str) -> str:
|
||||
"""Convert camelCase to snake_case."""
|
||||
name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
|
||||
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower()
|
||||
|
||||
|
||||
def convert_dict_keys(d: dict) -> dict:
|
||||
"""Recursively convert dictionary keys from camelCase to snake_case."""
|
||||
if not isinstance(d, dict):
|
||||
return d
|
||||
|
||||
return {camel_to_snake(k): convert_dict_keys(v) if isinstance(v, dict) else v for k, v in d.items()}
|
||||
|
||||
|
||||
class AirflowClient:
|
||||
"""Client for interacting with Airflow API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
spec_path: Path | str | dict | bytes | BinaryIO | TextIO,
|
||||
base_url: str,
|
||||
auth_token: str,
|
||||
) -> None:
|
||||
"""Initialize Airflow client.
|
||||
|
||||
Args:
|
||||
spec_path: OpenAPI spec as file path, dict, bytes, or file object
|
||||
base_url: Base URL for API
|
||||
auth_token: Authentication token
|
||||
|
||||
Raises:
|
||||
ValueError: If spec_path is invalid or spec cannot be loaded
|
||||
"""
|
||||
try:
|
||||
# Load and parse OpenAPI spec
|
||||
if isinstance(spec_path, dict):
|
||||
self.raw_spec = spec_path
|
||||
elif isinstance(spec_path, bytes):
|
||||
self.raw_spec = yaml.safe_load(spec_path)
|
||||
elif isinstance(spec_path, str | Path):
|
||||
with open(spec_path) as f:
|
||||
self.raw_spec = yaml.safe_load(f)
|
||||
elif hasattr(spec_path, "read"):
|
||||
content = spec_path.read()
|
||||
if isinstance(content, bytes):
|
||||
self.raw_spec = yaml.safe_load(content)
|
||||
else:
|
||||
self.raw_spec = yaml.safe_load(content)
|
||||
else:
|
||||
raise ValueError("Invalid spec_path type. Expected Path, str, dict, bytes or file-like object")
|
||||
|
||||
# Validate spec has required fields
|
||||
if not isinstance(self.raw_spec, dict):
|
||||
raise ValueError("OpenAPI spec must be a dictionary")
|
||||
|
||||
required_fields = ["openapi", "info", "paths"]
|
||||
for field in required_fields:
|
||||
if field not in self.raw_spec:
|
||||
raise ValueError(f"OpenAPI spec missing required field: {field}")
|
||||
|
||||
# Validate OpenAPI spec format
|
||||
validate(self.raw_spec)
|
||||
|
||||
# Initialize OpenAPI spec
|
||||
self.spec = OpenAPI.from_dict(self.raw_spec)
|
||||
logger.debug("OpenAPI spec loaded successfully")
|
||||
|
||||
# Debug raw spec
|
||||
logger.debug("Raw spec keys: %s", self.raw_spec.keys())
|
||||
|
||||
# Get paths from raw spec
|
||||
if "paths" not in self.raw_spec:
|
||||
raise ValueError("OpenAPI spec does not contain paths information")
|
||||
self._paths = self.raw_spec["paths"]
|
||||
logger.debug("Using raw spec paths")
|
||||
|
||||
# Initialize request validator with schema path
|
||||
schema_path = SchemaPath.from_dict(self.raw_spec)
|
||||
self._validator = V31RequestValidator(schema_path)
|
||||
|
||||
# API configuration
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.headers = {
|
||||
"Authorization": f"Basic {auth_token}",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to initialize AirflowClient: %s", e)
|
||||
raise ValueError(f"Failed to initialize client: {e}")
|
||||
|
||||
async def __aenter__(self) -> "AirflowClient":
|
||||
self._session = aiohttp.ClientSession(headers=self.headers)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *exc) -> None:
|
||||
if hasattr(self, "_session"):
|
||||
await self._session.close()
|
||||
delattr(self, "_session")
|
||||
|
||||
def _get_operation(self, operation_id: str) -> tuple[str, str, SimpleNamespace]:
|
||||
"""Get operation details from OpenAPI spec.
|
||||
|
||||
Args:
|
||||
operation_id: The operation ID to look up
|
||||
|
||||
Returns:
|
||||
Tuple of (path, method, operation) where operation is a SimpleNamespace object
|
||||
|
||||
Raises:
|
||||
ValueError: If operation not found
|
||||
"""
|
||||
try:
|
||||
# Debug the paths structure
|
||||
logger.debug("Looking for operation %s in paths", operation_id)
|
||||
|
||||
for path, path_item in self._paths.items():
|
||||
for method, operation_data in path_item.items():
|
||||
# Skip non-operation fields
|
||||
if method.startswith("x-") or method == "parameters":
|
||||
continue
|
||||
|
||||
# Debug each operation
|
||||
logger.debug("Checking %s %s: %s", method, path, operation_data.get("operationId"))
|
||||
|
||||
if operation_data.get("operationId") == operation_id:
|
||||
logger.debug("Found operation %s at %s %s", operation_id, method, path)
|
||||
# Convert keys to snake_case and create object
|
||||
converted_data = convert_dict_keys(operation_data)
|
||||
operation_obj = SimpleNamespace(**converted_data)
|
||||
return path, method, operation_obj
|
||||
|
||||
raise ValueError(f"Operation {operation_id} not found in spec")
|
||||
except Exception as e:
|
||||
logger.error("Error getting operation %s: %s", operation_id, e)
|
||||
raise
|
||||
|
||||
def _validate_path_params(self, path: str, params: dict[str, Any] | None) -> None:
|
||||
if not params:
|
||||
params = {}
|
||||
|
||||
# Extract path parameter names from the path
|
||||
path_params = set(re.findall(r"{([^}]+)}", path))
|
||||
|
||||
# Check for missing required parameters
|
||||
missing_params = path_params - set(params.keys())
|
||||
if missing_params:
|
||||
raise ValueError(f"Missing required path parameters: {missing_params}")
|
||||
|
||||
# Check for invalid parameters
|
||||
invalid_params = set(params.keys()) - path_params
|
||||
if invalid_params:
|
||||
raise ValueError(f"Invalid path parameters: {invalid_params}")
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
operation_id: str,
|
||||
path_params: dict[str, Any] | None = None,
|
||||
query_params: dict[str, Any] | None = None,
|
||||
body: dict[str, Any] | None = None,
|
||||
) -> Any:
|
||||
"""Execute an API operation.
|
||||
|
||||
Args:
|
||||
operation_id: Operation ID from OpenAPI spec
|
||||
path_params: URL path parameters
|
||||
query_params: URL query parameters
|
||||
body: Request body data
|
||||
|
||||
Returns:
|
||||
API response data
|
||||
|
||||
Raises:
|
||||
ValueError: If operation not found
|
||||
RuntimeError: If used outside async context
|
||||
aiohttp.ClientError: For HTTP/network errors
|
||||
"""
|
||||
if not hasattr(self, "_session") or not self._session:
|
||||
raise RuntimeError("Client not in async context")
|
||||
|
||||
try:
|
||||
# Get operation details
|
||||
path, method, _ = self._get_operation(operation_id)
|
||||
|
||||
# Validate path parameters
|
||||
self._validate_path_params(path, path_params)
|
||||
|
||||
# Format URL
|
||||
if path_params:
|
||||
path = path.format(**path_params)
|
||||
url = f"{self.base_url}{path}"
|
||||
|
||||
logger.debug("Executing %s %s", method, url)
|
||||
logger.debug("Request body: %s", body)
|
||||
logger.debug("Request query params: %s", query_params)
|
||||
|
||||
# Dynamically set headers based on presence of body
|
||||
request_headers = self.headers.copy()
|
||||
if body is not None:
|
||||
request_headers["Content-Type"] = "application/json"
|
||||
# Make request
|
||||
async with self._session.request(
|
||||
method=method,
|
||||
url=url,
|
||||
params=query_params,
|
||||
json=body,
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
content_type = response.headers.get("Content-Type", "").lower()
|
||||
# Status codes that typically have no body
|
||||
no_body_statuses = {204}
|
||||
if response.status in no_body_statuses:
|
||||
if content_type and "application/json" in content_type:
|
||||
logger.warning("Unexpected JSON body with status %s", response.status)
|
||||
return await response.json() # Parse if present, though rare
|
||||
logger.debug("Received %s response with no body", response.status)
|
||||
return response.status
|
||||
# For statuses expecting a body, check mimetype
|
||||
if "application/json" in content_type:
|
||||
logger.debug("Response: %s", await response.text())
|
||||
return await response.json()
|
||||
# Unexpected mimetype with body
|
||||
response_text = await response.text()
|
||||
logger.error("Unexpected mimetype %s for status %s: %s", content_type, response.status, response_text)
|
||||
raise ValueError(f"Cannot parse response with mimetype {content_type} as JSON")
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error("Error executing operation %s: %s", operation_id, e)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Error executing operation %s: %s", operation_id, e)
|
||||
raise ValueError(f"Failed to execute operation: {e}")
|
||||
3
src/airflow_mcp_server/parser/__init__.py
Normal file
3
src/airflow_mcp_server/parser/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from airflow_mcp_server.parser.operation_parser import OperationParser
|
||||
|
||||
__all__ = ["OperationParser"]
|
||||
415
src/airflow_mcp_server/parser/operation_parser.py
Normal file
415
src/airflow_mcp_server/parser/operation_parser.py
Normal file
@@ -0,0 +1,415 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
import yaml
|
||||
from openapi_core import OpenAPI
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OperationDetails:
|
||||
"""Details of an OpenAPI operation."""
|
||||
|
||||
operation_id: str
|
||||
path: str
|
||||
method: str
|
||||
parameters: dict[str, Any]
|
||||
input_model: type[BaseModel]
|
||||
|
||||
|
||||
class OperationParser:
|
||||
"""Parser for OpenAPI operations."""
|
||||
|
||||
def __init__(self, spec_path: Path | str | dict | bytes | object) -> None:
|
||||
"""Initialize parser with OpenAPI specification.
|
||||
|
||||
Args:
|
||||
spec_path: Path to OpenAPI spec file, dict, bytes, or file-like object
|
||||
|
||||
Raises:
|
||||
ValueError: If spec_path is invalid or spec cannot be loaded
|
||||
"""
|
||||
try:
|
||||
if isinstance(spec_path, bytes):
|
||||
self.raw_spec = yaml.safe_load(spec_path)
|
||||
elif isinstance(spec_path, dict):
|
||||
self.raw_spec = spec_path
|
||||
elif isinstance(spec_path, str | Path):
|
||||
with open(spec_path) as f:
|
||||
self.raw_spec = yaml.safe_load(f)
|
||||
elif hasattr(spec_path, "read"):
|
||||
self.raw_spec = yaml.safe_load(spec_path)
|
||||
else:
|
||||
raise ValueError(f"Invalid spec_path type: {type(spec_path)}. Expected Path, str, dict, bytes or file-like object")
|
||||
|
||||
spec = OpenAPI.from_dict(self.raw_spec)
|
||||
self.spec = spec
|
||||
self._paths = self.raw_spec["paths"]
|
||||
self._components = self.raw_spec.get("components", {})
|
||||
self._schema_cache: dict[str, dict[str, Any]] = {}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error initializing OperationParser: %s", e)
|
||||
raise ValueError(f"Failed to initialize parser: {e}") from e
|
||||
|
||||
def _merge_allof_schema(self, schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Merge an allOf schema into a single effective schema.
|
||||
|
||||
Args:
|
||||
schema: The schema potentially containing allOf
|
||||
|
||||
Returns:
|
||||
A merged schema with unified properties and required fields
|
||||
"""
|
||||
if "allOf" not in schema:
|
||||
return schema
|
||||
merged = {"type": "object", "properties": {}, "required": []}
|
||||
for subschema in schema["allOf"]:
|
||||
resolved = subschema
|
||||
if "$ref" in subschema:
|
||||
resolved = self._resolve_ref(subschema["$ref"])
|
||||
if "properties" in resolved:
|
||||
merged["properties"].update(resolved["properties"])
|
||||
if "required" in resolved:
|
||||
merged["required"].extend(resolved.get("required", []))
|
||||
merged["required"] = list(set(merged["required"])) # Remove duplicates
|
||||
logger.debug("Merged allOf schema: %s", merged)
|
||||
return merged
|
||||
|
||||
def parse_operation(self, operation_id: str) -> OperationDetails:
|
||||
"""Parse operation details from OpenAPI spec.
|
||||
|
||||
Args:
|
||||
operation_id: Operation ID to parse
|
||||
|
||||
Returns:
|
||||
OperationDetails object containing parsed information
|
||||
|
||||
Raises:
|
||||
ValueError: If operation not found or invalid
|
||||
"""
|
||||
try:
|
||||
for path, path_item in self._paths.items():
|
||||
for method, operation in path_item.items():
|
||||
if method.startswith("x-") or method == "parameters":
|
||||
continue
|
||||
|
||||
if operation.get("operationId") == operation_id:
|
||||
logger.debug("Found operation %s at %s %s", operation_id, method, path)
|
||||
|
||||
operation["path"] = path
|
||||
operation["path_item"] = path_item
|
||||
|
||||
parameters = self.extract_parameters(operation)
|
||||
|
||||
# Get request body schema if present
|
||||
body_schema = None
|
||||
if "requestBody" in operation:
|
||||
content = operation["requestBody"].get("content", {})
|
||||
if "application/json" in content:
|
||||
body_schema = content["application/json"].get("schema", {})
|
||||
if "$ref" in body_schema:
|
||||
body_schema = self._resolve_ref(body_schema["$ref"])
|
||||
|
||||
# Create unified input model
|
||||
input_model = self._create_input_model(operation_id, parameters, body_schema)
|
||||
|
||||
return OperationDetails(operation_id=operation_id, path=str(path), method=method, parameters=parameters, input_model=input_model)
|
||||
|
||||
raise ValueError(f"Operation {operation_id} not found in spec")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error parsing operation %s: %s", operation_id, e)
|
||||
raise
|
||||
|
||||
def _create_input_model(
|
||||
self,
|
||||
operation_id: str,
|
||||
parameters: dict[str, Any],
|
||||
body_schema: dict[str, Any] | None = None,
|
||||
) -> type[BaseModel]:
|
||||
"""Create unified input model for all parameters."""
|
||||
fields: dict[str, tuple[type, Any]] = {}
|
||||
parameter_mapping = {"path": [], "query": [], "body": []}
|
||||
|
||||
# Add path parameters
|
||||
for name, schema in parameters.get("path", {}).items():
|
||||
field_type = schema["type"] # Use the mapped type from parameter schema
|
||||
fields[name] = (field_type | None, None) # Make all optional
|
||||
parameter_mapping["path"].append(name)
|
||||
|
||||
# Add query parameters
|
||||
for name, schema in parameters.get("query", {}).items():
|
||||
field_type = schema["type"] # Use the mapped type from parameter schema
|
||||
fields[name] = (field_type | None, None) # Make all optional
|
||||
parameter_mapping["query"].append(name)
|
||||
|
||||
# Handle body schema with allOf support
|
||||
if body_schema:
|
||||
effective_schema = self._merge_allof_schema(body_schema)
|
||||
if "properties" in effective_schema or effective_schema.get("type") == "object":
|
||||
for prop_name, prop_schema in effective_schema.get("properties", {}).items():
|
||||
field_type = self._map_type(prop_schema.get("type", "string"), prop_schema.get("format"), prop_schema)
|
||||
default = None if prop_name in effective_schema.get("required", []) else None
|
||||
if prop_name == "schema": # Avoid shadowing BaseModel.schema
|
||||
fields["connection_schema"] = (field_type | None, Field(default, alias="schema"))
|
||||
parameter_mapping["body"].append("connection_schema")
|
||||
else:
|
||||
fields[prop_name] = (field_type | None, default)
|
||||
parameter_mapping["body"].append(prop_name)
|
||||
|
||||
logger.debug("Creating input model for %s with fields: %s", operation_id, fields)
|
||||
model = create_model(f"{operation_id}_input", **fields)
|
||||
model.model_config["parameter_mapping"] = parameter_mapping
|
||||
return model
|
||||
|
||||
def extract_parameters(self, operation: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Extract and categorize operation parameters.
|
||||
|
||||
Args:
|
||||
operation: Operation object from OpenAPI spec
|
||||
|
||||
Returns:
|
||||
Dictionary of parameters by category (path, query, header)
|
||||
"""
|
||||
parameters: dict[str, dict[str, Any]] = {
|
||||
"path": {},
|
||||
"query": {},
|
||||
"header": {},
|
||||
}
|
||||
|
||||
path_item = operation.get("path_item", {})
|
||||
if path_item and "parameters" in path_item:
|
||||
self._process_parameters(path_item["parameters"], parameters)
|
||||
|
||||
self._process_parameters(operation.get("parameters", []), parameters)
|
||||
|
||||
return parameters
|
||||
|
||||
def _process_parameters(self, params: list[dict[str, Any]], target: dict[str, dict[str, Any]]) -> None:
|
||||
"""Process a list of parameters and add them to the target dict.
|
||||
|
||||
Args:
|
||||
params: List of parameter objects
|
||||
target: Target dictionary to store processed parameters
|
||||
"""
|
||||
for param in params:
|
||||
if "$ref" in param:
|
||||
param = self._resolve_ref(param["$ref"])
|
||||
|
||||
if not isinstance(param, dict) or "in" not in param:
|
||||
logger.warning("Invalid parameter format: %s", param)
|
||||
continue
|
||||
|
||||
param_in = param["in"]
|
||||
if param_in in target:
|
||||
target[param_in][param["name"]] = self._map_parameter_schema(param)
|
||||
|
||||
def _resolve_ref(self, ref: str) -> dict[str, Any]:
|
||||
"""Resolve OpenAPI reference.
|
||||
|
||||
Args:
|
||||
ref: Reference string (e.g. '#/components/schemas/Model')
|
||||
|
||||
Returns:
|
||||
Resolved object
|
||||
"""
|
||||
if ref in self._schema_cache:
|
||||
return self._schema_cache[ref]
|
||||
|
||||
parts = ref.split("/")
|
||||
current = self.raw_spec
|
||||
for part in parts[1:]:
|
||||
current = current[part]
|
||||
|
||||
self._schema_cache[ref] = current
|
||||
return current
|
||||
|
||||
def _map_parameter_schema(self, param: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Map parameter schema to Python type information.
|
||||
|
||||
Args:
|
||||
param: Parameter object from OpenAPI spec
|
||||
|
||||
Returns:
|
||||
Dictionary with Python type information
|
||||
"""
|
||||
schema = param.get("schema", {})
|
||||
if "$ref" in schema:
|
||||
schema = self._resolve_ref(schema["$ref"])
|
||||
|
||||
# Get the type and format from schema
|
||||
openapi_type = schema.get("type", "string")
|
||||
format_type = schema.get("format")
|
||||
|
||||
return {
|
||||
"type": self._map_type(openapi_type, format_type, schema), # Pass format_type and full schema
|
||||
"required": param.get("required", False),
|
||||
"default": schema.get("default"),
|
||||
"description": param.get("description"),
|
||||
}
|
||||
|
||||
def _map_type(self, openapi_type: str, format_type: str | None = None, schema: dict[str, Any] | None = None) -> type:
|
||||
"""Map OpenAPI type to Python type."""
|
||||
# Handle enums
|
||||
if schema and "enum" in schema:
|
||||
return Literal[tuple(schema["enum"])] # type: ignore
|
||||
|
||||
# Handle date-time format
|
||||
if openapi_type == "string" and format_type == "date-time":
|
||||
return datetime
|
||||
|
||||
# Handle nullable fields
|
||||
base_type = {
|
||||
"string": str,
|
||||
"integer": int,
|
||||
"number": float,
|
||||
"boolean": bool,
|
||||
"array": list,
|
||||
"object": dict,
|
||||
}.get(openapi_type, Any)
|
||||
|
||||
if schema and schema.get("nullable"):
|
||||
return Union[base_type, None] # noqa: UP007
|
||||
|
||||
return base_type
|
||||
|
||||
def _parse_response_model(self, operation: dict[str, Any]) -> type[BaseModel] | None:
|
||||
"""Parse response schema into Pydantic model."""
|
||||
try:
|
||||
responses = operation.get("responses", {})
|
||||
if "200" not in responses:
|
||||
return None
|
||||
|
||||
response = responses["200"]
|
||||
if "$ref" in response:
|
||||
response = self._resolve_ref(response["$ref"])
|
||||
|
||||
content = response.get("content", {})
|
||||
if "application/json" not in content:
|
||||
return None
|
||||
|
||||
schema = content["application/json"].get("schema", {})
|
||||
if "$ref" in schema:
|
||||
schema = self._resolve_ref(schema["$ref"])
|
||||
|
||||
logger.debug("Response schema before model creation: %s", schema)
|
||||
|
||||
# Ensure schema has properties
|
||||
if "properties" not in schema:
|
||||
logger.error("Response schema missing properties")
|
||||
return None
|
||||
|
||||
# Create model with schema properties
|
||||
model = self._create_model("Response", schema)
|
||||
logger.debug("Created response model with schema: %s", model.model_json_schema())
|
||||
return model
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to create response model: %s", e)
|
||||
return None
|
||||
|
||||
def _create_model(self, name: str, schema: dict[str, Any]) -> type[BaseModel]: # noqa C901
|
||||
"""Create Pydantic model from schema.
|
||||
|
||||
Args:
|
||||
name: Model name
|
||||
schema: OpenAPI schema
|
||||
|
||||
Returns:
|
||||
Generated Pydantic model
|
||||
|
||||
Raises:
|
||||
ValueError: If schema is invalid
|
||||
"""
|
||||
if "$ref" in schema:
|
||||
schema = self._resolve_ref(schema["$ref"])
|
||||
|
||||
# Handle discriminated unions
|
||||
if "anyOf" in schema:
|
||||
discriminator = schema.get("discriminator", {}).get("propertyName")
|
||||
if discriminator:
|
||||
union_types = []
|
||||
for subschema in schema["anyOf"]:
|
||||
if "$ref" in subschema:
|
||||
resolved = self._resolve_ref(subschema["$ref"])
|
||||
sub_model = self._create_model(f"{name}_type", resolved)
|
||||
union_types.append(sub_model)
|
||||
|
||||
fields = {
|
||||
"type_name": (str, ...), # Use type_name instead of __type
|
||||
**{k: (Union[tuple(union_types)], ...) for k in schema.get("properties", {}) if k != discriminator}, # noqa: UP007
|
||||
}
|
||||
return create_model(name, **fields)
|
||||
|
||||
if schema.get("type", "object") != "object":
|
||||
raise ValueError("Schema must be an object type")
|
||||
|
||||
fields = {}
|
||||
for prop_name, prop_schema in schema.get("properties", {}).items():
|
||||
if "$ref" in prop_schema:
|
||||
prop_schema = self._resolve_ref(prop_schema["$ref"])
|
||||
|
||||
required = prop_name in schema.get("required", [])
|
||||
|
||||
if prop_schema.get("type") == "object":
|
||||
nested_model = self._create_model(f"{name}_{prop_name}", prop_schema)
|
||||
field_type = nested_model
|
||||
elif prop_schema.get("type") == "array":
|
||||
items = prop_schema.get("items", {})
|
||||
if "$ref" in items:
|
||||
items = self._resolve_ref(items["$ref"])
|
||||
if items.get("type") == "object":
|
||||
item_model = self._create_model(f"{name}_{prop_name}_item", items)
|
||||
field_type = list[item_model]
|
||||
else:
|
||||
item_type = self._map_type(items.get("type", "string"))
|
||||
field_type = list[item_type]
|
||||
else:
|
||||
field_type = self._map_type(prop_schema.get("type", "string"), prop_schema.get("format"), prop_schema)
|
||||
|
||||
# Add validation parameters
|
||||
field_config = {}
|
||||
if "pattern" in prop_schema:
|
||||
field_config["pattern"] = prop_schema["pattern"]
|
||||
if "minLength" in prop_schema:
|
||||
field_config["min_length"] = prop_schema["minLength"]
|
||||
if "maxLength" in prop_schema:
|
||||
field_config["max_length"] = prop_schema["maxLength"]
|
||||
if "minimum" in prop_schema:
|
||||
field_config["ge"] = prop_schema["minimum"]
|
||||
if "maximum" in prop_schema:
|
||||
field_config["le"] = prop_schema["maximum"]
|
||||
|
||||
# Create field with validation
|
||||
if field_config:
|
||||
field = Field(..., **field_config) if required else Field(None, **field_config)
|
||||
fields[prop_name] = (field_type, field)
|
||||
continue
|
||||
|
||||
fields[prop_name] = (field_type, ... if required else None)
|
||||
|
||||
logger.debug("Creating model %s with fields: %s", name, fields)
|
||||
try:
|
||||
return create_model(name, **fields)
|
||||
except Exception as e:
|
||||
logger.error("Error creating model %s: %s", name, e)
|
||||
raise ValueError(f"Failed to create model {name}: {e}")
|
||||
|
||||
def get_operations(self) -> list[str]:
|
||||
"""Get list of all operation IDs from spec."""
|
||||
operations = []
|
||||
|
||||
for path in self._paths.values():
|
||||
for method, operation in path.items():
|
||||
if method.startswith("x-") or method == "parameters":
|
||||
continue
|
||||
if "operationId" in operation:
|
||||
operations.append(operation["operationId"])
|
||||
|
||||
return operations
|
||||
6161
src/airflow_mcp_server/resources/v1.yaml
Normal file
6161
src/airflow_mcp_server/resources/v1.yaml
Normal file
File diff suppressed because it is too large
Load Diff
52
src/airflow_mcp_server/server.py
Normal file
52
src/airflow_mcp_server/server.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from mcp.server import Server
|
||||
from mcp.server.stdio import stdio_server
|
||||
from mcp.types import TextContent, Tool
|
||||
|
||||
from airflow_mcp_server.tools.tool_manager import get_airflow_tools, get_tool
|
||||
|
||||
# ===========THIS IS FOR DEBUGGING WITH MCP INSPECTOR===================
|
||||
# import sys
|
||||
# Configure root logger to stderr
|
||||
# logging.basicConfig(level=logging.DEBUG, format="%(asctime)s [%(levelname)s] %(message)s", handlers=[logging.StreamHandler(sys.stderr)])
|
||||
|
||||
# Disable Uvicorn's default handlers
|
||||
# logging.getLogger("uvicorn.error").handlers = []
|
||||
# logging.getLogger("uvicorn.access").handlers = []
|
||||
# ======================================================================
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def serve() -> None:
|
||||
"""Start MCP server."""
|
||||
required_vars = ["AIRFLOW_BASE_URL", "AUTH_TOKEN"]
|
||||
if not all(var in os.environ for var in required_vars):
|
||||
raise ValueError(f"Missing required environment variables: {required_vars}")
|
||||
|
||||
server = Server("airflow-mcp-server")
|
||||
|
||||
@server.list_tools()
|
||||
async def list_tools() -> list[Tool]:
|
||||
try:
|
||||
return await get_airflow_tools()
|
||||
except Exception as e:
|
||||
logger.error("Failed to list tools: %s", e)
|
||||
raise
|
||||
|
||||
@server.call_tool()
|
||||
async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]:
|
||||
try:
|
||||
tool = await get_tool(name)
|
||||
async with tool.client:
|
||||
result = await tool.run(body=arguments)
|
||||
return [TextContent(type="text", text=str(result))]
|
||||
except Exception as e:
|
||||
logger.error("Tool execution failed: %s", e)
|
||||
raise
|
||||
|
||||
options = server.create_initialization_options()
|
||||
async with stdio_server() as (read_stream, write_stream):
|
||||
await server.run(read_stream, write_stream, options, raise_exceptions=True)
|
||||
45
src/airflow_mcp_server/server_safe.py
Normal file
45
src/airflow_mcp_server/server_safe.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from mcp.server import Server
|
||||
from mcp.server.stdio import stdio_server
|
||||
from mcp.types import TextContent, Tool
|
||||
|
||||
from airflow_mcp_server.tools.tool_manager import get_airflow_tools, get_tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def serve() -> None:
|
||||
"""Start MCP server in safe mode (read-only operations)."""
|
||||
required_vars = ["AIRFLOW_BASE_URL", "AUTH_TOKEN"]
|
||||
if not all(var in os.environ for var in required_vars):
|
||||
raise ValueError(f"Missing required environment variables: {required_vars}")
|
||||
|
||||
server = Server("airflow-mcp-server-safe")
|
||||
|
||||
@server.list_tools()
|
||||
async def list_tools() -> list[Tool]:
|
||||
try:
|
||||
return await get_airflow_tools(mode="safe")
|
||||
except Exception as e:
|
||||
logger.error("Failed to list tools: %s", e)
|
||||
raise
|
||||
|
||||
@server.call_tool()
|
||||
async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]:
|
||||
try:
|
||||
if not name.startswith("get_"):
|
||||
raise ValueError("Only GET operations allowed in safe mode")
|
||||
tool = await get_tool(name)
|
||||
async with tool.client:
|
||||
result = await tool.run(body=arguments)
|
||||
return [TextContent(type="text", text=str(result))]
|
||||
except Exception as e:
|
||||
logger.error("Tool execution failed: %s", e)
|
||||
raise
|
||||
|
||||
options = server.create_initialization_options()
|
||||
async with stdio_server() as (read_stream, write_stream):
|
||||
await server.run(read_stream, write_stream, options, raise_exceptions=True)
|
||||
43
src/airflow_mcp_server/server_unsafe.py
Normal file
43
src/airflow_mcp_server/server_unsafe.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from mcp.server import Server
|
||||
from mcp.server.stdio import stdio_server
|
||||
from mcp.types import TextContent, Tool
|
||||
|
||||
from airflow_mcp_server.tools.tool_manager import get_airflow_tools, get_tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def serve() -> None:
|
||||
"""Start MCP server in unsafe mode (all operations)."""
|
||||
required_vars = ["AIRFLOW_BASE_URL", "AUTH_TOKEN"]
|
||||
if not all(var in os.environ for var in required_vars):
|
||||
raise ValueError(f"Missing required environment variables: {required_vars}")
|
||||
|
||||
server = Server("airflow-mcp-server-unsafe")
|
||||
|
||||
@server.list_tools()
|
||||
async def list_tools() -> list[Tool]:
|
||||
try:
|
||||
return await get_airflow_tools(mode="unsafe")
|
||||
except Exception as e:
|
||||
logger.error("Failed to list tools: %s", e)
|
||||
raise
|
||||
|
||||
@server.call_tool()
|
||||
async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]:
|
||||
try:
|
||||
tool = await get_tool(name)
|
||||
async with tool.client:
|
||||
result = await tool.run(body=arguments)
|
||||
return [TextContent(type="text", text=str(result))]
|
||||
except Exception as e:
|
||||
logger.error("Tool execution failed: %s", e)
|
||||
raise
|
||||
|
||||
options = server.create_initialization_options()
|
||||
async with stdio_server() as (read_stream, write_stream):
|
||||
await server.run(read_stream, write_stream, options, raise_exceptions=True)
|
||||
0
src/airflow_mcp_server/tools/__init__.py
Normal file
0
src/airflow_mcp_server/tools/__init__.py
Normal file
81
src/airflow_mcp_server/tools/airflow_tool.py
Normal file
81
src/airflow_mcp_server/tools/airflow_tool.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from airflow_mcp_server.client.airflow_client import AirflowClient
|
||||
from airflow_mcp_server.parser.operation_parser import OperationDetails
|
||||
from airflow_mcp_server.tools.base_tools import BaseTools
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_validation_error(field: str, message: str) -> ValidationError:
|
||||
"""Create a properly formatted validation error.
|
||||
|
||||
Args:
|
||||
field: The field that failed validation
|
||||
message: The error message
|
||||
|
||||
Returns:
|
||||
ValidationError: A properly formatted validation error
|
||||
"""
|
||||
errors = [
|
||||
{
|
||||
"loc": (field,),
|
||||
"msg": message,
|
||||
"type": "value_error",
|
||||
"input": None,
|
||||
"ctx": {"error": message},
|
||||
}
|
||||
]
|
||||
return ValidationError.from_exception_data("validation_error", errors)
|
||||
|
||||
|
||||
class AirflowTool(BaseTools):
|
||||
"""
|
||||
Tool for executing Airflow API operations.
|
||||
AirflowTool is supposed to have objects per operation.
|
||||
"""
|
||||
|
||||
def __init__(self, operation_details: OperationDetails, client: AirflowClient) -> None:
|
||||
"""Initialize tool with operation details and client.
|
||||
|
||||
Args:
|
||||
operation_details: Operation details
|
||||
client: AirflowClient instance
|
||||
"""
|
||||
super().__init__()
|
||||
self.operation = operation_details
|
||||
self.client = client
|
||||
|
||||
async def run(
|
||||
self,
|
||||
body: dict[str, Any] | None = None,
|
||||
) -> Any:
|
||||
"""Execute the operation with provided parameters."""
|
||||
try:
|
||||
# Validate input
|
||||
validated_input = self.operation.input_model(**(body or {}))
|
||||
validated_body = validated_input.model_dump(exclude_none=True) # Only include non-None values
|
||||
|
||||
mapping = self.operation.input_model.model_config["parameter_mapping"]
|
||||
path_params = {k: validated_body[k] for k in mapping.get("path", []) if k in validated_body}
|
||||
query_params = {k: validated_body[k] for k in mapping.get("query", []) if k in validated_body}
|
||||
body_params = {k: validated_body[k] for k in mapping.get("body", []) if k in validated_body}
|
||||
|
||||
# Execute operation and return raw response
|
||||
response = await self.client.execute(
|
||||
operation_id=self.operation.operation_id,
|
||||
path_params=path_params or None,
|
||||
query_params=query_params or None,
|
||||
body=body_params or None,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except ValidationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Operation execution failed: %s", e)
|
||||
raise
|
||||
19
src/airflow_mcp_server/tools/base_tools.py
Normal file
19
src/airflow_mcp_server/tools/base_tools.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
class BaseTools(ABC):
|
||||
"""Abstract base class for tools."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the tool."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def run(self) -> Any:
|
||||
"""Execute the tool's main functionality.
|
||||
|
||||
Returns:
|
||||
Any: The result of the tool execution
|
||||
"""
|
||||
raise NotImplementedError
|
||||
126
src/airflow_mcp_server/tools/tool_manager.py
Normal file
126
src/airflow_mcp_server/tools/tool_manager.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import logging
|
||||
import os
|
||||
from importlib import resources
|
||||
|
||||
from mcp.types import Tool
|
||||
|
||||
from airflow_mcp_server.client.airflow_client import AirflowClient
|
||||
from airflow_mcp_server.parser.operation_parser import OperationParser
|
||||
from airflow_mcp_server.tools.airflow_tool import AirflowTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_tools_cache: dict[str, AirflowTool] = {}
|
||||
|
||||
|
||||
def _initialize_client() -> AirflowClient:
|
||||
"""Initialize Airflow client with environment variables or embedded spec.
|
||||
|
||||
Returns:
|
||||
AirflowClient instance
|
||||
|
||||
Raises:
|
||||
ValueError: If required environment variables are missing or default spec is not found
|
||||
"""
|
||||
spec_path = os.environ.get("OPENAPI_SPEC")
|
||||
if not spec_path:
|
||||
# Fallback to embedded v1.yaml
|
||||
try:
|
||||
with resources.files("airflow_mcp_server.resources").joinpath("v1.yaml").open("rb") as f:
|
||||
spec_path = f.name
|
||||
logger.info("OPENAPI_SPEC not set; using embedded v1.yaml from %s", spec_path)
|
||||
except Exception as e:
|
||||
raise ValueError("Default OpenAPI spec not found in package resources") from e
|
||||
|
||||
required_vars = ["AIRFLOW_BASE_URL", "AUTH_TOKEN"]
|
||||
missing_vars = [var for var in required_vars if var not in os.environ]
|
||||
if missing_vars:
|
||||
raise ValueError(f"Missing required environment variables: {missing_vars}")
|
||||
|
||||
return AirflowClient(spec_path=spec_path, base_url=os.environ["AIRFLOW_BASE_URL"], auth_token=os.environ["AUTH_TOKEN"])
|
||||
|
||||
|
||||
async def _initialize_tools() -> None:
|
||||
"""Initialize tools cache with Airflow operations.
|
||||
|
||||
Raises:
|
||||
ValueError: If initialization fails
|
||||
"""
|
||||
global _tools_cache
|
||||
|
||||
try:
|
||||
client = _initialize_client()
|
||||
spec_path = os.environ.get("OPENAPI_SPEC")
|
||||
if not spec_path:
|
||||
with resources.files("airflow_mcp_server.resources").joinpath("v1.yaml").open("rb") as f:
|
||||
spec_path = f.name
|
||||
parser = OperationParser(spec_path)
|
||||
|
||||
# Generate tools for each operation
|
||||
for operation_id in parser.get_operations():
|
||||
operation_details = parser.parse_operation(operation_id)
|
||||
tool = AirflowTool(operation_details, client)
|
||||
_tools_cache[operation_id] = tool
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to initialize tools: %s", e)
|
||||
_tools_cache.clear()
|
||||
raise ValueError(f"Failed to initialize tools: {e}") from e
|
||||
|
||||
|
||||
async def get_airflow_tools(mode: str = "unsafe") -> list[Tool]:
|
||||
"""Get list of available Airflow tools based on mode.
|
||||
|
||||
Args:
|
||||
mode: "safe" for GET operations only, "unsafe" for all operations (default)
|
||||
|
||||
Returns:
|
||||
List of MCP Tool objects representing available operations
|
||||
|
||||
Raises:
|
||||
ValueError: If initialization fails
|
||||
"""
|
||||
if not _tools_cache:
|
||||
await _initialize_tools()
|
||||
|
||||
tools = []
|
||||
for operation_id, tool in _tools_cache.items():
|
||||
try:
|
||||
# Skip non-GET operations in safe mode
|
||||
if mode == "safe" and not tool.operation.method.lower() == "get":
|
||||
continue
|
||||
schema = tool.operation.input_model.model_json_schema()
|
||||
tools.append(
|
||||
Tool(
|
||||
name=operation_id,
|
||||
description=tool.operation.operation_id,
|
||||
inputSchema=schema,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to create tool schema for %s: %s", operation_id, e)
|
||||
continue
|
||||
|
||||
return tools
|
||||
|
||||
|
||||
async def get_tool(name: str) -> AirflowTool:
|
||||
"""Get specific tool by name.
|
||||
|
||||
Args:
|
||||
name: Tool/operation name
|
||||
|
||||
Returns:
|
||||
AirflowTool instance
|
||||
|
||||
Raises:
|
||||
KeyError: If tool not found
|
||||
ValueError: If tool initialization fails
|
||||
"""
|
||||
if not _tools_cache:
|
||||
await _initialize_tools()
|
||||
|
||||
if name not in _tools_cache:
|
||||
raise KeyError(f"Tool {name} not found")
|
||||
|
||||
return _tools_cache[name]
|
||||
Reference in New Issue
Block a user