Clean up for only MCP Server

This commit is contained in:
2025-02-24 16:50:08 +00:00
parent 5d199ba154
commit 16cd3f48fe
52 changed files with 66 additions and 1317 deletions

0
tests/__init__.py Normal file
View File

View File

@@ -0,0 +1,186 @@
import logging
from importlib import resources
from pathlib import Path
from typing import Any
import aiohttp
import pytest
import yaml
from aioresponses import aioresponses
from airflow_mcp_server.client.airflow_client import AirflowClient
from openapi_core import OpenAPI
logging.basicConfig(level=logging.DEBUG)
def create_valid_spec(paths: dict[str, Any] | None = None) -> dict[str, Any]:
return {"openapi": "3.0.0", "info": {"title": "Airflow API", "version": "1.0.0"}, "paths": paths or {}}
@pytest.fixture
def client() -> AirflowClient:
with resources.files("airflow_mcp_server.resources").joinpath("v1.yaml").open("rb") as f:
spec = yaml.safe_load(f)
return AirflowClient(
spec_path=spec,
base_url="http://localhost:8080/api/v1",
auth_token="test-token",
)
def test_init_client_initialization(client: AirflowClient) -> None:
assert isinstance(client.spec, OpenAPI)
assert client.base_url == "http://localhost:8080/api/v1"
assert client.headers["Authorization"] == "Basic test-token"
def test_init_load_spec_from_bytes() -> None:
spec_bytes = yaml.dump(create_valid_spec()).encode()
client = AirflowClient(spec_path=spec_bytes, base_url="http://test", auth_token="test")
assert client.raw_spec is not None
def test_init_load_spec_from_path(tmp_path: Path) -> None:
spec_file = tmp_path / "test_spec.yaml"
spec_file.write_text(yaml.dump(create_valid_spec()))
client = AirflowClient(spec_path=spec_file, base_url="http://test", auth_token="test")
assert client.raw_spec is not None
def test_init_invalid_spec() -> None:
with pytest.raises(ValueError):
AirflowClient(spec_path={"invalid": "spec"}, base_url="http://test", auth_token="test")
def test_init_missing_paths_in_spec() -> None:
with pytest.raises(ValueError):
AirflowClient(spec_path={"openapi": "3.0.0"}, base_url="http://test", auth_token="test")
def test_ops_get_operation(client: AirflowClient) -> None:
path, method, operation = client._get_operation("get_dags")
assert path == "/dags"
assert method == "get"
assert operation.operation_id == "get_dags"
path, method, operation = client._get_operation("get_dag")
assert path == "/dags/{dag_id}"
assert method == "get"
assert operation.operation_id == "get_dag"
def test_ops_nonexistent_operation(client: AirflowClient) -> None:
with pytest.raises(ValueError, match="Operation nonexistent not found in spec"):
client._get_operation("nonexistent")
def test_ops_case_sensitive_operation(client: AirflowClient) -> None:
with pytest.raises(ValueError):
client._get_operation("GET_DAGS")
@pytest.mark.asyncio
async def test_exec_without_context() -> None:
client = AirflowClient(
spec_path=create_valid_spec(),
base_url="http://test",
auth_token="test",
)
with pytest.raises(RuntimeError, match="Client not in async context"):
await client.execute("get_dags")
@pytest.mark.asyncio
async def test_exec_get_dags(client: AirflowClient) -> None:
expected_response = {
"dags": [
{
"dag_id": "test_dag",
"is_active": True,
"is_paused": False,
}
],
"total_entries": 1,
}
with aioresponses() as mock:
async with client:
mock.get(
"http://localhost:8080/api/v1/dags?limit=100",
status=200,
payload=expected_response,
)
response = await client.execute("get_dags", query_params={"limit": 100})
assert response == expected_response
@pytest.mark.asyncio
async def test_exec_get_dag(client: AirflowClient) -> None:
expected_response = {
"dag_id": "test_dag",
"is_active": True,
"is_paused": False,
}
with aioresponses() as mock:
async with client:
mock.get(
"http://localhost:8080/api/v1/dags/test_dag",
status=200,
payload=expected_response,
)
response = await client.execute(
"get_dag",
path_params={"dag_id": "test_dag"},
)
assert response == expected_response
@pytest.mark.asyncio
async def test_exec_invalid_params(client: AirflowClient) -> None:
with pytest.raises(ValueError):
async with client:
# Test with missing required parameter
await client.execute("get_dag", path_params={})
with pytest.raises(ValueError):
async with client:
# Test with invalid parameter name
await client.execute("get_dag", path_params={"invalid": "value"})
@pytest.mark.asyncio
async def test_exec_timeout(client: AirflowClient) -> None:
with aioresponses() as mock:
mock.get("http://localhost:8080/api/v1/dags", exception=aiohttp.ClientError("Timeout"))
async with client:
with pytest.raises(aiohttp.ClientError):
await client.execute("get_dags")
@pytest.mark.asyncio
async def test_exec_error_response(client: AirflowClient) -> None:
with aioresponses() as mock:
async with client:
mock.get(
"http://localhost:8080/api/v1/dags",
status=403,
body="Forbidden",
)
with pytest.raises(aiohttp.ClientResponseError):
await client.execute("get_dags")
@pytest.mark.asyncio
async def test_exec_session_management(client: AirflowClient) -> None:
async with client:
with aioresponses() as mock:
mock.get(
"http://localhost:8080/api/v1/dags",
status=200,
payload={"dags": []},
)
await client.execute("get_dags")
with pytest.raises(RuntimeError):
await client.execute("get_dags")

