fix tests

This commit is contained in:
2025-02-16 08:51:35 +00:00
parent 869cffb549
commit c952e945c7
7 changed files with 95 additions and 102 deletions

View File

@@ -20,7 +20,6 @@ class OperationDetails:
method: str method: str
parameters: dict[str, Any] parameters: dict[str, Any]
input_model: type[BaseModel] input_model: type[BaseModel]
response_model: type[BaseModel] | None = None
class OperationParser: class OperationParser:
@@ -83,7 +82,6 @@ class OperationParser:
operation["path_item"] = path_item operation["path_item"] = path_item
parameters = self.extract_parameters(operation) parameters = self.extract_parameters(operation)
response_model = self._parse_response_model(operation)
# Get request body schema if present # Get request body schema if present
body_schema = None body_schema = None
@@ -97,14 +95,7 @@ class OperationParser:
# Create unified input model # Create unified input model
input_model = self._create_input_model(operation_id, parameters, body_schema) input_model = self._create_input_model(operation_id, parameters, body_schema)
return OperationDetails( return OperationDetails(operation_id=operation_id, path=str(path), method=method, parameters=parameters, input_model=input_model)
operation_id=operation_id,
path=str(path),
method=method,
parameters=parameters,
input_model=input_model,
response_model=response_model,
)
raise ValueError(f"Operation {operation_id} not found in spec") raise ValueError(f"Operation {operation_id} not found in spec")
@@ -124,24 +115,21 @@ class OperationParser:
# Add path parameters # Add path parameters
for name, schema in parameters.get("path", {}).items(): for name, schema in parameters.get("path", {}).items():
field_type = schema["type"] field_type = schema["type"] # Use the mapped type from parameter schema
required = schema.get("required", True) # Path parameters are required by default fields[name] = (field_type | None, None) # Make all optional
fields[name] = (field_type, ... if required else None)
parameter_mapping["path"].append(name) parameter_mapping["path"].append(name)
# Add query parameters # Add query parameters
for name, schema in parameters.get("query", {}).items(): for name, schema in parameters.get("query", {}).items():
field_type = schema["type"] field_type = schema["type"] # Use the mapped type from parameter schema
required = schema.get("required", False) # Query parameters are optional by default fields[name] = (field_type | None, None) # Make all optional
fields[name] = (field_type, ... if required else None)
parameter_mapping["query"].append(name) parameter_mapping["query"].append(name)
# Add body fields if present # Add body fields if present
if body_schema and body_schema.get("type") == "object": if body_schema and body_schema.get("type") == "object":
for prop_name, prop_schema in body_schema.get("properties", {}).items(): for prop_name, prop_schema in body_schema.get("properties", {}).items():
field_type = self._map_type(prop_schema.get("type", "string"), prop_schema.get("format")) field_type = self._map_type(prop_schema.get("type", "string"), prop_schema.get("format"), prop_schema)
required = prop_name in body_schema.get("required", []) fields[prop_name] = (field_type | None, None) # Make all optional
fields[prop_name] = (field_type, ... if required else None)
parameter_mapping["body"].append(prop_name) parameter_mapping["body"].append(prop_name)
logger.debug("Creating input model for %s with fields: %s", operation_id, fields) logger.debug("Creating input model for %s with fields: %s", operation_id, fields)
@@ -224,8 +212,12 @@ class OperationParser:
if "$ref" in schema: if "$ref" in schema:
schema = self._resolve_ref(schema["$ref"]) schema = self._resolve_ref(schema["$ref"])
# Get the type and format from schema
openapi_type = schema.get("type", "string")
format_type = schema.get("format")
return { return {
"type": self._map_type(schema.get("type", "string")), "type": self._map_type(openapi_type, format_type, schema), # Pass format_type and full schema
"required": param.get("required", False), "required": param.get("required", False),
"default": schema.get("default"), "default": schema.get("default"),
"description": param.get("description"), "description": param.get("description"),
@@ -291,7 +283,7 @@ class OperationParser:
logger.error("Failed to create response model: %s", e) logger.error("Failed to create response model: %s", e)
return None return None
def _create_model(self, name: str, schema: dict[str, Any]) -> type[BaseModel]: def _create_model(self, name: str, schema: dict[str, Any]) -> type[BaseModel]: # noqa C901
"""Create Pydantic model from schema. """Create Pydantic model from schema.
Args: Args:

View File

@@ -1,7 +1,7 @@
import logging import logging
from typing import Any from typing import Any
from pydantic import BaseModel, ValidationError from pydantic import ValidationError
from airflow_mcp_server.client.airflow_client import AirflowClient from airflow_mcp_server.client.airflow_client import AirflowClient
from airflow_mcp_server.parser.operation_parser import OperationDetails from airflow_mcp_server.parser.operation_parser import OperationDetails
@@ -55,40 +55,27 @@ class AirflowTool(BaseTools):
) -> Any: ) -> Any:
"""Execute the operation with provided parameters.""" """Execute the operation with provided parameters."""
try: try:
mapping = self.operation.input_model.model_config["parameter_mapping"] # Validate input
body = body or {} validated_input = self.operation.input_model(**(body or {}))
path_params = {k: body[k] for k in mapping.get("path", []) if k in body} validated_body = validated_input.model_dump(exclude_none=True) # Only include non-None values
query_params = {k: body[k] for k in mapping.get("query", []) if k in body}
body_params = {k: body[k] for k in mapping.get("body", []) if k in body}
# Execute operation mapping = self.operation.input_model.model_config["parameter_mapping"]
path_params = {k: validated_body[k] for k in mapping.get("path", []) if k in validated_body}
query_params = {k: validated_body[k] for k in mapping.get("query", []) if k in validated_body}
body_params = {k: validated_body[k] for k in mapping.get("body", []) if k in validated_body}
# Execute operation and return raw response
response = await self.client.execute( response = await self.client.execute(
operation_id=self.operation.operation_id, operation_id=self.operation.operation_id,
path_params=path_params, path_params=path_params or None,
query_params=query_params, query_params=query_params or None,
body=body_params, body=body_params or None,
) )
logger.debug("Raw response: %s", response)
# Validate response if model exists
if self.operation.response_model and isinstance(response, dict):
try:
logger.debug("Response model schema: %s", self.operation.response_model.model_json_schema())
model: type[BaseModel] = self.operation.response_model
logger.debug("Attempting to validate response with model: %s", model.__name__)
validated_response = model(**response)
logger.debug("Response validation successful")
result = validated_response.model_dump()
logger.debug("Final response after model_dump: %s", result)
return result
except ValidationError as e:
logger.error("Response validation failed: %s", e)
logger.error("Validation error details: %s", e.errors())
raise RuntimeError(f"Invalid response format: {e}")
return response return response
except ValidationError:
raise
except Exception as e: except Exception as e:
logger.error("Operation execution failed: %s", e) logger.error("Operation execution failed: %s", e)
raise raise

