refactor: update AirflowClient to use httpx for async requests and enhance tests for concurrency
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -9,35 +10,60 @@ from airflow_mcp_server.client.airflow_client import AirflowClient
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
|
||||
def mock_openapi_response(*args, **kwargs):
|
||||
class MockResponse:
|
||||
def __init__(self):
|
||||
self.status_code = 200
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_multiple_clients_concurrent():
|
||||
"""Test initializing two AirflowClients concurrently to verify async power."""
|
||||
|
||||
def json(self):
|
||||
return {"openapi": "3.0.0", "info": {"title": "Airflow API", "version": "1.0.0"}, "paths": {}}
|
||||
async def mock_get(self, url, *args, **kwargs):
|
||||
class MockResponse:
|
||||
def __init__(self):
|
||||
self.status_code = 200
|
||||
|
||||
return MockResponse()
|
||||
def raise_for_status(self):
|
||||
pass
|
||||
|
||||
def json(self):
|
||||
return {"openapi": "3.1.0", "info": {"title": "Airflow API", "version": "2.0.0"}, "paths": {}}
|
||||
|
||||
return MockResponse()
|
||||
|
||||
with patch("httpx.AsyncClient.get", new=mock_get):
|
||||
|
||||
async def create_and_check():
|
||||
async with AirflowClient(base_url="http://localhost:8080", auth_token="token") as client:
|
||||
assert client.base_url == "http://localhost:8080"
|
||||
assert client.headers["Authorization"] == "Bearer token"
|
||||
assert isinstance(client.spec, OpenAPI)
|
||||
|
||||
# Run two clients concurrently
|
||||
await asyncio.gather(create_and_check(), create_and_check())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
with patch("airflow_mcp_server.client.airflow_client.requests.get", side_effect=mock_openapi_response):
|
||||
return AirflowClient(
|
||||
base_url="http://localhost:8080/api/v1",
|
||||
auth_token="test-token",
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_initialization():
|
||||
async def mock_get(self, url, *args, **kwargs):
|
||||
class MockResponse:
|
||||
def __init__(self):
|
||||
self.status_code = 200
|
||||
|
||||
def raise_for_status(self):
|
||||
pass
|
||||
|
||||
def test_init_client_initialization(client):
|
||||
assert isinstance(client.spec, OpenAPI)
|
||||
assert client.base_url == "http://localhost:8080/api/v1"
|
||||
assert client.headers["Authorization"] == "Bearer test-token"
|
||||
def json(self):
|
||||
return {"openapi": "3.1.0", "info": {"title": "Airflow API", "version": "2.0.0"}, "paths": {}}
|
||||
|
||||
return MockResponse()
|
||||
|
||||
with patch("httpx.AsyncClient.get", new=mock_get):
|
||||
async with AirflowClient(base_url="http://localhost:8080", auth_token="test-token") as client:
|
||||
assert client.base_url == "http://localhost:8080"
|
||||
assert client.headers["Authorization"] == "Bearer test-token"
|
||||
assert isinstance(client.spec, OpenAPI)
|
||||
|
||||
|
||||
def test_init_client_missing_auth():
|
||||
with pytest.raises(ValueError, match="auth_token"):
|
||||
AirflowClient(
|
||||
base_url="http://localhost:8080/api/v1",
|
||||
base_url="http://localhost:8080",
|
||||
auth_token=None,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user