fix tests
This commit is contained in:
@@ -20,7 +20,6 @@ class OperationDetails:
|
||||
method: str
|
||||
parameters: dict[str, Any]
|
||||
input_model: type[BaseModel]
|
||||
response_model: type[BaseModel] | None = None
|
||||
|
||||
|
||||
class OperationParser:
|
||||
@@ -83,7 +82,6 @@ class OperationParser:
|
||||
operation["path_item"] = path_item
|
||||
|
||||
parameters = self.extract_parameters(operation)
|
||||
response_model = self._parse_response_model(operation)
|
||||
|
||||
# Get request body schema if present
|
||||
body_schema = None
|
||||
@@ -97,14 +95,7 @@ class OperationParser:
|
||||
# Create unified input model
|
||||
input_model = self._create_input_model(operation_id, parameters, body_schema)
|
||||
|
||||
return OperationDetails(
|
||||
operation_id=operation_id,
|
||||
path=str(path),
|
||||
method=method,
|
||||
parameters=parameters,
|
||||
input_model=input_model,
|
||||
response_model=response_model,
|
||||
)
|
||||
return OperationDetails(operation_id=operation_id, path=str(path), method=method, parameters=parameters, input_model=input_model)
|
||||
|
||||
raise ValueError(f"Operation {operation_id} not found in spec")
|
||||
|
||||
@@ -124,24 +115,21 @@ class OperationParser:
|
||||
|
||||
# Add path parameters
|
||||
for name, schema in parameters.get("path", {}).items():
|
||||
field_type = schema["type"]
|
||||
required = schema.get("required", True) # Path parameters are required by default
|
||||
fields[name] = (field_type, ... if required else None)
|
||||
field_type = schema["type"] # Use the mapped type from parameter schema
|
||||
fields[name] = (field_type | None, None) # Make all optional
|
||||
parameter_mapping["path"].append(name)
|
||||
|
||||
# Add query parameters
|
||||
for name, schema in parameters.get("query", {}).items():
|
||||
field_type = schema["type"]
|
||||
required = schema.get("required", False) # Query parameters are optional by default
|
||||
fields[name] = (field_type, ... if required else None)
|
||||
field_type = schema["type"] # Use the mapped type from parameter schema
|
||||
fields[name] = (field_type | None, None) # Make all optional
|
||||
parameter_mapping["query"].append(name)
|
||||
|
||||
# Add body fields if present
|
||||
if body_schema and body_schema.get("type") == "object":
|
||||
for prop_name, prop_schema in body_schema.get("properties", {}).items():
|
||||
field_type = self._map_type(prop_schema.get("type", "string"), prop_schema.get("format"))
|
||||
required = prop_name in body_schema.get("required", [])
|
||||
fields[prop_name] = (field_type, ... if required else None)
|
||||
field_type = self._map_type(prop_schema.get("type", "string"), prop_schema.get("format"), prop_schema)
|
||||
fields[prop_name] = (field_type | None, None) # Make all optional
|
||||
parameter_mapping["body"].append(prop_name)
|
||||
|
||||
logger.debug("Creating input model for %s with fields: %s", operation_id, fields)
|
||||
@@ -224,8 +212,12 @@ class OperationParser:
|
||||
if "$ref" in schema:
|
||||
schema = self._resolve_ref(schema["$ref"])
|
||||
|
||||
# Get the type and format from schema
|
||||
openapi_type = schema.get("type", "string")
|
||||
format_type = schema.get("format")
|
||||
|
||||
return {
|
||||
"type": self._map_type(schema.get("type", "string")),
|
||||
"type": self._map_type(openapi_type, format_type, schema), # Pass format_type and full schema
|
||||
"required": param.get("required", False),
|
||||
"default": schema.get("default"),
|
||||
"description": param.get("description"),
|
||||
@@ -291,7 +283,7 @@ class OperationParser:
|
||||
logger.error("Failed to create response model: %s", e)
|
||||
return None
|
||||
|
||||
def _create_model(self, name: str, schema: dict[str, Any]) -> type[BaseModel]:
|
||||
def _create_model(self, name: str, schema: dict[str, Any]) -> type[BaseModel]: # noqa C901
|
||||
"""Create Pydantic model from schema.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from pydantic import ValidationError
|
||||
|
||||
from airflow_mcp_server.client.airflow_client import AirflowClient
|
||||
from airflow_mcp_server.parser.operation_parser import OperationDetails
|
||||
@@ -55,40 +55,27 @@ class AirflowTool(BaseTools):
|
||||
) -> Any:
|
||||
"""Execute the operation with provided parameters."""
|
||||
try:
|
||||
mapping = self.operation.input_model.model_config["parameter_mapping"]
|
||||
body = body or {}
|
||||
path_params = {k: body[k] for k in mapping.get("path", []) if k in body}
|
||||
query_params = {k: body[k] for k in mapping.get("query", []) if k in body}
|
||||
body_params = {k: body[k] for k in mapping.get("body", []) if k in body}
|
||||
# Validate input
|
||||
validated_input = self.operation.input_model(**(body or {}))
|
||||
validated_body = validated_input.model_dump(exclude_none=True) # Only include non-None values
|
||||
|
||||
# Execute operation
|
||||
mapping = self.operation.input_model.model_config["parameter_mapping"]
|
||||
path_params = {k: validated_body[k] for k in mapping.get("path", []) if k in validated_body}
|
||||
query_params = {k: validated_body[k] for k in mapping.get("query", []) if k in validated_body}
|
||||
body_params = {k: validated_body[k] for k in mapping.get("body", []) if k in validated_body}
|
||||
|
||||
# Execute operation and return raw response
|
||||
response = await self.client.execute(
|
||||
operation_id=self.operation.operation_id,
|
||||
path_params=path_params,
|
||||
query_params=query_params,
|
||||
body=body_params,
|
||||
path_params=path_params or None,
|
||||
query_params=query_params or None,
|
||||
body=body_params or None,
|
||||
)
|
||||
|
||||
logger.debug("Raw response: %s", response)
|
||||
|
||||
# Validate response if model exists
|
||||
if self.operation.response_model and isinstance(response, dict):
|
||||
try:
|
||||
logger.debug("Response model schema: %s", self.operation.response_model.model_json_schema())
|
||||
model: type[BaseModel] = self.operation.response_model
|
||||
logger.debug("Attempting to validate response with model: %s", model.__name__)
|
||||
validated_response = model(**response)
|
||||
logger.debug("Response validation successful")
|
||||
result = validated_response.model_dump()
|
||||
logger.debug("Final response after model_dump: %s", result)
|
||||
return result
|
||||
except ValidationError as e:
|
||||
logger.error("Response validation failed: %s", e)
|
||||
logger.error("Validation error details: %s", e.errors())
|
||||
raise RuntimeError(f"Invalid response format: {e}")
|
||||
|
||||
return response
|
||||
|
||||
except ValidationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Operation execution failed: %s", e)
|
||||
raise
|
||||
|
||||
@@ -35,7 +35,7 @@ def client(spec_file: dict[str, Any]) -> AirflowClient:
|
||||
def test_init_client_initialization(client: AirflowClient) -> None:
|
||||
assert isinstance(client.spec, OpenAPI)
|
||||
assert client.base_url == "http://localhost:8080/api/v1"
|
||||
assert client.headers["Authorization"] == "Bearer test-token"
|
||||
assert client.headers["Authorization"] == "Basic test-token"
|
||||
|
||||
|
||||
def test_init_load_spec_from_bytes() -> None:
|
||||
|
||||
@@ -43,8 +43,11 @@ def test_parse_operation_with_path_params(parser: OperationParser) -> None:
|
||||
|
||||
# Verify path parameter field exists
|
||||
fields = operation.input_model.__annotations__
|
||||
assert "path_dag_id" in fields
|
||||
assert isinstance(fields["path_dag_id"], type(str))
|
||||
assert "dag_id" in fields
|
||||
assert str in fields["dag_id"].__args__ # Check if str is in the Union types
|
||||
|
||||
# Verify parameter is mapped correctly
|
||||
assert "dag_id" in operation.input_model.model_config["parameter_mapping"]["path"]
|
||||
|
||||
|
||||
def test_parse_operation_with_query_params(parser: OperationParser) -> None:
|
||||
@@ -53,8 +56,11 @@ def test_parse_operation_with_query_params(parser: OperationParser) -> None:
|
||||
|
||||
# Verify query parameter field exists
|
||||
fields = operation.input_model.__annotations__
|
||||
assert "query_limit" in fields
|
||||
assert isinstance(fields["query_limit"], type(int))
|
||||
assert "limit" in fields
|
||||
assert int in fields["limit"].__args__ # Check if int is in the Union types
|
||||
|
||||
# Verify parameter is mapped correctly
|
||||
assert "limit" in operation.input_model.model_config["parameter_mapping"]["query"]
|
||||
|
||||
|
||||
def test_parse_operation_with_body_params(parser: OperationParser) -> None:
|
||||
@@ -63,21 +69,11 @@ def test_parse_operation_with_body_params(parser: OperationParser) -> None:
|
||||
|
||||
# Verify body fields exist
|
||||
fields = operation.input_model.__annotations__
|
||||
assert "body_dag_run_id" in fields
|
||||
assert isinstance(fields["body_dag_run_id"], type(str))
|
||||
assert "dag_run_id" in fields
|
||||
assert str in fields["dag_run_id"].__args__ # Check if str is in the Union types
|
||||
|
||||
|
||||
def test_parse_operation_with_response_model(parser: OperationParser) -> None:
|
||||
"""Test parsing operation with response model."""
|
||||
operation = parser.parse_operation("get_dag")
|
||||
|
||||
assert operation.response_model is not None
|
||||
assert issubclass(operation.response_model, BaseModel)
|
||||
|
||||
# Test model fields
|
||||
fields = operation.response_model.__annotations__
|
||||
assert "dag_id" in fields
|
||||
assert "is_paused" in fields
|
||||
# Verify parameter is mapped correctly
|
||||
assert "dag_run_id" in operation.input_model.model_config["parameter_mapping"]["body"]
|
||||
|
||||
|
||||
def test_parse_operation_not_found(parser: OperationParser) -> None:
|
||||
@@ -118,7 +114,9 @@ def test_map_parameter_schema_nullable(parser: OperationParser) -> None:
|
||||
}
|
||||
|
||||
result = parser._map_parameter_schema(param)
|
||||
assert isinstance(result["type"], type(str))
|
||||
# Check that str is in the Union types
|
||||
assert str in result["type"].__args__
|
||||
assert None.__class__ in result["type"].__args__ # Check for NoneType
|
||||
assert not result["required"]
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from airflow_mcp_server.parser.operation_parser import OperationDetails
|
||||
from airflow_mcp_server.tools.airflow_tool import AirflowTool
|
||||
from pydantic import ValidationError
|
||||
|
||||
from tests.tools.test_models import TestRequestModel, TestResponseModel
|
||||
from tests.tools.test_models import TestRequestModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -20,20 +20,27 @@ def mock_client(mocker):
|
||||
@pytest.fixture
|
||||
def operation_details():
|
||||
"""Create test operation details."""
|
||||
model = TestRequestModel
|
||||
# Add parameter mapping to model config
|
||||
model.model_config["parameter_mapping"] = {
|
||||
"path": ["path_id"],
|
||||
"query": ["query_filter"],
|
||||
"body": ["body_name", "body_value"],
|
||||
}
|
||||
|
||||
return OperationDetails(
|
||||
operation_id="test_operation",
|
||||
path="/test/{id}",
|
||||
path="/test/{path_id}",
|
||||
method="POST",
|
||||
parameters={
|
||||
"path": {
|
||||
"id": {"type": int, "required": True},
|
||||
"path_id": {"type": int, "required": True},
|
||||
},
|
||||
"query": {
|
||||
"filter": {"type": str, "required": False},
|
||||
"query_filter": {"type": str, "required": False},
|
||||
},
|
||||
},
|
||||
input_model=TestRequestModel,
|
||||
response_model=TestResponseModel,
|
||||
input_model=model,
|
||||
)
|
||||
|
||||
|
||||
@@ -49,20 +56,23 @@ async def test_successful_execution(airflow_tool, mock_client):
|
||||
# Setup mock response
|
||||
mock_client.execute.return_value = {"item_id": 1, "result": "success"}
|
||||
|
||||
# Execute operation
|
||||
# Execute operation with unified body
|
||||
result = await airflow_tool.run(
|
||||
path_params={"id": 123},
|
||||
query_params={"filter": "test"},
|
||||
body={"name": "test", "value": 42},
|
||||
body={
|
||||
"path_id": 123,
|
||||
"query_filter": "test",
|
||||
"body_name": "test",
|
||||
"body_value": 42,
|
||||
}
|
||||
)
|
||||
|
||||
# Verify response
|
||||
assert result == {"item_id": 1, "result": "success"}
|
||||
mock_client.execute.assert_called_once_with(
|
||||
operation_id="test_operation",
|
||||
path_params={"id": 123},
|
||||
query_params={"filter": "test"},
|
||||
body={"name": "test", "value": 42},
|
||||
path_params={"path_id": 123},
|
||||
query_params={"query_filter": "test"},
|
||||
body={"body_name": "test", "body_value": 42},
|
||||
)
|
||||
|
||||
|
||||
@@ -71,8 +81,11 @@ async def test_invalid_path_parameter(airflow_tool):
|
||||
"""Test validation error for invalid path parameter type."""
|
||||
with pytest.raises(ValidationError):
|
||||
await airflow_tool.run(
|
||||
path_params={"id": "not_an_integer"},
|
||||
body={"name": "test", "value": 42},
|
||||
body={
|
||||
"path_id": "not_an_integer", # Invalid type
|
||||
"body_name": "test",
|
||||
"body_value": 42,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -81,22 +94,29 @@ async def test_invalid_request_body(airflow_tool):
|
||||
"""Test validation error for invalid request body."""
|
||||
with pytest.raises(ValidationError):
|
||||
await airflow_tool.run(
|
||||
path_params={"id": 123},
|
||||
body={"name": "test", "value": "not_an_integer"},
|
||||
body={
|
||||
"path_id": 123,
|
||||
"body_name": "test",
|
||||
"body_value": "not_an_integer", # Invalid type
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_response_format(airflow_tool, mock_client):
|
||||
"""Test error handling for invalid response format."""
|
||||
# Setup mock response with invalid format
|
||||
# Setup mock response
|
||||
mock_client.execute.return_value = {"invalid": "response"}
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
await airflow_tool.run(
|
||||
path_params={"id": 123},
|
||||
body={"name": "test", "value": 42},
|
||||
# Should not raise any validation error
|
||||
result = await airflow_tool.run(
|
||||
body={
|
||||
"path_id": 123,
|
||||
"body_name": "test",
|
||||
"body_value": 42,
|
||||
}
|
||||
)
|
||||
assert result == {"invalid": "response"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -107,6 +127,9 @@ async def test_client_error(airflow_tool, mock_client):
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
await airflow_tool.run(
|
||||
path_params={"id": 123},
|
||||
body={"name": "test", "value": 42},
|
||||
body={
|
||||
"path_id": 123,
|
||||
"body_name": "test",
|
||||
"body_value": 42,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -10,10 +10,3 @@ class TestRequestModel(BaseModel):
|
||||
query_filter: str | None = None
|
||||
body_name: str
|
||||
body_value: int
|
||||
|
||||
|
||||
class TestResponseModel(BaseModel):
|
||||
"""Test response model."""
|
||||
|
||||
item_id: int
|
||||
result: str
|
||||
|
||||
2
airflow-mcp-server/uv.lock
generated
2
airflow-mcp-server/uv.lock
generated
@@ -111,7 +111,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "airflow-mcp-server"
|
||||
version = "0.1.0"
|
||||
version = "0.2.0"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "aiofiles" },
|
||||
|
||||
Reference in New Issue
Block a user