Airflow Tools Generation
This commit is contained in:
@@ -22,6 +22,7 @@ dev = [
|
|||||||
"pre-commit>=4.0.1",
|
"pre-commit>=4.0.1",
|
||||||
"pytest>=8.3.4",
|
"pytest>=8.3.4",
|
||||||
"pytest-asyncio>=0.25.0",
|
"pytest-asyncio>=0.25.0",
|
||||||
|
"pytest-mock>=3.14.0",
|
||||||
"ruff>=0.9.2"
|
"ruff>=0.9.2"
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -33,6 +34,8 @@ build-backend = "hatchling.build"
|
|||||||
pythonpath = ["src"]
|
pythonpath = ["src"]
|
||||||
asyncio_mode = "strict"
|
asyncio_mode = "strict"
|
||||||
testpaths = ["tests"]
|
testpaths = ["tests"]
|
||||||
|
python_classes = "!TestRequestModel,!TestResponseModel"
|
||||||
|
asyncio_default_fixture_loop_scope = "function"
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
line-length = 200
|
line-length = 200
|
||||||
@@ -69,3 +72,6 @@ skip-magic-trailing-comma = false
|
|||||||
|
|
||||||
[tool.ruff.lint.isort]
|
[tool.ruff.lint.isort]
|
||||||
combine-as-imports = true
|
combine-as-imports = true
|
||||||
|
|
||||||
|
[tool.ruff.lint.mccabe]
|
||||||
|
max-complexity = 12
|
||||||
|
|||||||
166
airflow-mcp-server/src/airflow_mcp_server/tools/airflow_tool.py
Normal file
166
airflow-mcp-server/src/airflow_mcp_server/tools/airflow_tool.py
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from airflow_mcp_server.client.airflow_client import AirflowClient
|
||||||
|
from airflow_mcp_server.parser.operation_parser import OperationDetails
|
||||||
|
from airflow_mcp_server.tools.base_tools import BaseTools
|
||||||
|
from pydantic import BaseModel, ValidationError
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def create_validation_error(field: str, message: str) -> ValidationError:
|
||||||
|
"""Create a properly formatted validation error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
field: The field that failed validation
|
||||||
|
message: The error message
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ValidationError: A properly formatted validation error
|
||||||
|
"""
|
||||||
|
errors = [
|
||||||
|
{
|
||||||
|
"loc": (field,),
|
||||||
|
"msg": message,
|
||||||
|
"type": "value_error",
|
||||||
|
"input": None,
|
||||||
|
"ctx": {"error": message},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
return ValidationError.from_exception_data("validation_error", errors)
|
||||||
|
|
||||||
|
|
||||||
|
class AirflowTool(BaseTools):
|
||||||
|
"""Tool for executing Airflow API operations."""
|
||||||
|
|
||||||
|
def __init__(self, operation_details: OperationDetails, client: AirflowClient) -> None:
|
||||||
|
"""Initialize tool with operation details and client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
operation_details: Parsed operation details
|
||||||
|
client: Configured Airflow API client
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.operation = operation_details
|
||||||
|
self.client = client
|
||||||
|
|
||||||
|
def _validate_parameters(
|
||||||
|
self,
|
||||||
|
path_params: dict[str, Any] | None = None,
|
||||||
|
query_params: dict[str, Any] | None = None,
|
||||||
|
body: dict[str, Any] | None = None,
|
||||||
|
) -> tuple[dict[str, Any] | None, dict[str, Any] | None, dict[str, Any] | None]:
|
||||||
|
"""Validate input parameters against operation schemas.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path_params: URL path parameters
|
||||||
|
query_params: URL query parameters
|
||||||
|
body: Request body data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of validated (path_params, query_params, body)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValidationError: If parameters fail validation
|
||||||
|
"""
|
||||||
|
validated_params: dict[str, dict[str, Any] | None] = {
|
||||||
|
"path": None,
|
||||||
|
"query": None,
|
||||||
|
"body": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Validate path parameters
|
||||||
|
if path_params and "path" in self.operation.parameters:
|
||||||
|
path_schema = self.operation.parameters["path"]
|
||||||
|
for name, value in path_params.items():
|
||||||
|
if name in path_schema:
|
||||||
|
param_type = path_schema[name]["type"]
|
||||||
|
if not isinstance(value, param_type):
|
||||||
|
raise create_validation_error(
|
||||||
|
field=name,
|
||||||
|
message=f"Path parameter {name} must be of type {param_type.__name__}",
|
||||||
|
)
|
||||||
|
validated_params["path"] = path_params
|
||||||
|
|
||||||
|
# Validate query parameters
|
||||||
|
if query_params and "query" in self.operation.parameters:
|
||||||
|
query_schema = self.operation.parameters["query"]
|
||||||
|
for name, value in query_params.items():
|
||||||
|
if name in query_schema:
|
||||||
|
param_type = query_schema[name]["type"]
|
||||||
|
if not isinstance(value, param_type):
|
||||||
|
raise create_validation_error(
|
||||||
|
field=name,
|
||||||
|
message=f"Query parameter {name} must be of type {param_type.__name__}",
|
||||||
|
)
|
||||||
|
validated_params["query"] = query_params
|
||||||
|
|
||||||
|
# Validate request body
|
||||||
|
if body and self.operation.request_body:
|
||||||
|
try:
|
||||||
|
model: type[BaseModel] = self.operation.request_body
|
||||||
|
validated_body = model(**body)
|
||||||
|
validated_params["body"] = validated_body.model_dump()
|
||||||
|
except ValidationError as e:
|
||||||
|
# Re-raise Pydantic validation errors directly
|
||||||
|
raise e
|
||||||
|
|
||||||
|
return (
|
||||||
|
validated_params["path"],
|
||||||
|
validated_params["query"],
|
||||||
|
validated_params["body"],
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Parameter validation failed: %s", e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
path_params: dict[str, Any] | None = None,
|
||||||
|
query_params: dict[str, Any] | None = None,
|
||||||
|
body: dict[str, Any] | None = None,
|
||||||
|
) -> Any:
|
||||||
|
"""Execute the operation with provided parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path_params: URL path parameters
|
||||||
|
query_params: URL query parameters
|
||||||
|
body: Request body data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
API response data
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValidationError: If parameters fail validation
|
||||||
|
RuntimeError: If client execution fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Validate parameters
|
||||||
|
validated_path_params, validated_query_params, validated_body = self._validate_parameters(path_params, query_params, body)
|
||||||
|
|
||||||
|
# Execute operation
|
||||||
|
response = await self.client.execute(
|
||||||
|
operation_id=self.operation.operation_id,
|
||||||
|
path_params=validated_path_params,
|
||||||
|
query_params=validated_query_params,
|
||||||
|
body=validated_body,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate response if model exists
|
||||||
|
if self.operation.response_model and isinstance(response, dict):
|
||||||
|
try:
|
||||||
|
model: type[BaseModel] = self.operation.response_model
|
||||||
|
validated_response = model(**response)
|
||||||
|
return validated_response.model_dump()
|
||||||
|
except ValidationError as e:
|
||||||
|
logger.error("Response validation failed: %s", e)
|
||||||
|
raise RuntimeError(f"Invalid response format: {e}")
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Operation execution failed: %s", e)
|
||||||
|
raise
|
||||||
0
airflow-mcp-server/tests/tools/__init__.py
Normal file
0
airflow-mcp-server/tests/tools/__init__.py
Normal file
112
airflow-mcp-server/tests/tools/test_airflow_tool.py
Normal file
112
airflow-mcp-server/tests/tools/test_airflow_tool.py
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
"""Tests for AirflowTool."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from airflow_mcp_server.client.airflow_client import AirflowClient
|
||||||
|
from airflow_mcp_server.parser.operation_parser import OperationDetails
|
||||||
|
from airflow_mcp_server.tools.airflow_tool import AirflowTool
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from tests.tools.test_models import TestRequestModel, TestResponseModel
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_client(mocker):
|
||||||
|
"""Create mock Airflow client."""
|
||||||
|
client = mocker.Mock(spec=AirflowClient)
|
||||||
|
client.execute = mocker.AsyncMock()
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def operation_details():
|
||||||
|
"""Create test operation details."""
|
||||||
|
return OperationDetails(
|
||||||
|
operation_id="test_operation",
|
||||||
|
path="/test/{id}",
|
||||||
|
method="POST",
|
||||||
|
parameters={
|
||||||
|
"path": {
|
||||||
|
"id": {"type": int, "required": True},
|
||||||
|
},
|
||||||
|
"query": {
|
||||||
|
"filter": {"type": str, "required": False},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
request_body=TestRequestModel,
|
||||||
|
response_model=TestResponseModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def airflow_tool(mock_client, operation_details):
|
||||||
|
"""Create AirflowTool instance for testing."""
|
||||||
|
return AirflowTool(operation_details, mock_client)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_successful_execution(airflow_tool, mock_client):
|
||||||
|
"""Test successful operation execution with valid parameters."""
|
||||||
|
# Setup mock response
|
||||||
|
mock_client.execute.return_value = {"item_id": 1, "result": "success"}
|
||||||
|
|
||||||
|
# Execute operation
|
||||||
|
result = await airflow_tool.run(
|
||||||
|
path_params={"id": 123},
|
||||||
|
query_params={"filter": "test"},
|
||||||
|
body={"name": "test", "value": 42},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify response
|
||||||
|
assert result == {"item_id": 1, "result": "success"}
|
||||||
|
mock_client.execute.assert_called_once_with(
|
||||||
|
operation_id="test_operation",
|
||||||
|
path_params={"id": 123},
|
||||||
|
query_params={"filter": "test"},
|
||||||
|
body={"name": "test", "value": 42},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invalid_path_parameter(airflow_tool):
|
||||||
|
"""Test validation error for invalid path parameter type."""
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
await airflow_tool.run(
|
||||||
|
path_params={"id": "not_an_integer"},
|
||||||
|
body={"name": "test", "value": 42},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invalid_request_body(airflow_tool):
|
||||||
|
"""Test validation error for invalid request body."""
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
await airflow_tool.run(
|
||||||
|
path_params={"id": 123},
|
||||||
|
body={"name": "test", "value": "not_an_integer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invalid_response_format(airflow_tool, mock_client):
|
||||||
|
"""Test error handling for invalid response format."""
|
||||||
|
# Setup mock response with invalid format
|
||||||
|
mock_client.execute.return_value = {"invalid": "response"}
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
await airflow_tool.run(
|
||||||
|
path_params={"id": 123},
|
||||||
|
body={"name": "test", "value": 42},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_client_error(airflow_tool, mock_client):
|
||||||
|
"""Test error handling for client execution failure."""
|
||||||
|
# Setup mock to raise exception
|
||||||
|
mock_client.execute.side_effect = RuntimeError("API Error")
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
await airflow_tool.run(
|
||||||
|
path_params={"id": 123},
|
||||||
|
body={"name": "test", "value": 42},
|
||||||
|
)
|
||||||
17
airflow-mcp-server/tests/tools/test_models.py
Normal file
17
airflow-mcp-server/tests/tools/test_models.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
"""Test models for Airflow tool tests."""
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class TestRequestModel(BaseModel):
|
||||||
|
"""Test request model."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
value: int
|
||||||
|
|
||||||
|
|
||||||
|
class TestResponseModel(BaseModel):
|
||||||
|
"""Test response model."""
|
||||||
|
|
||||||
|
item_id: int
|
||||||
|
result: str
|
||||||
Reference in New Issue
Block a user