fix airflow client
This commit is contained in:
@@ -6,7 +6,10 @@ from typing import Any, BinaryIO, TextIO
|
|||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import yaml
|
import yaml
|
||||||
|
from jsonschema_path import SchemaPath
|
||||||
from openapi_core import OpenAPI
|
from openapi_core import OpenAPI
|
||||||
|
from openapi_core.validation.request.validators import V31RequestValidator
|
||||||
|
from openapi_spec_validator import validate
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -62,6 +65,18 @@ class AirflowClient:
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Invalid spec_path type. Expected Path, str, dict, bytes or file-like object")
|
raise ValueError("Invalid spec_path type. Expected Path, str, dict, bytes or file-like object")
|
||||||
|
|
||||||
|
# Validate spec has required fields
|
||||||
|
if not isinstance(self.raw_spec, dict):
|
||||||
|
raise ValueError("OpenAPI spec must be a dictionary")
|
||||||
|
|
||||||
|
required_fields = ["openapi", "info", "paths"]
|
||||||
|
for field in required_fields:
|
||||||
|
if field not in self.raw_spec:
|
||||||
|
raise ValueError(f"OpenAPI spec missing required field: {field}")
|
||||||
|
|
||||||
|
# Validate OpenAPI spec format
|
||||||
|
validate(self.raw_spec)
|
||||||
|
|
||||||
# Initialize OpenAPI spec
|
# Initialize OpenAPI spec
|
||||||
self.spec = OpenAPI.from_dict(self.raw_spec)
|
self.spec = OpenAPI.from_dict(self.raw_spec)
|
||||||
logger.debug("OpenAPI spec loaded successfully")
|
logger.debug("OpenAPI spec loaded successfully")
|
||||||
@@ -75,6 +90,10 @@ class AirflowClient:
|
|||||||
self._paths = self.raw_spec["paths"]
|
self._paths = self.raw_spec["paths"]
|
||||||
logger.debug("Using raw spec paths")
|
logger.debug("Using raw spec paths")
|
||||||
|
|
||||||
|
# Initialize request validator with schema path
|
||||||
|
schema_path = SchemaPath.from_dict(self.raw_spec)
|
||||||
|
self._validator = V31RequestValidator(schema_path)
|
||||||
|
|
||||||
# API configuration
|
# API configuration
|
||||||
self.base_url = base_url.rstrip("/")
|
self.base_url = base_url.rstrip("/")
|
||||||
self.headers = {
|
self.headers = {
|
||||||
@@ -85,16 +104,14 @@ class AirflowClient:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to initialize AirflowClient: %s", e)
|
logger.error("Failed to initialize AirflowClient: %s", e)
|
||||||
raise
|
raise ValueError(f"Failed to initialize client: {e}")
|
||||||
|
|
||||||
async def __aenter__(self) -> "AirflowClient":
|
async def __aenter__(self) -> "AirflowClient":
|
||||||
"""Enter async context, creating session."""
|
|
||||||
self._session = aiohttp.ClientSession(headers=self.headers)
|
self._session = aiohttp.ClientSession(headers=self.headers)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def __aexit__(self, *exc) -> None:
|
async def __aexit__(self, *exc) -> None:
|
||||||
"""Exit async context, closing session."""
|
if hasattr(self, "_session"):
|
||||||
if self._session:
|
|
||||||
await self._session.close()
|
await self._session.close()
|
||||||
delattr(self, "_session")
|
delattr(self, "_session")
|
||||||
|
|
||||||
@@ -135,6 +152,23 @@ class AirflowClient:
|
|||||||
logger.error("Error getting operation %s: %s", operation_id, e)
|
logger.error("Error getting operation %s: %s", operation_id, e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
def _validate_path_params(self, path: str, params: dict[str, Any] | None) -> None:
|
||||||
|
if not params:
|
||||||
|
params = {}
|
||||||
|
|
||||||
|
# Extract path parameter names from the path
|
||||||
|
path_params = set(re.findall(r"{([^}]+)}", path))
|
||||||
|
|
||||||
|
# Check for missing required parameters
|
||||||
|
missing_params = path_params - set(params.keys())
|
||||||
|
if missing_params:
|
||||||
|
raise ValueError(f"Missing required path parameters: {missing_params}")
|
||||||
|
|
||||||
|
# Check for invalid parameters
|
||||||
|
invalid_params = set(params.keys()) - path_params
|
||||||
|
if invalid_params:
|
||||||
|
raise ValueError(f"Invalid path parameters: {invalid_params}")
|
||||||
|
|
||||||
async def execute(
|
async def execute(
|
||||||
self,
|
self,
|
||||||
operation_id: str,
|
operation_id: str,
|
||||||
@@ -165,6 +199,9 @@ class AirflowClient:
|
|||||||
# Get operation details
|
# Get operation details
|
||||||
path, method, _ = self._get_operation(operation_id)
|
path, method, _ = self._get_operation(operation_id)
|
||||||
|
|
||||||
|
# Validate path parameters
|
||||||
|
self._validate_path_params(path, path_params)
|
||||||
|
|
||||||
# Format URL
|
# Format URL
|
||||||
if path_params:
|
if path_params:
|
||||||
path = path.format(**path_params)
|
path = path.format(**path_params)
|
||||||
@@ -182,6 +219,9 @@ class AirflowClient:
|
|||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return await response.json()
|
return await response.json()
|
||||||
|
|
||||||
except Exception as e:
|
except aiohttp.ClientError as e:
|
||||||
logger.error("Error executing operation %s: %s", operation_id, e)
|
logger.error("Error executing operation %s: %s", operation_id, e)
|
||||||
raise
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error executing operation %s: %s", operation_id, e)
|
||||||
|
raise ValueError(f"Failed to execute operation: {e}")
|
||||||
|
|||||||
@@ -1,26 +1,30 @@
|
|||||||
import logging
|
import logging
|
||||||
from importlib import resources
|
from importlib import resources
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import pytest
|
import pytest
|
||||||
import yaml
|
import yaml
|
||||||
from aioresponses import aioresponses
|
from aioresponses import aioresponses
|
||||||
from airflow_mcp_server.client import AirflowClient
|
from airflow_mcp_server.client.airflow_client import AirflowClient
|
||||||
from openapi_core import OpenAPI
|
from openapi_core import OpenAPI
|
||||||
|
|
||||||
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
|
||||||
|
|
||||||
|
def create_valid_spec(paths: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||||
|
return {"openapi": "3.0.0", "info": {"title": "Airflow API", "version": "1.0.0"}, "paths": paths or {}}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def spec_file():
|
def spec_file() -> dict[str, Any]:
|
||||||
"""Get content of the v1.yaml spec file."""
|
|
||||||
with resources.files("tests.client").joinpath("v1.yaml").open("r") as f:
|
with resources.files("tests.client").joinpath("v1.yaml").open("r") as f:
|
||||||
return yaml.safe_load(f)
|
return yaml.safe_load(f)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def client(spec_file):
|
def client(spec_file: dict[str, Any]) -> AirflowClient:
|
||||||
"""Create test client with the actual spec."""
|
|
||||||
return AirflowClient(
|
return AirflowClient(
|
||||||
spec_path=spec_file,
|
spec_path=spec_file,
|
||||||
base_url="http://localhost:8080/api/v1",
|
base_url="http://localhost:8080/api/v1",
|
||||||
@@ -28,46 +32,70 @@ def client(spec_file):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_client_initialization(client):
|
def test_init_client_initialization(client: AirflowClient) -> None:
|
||||||
"""Test client initialization and spec loading."""
|
|
||||||
assert isinstance(client.spec, OpenAPI)
|
assert isinstance(client.spec, OpenAPI)
|
||||||
assert client.base_url == "http://localhost:8080/api/v1"
|
assert client.base_url == "http://localhost:8080/api/v1"
|
||||||
assert client.headers["Authorization"] == "Bearer test-token"
|
assert client.headers["Authorization"] == "Bearer test-token"
|
||||||
|
|
||||||
|
|
||||||
def test_get_operation(client):
|
def test_init_load_spec_from_bytes() -> None:
|
||||||
"""Test operation lookup from spec."""
|
spec_bytes = yaml.dump(create_valid_spec()).encode()
|
||||||
# Test get_dags operation
|
client = AirflowClient(spec_path=spec_bytes, base_url="http://test", auth_token="test")
|
||||||
|
assert client.raw_spec is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_load_spec_from_path(tmp_path: Path) -> None:
|
||||||
|
spec_file = tmp_path / "test_spec.yaml"
|
||||||
|
spec_file.write_text(yaml.dump(create_valid_spec()))
|
||||||
|
client = AirflowClient(spec_path=spec_file, base_url="http://test", auth_token="test")
|
||||||
|
assert client.raw_spec is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_invalid_spec() -> None:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
AirflowClient(spec_path={"invalid": "spec"}, base_url="http://test", auth_token="test")
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_missing_paths_in_spec() -> None:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
AirflowClient(spec_path={"openapi": "3.0.0"}, base_url="http://test", auth_token="test")
|
||||||
|
|
||||||
|
|
||||||
|
def test_ops_get_operation(client: AirflowClient) -> None:
|
||||||
path, method, operation = client._get_operation("get_dags")
|
path, method, operation = client._get_operation("get_dags")
|
||||||
assert path == "/dags"
|
assert path == "/dags"
|
||||||
assert method == "get"
|
assert method == "get"
|
||||||
assert operation.operation_id == "get_dags"
|
assert operation.operation_id == "get_dags"
|
||||||
|
|
||||||
# Test get_dag operation
|
|
||||||
path, method, operation = client._get_operation("get_dag")
|
path, method, operation = client._get_operation("get_dag")
|
||||||
assert path == "/dags/{dag_id}"
|
assert path == "/dags/{dag_id}"
|
||||||
assert method == "get"
|
assert method == "get"
|
||||||
assert operation.operation_id == "get_dag"
|
assert operation.operation_id == "get_dag"
|
||||||
|
|
||||||
|
|
||||||
# Note: asyncio_mode is configured in pyproject.toml
|
def test_ops_nonexistent_operation(client: AirflowClient) -> None:
|
||||||
|
with pytest.raises(ValueError, match="Operation nonexistent not found in spec"):
|
||||||
|
client._get_operation("nonexistent")
|
||||||
|
|
||||||
|
|
||||||
|
def test_ops_case_sensitive_operation(client: AirflowClient) -> None:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
client._get_operation("GET_DAGS")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_execute_without_context():
|
async def test_exec_without_context(spec_file: dict[str, Any]) -> None:
|
||||||
"""Test error when executing outside async context."""
|
client = AirflowClient(
|
||||||
with resources.files("tests.client").joinpath("v1.yaml").open("r") as f:
|
spec_path=spec_file,
|
||||||
spec_content = yaml.safe_load(f)
|
base_url="http://test",
|
||||||
client = AirflowClient(
|
auth_token="test",
|
||||||
spec_path=spec_content,
|
)
|
||||||
base_url="http://test",
|
with pytest.raises(RuntimeError, match="Client not in async context"):
|
||||||
auth_token="test",
|
|
||||||
)
|
|
||||||
with pytest.raises((RuntimeError, AttributeError)):
|
|
||||||
await client.execute("get_dags")
|
await client.execute("get_dags")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_execute_get_dags(client):
|
async def test_exec_get_dags(client: AirflowClient) -> None:
|
||||||
"""Test DAG list retrieval."""
|
|
||||||
expected_response = {
|
expected_response = {
|
||||||
"dags": [
|
"dags": [
|
||||||
{
|
{
|
||||||
@@ -91,8 +119,7 @@ async def test_execute_get_dags(client):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_execute_get_dag(client):
|
async def test_exec_get_dag(client: AirflowClient) -> None:
|
||||||
"""Test single DAG retrieval with path parameters."""
|
|
||||||
expected_response = {
|
expected_response = {
|
||||||
"dag_id": "test_dag",
|
"dag_id": "test_dag",
|
||||||
"is_active": True,
|
"is_active": True,
|
||||||
@@ -114,8 +141,29 @@ async def test_execute_get_dag(client):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_execute_error_response(client):
|
async def test_exec_invalid_params(client: AirflowClient) -> None:
|
||||||
"""Test error handling for failed requests."""
|
with pytest.raises(ValueError):
|
||||||
|
async with client:
|
||||||
|
# Test with missing required parameter
|
||||||
|
await client.execute("get_dag", path_params={})
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
async with client:
|
||||||
|
# Test with invalid parameter name
|
||||||
|
await client.execute("get_dag", path_params={"invalid": "value"})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exec_timeout(client: AirflowClient) -> None:
|
||||||
|
with aioresponses() as mock:
|
||||||
|
mock.get("http://localhost:8080/api/v1/dags", exception=aiohttp.ClientError("Timeout"))
|
||||||
|
async with client:
|
||||||
|
with pytest.raises(aiohttp.ClientError):
|
||||||
|
await client.execute("get_dags")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exec_error_response(client: AirflowClient) -> None:
|
||||||
with aioresponses() as mock:
|
with aioresponses() as mock:
|
||||||
async with client:
|
async with client:
|
||||||
mock.get(
|
mock.get(
|
||||||
@@ -123,15 +171,13 @@ async def test_execute_error_response(client):
|
|||||||
status=403,
|
status=403,
|
||||||
body="Forbidden",
|
body="Forbidden",
|
||||||
)
|
)
|
||||||
with pytest.raises(aiohttp.ClientError):
|
with pytest.raises(aiohttp.ClientResponseError):
|
||||||
await client.execute("get_dags")
|
await client.execute("get_dags")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_session_management(client):
|
async def test_exec_session_management(client: AirflowClient) -> None:
|
||||||
"""Test proper session creation and cleanup."""
|
|
||||||
async with client:
|
async with client:
|
||||||
# Should work inside context
|
|
||||||
with aioresponses() as mock:
|
with aioresponses() as mock:
|
||||||
mock.get(
|
mock.get(
|
||||||
"http://localhost:8080/api/v1/dags",
|
"http://localhost:8080/api/v1/dags",
|
||||||
@@ -140,6 +186,5 @@ async def test_session_management(client):
|
|||||||
)
|
)
|
||||||
await client.execute("get_dags")
|
await client.execute("get_dags")
|
||||||
|
|
||||||
# Should fail after context exit
|
|
||||||
with pytest.raises(RuntimeError):
|
with pytest.raises(RuntimeError):
|
||||||
await client.execute("get_dags")
|
await client.execute("get_dags")
|
||||||
|
|||||||
Reference in New Issue
Block a user