Add AirflowConfig class for centralized configuration management
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
import logging
|
||||
import os
|
||||
from importlib import resources
|
||||
|
||||
from mcp.types import Tool
|
||||
|
||||
from airflow_mcp_server.client.airflow_client import AirflowClient
|
||||
from airflow_mcp_server.config import AirflowConfig
|
||||
from airflow_mcp_server.parser.operation_parser import OperationParser
|
||||
from airflow_mcp_server.tools.airflow_tool import AirflowTool
|
||||
|
||||
@@ -13,16 +13,19 @@ logger = logging.getLogger(__name__)
|
||||
_tools_cache: dict[str, AirflowTool] = {}
|
||||
|
||||
|
||||
def _initialize_client() -> AirflowClient:
|
||||
"""Initialize Airflow client with environment variables or embedded spec.
|
||||
def _initialize_client(config: AirflowConfig) -> AirflowClient:
|
||||
"""Initialize Airflow client with configuration.
|
||||
|
||||
Args:
|
||||
config: Configuration object with auth and URL settings
|
||||
|
||||
Returns:
|
||||
AirflowClient instance
|
||||
|
||||
Raises:
|
||||
ValueError: If required environment variables are missing or default spec is not found
|
||||
ValueError: If default spec is not found
|
||||
"""
|
||||
spec_path = os.environ.get("OPENAPI_SPEC")
|
||||
spec_path = config.spec_path
|
||||
if not spec_path:
|
||||
# Fallback to embedded v1.yaml
|
||||
try:
|
||||
@@ -32,41 +35,33 @@ def _initialize_client() -> AirflowClient:
|
||||
except Exception as e:
|
||||
raise ValueError("Default OpenAPI spec not found in package resources") from e
|
||||
|
||||
# Check for base URL
|
||||
if "AIRFLOW_BASE_URL" not in os.environ:
|
||||
raise ValueError("Missing required environment variable: AIRFLOW_BASE_URL")
|
||||
|
||||
# Check for either AUTH_TOKEN or COOKIE
|
||||
has_auth_token = "AUTH_TOKEN" in os.environ
|
||||
has_cookie = "COOKIE" in os.environ
|
||||
|
||||
if not has_auth_token and not has_cookie:
|
||||
raise ValueError("Either AUTH_TOKEN or COOKIE environment variable must be provided")
|
||||
|
||||
# Initialize client with appropriate authentication method
|
||||
client_args = {"spec_path": spec_path, "base_url": os.environ["AIRFLOW_BASE_URL"]}
|
||||
client_args = {"spec_path": spec_path, "base_url": config.base_url}
|
||||
|
||||
# Apply cookie auth first if available (highest precedence)
|
||||
if has_cookie:
|
||||
client_args["cookie"] = os.environ["COOKIE"]
|
||||
if config.cookie:
|
||||
client_args["cookie"] = config.cookie
|
||||
# Otherwise use auth token if available
|
||||
elif has_auth_token:
|
||||
client_args["auth_token"] = os.environ["AUTH_TOKEN"]
|
||||
elif config.auth_token:
|
||||
client_args["auth_token"] = config.auth_token
|
||||
|
||||
return AirflowClient(**client_args)
|
||||
|
||||
|
||||
async def _initialize_tools() -> None:
|
||||
async def _initialize_tools(config: AirflowConfig) -> None:
|
||||
"""Initialize tools cache with Airflow operations.
|
||||
|
||||
Args:
|
||||
config: Configuration object with auth and URL settings
|
||||
|
||||
Raises:
|
||||
ValueError: If initialization fails
|
||||
"""
|
||||
global _tools_cache
|
||||
|
||||
try:
|
||||
client = _initialize_client()
|
||||
spec_path = os.environ.get("OPENAPI_SPEC")
|
||||
client = _initialize_client(config)
|
||||
spec_path = config.spec_path
|
||||
if not spec_path:
|
||||
with resources.files("airflow_mcp_server.resources").joinpath("v1.yaml").open("rb") as f:
|
||||
spec_path = f.name
|
||||
@@ -84,10 +79,11 @@ async def _initialize_tools() -> None:
|
||||
raise ValueError(f"Failed to initialize tools: {e}") from e
|
||||
|
||||
|
||||
async def get_airflow_tools(mode: str = "unsafe") -> list[Tool]:
|
||||
async def get_airflow_tools(config: AirflowConfig, mode: str = "unsafe") -> list[Tool]:
|
||||
"""Get list of available Airflow tools based on mode.
|
||||
|
||||
Args:
|
||||
config: Configuration object with auth and URL settings
|
||||
mode: "safe" for GET operations only, "unsafe" for all operations (default)
|
||||
|
||||
Returns:
|
||||
@@ -97,7 +93,7 @@ async def get_airflow_tools(mode: str = "unsafe") -> list[Tool]:
|
||||
ValueError: If initialization fails
|
||||
"""
|
||||
if not _tools_cache:
|
||||
await _initialize_tools()
|
||||
await _initialize_tools(config)
|
||||
|
||||
tools = []
|
||||
for operation_id, tool in _tools_cache.items():
|
||||
@@ -120,10 +116,11 @@ async def get_airflow_tools(mode: str = "unsafe") -> list[Tool]:
|
||||
return tools
|
||||
|
||||
|
||||
async def get_tool(name: str) -> AirflowTool:
|
||||
async def get_tool(config: AirflowConfig, name: str) -> AirflowTool:
|
||||
"""Get specific tool by name.
|
||||
|
||||
Args:
|
||||
config: Configuration object with auth and URL settings
|
||||
name: Tool/operation name
|
||||
|
||||
Returns:
|
||||
@@ -134,7 +131,7 @@ async def get_tool(name: str) -> AirflowTool:
|
||||
ValueError: If tool initialization fails
|
||||
"""
|
||||
if not _tools_cache:
|
||||
await _initialize_tools()
|
||||
await _initialize_tools(config)
|
||||
|
||||
if name not in _tools_cache:
|
||||
raise KeyError(f"Tool {name} not found")
|
||||
|
||||
Reference in New Issue
Block a user