separate safe and unsafe servers
This commit is contained in:
@@ -1,19 +1,19 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
|
||||
from airflow_mcp_server.server import serve
|
||||
from airflow_mcp_server.server_safe import serve as serve_safe
|
||||
from airflow_mcp_server.server_unsafe import serve as serve_unsafe
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("-v", "--verbose", count=True)
|
||||
def main(verbose: bool) -> None:
|
||||
@click.option("-v", "--verbose", count=True, help="Increase verbosity")
|
||||
@click.option("--safe", "-s", is_flag=True, help="Use only read-only tools")
|
||||
@click.option("--unsafe", "-u", is_flag=True, help="Use all tools (default)")
|
||||
def main(verbose: int, safe: bool, unsafe: bool) -> None:
|
||||
"""MCP server for Airflow"""
|
||||
import asyncio
|
||||
|
||||
logging_level = logging.WARN
|
||||
if verbose == 1:
|
||||
logging_level = logging.INFO
|
||||
@@ -21,7 +21,14 @@ def main(verbose: bool) -> None:
|
||||
logging_level = logging.DEBUG
|
||||
|
||||
logging.basicConfig(level=logging_level, stream=sys.stderr)
|
||||
asyncio.run(serve())
|
||||
|
||||
if safe and unsafe:
|
||||
raise click.UsageError("Options --safe and --unsafe are mutually exclusive")
|
||||
|
||||
if safe:
|
||||
asyncio.run(serve_safe())
|
||||
else: # Default to unsafe mode
|
||||
asyncio.run(serve_unsafe())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
45
airflow-mcp-server/src/airflow_mcp_server/server_safe.py
Normal file
45
airflow-mcp-server/src/airflow_mcp_server/server_safe.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from mcp.server import Server
|
||||
from mcp.server.stdio import stdio_server
|
||||
from mcp.types import TextContent, Tool
|
||||
|
||||
from airflow_mcp_server.tools.tool_manager import get_airflow_tools, get_tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def serve() -> None:
|
||||
"""Start MCP server in safe mode (read-only operations)."""
|
||||
required_vars = ["AIRFLOW_BASE_URL", "AUTH_TOKEN"]
|
||||
if not all(var in os.environ for var in required_vars):
|
||||
raise ValueError(f"Missing required environment variables: {required_vars}")
|
||||
|
||||
server = Server("airflow-mcp-server-safe")
|
||||
|
||||
@server.list_tools()
|
||||
async def list_tools() -> list[Tool]:
|
||||
try:
|
||||
return await get_airflow_tools(mode="safe")
|
||||
except Exception as e:
|
||||
logger.error("Failed to list tools: %s", e)
|
||||
raise
|
||||
|
||||
@server.call_tool()
|
||||
async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]:
|
||||
try:
|
||||
if not name.startswith("get_"):
|
||||
raise ValueError("Only GET operations allowed in safe mode")
|
||||
tool = await get_tool(name)
|
||||
async with tool.client:
|
||||
result = await tool.run(body=arguments)
|
||||
return [TextContent(type="text", text=str(result))]
|
||||
except Exception as e:
|
||||
logger.error("Tool execution failed: %s", e)
|
||||
raise
|
||||
|
||||
options = server.create_initialization_options()
|
||||
async with stdio_server() as (read_stream, write_stream):
|
||||
await server.run(read_stream, write_stream, options, raise_exceptions=True)
|
||||
43
airflow-mcp-server/src/airflow_mcp_server/server_unsafe.py
Normal file
43
airflow-mcp-server/src/airflow_mcp_server/server_unsafe.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from mcp.server import Server
|
||||
from mcp.server.stdio import stdio_server
|
||||
from mcp.types import TextContent, Tool
|
||||
|
||||
from airflow_mcp_server.tools.tool_manager import get_airflow_tools, get_tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def serve() -> None:
|
||||
"""Start MCP server in unsafe mode (all operations)."""
|
||||
required_vars = ["AIRFLOW_BASE_URL", "AUTH_TOKEN"]
|
||||
if not all(var in os.environ for var in required_vars):
|
||||
raise ValueError(f"Missing required environment variables: {required_vars}")
|
||||
|
||||
server = Server("airflow-mcp-server-unsafe")
|
||||
|
||||
@server.list_tools()
|
||||
async def list_tools() -> list[Tool]:
|
||||
try:
|
||||
return await get_airflow_tools(mode="unsafe")
|
||||
except Exception as e:
|
||||
logger.error("Failed to list tools: %s", e)
|
||||
raise
|
||||
|
||||
@server.call_tool()
|
||||
async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]:
|
||||
try:
|
||||
tool = await get_tool(name)
|
||||
async with tool.client:
|
||||
result = await tool.run(body=arguments)
|
||||
return [TextContent(type="text", text=str(result))]
|
||||
except Exception as e:
|
||||
logger.error("Tool execution failed: %s", e)
|
||||
raise
|
||||
|
||||
options = server.create_initialization_options()
|
||||
async with stdio_server() as (read_stream, write_stream):
|
||||
await server.run(read_stream, write_stream, options, raise_exceptions=True)
|
||||
@@ -68,8 +68,11 @@ async def _initialize_tools() -> None:
|
||||
raise ValueError(f"Failed to initialize tools: {e}") from e
|
||||
|
||||
|
||||
async def get_airflow_tools() -> list[Tool]:
|
||||
"""Get list of all available Airflow tools.
|
||||
async def get_airflow_tools(mode: str = "unsafe") -> list[Tool]:
|
||||
"""Get list of available Airflow tools based on mode.
|
||||
|
||||
Args:
|
||||
mode: "safe" for GET operations only, "unsafe" for all operations (default)
|
||||
|
||||
Returns:
|
||||
List of MCP Tool objects representing available operations
|
||||
@@ -83,6 +86,9 @@ async def get_airflow_tools() -> list[Tool]:
|
||||
tools = []
|
||||
for operation_id, tool in _tools_cache.items():
|
||||
try:
|
||||
# Skip non-GET operations in safe mode
|
||||
if mode == "safe" and not tool.operation.method.lower() == "get":
|
||||
continue
|
||||
schema = tool.operation.input_model.model_json_schema()
|
||||
tools.append(
|
||||
Tool(
|
||||
|
||||
Reference in New Issue
Block a user