fix airflow client

This commit is contained in:
2025-02-13 11:38:55 +00:00
parent b87fe176bd
commit 0758f5c17e
2 changed files with 124 additions and 39 deletions

View File

@@ -6,7 +6,10 @@ from typing import Any, BinaryIO, TextIO
import aiohttp import aiohttp
import yaml import yaml
from jsonschema_path import SchemaPath
from openapi_core import OpenAPI from openapi_core import OpenAPI
from openapi_core.validation.request.validators import V31RequestValidator
from openapi_spec_validator import validate
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -62,6 +65,18 @@ class AirflowClient:
else: else:
raise ValueError("Invalid spec_path type. Expected Path, str, dict, bytes or file-like object") raise ValueError("Invalid spec_path type. Expected Path, str, dict, bytes or file-like object")
# Validate spec has required fields
if not isinstance(self.raw_spec, dict):
raise ValueError("OpenAPI spec must be a dictionary")
required_fields = ["openapi", "info", "paths"]
for field in required_fields:
if field not in self.raw_spec:
raise ValueError(f"OpenAPI spec missing required field: {field}")
# Validate OpenAPI spec format
validate(self.raw_spec)
# Initialize OpenAPI spec # Initialize OpenAPI spec
self.spec = OpenAPI.from_dict(self.raw_spec) self.spec = OpenAPI.from_dict(self.raw_spec)
logger.debug("OpenAPI spec loaded successfully") logger.debug("OpenAPI spec loaded successfully")
@@ -75,6 +90,10 @@ class AirflowClient:
self._paths = self.raw_spec["paths"] self._paths = self.raw_spec["paths"]
logger.debug("Using raw spec paths") logger.debug("Using raw spec paths")
# Initialize request validator with schema path
schema_path = SchemaPath.from_dict(self.raw_spec)
self._validator = V31RequestValidator(schema_path)
# API configuration # API configuration
self.base_url = base_url.rstrip("/") self.base_url = base_url.rstrip("/")
self.headers = { self.headers = {
@@ -85,16 +104,14 @@ class AirflowClient:
except Exception as e: except Exception as e:
logger.error("Failed to initialize AirflowClient: %s", e) logger.error("Failed to initialize AirflowClient: %s", e)
raise raise ValueError(f"Failed to initialize client: {e}")
async def __aenter__(self) -> "AirflowClient": async def __aenter__(self) -> "AirflowClient":
"""Enter async context, creating session."""
self._session = aiohttp.ClientSession(headers=self.headers) self._session = aiohttp.ClientSession(headers=self.headers)
return self return self
async def __aexit__(self, *exc) -> None: async def __aexit__(self, *exc) -> None:
"""Exit async context, closing session.""" if hasattr(self, "_session"):
if self._session:
await self._session.close() await self._session.close()
delattr(self, "_session") delattr(self, "_session")
@@ -135,6 +152,23 @@ class AirflowClient:
logger.error("Error getting operation %s: %s", operation_id, e) logger.error("Error getting operation %s: %s", operation_id, e)
raise raise
def _validate_path_params(self, path: str, params: dict[str, Any] | None) -> None:
if not params:
params = {}
# Extract path parameter names from the path
path_params = set(re.findall(r"{([^}]+)}", path))
# Check for missing required parameters
missing_params = path_params - set(params.keys())
if missing_params:
raise ValueError(f"Missing required path parameters: {missing_params}")
# Check for invalid parameters
invalid_params = set(params.keys()) - path_params
if invalid_params:
raise ValueError(f"Invalid path parameters: {invalid_params}")
async def execute( async def execute(
self, self,
operation_id: str, operation_id: str,
@@ -165,6 +199,9 @@ class AirflowClient:
# Get operation details # Get operation details
path, method, _ = self._get_operation(operation_id) path, method, _ = self._get_operation(operation_id)
# Validate path parameters
self._validate_path_params(path, path_params)
# Format URL # Format URL
if path_params: if path_params:
path = path.format(**path_params) path = path.format(**path_params)
@@ -182,6 +219,9 @@ class AirflowClient:
response.raise_for_status() response.raise_for_status()
return await response.json() return await response.json()
except Exception as e: except aiohttp.ClientError as e:
logger.error("Error executing operation %s: %s", operation_id, e) logger.error("Error executing operation %s: %s", operation_id, e)
raise raise
except Exception as e:
logger.error("Error executing operation %s: %s", operation_id, e)
raise ValueError(f"Failed to execute operation: {e}")

View File

@@ -1,26 +1,30 @@
import logging import logging
from importlib import resources from importlib import resources
from pathlib import Path
from typing import Any
import aiohttp import aiohttp
import pytest import pytest
import yaml import yaml
from aioresponses import aioresponses from aioresponses import aioresponses
from airflow_mcp_server.client import AirflowClient from airflow_mcp_server.client.airflow_client import AirflowClient
from openapi_core import OpenAPI from openapi_core import OpenAPI
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") logging.basicConfig(level=logging.DEBUG)
def create_valid_spec(paths: dict[str, Any] | None = None) -> dict[str, Any]:
return {"openapi": "3.0.0", "info": {"title": "Airflow API", "version": "1.0.0"}, "paths": paths or {}}
@pytest.fixture @pytest.fixture
def spec_file(): def spec_file() -> dict[str, Any]:
"""Get content of the v1.yaml spec file."""
with resources.files("tests.client").joinpath("v1.yaml").open("r") as f: with resources.files("tests.client").joinpath("v1.yaml").open("r") as f:
return yaml.safe_load(f) return yaml.safe_load(f)
@pytest.fixture @pytest.fixture
def client(spec_file): def client(spec_file: dict[str, Any]) -> AirflowClient:
"""Create test client with the actual spec."""
return AirflowClient( return AirflowClient(
spec_path=spec_file, spec_path=spec_file,
base_url="http://localhost:8080/api/v1", base_url="http://localhost:8080/api/v1",
@@ -28,46 +32,70 @@ def client(spec_file):
) )
def test_client_initialization(client): def test_init_client_initialization(client: AirflowClient) -> None:
"""Test client initialization and spec loading."""
assert isinstance(client.spec, OpenAPI) assert isinstance(client.spec, OpenAPI)
assert client.base_url == "http://localhost:8080/api/v1" assert client.base_url == "http://localhost:8080/api/v1"
assert client.headers["Authorization"] == "Bearer test-token" assert client.headers["Authorization"] == "Bearer test-token"
def test_get_operation(client): def test_init_load_spec_from_bytes() -> None:
"""Test operation lookup from spec.""" spec_bytes = yaml.dump(create_valid_spec()).encode()
# Test get_dags operation client = AirflowClient(spec_path=spec_bytes, base_url="http://test", auth_token="test")
assert client.raw_spec is not None
def test_init_load_spec_from_path(tmp_path: Path) -> None:
spec_file = tmp_path / "test_spec.yaml"
spec_file.write_text(yaml.dump(create_valid_spec()))
client = AirflowClient(spec_path=spec_file, base_url="http://test", auth_token="test")
assert client.raw_spec is not None
def test_init_invalid_spec() -> None:
with pytest.raises(ValueError):
AirflowClient(spec_path={"invalid": "spec"}, base_url="http://test", auth_token="test")
def test_init_missing_paths_in_spec() -> None:
with pytest.raises(ValueError):
AirflowClient(spec_path={"openapi": "3.0.0"}, base_url="http://test", auth_token="test")
def test_ops_get_operation(client: AirflowClient) -> None:
path, method, operation = client._get_operation("get_dags") path, method, operation = client._get_operation("get_dags")
assert path == "/dags" assert path == "/dags"
assert method == "get" assert method == "get"
assert operation.operation_id == "get_dags" assert operation.operation_id == "get_dags"
# Test get_dag operation
path, method, operation = client._get_operation("get_dag") path, method, operation = client._get_operation("get_dag")
assert path == "/dags/{dag_id}" assert path == "/dags/{dag_id}"
assert method == "get" assert method == "get"
assert operation.operation_id == "get_dag" assert operation.operation_id == "get_dag"
# Note: asyncio_mode is configured in pyproject.toml def test_ops_nonexistent_operation(client: AirflowClient) -> None:
with pytest.raises(ValueError, match="Operation nonexistent not found in spec"):
client._get_operation("nonexistent")
def test_ops_case_sensitive_operation(client: AirflowClient) -> None:
with pytest.raises(ValueError):
client._get_operation("GET_DAGS")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_execute_without_context(): async def test_exec_without_context(spec_file: dict[str, Any]) -> None:
"""Test error when executing outside async context."""
with resources.files("tests.client").joinpath("v1.yaml").open("r") as f:
spec_content = yaml.safe_load(f)
client = AirflowClient( client = AirflowClient(
spec_path=spec_content, spec_path=spec_file,
base_url="http://test", base_url="http://test",
auth_token="test", auth_token="test",
) )
with pytest.raises((RuntimeError, AttributeError)): with pytest.raises(RuntimeError, match="Client not in async context"):
await client.execute("get_dags") await client.execute("get_dags")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_execute_get_dags(client): async def test_exec_get_dags(client: AirflowClient) -> None:
"""Test DAG list retrieval."""
expected_response = { expected_response = {
"dags": [ "dags": [
{ {
@@ -91,8 +119,7 @@ async def test_execute_get_dags(client):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_execute_get_dag(client): async def test_exec_get_dag(client: AirflowClient) -> None:
"""Test single DAG retrieval with path parameters."""
expected_response = { expected_response = {
"dag_id": "test_dag", "dag_id": "test_dag",
"is_active": True, "is_active": True,
@@ -114,8 +141,29 @@ async def test_execute_get_dag(client):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_execute_error_response(client): async def test_exec_invalid_params(client: AirflowClient) -> None:
"""Test error handling for failed requests.""" with pytest.raises(ValueError):
async with client:
# Test with missing required parameter
await client.execute("get_dag", path_params={})
with pytest.raises(ValueError):
async with client:
# Test with invalid parameter name
await client.execute("get_dag", path_params={"invalid": "value"})
@pytest.mark.asyncio
async def test_exec_timeout(client: AirflowClient) -> None:
with aioresponses() as mock:
mock.get("http://localhost:8080/api/v1/dags", exception=aiohttp.ClientError("Timeout"))
async with client:
with pytest.raises(aiohttp.ClientError):
await client.execute("get_dags")
@pytest.mark.asyncio
async def test_exec_error_response(client: AirflowClient) -> None:
with aioresponses() as mock: with aioresponses() as mock:
async with client: async with client:
mock.get( mock.get(
@@ -123,15 +171,13 @@ async def test_execute_error_response(client):
status=403, status=403,
body="Forbidden", body="Forbidden",
) )
with pytest.raises(aiohttp.ClientError): with pytest.raises(aiohttp.ClientResponseError):
await client.execute("get_dags") await client.execute("get_dags")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_session_management(client): async def test_exec_session_management(client: AirflowClient) -> None:
"""Test proper session creation and cleanup."""
async with client: async with client:
# Should work inside context
with aioresponses() as mock: with aioresponses() as mock:
mock.get( mock.get(
"http://localhost:8080/api/v1/dags", "http://localhost:8080/api/v1/dags",
@@ -140,6 +186,5 @@ async def test_session_management(client):
) )
await client.execute("get_dags") await client.execute("get_dags")
# Should fail after context exit
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
await client.execute("get_dags") await client.execute("get_dags")