fix tests

This commit is contained in:
2025-02-16 08:51:35 +00:00
parent 869cffb549
commit c952e945c7
7 changed files with 95 additions and 102 deletions

View File

@@ -20,7 +20,6 @@ class OperationDetails:
method: str
parameters: dict[str, Any]
input_model: type[BaseModel]
response_model: type[BaseModel] | None = None
class OperationParser:
@@ -83,7 +82,6 @@ class OperationParser:
operation["path_item"] = path_item
parameters = self.extract_parameters(operation)
response_model = self._parse_response_model(operation)
# Get request body schema if present
body_schema = None
@@ -97,14 +95,7 @@ class OperationParser:
# Create unified input model
input_model = self._create_input_model(operation_id, parameters, body_schema)
return OperationDetails(
operation_id=operation_id,
path=str(path),
method=method,
parameters=parameters,
input_model=input_model,
response_model=response_model,
)
return OperationDetails(operation_id=operation_id, path=str(path), method=method, parameters=parameters, input_model=input_model)
raise ValueError(f"Operation {operation_id} not found in spec")
@@ -124,24 +115,21 @@ class OperationParser:
# Add path parameters
for name, schema in parameters.get("path", {}).items():
field_type = schema["type"]
required = schema.get("required", True) # Path parameters are required by default
fields[name] = (field_type, ... if required else None)
field_type = schema["type"] # Use the mapped type from parameter schema
fields[name] = (field_type | None, None) # Make all optional
parameter_mapping["path"].append(name)
# Add query parameters
for name, schema in parameters.get("query", {}).items():
field_type = schema["type"]
required = schema.get("required", False) # Query parameters are optional by default
fields[name] = (field_type, ... if required else None)
field_type = schema["type"] # Use the mapped type from parameter schema
fields[name] = (field_type | None, None) # Make all optional
parameter_mapping["query"].append(name)
# Add body fields if present
if body_schema and body_schema.get("type") == "object":
for prop_name, prop_schema in body_schema.get("properties", {}).items():
field_type = self._map_type(prop_schema.get("type", "string"), prop_schema.get("format"))
required = prop_name in body_schema.get("required", [])
fields[prop_name] = (field_type, ... if required else None)
field_type = self._map_type(prop_schema.get("type", "string"), prop_schema.get("format"), prop_schema)
fields[prop_name] = (field_type | None, None) # Make all optional
parameter_mapping["body"].append(prop_name)
logger.debug("Creating input model for %s with fields: %s", operation_id, fields)
@@ -224,8 +212,12 @@ class OperationParser:
if "$ref" in schema:
schema = self._resolve_ref(schema["$ref"])
# Get the type and format from schema
openapi_type = schema.get("type", "string")
format_type = schema.get("format")
return {
"type": self._map_type(schema.get("type", "string")),
"type": self._map_type(openapi_type, format_type, schema), # Pass format_type and full schema
"required": param.get("required", False),
"default": schema.get("default"),
"description": param.get("description"),
@@ -291,7 +283,7 @@ class OperationParser:
logger.error("Failed to create response model: %s", e)
return None
def _create_model(self, name: str, schema: dict[str, Any]) -> type[BaseModel]:
def _create_model(self, name: str, schema: dict[str, Any]) -> type[BaseModel]: # noqa C901
"""Create Pydantic model from schema.
Args:

View File

@@ -1,7 +1,7 @@
import logging
from typing import Any
from pydantic import BaseModel, ValidationError
from pydantic import ValidationError
from airflow_mcp_server.client.airflow_client import AirflowClient
from airflow_mcp_server.parser.operation_parser import OperationDetails
@@ -55,40 +55,27 @@ class AirflowTool(BaseTools):
) -> Any:
"""Execute the operation with provided parameters."""
try:
mapping = self.operation.input_model.model_config["parameter_mapping"]
body = body or {}
path_params = {k: body[k] for k in mapping.get("path", []) if k in body}
query_params = {k: body[k] for k in mapping.get("query", []) if k in body}
body_params = {k: body[k] for k in mapping.get("body", []) if k in body}
# Validate input
validated_input = self.operation.input_model(**(body or {}))
validated_body = validated_input.model_dump(exclude_none=True) # Only include non-None values
# Execute operation
mapping = self.operation.input_model.model_config["parameter_mapping"]
path_params = {k: validated_body[k] for k in mapping.get("path", []) if k in validated_body}
query_params = {k: validated_body[k] for k in mapping.get("query", []) if k in validated_body}
body_params = {k: validated_body[k] for k in mapping.get("body", []) if k in validated_body}
# Execute operation and return raw response
response = await self.client.execute(
operation_id=self.operation.operation_id,
path_params=path_params,
query_params=query_params,
body=body_params,
path_params=path_params or None,
query_params=query_params or None,
body=body_params or None,
)
logger.debug("Raw response: %s", response)
# Validate response if model exists
if self.operation.response_model and isinstance(response, dict):
try:
logger.debug("Response model schema: %s", self.operation.response_model.model_json_schema())
model: type[BaseModel] = self.operation.response_model
logger.debug("Attempting to validate response with model: %s", model.__name__)
validated_response = model(**response)
logger.debug("Response validation successful")
result = validated_response.model_dump()
logger.debug("Final response after model_dump: %s", result)
return result
except ValidationError as e:
logger.error("Response validation failed: %s", e)
logger.error("Validation error details: %s", e.errors())
raise RuntimeError(f"Invalid response format: {e}")
return response
except ValidationError:
raise
except Exception as e:
logger.error("Operation execution failed: %s", e)
raise

View File

@@ -35,7 +35,7 @@ def client(spec_file: dict[str, Any]) -> AirflowClient:
def test_init_client_initialization(client: AirflowClient) -> None:
assert isinstance(client.spec, OpenAPI)
assert client.base_url == "http://localhost:8080/api/v1"
assert client.headers["Authorization"] == "Bearer test-token"
assert client.headers["Authorization"] == "Basic test-token"
def test_init_load_spec_from_bytes() -> None:

View File

@@ -43,8 +43,11 @@ def test_parse_operation_with_path_params(parser: OperationParser) -> None:
# Verify path parameter field exists
fields = operation.input_model.__annotations__
assert "path_dag_id" in fields
assert isinstance(fields["path_dag_id"], type(str))
assert "dag_id" in fields
assert str in fields["dag_id"].__args__ # Check if str is in the Union types
# Verify parameter is mapped correctly
assert "dag_id" in operation.input_model.model_config["parameter_mapping"]["path"]
def test_parse_operation_with_query_params(parser: OperationParser) -> None:
@@ -53,8 +56,11 @@ def test_parse_operation_with_query_params(parser: OperationParser) -> None:
# Verify query parameter field exists
fields = operation.input_model.__annotations__
assert "query_limit" in fields
assert isinstance(fields["query_limit"], type(int))
assert "limit" in fields
assert int in fields["limit"].__args__ # Check if int is in the Union types
# Verify parameter is mapped correctly
assert "limit" in operation.input_model.model_config["parameter_mapping"]["query"]
def test_parse_operation_with_body_params(parser: OperationParser) -> None:
@@ -63,21 +69,11 @@ def test_parse_operation_with_body_params(parser: OperationParser) -> None:
# Verify body fields exist
fields = operation.input_model.__annotations__
assert "body_dag_run_id" in fields
assert isinstance(fields["body_dag_run_id"], type(str))
assert "dag_run_id" in fields
assert str in fields["dag_run_id"].__args__ # Check if str is in the Union types
def test_parse_operation_with_response_model(parser: OperationParser) -> None:
"""Test parsing operation with response model."""
operation = parser.parse_operation("get_dag")
assert operation.response_model is not None
assert issubclass(operation.response_model, BaseModel)
# Test model fields
fields = operation.response_model.__annotations__
assert "dag_id" in fields
assert "is_paused" in fields
# Verify parameter is mapped correctly
assert "dag_run_id" in operation.input_model.model_config["parameter_mapping"]["body"]
def test_parse_operation_not_found(parser: OperationParser) -> None:
@@ -118,7 +114,9 @@ def test_map_parameter_schema_nullable(parser: OperationParser) -> None:
}
result = parser._map_parameter_schema(param)
assert isinstance(result["type"], type(str))
# Check that str is in the Union types
assert str in result["type"].__args__
assert None.__class__ in result["type"].__args__ # Check for NoneType
assert not result["required"]

View File

@@ -6,7 +6,7 @@ from airflow_mcp_server.parser.operation_parser import OperationDetails
from airflow_mcp_server.tools.airflow_tool import AirflowTool
from pydantic import ValidationError
from tests.tools.test_models import TestRequestModel, TestResponseModel
from tests.tools.test_models import TestRequestModel
@pytest.fixture
@@ -20,20 +20,27 @@ def mock_client(mocker):
@pytest.fixture
def operation_details():
"""Create test operation details."""
model = TestRequestModel
# Add parameter mapping to model config
model.model_config["parameter_mapping"] = {
"path": ["path_id"],
"query": ["query_filter"],
"body": ["body_name", "body_value"],
}
return OperationDetails(
operation_id="test_operation",
path="/test/{id}",
path="/test/{path_id}",
method="POST",
parameters={
"path": {
"id": {"type": int, "required": True},
"path_id": {"type": int, "required": True},
},
"query": {
"filter": {"type": str, "required": False},
"query_filter": {"type": str, "required": False},
},
},
input_model=TestRequestModel,
response_model=TestResponseModel,
input_model=model,
)
@@ -49,20 +56,23 @@ async def test_successful_execution(airflow_tool, mock_client):
# Setup mock response
mock_client.execute.return_value = {"item_id": 1, "result": "success"}
# Execute operation
# Execute operation with unified body
result = await airflow_tool.run(
path_params={"id": 123},
query_params={"filter": "test"},
body={"name": "test", "value": 42},
body={
"path_id": 123,
"query_filter": "test",
"body_name": "test",
"body_value": 42,
}
)
# Verify response
assert result == {"item_id": 1, "result": "success"}
mock_client.execute.assert_called_once_with(
operation_id="test_operation",
path_params={"id": 123},
query_params={"filter": "test"},
body={"name": "test", "value": 42},
path_params={"path_id": 123},
query_params={"query_filter": "test"},
body={"body_name": "test", "body_value": 42},
)
@@ -71,8 +81,11 @@ async def test_invalid_path_parameter(airflow_tool):
"""Test validation error for invalid path parameter type."""
with pytest.raises(ValidationError):
await airflow_tool.run(
path_params={"id": "not_an_integer"},
body={"name": "test", "value": 42},
body={
"path_id": "not_an_integer", # Invalid type
"body_name": "test",
"body_value": 42,
}
)
@@ -81,22 +94,29 @@ async def test_invalid_request_body(airflow_tool):
"""Test validation error for invalid request body."""
with pytest.raises(ValidationError):
await airflow_tool.run(
path_params={"id": 123},
body={"name": "test", "value": "not_an_integer"},
body={
"path_id": 123,
"body_name": "test",
"body_value": "not_an_integer", # Invalid type
}
)
@pytest.mark.asyncio
async def test_invalid_response_format(airflow_tool, mock_client):
"""Test error handling for invalid response format."""
# Setup mock response with invalid format
# Setup mock response
mock_client.execute.return_value = {"invalid": "response"}
with pytest.raises(RuntimeError):
await airflow_tool.run(
path_params={"id": 123},
body={"name": "test", "value": 42},
)
# Should not raise any validation error
result = await airflow_tool.run(
body={
"path_id": 123,
"body_name": "test",
"body_value": 42,
}
)
assert result == {"invalid": "response"}
@pytest.mark.asyncio
@@ -107,6 +127,9 @@ async def test_client_error(airflow_tool, mock_client):
with pytest.raises(RuntimeError):
await airflow_tool.run(
path_params={"id": 123},
body={"name": "test", "value": 42},
body={
"path_id": 123,
"body_name": "test",
"body_value": 42,
}
)

View File

@@ -10,10 +10,3 @@ class TestRequestModel(BaseModel):
query_filter: str | None = None
body_name: str
body_value: int
class TestResponseModel(BaseModel):
"""Test response model."""
item_id: int
result: str

View File

@@ -111,7 +111,7 @@ wheels = [
[[package]]
name = "airflow-mcp-server"
version = "0.1.0"
version = "0.2.0"
source = { editable = "." }
dependencies = [
{ name = "aiofiles" },