diff --git a/airflow-mcp-server/src/airflow_mcp_server/client/airflow_client.py b/airflow-mcp-server/src/airflow_mcp_server/client/airflow_client.py index c64e28d..1bebca1 100644 --- a/airflow-mcp-server/src/airflow_mcp_server/client/airflow_client.py +++ b/airflow-mcp-server/src/airflow_mcp_server/client/airflow_client.py @@ -6,7 +6,10 @@ from typing import Any, BinaryIO, TextIO import aiohttp import yaml +from jsonschema_path import SchemaPath from openapi_core import OpenAPI +from openapi_core.validation.request.validators import V31RequestValidator +from openapi_spec_validator import validate logger = logging.getLogger(__name__) @@ -62,6 +65,18 @@ class AirflowClient: else: raise ValueError("Invalid spec_path type. Expected Path, str, dict, bytes or file-like object") + # Validate spec has required fields + if not isinstance(self.raw_spec, dict): + raise ValueError("OpenAPI spec must be a dictionary") + + required_fields = ["openapi", "info", "paths"] + for field in required_fields: + if field not in self.raw_spec: + raise ValueError(f"OpenAPI spec missing required field: {field}") + + # Validate OpenAPI spec format + validate(self.raw_spec) + # Initialize OpenAPI spec self.spec = OpenAPI.from_dict(self.raw_spec) logger.debug("OpenAPI spec loaded successfully") @@ -75,6 +90,10 @@ class AirflowClient: self._paths = self.raw_spec["paths"] logger.debug("Using raw spec paths") + # Initialize request validator with schema path + schema_path = SchemaPath.from_dict(self.raw_spec) + self._validator = V31RequestValidator(schema_path) + # API configuration self.base_url = base_url.rstrip("/") self.headers = { @@ -85,16 +104,14 @@ class AirflowClient: except Exception as e: logger.error("Failed to initialize AirflowClient: %s", e) - raise + raise ValueError(f"Failed to initialize client: {e}") async def __aenter__(self) -> "AirflowClient": - """Enter async context, creating session.""" self._session = aiohttp.ClientSession(headers=self.headers) return self async def __aexit__(self, *exc) -> None: - """Exit async context, closing session.""" - if self._session: + if hasattr(self, "_session"): await self._session.close() delattr(self, "_session") @@ -135,6 +152,23 @@ class AirflowClient: logger.error("Error getting operation %s: %s", operation_id, e) raise + def _validate_path_params(self, path: str, params: dict[str, Any] | None) -> None: + if not params: + params = {} + + # Extract path parameter names from the path + path_params = set(re.findall(r"{([^}]+)}", path)) + + # Check for missing required parameters + missing_params = path_params - set(params.keys()) + if missing_params: + raise ValueError(f"Missing required path parameters: {missing_params}") + + # Check for invalid parameters + invalid_params = set(params.keys()) - path_params + if invalid_params: + raise ValueError(f"Invalid path parameters: {invalid_params}") + async def execute( self, operation_id: str, @@ -165,6 +199,9 @@ class AirflowClient: # Get operation details path, method, _ = self._get_operation(operation_id) + # Validate path parameters + self._validate_path_params(path, path_params) + # Format URL if path_params: path = path.format(**path_params) @@ -182,6 +219,9 @@ class AirflowClient: response.raise_for_status() return await response.json() - except Exception as e: + except aiohttp.ClientError as e: logger.error("Error executing operation %s: %s", operation_id, e) raise + except Exception as e: + logger.error("Error executing operation %s: %s", operation_id, e) + raise ValueError(f"Failed to execute operation: {e}") diff --git a/airflow-mcp-server/tests/client/test_airflow_client.py b/airflow-mcp-server/tests/client/test_airflow_client.py index 34a6f9c..41cd497 100644 --- a/airflow-mcp-server/tests/client/test_airflow_client.py +++ b/airflow-mcp-server/tests/client/test_airflow_client.py @@ -1,26 +1,30 @@ import logging from importlib import resources +from pathlib import Path +from typing import Any import aiohttp import pytest import yaml from aioresponses import aioresponses -from airflow_mcp_server.client import AirflowClient +from airflow_mcp_server.client.airflow_client import AirflowClient from openapi_core import OpenAPI -logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") +logging.basicConfig(level=logging.DEBUG) + + +def create_valid_spec(paths: dict[str, Any] | None = None) -> dict[str, Any]: + return {"openapi": "3.0.0", "info": {"title": "Airflow API", "version": "1.0.0"}, "paths": paths or {}} @pytest.fixture -def spec_file(): - """Get content of the v1.yaml spec file.""" +def spec_file() -> dict[str, Any]: with resources.files("tests.client").joinpath("v1.yaml").open("r") as f: return yaml.safe_load(f) @pytest.fixture -def client(spec_file): - """Create test client with the actual spec.""" +def client(spec_file: dict[str, Any]) -> AirflowClient: return AirflowClient( spec_path=spec_file, base_url="http://localhost:8080/api/v1", @@ -28,46 +32,70 @@ def client(spec_file): ) -def test_client_initialization(client): - """Test client initialization and spec loading.""" +def test_init_client_initialization(client: AirflowClient) -> None: assert isinstance(client.spec, OpenAPI) assert client.base_url == "http://localhost:8080/api/v1" assert client.headers["Authorization"] == "Bearer test-token" -def test_get_operation(client): - """Test operation lookup from spec.""" - # Test get_dags operation +def test_init_load_spec_from_bytes() -> None: + spec_bytes = yaml.dump(create_valid_spec()).encode() + client = AirflowClient(spec_path=spec_bytes, base_url="http://test", auth_token="test") + assert client.raw_spec is not None + + +def test_init_load_spec_from_path(tmp_path: Path) -> None: + spec_file = tmp_path / "test_spec.yaml" + spec_file.write_text(yaml.dump(create_valid_spec())) + client = AirflowClient(spec_path=spec_file, base_url="http://test", auth_token="test") + assert client.raw_spec is not None + + +def test_init_invalid_spec() -> None: + with pytest.raises(ValueError): + AirflowClient(spec_path={"invalid": "spec"}, base_url="http://test", auth_token="test") + + +def test_init_missing_paths_in_spec() -> None: + with pytest.raises(ValueError): + AirflowClient(spec_path={"openapi": "3.0.0"}, base_url="http://test", auth_token="test") + + +def test_ops_get_operation(client: AirflowClient) -> None: path, method, operation = client._get_operation("get_dags") assert path == "/dags" assert method == "get" assert operation.operation_id == "get_dags" - # Test get_dag operation path, method, operation = client._get_operation("get_dag") assert path == "/dags/{dag_id}" assert method == "get" assert operation.operation_id == "get_dag" -# Note: asyncio_mode is configured in pyproject.toml +def test_ops_nonexistent_operation(client: AirflowClient) -> None: + with pytest.raises(ValueError, match="Operation nonexistent not found in spec"): + client._get_operation("nonexistent") + + +def test_ops_case_sensitive_operation(client: AirflowClient) -> None: + with pytest.raises(ValueError): + client._get_operation("GET_DAGS") + + @pytest.mark.asyncio -async def test_execute_without_context(): - """Test error when executing outside async context.""" - with resources.files("tests.client").joinpath("v1.yaml").open("r") as f: - spec_content = yaml.safe_load(f) - client = AirflowClient( - spec_path=spec_content, - base_url="http://test", - auth_token="test", - ) - with pytest.raises((RuntimeError, AttributeError)): +async def test_exec_without_context(spec_file: dict[str, Any]) -> None: + client = AirflowClient( + spec_path=spec_file, + base_url="http://test", + auth_token="test", + ) + with pytest.raises(RuntimeError, match="Client not in async context"): await client.execute("get_dags") @pytest.mark.asyncio -async def test_execute_get_dags(client): - """Test DAG list retrieval.""" +async def test_exec_get_dags(client: AirflowClient) -> None: expected_response = { "dags": [ { @@ -91,8 +119,7 @@ async def test_execute_get_dags(client): @pytest.mark.asyncio -async def test_execute_get_dag(client): - """Test single DAG retrieval with path parameters.""" +async def test_exec_get_dag(client: AirflowClient) -> None: expected_response = { "dag_id": "test_dag", "is_active": True, @@ -114,8 +141,29 @@ async def test_execute_get_dag(client): @pytest.mark.asyncio -async def test_execute_error_response(client): - """Test error handling for failed requests.""" +async def test_exec_invalid_params(client: AirflowClient) -> None: + with pytest.raises(ValueError): + async with client: + # Test with missing required parameter + await client.execute("get_dag", path_params={}) + + with pytest.raises(ValueError): + async with client: + # Test with invalid parameter name + await client.execute("get_dag", path_params={"invalid": "value"}) + + +@pytest.mark.asyncio +async def test_exec_timeout(client: AirflowClient) -> None: + with aioresponses() as mock: + mock.get("http://localhost:8080/api/v1/dags", exception=aiohttp.ClientError("Timeout")) + async with client: + with pytest.raises(aiohttp.ClientError): + await client.execute("get_dags") + + +@pytest.mark.asyncio +async def test_exec_error_response(client: AirflowClient) -> None: with aioresponses() as mock: async with client: mock.get( @@ -123,15 +171,13 @@ async def test_execute_error_response(client): status=403, body="Forbidden", ) - with pytest.raises(aiohttp.ClientError): + with pytest.raises(aiohttp.ClientResponseError): await client.execute("get_dags") @pytest.mark.asyncio -async def test_session_management(client): - """Test proper session creation and cleanup.""" +async def test_exec_session_management(client: AirflowClient) -> None: async with client: - # Should work inside context with aioresponses() as mock: mock.get( "http://localhost:8080/api/v1/dags", @@ -140,6 +186,5 @@ async def test_session_management(client): ) await client.execute("get_dags") - # Should fail after context exit with pytest.raises(RuntimeError): await client.execute("get_dags")