diff --git a/airflow-mcp-server/src/airflow_mcp_server/parser/operation_parser.py b/airflow-mcp-server/src/airflow_mcp_server/parser/operation_parser.py index 588cbf2..fca8ca5 100644 --- a/airflow-mcp-server/src/airflow_mcp_server/parser/operation_parser.py +++ b/airflow-mcp-server/src/airflow_mcp_server/parser/operation_parser.py @@ -119,28 +119,34 @@ class OperationParser: ) -> type[BaseModel]: """Create unified input model for all parameters.""" fields: dict[str, tuple[type, Any]] = {} + parameter_mapping = {"path": [], "query": [], "body": []} # Add path parameters for name, schema in parameters.get("path", {}).items(): field_type = schema["type"] required = schema.get("required", True) # Path parameters are required by default - fields[f"path_{name}"] = (field_type, ... if required else None) + fields[name] = (field_type, ... if required else None) + parameter_mapping["path"].append(name) # Add query parameters for name, schema in parameters.get("query", {}).items(): field_type = schema["type"] required = schema.get("required", False) # Query parameters are optional by default - fields[f"query_{name}"] = (field_type, ... if required else None) + fields[name] = (field_type, ... if required else None) + parameter_mapping["query"].append(name) # Add body fields if present if body_schema and body_schema.get("type") == "object": for prop_name, prop_schema in body_schema.get("properties", {}).items(): field_type = self._map_type(prop_schema.get("type", "string")) required = prop_name in body_schema.get("required", []) - fields[f"body_{prop_name}"] = (field_type, ... if required else None) + fields[prop_name] = (field_type, ... if required else None) + parameter_mapping["body"].append(prop_name) logger.debug("Creating input model for %s with fields: %s", operation_id, fields) - return create_model(f"{operation_id}_input", **fields) + model = create_model(f"{operation_id}_input", **fields) + model.model_config["parameter_mapping"] = parameter_mapping + return model def extract_parameters(self, operation: dict[str, Any]) -> dict[str, Any]: """Extract and categorize operation parameters. diff --git a/airflow-mcp-server/src/airflow_mcp_server/server.py b/airflow-mcp-server/src/airflow_mcp_server/server.py index 8691efe..faf2dbe 100644 --- a/airflow-mcp-server/src/airflow_mcp_server/server.py +++ b/airflow-mcp-server/src/airflow_mcp_server/server.py @@ -21,13 +21,18 @@ async def serve() -> None: @server.list_tools() async def list_tools() -> list[Tool]: - return get_airflow_tools() + try: + return await get_airflow_tools() + except Exception as e: + logger.error("Failed to list tools: %s", e) + raise @server.call_tool() async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]: try: - tool = get_tool(name) - result = await tool.run(**arguments) + tool = await get_tool(name) + async with tool.client: + result = await tool.run(body=arguments) return [TextContent(type="text", text=str(result))] except Exception as e: logger.error("Tool execution failed: %s", e) diff --git a/airflow-mcp-server/src/airflow_mcp_server/tools/airflow_tool.py b/airflow-mcp-server/src/airflow_mcp_server/tools/airflow_tool.py index 75977df..fa6f42d 100644 --- a/airflow-mcp-server/src/airflow_mcp_server/tools/airflow_tool.py +++ b/airflow-mcp-server/src/airflow_mcp_server/tools/airflow_tool.py @@ -49,82 +49,23 @@ class AirflowTool(BaseTools): self.operation = operation_details self.client = client - def _validate_input( - self, - path_params: dict[str, Any] | None = None, - query_params: dict[str, Any] | None = None, - body: dict[str, Any] | None = None, - ) -> dict[str, Any]: - """Validate input parameters using unified input model. - - Args: - path_params: Path parameters - query_params: Query parameters - body: Body parameters - - Returns: - dict[str, Any]: Validated input parameters - """ - try: - input_data = {} - - if path_params: - input_data.update({f"path_{k}": v for k, v in path_params.items()}) - - if query_params: - input_data.update({f"query_{k}": v for k, v in query_params.items()}) - - if body: - input_data.update({f"body_{k}": v for k, v in body.items()}) - - validated = self.operation.input_model(**input_data) - return validated.model_dump() - - except ValidationError as e: - logger.error("Input validation failed: %s", e) - raise - - def _extract_parameters(self, validated_input: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]: - """Extract validated parameters by type.""" - path_params = {} - query_params = {} - body = {} - - # Extract parameters based on operation definition - for key, value in validated_input.items(): - # Remove prefix from key if present - param_key = key - if key.startswith(("path_", "query_", "body_")): - param_key = key.split("_", 1)[1] - - if key.startswith("path_"): - path_params[param_key] = value - elif key.startswith("query_"): - query_params[param_key] = value - elif key.startswith("body_"): - body[param_key] = value - else: - body[key] = value - - return path_params, query_params, body - async def run( self, - path_params: dict[str, Any] | None = None, - query_params: dict[str, Any] | None = None, body: dict[str, Any] | None = None, ) -> Any: """Execute the operation with provided parameters.""" try: - validated_input = self._validate_input(path_params, query_params, body) - path_params, query_params, body = self._extract_parameters(validated_input) + mapping = self.operation.input_model.model_config["parameter_mapping"] + path_params = {k: body[k] for k in mapping.get("path", []) if k in body} + query_params = {k: body[k] for k in mapping.get("query", []) if k in body} + body_params = {k: body[k] for k in mapping.get("body", []) if k in body} # Execute operation response = await self.client.execute( operation_id=self.operation.operation_id, path_params=path_params, query_params=query_params, - body=body, + body=body_params, ) # Validate response if model exists diff --git a/airflow-mcp-server/src/airflow_mcp_server/tools/tool_manager.py b/airflow-mcp-server/src/airflow_mcp_server/tools/tool_manager.py index fc7792a..96af499 100644 --- a/airflow-mcp-server/src/airflow_mcp_server/tools/tool_manager.py +++ b/airflow-mcp-server/src/airflow_mcp_server/tools/tool_manager.py @@ -10,55 +10,80 @@ from airflow_mcp_server.tools.airflow_tool import AirflowTool logger = logging.getLogger(__name__) _tools_cache: dict[str, AirflowTool] = {} -_client: AirflowClient | None = None -def get_airflow_tools() -> list[Tool]: +def _initialize_client() -> AirflowClient: + """Initialize Airflow client with environment variables. + + Returns: + AirflowClient instance + + Raises: + ValueError: If required environment variables are missing + """ + required_vars = ["OPENAPI_SPEC", "AIRFLOW_BASE_URL", "AUTH_TOKEN"] + missing_vars = [var for var in required_vars if var not in os.environ] + if missing_vars: + raise ValueError(f"Missing required environment variables: {missing_vars}") + + return AirflowClient(spec_path=os.environ["OPENAPI_SPEC"], base_url=os.environ["AIRFLOW_BASE_URL"], auth_token=os.environ["AUTH_TOKEN"]) + + +async def _initialize_tools() -> None: + """Initialize tools cache with Airflow operations. + + Raises: + ValueError: If initialization fails + """ + global _tools_cache + + try: + client = _initialize_client() + parser = OperationParser(os.environ["OPENAPI_SPEC"]) + + # Generate tools for each operation + for operation_id in parser.get_operations(): + operation_details = parser.parse_operation(operation_id) + tool = AirflowTool(operation_details, client) + _tools_cache[operation_id] = tool + + except Exception as e: + logger.error("Failed to initialize tools: %s", e) + _tools_cache.clear() + raise ValueError(f"Failed to initialize tools: {e}") from e + + +async def get_airflow_tools() -> list[Tool]: """Get list of all available Airflow tools. Returns: List of MCP Tool objects representing available operations Raises: - ValueError: If required environment variables are missing + ValueError: If required environment variables are missing or initialization fails """ - global _tools_cache, _client - if not _tools_cache: - required_vars = ["OPENAPI_SPEC", "AIRFLOW_BASE_URL", "AUTH_TOKEN"] - if not all(var in os.environ for var in required_vars): - raise ValueError(f"Missing required environment variables: {required_vars}") - - # Initialize client if not exists - if not _client: - _client = AirflowClient(spec_path=os.environ["OPENAPI_SPEC"], base_url=os.environ["AIRFLOW_BASE_URL"], auth_token=os.environ["AUTH_TOKEN"]) + await _initialize_tools() + tools = [] + for operation_id, tool in _tools_cache.items(): try: - # Create parser - parser = OperationParser(os.environ["OPENAPI_SPEC"]) - - # Generate tools for each operation - for operation_id in parser.get_operations(): - operation_details = parser.parse_operation(operation_id) - tool = AirflowTool(operation_details, _client) - _tools_cache[operation_id] = tool - + schema = tool.operation.input_model.model_json_schema() + tools.append( + Tool( + name=operation_id, + description=tool.operation.operation_id, + inputSchema=schema, + ) + ) except Exception as e: - logger.error("Failed to initialize tools: %s", e) - raise + logger.error("Failed to create tool schema for %s: %s", operation_id, e) + continue - # Convert to MCP Tool format - return [ - Tool( - name=operation_id, - description=tool.operation.operation_id, - inputSchema=tool.operation.input_model.model_json_schema(), - ) - for operation_id, tool in _tools_cache.items() - ] + return tools -def get_tool(name: str) -> AirflowTool: +async def get_tool(name: str) -> AirflowTool: """Get specific tool by name. Args: @@ -69,9 +94,12 @@ def get_tool(name: str) -> AirflowTool: Raises: KeyError: If tool not found + ValueError: If tool initialization fails """ + if not _tools_cache: + await _initialize_tools() + if name not in _tools_cache: - # Ensure cache is populated - get_airflow_tools() + raise KeyError(f"Tool {name} not found") return _tools_cache[name]