Clean up for only MCP Server
This commit is contained in:
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
186
tests/client/test_airflow_client.py
Normal file
186
tests/client/test_airflow_client.py
Normal file
@@ -0,0 +1,186 @@
|
||||
import logging
|
||||
from importlib import resources
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
import yaml
|
||||
from aioresponses import aioresponses
|
||||
from airflow_mcp_server.client.airflow_client import AirflowClient
|
||||
from openapi_core import OpenAPI
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
|
||||
def create_valid_spec(paths: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||
return {"openapi": "3.0.0", "info": {"title": "Airflow API", "version": "1.0.0"}, "paths": paths or {}}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client() -> AirflowClient:
|
||||
with resources.files("airflow_mcp_server.resources").joinpath("v1.yaml").open("rb") as f:
|
||||
spec = yaml.safe_load(f)
|
||||
return AirflowClient(
|
||||
spec_path=spec,
|
||||
base_url="http://localhost:8080/api/v1",
|
||||
auth_token="test-token",
|
||||
)
|
||||
|
||||
|
||||
def test_init_client_initialization(client: AirflowClient) -> None:
|
||||
assert isinstance(client.spec, OpenAPI)
|
||||
assert client.base_url == "http://localhost:8080/api/v1"
|
||||
assert client.headers["Authorization"] == "Basic test-token"
|
||||
|
||||
|
||||
def test_init_load_spec_from_bytes() -> None:
|
||||
spec_bytes = yaml.dump(create_valid_spec()).encode()
|
||||
client = AirflowClient(spec_path=spec_bytes, base_url="http://test", auth_token="test")
|
||||
assert client.raw_spec is not None
|
||||
|
||||
|
||||
def test_init_load_spec_from_path(tmp_path: Path) -> None:
|
||||
spec_file = tmp_path / "test_spec.yaml"
|
||||
spec_file.write_text(yaml.dump(create_valid_spec()))
|
||||
client = AirflowClient(spec_path=spec_file, base_url="http://test", auth_token="test")
|
||||
assert client.raw_spec is not None
|
||||
|
||||
|
||||
def test_init_invalid_spec() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
AirflowClient(spec_path={"invalid": "spec"}, base_url="http://test", auth_token="test")
|
||||
|
||||
|
||||
def test_init_missing_paths_in_spec() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
AirflowClient(spec_path={"openapi": "3.0.0"}, base_url="http://test", auth_token="test")
|
||||
|
||||
|
||||
def test_ops_get_operation(client: AirflowClient) -> None:
|
||||
path, method, operation = client._get_operation("get_dags")
|
||||
assert path == "/dags"
|
||||
assert method == "get"
|
||||
assert operation.operation_id == "get_dags"
|
||||
|
||||
path, method, operation = client._get_operation("get_dag")
|
||||
assert path == "/dags/{dag_id}"
|
||||
assert method == "get"
|
||||
assert operation.operation_id == "get_dag"
|
||||
|
||||
|
||||
def test_ops_nonexistent_operation(client: AirflowClient) -> None:
|
||||
with pytest.raises(ValueError, match="Operation nonexistent not found in spec"):
|
||||
client._get_operation("nonexistent")
|
||||
|
||||
|
||||
def test_ops_case_sensitive_operation(client: AirflowClient) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
client._get_operation("GET_DAGS")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_without_context() -> None:
|
||||
client = AirflowClient(
|
||||
spec_path=create_valid_spec(),
|
||||
base_url="http://test",
|
||||
auth_token="test",
|
||||
)
|
||||
with pytest.raises(RuntimeError, match="Client not in async context"):
|
||||
await client.execute("get_dags")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_get_dags(client: AirflowClient) -> None:
|
||||
expected_response = {
|
||||
"dags": [
|
||||
{
|
||||
"dag_id": "test_dag",
|
||||
"is_active": True,
|
||||
"is_paused": False,
|
||||
}
|
||||
],
|
||||
"total_entries": 1,
|
||||
}
|
||||
|
||||
with aioresponses() as mock:
|
||||
async with client:
|
||||
mock.get(
|
||||
"http://localhost:8080/api/v1/dags?limit=100",
|
||||
status=200,
|
||||
payload=expected_response,
|
||||
)
|
||||
response = await client.execute("get_dags", query_params={"limit": 100})
|
||||
assert response == expected_response
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_get_dag(client: AirflowClient) -> None:
|
||||
expected_response = {
|
||||
"dag_id": "test_dag",
|
||||
"is_active": True,
|
||||
"is_paused": False,
|
||||
}
|
||||
|
||||
with aioresponses() as mock:
|
||||
async with client:
|
||||
mock.get(
|
||||
"http://localhost:8080/api/v1/dags/test_dag",
|
||||
status=200,
|
||||
payload=expected_response,
|
||||
)
|
||||
response = await client.execute(
|
||||
"get_dag",
|
||||
path_params={"dag_id": "test_dag"},
|
||||
)
|
||||
assert response == expected_response
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_invalid_params(client: AirflowClient) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
async with client:
|
||||
# Test with missing required parameter
|
||||
await client.execute("get_dag", path_params={})
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
async with client:
|
||||
# Test with invalid parameter name
|
||||
await client.execute("get_dag", path_params={"invalid": "value"})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_timeout(client: AirflowClient) -> None:
|
||||
with aioresponses() as mock:
|
||||
mock.get("http://localhost:8080/api/v1/dags", exception=aiohttp.ClientError("Timeout"))
|
||||
async with client:
|
||||
with pytest.raises(aiohttp.ClientError):
|
||||
await client.execute("get_dags")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_error_response(client: AirflowClient) -> None:
|
||||
with aioresponses() as mock:
|
||||
async with client:
|
||||
mock.get(
|
||||
"http://localhost:8080/api/v1/dags",
|
||||
status=403,
|
||||
body="Forbidden",
|
||||
)
|
||||
with pytest.raises(aiohttp.ClientResponseError):
|
||||
await client.execute("get_dags")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_session_management(client: AirflowClient) -> None:
|
||||
async with client:
|
||||
with aioresponses() as mock:
|
||||
mock.get(
|
||||
"http://localhost:8080/api/v1/dags",
|
||||
status=200,
|
||||
payload={"dags": []},
|
||||
)
|
||||
await client.execute("get_dags")
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
await client.execute("get_dags")
|
||||
58
tests/conftest.py
Normal file
58
tests/conftest.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""Test configuration and shared fixtures."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_spec_file():
|
||||
"""Mock OpenAPI spec file for testing."""
|
||||
mock_spec = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {"title": "Airflow API", "version": "1.0.0"},
|
||||
"paths": {
|
||||
"/api/v1/dags": {
|
||||
"get": {
|
||||
"operationId": "get_dags",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "List of DAGs",
|
||||
"content": {
|
||||
"application/json": {"schema": {"type": "object", "properties": {"dags": {"type": "array", "items": {"type": "object", "properties": {"dag_id": {"type": "string"}}}}}}}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
"/api/v1/dags/{dag_id}": {
|
||||
"get": {
|
||||
"operationId": "get_dag",
|
||||
"parameters": [{"name": "dag_id", "in": "path", "required": True, "schema": {"type": "string"}}],
|
||||
"responses": {"200": {"description": "Successful response", "content": {"application/json": {"schema": {"type": "object", "properties": {"dag_id": {"type": "string"}}}}}}},
|
||||
},
|
||||
"post": {
|
||||
"operationId": "post_dag_run",
|
||||
"parameters": [{"name": "dag_id", "in": "path", "required": True, "schema": {"type": "string"}}],
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"conf": {"type": "object"},
|
||||
"dag_run_id": {"type": "string"},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful response",
|
||||
"content": {"application/json": {"schema": {"type": "object", "properties": {"dag_run_id": {"type": "string"}, "state": {"type": "string"}}}}},
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
return mock_spec
|
||||
@@ -1,83 +0,0 @@
|
||||
"""Example DAGs test. This test ensures that all Dags have tags, retries set to two, and no import errors. This is an example pytest and may not be fit the context of your DAGs. Feel free to add and remove tests."""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
import pytest
|
||||
from airflow.models import DagBag
|
||||
|
||||
|
||||
@contextmanager
|
||||
def suppress_logging(namespace):
|
||||
logger = logging.getLogger(namespace)
|
||||
old_value = logger.disabled
|
||||
logger.disabled = True
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
logger.disabled = old_value
|
||||
|
||||
|
||||
def get_import_errors():
|
||||
"""
|
||||
Generate a tuple for import errors in the dag bag
|
||||
"""
|
||||
with suppress_logging("airflow"):
|
||||
dag_bag = DagBag(include_examples=False)
|
||||
|
||||
def strip_path_prefix(path):
|
||||
return os.path.relpath(path, os.environ.get("AIRFLOW_HOME"))
|
||||
|
||||
# prepend "(None,None)" to ensure that a test object is always created even if it's a no op.
|
||||
return [(None, None)] + [
|
||||
(strip_path_prefix(k), v.strip()) for k, v in dag_bag.import_errors.items()
|
||||
]
|
||||
|
||||
|
||||
def get_dags():
|
||||
"""
|
||||
Generate a tuple of dag_id, <DAG objects> in the DagBag
|
||||
"""
|
||||
with suppress_logging("airflow"):
|
||||
dag_bag = DagBag(include_examples=False)
|
||||
|
||||
def strip_path_prefix(path):
|
||||
return os.path.relpath(path, os.environ.get("AIRFLOW_HOME"))
|
||||
|
||||
return [(k, v, strip_path_prefix(v.fileloc)) for k, v in dag_bag.dags.items()]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"rel_path,rv", get_import_errors(), ids=[x[0] for x in get_import_errors()]
|
||||
)
|
||||
def test_file_imports(rel_path, rv):
|
||||
"""Test for import errors on a file"""
|
||||
if rel_path and rv:
|
||||
raise Exception(f"{rel_path} failed to import with message \n {rv}")
|
||||
|
||||
|
||||
APPROVED_TAGS = {}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"dag_id,dag,fileloc", get_dags(), ids=[x[2] for x in get_dags()]
|
||||
)
|
||||
def test_dag_tags(dag_id, dag, fileloc):
|
||||
"""
|
||||
test if a DAG is tagged and if those TAGs are in the approved list
|
||||
"""
|
||||
assert dag.tags, f"{dag_id} in {fileloc} has no tags"
|
||||
if APPROVED_TAGS:
|
||||
assert not set(dag.tags) - APPROVED_TAGS
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"dag_id,dag, fileloc", get_dags(), ids=[x[2] for x in get_dags()]
|
||||
)
|
||||
def test_dag_retries(dag_id, dag, fileloc):
|
||||
"""
|
||||
test if a DAG has retries set
|
||||
"""
|
||||
assert (
|
||||
dag.default_args.get("retries", None) >= 2
|
||||
), f"{dag_id} in {fileloc} must have task retries >= 2."
|
||||
174
tests/parser/test_operation_parser.py
Normal file
174
tests/parser/test_operation_parser.py
Normal file
@@ -0,0 +1,174 @@
|
||||
import logging
|
||||
from importlib import resources
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from airflow_mcp_server.parser.operation_parser import OperationDetails, OperationParser
|
||||
from pydantic import BaseModel
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def spec_file():
|
||||
"""Get content of the v1.yaml spec file."""
|
||||
with resources.files("tests.client").joinpath("v1.yaml").open("rb") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def parser(spec_file) -> OperationParser:
|
||||
"""Create OperationParser instance."""
|
||||
return OperationParser(spec_path=spec_file)
|
||||
|
||||
|
||||
def test_parse_operation_basic(parser: OperationParser) -> None:
|
||||
"""Test basic operation parsing."""
|
||||
operation = parser.parse_operation("get_dags")
|
||||
|
||||
assert isinstance(operation, OperationDetails)
|
||||
assert operation.operation_id == "get_dags"
|
||||
assert operation.path == "/dags"
|
||||
assert operation.method == "get"
|
||||
assert isinstance(operation.parameters, dict)
|
||||
|
||||
|
||||
def test_parse_operation_with_path_params(parser: OperationParser) -> None:
|
||||
"""Test parsing operation with path parameters."""
|
||||
operation = parser.parse_operation("get_dag")
|
||||
|
||||
assert operation.path == "/dags/{dag_id}"
|
||||
assert isinstance(operation.input_model, type(BaseModel))
|
||||
|
||||
# Verify path parameter field exists
|
||||
fields = operation.input_model.__annotations__
|
||||
assert "dag_id" in fields
|
||||
assert str in fields["dag_id"].__args__ # Check if str is in the Union types
|
||||
|
||||
# Verify parameter is mapped correctly
|
||||
assert "dag_id" in operation.input_model.model_config["parameter_mapping"]["path"]
|
||||
|
||||
|
||||
def test_parse_operation_with_query_params(parser: OperationParser) -> None:
|
||||
"""Test parsing operation with query parameters."""
|
||||
operation = parser.parse_operation("get_dags")
|
||||
|
||||
# Verify query parameter field exists
|
||||
fields = operation.input_model.__annotations__
|
||||
assert "limit" in fields
|
||||
assert int in fields["limit"].__args__ # Check if int is in the Union types
|
||||
|
||||
# Verify parameter is mapped correctly
|
||||
assert "limit" in operation.input_model.model_config["parameter_mapping"]["query"]
|
||||
|
||||
|
||||
def test_parse_operation_with_body_params(parser: OperationParser) -> None:
|
||||
"""Test parsing operation with request body."""
|
||||
operation = parser.parse_operation("post_dag_run")
|
||||
|
||||
# Verify body fields exist
|
||||
fields = operation.input_model.__annotations__
|
||||
assert "dag_run_id" in fields
|
||||
assert str in fields["dag_run_id"].__args__ # Check if str is in the Union types
|
||||
|
||||
# Verify parameter is mapped correctly
|
||||
assert "dag_run_id" in operation.input_model.model_config["parameter_mapping"]["body"]
|
||||
|
||||
|
||||
def test_parse_operation_not_found(parser: OperationParser) -> None:
|
||||
"""Test error handling for non-existent operation."""
|
||||
with pytest.raises(ValueError, match="Operation invalid_op not found in spec"):
|
||||
parser.parse_operation("invalid_op")
|
||||
|
||||
|
||||
def test_extract_parameters_empty(parser: OperationParser) -> None:
|
||||
"""Test parameter extraction with no parameters."""
|
||||
params = parser.extract_parameters({})
|
||||
|
||||
assert isinstance(params, dict)
|
||||
assert "path" in params
|
||||
assert "query" in params
|
||||
assert "header" in params
|
||||
assert all(isinstance(v, dict) for v in params.values())
|
||||
|
||||
|
||||
def test_map_parameter_schema_array(parser: OperationParser) -> None:
|
||||
"""Test mapping array parameter schema."""
|
||||
param: dict[str, Any] = {
|
||||
"name": "tags",
|
||||
"in": "query",
|
||||
"schema": {"type": "array", "items": {"type": "string"}},
|
||||
}
|
||||
|
||||
result = parser._map_parameter_schema(param)
|
||||
assert isinstance(result["type"], type(list))
|
||||
|
||||
|
||||
def test_map_parameter_schema_nullable(parser: OperationParser) -> None:
|
||||
"""Test mapping nullable parameter schema."""
|
||||
param: dict[str, Any] = {
|
||||
"name": "test",
|
||||
"in": "query",
|
||||
"schema": {"type": "string", "nullable": True},
|
||||
}
|
||||
|
||||
result = parser._map_parameter_schema(param)
|
||||
# Check that str is in the Union types
|
||||
assert str in result["type"].__args__
|
||||
assert None.__class__ in result["type"].__args__ # Check for NoneType
|
||||
assert not result["required"]
|
||||
|
||||
|
||||
def test_create_model_invalid_schema(parser: OperationParser) -> None:
|
||||
"""Test error handling for invalid schema."""
|
||||
with pytest.raises(ValueError, match="Schema must be an object type"):
|
||||
parser._create_model("Test", {"type": "string"})
|
||||
|
||||
|
||||
def test_create_model_nested_objects(parser: OperationParser) -> None:
|
||||
"""Test creating model with nested objects."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {"nested": {"type": "object", "properties": {"field": {"type": "string"}}}},
|
||||
}
|
||||
|
||||
model = parser._create_model("Test", schema)
|
||||
assert issubclass(model, BaseModel)
|
||||
fields = model.__annotations__
|
||||
assert "nested" in fields
|
||||
assert issubclass(fields["nested"], BaseModel)
|
||||
nested_fields = fields["nested"].__annotations__
|
||||
assert "field" in nested_fields
|
||||
assert isinstance(nested_fields["field"], type(str))
|
||||
|
||||
|
||||
def test_parse_operation_with_allof_body(parser: OperationParser) -> None:
|
||||
"""Test parsing operation with allOf schema in request body."""
|
||||
operation = parser.parse_operation("test_connection")
|
||||
|
||||
assert isinstance(operation, OperationDetails)
|
||||
assert operation.operation_id == "test_connection"
|
||||
assert operation.path == "/connections/test"
|
||||
assert operation.method == "post"
|
||||
|
||||
# Verify input model includes fields from allOf schema
|
||||
fields = operation.input_model.__annotations__
|
||||
assert "connection_id" in fields, "Missing connection_id from ConnectionCollectionItem"
|
||||
assert str in fields["connection_id"].__args__, "connection_id should be a string"
|
||||
assert "password" in fields, "Missing password from Connection"
|
||||
assert str in fields["password"].__args__, "password should be a string"
|
||||
assert "connection_schema" in fields, "Missing schema field (aliased as connection_schema)"
|
||||
assert str in fields["connection_schema"].__args__, "connection_schema should be a string"
|
||||
|
||||
# Verify parameter mapping
|
||||
mapping = operation.input_model.model_config["parameter_mapping"]
|
||||
assert "body" in mapping
|
||||
assert "connection_id" in mapping["body"]
|
||||
assert "password" in mapping["body"]
|
||||
assert "connection_schema" in mapping["body"]
|
||||
|
||||
# Verify alias configuration
|
||||
model_fields = operation.input_model.model_fields
|
||||
assert "connection_schema" in model_fields
|
||||
assert model_fields["connection_schema"].alias == "schema", "connection_schema should alias to schema"
|
||||
0
tests/tools/__init__.py
Normal file
0
tests/tools/__init__.py
Normal file
135
tests/tools/test_airflow_tool.py
Normal file
135
tests/tools/test_airflow_tool.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""Tests for AirflowTool."""
|
||||
|
||||
import pytest
|
||||
from airflow_mcp_server.client.airflow_client import AirflowClient
|
||||
from airflow_mcp_server.parser.operation_parser import OperationDetails
|
||||
from airflow_mcp_server.tools.airflow_tool import AirflowTool
|
||||
from pydantic import ValidationError
|
||||
|
||||
from tests.tools.test_models import TestRequestModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_client(mocker):
|
||||
"""Create mock Airflow client."""
|
||||
client = mocker.Mock(spec=AirflowClient)
|
||||
client.execute = mocker.AsyncMock()
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def operation_details():
|
||||
"""Create test operation details."""
|
||||
model = TestRequestModel
|
||||
# Add parameter mapping to model config
|
||||
model.model_config["parameter_mapping"] = {
|
||||
"path": ["path_id"],
|
||||
"query": ["query_filter"],
|
||||
"body": ["body_name", "body_value"],
|
||||
}
|
||||
|
||||
return OperationDetails(
|
||||
operation_id="test_operation",
|
||||
path="/test/{path_id}",
|
||||
method="POST",
|
||||
parameters={
|
||||
"path": {
|
||||
"path_id": {"type": int, "required": True},
|
||||
},
|
||||
"query": {
|
||||
"query_filter": {"type": str, "required": False},
|
||||
},
|
||||
},
|
||||
input_model=model,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def airflow_tool(mock_client, operation_details):
|
||||
"""Create AirflowTool instance for testing."""
|
||||
return AirflowTool(operation_details, mock_client)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_execution(airflow_tool, mock_client):
|
||||
"""Test successful operation execution with valid parameters."""
|
||||
# Setup mock response
|
||||
mock_client.execute.return_value = {"item_id": 1, "result": "success"}
|
||||
|
||||
# Execute operation with unified body
|
||||
result = await airflow_tool.run(
|
||||
body={
|
||||
"path_id": 123,
|
||||
"query_filter": "test",
|
||||
"body_name": "test",
|
||||
"body_value": 42,
|
||||
}
|
||||
)
|
||||
|
||||
# Verify response
|
||||
assert result == {"item_id": 1, "result": "success"}
|
||||
mock_client.execute.assert_called_once_with(
|
||||
operation_id="test_operation",
|
||||
path_params={"path_id": 123},
|
||||
query_params={"query_filter": "test"},
|
||||
body={"body_name": "test", "body_value": 42},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_path_parameter(airflow_tool):
|
||||
"""Test validation error for invalid path parameter type."""
|
||||
with pytest.raises(ValidationError):
|
||||
await airflow_tool.run(
|
||||
body={
|
||||
"path_id": "not_an_integer", # Invalid type
|
||||
"body_name": "test",
|
||||
"body_value": 42,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_request_body(airflow_tool):
|
||||
"""Test validation error for invalid request body."""
|
||||
with pytest.raises(ValidationError):
|
||||
await airflow_tool.run(
|
||||
body={
|
||||
"path_id": 123,
|
||||
"body_name": "test",
|
||||
"body_value": "not_an_integer", # Invalid type
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_response_format(airflow_tool, mock_client):
|
||||
"""Test error handling for invalid response format."""
|
||||
# Setup mock response
|
||||
mock_client.execute.return_value = {"invalid": "response"}
|
||||
|
||||
# Should not raise any validation error
|
||||
result = await airflow_tool.run(
|
||||
body={
|
||||
"path_id": 123,
|
||||
"body_name": "test",
|
||||
"body_value": 42,
|
||||
}
|
||||
)
|
||||
assert result == {"invalid": "response"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_error(airflow_tool, mock_client):
|
||||
"""Test error handling for client execution failure."""
|
||||
# Setup mock to raise exception
|
||||
mock_client.execute.side_effect = RuntimeError("API Error")
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
await airflow_tool.run(
|
||||
body={
|
||||
"path_id": 123,
|
||||
"body_name": "test",
|
||||
"body_value": 42,
|
||||
}
|
||||
)
|
||||
12
tests/tools/test_models.py
Normal file
12
tests/tools/test_models.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""Test models for Airflow tool tests."""
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class TestRequestModel(BaseModel):
|
||||
"""Test request model."""
|
||||
|
||||
path_id: int
|
||||
query_filter: str | None = None
|
||||
body_name: str
|
||||
body_value: int
|
||||
0
tests/tools/test_tool_manager.py
Normal file
0
tests/tools/test_tool_manager.py
Normal file
Reference in New Issue
Block a user