View File

@@ -35,7 +35,7 @@ def client(spec_file: dict[str, Any]) -> AirflowClient:
def test_init_client_initialization(client: AirflowClient) -> None: def test_init_client_initialization(client: AirflowClient) -> None:
assert isinstance(client.spec, OpenAPI) assert isinstance(client.spec, OpenAPI)
assert client.base_url == "http://localhost:8080/api/v1" assert client.base_url == "http://localhost:8080/api/v1"
assert client.headers["Authorization"] == "Bearer test-token" assert client.headers["Authorization"] == "Basic test-token"
def test_init_load_spec_from_bytes() -> None: def test_init_load_spec_from_bytes() -> None:

View File

@@ -43,8 +43,11 @@ def test_parse_operation_with_path_params(parser: OperationParser) -> None:
# Verify path parameter field exists # Verify path parameter field exists
fields = operation.input_model.__annotations__ fields = operation.input_model.__annotations__
assert "path_dag_id" in fields assert "dag_id" in fields
assert isinstance(fields["path_dag_id"], type(str)) assert str in fields["dag_id"].__args__ # Check if str is in the Union types
# Verify parameter is mapped correctly
assert "dag_id" in operation.input_model.model_config["parameter_mapping"]["path"]
def test_parse_operation_with_query_params(parser: OperationParser) -> None: def test_parse_operation_with_query_params(parser: OperationParser) -> None:
@@ -53,8 +56,11 @@ def test_parse_operation_with_query_params(parser: OperationParser) -> None:
# Verify query parameter field exists # Verify query parameter field exists
fields = operation.input_model.__annotations__ fields = operation.input_model.__annotations__
assert "query_limit" in fields assert "limit" in fields
assert isinstance(fields["query_limit"], type(int)) assert int in fields["limit"].__args__ # Check if int is in the Union types
# Verify parameter is mapped correctly
assert "limit" in operation.input_model.model_config["parameter_mapping"]["query"]
def test_parse_operation_with_body_params(parser: OperationParser) -> None: def test_parse_operation_with_body_params(parser: OperationParser) -> None:
@@ -63,21 +69,11 @@ def test_parse_operation_with_body_params(parser: OperationParser) -> None:
# Verify body fields exist # Verify body fields exist
fields = operation.input_model.__annotations__ fields = operation.input_model.__annotations__
assert "body_dag_run_id" in fields assert "dag_run_id" in fields
assert isinstance(fields["body_dag_run_id"], type(str)) assert str in fields["dag_run_id"].__args__ # Check if str is in the Union types
# Verify parameter is mapped correctly
def test_parse_operation_with_response_model(parser: OperationParser) -> None: assert "dag_run_id" in operation.input_model.model_config["parameter_mapping"]["body"]
"""Test parsing operation with response model."""
operation = parser.parse_operation("get_dag")
assert operation.response_model is not None
assert issubclass(operation.response_model, BaseModel)
# Test model fields
fields = operation.response_model.__annotations__
assert "dag_id" in fields
assert "is_paused" in fields
def test_parse_operation_not_found(parser: OperationParser) -> None: def test_parse_operation_not_found(parser: OperationParser) -> None:
@@ -118,7 +114,9 @@ def test_map_parameter_schema_nullable(parser: OperationParser) -> None:
} }
result = parser._map_parameter_schema(param) result = parser._map_parameter_schema(param)
assert isinstance(result["type"], type(str)) # Check that str is in the Union types
assert str in result["type"].__args__
assert None.__class__ in result["type"].__args__ # Check for NoneType
assert not result["required"] assert not result["required"]

