Mapping dict for path query and body
This commit is contained in:
@@ -119,28 +119,34 @@ class OperationParser:
|
|||||||
) -> type[BaseModel]:
|
) -> type[BaseModel]:
|
||||||
"""Create unified input model for all parameters."""
|
"""Create unified input model for all parameters."""
|
||||||
fields: dict[str, tuple[type, Any]] = {}
|
fields: dict[str, tuple[type, Any]] = {}
|
||||||
|
parameter_mapping = {"path": [], "query": [], "body": []}
|
||||||
|
|
||||||
# Add path parameters
|
# Add path parameters
|
||||||
for name, schema in parameters.get("path", {}).items():
|
for name, schema in parameters.get("path", {}).items():
|
||||||
field_type = schema["type"]
|
field_type = schema["type"]
|
||||||
required = schema.get("required", True) # Path parameters are required by default
|
required = schema.get("required", True) # Path parameters are required by default
|
||||||
fields[f"path_{name}"] = (field_type, ... if required else None)
|
fields[name] = (field_type, ... if required else None)
|
||||||
|
parameter_mapping["path"].append(name)
|
||||||
|
|
||||||
# Add query parameters
|
# Add query parameters
|
||||||
for name, schema in parameters.get("query", {}).items():
|
for name, schema in parameters.get("query", {}).items():
|
||||||
field_type = schema["type"]
|
field_type = schema["type"]
|
||||||
required = schema.get("required", False) # Query parameters are optional by default
|
required = schema.get("required", False) # Query parameters are optional by default
|
||||||
fields[f"query_{name}"] = (field_type, ... if required else None)
|
fields[name] = (field_type, ... if required else None)
|
||||||
|
parameter_mapping["query"].append(name)
|
||||||
|
|
||||||
# Add body fields if present
|
# Add body fields if present
|
||||||
if body_schema and body_schema.get("type") == "object":
|
if body_schema and body_schema.get("type") == "object":
|
||||||
for prop_name, prop_schema in body_schema.get("properties", {}).items():
|
for prop_name, prop_schema in body_schema.get("properties", {}).items():
|
||||||
field_type = self._map_type(prop_schema.get("type", "string"))
|
field_type = self._map_type(prop_schema.get("type", "string"))
|
||||||
required = prop_name in body_schema.get("required", [])
|
required = prop_name in body_schema.get("required", [])
|
||||||
fields[f"body_{prop_name}"] = (field_type, ... if required else None)
|
fields[prop_name] = (field_type, ... if required else None)
|
||||||
|
parameter_mapping["body"].append(prop_name)
|
||||||
|
|
||||||
logger.debug("Creating input model for %s with fields: %s", operation_id, fields)
|
logger.debug("Creating input model for %s with fields: %s", operation_id, fields)
|
||||||
return create_model(f"{operation_id}_input", **fields)
|
model = create_model(f"{operation_id}_input", **fields)
|
||||||
|
model.model_config["parameter_mapping"] = parameter_mapping
|
||||||
|
return model
|
||||||
|
|
||||||
def extract_parameters(self, operation: dict[str, Any]) -> dict[str, Any]:
|
def extract_parameters(self, operation: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""Extract and categorize operation parameters.
|
"""Extract and categorize operation parameters.
|
||||||
|
|||||||
@@ -21,13 +21,18 @@ async def serve() -> None:
|
|||||||
|
|
||||||
@server.list_tools()
|
@server.list_tools()
|
||||||
async def list_tools() -> list[Tool]:
|
async def list_tools() -> list[Tool]:
|
||||||
return get_airflow_tools()
|
try:
|
||||||
|
return await get_airflow_tools()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to list tools: %s", e)
|
||||||
|
raise
|
||||||
|
|
||||||
@server.call_tool()
|
@server.call_tool()
|
||||||
async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]:
|
async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]:
|
||||||
try:
|
try:
|
||||||
tool = get_tool(name)
|
tool = await get_tool(name)
|
||||||
result = await tool.run(**arguments)
|
async with tool.client:
|
||||||
|
result = await tool.run(body=arguments)
|
||||||
return [TextContent(type="text", text=str(result))]
|
return [TextContent(type="text", text=str(result))]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Tool execution failed: %s", e)
|
logger.error("Tool execution failed: %s", e)
|
||||||
|
|||||||
@@ -49,82 +49,23 @@ class AirflowTool(BaseTools):
|
|||||||
self.operation = operation_details
|
self.operation = operation_details
|
||||||
self.client = client
|
self.client = client
|
||||||
|
|
||||||
def _validate_input(
|
|
||||||
self,
|
|
||||||
path_params: dict[str, Any] | None = None,
|
|
||||||
query_params: dict[str, Any] | None = None,
|
|
||||||
body: dict[str, Any] | None = None,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Validate input parameters using unified input model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path_params: Path parameters
|
|
||||||
query_params: Query parameters
|
|
||||||
body: Body parameters
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict[str, Any]: Validated input parameters
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
input_data = {}
|
|
||||||
|
|
||||||
if path_params:
|
|
||||||
input_data.update({f"path_{k}": v for k, v in path_params.items()})
|
|
||||||
|
|
||||||
if query_params:
|
|
||||||
input_data.update({f"query_{k}": v for k, v in query_params.items()})
|
|
||||||
|
|
||||||
if body:
|
|
||||||
input_data.update({f"body_{k}": v for k, v in body.items()})
|
|
||||||
|
|
||||||
validated = self.operation.input_model(**input_data)
|
|
||||||
return validated.model_dump()
|
|
||||||
|
|
||||||
except ValidationError as e:
|
|
||||||
logger.error("Input validation failed: %s", e)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _extract_parameters(self, validated_input: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
|
|
||||||
"""Extract validated parameters by type."""
|
|
||||||
path_params = {}
|
|
||||||
query_params = {}
|
|
||||||
body = {}
|
|
||||||
|
|
||||||
# Extract parameters based on operation definition
|
|
||||||
for key, value in validated_input.items():
|
|
||||||
# Remove prefix from key if present
|
|
||||||
param_key = key
|
|
||||||
if key.startswith(("path_", "query_", "body_")):
|
|
||||||
param_key = key.split("_", 1)[1]
|
|
||||||
|
|
||||||
if key.startswith("path_"):
|
|
||||||
path_params[param_key] = value
|
|
||||||
elif key.startswith("query_"):
|
|
||||||
query_params[param_key] = value
|
|
||||||
elif key.startswith("body_"):
|
|
||||||
body[param_key] = value
|
|
||||||
else:
|
|
||||||
body[key] = value
|
|
||||||
|
|
||||||
return path_params, query_params, body
|
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self,
|
self,
|
||||||
path_params: dict[str, Any] | None = None,
|
|
||||||
query_params: dict[str, Any] | None = None,
|
|
||||||
body: dict[str, Any] | None = None,
|
body: dict[str, Any] | None = None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Execute the operation with provided parameters."""
|
"""Execute the operation with provided parameters."""
|
||||||
try:
|
try:
|
||||||
validated_input = self._validate_input(path_params, query_params, body)
|
mapping = self.operation.input_model.model_config["parameter_mapping"]
|
||||||
path_params, query_params, body = self._extract_parameters(validated_input)
|
path_params = {k: body[k] for k in mapping.get("path", []) if k in body}
|
||||||
|
query_params = {k: body[k] for k in mapping.get("query", []) if k in body}
|
||||||
|
body_params = {k: body[k] for k in mapping.get("body", []) if k in body}
|
||||||
|
|
||||||
# Execute operation
|
# Execute operation
|
||||||
response = await self.client.execute(
|
response = await self.client.execute(
|
||||||
operation_id=self.operation.operation_id,
|
operation_id=self.operation.operation_id,
|
||||||
path_params=path_params,
|
path_params=path_params,
|
||||||
query_params=query_params,
|
query_params=query_params,
|
||||||
body=body,
|
body=body_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate response if model exists
|
# Validate response if model exists
|
||||||
|
|||||||
@@ -10,55 +10,80 @@ from airflow_mcp_server.tools.airflow_tool import AirflowTool
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_tools_cache: dict[str, AirflowTool] = {}
|
_tools_cache: dict[str, AirflowTool] = {}
|
||||||
_client: AirflowClient | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_airflow_tools() -> list[Tool]:
|
def _initialize_client() -> AirflowClient:
|
||||||
|
"""Initialize Airflow client with environment variables.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AirflowClient instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If required environment variables are missing
|
||||||
|
"""
|
||||||
|
required_vars = ["OPENAPI_SPEC", "AIRFLOW_BASE_URL", "AUTH_TOKEN"]
|
||||||
|
missing_vars = [var for var in required_vars if var not in os.environ]
|
||||||
|
if missing_vars:
|
||||||
|
raise ValueError(f"Missing required environment variables: {missing_vars}")
|
||||||
|
|
||||||
|
return AirflowClient(spec_path=os.environ["OPENAPI_SPEC"], base_url=os.environ["AIRFLOW_BASE_URL"], auth_token=os.environ["AUTH_TOKEN"])
|
||||||
|
|
||||||
|
|
||||||
|
async def _initialize_tools() -> None:
|
||||||
|
"""Initialize tools cache with Airflow operations.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If initialization fails
|
||||||
|
"""
|
||||||
|
global _tools_cache
|
||||||
|
|
||||||
|
try:
|
||||||
|
client = _initialize_client()
|
||||||
|
parser = OperationParser(os.environ["OPENAPI_SPEC"])
|
||||||
|
|
||||||
|
# Generate tools for each operation
|
||||||
|
for operation_id in parser.get_operations():
|
||||||
|
operation_details = parser.parse_operation(operation_id)
|
||||||
|
tool = AirflowTool(operation_details, client)
|
||||||
|
_tools_cache[operation_id] = tool
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to initialize tools: %s", e)
|
||||||
|
_tools_cache.clear()
|
||||||
|
raise ValueError(f"Failed to initialize tools: {e}") from e
|
||||||
|
|
||||||
|
|
||||||
|
async def get_airflow_tools() -> list[Tool]:
|
||||||
"""Get list of all available Airflow tools.
|
"""Get list of all available Airflow tools.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of MCP Tool objects representing available operations
|
List of MCP Tool objects representing available operations
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If required environment variables are missing
|
ValueError: If required environment variables are missing or initialization fails
|
||||||
"""
|
"""
|
||||||
global _tools_cache, _client
|
|
||||||
|
|
||||||
if not _tools_cache:
|
if not _tools_cache:
|
||||||
required_vars = ["OPENAPI_SPEC", "AIRFLOW_BASE_URL", "AUTH_TOKEN"]
|
await _initialize_tools()
|
||||||
if not all(var in os.environ for var in required_vars):
|
|
||||||
raise ValueError(f"Missing required environment variables: {required_vars}")
|
|
||||||
|
|
||||||
# Initialize client if not exists
|
|
||||||
if not _client:
|
|
||||||
_client = AirflowClient(spec_path=os.environ["OPENAPI_SPEC"], base_url=os.environ["AIRFLOW_BASE_URL"], auth_token=os.environ["AUTH_TOKEN"])
|
|
||||||
|
|
||||||
|
tools = []
|
||||||
|
for operation_id, tool in _tools_cache.items():
|
||||||
try:
|
try:
|
||||||
# Create parser
|
schema = tool.operation.input_model.model_json_schema()
|
||||||
parser = OperationParser(os.environ["OPENAPI_SPEC"])
|
tools.append(
|
||||||
|
|
||||||
# Generate tools for each operation
|
|
||||||
for operation_id in parser.get_operations():
|
|
||||||
operation_details = parser.parse_operation(operation_id)
|
|
||||||
tool = AirflowTool(operation_details, _client)
|
|
||||||
_tools_cache[operation_id] = tool
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Failed to initialize tools: %s", e)
|
|
||||||
raise
|
|
||||||
|
|
||||||
# Convert to MCP Tool format
|
|
||||||
return [
|
|
||||||
Tool(
|
Tool(
|
||||||
name=operation_id,
|
name=operation_id,
|
||||||
description=tool.operation.operation_id,
|
description=tool.operation.operation_id,
|
||||||
inputSchema=tool.operation.input_model.model_json_schema(),
|
inputSchema=schema,
|
||||||
)
|
)
|
||||||
for operation_id, tool in _tools_cache.items()
|
)
|
||||||
]
|
except Exception as e:
|
||||||
|
logger.error("Failed to create tool schema for %s: %s", operation_id, e)
|
||||||
|
continue
|
||||||
|
|
||||||
|
return tools
|
||||||
|
|
||||||
|
|
||||||
def get_tool(name: str) -> AirflowTool:
|
async def get_tool(name: str) -> AirflowTool:
|
||||||
"""Get specific tool by name.
|
"""Get specific tool by name.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -69,9 +94,12 @@ def get_tool(name: str) -> AirflowTool:
|
|||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
KeyError: If tool not found
|
KeyError: If tool not found
|
||||||
|
ValueError: If tool initialization fails
|
||||||
"""
|
"""
|
||||||
|
if not _tools_cache:
|
||||||
|
await _initialize_tools()
|
||||||
|
|
||||||
if name not in _tools_cache:
|
if name not in _tools_cache:
|
||||||
# Ensure cache is populated
|
raise KeyError(f"Tool {name} not found")
|
||||||
get_airflow_tools()
|
|
||||||
|
|
||||||
return _tools_cache[name]
|
return _tools_cache[name]
|
||||||
|
|||||||
Reference in New Issue
Block a user