58
tests/conftest.py Normal file
View File

@@ -0,0 +1,58 @@
"""Test configuration and shared fixtures."""
import pytest
@pytest.fixture
def mock_spec_file():
"""Mock OpenAPI spec file for testing."""
mock_spec = {
"openapi": "3.0.0",
"info": {"title": "Airflow API", "version": "1.0.0"},
"paths": {
"/api/v1/dags": {
"get": {
"operationId": "get_dags",
"responses": {
"200": {
"description": "List of DAGs",
"content": {
"application/json": {"schema": {"type": "object", "properties": {"dags": {"type": "array", "items": {"type": "object", "properties": {"dag_id": {"type": "string"}}}}}}}
},
}
},
}
},
"/api/v1/dags/{dag_id}": {
"get": {
"operationId": "get_dag",
"parameters": [{"name": "dag_id", "in": "path", "required": True, "schema": {"type": "string"}}],
"responses": {"200": {"description": "Successful response", "content": {"application/json": {"schema": {"type": "object", "properties": {"dag_id": {"type": "string"}}}}}}},
},
"post": {
"operationId": "post_dag_run",
"parameters": [{"name": "dag_id", "in": "path", "required": True, "schema": {"type": "string"}}],
"requestBody": {
"content": {
"application/json": {
"schema": {
"type": "object",
"properties": {
"conf": {"type": "object"},
"dag_run_id": {"type": "string"},
},
}
}
}
},
"responses": {
"200": {
"description": "Successful response",
"content": {"application/json": {"schema": {"type": "object", "properties": {"dag_run_id": {"type": "string"}, "state": {"type": "string"}}}}},
}
},
},
},
},
}
return mock_spec

View File

@@ -1,83 +0,0 @@
"""Example DAGs test. This test ensures that all Dags have tags, retries set to two, and no import errors. This is an example pytest and may not be fit the context of your DAGs. Feel free to add and remove tests."""
import os
import logging
from contextlib import contextmanager
import pytest
from airflow.models import DagBag
@contextmanager
def suppress_logging(namespace):
logger = logging.getLogger(namespace)
old_value = logger.disabled
logger.disabled = True
try:
yield
finally:
logger.disabled = old_value
def get_import_errors():
"""
Generate a tuple for import errors in the dag bag
"""
with suppress_logging("airflow"):
dag_bag = DagBag(include_examples=False)
def strip_path_prefix(path):
return os.path.relpath(path, os.environ.get("AIRFLOW_HOME"))
# prepend "(None,None)" to ensure that a test object is always created even if it's a no op.
return [(None, None)] + [
(strip_path_prefix(k), v.strip()) for k, v in dag_bag.import_errors.items()
]
def get_dags():
"""
Generate a tuple of dag_id, <DAG objects> in the DagBag
"""
with suppress_logging("airflow"):
dag_bag = DagBag(include_examples=False)
def strip_path_prefix(path):
return os.path.relpath(path, os.environ.get("AIRFLOW_HOME"))
return [(k, v, strip_path_prefix(v.fileloc)) for k, v in dag_bag.dags.items()]
@pytest.mark.parametrize(
"rel_path,rv", get_import_errors(), ids=[x[0] for x in get_import_errors()]
)
def test_file_imports(rel_path, rv):
"""Test for import errors on a file"""
if rel_path and rv:
raise Exception(f"{rel_path} failed to import with message \n {rv}")
APPROVED_TAGS = {}
@pytest.mark.parametrize(
"dag_id,dag,fileloc", get_dags(), ids=[x[2] for x in get_dags()]
)
def test_dag_tags(dag_id, dag, fileloc):
"""
test if a DAG is tagged and if those TAGs are in the approved list
"""
assert dag.tags, f"{dag_id} in {fileloc} has no tags"
if APPROVED_TAGS:
assert not set(dag.tags) - APPROVED_TAGS
@pytest.mark.parametrize(
"dag_id,dag, fileloc", get_dags(), ids=[x[2] for x in get_dags()]
)
def test_dag_retries(dag_id, dag, fileloc):
"""
test if a DAG has retries set
"""
assert (
dag.default_args.get("retries", None) >= 2
), f"{dag_id} in {fileloc} must have task retries >= 2."

