diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0611dcb..961f288 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.11.6 + rev: v0.11.8 hooks: - id: ruff args: [--fix] diff --git a/src/airflow_mcp_server/client/airflow_client.py b/src/airflow_mcp_server/client/airflow_client.py index a75394f..7ec9afe 100644 --- a/src/airflow_mcp_server/client/airflow_client.py +++ b/src/airflow_mcp_server/client/airflow_client.py @@ -1,7 +1,7 @@ import logging import re -import requests +import httpx from jsonschema_path import SchemaPath from openapi_core import OpenAPI from openapi_core.validation.request.validators import V31RequestValidator @@ -25,7 +25,7 @@ def convert_dict_keys(d: dict) -> dict: class AirflowClient: - """Client for interacting with Airflow API.""" + """Async client for interacting with Airflow API.""" def __init__( self, @@ -48,12 +48,25 @@ class AirflowClient: self.base_url = base_url self.auth_token = auth_token self.headers = {"Authorization": f"Bearer {self.auth_token}"} + self._client: httpx.AsyncClient | None = None + self.raw_spec = None + self.spec = None + self._paths = None + self._validator = None - # Fetch OpenAPI spec from endpoint + async def __aenter__(self): + self._client = httpx.AsyncClient(headers=self.headers) + await self._initialize_spec() + return self + + async def __aexit__(self, exc_type, exc, tb): + if self._client: + await self._client.aclose() + self._client = None + + async def _initialize_spec(self): openapi_url = f"{self.base_url.rstrip('/')}/openapi.json" - self.raw_spec = self._fetch_openapi_spec(openapi_url) - - # Validate spec has required fields + self.raw_spec = await self._fetch_openapi_spec(openapi_url) if not isinstance(self.raw_spec, dict): raise ValueError("OpenAPI spec must be a dictionary") required_fields = ["openapi", "info", "paths"] @@ -70,10 +83,12 @@ class AirflowClient: schema_path = SchemaPath.from_dict(self.raw_spec) self._validator = V31RequestValidator(schema_path) - def _fetch_openapi_spec(self, url: str) -> dict: + async def _fetch_openapi_spec(self, url: str) -> dict: + if not self._client: + self._client = httpx.AsyncClient(headers=self.headers) try: - response = requests.get(url, headers=self.headers) + response = await self._client.get(url) response.raise_for_status() - except requests.RequestException as e: + except httpx.RequestError as e: raise ValueError(f"Failed to fetch OpenAPI spec from {url}: {e}") return response.json() diff --git a/tests/client/test_airflow_client.py b/tests/client/test_airflow_client.py index b6f1f8c..4ee21c2 100644 --- a/tests/client/test_airflow_client.py +++ b/tests/client/test_airflow_client.py @@ -1,3 +1,4 @@ +import asyncio import logging from unittest.mock import patch @@ -9,35 +10,60 @@ from airflow_mcp_server.client.airflow_client import AirflowClient logging.basicConfig(level=logging.DEBUG) -def mock_openapi_response(*args, **kwargs): - class MockResponse: - def __init__(self): - self.status_code = 200 +@pytest.mark.asyncio +async def test_async_multiple_clients_concurrent(): + """Test initializing two AirflowClients concurrently to verify async power.""" - def json(self): - return {"openapi": "3.0.0", "info": {"title": "Airflow API", "version": "1.0.0"}, "paths": {}} + async def mock_get(self, url, *args, **kwargs): + class MockResponse: + def __init__(self): + self.status_code = 200 - return MockResponse() + def raise_for_status(self): + pass + + def json(self): + return {"openapi": "3.1.0", "info": {"title": "Airflow API", "version": "2.0.0"}, "paths": {}} + + return MockResponse() + + with patch("httpx.AsyncClient.get", new=mock_get): + + async def create_and_check(): + async with AirflowClient(base_url="http://localhost:8080", auth_token="token") as client: + assert client.base_url == "http://localhost:8080" + assert client.headers["Authorization"] == "Bearer token" + assert isinstance(client.spec, OpenAPI) + + # Run two clients concurrently + await asyncio.gather(create_and_check(), create_and_check()) -@pytest.fixture -def client(): - with patch("airflow_mcp_server.client.airflow_client.requests.get", side_effect=mock_openapi_response): - return AirflowClient( - base_url="http://localhost:8080/api/v1", - auth_token="test-token", - ) +@pytest.mark.asyncio +async def test_async_client_initialization(): + async def mock_get(self, url, *args, **kwargs): + class MockResponse: + def __init__(self): + self.status_code = 200 + def raise_for_status(self): + pass -def test_init_client_initialization(client): - assert isinstance(client.spec, OpenAPI) - assert client.base_url == "http://localhost:8080/api/v1" - assert client.headers["Authorization"] == "Bearer test-token" + def json(self): + return {"openapi": "3.1.0", "info": {"title": "Airflow API", "version": "2.0.0"}, "paths": {}} + + return MockResponse() + + with patch("httpx.AsyncClient.get", new=mock_get): + async with AirflowClient(base_url="http://localhost:8080", auth_token="test-token") as client: + assert client.base_url == "http://localhost:8080" + assert client.headers["Authorization"] == "Bearer test-token" + assert isinstance(client.spec, OpenAPI) def test_init_client_missing_auth(): with pytest.raises(ValueError, match="auth_token"): AirflowClient( - base_url="http://localhost:8080/api/v1", + base_url="http://localhost:8080", auth_token=None, )