diff --git a/.gitignore b/.gitignore
index 85d6678..ac2d8da 100644
--- a/.gitignore
+++ b/.gitignore
@@ -174,3 +174,5 @@ cython_debug/
# Local Resources
plugins_reference/
astro/
+
+node_modules/
diff --git a/pyproject.toml b/pyproject.toml
index 4431ddc..931b776 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -9,9 +9,10 @@ authors = [
{name = "Abhishek Bhakat", email = "abhishek.bhakat@hotmail.com"}
]
dependencies = [
+ "airflow-mcp-server>=0.4.0",
+ "anthropic>=0.46.0",
"apache-airflow>=2.10.0",
"openai>=1.64.0",
- "anthropic>=0.46.0"
]
classifiers = [
"Development Status :: 3 - Alpha",
@@ -31,6 +32,13 @@ Issues = "https://github.com/abhishekbhakat/airflow-wingman/issues"
[project.entry-points."airflow.plugins"]
wingman = "airflow_wingman:WingmanPlugin"
+[project.optional-dependencies]
+dev = [
+ "build>=1.2.2",
+ "pre-commit>=4.0.1",
+ "ruff>=0.9.2"
+]
+
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
@@ -60,7 +68,8 @@ lint.select = [
lint.ignore = [
"C416", # Unnecessary list comprehension - rewrite as a generator expression
"C408", # Unnecessary `dict` call - rewrite as a literal
- "ISC001" # Single line implicit string concatenation
+ "ISC001", # Single line implicit string concatenation
+ "C901"
]
lint.fixable = ["ALL"]
diff --git a/src/airflow_wingman/llm_client.py b/src/airflow_wingman/llm_client.py
index dd58370..5cd1071 100644
--- a/src/airflow_wingman/llm_client.py
+++ b/src/airflow_wingman/llm_client.py
@@ -1,109 +1,266 @@
"""
-Client for making API calls to various LLM providers using their official SDKs.
+Multi-provider LLM client for Airflow Wingman.
+
+This module contains the LLMClient class that supports multiple LLM providers
+(OpenAI, Anthropic, OpenRouter) through a unified interface.
"""
-from collections.abc import Generator
+import logging
+import traceback
+from typing import Any
-from anthropic import Anthropic
-from openai import OpenAI
+from flask import session
+
+from airflow_wingman.providers import create_llm_provider
+from airflow_wingman.tools import list_airflow_tools
+
+# Create a properly namespaced logger for the Airflow plugin
+logger = logging.getLogger("airflow.plugins.wingman")
class LLMClient:
- def __init__(self, api_key: str):
- """Initialize the LLM client.
+ """
+ Multi-provider LLM client for Airflow Wingman.
+
+ This class handles chat completion requests to various LLM providers
+ (OpenAI, Anthropic, OpenRouter) through a unified interface.
+ """
+
+ def __init__(self, provider_name: str, api_key: str, base_url: str | None = None):
+ """
+ Initialize the LLM client.
Args:
+ provider_name: Name of the provider (openai, anthropic, openrouter)
api_key: API key for the provider
+ base_url: Optional base URL for the provider API
"""
+ self.provider_name = provider_name
self.api_key = api_key
- self.openai_client = OpenAI(api_key=api_key)
- self.anthropic_client = Anthropic(api_key=api_key)
- self.openrouter_client = OpenAI(
- base_url="https://openrouter.ai/api/v1",
- api_key=api_key,
- default_headers={
- "HTTP-Referer": "Airflow Wingman", # Required by OpenRouter
- "X-Title": "Airflow Wingman", # Required by OpenRouter
- },
- )
+ self.base_url = base_url
+ self.provider = create_llm_provider(provider_name, api_key, base_url)
+ self.airflow_tools = []
+
+ def set_airflow_tools(self, tools: list):
+ """
+ Set the available Airflow tools.
+
+ Args:
+ tools: List of Airflow Tool objects
+ """
+ self.airflow_tools = tools
def chat_completion(
- self, messages: list[dict[str, str]], model: str, provider: str, temperature: float = 0.7, max_tokens: int | None = None, stream: bool = False
- ) -> Generator[str, None, None] | dict:
- """Send a chat completion request to the specified provider.
+ self, messages: list[dict[str, str]], model: str, temperature: float = 0.4, max_tokens: int | None = None, stream: bool = True, return_response_obj: bool = False
+ ) -> dict[str, Any] | tuple[Any, Any]:
+ """
+ Send a chat completion request to the LLM provider.
Args:
messages: List of message dictionaries with 'role' and 'content'
model: Model identifier
- provider: Provider identifier (openai, anthropic, openrouter)
temperature: Sampling temperature (0-1)
max_tokens: Maximum tokens to generate
+ stream: Whether to stream the response (default is True)
+ return_response_obj: If True and streaming, returns both the response object and generator
+
+ Returns:
+ If stream=False: Dictionary with the response content
+ If stream=True and return_response_obj=False: Generator for streaming
+ If stream=True and return_response_obj=True: Tuple of (response_obj, generator)
+ """
+ # Get provider-specific tool definitions from Airflow tools
+ provider_tools = self.provider.convert_tools(self.airflow_tools)
+
+ try:
+ # Make the initial request with tools
+ logger.info(f"Sending chat completion request to {self.provider_name} with model: {model}")
+ response = self.provider.create_chat_completion(messages=messages, model=model, temperature=temperature, max_tokens=max_tokens, stream=stream, tools=provider_tools)
+ logger.info(f"Received response from {self.provider_name}")
+
+ # If streaming, handle based on return_response_obj flag
+ if stream:
+ logger.info(f"Using streaming response from {self.provider_name}")
+ streaming_content = self.provider.get_streaming_content(response)
+ return streaming_content
+
+ # For non-streaming responses, handle tool calls if present
+ if self.provider.has_tool_calls(response):
+ logger.info("Response contains tool calls")
+
+ # Process tool calls and get results
+ cookie = session.get("airflow_cookie")
+ if not cookie:
+ error_msg = "No Airflow cookie available"
+ logger.error(error_msg)
+ return {"error": error_msg}
+
+ tool_results = self.provider.process_tool_calls(response, cookie)
+
+ # Create a follow-up completion with the tool results
+ logger.info("Making follow-up request with tool results")
+ follow_up_response = self.provider.create_follow_up_completion(
+ messages=messages, model=model, temperature=temperature, max_tokens=max_tokens, tool_results=tool_results, original_response=response, tools=provider_tools
+ )
+
+ content = self.provider.get_content(follow_up_response)
+ logger.info(f"Final content from {self.provider_name} with tool calls COMPLETE RESPONSE START >>>")
+ logger.info(content)
+ logger.info("<<< COMPLETE RESPONSE END")
+ return {"content": content}
+ else:
+ logger.info("Response does not contain tool calls")
+ content = self.provider.get_content(response)
+ logger.info(f"Final content from {self.provider_name} without tool calls COMPLETE RESPONSE START >>>")
+ logger.info(content)
+ logger.info("<<< COMPLETE RESPONSE END")
+ return {"content": content}
+
+ except Exception as e:
+ error_msg = f"Error in {self.provider_name} API call: {str(e)}\n{traceback.format_exc()}"
+ logger.error(error_msg)
+ return {"error": f"API request failed: {str(e)}"}
+
+ @classmethod
+ def from_config(cls, config: dict[str, Any]) -> "LLMClient":
+ """
+ Create an LLMClient instance from a configuration dictionary.
+
+ Args:
+ config: Configuration dictionary with provider_name, api_key, and optional base_url
+
+ Returns:
+ LLMClient instance
+ """
+ provider_name = config.get("provider_name", "openai")
+ api_key = config.get("api_key")
+ base_url = config.get("base_url")
+
+ if not api_key:
+ raise ValueError("API key is required")
+
+ return cls(provider_name=provider_name, api_key=api_key, base_url=base_url)
+
+ def process_tool_calls_and_follow_up(self, response, messages, model, temperature, max_tokens, max_iterations=5, cookie=None, stream=True):
+ """
+ Process tool calls recursively from a response and make follow-up requests until
+ there are no more tool calls or max_iterations is reached.
+ Returns a generator for streaming the final follow-up response.
+
+ Args:
+ response: The original response object containing tool calls
+ messages: List of message dictionaries with 'role' and 'content'
+ model: Model identifier
+ temperature: Sampling temperature (0-1)
+ max_tokens: Maximum tokens to generate
+ max_iterations: Maximum number of tool call iterations to prevent infinite loops
+ cookie: Airflow cookie for authentication (optional, will try to get from session if not provided)
stream: Whether to stream the response
Returns:
- If stream=True, returns a generator yielding response chunks
- If stream=False, returns the complete response
+ Generator for streaming the final follow-up response
"""
try:
- if provider == "openai":
- return self._openai_chat_completion(messages, model, temperature, max_tokens, stream)
- elif provider == "anthropic":
- return self._anthropic_chat_completion(messages, model, temperature, max_tokens, stream)
- elif provider == "openrouter":
- return self._openrouter_chat_completion(messages, model, temperature, max_tokens, stream)
- else:
- return {"error": f"Unknown provider: {provider}"}
+ iteration = 0
+ current_response = response
+
+ # Check if we have a cookie
+ if not cookie:
+ error_msg = "No Airflow cookie available"
+ logger.error(error_msg)
+ yield f"Error: {error_msg}"
+ return
+
+ # Process tool calls recursively until there are no more or max_iterations is reached
+ while self.provider.has_tool_calls(current_response) and iteration < max_iterations:
+ iteration += 1
+ logger.info(f"Processing tool calls iteration {iteration}/{max_iterations}")
+
+ # Process tool calls and get results
+ tool_results = self.provider.process_tool_calls(current_response, cookie)
+
+ # Make follow-up request with tool results
+ logger.info(f"Making follow-up request with tool results (iteration {iteration})")
+
+ # Always stream follow-up requests to ensure consistent behavior
+ # This ensures we get streaming responses from the provider
+ should_stream = True
+ logger.info(f"Setting should_stream=True for follow-up request (iteration {iteration})")
+
+ # Get provider-specific tool definitions from Airflow tools
+ provider_tools = self.provider.convert_tools(self.airflow_tools)
+
+ follow_up_response = self.provider.create_follow_up_completion(
+ messages=messages,
+ model=model,
+ temperature=temperature,
+ max_tokens=max_tokens,
+ tool_results=tool_results,
+ original_response=current_response,
+ stream=should_stream,
+ tools=provider_tools,
+ )
+
+ # Check if this follow-up response has more tool calls
+ if not self.provider.has_tool_calls(follow_up_response):
+ logger.info(f"No more tool calls after iteration {iteration}")
+ # Final response - always yield content in a streaming fashion
+ # Since we're always streaming now, we can directly yield chunks from the streaming generator
+ chunk_count = 0
+ for chunk in self.provider.get_streaming_content(follow_up_response):
+ chunk_count += 1
+ # logger.info(f"Yielding chunk {chunk_count} from streaming generator: {chunk[:50] if chunk else 'Empty chunk'}...")
+ yield chunk
+ logger.info(f"Finished yielding {chunk_count} chunks from streaming generator")
+
+ # Update current_response for the next iteration
+ current_response = follow_up_response
+
+ # If we've reached max_iterations and still have tool calls, log a warning
+ if iteration == max_iterations and self.provider.has_tool_calls(current_response):
+ logger.warning(f"Reached maximum tool call iterations ({max_iterations})")
+ # Stream the final response even if it has tool calls
+ if not should_stream:
+ # If we didn't stream this response, convert it to a single chunk
+ content = self.provider.get_content(follow_up_response)
+ logger.info(f"Yielding complete content as a single chunk (max iterations): {content[:100]}...")
+ yield content
+ logger.info("Finished yielding complete content (max iterations)")
+ else:
+ # Yield chunks from the streaming generator
+ logger.info("Starting to yield chunks from streaming generator (max iterations reached)")
+ chunk_count = 0
+ for chunk in self.provider.get_streaming_content(follow_up_response):
+ chunk_count += 1
+ logger.info(f"Yielding chunk {chunk_count} from streaming generator (max iterations)")
+ yield chunk
+ logger.info(f"Finished yielding {chunk_count} chunks from streaming generator (max iterations)")
+
+ # If we didn't process any tool calls (shouldn't happen), return an error
+ if iteration == 0:
+ error_msg = "No tool calls found in response"
+ logger.error(error_msg)
+ yield f"Error: {error_msg}"
+
except Exception as e:
- return {"error": f"API request failed: {str(e)}"}
+ error_msg = f"Error processing tool calls: {str(e)}\n{traceback.format_exc()}"
+ logger.error(error_msg)
+ yield f"Error: {str(e)}"
- def _openai_chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float, max_tokens: int | None, stream: bool):
- """Handle OpenAI chat completion requests."""
- response = self.openai_client.chat.completions.create(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream)
+ def refresh_tools(self, cookie: str) -> None:
+ """
+ Refresh the available Airflow tools.
- if stream:
-
- def response_generator():
- for chunk in response:
- if chunk.choices[0].delta.content:
- yield chunk.choices[0].delta.content
-
- return response_generator()
- else:
- return {"content": response.choices[0].message.content}
-
- def _anthropic_chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float, max_tokens: int | None, stream: bool):
- """Handle Anthropic chat completion requests."""
- # Convert messages to Anthropic format
- system_message = next((m["content"] for m in messages if m["role"] == "system"), None)
- conversation = []
- for m in messages:
- if m["role"] != "system":
- conversation.append({"role": "assistant" if m["role"] == "assistant" else "user", "content": m["content"]})
-
- response = self.anthropic_client.messages.create(model=model, messages=conversation, system=system_message, temperature=temperature, max_tokens=max_tokens, stream=stream)
-
- if stream:
-
- def response_generator():
- for chunk in response:
- if chunk.delta.text:
- yield chunk.delta.text
-
- return response_generator()
- else:
- return {"content": response.content[0].text}
-
- def _openrouter_chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float, max_tokens: int | None, stream: bool):
- """Handle OpenRouter chat completion requests."""
- response = self.openrouter_client.chat.completions.create(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream)
-
- if stream:
-
- def response_generator():
- for chunk in response:
- if chunk.choices[0].delta.content:
- yield chunk.choices[0].delta.content
-
- return response_generator()
- else:
- return {"content": response.choices[0].message.content}
+ Args:
+ cookie: Airflow cookie for authentication
+ """
+ try:
+ logger.info("Refreshing Airflow tools")
+ tools = list_airflow_tools(cookie)
+ self.set_airflow_tools(tools)
+ logger.info(f"Refreshed {len(tools)} Airflow tools")
+ except Exception as e:
+ error_msg = f"Error refreshing Airflow tools: {str(e)}\n{traceback.format_exc()}"
+ logger.error(error_msg)
+ # Don't raise the exception, just log it
+ # The client will continue to use the existing tools (if any)
diff --git a/src/airflow_wingman/llms_models.py b/src/airflow_wingman/llms_models.py
index c73cd46..ff6f2b9 100644
--- a/src/airflow_wingman/llms_models.py
+++ b/src/airflow_wingman/llms_models.py
@@ -17,14 +17,14 @@ MODELS = {
"endpoint": "https://api.anthropic.com/v1/messages",
"models": [
{
- "id": "claude-3.5-sonnet",
- "name": "Claude 3.5 Sonnet",
+ "id": "claude-3-7-sonnet-20250219",
+ "name": "Claude 3.7 Sonnet",
"default": True,
"context_window": 200000,
"description": "Input $3/M tokens, Output $15/M tokens",
},
{
- "id": "claude-3.5-haiku",
+ "id": "claude-3-5-haiku-20241022",
"name": "Claude 3.5 Haiku",
"default": False,
"context_window": 200000,
diff --git a/src/airflow_wingman/prompt_engineering.py b/src/airflow_wingman/prompt_engineering.py
index 31b5d8e..74a67e5 100644
--- a/src/airflow_wingman/prompt_engineering.py
+++ b/src/airflow_wingman/prompt_engineering.py
@@ -8,9 +8,11 @@ INSTRUCTIONS = {
You have deep knowledge of Apache Airflow's architecture, DAGs, operators, and best practices.
The Airflow version being used is >=2.10.
-You have access to the following Airflow API tools:
+You have access to Airflow MCP tools that you can use to fetch information and help users understand
+and manage their Airflow environment.
-You can use these tools to fetch information and help users understand and manage their Airflow environment.
+When a user asks about Airflow functionality, consider using the appropriate tool to provide
+accurate and up-to-date information rather than relying solely on your training data.
"""
}
diff --git a/src/airflow_wingman/providers/__init__.py b/src/airflow_wingman/providers/__init__.py
new file mode 100644
index 0000000..828726c
--- /dev/null
+++ b/src/airflow_wingman/providers/__init__.py
@@ -0,0 +1,41 @@
+"""
+Provider factory for Airflow Wingman.
+
+This module contains the factory function to create provider instances
+based on the provider name.
+"""
+
+from airflow_wingman.providers.anthropic_provider import AnthropicProvider
+from airflow_wingman.providers.base import BaseLLMProvider
+from airflow_wingman.providers.openai_provider import OpenAIProvider
+
+
+def create_llm_provider(provider_name: str, api_key: str, base_url: str | None = None) -> BaseLLMProvider:
+ """
+ Create a provider instance based on the provider name.
+
+ Args:
+ provider_name: Name of the provider (openai, anthropic, openrouter)
+ api_key: API key for the provider
+ base_url: Optional base URL for the provider API
+
+ Returns:
+ Provider instance
+
+ Raises:
+ ValueError: If the provider is not supported
+ """
+ provider_name = provider_name.lower()
+
+ if provider_name == "openai":
+ return OpenAIProvider(api_key=api_key, base_url=base_url)
+ elif provider_name == "openrouter":
+ # OpenRouter uses the OpenAI API format, so we can use the OpenAI provider
+ # with a custom base URL
+ if not base_url:
+ base_url = "https://openrouter.ai/api/v1"
+ return OpenAIProvider(api_key=api_key, base_url=base_url)
+ elif provider_name == "anthropic":
+ return AnthropicProvider(api_key=api_key)
+ else:
+ raise ValueError(f"Unsupported provider: {provider_name}. Supported providers: openai, anthropic, openrouter")
diff --git a/src/airflow_wingman/providers/anthropic_provider.py b/src/airflow_wingman/providers/anthropic_provider.py
new file mode 100644
index 0000000..c02441f
--- /dev/null
+++ b/src/airflow_wingman/providers/anthropic_provider.py
@@ -0,0 +1,458 @@
+"""
+Anthropic provider implementation for Airflow Wingman.
+
+This module contains the Anthropic provider implementation that handles
+API requests, tool conversion, and response processing for Anthropic's Claude models.
+"""
+
+import json
+import logging
+import traceback
+from typing import Any
+
+from anthropic import Anthropic
+
+from airflow_wingman.providers.base import BaseLLMProvider, StreamingResponse
+from airflow_wingman.tools import execute_airflow_tool
+from airflow_wingman.tools.conversion import convert_to_anthropic_tools
+
+# Create a properly namespaced logger for the Airflow plugin
+logger = logging.getLogger("airflow.plugins.wingman")
+
+
+class AnthropicProvider(BaseLLMProvider):
+ """
+ Anthropic provider implementation.
+
+ This class handles API requests, tool conversion, and response processing
+ for the Anthropic API (Claude models).
+ """
+
+ def __init__(self, api_key: str):
+ """
+ Initialize the Anthropic provider.
+
+ Args:
+ api_key: API key for Anthropic
+ """
+ self.api_key = api_key
+ self.client = Anthropic(api_key=api_key)
+
+ def convert_tools(self, airflow_tools: list) -> list:
+ """
+ Convert Airflow tools to Anthropic format.
+
+ Args:
+ airflow_tools: List of Airflow tools from MCP server
+
+ Returns:
+ List of Anthropic tool definitions
+ """
+ return convert_to_anthropic_tools(airflow_tools)
+
+ def create_chat_completion(
+ self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, stream: bool = False, tools: list[dict[str, Any]] | None = None
+ ) -> Any:
+ """
+ Make API request to Anthropic.
+
+ Args:
+ messages: List of message dictionaries with 'role' and 'content'
+ model: Model identifier
+ temperature: Sampling temperature (0-1)
+ max_tokens: Maximum tokens to generate
+ stream: Whether to stream the response
+ tools: List of tool definitions in Anthropic format
+
+ Returns:
+ Anthropic response object
+
+ Raises:
+ Exception: If the API request fails
+ """
+ # Convert max_tokens to Anthropic's max_tokens parameter (if provided)
+ max_tokens_param = max_tokens if max_tokens is not None else 4096
+
+ # Convert messages from ChatML format to Anthropic's format
+ anthropic_messages = self._convert_to_anthropic_messages(messages)
+
+ try:
+ logger.info(f"Sending chat completion request to Anthropic with model: {model}")
+
+ # Create request parameters
+ params = {"model": model, "messages": anthropic_messages, "temperature": temperature, "max_tokens": max_tokens_param, "stream": stream}
+
+ # Add tools if provided
+ if tools and len(tools) > 0:
+ params["tools"] = tools
+ else:
+ logger.warning("No tools included in request")
+
+ # Log the full request parameters (with sensitive information redacted)
+ log_params = params.copy()
+ logger.info(f"Request parameters: {json.dumps(log_params)}")
+
+ # Make the API request
+ response = self.client.messages.create(**params)
+
+ logger.info("Received response from Anthropic")
+ # Log the response (with sensitive information redacted)
+ logger.info(f"Anthropic response type: {type(response).__name__}")
+
+ # Log as much information as possible
+ if hasattr(response, "json"):
+ if callable(response.json):
+ # If json is a method, call it
+ try:
+ logger.info(f"Anthropic response json: {json.dumps(response.model_dump_json())}")
+ except Exception as json_err:
+ logger.warning(f"Could not serialize response.json(): {str(json_err)}")
+ else:
+ # If json is a property, use it directly
+ try:
+ logger.info(f"Anthropic response json: {json.dumps(response.json)}")
+ except Exception as json_err:
+ logger.warning(f"Could not serialize response.json: {str(json_err)}")
+
+ # Log response attributes
+ response_attrs = [attr for attr in dir(response) if not attr.startswith("_") and not callable(getattr(response, attr))]
+ logger.info(f"Anthropic response attributes: {response_attrs}")
+
+ return response
+ except Exception as e:
+ error_msg = str(e)
+ logger.error(f"Failed to get response from Anthropic: {error_msg}\n{traceback.format_exc()}")
+ raise
+
+ def _convert_to_anthropic_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
+ """
+ Convert messages from ChatML format to Anthropic's format.
+
+ Args:
+ messages: List of message dictionaries in ChatML format
+
+ Returns:
+ List of message dictionaries in Anthropic format
+ """
+ anthropic_messages = []
+
+ for message in messages:
+ role = message["role"]
+ content = message["content"]
+
+ # Map ChatML roles to Anthropic roles
+ if role == "system":
+ # System messages in Anthropic are handled differently
+ # We'll add them as a user message with a special prefix
+ anthropic_messages.append({"role": "user", "content": f"\n{content}\n"})
+ elif role == "user":
+ anthropic_messages.append({"role": "user", "content": content})
+ elif role == "assistant":
+ anthropic_messages.append({"role": "assistant", "content": content})
+ elif role == "tool":
+ # Tool messages in ChatML become part of the user message in Anthropic
+ # We'll handle this in the follow-up completion
+ continue
+
+ return anthropic_messages
+
+ def has_tool_calls(self, response: Any) -> bool:
+ """
+ Check if the response contains tool calls.
+
+ Args:
+ response: Anthropic response object or StreamingResponse with tool_call attribute
+
+ Returns:
+ True if the response contains tool calls, False otherwise
+ """
+ logger.info(f"Checking for tool calls in response of type: {type(response)}")
+
+ # Check if response is a StreamingResponse with a tool_call attribute
+ if isinstance(response, StreamingResponse):
+ logger.info(f"Response is a StreamingResponse, has tool_call attribute: {hasattr(response, 'tool_call')}")
+ if response.tool_call is not None:
+ logger.info(f"StreamingResponse has non-None tool_call: {response.tool_call}")
+ return True
+ else:
+ logger.info("StreamingResponse has None tool_call")
+ else:
+ logger.info("Response is not a StreamingResponse")
+
+ # Check if any content block is a tool_use block (for non-streaming responses)
+ if hasattr(response, "content"):
+ logger.info(f"Response has content attribute with {len(response.content)} blocks")
+ for i, block in enumerate(response.content):
+ logger.info(f"Checking content block {i}: {type(block)}")
+ if isinstance(block, dict) and block.get("type") == "tool_use":
+ logger.info(f"Found tool_use block: {block}")
+ return True
+ else:
+ logger.info("Response does not have content attribute")
+
+ logger.info("No tool calls found in response")
+ return False
+
+ def get_tool_calls(self, response: Any) -> list:
+ """
+ Extract tool calls from the response.
+
+ Args:
+ response: Anthropic response object or StreamingResponse with tool_call attribute
+
+ Returns:
+ List of tool call objects in a standardized format
+ """
+ tool_calls = []
+
+ # Check if response is a StreamingResponse with a tool_call attribute
+ if isinstance(response, StreamingResponse) and response.tool_call is not None:
+ logger.info(f"Extracting tool call from StreamingResponse: {response.tool_call}")
+ tool_calls.append(response.tool_call)
+ # Otherwise, extract tool calls from response content (for non-streaming responses)
+ elif hasattr(response, "content"):
+ logger.info("Extracting tool calls from response content")
+ for block in response.content:
+ if isinstance(block, dict) and block.get("type") == "tool_use":
+ tool_call = {"id": block.get("id", ""), "name": block.get("name", ""), "input": block.get("input", {})}
+ tool_calls.append(tool_call)
+
+ return tool_calls
+
+ def process_tool_calls(self, response: Any, cookie: str) -> dict[str, Any]:
+ """
+ Process tool calls from the response.
+
+ Args:
+ response: Anthropic response object or generator with tool_call attribute
+ cookie: Airflow cookie for authentication
+
+ Returns:
+ Dictionary mapping tool call IDs to results
+ """
+ results = {}
+
+ if not self.has_tool_calls(response):
+ return results
+
+ # Get tool calls using the standardized method
+ tool_calls = self.get_tool_calls(response)
+ logger.info(f"Processing {len(tool_calls)} tool calls")
+
+ for tool_call in tool_calls:
+ # Extract tool details - handle both formats (generator's tool_call and content block)
+ if isinstance(tool_call, dict) and "id" in tool_call:
+ # This is from the generator's tool_call attribute
+ tool_id = tool_call.get("id")
+ tool_name = tool_call.get("name")
+ tool_input = tool_call.get("input", {})
+ else:
+ # This is from the content blocks
+ tool_id = tool_call.get("id")
+ tool_name = tool_call.get("name")
+ tool_input = tool_call.get("input", {})
+
+ try:
+ # Execute the Airflow tool with the provided arguments and cookie
+ logger.info(f"Executing tool: {tool_name} with arguments: {tool_input}")
+ result = execute_airflow_tool(tool_name, tool_input, cookie)
+ logger.info(f"Tool execution result: {result}")
+ results[tool_id] = {"status": "success", "result": result}
+ except Exception as e:
+ error_msg = f"Error executing tool: {str(e)}\n{traceback.format_exc()}"
+ logger.error(error_msg)
+ results[tool_id] = {"status": "error", "message": error_msg}
+
+ return results
+
+ def create_follow_up_completion(
+ self,
+ messages: list[dict[str, Any]],
+ model: str,
+ temperature: float = 0.4,
+ max_tokens: int | None = None,
+ tool_results: dict[str, Any] = None,
+ original_response: Any = None,
+ stream: bool = True,
+ tools: list[dict[str, Any]] | None = None,
+ ) -> Any:
+ """
+ Create a follow-up completion with tool results.
+
+ Args:
+ messages: Original messages
+ model: Model identifier
+ temperature: Sampling temperature (0-1)
+ max_tokens: Maximum tokens to generate
+ tool_results: Results of tool executions
+ original_response: Original response with tool calls
+ stream: Whether to stream the response
+ tools: List of tool definitions in Anthropic format
+
+ Returns:
+ Anthropic response object or generator if streaming
+ """
+ if not original_response or not tool_results:
+ return original_response
+
+ # Extract tool call from the StreamingResponse or content blocks from Anthropic response
+ tool_use_blocks = []
+ if isinstance(original_response, StreamingResponse) and original_response.tool_call:
+ # For StreamingResponse, create a tool_use block from the tool_call
+ logger.info(f"Creating tool_use block from StreamingResponse.tool_call: {original_response.tool_call}")
+ tool_call = original_response.tool_call
+ tool_use_blocks.append({"type": "tool_use", "id": tool_call.get("id", ""), "name": tool_call.get("name", ""), "input": tool_call.get("input", {})})
+ elif hasattr(original_response, "content"):
+ # For regular Anthropic response, extract from content blocks
+ logger.info("Extracting tool_use blocks from response content")
+ tool_use_blocks = [block for block in original_response.content if isinstance(block, dict) and block.get("type") == "tool_use"]
+
+ # Create tool result blocks
+ tool_result_blocks = []
+ for tool_id, result in tool_results.items():
+ tool_result_blocks.append({"type": "tool_result", "tool_use_id": tool_id, "content": result.get("result", str(result))})
+
+ # Convert original messages to Anthropic format
+ anthropic_messages = self._convert_to_anthropic_messages(messages)
+
+ # Add the assistant response with tool use
+ anthropic_messages.append({"role": "assistant", "content": tool_use_blocks})
+
+ # Add the user message with tool results
+ anthropic_messages.append({"role": "user", "content": tool_result_blocks})
+
+ # Make a second request to get the final response
+ logger.info(f"Making second request with tool results (stream={stream})")
+ return self.create_chat_completion(
+ messages=anthropic_messages,
+ model=model,
+ temperature=temperature,
+ max_tokens=max_tokens,
+ stream=stream,
+ tools=tools,
+ )
+
+ def get_content(self, response: Any) -> str:
+ """
+ Extract content from the response.
+
+ Args:
+ response: Anthropic response object
+
+ Returns:
+ Content string from the response
+ """
+ if not hasattr(response, "content"):
+ return ""
+
+ # Combine all text blocks into a single string
+ content_parts = []
+ for block in response.content:
+ if isinstance(block, dict) and block.get("type") == "text":
+ content_parts.append(block.get("text", ""))
+ elif isinstance(block, str):
+ content_parts.append(block)
+
+ return "".join(content_parts)
+
+ def get_streaming_content(self, response: Any) -> StreamingResponse:
+ """
+ Get a generator for streaming content from the response.
+
+ Args:
+ response: Anthropic streaming response object
+
+ Returns:
+ StreamingResponse object wrapping a generator that yields content chunks
+ and can also store tool call information detected during streaming
+ """
+ logger.info("Starting Anthropic streaming response processing")
+
+ # Track only the first tool call detected during streaming
+ tool_call = None
+ tool_use_detected = False
+
+ # Create the StreamingResponse object first
+ streaming_response = StreamingResponse(generator=None, tool_call=None)
+
+ def generate():
+ nonlocal tool_call, tool_use_detected
+
+ for chunk in response:
+ logger.debug(f"Chunk type: {type(chunk)}")
+ logger.debug(f"Chunk content: {json.dumps(chunk.model_dump_json()) if hasattr(chunk, 'json') else str(chunk)}")
+
+ # Check for content_block_start events with type "tool_use"
+ if not tool_use_detected and hasattr(chunk, "type") and chunk.type == "content_block_start":
+ if hasattr(chunk, "content_block") and hasattr(chunk.content_block, "type"):
+ if chunk.content_block.type == "tool_use":
+ logger.info(f"Tool use detected in streaming response: {json.dumps(chunk.model_dump_json()) if hasattr(chunk, 'json') else str(chunk)}")
+ tool_use_detected = True
+ tool_call = {"id": getattr(chunk.content_block, "id", ""), "name": getattr(chunk.content_block, "name", ""), "input": getattr(chunk.content_block, "input", {})}
+ # Update the StreamingResponse object's tool_call attribute
+ streaming_response.tool_call = tool_call
+ # We don't signal to the frontend during streaming
+ # The tool will only be executed after streaming ends
+ continue
+
+ # Handle content_block_delta events for tool_use (input updates)
+ if tool_use_detected and hasattr(chunk, "type") and chunk.type == "content_block_delta":
+ if hasattr(chunk, "delta") and hasattr(chunk.delta, "type") and chunk.delta.type == "input_json_delta":
+ if hasattr(chunk.delta, "partial_json") and chunk.delta.partial_json:
+ logger.info(f"Tool use input update: {chunk.delta.partial_json}")
+ # Update the current tool call input
+ if tool_call:
+ try:
+ # Try to parse the partial JSON and update the input
+ partial_input = json.loads(chunk.delta.partial_json)
+ tool_call["input"].update(partial_input)
+ # Update the StreamingResponse object's tool_call attribute
+ streaming_response.tool_call = tool_call
+ except json.JSONDecodeError:
+ logger.warning(f"Failed to parse partial JSON: {chunk.delta.partial_json}")
+ continue
+
+ # Handle content_block_stop events for tool_use
+ if tool_use_detected and hasattr(chunk, "type") and chunk.type == "content_block_stop":
+ logger.info("Tool use block completed")
+ # Log the complete tool call for debugging
+ if tool_call:
+ logger.info(f"Completed tool call: {json.dumps(tool_call)}")
+ # Update the StreamingResponse object's tool_call attribute
+ streaming_response.tool_call = tool_call
+ continue
+
+ # Handle message_delta events with stop_reason "tool_use"
+ if hasattr(chunk, "type") and chunk.type == "message_delta":
+ if hasattr(chunk, "delta") and hasattr(chunk.delta, "stop_reason"):
+ if chunk.delta.stop_reason == "tool_use":
+ logger.info("Message stopped due to tool use")
+ # Update the StreamingResponse object's tool_call attribute one last time
+ if tool_call:
+ streaming_response.tool_call = tool_call
+ continue
+
+ # Handle regular content chunks
+ content = None
+ if hasattr(chunk, "type") and chunk.type == "content_block_delta":
+ if hasattr(chunk, "delta") and hasattr(chunk.delta, "text"):
+ content = chunk.delta.text
+ elif hasattr(chunk, "delta") and hasattr(chunk.delta, "text"):
+ content = chunk.delta.text
+ elif hasattr(chunk, "content") and chunk.content:
+ for block in chunk.content:
+ if isinstance(block, dict) and block.get("type") == "text":
+ content = block.get("text", "")
+
+ if content:
+ # Don't do any newline replacement here
+ yield content
+
+ # Create the generator
+ gen = generate()
+
+ # Set the generator in the StreamingResponse object
+ streaming_response.generator = gen
+
+ # Return the StreamingResponse object
+ return streaming_response
diff --git a/src/airflow_wingman/providers/base.py b/src/airflow_wingman/providers/base.py
new file mode 100644
index 0000000..18195dc
--- /dev/null
+++ b/src/airflow_wingman/providers/base.py
@@ -0,0 +1,216 @@
+"""
+Base provider interface for Airflow Wingman.
+
+This module contains the base provider interface that all provider implementations
+must adhere to. It defines the methods required for tool conversion, API requests,
+and response processing.
+"""
+
+import json
+from abc import ABC, abstractmethod
+from collections.abc import Generator, Iterator
+from typing import Any, Generic, TypeVar
+
+T = TypeVar("T")
+
+
+class StreamingResponse(Generic[T]):
+ """
+ Wrapper for streaming responses that can hold tool call information.
+
+ This class wraps a generator and provides an iterator interface while also
+ storing tool call information. This allows us to associate metadata with
+ a generator without modifying the generator itself.
+ """
+
+ def __init__(self, generator: Generator[T, None, None], tool_call: dict = None):
+ """
+ Initialize the streaming response.
+
+ Args:
+ generator: The underlying generator yielding content chunks
+ tool_call: Optional tool call information detected during streaming
+ """
+ self.generator = generator
+ self.tool_call = tool_call
+
+ def __iter__(self) -> Iterator[T]:
+ """
+ Return self as iterator.
+ """
+ return self
+
+ def __next__(self) -> T:
+ """
+ Get the next item from the generator.
+ """
+ return next(self.generator)
+
+
+class BaseLLMProvider(ABC):
+ """
+ Base provider interface for LLM providers.
+
+ This abstract class defines the methods that all provider implementations
+ must implement to support tool integration.
+ """
+
+ @abstractmethod
+ def convert_tools(self, airflow_tools: list) -> list:
+ """
+ Convert internal tool representation to provider format.
+
+ Args:
+ airflow_tools: List of Airflow tools from MCP server
+
+ Returns:
+ List of provider-specific tool definitions
+ """
+ pass
+
+ @abstractmethod
+ def create_chat_completion(
+ self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, stream: bool = False, tools: list[dict[str, Any]] | None = None
+ ) -> Any:
+ """
+ Make API request to provider.
+
+ Args:
+ messages: List of message dictionaries with 'role' and 'content'
+ model: Model identifier
+ temperature: Sampling temperature (0-1)
+ max_tokens: Maximum tokens to generate
+ stream: Whether to stream the response
+ tools: List of tool definitions in provider format
+
+ Returns:
+ Provider-specific response object
+ """
+ pass
+
+ def has_tool_calls(self, response: Any) -> bool:
+ """
+ Check if the response contains tool calls.
+
+ Args:
+ response: Provider-specific response object or StreamingResponse
+
+ Returns:
+ True if the response contains tool calls, False otherwise
+ """
+ # Check if response is a StreamingResponse with a tool_call attribute
+ if isinstance(response, StreamingResponse) and response.tool_call is not None:
+ return True
+
+ # Provider-specific implementation should handle other cases
+ return False
+
+ def get_tool_calls(self, response: Any) -> list:
+ """
+ Extract tool calls from the response.
+
+ Args:
+ response: Provider-specific response object or StreamingResponse
+
+ Returns:
+ List of tool call objects in a standardized format
+ """
+ tool_calls = []
+
+ # Check if response is a StreamingResponse with a tool_call attribute
+ if isinstance(response, StreamingResponse) and response.tool_call is not None:
+ tool_calls.append(response.tool_call)
+
+ # Provider-specific implementation should handle other cases
+ return tool_calls
+
+ def process_tool_calls(self, response: Any, cookie: str) -> dict[str, Any]:
+ """
+ Process tool calls from the response.
+
+ Args:
+ response: Provider-specific response object or StreamingResponse
+ cookie: Airflow cookie for authentication
+
+ Returns:
+ Dictionary mapping tool call IDs to results
+ """
+ tool_calls = self.get_tool_calls(response)
+ results = {}
+
+ for tool_call in tool_calls:
+ tool_name = tool_call.get("name", "")
+ tool_input = tool_call.get("input", {})
+ tool_id = tool_call.get("id", "")
+
+ try:
+ import logging
+
+ logger = logging.getLogger(__name__)
+ logger.info(f"Executing tool: {tool_name} with input: {json.dumps(tool_input)}")
+
+ from airflow_wingman.tools import execute_airflow_tool
+
+ result = execute_airflow_tool(tool_name, tool_input, cookie)
+
+ logger.info(f"Tool result: {json.dumps(result)}")
+ results[tool_id] = {
+ "name": tool_name,
+ "input": tool_input,
+ "output": result,
+ }
+ except Exception as e:
+ import traceback
+
+ error_msg = f"Error executing tool: {str(e)}\n{traceback.format_exc()}"
+ logger.error(error_msg)
+ results[tool_id] = {"status": "error", "message": error_msg}
+
+ return results
+
+ @abstractmethod
+ def create_follow_up_completion(
+ self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None
+ ) -> Any:
+ """
+ Create a follow-up completion with tool results.
+
+ Args:
+ messages: Original messages
+ model: Model identifier
+ temperature: Sampling temperature (0-1)
+ max_tokens: Maximum tokens to generate
+ tool_results: Results of tool executions
+ original_response: Original response with tool calls
+
+ Returns:
+ Provider-specific response object
+ """
+ pass
+
+ @abstractmethod
+ def get_content(self, response: Any) -> str:
+ """
+ Extract content from the response.
+
+ Args:
+ response: Provider-specific response object
+
+ Returns:
+ Content string from the response
+ """
+ pass
+
+ @abstractmethod
+ def get_streaming_content(self, response: Any) -> StreamingResponse:
+ """
+ Get a StreamingResponse for streaming content from the response.
+
+ Args:
+ response: Provider-specific response object
+
+ Returns:
+ StreamingResponse object wrapping a generator that yields content chunks
+ and can also store tool call information detected during streaming
+ """
+ pass
diff --git a/src/airflow_wingman/providers/openai_provider.py b/src/airflow_wingman/providers/openai_provider.py
new file mode 100644
index 0000000..d1847da
--- /dev/null
+++ b/src/airflow_wingman/providers/openai_provider.py
@@ -0,0 +1,354 @@
+"""
+OpenAI provider implementation for Airflow Wingman.
+
+This module contains the OpenAI provider implementation that handles
+API requests, tool conversion, and response processing for OpenAI.
+"""
+
+import json
+import logging
+import traceback
+from typing import Any
+
+from openai import OpenAI
+
+from airflow_wingman.providers.base import BaseLLMProvider, StreamingResponse
+from airflow_wingman.tools import execute_airflow_tool
+from airflow_wingman.tools.conversion import convert_to_openai_tools
+
+# Create a properly namespaced logger for the Airflow plugin
+logger = logging.getLogger("airflow.plugins.wingman")
+
+
+class OpenAIProvider(BaseLLMProvider):
+ """
+ OpenAI provider implementation.
+
+ This class handles API requests, tool conversion, and response processing
+ for the OpenAI API. It can also be used for OpenRouter with a custom base URL.
+ """
+
+ def __init__(self, api_key: str, base_url: str | None = None):
+ """
+ Initialize the OpenAI provider.
+
+ Args:
+ api_key: API key for OpenAI
+ base_url: Optional base URL for the API (used for OpenRouter)
+ """
+ self.api_key = api_key
+ self.client = OpenAI(api_key=api_key, base_url=base_url)
+
+ def convert_tools(self, airflow_tools: list) -> list:
+ """
+ Convert Airflow tools to OpenAI format.
+
+ Args:
+ airflow_tools: List of Airflow tools from MCP server
+
+ Returns:
+ List of OpenAI tool definitions
+ """
+ return convert_to_openai_tools(airflow_tools)
+
+ def create_chat_completion(
+ self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, stream: bool = False, tools: list[dict[str, Any]] | None = None
+ ) -> Any:
+ """
+ Make API request to OpenAI.
+
+ Args:
+ messages: List of message dictionaries with 'role' and 'content'
+ model: Model identifier
+ temperature: Sampling temperature (0-1)
+ max_tokens: Maximum tokens to generate
+ stream: Whether to stream the response
+ tools: List of tool definitions in OpenAI format
+
+ Returns:
+ OpenAI response object
+
+ Raises:
+ Exception: If the API request fails
+ """
+ # Only include tools if we have any
+ has_tools = tools is not None and len(tools) > 0
+ tool_choice = "auto" if has_tools else None
+
+ try:
+ logger.info(f"Sending chat completion request to OpenAI with model: {model}")
+
+ # Log information about tools
+ if not has_tools:
+ logger.warning("No tools included in request")
+
+ # Log request parameters
+ request_params = {
+ "model": model,
+ "messages": messages,
+ "temperature": temperature,
+ "max_tokens": max_tokens,
+ "stream": stream,
+ "tools": tools if has_tools else None,
+ "tool_choice": tool_choice,
+ }
+ logger.info(f"Request parameters: {json.dumps(request_params)}")
+
+ response = self.client.chat.completions.create(
+ model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream, tools=tools if has_tools else None, tool_choice=tool_choice
+ )
+ logger.info("Received response from OpenAI")
+ return response
+ except Exception as e:
+ # If the API call fails due to tools not being supported, retry without tools
+ error_msg = str(e)
+ logger.warning(f"Error in OpenAI API call: {error_msg}")
+ if "tools" in error_msg.lower():
+ logger.info("Retrying without tools")
+ response = self.client.chat.completions.create(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream)
+ return response
+ else:
+ logger.error(f"Failed to get response from OpenAI: {error_msg}\n{traceback.format_exc()}")
+ raise
+
+ def has_tool_calls(self, response: Any) -> bool:
+ """
+ Check if the response contains tool calls.
+
+ Args:
+ response: OpenAI response object or StreamingResponse with tool_call attribute
+
+ Returns:
+ True if the response contains tool calls, False otherwise
+ """
+ # Check if response is a StreamingResponse with a tool_call attribute
+ if isinstance(response, StreamingResponse) and response.tool_call is not None:
+ return True
+
+ # For non-streaming responses
+ if hasattr(response, "choices") and len(response.choices) > 0:
+ message = response.choices[0].message
+ return hasattr(message, "tool_calls") and message.tool_calls
+
+ return False
+
+ def get_tool_calls(self, response: Any) -> list:
+ """
+ Extract tool calls from the response.
+
+ Args:
+ response: OpenAI response object or StreamingResponse with tool_call attribute
+
+ Returns:
+ List of tool call objects in a standardized format
+ """
+ tool_calls = []
+
+ # Check if response is a StreamingResponse with a tool_call attribute
+ if isinstance(response, StreamingResponse) and response.tool_call is not None:
+ logger.info(f"Extracting tool call from StreamingResponse: {response.tool_call}")
+ tool_calls.append(response.tool_call)
+ return tool_calls
+
+ # For non-streaming responses
+ if hasattr(response, "choices") and len(response.choices) > 0:
+ message = response.choices[0].message
+ if hasattr(message, "tool_calls") and message.tool_calls:
+ for tool_call in message.tool_calls:
+ standardized_tool_call = {"id": tool_call.id, "name": tool_call.function.name, "input": json.loads(tool_call.function.arguments)}
+ tool_calls.append(standardized_tool_call)
+
+ logger.info(f"Extracted {len(tool_calls)} tool calls from OpenAI response")
+ return tool_calls
+
+ def process_tool_calls(self, response: Any, cookie: str) -> dict[str, Any]:
+ """
+ Process tool calls from the response.
+
+ Args:
+ response: OpenAI response object or StreamingResponse with tool_call attribute
+ cookie: Airflow cookie for authentication
+
+ Returns:
+ Dictionary mapping tool call IDs to results
+ """
+ results = {}
+
+ if not self.has_tool_calls(response):
+ return results
+
+ # Get tool calls using the standardized method
+ tool_calls = self.get_tool_calls(response)
+ logger.info(f"Processing {len(tool_calls)} tool calls")
+
+ for tool_call in tool_calls:
+ tool_id = tool_call["id"]
+ function_name = tool_call["name"]
+ arguments = tool_call["input"]
+
+ try:
+ # Execute the Airflow tool with the provided arguments and cookie
+ logger.info(f"Executing tool: {function_name} with arguments: {arguments}")
+ result = execute_airflow_tool(function_name, arguments, cookie)
+ logger.info(f"Tool execution result: {result}")
+ results[tool_id] = {"status": "success", "result": result}
+ except Exception as e:
+ error_msg = f"Error executing tool: {str(e)}\n{traceback.format_exc()}"
+ logger.error(error_msg)
+ results[tool_id] = {"status": "error", "message": error_msg}
+
+ return results
+
+ def create_follow_up_completion(
+ self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None
+ ) -> Any:
+ """
+ Create a follow-up completion with tool results.
+
+ Args:
+ messages: Original messages
+ model: Model identifier
+ temperature: Sampling temperature (0-1)
+ max_tokens: Maximum tokens to generate
+ tool_results: Results of tool executions
+ original_response: Original response with tool calls
+
+ Returns:
+ OpenAI response object
+ """
+ if not original_response or not tool_results:
+ return original_response
+
+ # Get the original message with tool calls
+ original_message = original_response.choices[0].message
+
+ # Create a new message with the tool calls
+ assistant_message = {
+ "role": "assistant",
+ "content": None,
+ "tool_calls": [{"id": tc.id, "type": "function", "function": {"name": tc.function.name, "arguments": tc.function.arguments}} for tc in original_message.tool_calls],
+ }
+
+ # Create tool result messages
+ tool_messages = []
+ for tool_call_id, result in tool_results.items():
+ tool_messages.append({"role": "tool", "tool_call_id": tool_call_id, "content": result.get("result", str(result))})
+
+ # Add the original messages, assistant message, and tool results
+ new_messages = messages + [assistant_message] + tool_messages
+
+ # Make a second request to get the final response
+ logger.info("Making second request with tool results")
+ return self.create_chat_completion(
+ messages=new_messages,
+ model=model,
+ temperature=temperature,
+ max_tokens=max_tokens,
+ stream=False,
+ tools=None, # No tools needed for follow-up
+ )
+
+ def get_content(self, response: Any) -> str:
+ """
+ Extract content from the response.
+
+ Args:
+ response: OpenAI response object
+
+ Returns:
+ Content string from the response
+ """
+ return response.choices[0].message.content
+
+ def get_streaming_content(self, response: Any) -> StreamingResponse:
+ """
+ Get a generator for streaming content from the response.
+
+ Args:
+ response: OpenAI streaming response object
+
+ Returns:
+ StreamingResponse object wrapping a generator that yields content chunks
+ and can also store tool call information detected during streaming
+ """
+ logger.info("Starting OpenAI streaming response processing")
+
+ # Track only the first tool call detected during streaming
+ tool_call = None
+ tool_use_detected = False
+ current_tool_call = None
+
+ # Create the StreamingResponse object first
+ streaming_response = StreamingResponse(generator=None, tool_call=None)
+
+ def generate():
+ nonlocal tool_call, tool_use_detected, current_tool_call
+
+ for chunk in response:
+ # Check for tool call in the delta
+ if chunk.choices and hasattr(chunk.choices[0].delta, "tool_calls") and chunk.choices[0].delta.tool_calls:
+ # Tool call detected
+ if not tool_use_detected:
+ tool_use_detected = True
+ logger.info("Tool call detected in streaming response")
+
+ # Initialize the tool call
+ delta_tool_call = chunk.choices[0].delta.tool_calls[0]
+ current_tool_call = {
+ "id": getattr(delta_tool_call, "id", ""),
+ "name": getattr(delta_tool_call.function, "name", "") if hasattr(delta_tool_call, "function") else "",
+ "input": {},
+ }
+ # Update the StreamingResponse object's tool_call attribute
+ streaming_response.tool_call = current_tool_call
+ else:
+ # Update the existing tool call
+ delta_tool_call = chunk.choices[0].delta.tool_calls[0]
+
+ # Update the tool call ID if it's provided in this chunk
+ if hasattr(delta_tool_call, "id") and delta_tool_call.id and current_tool_call:
+ current_tool_call["id"] = delta_tool_call.id
+
+ # Update the function name if it's provided in this chunk
+ if hasattr(delta_tool_call, "function") and hasattr(delta_tool_call.function, "name") and delta_tool_call.function.name and current_tool_call:
+ current_tool_call["name"] = delta_tool_call.function.name
+
+ # Update the arguments if they're provided in this chunk
+ if hasattr(delta_tool_call, "function") and hasattr(delta_tool_call.function, "arguments") and delta_tool_call.function.arguments and current_tool_call:
+ try:
+ # Try to parse the arguments JSON
+ arguments = json.loads(delta_tool_call.function.arguments)
+ if isinstance(arguments, dict):
+ current_tool_call["input"].update(arguments)
+ # Update the StreamingResponse object's tool_call attribute
+ streaming_response.tool_call = current_tool_call
+ except json.JSONDecodeError:
+ # If the arguments are not valid JSON, just log a warning
+ logger.warning(f"Failed to parse arguments: {delta_tool_call.function.arguments}")
+
+ # Skip yielding content for tool call chunks
+ continue
+
+ # For the final chunk, set the tool_call attribute
+ if chunk.choices and hasattr(chunk.choices[0], "finish_reason") and chunk.choices[0].finish_reason == "tool_calls":
+ logger.info("Streaming response finished with tool_calls reason")
+ if current_tool_call:
+ tool_call = current_tool_call
+ logger.info(f"Final tool call: {json.dumps(tool_call)}")
+ # Update the StreamingResponse object's tool_call attribute
+ streaming_response.tool_call = tool_call
+ continue
+
+ # Handle regular content chunks
+ if chunk.choices and hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content:
+ content = chunk.choices[0].delta.content
+ yield content
+
+ # Create the generator
+ gen = generate()
+
+ # Set the generator in the StreamingResponse object
+ streaming_response.generator = gen
+
+ # Return the StreamingResponse object
+ return streaming_response
diff --git a/src/airflow_wingman/static/css/wingman_chat.css b/src/airflow_wingman/static/css/wingman_chat.css
new file mode 100644
index 0000000..1d6cdce
--- /dev/null
+++ b/src/airflow_wingman/static/css/wingman_chat.css
@@ -0,0 +1,99 @@
+/* Provider and model selection styling */
+.provider-section {
+ margin-bottom: 20px;
+}
+.provider-name {
+ font-size: 16px;
+ font-weight: bold;
+ margin-bottom: 10px;
+ color: #666;
+}
+.model-option {
+ margin-left: 15px;
+ margin-bottom: 8px;
+}
+.model-option label {
+ display: block;
+ cursor: pointer;
+}
+
+/* Message styling */
+.message {
+ margin-bottom: 15px;
+ max-width: 80%;
+ clear: both;
+}
+
+.message p {
+ margin-top: 0.5em;
+ margin-bottom: 0.5em;
+}
+
+.message-user {
+ float: right;
+ background-color: #f0f7ff;
+ border: 1px solid #d1e6ff;
+ border-radius: 15px 15px 0 15px;
+ padding: 10px 15px;
+}
+
+.message pre {
+ margin-top: 0.5em;
+ margin-bottom: 0.5em;
+ padding: 0.5em;
+}
+
+.message-assistant {
+ float: left;
+ background-color: #f8f9fa;
+ border: 1px solid #e9ecef;
+ border-radius: 15px 15px 15px 0;
+ padding: 10px 15px;
+}
+
+.message code {
+ padding: 0.1em 0.3em;
+}
+
+#chat-messages::after {
+ content: "";
+ clear: both;
+ display: table;
+}
+
+/* Scrollbar styling */
+.panel-body::-webkit-scrollbar {
+ width: 8px;
+}
+
+.panel-body::-webkit-scrollbar-track {
+ background: #f1f1f1;
+}
+
+.panel-body::-webkit-scrollbar-thumb {
+ background: #888;
+ border-radius: 4px;
+}
+
+.panel-body::-webkit-scrollbar-thumb:hover {
+ background: #555;
+}
+
+/* Processing indicator styling */
+.processing-indicator {
+ display: none;
+ background-color: #f0f8ff;
+ padding: 8px 12px;
+ border-radius: 4px;
+ margin: 8px 0;
+ font-style: italic;
+}
+
+.processing-indicator.visible {
+ display: block;
+}
+
+.pre-formatted {
+ font-family: monospace;
+ line-height: 1.2;
+}
diff --git a/src/airflow_wingman/static/js/wingman_chat.js b/src/airflow_wingman/static/js/wingman_chat.js
new file mode 100644
index 0000000..e12f53d
--- /dev/null
+++ b/src/airflow_wingman/static/js/wingman_chat.js
@@ -0,0 +1,346 @@
+document.addEventListener('DOMContentLoaded', function() {
+ // Initialize tooltips
+ document.querySelectorAll('[data-bs-toggle="tooltip"]').forEach(function(el) {
+ el.title = el.getAttribute('title') || el.getAttribute('data-bs-original-title');
+ });
+
+ // Handle model selection and model name input
+ const modelNameInput = document.getElementById('modelName');
+ const modelRadios = document.querySelectorAll('input[name="model"]');
+
+ modelRadios.forEach(function(radio) {
+ radio.addEventListener('change', function() {
+ const provider = this.value.split(':')[0];
+ const modelName = this.getAttribute('data-model-name');
+ console.log('Selected provider:', provider);
+ console.log('Model name:', modelName);
+
+ if (provider === 'openrouter') {
+ console.log('Enabling model name input');
+ modelNameInput.disabled = false;
+ modelNameInput.value = '';
+ modelNameInput.placeholder = 'Enter model name for OpenRouter';
+ } else {
+ console.log('Disabling model name input');
+ modelNameInput.disabled = true;
+ modelNameInput.value = modelName;
+ }
+ });
+ });
+
+ // Set initial state based on default selection
+ const defaultSelected = document.querySelector('input[name="model"]:checked');
+ if (defaultSelected) {
+ const provider = defaultSelected.value.split(':')[0];
+ const modelName = defaultSelected.getAttribute('data-model-name');
+ console.log('Initial provider:', provider);
+ console.log('Initial model name:', modelName);
+
+ if (provider === 'openrouter') {
+ console.log('Initially enabling model name input');
+ modelNameInput.disabled = false;
+ modelNameInput.value = '';
+ modelNameInput.placeholder = 'Enter model name for OpenRouter';
+ } else {
+ console.log('Initially disabling model name input');
+ modelNameInput.disabled = true;
+ modelNameInput.value = modelName;
+ }
+ }
+
+ const messageInput = document.getElementById('message-input');
+ const sendButton = document.getElementById('send-button');
+ const refreshButton = document.getElementById('refresh-button');
+ const chatMessages = document.getElementById('chat-messages');
+
+ let currentMessageDiv = null;
+ let messageHistory = [];
+
+ // Create a processing indicator element
+ const processingIndicator = document.createElement('div');
+ processingIndicator.className = 'processing-indicator';
+ processingIndicator.textContent = 'Processing tool calls...';
+ chatMessages.appendChild(processingIndicator);
+
+ function clearChat() {
+ // Clear the chat messages
+ chatMessages.innerHTML = '';
+ // Add back the processing indicator
+ chatMessages.appendChild(processingIndicator);
+ // Reset message history
+ messageHistory = [];
+ // Clear the input field
+ messageInput.value = '';
+ // Enable input if it was disabled
+ messageInput.disabled = false;
+ sendButton.disabled = false;
+ }
+
+ function addMessage(content, isUser) {
+ const messageDiv = document.createElement('div');
+ messageDiv.className = `message ${isUser ? 'message-user' : 'message-assistant'}`;
+
+ messageDiv.classList.add('pre-formatted');
+
+ // Use marked.js to render markdown
+ try {
+ // Configure marked options
+ marked.use({
+ breaks: true, // Add line breaks on single newlines
+ gfm: true, // Use GitHub Flavored Markdown
+ headerIds: false, // Don't add IDs to headers
+ mangle: false, // Don't mangle email addresses
+ });
+
+ // Render markdown to HTML
+ messageDiv.innerHTML = marked.parse(content);
+ } catch (e) {
+ console.error('Error rendering markdown:', e);
+ // Fallback to innerText if markdown parsing fails
+ messageDiv.innerText = content;
+ }
+
+ chatMessages.appendChild(messageDiv);
+ chatMessages.scrollTop = chatMessages.scrollHeight;
+ return messageDiv;
+ }
+
+ function showProcessingIndicator() {
+ processingIndicator.classList.add('visible');
+ chatMessages.scrollTop = chatMessages.scrollHeight;
+
+ // Disable send button and input field during tool processing
+ sendButton.disabled = true;
+ messageInput.disabled = true;
+ }
+
+ function hideProcessingIndicator() {
+ processingIndicator.classList.remove('visible');
+
+ // Re-enable send button and input field after tool processing
+ sendButton.disabled = false;
+ messageInput.disabled = false;
+ }
+
+ async function sendMessage() {
+ const message = messageInput.value.trim();
+ if (!message) return;
+
+ // Get selected model
+ const selectedModel = document.querySelector('input[name="model"]:checked');
+ if (!selectedModel) {
+ alert('Please select a model');
+ return;
+ }
+
+ const [provider, modelId] = selectedModel.value.split(':');
+ const modelName = provider === 'openrouter' ? modelNameInput.value : modelId;
+
+ // Clear input and add user message
+ messageInput.value = '';
+ addMessage(message, true);
+
+ // Add user message to history
+ messageHistory.push({
+ role: 'user',
+ content: message
+ });
+
+ // Use full message history for the request
+ const messages = [...messageHistory];
+
+ // Create assistant message div
+ currentMessageDiv = addMessage('', false);
+
+ // Get API key
+ const apiKey = document.getElementById('api-key').value.trim();
+ if (!apiKey) {
+ alert('Please enter an API key');
+ return;
+ }
+
+ // Disable input while processing
+ messageInput.disabled = true;
+ sendButton.disabled = true;
+
+ // Get CSRF token
+ const csrfToken = document.querySelector('meta[name="csrf-token"]')?.getAttribute('content');
+ if (!csrfToken) {
+ alert('CSRF token not found. Please refresh the page.');
+ return;
+ }
+
+ // Create request data
+ const requestData = {
+ provider: provider,
+ model: modelName,
+ messages: messages,
+ api_key: apiKey,
+ stream: true,
+ temperature: 0.4,
+ };
+ console.log('Sending request:', {...requestData, api_key: '***'});
+
+ try {
+ // Send request
+ const response = await fetch('/wingman/chat', {
+ method: 'POST',
+ headers: {
+ 'Content-Type': 'application/json',
+ 'X-CSRFToken': csrfToken
+ },
+ body: JSON.stringify(requestData)
+ });
+
+ if (!response.ok) {
+ const error = await response.json();
+ throw new Error(error.error || 'Failed to get response');
+ }
+
+ // Process the streaming response
+ const reader = response.body.getReader();
+ const decoder = new TextDecoder();
+ let fullResponse = '';
+
+ while (true) {
+ const { value, done } = await reader.read();
+ if (done) break;
+
+ const chunk = decoder.decode(value);
+ const lines = chunk.split('\n');
+
+ for (const line of lines) {
+ if (line.trim() === '') continue;
+
+ if (line.startsWith('data: ')) {
+ const content = line.slice(6); // Remove 'data: ' prefix
+
+ // Check for special events or end marker
+ if (content === '[DONE]') {
+ console.log('Stream complete');
+
+ // Add assistant's response to history
+ if (fullResponse) {
+ messageHistory.push({
+ role: 'assistant',
+ content: fullResponse
+ });
+ }
+ continue;
+ }
+
+ // Try to parse as JSON for special events
+ try {
+ const parsed = JSON.parse(content);
+
+ if (parsed.event === 'tool_processing_start') {
+ console.log('Tool processing started');
+ showProcessingIndicator();
+ continue;
+ }
+
+ if (parsed.event === 'replace_content') {
+ console.log('Replacing content due to tool call');
+ // Clear the current message content
+ const currentMessageDiv = document.querySelector('.message.assistant:last-child .message-content');
+ if (currentMessageDiv) {
+ currentMessageDiv.innerHTML = '';
+ fullResponse = ''; // Reset the full response
+ }
+ continue;
+ }
+
+ if (parsed.event === 'tool_processing_complete') {
+ console.log('Tool processing completed');
+ hideProcessingIndicator();
+ continue;
+ }
+
+ // Handle follow-up response event
+ if (parsed.event === 'follow_up_response' && parsed.content) {
+ console.log('Received follow-up response');
+
+ // Add this follow-up response to message history
+ messageHistory.push({
+ role: 'assistant',
+ content: parsed.content
+ });
+
+ // Create a new message div for the follow-up response
+ // The addMessage function already handles markdown rendering
+ addMessage(parsed.content, false);
+ continue;
+ }
+
+ // Handle the complete response event
+ if (parsed.event === 'complete_response') {
+ console.log('Received complete response from backend');
+ // Use the complete response from the backend
+ fullResponse = parsed.content;
+
+ // Use marked.js to render markdown
+ try {
+ // Configure marked options
+ marked.use({
+ breaks: true, // Add line breaks on single newlines
+ gfm: true, // Use GitHub Flavored Markdown
+ headerIds: false, // Don't add IDs to headers
+ mangle: false, // Don't mangle email addresses
+ });
+
+ // Render markdown to HTML
+ currentMessageDiv.innerHTML = marked.parse(fullResponse);
+ } catch (e) {
+ console.error('Error rendering markdown:', e);
+ // Fallback to innerText if markdown parsing fails
+ currentMessageDiv.innerText = fullResponse;
+ }
+ continue;
+ }
+
+ // If we have JSON that's not a special event, it might be content
+ currentMessageDiv.textContent += JSON.stringify(parsed);
+ fullResponse += JSON.stringify(parsed);
+ } catch (e) {
+ // Not JSON, handle as normal content
+ // console.log('Received chunk:', JSON.stringify(content));
+
+ // Add to full response
+ fullResponse += content;
+
+ // Create a properly formatted display
+ if (!currentMessageDiv.classList.contains('pre-formatted')) {
+ currentMessageDiv.classList.add('pre-formatted');
+ }
+
+ // Always rebuild the entire content from the full response
+ currentMessageDiv.innerHTML = marked.parse(fullResponse);
+ }
+ // Scroll to bottom
+ chatMessages.scrollTop = chatMessages.scrollHeight;
+ }
+ }
+ }
+ } catch (error) {
+ console.error('Error:', error);
+ if (currentMessageDiv) {
+ currentMessageDiv.textContent = `Error: ${error.message}`;
+ currentMessageDiv.style.color = 'red';
+ }
+ } finally {
+ // Always re-enable input and hide indicators
+ messageInput.disabled = false;
+ sendButton.disabled = false;
+ hideProcessingIndicator();
+ }
+ }
+
+ sendButton.addEventListener('click', sendMessage);
+ messageInput.addEventListener('keypress', function(e) {
+ if (e.key === 'Enter') {
+ sendMessage();
+ }
+ });
+
+ refreshButton.addEventListener('click', clearChat);
+});
diff --git a/src/airflow_wingman/templates/wingman_chat.html b/src/airflow_wingman/templates/wingman_chat.html
index 72657f0..c1da377 100644
--- a/src/airflow_wingman/templates/wingman_chat.html
+++ b/src/airflow_wingman/templates/wingman_chat.html
@@ -3,6 +3,7 @@
{% block head_meta %}
{{ super() }}
+
{% endblock %}
{% block content %}
@@ -77,25 +78,7 @@
-
+
@@ -129,256 +112,6 @@
-
-
-
+
+
{% endblock %}
diff --git a/src/airflow_wingman/tools/__init__.py b/src/airflow_wingman/tools/__init__.py
new file mode 100644
index 0000000..44e9360
--- /dev/null
+++ b/src/airflow_wingman/tools/__init__.py
@@ -0,0 +1,15 @@
+"""
+Tools module for Airflow Wingman.
+
+This module contains the tools used by Airflow Wingman to interact with Airflow.
+"""
+
+from airflow_wingman.tools.conversion import convert_to_anthropic_tools, convert_to_openai_tools
+from airflow_wingman.tools.execution import execute_airflow_tool, list_airflow_tools
+
+__all__ = [
+ "convert_to_openai_tools",
+ "convert_to_anthropic_tools",
+ "list_airflow_tools",
+ "execute_airflow_tool",
+]
diff --git a/src/airflow_wingman/tools/conversion.py b/src/airflow_wingman/tools/conversion.py
new file mode 100644
index 0000000..09010e2
--- /dev/null
+++ b/src/airflow_wingman/tools/conversion.py
@@ -0,0 +1,143 @@
+"""
+Conversion utilities for Airflow Wingman tools.
+
+This module contains functions to convert between different tool formats
+for various LLM providers (OpenAI, Anthropic, etc.).
+"""
+
+import logging
+from typing import Any
+
+
+def convert_to_openai_tools(airflow_tools: list) -> list:
+ """
+ Convert Airflow tools to OpenAI tool definitions.
+
+ Args:
+ airflow_tools: List of Airflow tools from MCP server
+
+ Returns:
+ List of OpenAI tool definitions
+ """
+ openai_tools = []
+
+ for tool in airflow_tools:
+ # Initialize the OpenAI tool structure
+ openai_tool = {"type": "function", "function": {"name": tool.name, "description": tool.description or tool.name, "parameters": {"type": "object", "properties": {}, "required": []}}}
+
+ # Extract parameters directly from inputSchema if available
+ if hasattr(tool, "inputSchema") and tool.inputSchema:
+ # Set the type and required fields directly from the schema
+ if "type" in tool.inputSchema:
+ openai_tool["function"]["parameters"]["type"] = tool.inputSchema["type"]
+
+ # Add required parameters if specified
+ if "required" in tool.inputSchema:
+ openai_tool["function"]["parameters"]["required"] = tool.inputSchema["required"]
+
+ # Add properties from the input schema
+ if "properties" in tool.inputSchema:
+ for param_name, param_info in tool.inputSchema["properties"].items():
+ # Create parameter definition
+ param_def = {}
+
+ # Handle different schema constructs
+ if "anyOf" in param_info:
+ _handle_schema_construct(param_def, param_info, "anyOf")
+ elif "oneOf" in param_info:
+ _handle_schema_construct(param_def, param_info, "oneOf")
+ elif "allOf" in param_info:
+ _handle_schema_construct(param_def, param_info, "allOf")
+ elif "type" in param_info:
+ param_def["type"] = param_info["type"]
+ # Add format if available
+ if "format" in param_info:
+ param_def["format"] = param_info["format"]
+ else:
+ param_def["type"] = "string" # Default type
+
+ # Add description from title or param name
+ param_def["description"] = param_info.get("description", param_info.get("title", param_name))
+
+ # Add enum values if available
+ if "enum" in param_info:
+ param_def["enum"] = param_info["enum"]
+
+ # Add default value if available
+ if "default" in param_info and param_info["default"] is not None:
+ param_def["default"] = param_info["default"]
+
+ # Add to properties
+ openai_tool["function"]["parameters"]["properties"][param_name] = param_def
+
+ openai_tools.append(openai_tool)
+
+ return openai_tools
+
+
+def convert_to_anthropic_tools(airflow_tools: list) -> list:
+ """
+ Convert Airflow tools to Anthropic tool definitions.
+
+ Args:
+ airflow_tools: List of Airflow tools from MCP server
+
+ Returns:
+ List of Anthropic tool definitions
+ """
+ logger = logging.getLogger("airflow.plugins.wingman")
+ logger.info(f"Converting {len(airflow_tools)} Airflow tools to Anthropic format")
+ anthropic_tools = []
+
+ for tool in airflow_tools:
+ # Initialize the Anthropic tool structure
+ anthropic_tool = {"name": tool.name, "description": tool.description or tool.name, "input_schema": {}}
+
+ # Extract parameters directly from inputSchema if available
+ if hasattr(tool, "inputSchema") and tool.inputSchema:
+ # Copy the input schema directly as Anthropic's format is similar to JSON Schema
+ anthropic_tool["input_schema"] = tool.inputSchema
+ else:
+ # Create a minimal schema if none exists
+ anthropic_tool["input_schema"] = {"type": "object", "properties": {}, "required": []}
+
+ anthropic_tools.append(anthropic_tool)
+
+ logger.info(f"Converted {len(anthropic_tools)} tools to Anthropic format")
+ return anthropic_tools
+
+
+def _handle_schema_construct(param_def: dict[str, Any], param_info: dict[str, Any], construct_type: str) -> None:
+ """
+ Helper function to handle JSON Schema constructs like anyOf, oneOf, allOf.
+
+ Args:
+ param_def: Parameter definition to update
+ param_info: Parameter info from the schema
+ construct_type: Type of construct (anyOf, oneOf, allOf)
+ """
+ # Get the list of schemas from the construct
+ schemas = param_info[construct_type]
+
+ # Find the first schema with a type
+ for schema in schemas:
+ if "type" in schema:
+ param_def["type"] = schema["type"]
+
+ # Add format if available
+ if "format" in schema:
+ param_def["format"] = schema["format"]
+
+ # Add enum values if available
+ if "enum" in schema:
+ param_def["enum"] = schema["enum"]
+
+ # Add default value if available
+ if "default" in schema and schema["default"] is not None:
+ param_def["default"] = schema["default"]
+
+ break
+
+ # If no type was found, default to string
+ if "type" not in param_def:
+ param_def["type"] = "string"
diff --git a/src/airflow_wingman/tools/execution.py b/src/airflow_wingman/tools/execution.py
new file mode 100644
index 0000000..5dc88b8
--- /dev/null
+++ b/src/airflow_wingman/tools/execution.py
@@ -0,0 +1,153 @@
+"""
+Tool execution module for Airflow Wingman.
+
+This module contains functions to list and execute Airflow tools.
+"""
+
+import asyncio
+import json
+import logging
+import traceback
+
+from airflow import configuration
+from airflow_mcp_server.config import AirflowConfig
+from airflow_mcp_server.tools.tool_manager import get_airflow_tools, get_tool
+
+# Create a properly namespaced logger for the Airflow plugin
+logger = logging.getLogger("airflow.plugins.wingman")
+
+
+async def _list_airflow_tools_async(cookie: str) -> list:
+ """
+ Async implementation to list available Airflow tools.
+
+ Args:
+ cookie: Cookie for authentication
+
+ Returns:
+ List of available Airflow tools
+ """
+ try:
+ # Set up configuration
+ base_url = f"{configuration.conf.get('webserver', 'base_url')}/api/v1/"
+ logger.info(f"Setting up AirflowConfig with base_url: {base_url}")
+
+ # Format the cookie properly if it doesn't already have the 'session=' prefix
+ formatted_cookie = cookie
+ if cookie and not cookie.startswith("session="):
+ formatted_cookie = f"session={cookie}"
+ logger.info(f"Formatted cookie with session prefix: {formatted_cookie[:10]}...")
+
+ config = AirflowConfig(base_url=base_url, cookie=formatted_cookie, auth_token=None)
+
+ # Get available tools
+ logger.info("Getting Airflow tools...")
+ tools = await get_airflow_tools(config=config, mode="safe")
+ logger.info(f"Got {len(tools)} tools")
+ return tools
+ except Exception as e:
+ error_msg = f"Error listing Airflow tools: {str(e)}\n{traceback.format_exc()}"
+ logger.error(error_msg)
+ return []
+
+
+def list_airflow_tools(cookie: str) -> list:
+ """
+ Synchronous wrapper to list available Airflow tools.
+
+ Args:
+ cookie: Cookie for authentication
+
+ Returns:
+ List of available Airflow tools
+ """
+ return asyncio.run(_list_airflow_tools_async(cookie))
+
+
+async def _execute_airflow_tool_async(tool_name: str, arguments: dict, cookie: str) -> str:
+ """
+ Async implementation to execute an Airflow tool.
+
+ Args:
+ tool_name: Name of the tool to execute
+ arguments: Arguments to pass to the tool
+ cookie: Cookie for authentication
+
+ Returns:
+ Result of the tool execution as a string
+ """
+ try:
+ # Set up configuration
+ base_url = f"{configuration.conf.get('webserver', 'base_url')}/api/v1/"
+ logger.info(f"Setting up AirflowConfig with base_url: {base_url}")
+
+ # Format the cookie properly if it doesn't already have the 'session=' prefix
+ formatted_cookie = cookie
+ if cookie and not cookie.startswith("session="):
+ formatted_cookie = f"session={cookie}"
+ logger.info(f"Formatted cookie with session prefix: {formatted_cookie[:10]}...")
+
+ config = AirflowConfig(base_url=base_url, cookie=formatted_cookie, auth_token=None)
+
+ # Get the tool
+ logger.info(f"Getting tool: {tool_name}")
+ tool = await get_tool(config=config, name=tool_name)
+
+ if not tool:
+ error_msg = f"Tool not found: {tool_name}"
+ logger.error(error_msg)
+ return json.dumps({"error": error_msg})
+
+ # Execute the tool - ensure the client is in an async context
+ logger.info(f"Executing tool: {tool_name} with arguments: {arguments}")
+
+ # The AirflowClient needs to be used as an async context manager
+ # to properly initialize its session
+ async with tool.client as client: # noqa F841
+ # Now the client has a _session attribute and is in an async context
+ result = await tool.run(arguments)
+
+ # Convert result to string
+ if isinstance(result, dict | list):
+ result_str = json.dumps(result, indent=2)
+ else:
+ result_str = str(result)
+
+ logger.info(f"Tool execution result: {result_str[:100]}...")
+ return result_str
+ except Exception as e:
+ error_msg = f"Error executing tool: {str(e)}\n{traceback.format_exc()}"
+ logger.error(error_msg)
+ return json.dumps({"error": error_msg})
+
+
+def execute_airflow_tool(tool_name: str, arguments: dict, cookie: str) -> str:
+ """
+ Synchronous wrapper to execute an Airflow tool.
+
+ Args:
+ tool_name: Name of the tool to execute
+ arguments: Arguments to pass to the tool
+ cookie: Cookie for authentication
+
+ Returns:
+ Result of the tool execution as a string
+ """
+ # Create a new event loop for this execution
+ # This ensures we're always in a clean async context
+ loop = asyncio.new_event_loop()
+
+ try:
+ # Set the event loop for this thread
+ asyncio.set_event_loop(loop)
+
+ # Run the async function in the new event loop
+ result = loop.run_until_complete(_execute_airflow_tool_async(tool_name, arguments, cookie))
+ return result
+ except Exception as e:
+ error_msg = f"Error in execute_airflow_tool: {str(e)}\n{traceback.format_exc()}"
+ logger.error(error_msg)
+ return json.dumps({"error": error_msg})
+ finally:
+ # Always close the loop to free resources
+ loop.close()
diff --git a/src/airflow_wingman/views.py b/src/airflow_wingman/views.py
index 34158d7..12825fc 100644
--- a/src/airflow_wingman/views.py
+++ b/src/airflow_wingman/views.py
@@ -1,6 +1,9 @@
"""Views for Airflow Wingman plugin."""
-from flask import Response, request, stream_with_context
+import json
+import logging
+
+from flask import Response, request, session
from flask.json import jsonify
from flask_appbuilder import BaseView as AppBuilderBaseView, expose
@@ -8,6 +11,10 @@ from airflow_wingman.llm_client import LLMClient
from airflow_wingman.llms_models import MODELS
from airflow_wingman.notes import INTERFACE_MESSAGES
from airflow_wingman.prompt_engineering import prepare_messages
+from airflow_wingman.tools import list_airflow_tools
+
+# Create a properly namespaced logger for the Airflow plugin
+logger = logging.getLogger("airflow.plugins.wingman")
class WingmanView(AppBuilderBaseView):
@@ -28,8 +35,43 @@ class WingmanView(AppBuilderBaseView):
try:
data = self._validate_chat_request(request.get_json())
- # Create a new client for this request
- client = LLMClient(data["api_key"])
+ if data.get("cookie"):
+ session["airflow_cookie"] = data["cookie"]
+
+ # Get available Airflow tools using the stored cookie
+ airflow_tools = []
+ airflow_cookie = request.cookies.get("session")
+ if airflow_cookie:
+ try:
+ airflow_tools = list_airflow_tools(airflow_cookie)
+ logger.info(f"Loaded {len(airflow_tools)} Airflow tools")
+ if len(airflow_tools) > 0:
+ logger.info(f"First tool: {airflow_tools[0].name if hasattr(airflow_tools[0], 'name') else 'Unknown'}")
+ else:
+ logger.warning("No Airflow tools were loaded")
+ except Exception as e:
+ # Log the error but continue without tools
+ logger.error(f"Error fetching Airflow tools: {str(e)}")
+
+ # Prepare messages with Airflow tools included in the prompt
+ data["messages"] = prepare_messages(data["messages"])
+
+ # Get provider name from request or use default
+ provider_name = data.get("provider", "openai")
+
+ # Get base URL from models configuration based on provider
+ base_url = MODELS.get(provider_name, {}).get("endpoint")
+
+ # Log the request parameters (excluding API key for security)
+ safe_data = {k: v for k, v in data.items() if k != "api_key"}
+ logger.info(f"Chat request: provider={provider_name}, model={data.get('model')}, stream={data.get('stream')}")
+ logger.info(f"Request parameters: {json.dumps(safe_data)[:200]}...")
+
+ # Create a new client for this request with the appropriate provider
+ client = LLMClient(provider_name=provider_name, api_key=data["api_key"], base_url=base_url)
+
+ # Set the Airflow tools for the client to use
+ client.set_airflow_tools(airflow_tools)
if data["stream"]:
return self._handle_streaming_response(client, data)
@@ -46,39 +88,116 @@ class WingmanView(AppBuilderBaseView):
if not data:
raise ValueError("No data provided")
- required_fields = ["provider", "model", "messages", "api_key"]
+ required_fields = ["model", "messages", "api_key"]
missing = [f for f in required_fields if not data.get(f)]
if missing:
raise ValueError(f"Missing required fields: {', '.join(missing)}")
- # Prepare messages with system instruction while maintaining history
- messages = data["messages"]
- messages = prepare_messages(messages)
+ # Validate provider if provided
+ provider = data.get("provider", "openai")
+ if provider not in MODELS:
+ raise ValueError(f"Unsupported provider: {provider}. Supported providers: {', '.join(MODELS.keys())}")
return {
- "provider": data["provider"],
"model": data["model"],
- "messages": messages,
+ "messages": data["messages"],
"api_key": data["api_key"],
- "stream": data.get("stream", False),
- "temperature": data.get("temperature", 0.7),
+ "stream": data.get("stream", True),
+ "temperature": data.get("temperature", 0.4),
"max_tokens": data.get("max_tokens"),
+ "cookie": data.get("cookie"),
+ "provider": provider,
+ "base_url": data.get("base_url"),
}
def _handle_streaming_response(self, client: LLMClient, data: dict) -> Response:
"""Handle streaming response."""
+ try:
+ logger.info("Beginning streaming response")
+ # Get the cookie at the beginning of the request handler
+ airflow_cookie = request.cookies.get("session")
+ logger.info(f"Got airflow_cookie: {airflow_cookie is not None}")
- def generate():
- for chunk in client.chat_completion(messages=data["messages"], model=data["model"], provider=data["provider"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=True):
- yield f"data: {chunk}\n\n"
+ # Use the enhanced chat_completion method with return_response_obj=True
+ streaming_response = client.chat_completion(messages=data["messages"], model=data["model"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=True)
- response = Response(stream_with_context(generate()), mimetype="text/event-stream")
- response.headers["Content-Type"] = "text/event-stream"
- response.headers["Cache-Control"] = "no-cache"
- response.headers["Connection"] = "keep-alive"
- return response
+ def stream_response(cookie=airflow_cookie):
+ complete_response = ""
+
+ # Stream the initial response
+ for chunk in streaming_response:
+ if chunk:
+ complete_response += chunk
+ yield f"data: {chunk}\n\n"
+
+ # Log the complete assembled response
+ logger.info("COMPLETE RESPONSE START >>>")
+ logger.info(complete_response)
+ logger.info("<<< COMPLETE RESPONSE END")
+
+ # Check for tool calls and make follow-up if needed
+ has_tool_calls = client.provider.has_tool_calls(streaming_response)
+ logger.info(f"Has tool calls: {has_tool_calls}")
+ if has_tool_calls:
+ # Signal tool processing start - frontend should disable send button
+ yield f"data: {json.dumps({'event': 'tool_processing_start'})}\n\n"
+
+ # Signal to replace content - frontend should clear the current message
+ yield f"data: {json.dumps({'event': 'replace_content'})}\n\n"
+
+ logger.info("Response contains tool calls, making follow-up request")
+ logger.info(f"Using cookie from closure: {cookie is not None}")
+
+ # Process tool calls and get follow-up response (handles recursive tool calls)
+ # Always stream the follow-up response for consistent handling
+ follow_up_response = client.process_tool_calls_and_follow_up(
+ streaming_response, data["messages"], data["model"], data["temperature"], data["max_tokens"], cookie=cookie, stream=True
+ )
+
+ # Collect the follow-up response
+ follow_up_complete_response = ""
+ for chunk in follow_up_response:
+ if chunk:
+ follow_up_complete_response += chunk
+
+ # Send the follow-up response as a single event
+ if follow_up_complete_response:
+ follow_up_event = json.dumps({'event': 'follow_up_response', 'content': follow_up_complete_response})
+ logger.info(f"Follow-up event created with length: {len(follow_up_event)}")
+ data_line = f"data: {follow_up_event}\n\n"
+ logger.info(f"Yielding data line with length: {len(data_line)}")
+ yield data_line
+
+ # Log the complete follow-up response
+ logger.info("FOLLOW-UP RESPONSE START >>>")
+ logger.info(follow_up_complete_response)
+ logger.info("<<< FOLLOW-UP RESPONSE END")
+
+ # Signal tool processing complete - frontend can re-enable send button
+ yield f"data: {json.dumps({'event': 'tool_processing_complete'})}\n\n"
+
+ # Send the complete response as a special event (for compatibility with existing code)
+ complete_event = json.dumps({"event": "complete_response", "content": complete_response})
+ yield f"data: {complete_event}\n\n"
+
+ # Signal end of stream
+ yield "data: [DONE]\n\n"
+
+ return Response(stream_response(), mimetype="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
+ except Exception as e:
+ logger.error(f"Streaming error: {str(e)}")
+ return jsonify({"error": str(e)}), 500
def _handle_regular_response(self, client: LLMClient, data: dict) -> Response:
"""Handle regular response."""
- response = client.chat_completion(messages=data["messages"], model=data["model"], provider=data["provider"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=False)
- return jsonify(response)
+ try:
+ logger.info("Beginning regular (non-streaming) response")
+ response = client.chat_completion(messages=data["messages"], model=data["model"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=False)
+ logger.info("COMPLETE RESPONSE START >>>")
+ logger.info(f"Response to frontend: {json.dumps(response)}")
+ logger.info("<<< COMPLETE RESPONSE END")
+
+ return jsonify(response)
+ except Exception as e:
+ logger.error(f"Regular response error: {str(e)}")
+ return jsonify({"error": str(e)}), 500