View File

@@ -6,7 +6,7 @@ from airflow_mcp_server.parser.operation_parser import OperationDetails
from airflow_mcp_server.tools.airflow_tool import AirflowTool from airflow_mcp_server.tools.airflow_tool import AirflowTool
from pydantic import ValidationError from pydantic import ValidationError
from tests.tools.test_models import TestRequestModel, TestResponseModel from tests.tools.test_models import TestRequestModel
@pytest.fixture @pytest.fixture
@@ -20,20 +20,27 @@ def mock_client(mocker):
@pytest.fixture @pytest.fixture
def operation_details(): def operation_details():
"""Create test operation details.""" """Create test operation details."""
model = TestRequestModel
# Add parameter mapping to model config
model.model_config["parameter_mapping"] = {
"path": ["path_id"],
"query": ["query_filter"],
"body": ["body_name", "body_value"],
}
return OperationDetails( return OperationDetails(
operation_id="test_operation", operation_id="test_operation",
path="/test/{id}", path="/test/{path_id}",
method="POST", method="POST",
parameters={ parameters={
"path": { "path": {
"id": {"type": int, "required": True}, "path_id": {"type": int, "required": True},
}, },
"query": { "query": {
"filter": {"type": str, "required": False}, "query_filter": {"type": str, "required": False},
}, },
}, },
input_model=TestRequestModel, input_model=model,
response_model=TestResponseModel,
) )
@@ -49,20 +56,23 @@ async def test_successful_execution(airflow_tool, mock_client):
# Setup mock response # Setup mock response
mock_client.execute.return_value = {"item_id": 1, "result": "success"} mock_client.execute.return_value = {"item_id": 1, "result": "success"}
# Execute operation # Execute operation with unified body
result = await airflow_tool.run( result = await airflow_tool.run(
path_params={"id": 123}, body={
query_params={"filter": "test"}, "path_id": 123,
body={"name": "test", "value": 42}, "query_filter": "test",
"body_name": "test",
"body_value": 42,
}
) )
# Verify response # Verify response
assert result == {"item_id": 1, "result": "success"} assert result == {"item_id": 1, "result": "success"}
mock_client.execute.assert_called_once_with( mock_client.execute.assert_called_once_with(
operation_id="test_operation", operation_id="test_operation",
path_params={"id": 123}, path_params={"path_id": 123},
query_params={"filter": "test"}, query_params={"query_filter": "test"},
body={"name": "test", "value": 42}, body={"body_name": "test", "body_value": 42},
) )
@@ -71,8 +81,11 @@ async def test_invalid_path_parameter(airflow_tool):
"""Test validation error for invalid path parameter type.""" """Test validation error for invalid path parameter type."""
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
await airflow_tool.run( await airflow_tool.run(
path_params={"id": "not_an_integer"}, body={
body={"name": "test", "value": 42}, "path_id": "not_an_integer", # Invalid type
"body_name": "test",
"body_value": 42,
}
) )
@@ -81,22 +94,29 @@ async def test_invalid_request_body(airflow_tool):
"""Test validation error for invalid request body.""" """Test validation error for invalid request body."""
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
await airflow_tool.run( await airflow_tool.run(
path_params={"id": 123}, body={
body={"name": "test", "value": "not_an_integer"}, "path_id": 123,
"body_name": "test",
"body_value": "not_an_integer", # Invalid type
}
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_invalid_response_format(airflow_tool, mock_client): async def test_invalid_response_format(airflow_tool, mock_client):
"""Test error handling for invalid response format.""" """Test error handling for invalid response format."""
# Setup mock response with invalid format # Setup mock response
mock_client.execute.return_value = {"invalid": "response"} mock_client.execute.return_value = {"invalid": "response"}
with pytest.raises(RuntimeError): # Should not raise any validation error
await airflow_tool.run( result = await airflow_tool.run(
path_params={"id": 123}, body={
body={"name": "test", "value": 42}, "path_id": 123,
) "body_name": "test",
"body_value": 42,
}
)
assert result == {"invalid": "response"}
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -107,6 +127,9 @@ async def test_client_error(airflow_tool, mock_client):
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
await airflow_tool.run( await airflow_tool.run(
path_params={"id": 123}, body={
body={"name": "test", "value": 42}, "path_id": 123,
"body_name": "test",
"body_value": 42,
}
) )

View File

@@ -10,10 +10,3 @@ class TestRequestModel(BaseModel):
query_filter: str | None = None query_filter: str | None = None
body_name: str body_name: str
body_value: int body_value: int
class TestResponseModel(BaseModel):
"""Test response model."""
item_id: int
result: str

View File

@@ -111,7 +111,7 @@ wheels = [
[[package]] [[package]]
name = "airflow-mcp-server" name = "airflow-mcp-server"
version = "0.1.0" version = "0.2.0"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "aiofiles" }, { name = "aiofiles" },