View File

@@ -0,0 +1,174 @@
import logging
from importlib import resources
from typing import Any
import pytest
from airflow_mcp_server.parser.operation_parser import OperationDetails, OperationParser
from pydantic import BaseModel
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
@pytest.fixture
def spec_file():
"""Get content of the v1.yaml spec file."""
with resources.files("tests.client").joinpath("v1.yaml").open("rb") as f:
return f.read()
@pytest.fixture
def parser(spec_file) -> OperationParser:
"""Create OperationParser instance."""
return OperationParser(spec_path=spec_file)
def test_parse_operation_basic(parser: OperationParser) -> None:
"""Test basic operation parsing."""
operation = parser.parse_operation("get_dags")
assert isinstance(operation, OperationDetails)
assert operation.operation_id == "get_dags"
assert operation.path == "/dags"
assert operation.method == "get"
assert isinstance(operation.parameters, dict)
def test_parse_operation_with_path_params(parser: OperationParser) -> None:
"""Test parsing operation with path parameters."""
operation = parser.parse_operation("get_dag")
assert operation.path == "/dags/{dag_id}"
assert isinstance(operation.input_model, type(BaseModel))
# Verify path parameter field exists
fields = operation.input_model.__annotations__
assert "dag_id" in fields
assert str in fields["dag_id"].__args__ # Check if str is in the Union types
# Verify parameter is mapped correctly
assert "dag_id" in operation.input_model.model_config["parameter_mapping"]["path"]
def test_parse_operation_with_query_params(parser: OperationParser) -> None:
"""Test parsing operation with query parameters."""
operation = parser.parse_operation("get_dags")
# Verify query parameter field exists
fields = operation.input_model.__annotations__
assert "limit" in fields
assert int in fields["limit"].__args__ # Check if int is in the Union types
# Verify parameter is mapped correctly
assert "limit" in operation.input_model.model_config["parameter_mapping"]["query"]
def test_parse_operation_with_body_params(parser: OperationParser) -> None:
"""Test parsing operation with request body."""
operation = parser.parse_operation("post_dag_run")
# Verify body fields exist
fields = operation.input_model.__annotations__
assert "dag_run_id" in fields
assert str in fields["dag_run_id"].__args__ # Check if str is in the Union types
# Verify parameter is mapped correctly
assert "dag_run_id" in operation.input_model.model_config["parameter_mapping"]["body"]
def test_parse_operation_not_found(parser: OperationParser) -> None:
"""Test error handling for non-existent operation."""
with pytest.raises(ValueError, match="Operation invalid_op not found in spec"):
parser.parse_operation("invalid_op")
def test_extract_parameters_empty(parser: OperationParser) -> None:
"""Test parameter extraction with no parameters."""
params = parser.extract_parameters({})
assert isinstance(params, dict)
assert "path" in params
assert "query" in params
assert "header" in params
assert all(isinstance(v, dict) for v in params.values())
def test_map_parameter_schema_array(parser: OperationParser) -> None:
"""Test mapping array parameter schema."""
param: dict[str, Any] = {
"name": "tags",
"in": "query",
"schema": {"type": "array", "items": {"type": "string"}},
}
result = parser._map_parameter_schema(param)
assert isinstance(result["type"], type(list))
def test_map_parameter_schema_nullable(parser: OperationParser) -> None:
"""Test mapping nullable parameter schema."""
param: dict[str, Any] = {
"name": "test",
"in": "query",
"schema": {"type": "string", "nullable": True},
}
result = parser._map_parameter_schema(param)
# Check that str is in the Union types
assert str in result["type"].__args__
assert None.__class__ in result["type"].__args__ # Check for NoneType
assert not result["required"]
def test_create_model_invalid_schema(parser: OperationParser) -> None:
"""Test error handling for invalid schema."""
with pytest.raises(ValueError, match="Schema must be an object type"):
parser._create_model("Test", {"type": "string"})
def test_create_model_nested_objects(parser: OperationParser) -> None:
"""Test creating model with nested objects."""
schema = {
"type": "object",
"properties": {"nested": {"type": "object", "properties": {"field": {"type": "string"}}}},
}
model = parser._create_model("Test", schema)
assert issubclass(model, BaseModel)
fields = model.__annotations__
assert "nested" in fields
assert issubclass(fields["nested"], BaseModel)
nested_fields = fields["nested"].__annotations__
assert "field" in nested_fields
assert isinstance(nested_fields["field"], type(str))
def test_parse_operation_with_allof_body(parser: OperationParser) -> None:
"""Test parsing operation with allOf schema in request body."""
operation = parser.parse_operation("test_connection")
assert isinstance(operation, OperationDetails)
assert operation.operation_id == "test_connection"
assert operation.path == "/connections/test"
assert operation.method == "post"
# Verify input model includes fields from allOf schema
fields = operation.input_model.__annotations__
assert "connection_id" in fields, "Missing connection_id from ConnectionCollectionItem"
assert str in fields["connection_id"].__args__, "connection_id should be a string"
assert "password" in fields, "Missing password from Connection"
assert str in fields["password"].__args__, "password should be a string"
assert "connection_schema" in fields, "Missing schema field (aliased as connection_schema)"
assert str in fields["connection_schema"].__args__, "connection_schema should be a string"
# Verify parameter mapping
mapping = operation.input_model.model_config["parameter_mapping"]
assert "body" in mapping
assert "connection_id" in mapping["body"]
assert "password" in mapping["body"]
assert "connection_schema" in mapping["body"]
# Verify alias configuration
model_fields = operation.input_model.model_fields
assert "connection_schema" in model_fields
assert model_fields["connection_schema"].alias == "schema", "connection_schema should alias to schema"

