From 2b652c592634e9b992f46e46fd94e622dc41cfcc Mon Sep 17 00:00:00 2001 From: abhishekbhakat Date: Tue, 25 Feb 2025 02:29:16 +0000 Subject: [PATCH] support cookies --- pyproject.toml | 3 ++- .../client/airflow_client.py | 21 ++++++++++------ src/airflow_mcp_server/server.py | 13 +++++++--- src/airflow_mcp_server/server_safe.py | 13 +++++++--- src/airflow_mcp_server/server_unsafe.py | 13 +++++++--- src/airflow_mcp_server/tools/tool_manager.py | 24 ++++++++++++++---- tests/client/test_airflow_client.py | 25 +++++++++++++++++++ 7 files changed, 90 insertions(+), 22 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e3f292a..b8dc0ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,8 @@ build-backend = "hatchling.build" exclude = [ "*", "!src/**", - "!pyproject.toml" + "!pyproject.toml", + "!assets/**" ] [tool.hatch.build.targets.wheel] diff --git a/src/airflow_mcp_server/client/airflow_client.py b/src/airflow_mcp_server/client/airflow_client.py index e9f247c..1014a3e 100644 --- a/src/airflow_mcp_server/client/airflow_client.py +++ b/src/airflow_mcp_server/client/airflow_client.py @@ -35,18 +35,22 @@ class AirflowClient: self, spec_path: Path | str | dict | bytes | BinaryIO | TextIO, base_url: str, - auth_token: str, + auth_token: str | None = None, + cookie: str | None = None, ) -> None: """Initialize Airflow client. Args: spec_path: OpenAPI spec as file path, dict, bytes, or file object base_url: Base URL for API - auth_token: Authentication token + auth_token: Authentication token (optional if cookie is provided) + cookie: Session cookie (optional if auth_token is provided) Raises: - ValueError: If spec_path is invalid or spec cannot be loaded + ValueError: If spec_path is invalid or spec cannot be loaded or if neither auth_token nor cookie is provided """ + if not auth_token and not cookie: + raise ValueError("Either auth_token or cookie must be provided") try: # Load and parse OpenAPI spec if isinstance(spec_path, dict): @@ -96,10 +100,13 @@ class AirflowClient: # API configuration self.base_url = base_url.rstrip("/") - self.headers = { - "Authorization": f"Basic {auth_token}", - "Accept": "application/json", - } + self.headers = {"Accept": "application/json"} + + # Set authentication header based on what was provided + if auth_token: + self.headers["Authorization"] = f"Basic {auth_token}" + elif cookie: + self.headers["Cookie"] = cookie except Exception as e: logger.error("Failed to initialize AirflowClient: %s", e) diff --git a/src/airflow_mcp_server/server.py b/src/airflow_mcp_server/server.py index 09933d2..87d3a94 100644 --- a/src/airflow_mcp_server/server.py +++ b/src/airflow_mcp_server/server.py @@ -22,9 +22,16 @@ logger = logging.getLogger(__name__) async def serve() -> None: """Start MCP server.""" - required_vars = ["AIRFLOW_BASE_URL", "AUTH_TOKEN"] - if not all(var in os.environ for var in required_vars): - raise ValueError(f"Missing required environment variables: {required_vars}") + # Check for AIRFLOW_BASE_URL which is always required + if "AIRFLOW_BASE_URL" not in os.environ: + raise ValueError("Missing required environment variable: AIRFLOW_BASE_URL") + + # Check for either AUTH_TOKEN or COOKIE + has_auth_token = "AUTH_TOKEN" in os.environ + has_cookie = "COOKIE" in os.environ + + if not has_auth_token and not has_cookie: + raise ValueError("Either AUTH_TOKEN or COOKIE environment variable must be provided") server = Server("airflow-mcp-server") diff --git a/src/airflow_mcp_server/server_safe.py b/src/airflow_mcp_server/server_safe.py index bf81b3e..6be1f36 100644 --- a/src/airflow_mcp_server/server_safe.py +++ b/src/airflow_mcp_server/server_safe.py @@ -13,9 +13,16 @@ logger = logging.getLogger(__name__) async def serve() -> None: """Start MCP server in safe mode (read-only operations).""" - required_vars = ["AIRFLOW_BASE_URL", "AUTH_TOKEN"] - if not all(var in os.environ for var in required_vars): - raise ValueError(f"Missing required environment variables: {required_vars}") + # Check for AIRFLOW_BASE_URL which is always required + if "AIRFLOW_BASE_URL" not in os.environ: + raise ValueError("Missing required environment variable: AIRFLOW_BASE_URL") + + # Check for either AUTH_TOKEN or COOKIE + has_auth_token = "AUTH_TOKEN" in os.environ + has_cookie = "COOKIE" in os.environ + + if not has_auth_token and not has_cookie: + raise ValueError("Either AUTH_TOKEN or COOKIE environment variable must be provided") server = Server("airflow-mcp-server-safe") diff --git a/src/airflow_mcp_server/server_unsafe.py b/src/airflow_mcp_server/server_unsafe.py index bcc7932..b33b46c 100644 --- a/src/airflow_mcp_server/server_unsafe.py +++ b/src/airflow_mcp_server/server_unsafe.py @@ -13,9 +13,16 @@ logger = logging.getLogger(__name__) async def serve() -> None: """Start MCP server in unsafe mode (all operations).""" - required_vars = ["AIRFLOW_BASE_URL", "AUTH_TOKEN"] - if not all(var in os.environ for var in required_vars): - raise ValueError(f"Missing required environment variables: {required_vars}") + # Check for AIRFLOW_BASE_URL which is always required + if "AIRFLOW_BASE_URL" not in os.environ: + raise ValueError("Missing required environment variable: AIRFLOW_BASE_URL") + + # Check for either AUTH_TOKEN or COOKIE + has_auth_token = "AUTH_TOKEN" in os.environ + has_cookie = "COOKIE" in os.environ + + if not has_auth_token and not has_cookie: + raise ValueError("Either AUTH_TOKEN or COOKIE environment variable must be provided") server = Server("airflow-mcp-server-unsafe") diff --git a/src/airflow_mcp_server/tools/tool_manager.py b/src/airflow_mcp_server/tools/tool_manager.py index ee70e86..9eb9cbe 100644 --- a/src/airflow_mcp_server/tools/tool_manager.py +++ b/src/airflow_mcp_server/tools/tool_manager.py @@ -32,12 +32,26 @@ def _initialize_client() -> AirflowClient: except Exception as e: raise ValueError("Default OpenAPI spec not found in package resources") from e - required_vars = ["AIRFLOW_BASE_URL", "AUTH_TOKEN"] - missing_vars = [var for var in required_vars if var not in os.environ] - if missing_vars: - raise ValueError(f"Missing required environment variables: {missing_vars}") + # Check for base URL + if "AIRFLOW_BASE_URL" not in os.environ: + raise ValueError("Missing required environment variable: AIRFLOW_BASE_URL") - return AirflowClient(spec_path=spec_path, base_url=os.environ["AIRFLOW_BASE_URL"], auth_token=os.environ["AUTH_TOKEN"]) + # Check for either AUTH_TOKEN or COOKIE + has_auth_token = "AUTH_TOKEN" in os.environ + has_cookie = "COOKIE" in os.environ + + if not has_auth_token and not has_cookie: + raise ValueError("Either AUTH_TOKEN or COOKIE environment variable must be provided") + + # Initialize client with appropriate authentication method + client_args = {"spec_path": spec_path, "base_url": os.environ["AIRFLOW_BASE_URL"]} + + if has_auth_token: + client_args["auth_token"] = os.environ["AUTH_TOKEN"] + elif has_cookie: + client_args["cookie"] = os.environ["COOKIE"] + + return AirflowClient(**client_args) async def _initialize_tools() -> None: diff --git a/tests/client/test_airflow_client.py b/tests/client/test_airflow_client.py index 3fdc03d..0b4460c 100644 --- a/tests/client/test_airflow_client.py +++ b/tests/client/test_airflow_client.py @@ -32,6 +32,31 @@ def test_init_client_initialization(client: AirflowClient) -> None: assert isinstance(client.spec, OpenAPI) assert client.base_url == "http://localhost:8080/api/v1" assert client.headers["Authorization"] == "Basic test-token" + assert "Cookie" not in client.headers + + +def test_init_client_with_cookie() -> None: + with resources.files("airflow_mcp_server.resources").joinpath("v1.yaml").open("rb") as f: + spec = yaml.safe_load(f) + client = AirflowClient( + spec_path=spec, + base_url="http://localhost:8080/api/v1", + cookie="session=b18e8c5e-92f5-4d1e-a8f2-7c1b62110cae.vmX5kqDq5TdvT9BzTlypMVclAwM", + ) + assert isinstance(client.spec, OpenAPI) + assert client.base_url == "http://localhost:8080/api/v1" + assert "Authorization" not in client.headers + assert client.headers["Cookie"] == "session=b18e8c5e-92f5-4d1e-a8f2-7c1b62110cae.vmX5kqDq5TdvT9BzTlypMVclAwM" + + +def test_init_client_missing_auth() -> None: + with resources.files("airflow_mcp_server.resources").joinpath("v1.yaml").open("rb") as f: + spec = yaml.safe_load(f) + with pytest.raises(ValueError, match="Either auth_token or cookie must be provided"): + AirflowClient( + spec_path=spec, + base_url="http://localhost:8080/api/v1", + ) def test_init_load_spec_from_bytes() -> None: