refactor: update AirflowClient to use httpx for async requests and enhance tests for concurrency
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
rev: v0.11.6
|
rev: v0.11.8
|
||||||
hooks:
|
hooks:
|
||||||
- id: ruff
|
- id: ruff
|
||||||
args: [--fix]
|
args: [--fix]
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
|
||||||
import requests
|
import httpx
|
||||||
from jsonschema_path import SchemaPath
|
from jsonschema_path import SchemaPath
|
||||||
from openapi_core import OpenAPI
|
from openapi_core import OpenAPI
|
||||||
from openapi_core.validation.request.validators import V31RequestValidator
|
from openapi_core.validation.request.validators import V31RequestValidator
|
||||||
@@ -25,7 +25,7 @@ def convert_dict_keys(d: dict) -> dict:
|
|||||||
|
|
||||||
|
|
||||||
class AirflowClient:
|
class AirflowClient:
|
||||||
"""Client for interacting with Airflow API."""
|
"""Async client for interacting with Airflow API."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -48,12 +48,25 @@ class AirflowClient:
|
|||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
self.auth_token = auth_token
|
self.auth_token = auth_token
|
||||||
self.headers = {"Authorization": f"Bearer {self.auth_token}"}
|
self.headers = {"Authorization": f"Bearer {self.auth_token}"}
|
||||||
|
self._client: httpx.AsyncClient | None = None
|
||||||
|
self.raw_spec = None
|
||||||
|
self.spec = None
|
||||||
|
self._paths = None
|
||||||
|
self._validator = None
|
||||||
|
|
||||||
# Fetch OpenAPI spec from endpoint
|
async def __aenter__(self):
|
||||||
|
self._client = httpx.AsyncClient(headers=self.headers)
|
||||||
|
await self._initialize_spec()
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb):
|
||||||
|
if self._client:
|
||||||
|
await self._client.aclose()
|
||||||
|
self._client = None
|
||||||
|
|
||||||
|
async def _initialize_spec(self):
|
||||||
openapi_url = f"{self.base_url.rstrip('/')}/openapi.json"
|
openapi_url = f"{self.base_url.rstrip('/')}/openapi.json"
|
||||||
self.raw_spec = self._fetch_openapi_spec(openapi_url)
|
self.raw_spec = await self._fetch_openapi_spec(openapi_url)
|
||||||
|
|
||||||
# Validate spec has required fields
|
|
||||||
if not isinstance(self.raw_spec, dict):
|
if not isinstance(self.raw_spec, dict):
|
||||||
raise ValueError("OpenAPI spec must be a dictionary")
|
raise ValueError("OpenAPI spec must be a dictionary")
|
||||||
required_fields = ["openapi", "info", "paths"]
|
required_fields = ["openapi", "info", "paths"]
|
||||||
@@ -70,10 +83,12 @@ class AirflowClient:
|
|||||||
schema_path = SchemaPath.from_dict(self.raw_spec)
|
schema_path = SchemaPath.from_dict(self.raw_spec)
|
||||||
self._validator = V31RequestValidator(schema_path)
|
self._validator = V31RequestValidator(schema_path)
|
||||||
|
|
||||||
def _fetch_openapi_spec(self, url: str) -> dict:
|
async def _fetch_openapi_spec(self, url: str) -> dict:
|
||||||
|
if not self._client:
|
||||||
|
self._client = httpx.AsyncClient(headers=self.headers)
|
||||||
try:
|
try:
|
||||||
response = requests.get(url, headers=self.headers)
|
response = await self._client.get(url)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
except requests.RequestException as e:
|
except httpx.RequestError as e:
|
||||||
raise ValueError(f"Failed to fetch OpenAPI spec from {url}: {e}")
|
raise ValueError(f"Failed to fetch OpenAPI spec from {url}: {e}")
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
@@ -9,35 +10,60 @@ from airflow_mcp_server.client.airflow_client import AirflowClient
|
|||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
|
||||||
|
|
||||||
def mock_openapi_response(*args, **kwargs):
|
@pytest.mark.asyncio
|
||||||
class MockResponse:
|
async def test_async_multiple_clients_concurrent():
|
||||||
def __init__(self):
|
"""Test initializing two AirflowClients concurrently to verify async power."""
|
||||||
self.status_code = 200
|
|
||||||
|
|
||||||
def json(self):
|
async def mock_get(self, url, *args, **kwargs):
|
||||||
return {"openapi": "3.0.0", "info": {"title": "Airflow API", "version": "1.0.0"}, "paths": {}}
|
class MockResponse:
|
||||||
|
def __init__(self):
|
||||||
|
self.status_code = 200
|
||||||
|
|
||||||
return MockResponse()
|
def raise_for_status(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def json(self):
|
||||||
|
return {"openapi": "3.1.0", "info": {"title": "Airflow API", "version": "2.0.0"}, "paths": {}}
|
||||||
|
|
||||||
|
return MockResponse()
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient.get", new=mock_get):
|
||||||
|
|
||||||
|
async def create_and_check():
|
||||||
|
async with AirflowClient(base_url="http://localhost:8080", auth_token="token") as client:
|
||||||
|
assert client.base_url == "http://localhost:8080"
|
||||||
|
assert client.headers["Authorization"] == "Bearer token"
|
||||||
|
assert isinstance(client.spec, OpenAPI)
|
||||||
|
|
||||||
|
# Run two clients concurrently
|
||||||
|
await asyncio.gather(create_and_check(), create_and_check())
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.mark.asyncio
|
||||||
def client():
|
async def test_async_client_initialization():
|
||||||
with patch("airflow_mcp_server.client.airflow_client.requests.get", side_effect=mock_openapi_response):
|
async def mock_get(self, url, *args, **kwargs):
|
||||||
return AirflowClient(
|
class MockResponse:
|
||||||
base_url="http://localhost:8080/api/v1",
|
def __init__(self):
|
||||||
auth_token="test-token",
|
self.status_code = 200
|
||||||
)
|
|
||||||
|
|
||||||
|
def raise_for_status(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def test_init_client_initialization(client):
|
def json(self):
|
||||||
assert isinstance(client.spec, OpenAPI)
|
return {"openapi": "3.1.0", "info": {"title": "Airflow API", "version": "2.0.0"}, "paths": {}}
|
||||||
assert client.base_url == "http://localhost:8080/api/v1"
|
|
||||||
assert client.headers["Authorization"] == "Bearer test-token"
|
return MockResponse()
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient.get", new=mock_get):
|
||||||
|
async with AirflowClient(base_url="http://localhost:8080", auth_token="test-token") as client:
|
||||||
|
assert client.base_url == "http://localhost:8080"
|
||||||
|
assert client.headers["Authorization"] == "Bearer test-token"
|
||||||
|
assert isinstance(client.spec, OpenAPI)
|
||||||
|
|
||||||
|
|
||||||
def test_init_client_missing_auth():
|
def test_init_client_missing_auth():
|
||||||
with pytest.raises(ValueError, match="auth_token"):
|
with pytest.raises(ValueError, match="auth_token"):
|
||||||
AirflowClient(
|
AirflowClient(
|
||||||
base_url="http://localhost:8080/api/v1",
|
base_url="http://localhost:8080",
|
||||||
auth_token=None,
|
auth_token=None,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user