0
tests/tools/__init__.py Normal file
View File

View File

@@ -0,0 +1,135 @@
"""Tests for AirflowTool."""
import pytest
from airflow_mcp_server.client.airflow_client import AirflowClient
from airflow_mcp_server.parser.operation_parser import OperationDetails
from airflow_mcp_server.tools.airflow_tool import AirflowTool
from pydantic import ValidationError
from tests.tools.test_models import TestRequestModel
@pytest.fixture
def mock_client(mocker):
"""Create mock Airflow client."""
client = mocker.Mock(spec=AirflowClient)
client.execute = mocker.AsyncMock()
return client
@pytest.fixture
def operation_details():
"""Create test operation details."""
model = TestRequestModel
# Add parameter mapping to model config
model.model_config["parameter_mapping"] = {
"path": ["path_id"],
"query": ["query_filter"],
"body": ["body_name", "body_value"],
}
return OperationDetails(
operation_id="test_operation",
path="/test/{path_id}",
method="POST",
parameters={
"path": {
"path_id": {"type": int, "required": True},
},
"query": {
"query_filter": {"type": str, "required": False},
},
},
input_model=model,
)
@pytest.fixture
def airflow_tool(mock_client, operation_details):
"""Create AirflowTool instance for testing."""
return AirflowTool(operation_details, mock_client)
@pytest.mark.asyncio
async def test_successful_execution(airflow_tool, mock_client):
"""Test successful operation execution with valid parameters."""
# Setup mock response
mock_client.execute.return_value = {"item_id": 1, "result": "success"}
# Execute operation with unified body
result = await airflow_tool.run(
body={
"path_id": 123,
"query_filter": "test",
"body_name": "test",
"body_value": 42,
}
)
# Verify response
assert result == {"item_id": 1, "result": "success"}
mock_client.execute.assert_called_once_with(
operation_id="test_operation",
path_params={"path_id": 123},
query_params={"query_filter": "test"},
body={"body_name": "test", "body_value": 42},
)
@pytest.mark.asyncio
async def test_invalid_path_parameter(airflow_tool):
"""Test validation error for invalid path parameter type."""
with pytest.raises(ValidationError):
await airflow_tool.run(
body={
"path_id": "not_an_integer", # Invalid type
"body_name": "test",
"body_value": 42,
}
)
@pytest.mark.asyncio
async def test_invalid_request_body(airflow_tool):
"""Test validation error for invalid request body."""
with pytest.raises(ValidationError):
await airflow_tool.run(
body={
"path_id": 123,
"body_name": "test",
"body_value": "not_an_integer", # Invalid type
}
)
@pytest.mark.asyncio
async def test_invalid_response_format(airflow_tool, mock_client):
"""Test error handling for invalid response format."""
# Setup mock response
mock_client.execute.return_value = {"invalid": "response"}
# Should not raise any validation error
result = await airflow_tool.run(
body={
"path_id": 123,
"body_name": "test",
"body_value": 42,
}
)
assert result == {"invalid": "response"}
@pytest.mark.asyncio
async def test_client_error(airflow_tool, mock_client):
"""Test error handling for client execution failure."""
# Setup mock to raise exception
mock_client.execute.side_effect = RuntimeError("API Error")
with pytest.raises(RuntimeError):
await airflow_tool.run(
body={
"path_id": 123,
"body_name": "test",
"body_value": 42,
}
)

View File

@@ -0,0 +1,12 @@
"""Test models for Airflow tool tests."""
from pydantic import BaseModel
class TestRequestModel(BaseModel):
"""Test request model."""
path_id: int
query_filter: str | None = None
body_name: str
body_value: int

View File