Merge pull request #2 from abhishekbhakat/1-llm-chat-to-have-access-to-tools

1 llm chat to have access to tools
This commit is contained in:
2025-03-02 18:29:50 +00:00
committed by GitHub
16 changed files with 2227 additions and 380 deletions

2
.gitignore vendored
View File

@@ -174,3 +174,5 @@ cython_debug/
# Local Resources # Local Resources
plugins_reference/ plugins_reference/
astro/ astro/
node_modules/

View File

@@ -9,9 +9,10 @@ authors = [
{name = "Abhishek Bhakat", email = "abhishek.bhakat@hotmail.com"} {name = "Abhishek Bhakat", email = "abhishek.bhakat@hotmail.com"}
] ]
dependencies = [ dependencies = [
"airflow-mcp-server>=0.4.0",
"anthropic>=0.46.0",
"apache-airflow>=2.10.0", "apache-airflow>=2.10.0",
"openai>=1.64.0", "openai>=1.64.0",
"anthropic>=0.46.0"
] ]
classifiers = [ classifiers = [
"Development Status :: 3 - Alpha", "Development Status :: 3 - Alpha",
@@ -31,6 +32,13 @@ Issues = "https://github.com/abhishekbhakat/airflow-wingman/issues"
[project.entry-points."airflow.plugins"] [project.entry-points."airflow.plugins"]
wingman = "airflow_wingman:WingmanPlugin" wingman = "airflow_wingman:WingmanPlugin"
[project.optional-dependencies]
dev = [
"build>=1.2.2",
"pre-commit>=4.0.1",
"ruff>=0.9.2"
]
[build-system] [build-system]
requires = ["hatchling"] requires = ["hatchling"]
build-backend = "hatchling.build" build-backend = "hatchling.build"
@@ -60,7 +68,8 @@ lint.select = [
lint.ignore = [ lint.ignore = [
"C416", # Unnecessary list comprehension - rewrite as a generator expression "C416", # Unnecessary list comprehension - rewrite as a generator expression
"C408", # Unnecessary `dict` call - rewrite as a literal "C408", # Unnecessary `dict` call - rewrite as a literal
"ISC001" # Single line implicit string concatenation "ISC001", # Single line implicit string concatenation
"C901"
] ]
lint.fixable = ["ALL"] lint.fixable = ["ALL"]

View File

@@ -1,109 +1,266 @@
""" """
Client for making API calls to various LLM providers using their official SDKs. Multi-provider LLM client for Airflow Wingman.
This module contains the LLMClient class that supports multiple LLM providers
(OpenAI, Anthropic, OpenRouter) through a unified interface.
""" """
from collections.abc import Generator import logging
import traceback
from typing import Any
from anthropic import Anthropic from flask import session
from openai import OpenAI
from airflow_wingman.providers import create_llm_provider
from airflow_wingman.tools import list_airflow_tools
# Create a properly namespaced logger for the Airflow plugin
logger = logging.getLogger("airflow.plugins.wingman")
class LLMClient: class LLMClient:
def __init__(self, api_key: str): """
"""Initialize the LLM client. Multi-provider LLM client for Airflow Wingman.
This class handles chat completion requests to various LLM providers
(OpenAI, Anthropic, OpenRouter) through a unified interface.
"""
def __init__(self, provider_name: str, api_key: str, base_url: str | None = None):
"""
Initialize the LLM client.
Args: Args:
provider_name: Name of the provider (openai, anthropic, openrouter)
api_key: API key for the provider api_key: API key for the provider
base_url: Optional base URL for the provider API
""" """
self.provider_name = provider_name
self.api_key = api_key self.api_key = api_key
self.openai_client = OpenAI(api_key=api_key) self.base_url = base_url
self.anthropic_client = Anthropic(api_key=api_key) self.provider = create_llm_provider(provider_name, api_key, base_url)
self.openrouter_client = OpenAI( self.airflow_tools = []
base_url="https://openrouter.ai/api/v1",
api_key=api_key, def set_airflow_tools(self, tools: list):
default_headers={ """
"HTTP-Referer": "Airflow Wingman", # Required by OpenRouter Set the available Airflow tools.
"X-Title": "Airflow Wingman", # Required by OpenRouter
}, Args:
) tools: List of Airflow Tool objects
"""
self.airflow_tools = tools
def chat_completion( def chat_completion(
self, messages: list[dict[str, str]], model: str, provider: str, temperature: float = 0.7, max_tokens: int | None = None, stream: bool = False self, messages: list[dict[str, str]], model: str, temperature: float = 0.4, max_tokens: int | None = None, stream: bool = True, return_response_obj: bool = False
) -> Generator[str, None, None] | dict: ) -> dict[str, Any] | tuple[Any, Any]:
"""Send a chat completion request to the specified provider. """
Send a chat completion request to the LLM provider.
Args: Args:
messages: List of message dictionaries with 'role' and 'content' messages: List of message dictionaries with 'role' and 'content'
model: Model identifier model: Model identifier
provider: Provider identifier (openai, anthropic, openrouter)
temperature: Sampling temperature (0-1) temperature: Sampling temperature (0-1)
max_tokens: Maximum tokens to generate max_tokens: Maximum tokens to generate
stream: Whether to stream the response (default is True)
return_response_obj: If True and streaming, returns both the response object and generator
Returns:
If stream=False: Dictionary with the response content
If stream=True and return_response_obj=False: Generator for streaming
If stream=True and return_response_obj=True: Tuple of (response_obj, generator)
"""
# Get provider-specific tool definitions from Airflow tools
provider_tools = self.provider.convert_tools(self.airflow_tools)
try:
# Make the initial request with tools
logger.info(f"Sending chat completion request to {self.provider_name} with model: {model}")
response = self.provider.create_chat_completion(messages=messages, model=model, temperature=temperature, max_tokens=max_tokens, stream=stream, tools=provider_tools)
logger.info(f"Received response from {self.provider_name}")
# If streaming, handle based on return_response_obj flag
if stream:
logger.info(f"Using streaming response from {self.provider_name}")
streaming_content = self.provider.get_streaming_content(response)
return streaming_content
# For non-streaming responses, handle tool calls if present
if self.provider.has_tool_calls(response):
logger.info("Response contains tool calls")
# Process tool calls and get results
cookie = session.get("airflow_cookie")
if not cookie:
error_msg = "No Airflow cookie available"
logger.error(error_msg)
return {"error": error_msg}
tool_results = self.provider.process_tool_calls(response, cookie)
# Create a follow-up completion with the tool results
logger.info("Making follow-up request with tool results")
follow_up_response = self.provider.create_follow_up_completion(
messages=messages, model=model, temperature=temperature, max_tokens=max_tokens, tool_results=tool_results, original_response=response, tools=provider_tools
)
content = self.provider.get_content(follow_up_response)
logger.info(f"Final content from {self.provider_name} with tool calls COMPLETE RESPONSE START >>>")
logger.info(content)
logger.info("<<< COMPLETE RESPONSE END")
return {"content": content}
else:
logger.info("Response does not contain tool calls")
content = self.provider.get_content(response)
logger.info(f"Final content from {self.provider_name} without tool calls COMPLETE RESPONSE START >>>")
logger.info(content)
logger.info("<<< COMPLETE RESPONSE END")
return {"content": content}
except Exception as e:
error_msg = f"Error in {self.provider_name} API call: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
return {"error": f"API request failed: {str(e)}"}
@classmethod
def from_config(cls, config: dict[str, Any]) -> "LLMClient":
"""
Create an LLMClient instance from a configuration dictionary.
Args:
config: Configuration dictionary with provider_name, api_key, and optional base_url
Returns:
LLMClient instance
"""
provider_name = config.get("provider_name", "openai")
api_key = config.get("api_key")
base_url = config.get("base_url")
if not api_key:
raise ValueError("API key is required")
return cls(provider_name=provider_name, api_key=api_key, base_url=base_url)
def process_tool_calls_and_follow_up(self, response, messages, model, temperature, max_tokens, max_iterations=5, cookie=None, stream=True):
"""
Process tool calls recursively from a response and make follow-up requests until
there are no more tool calls or max_iterations is reached.
Returns a generator for streaming the final follow-up response.
Args:
response: The original response object containing tool calls
messages: List of message dictionaries with 'role' and 'content'
model: Model identifier
temperature: Sampling temperature (0-1)
max_tokens: Maximum tokens to generate
max_iterations: Maximum number of tool call iterations to prevent infinite loops
cookie: Airflow cookie for authentication (optional, will try to get from session if not provided)
stream: Whether to stream the response stream: Whether to stream the response
Returns: Returns:
If stream=True, returns a generator yielding response chunks Generator for streaming the final follow-up response
If stream=False, returns the complete response
""" """
try: try:
if provider == "openai": iteration = 0
return self._openai_chat_completion(messages, model, temperature, max_tokens, stream) current_response = response
elif provider == "anthropic":
return self._anthropic_chat_completion(messages, model, temperature, max_tokens, stream) # Check if we have a cookie
elif provider == "openrouter": if not cookie:
return self._openrouter_chat_completion(messages, model, temperature, max_tokens, stream) error_msg = "No Airflow cookie available"
else: logger.error(error_msg)
return {"error": f"Unknown provider: {provider}"} yield f"Error: {error_msg}"
return
# Process tool calls recursively until there are no more or max_iterations is reached
while self.provider.has_tool_calls(current_response) and iteration < max_iterations:
iteration += 1
logger.info(f"Processing tool calls iteration {iteration}/{max_iterations}")
# Process tool calls and get results
tool_results = self.provider.process_tool_calls(current_response, cookie)
# Make follow-up request with tool results
logger.info(f"Making follow-up request with tool results (iteration {iteration})")
# Always stream follow-up requests to ensure consistent behavior
# This ensures we get streaming responses from the provider
should_stream = True
logger.info(f"Setting should_stream=True for follow-up request (iteration {iteration})")
# Get provider-specific tool definitions from Airflow tools
provider_tools = self.provider.convert_tools(self.airflow_tools)
follow_up_response = self.provider.create_follow_up_completion(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
tool_results=tool_results,
original_response=current_response,
stream=should_stream,
tools=provider_tools,
)
# Check if this follow-up response has more tool calls
if not self.provider.has_tool_calls(follow_up_response):
logger.info(f"No more tool calls after iteration {iteration}")
# Final response - always yield content in a streaming fashion
# Since we're always streaming now, we can directly yield chunks from the streaming generator
chunk_count = 0
for chunk in self.provider.get_streaming_content(follow_up_response):
chunk_count += 1
# logger.info(f"Yielding chunk {chunk_count} from streaming generator: {chunk[:50] if chunk else 'Empty chunk'}...")
yield chunk
logger.info(f"Finished yielding {chunk_count} chunks from streaming generator")
# Update current_response for the next iteration
current_response = follow_up_response
# If we've reached max_iterations and still have tool calls, log a warning
if iteration == max_iterations and self.provider.has_tool_calls(current_response):
logger.warning(f"Reached maximum tool call iterations ({max_iterations})")
# Stream the final response even if it has tool calls
if not should_stream:
# If we didn't stream this response, convert it to a single chunk
content = self.provider.get_content(follow_up_response)
logger.info(f"Yielding complete content as a single chunk (max iterations): {content[:100]}...")
yield content
logger.info("Finished yielding complete content (max iterations)")
else:
# Yield chunks from the streaming generator
logger.info("Starting to yield chunks from streaming generator (max iterations reached)")
chunk_count = 0
for chunk in self.provider.get_streaming_content(follow_up_response):
chunk_count += 1
logger.info(f"Yielding chunk {chunk_count} from streaming generator (max iterations)")
yield chunk
logger.info(f"Finished yielding {chunk_count} chunks from streaming generator (max iterations)")
# If we didn't process any tool calls (shouldn't happen), return an error
if iteration == 0:
error_msg = "No tool calls found in response"
logger.error(error_msg)
yield f"Error: {error_msg}"
except Exception as e: except Exception as e:
return {"error": f"API request failed: {str(e)}"} error_msg = f"Error processing tool calls: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
yield f"Error: {str(e)}"
def _openai_chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float, max_tokens: int | None, stream: bool): def refresh_tools(self, cookie: str) -> None:
"""Handle OpenAI chat completion requests.""" """
response = self.openai_client.chat.completions.create(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream) Refresh the available Airflow tools.
if stream: Args:
cookie: Airflow cookie for authentication
def response_generator(): """
for chunk in response: try:
if chunk.choices[0].delta.content: logger.info("Refreshing Airflow tools")
yield chunk.choices[0].delta.content tools = list_airflow_tools(cookie)
self.set_airflow_tools(tools)
return response_generator() logger.info(f"Refreshed {len(tools)} Airflow tools")
else: except Exception as e:
return {"content": response.choices[0].message.content} error_msg = f"Error refreshing Airflow tools: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
def _anthropic_chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float, max_tokens: int | None, stream: bool): # Don't raise the exception, just log it
"""Handle Anthropic chat completion requests.""" # The client will continue to use the existing tools (if any)
# Convert messages to Anthropic format
system_message = next((m["content"] for m in messages if m["role"] == "system"), None)
conversation = []
for m in messages:
if m["role"] != "system":
conversation.append({"role": "assistant" if m["role"] == "assistant" else "user", "content": m["content"]})
response = self.anthropic_client.messages.create(model=model, messages=conversation, system=system_message, temperature=temperature, max_tokens=max_tokens, stream=stream)
if stream:
def response_generator():
for chunk in response:
if chunk.delta.text:
yield chunk.delta.text
return response_generator()
else:
return {"content": response.content[0].text}
def _openrouter_chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float, max_tokens: int | None, stream: bool):
"""Handle OpenRouter chat completion requests."""
response = self.openrouter_client.chat.completions.create(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream)
if stream:
def response_generator():
for chunk in response:
if chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
return response_generator()
else:
return {"content": response.choices[0].message.content}

View File

@@ -17,14 +17,14 @@ MODELS = {
"endpoint": "https://api.anthropic.com/v1/messages", "endpoint": "https://api.anthropic.com/v1/messages",
"models": [ "models": [
{ {
"id": "claude-3.5-sonnet", "id": "claude-3-7-sonnet-20250219",
"name": "Claude 3.5 Sonnet", "name": "Claude 3.7 Sonnet",
"default": True, "default": True,
"context_window": 200000, "context_window": 200000,
"description": "Input $3/M tokens, Output $15/M tokens", "description": "Input $3/M tokens, Output $15/M tokens",
}, },
{ {
"id": "claude-3.5-haiku", "id": "claude-3-5-haiku-20241022",
"name": "Claude 3.5 Haiku", "name": "Claude 3.5 Haiku",
"default": False, "default": False,
"context_window": 200000, "context_window": 200000,

View File

@@ -8,9 +8,11 @@ INSTRUCTIONS = {
You have deep knowledge of Apache Airflow's architecture, DAGs, operators, and best practices. You have deep knowledge of Apache Airflow's architecture, DAGs, operators, and best practices.
The Airflow version being used is >=2.10. The Airflow version being used is >=2.10.
You have access to the following Airflow API tools: You have access to Airflow MCP tools that you can use to fetch information and help users understand
and manage their Airflow environment.
You can use these tools to fetch information and help users understand and manage their Airflow environment. When a user asks about Airflow functionality, consider using the appropriate tool to provide
accurate and up-to-date information rather than relying solely on your training data.
""" """
} }

View File

@@ -0,0 +1,41 @@
"""
Provider factory for Airflow Wingman.
This module contains the factory function to create provider instances
based on the provider name.
"""
from airflow_wingman.providers.anthropic_provider import AnthropicProvider
from airflow_wingman.providers.base import BaseLLMProvider
from airflow_wingman.providers.openai_provider import OpenAIProvider
def create_llm_provider(provider_name: str, api_key: str, base_url: str | None = None) -> BaseLLMProvider:
"""
Create a provider instance based on the provider name.
Args:
provider_name: Name of the provider (openai, anthropic, openrouter)
api_key: API key for the provider
base_url: Optional base URL for the provider API
Returns:
Provider instance
Raises:
ValueError: If the provider is not supported
"""
provider_name = provider_name.lower()
if provider_name == "openai":
return OpenAIProvider(api_key=api_key, base_url=base_url)
elif provider_name == "openrouter":
# OpenRouter uses the OpenAI API format, so we can use the OpenAI provider
# with a custom base URL
if not base_url:
base_url = "https://openrouter.ai/api/v1"
return OpenAIProvider(api_key=api_key, base_url=base_url)
elif provider_name == "anthropic":
return AnthropicProvider(api_key=api_key)
else:
raise ValueError(f"Unsupported provider: {provider_name}. Supported providers: openai, anthropic, openrouter")

View File

@@ -0,0 +1,458 @@
"""
Anthropic provider implementation for Airflow Wingman.
This module contains the Anthropic provider implementation that handles
API requests, tool conversion, and response processing for Anthropic's Claude models.
"""
import json
import logging
import traceback
from typing import Any
from anthropic import Anthropic
from airflow_wingman.providers.base import BaseLLMProvider, StreamingResponse
from airflow_wingman.tools import execute_airflow_tool
from airflow_wingman.tools.conversion import convert_to_anthropic_tools
# Create a properly namespaced logger for the Airflow plugin
logger = logging.getLogger("airflow.plugins.wingman")
class AnthropicProvider(BaseLLMProvider):
"""
Anthropic provider implementation.
This class handles API requests, tool conversion, and response processing
for the Anthropic API (Claude models).
"""
def __init__(self, api_key: str):
"""
Initialize the Anthropic provider.
Args:
api_key: API key for Anthropic
"""
self.api_key = api_key
self.client = Anthropic(api_key=api_key)
def convert_tools(self, airflow_tools: list) -> list:
"""
Convert Airflow tools to Anthropic format.
Args:
airflow_tools: List of Airflow tools from MCP server
Returns:
List of Anthropic tool definitions
"""
return convert_to_anthropic_tools(airflow_tools)
def create_chat_completion(
self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, stream: bool = False, tools: list[dict[str, Any]] | None = None
) -> Any:
"""
Make API request to Anthropic.
Args:
messages: List of message dictionaries with 'role' and 'content'
model: Model identifier
temperature: Sampling temperature (0-1)
max_tokens: Maximum tokens to generate
stream: Whether to stream the response
tools: List of tool definitions in Anthropic format
Returns:
Anthropic response object
Raises:
Exception: If the API request fails
"""
# Convert max_tokens to Anthropic's max_tokens parameter (if provided)
max_tokens_param = max_tokens if max_tokens is not None else 4096
# Convert messages from ChatML format to Anthropic's format
anthropic_messages = self._convert_to_anthropic_messages(messages)
try:
logger.info(f"Sending chat completion request to Anthropic with model: {model}")
# Create request parameters
params = {"model": model, "messages": anthropic_messages, "temperature": temperature, "max_tokens": max_tokens_param, "stream": stream}
# Add tools if provided
if tools and len(tools) > 0:
params["tools"] = tools
else:
logger.warning("No tools included in request")
# Log the full request parameters (with sensitive information redacted)
log_params = params.copy()
logger.info(f"Request parameters: {json.dumps(log_params)}")
# Make the API request
response = self.client.messages.create(**params)
logger.info("Received response from Anthropic")
# Log the response (with sensitive information redacted)
logger.info(f"Anthropic response type: {type(response).__name__}")
# Log as much information as possible
if hasattr(response, "json"):
if callable(response.json):
# If json is a method, call it
try:
logger.info(f"Anthropic response json: {json.dumps(response.model_dump_json())}")
except Exception as json_err:
logger.warning(f"Could not serialize response.json(): {str(json_err)}")
else:
# If json is a property, use it directly
try:
logger.info(f"Anthropic response json: {json.dumps(response.json)}")
except Exception as json_err:
logger.warning(f"Could not serialize response.json: {str(json_err)}")
# Log response attributes
response_attrs = [attr for attr in dir(response) if not attr.startswith("_") and not callable(getattr(response, attr))]
logger.info(f"Anthropic response attributes: {response_attrs}")
return response
except Exception as e:
error_msg = str(e)
logger.error(f"Failed to get response from Anthropic: {error_msg}\n{traceback.format_exc()}")
raise
def _convert_to_anthropic_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""
Convert messages from ChatML format to Anthropic's format.
Args:
messages: List of message dictionaries in ChatML format
Returns:
List of message dictionaries in Anthropic format
"""
anthropic_messages = []
for message in messages:
role = message["role"]
content = message["content"]
# Map ChatML roles to Anthropic roles
if role == "system":
# System messages in Anthropic are handled differently
# We'll add them as a user message with a special prefix
anthropic_messages.append({"role": "user", "content": f"<system>\n{content}\n</system>"})
elif role == "user":
anthropic_messages.append({"role": "user", "content": content})
elif role == "assistant":
anthropic_messages.append({"role": "assistant", "content": content})
elif role == "tool":
# Tool messages in ChatML become part of the user message in Anthropic
# We'll handle this in the follow-up completion
continue
return anthropic_messages
def has_tool_calls(self, response: Any) -> bool:
"""
Check if the response contains tool calls.
Args:
response: Anthropic response object or StreamingResponse with tool_call attribute
Returns:
True if the response contains tool calls, False otherwise
"""
logger.info(f"Checking for tool calls in response of type: {type(response)}")
# Check if response is a StreamingResponse with a tool_call attribute
if isinstance(response, StreamingResponse):
logger.info(f"Response is a StreamingResponse, has tool_call attribute: {hasattr(response, 'tool_call')}")
if response.tool_call is not None:
logger.info(f"StreamingResponse has non-None tool_call: {response.tool_call}")
return True
else:
logger.info("StreamingResponse has None tool_call")
else:
logger.info("Response is not a StreamingResponse")
# Check if any content block is a tool_use block (for non-streaming responses)
if hasattr(response, "content"):
logger.info(f"Response has content attribute with {len(response.content)} blocks")
for i, block in enumerate(response.content):
logger.info(f"Checking content block {i}: {type(block)}")
if isinstance(block, dict) and block.get("type") == "tool_use":
logger.info(f"Found tool_use block: {block}")
return True
else:
logger.info("Response does not have content attribute")
logger.info("No tool calls found in response")
return False
def get_tool_calls(self, response: Any) -> list:
"""
Extract tool calls from the response.
Args:
response: Anthropic response object or StreamingResponse with tool_call attribute
Returns:
List of tool call objects in a standardized format
"""
tool_calls = []
# Check if response is a StreamingResponse with a tool_call attribute
if isinstance(response, StreamingResponse) and response.tool_call is not None:
logger.info(f"Extracting tool call from StreamingResponse: {response.tool_call}")
tool_calls.append(response.tool_call)
# Otherwise, extract tool calls from response content (for non-streaming responses)
elif hasattr(response, "content"):
logger.info("Extracting tool calls from response content")
for block in response.content:
if isinstance(block, dict) and block.get("type") == "tool_use":
tool_call = {"id": block.get("id", ""), "name": block.get("name", ""), "input": block.get("input", {})}
tool_calls.append(tool_call)
return tool_calls
def process_tool_calls(self, response: Any, cookie: str) -> dict[str, Any]:
"""
Process tool calls from the response.
Args:
response: Anthropic response object or generator with tool_call attribute
cookie: Airflow cookie for authentication
Returns:
Dictionary mapping tool call IDs to results
"""
results = {}
if not self.has_tool_calls(response):
return results
# Get tool calls using the standardized method
tool_calls = self.get_tool_calls(response)
logger.info(f"Processing {len(tool_calls)} tool calls")
for tool_call in tool_calls:
# Extract tool details - handle both formats (generator's tool_call and content block)
if isinstance(tool_call, dict) and "id" in tool_call:
# This is from the generator's tool_call attribute
tool_id = tool_call.get("id")
tool_name = tool_call.get("name")
tool_input = tool_call.get("input", {})
else:
# This is from the content blocks
tool_id = tool_call.get("id")
tool_name = tool_call.get("name")
tool_input = tool_call.get("input", {})
try:
# Execute the Airflow tool with the provided arguments and cookie
logger.info(f"Executing tool: {tool_name} with arguments: {tool_input}")
result = execute_airflow_tool(tool_name, tool_input, cookie)
logger.info(f"Tool execution result: {result}")
results[tool_id] = {"status": "success", "result": result}
except Exception as e:
error_msg = f"Error executing tool: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
results[tool_id] = {"status": "error", "message": error_msg}
return results
def create_follow_up_completion(
self,
messages: list[dict[str, Any]],
model: str,
temperature: float = 0.4,
max_tokens: int | None = None,
tool_results: dict[str, Any] = None,
original_response: Any = None,
stream: bool = True,
tools: list[dict[str, Any]] | None = None,
) -> Any:
"""
Create a follow-up completion with tool results.
Args:
messages: Original messages
model: Model identifier
temperature: Sampling temperature (0-1)
max_tokens: Maximum tokens to generate
tool_results: Results of tool executions
original_response: Original response with tool calls
stream: Whether to stream the response
tools: List of tool definitions in Anthropic format
Returns:
Anthropic response object or generator if streaming
"""
if not original_response or not tool_results:
return original_response
# Extract tool call from the StreamingResponse or content blocks from Anthropic response
tool_use_blocks = []
if isinstance(original_response, StreamingResponse) and original_response.tool_call:
# For StreamingResponse, create a tool_use block from the tool_call
logger.info(f"Creating tool_use block from StreamingResponse.tool_call: {original_response.tool_call}")
tool_call = original_response.tool_call
tool_use_blocks.append({"type": "tool_use", "id": tool_call.get("id", ""), "name": tool_call.get("name", ""), "input": tool_call.get("input", {})})
elif hasattr(original_response, "content"):
# For regular Anthropic response, extract from content blocks
logger.info("Extracting tool_use blocks from response content")
tool_use_blocks = [block for block in original_response.content if isinstance(block, dict) and block.get("type") == "tool_use"]
# Create tool result blocks
tool_result_blocks = []
for tool_id, result in tool_results.items():
tool_result_blocks.append({"type": "tool_result", "tool_use_id": tool_id, "content": result.get("result", str(result))})
# Convert original messages to Anthropic format
anthropic_messages = self._convert_to_anthropic_messages(messages)
# Add the assistant response with tool use
anthropic_messages.append({"role": "assistant", "content": tool_use_blocks})
# Add the user message with tool results
anthropic_messages.append({"role": "user", "content": tool_result_blocks})
# Make a second request to get the final response
logger.info(f"Making second request with tool results (stream={stream})")
return self.create_chat_completion(
messages=anthropic_messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
stream=stream,
tools=tools,
)
def get_content(self, response: Any) -> str:
"""
Extract content from the response.
Args:
response: Anthropic response object
Returns:
Content string from the response
"""
if not hasattr(response, "content"):
return ""
# Combine all text blocks into a single string
content_parts = []
for block in response.content:
if isinstance(block, dict) and block.get("type") == "text":
content_parts.append(block.get("text", ""))
elif isinstance(block, str):
content_parts.append(block)
return "".join(content_parts)
def get_streaming_content(self, response: Any) -> StreamingResponse:
"""
Get a generator for streaming content from the response.
Args:
response: Anthropic streaming response object
Returns:
StreamingResponse object wrapping a generator that yields content chunks
and can also store tool call information detected during streaming
"""
logger.info("Starting Anthropic streaming response processing")
# Track only the first tool call detected during streaming
tool_call = None
tool_use_detected = False
# Create the StreamingResponse object first
streaming_response = StreamingResponse(generator=None, tool_call=None)
def generate():
nonlocal tool_call, tool_use_detected
for chunk in response:
logger.debug(f"Chunk type: {type(chunk)}")
logger.debug(f"Chunk content: {json.dumps(chunk.model_dump_json()) if hasattr(chunk, 'json') else str(chunk)}")
# Check for content_block_start events with type "tool_use"
if not tool_use_detected and hasattr(chunk, "type") and chunk.type == "content_block_start":
if hasattr(chunk, "content_block") and hasattr(chunk.content_block, "type"):
if chunk.content_block.type == "tool_use":
logger.info(f"Tool use detected in streaming response: {json.dumps(chunk.model_dump_json()) if hasattr(chunk, 'json') else str(chunk)}")
tool_use_detected = True
tool_call = {"id": getattr(chunk.content_block, "id", ""), "name": getattr(chunk.content_block, "name", ""), "input": getattr(chunk.content_block, "input", {})}
# Update the StreamingResponse object's tool_call attribute
streaming_response.tool_call = tool_call
# We don't signal to the frontend during streaming
# The tool will only be executed after streaming ends
continue
# Handle content_block_delta events for tool_use (input updates)
if tool_use_detected and hasattr(chunk, "type") and chunk.type == "content_block_delta":
if hasattr(chunk, "delta") and hasattr(chunk.delta, "type") and chunk.delta.type == "input_json_delta":
if hasattr(chunk.delta, "partial_json") and chunk.delta.partial_json:
logger.info(f"Tool use input update: {chunk.delta.partial_json}")
# Update the current tool call input
if tool_call:
try:
# Try to parse the partial JSON and update the input
partial_input = json.loads(chunk.delta.partial_json)
tool_call["input"].update(partial_input)
# Update the StreamingResponse object's tool_call attribute
streaming_response.tool_call = tool_call
except json.JSONDecodeError:
logger.warning(f"Failed to parse partial JSON: {chunk.delta.partial_json}")
continue
# Handle content_block_stop events for tool_use
if tool_use_detected and hasattr(chunk, "type") and chunk.type == "content_block_stop":
logger.info("Tool use block completed")
# Log the complete tool call for debugging
if tool_call:
logger.info(f"Completed tool call: {json.dumps(tool_call)}")
# Update the StreamingResponse object's tool_call attribute
streaming_response.tool_call = tool_call
continue
# Handle message_delta events with stop_reason "tool_use"
if hasattr(chunk, "type") and chunk.type == "message_delta":
if hasattr(chunk, "delta") and hasattr(chunk.delta, "stop_reason"):
if chunk.delta.stop_reason == "tool_use":
logger.info("Message stopped due to tool use")
# Update the StreamingResponse object's tool_call attribute one last time
if tool_call:
streaming_response.tool_call = tool_call
continue
# Handle regular content chunks
content = None
if hasattr(chunk, "type") and chunk.type == "content_block_delta":
if hasattr(chunk, "delta") and hasattr(chunk.delta, "text"):
content = chunk.delta.text
elif hasattr(chunk, "delta") and hasattr(chunk.delta, "text"):
content = chunk.delta.text
elif hasattr(chunk, "content") and chunk.content:
for block in chunk.content:
if isinstance(block, dict) and block.get("type") == "text":
content = block.get("text", "")
if content:
# Don't do any newline replacement here
yield content
# Create the generator
gen = generate()
# Set the generator in the StreamingResponse object
streaming_response.generator = gen
# Return the StreamingResponse object
return streaming_response

View File

@@ -0,0 +1,216 @@
"""
Base provider interface for Airflow Wingman.
This module contains the base provider interface that all provider implementations
must adhere to. It defines the methods required for tool conversion, API requests,
and response processing.
"""
import json
from abc import ABC, abstractmethod
from collections.abc import Generator, Iterator
from typing import Any, Generic, TypeVar
T = TypeVar("T")
class StreamingResponse(Generic[T]):
"""
Wrapper for streaming responses that can hold tool call information.
This class wraps a generator and provides an iterator interface while also
storing tool call information. This allows us to associate metadata with
a generator without modifying the generator itself.
"""
def __init__(self, generator: Generator[T, None, None], tool_call: dict = None):
"""
Initialize the streaming response.
Args:
generator: The underlying generator yielding content chunks
tool_call: Optional tool call information detected during streaming
"""
self.generator = generator
self.tool_call = tool_call
def __iter__(self) -> Iterator[T]:
"""
Return self as iterator.
"""
return self
def __next__(self) -> T:
"""
Get the next item from the generator.
"""
return next(self.generator)
class BaseLLMProvider(ABC):
"""
Base provider interface for LLM providers.
This abstract class defines the methods that all provider implementations
must implement to support tool integration.
"""
@abstractmethod
def convert_tools(self, airflow_tools: list) -> list:
"""
Convert internal tool representation to provider format.
Args:
airflow_tools: List of Airflow tools from MCP server
Returns:
List of provider-specific tool definitions
"""
pass
@abstractmethod
def create_chat_completion(
self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, stream: bool = False, tools: list[dict[str, Any]] | None = None
) -> Any:
"""
Make API request to provider.
Args:
messages: List of message dictionaries with 'role' and 'content'
model: Model identifier
temperature: Sampling temperature (0-1)
max_tokens: Maximum tokens to generate
stream: Whether to stream the response
tools: List of tool definitions in provider format
Returns:
Provider-specific response object
"""
pass
def has_tool_calls(self, response: Any) -> bool:
"""
Check if the response contains tool calls.
Args:
response: Provider-specific response object or StreamingResponse
Returns:
True if the response contains tool calls, False otherwise
"""
# Check if response is a StreamingResponse with a tool_call attribute
if isinstance(response, StreamingResponse) and response.tool_call is not None:
return True
# Provider-specific implementation should handle other cases
return False
def get_tool_calls(self, response: Any) -> list:
"""
Extract tool calls from the response.
Args:
response: Provider-specific response object or StreamingResponse
Returns:
List of tool call objects in a standardized format
"""
tool_calls = []
# Check if response is a StreamingResponse with a tool_call attribute
if isinstance(response, StreamingResponse) and response.tool_call is not None:
tool_calls.append(response.tool_call)
# Provider-specific implementation should handle other cases
return tool_calls
def process_tool_calls(self, response: Any, cookie: str) -> dict[str, Any]:
"""
Process tool calls from the response.
Args:
response: Provider-specific response object or StreamingResponse
cookie: Airflow cookie for authentication
Returns:
Dictionary mapping tool call IDs to results
"""
tool_calls = self.get_tool_calls(response)
results = {}
for tool_call in tool_calls:
tool_name = tool_call.get("name", "")
tool_input = tool_call.get("input", {})
tool_id = tool_call.get("id", "")
try:
import logging
logger = logging.getLogger(__name__)
logger.info(f"Executing tool: {tool_name} with input: {json.dumps(tool_input)}")
from airflow_wingman.tools import execute_airflow_tool
result = execute_airflow_tool(tool_name, tool_input, cookie)
logger.info(f"Tool result: {json.dumps(result)}")
results[tool_id] = {
"name": tool_name,
"input": tool_input,
"output": result,
}
except Exception as e:
import traceback
error_msg = f"Error executing tool: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
results[tool_id] = {"status": "error", "message": error_msg}
return results
@abstractmethod
def create_follow_up_completion(
self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None
) -> Any:
"""
Create a follow-up completion with tool results.
Args:
messages: Original messages
model: Model identifier
temperature: Sampling temperature (0-1)
max_tokens: Maximum tokens to generate
tool_results: Results of tool executions
original_response: Original response with tool calls
Returns:
Provider-specific response object
"""
pass
@abstractmethod
def get_content(self, response: Any) -> str:
"""
Extract content from the response.
Args:
response: Provider-specific response object
Returns:
Content string from the response
"""
pass
@abstractmethod
def get_streaming_content(self, response: Any) -> StreamingResponse:
"""
Get a StreamingResponse for streaming content from the response.
Args:
response: Provider-specific response object
Returns:
StreamingResponse object wrapping a generator that yields content chunks
and can also store tool call information detected during streaming
"""
pass

View File

@@ -0,0 +1,354 @@
"""
OpenAI provider implementation for Airflow Wingman.
This module contains the OpenAI provider implementation that handles
API requests, tool conversion, and response processing for OpenAI.
"""
import json
import logging
import traceback
from typing import Any
from openai import OpenAI
from airflow_wingman.providers.base import BaseLLMProvider, StreamingResponse
from airflow_wingman.tools import execute_airflow_tool
from airflow_wingman.tools.conversion import convert_to_openai_tools
# Create a properly namespaced logger for the Airflow plugin
logger = logging.getLogger("airflow.plugins.wingman")
class OpenAIProvider(BaseLLMProvider):
"""
OpenAI provider implementation.
This class handles API requests, tool conversion, and response processing
for the OpenAI API. It can also be used for OpenRouter with a custom base URL.
"""
def __init__(self, api_key: str, base_url: str | None = None):
"""
Initialize the OpenAI provider.
Args:
api_key: API key for OpenAI
base_url: Optional base URL for the API (used for OpenRouter)
"""
self.api_key = api_key
self.client = OpenAI(api_key=api_key, base_url=base_url)
def convert_tools(self, airflow_tools: list) -> list:
"""
Convert Airflow tools to OpenAI format.
Args:
airflow_tools: List of Airflow tools from MCP server
Returns:
List of OpenAI tool definitions
"""
return convert_to_openai_tools(airflow_tools)
def create_chat_completion(
self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, stream: bool = False, tools: list[dict[str, Any]] | None = None
) -> Any:
"""
Make API request to OpenAI.
Args:
messages: List of message dictionaries with 'role' and 'content'
model: Model identifier
temperature: Sampling temperature (0-1)
max_tokens: Maximum tokens to generate
stream: Whether to stream the response
tools: List of tool definitions in OpenAI format
Returns:
OpenAI response object
Raises:
Exception: If the API request fails
"""
# Only include tools if we have any
has_tools = tools is not None and len(tools) > 0
tool_choice = "auto" if has_tools else None
try:
logger.info(f"Sending chat completion request to OpenAI with model: {model}")
# Log information about tools
if not has_tools:
logger.warning("No tools included in request")
# Log request parameters
request_params = {
"model": model,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
"stream": stream,
"tools": tools if has_tools else None,
"tool_choice": tool_choice,
}
logger.info(f"Request parameters: {json.dumps(request_params)}")
response = self.client.chat.completions.create(
model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream, tools=tools if has_tools else None, tool_choice=tool_choice
)
logger.info("Received response from OpenAI")
return response
except Exception as e:
# If the API call fails due to tools not being supported, retry without tools
error_msg = str(e)
logger.warning(f"Error in OpenAI API call: {error_msg}")
if "tools" in error_msg.lower():
logger.info("Retrying without tools")
response = self.client.chat.completions.create(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream)
return response
else:
logger.error(f"Failed to get response from OpenAI: {error_msg}\n{traceback.format_exc()}")
raise
def has_tool_calls(self, response: Any) -> bool:
"""
Check if the response contains tool calls.
Args:
response: OpenAI response object or StreamingResponse with tool_call attribute
Returns:
True if the response contains tool calls, False otherwise
"""
# Check if response is a StreamingResponse with a tool_call attribute
if isinstance(response, StreamingResponse) and response.tool_call is not None:
return True
# For non-streaming responses
if hasattr(response, "choices") and len(response.choices) > 0:
message = response.choices[0].message
return hasattr(message, "tool_calls") and message.tool_calls
return False
def get_tool_calls(self, response: Any) -> list:
"""
Extract tool calls from the response.
Args:
response: OpenAI response object or StreamingResponse with tool_call attribute
Returns:
List of tool call objects in a standardized format
"""
tool_calls = []
# Check if response is a StreamingResponse with a tool_call attribute
if isinstance(response, StreamingResponse) and response.tool_call is not None:
logger.info(f"Extracting tool call from StreamingResponse: {response.tool_call}")
tool_calls.append(response.tool_call)
return tool_calls
# For non-streaming responses
if hasattr(response, "choices") and len(response.choices) > 0:
message = response.choices[0].message
if hasattr(message, "tool_calls") and message.tool_calls:
for tool_call in message.tool_calls:
standardized_tool_call = {"id": tool_call.id, "name": tool_call.function.name, "input": json.loads(tool_call.function.arguments)}
tool_calls.append(standardized_tool_call)
logger.info(f"Extracted {len(tool_calls)} tool calls from OpenAI response")
return tool_calls
def process_tool_calls(self, response: Any, cookie: str) -> dict[str, Any]:
"""
Process tool calls from the response.
Args:
response: OpenAI response object or StreamingResponse with tool_call attribute
cookie: Airflow cookie for authentication
Returns:
Dictionary mapping tool call IDs to results
"""
results = {}
if not self.has_tool_calls(response):
return results
# Get tool calls using the standardized method
tool_calls = self.get_tool_calls(response)
logger.info(f"Processing {len(tool_calls)} tool calls")
for tool_call in tool_calls:
tool_id = tool_call["id"]
function_name = tool_call["name"]
arguments = tool_call["input"]
try:
# Execute the Airflow tool with the provided arguments and cookie
logger.info(f"Executing tool: {function_name} with arguments: {arguments}")
result = execute_airflow_tool(function_name, arguments, cookie)
logger.info(f"Tool execution result: {result}")
results[tool_id] = {"status": "success", "result": result}
except Exception as e:
error_msg = f"Error executing tool: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
results[tool_id] = {"status": "error", "message": error_msg}
return results
def create_follow_up_completion(
self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None
) -> Any:
"""
Create a follow-up completion with tool results.
Args:
messages: Original messages
model: Model identifier
temperature: Sampling temperature (0-1)
max_tokens: Maximum tokens to generate
tool_results: Results of tool executions
original_response: Original response with tool calls
Returns:
OpenAI response object
"""
if not original_response or not tool_results:
return original_response
# Get the original message with tool calls
original_message = original_response.choices[0].message
# Create a new message with the tool calls
assistant_message = {
"role": "assistant",
"content": None,
"tool_calls": [{"id": tc.id, "type": "function", "function": {"name": tc.function.name, "arguments": tc.function.arguments}} for tc in original_message.tool_calls],
}
# Create tool result messages
tool_messages = []
for tool_call_id, result in tool_results.items():
tool_messages.append({"role": "tool", "tool_call_id": tool_call_id, "content": result.get("result", str(result))})
# Add the original messages, assistant message, and tool results
new_messages = messages + [assistant_message] + tool_messages
# Make a second request to get the final response
logger.info("Making second request with tool results")
return self.create_chat_completion(
messages=new_messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
stream=False,
tools=None, # No tools needed for follow-up
)
def get_content(self, response: Any) -> str:
"""
Extract content from the response.
Args:
response: OpenAI response object
Returns:
Content string from the response
"""
return response.choices[0].message.content
def get_streaming_content(self, response: Any) -> StreamingResponse:
"""
Get a generator for streaming content from the response.
Args:
response: OpenAI streaming response object
Returns:
StreamingResponse object wrapping a generator that yields content chunks
and can also store tool call information detected during streaming
"""
logger.info("Starting OpenAI streaming response processing")
# Track only the first tool call detected during streaming
tool_call = None
tool_use_detected = False
current_tool_call = None
# Create the StreamingResponse object first
streaming_response = StreamingResponse(generator=None, tool_call=None)
def generate():
nonlocal tool_call, tool_use_detected, current_tool_call
for chunk in response:
# Check for tool call in the delta
if chunk.choices and hasattr(chunk.choices[0].delta, "tool_calls") and chunk.choices[0].delta.tool_calls:
# Tool call detected
if not tool_use_detected:
tool_use_detected = True
logger.info("Tool call detected in streaming response")
# Initialize the tool call
delta_tool_call = chunk.choices[0].delta.tool_calls[0]
current_tool_call = {
"id": getattr(delta_tool_call, "id", ""),
"name": getattr(delta_tool_call.function, "name", "") if hasattr(delta_tool_call, "function") else "",
"input": {},
}
# Update the StreamingResponse object's tool_call attribute
streaming_response.tool_call = current_tool_call
else:
# Update the existing tool call
delta_tool_call = chunk.choices[0].delta.tool_calls[0]
# Update the tool call ID if it's provided in this chunk
if hasattr(delta_tool_call, "id") and delta_tool_call.id and current_tool_call:
current_tool_call["id"] = delta_tool_call.id
# Update the function name if it's provided in this chunk
if hasattr(delta_tool_call, "function") and hasattr(delta_tool_call.function, "name") and delta_tool_call.function.name and current_tool_call:
current_tool_call["name"] = delta_tool_call.function.name
# Update the arguments if they're provided in this chunk
if hasattr(delta_tool_call, "function") and hasattr(delta_tool_call.function, "arguments") and delta_tool_call.function.arguments and current_tool_call:
try:
# Try to parse the arguments JSON
arguments = json.loads(delta_tool_call.function.arguments)
if isinstance(arguments, dict):
current_tool_call["input"].update(arguments)
# Update the StreamingResponse object's tool_call attribute
streaming_response.tool_call = current_tool_call
except json.JSONDecodeError:
# If the arguments are not valid JSON, just log a warning
logger.warning(f"Failed to parse arguments: {delta_tool_call.function.arguments}")
# Skip yielding content for tool call chunks
continue
# For the final chunk, set the tool_call attribute
if chunk.choices and hasattr(chunk.choices[0], "finish_reason") and chunk.choices[0].finish_reason == "tool_calls":
logger.info("Streaming response finished with tool_calls reason")
if current_tool_call:
tool_call = current_tool_call
logger.info(f"Final tool call: {json.dumps(tool_call)}")
# Update the StreamingResponse object's tool_call attribute
streaming_response.tool_call = tool_call
continue
# Handle regular content chunks
if chunk.choices and hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
yield content
# Create the generator
gen = generate()
# Set the generator in the StreamingResponse object
streaming_response.generator = gen
# Return the StreamingResponse object
return streaming_response

View File

@@ -0,0 +1,99 @@
/* Provider and model selection styling */
.provider-section {
margin-bottom: 20px;
}
.provider-name {
font-size: 16px;
font-weight: bold;
margin-bottom: 10px;
color: #666;
}
.model-option {
margin-left: 15px;
margin-bottom: 8px;
}
.model-option label {
display: block;
cursor: pointer;
}
/* Message styling */
.message {
margin-bottom: 15px;
max-width: 80%;
clear: both;
}
.message p {
margin-top: 0.5em;
margin-bottom: 0.5em;
}
.message-user {
float: right;
background-color: #f0f7ff;
border: 1px solid #d1e6ff;
border-radius: 15px 15px 0 15px;
padding: 10px 15px;
}
.message pre {
margin-top: 0.5em;
margin-bottom: 0.5em;
padding: 0.5em;
}
.message-assistant {
float: left;
background-color: #f8f9fa;
border: 1px solid #e9ecef;
border-radius: 15px 15px 15px 0;
padding: 10px 15px;
}
.message code {
padding: 0.1em 0.3em;
}
#chat-messages::after {
content: "";
clear: both;
display: table;
}
/* Scrollbar styling */
.panel-body::-webkit-scrollbar {
width: 8px;
}
.panel-body::-webkit-scrollbar-track {
background: #f1f1f1;
}
.panel-body::-webkit-scrollbar-thumb {
background: #888;
border-radius: 4px;
}
.panel-body::-webkit-scrollbar-thumb:hover {
background: #555;
}
/* Processing indicator styling */
.processing-indicator {
display: none;
background-color: #f0f8ff;
padding: 8px 12px;
border-radius: 4px;
margin: 8px 0;
font-style: italic;
}
.processing-indicator.visible {
display: block;
}
.pre-formatted {
font-family: monospace;
line-height: 1.2;
}

View File

@@ -0,0 +1,346 @@
document.addEventListener('DOMContentLoaded', function() {
// Initialize tooltips
document.querySelectorAll('[data-bs-toggle="tooltip"]').forEach(function(el) {
el.title = el.getAttribute('title') || el.getAttribute('data-bs-original-title');
});
// Handle model selection and model name input
const modelNameInput = document.getElementById('modelName');
const modelRadios = document.querySelectorAll('input[name="model"]');
modelRadios.forEach(function(radio) {
radio.addEventListener('change', function() {
const provider = this.value.split(':')[0];
const modelName = this.getAttribute('data-model-name');
console.log('Selected provider:', provider);
console.log('Model name:', modelName);
if (provider === 'openrouter') {
console.log('Enabling model name input');
modelNameInput.disabled = false;
modelNameInput.value = '';
modelNameInput.placeholder = 'Enter model name for OpenRouter';
} else {
console.log('Disabling model name input');
modelNameInput.disabled = true;
modelNameInput.value = modelName;
}
});
});
// Set initial state based on default selection
const defaultSelected = document.querySelector('input[name="model"]:checked');
if (defaultSelected) {
const provider = defaultSelected.value.split(':')[0];
const modelName = defaultSelected.getAttribute('data-model-name');
console.log('Initial provider:', provider);
console.log('Initial model name:', modelName);
if (provider === 'openrouter') {
console.log('Initially enabling model name input');
modelNameInput.disabled = false;
modelNameInput.value = '';
modelNameInput.placeholder = 'Enter model name for OpenRouter';
} else {
console.log('Initially disabling model name input');
modelNameInput.disabled = true;
modelNameInput.value = modelName;
}
}
const messageInput = document.getElementById('message-input');
const sendButton = document.getElementById('send-button');
const refreshButton = document.getElementById('refresh-button');
const chatMessages = document.getElementById('chat-messages');
let currentMessageDiv = null;
let messageHistory = [];
// Create a processing indicator element
const processingIndicator = document.createElement('div');
processingIndicator.className = 'processing-indicator';
processingIndicator.textContent = 'Processing tool calls...';
chatMessages.appendChild(processingIndicator);
function clearChat() {
// Clear the chat messages
chatMessages.innerHTML = '';
// Add back the processing indicator
chatMessages.appendChild(processingIndicator);
// Reset message history
messageHistory = [];
// Clear the input field
messageInput.value = '';
// Enable input if it was disabled
messageInput.disabled = false;
sendButton.disabled = false;
}
function addMessage(content, isUser) {
const messageDiv = document.createElement('div');
messageDiv.className = `message ${isUser ? 'message-user' : 'message-assistant'}`;
messageDiv.classList.add('pre-formatted');
// Use marked.js to render markdown
try {
// Configure marked options
marked.use({
breaks: true, // Add line breaks on single newlines
gfm: true, // Use GitHub Flavored Markdown
headerIds: false, // Don't add IDs to headers
mangle: false, // Don't mangle email addresses
});
// Render markdown to HTML
messageDiv.innerHTML = marked.parse(content);
} catch (e) {
console.error('Error rendering markdown:', e);
// Fallback to innerText if markdown parsing fails
messageDiv.innerText = content;
}
chatMessages.appendChild(messageDiv);
chatMessages.scrollTop = chatMessages.scrollHeight;
return messageDiv;
}
function showProcessingIndicator() {
processingIndicator.classList.add('visible');
chatMessages.scrollTop = chatMessages.scrollHeight;
// Disable send button and input field during tool processing
sendButton.disabled = true;
messageInput.disabled = true;
}
function hideProcessingIndicator() {
processingIndicator.classList.remove('visible');
// Re-enable send button and input field after tool processing
sendButton.disabled = false;
messageInput.disabled = false;
}
async function sendMessage() {
const message = messageInput.value.trim();
if (!message) return;
// Get selected model
const selectedModel = document.querySelector('input[name="model"]:checked');
if (!selectedModel) {
alert('Please select a model');
return;
}
const [provider, modelId] = selectedModel.value.split(':');
const modelName = provider === 'openrouter' ? modelNameInput.value : modelId;
// Clear input and add user message
messageInput.value = '';
addMessage(message, true);
// Add user message to history
messageHistory.push({
role: 'user',
content: message
});
// Use full message history for the request
const messages = [...messageHistory];
// Create assistant message div
currentMessageDiv = addMessage('', false);
// Get API key
const apiKey = document.getElementById('api-key').value.trim();
if (!apiKey) {
alert('Please enter an API key');
return;
}
// Disable input while processing
messageInput.disabled = true;
sendButton.disabled = true;
// Get CSRF token
const csrfToken = document.querySelector('meta[name="csrf-token"]')?.getAttribute('content');
if (!csrfToken) {
alert('CSRF token not found. Please refresh the page.');
return;
}
// Create request data
const requestData = {
provider: provider,
model: modelName,
messages: messages,
api_key: apiKey,
stream: true,
temperature: 0.4,
};
console.log('Sending request:', {...requestData, api_key: '***'});
try {
// Send request
const response = await fetch('/wingman/chat', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'X-CSRFToken': csrfToken
},
body: JSON.stringify(requestData)
});
if (!response.ok) {
const error = await response.json();
throw new Error(error.error || 'Failed to get response');
}
// Process the streaming response
const reader = response.body.getReader();
const decoder = new TextDecoder();
let fullResponse = '';
while (true) {
const { value, done } = await reader.read();
if (done) break;
const chunk = decoder.decode(value);
const lines = chunk.split('\n');
for (const line of lines) {
if (line.trim() === '') continue;
if (line.startsWith('data: ')) {
const content = line.slice(6); // Remove 'data: ' prefix
// Check for special events or end marker
if (content === '[DONE]') {
console.log('Stream complete');
// Add assistant's response to history
if (fullResponse) {
messageHistory.push({
role: 'assistant',
content: fullResponse
});
}
continue;
}
// Try to parse as JSON for special events
try {
const parsed = JSON.parse(content);
if (parsed.event === 'tool_processing_start') {
console.log('Tool processing started');
showProcessingIndicator();
continue;
}
if (parsed.event === 'replace_content') {
console.log('Replacing content due to tool call');
// Clear the current message content
const currentMessageDiv = document.querySelector('.message.assistant:last-child .message-content');
if (currentMessageDiv) {
currentMessageDiv.innerHTML = '';
fullResponse = ''; // Reset the full response
}
continue;
}
if (parsed.event === 'tool_processing_complete') {
console.log('Tool processing completed');
hideProcessingIndicator();
continue;
}
// Handle follow-up response event
if (parsed.event === 'follow_up_response' && parsed.content) {
console.log('Received follow-up response');
// Add this follow-up response to message history
messageHistory.push({
role: 'assistant',
content: parsed.content
});
// Create a new message div for the follow-up response
// The addMessage function already handles markdown rendering
addMessage(parsed.content, false);
continue;
}
// Handle the complete response event
if (parsed.event === 'complete_response') {
console.log('Received complete response from backend');
// Use the complete response from the backend
fullResponse = parsed.content;
// Use marked.js to render markdown
try {
// Configure marked options
marked.use({
breaks: true, // Add line breaks on single newlines
gfm: true, // Use GitHub Flavored Markdown
headerIds: false, // Don't add IDs to headers
mangle: false, // Don't mangle email addresses
});
// Render markdown to HTML
currentMessageDiv.innerHTML = marked.parse(fullResponse);
} catch (e) {
console.error('Error rendering markdown:', e);
// Fallback to innerText if markdown parsing fails
currentMessageDiv.innerText = fullResponse;
}
continue;
}
// If we have JSON that's not a special event, it might be content
currentMessageDiv.textContent += JSON.stringify(parsed);
fullResponse += JSON.stringify(parsed);
} catch (e) {
// Not JSON, handle as normal content
// console.log('Received chunk:', JSON.stringify(content));
// Add to full response
fullResponse += content;
// Create a properly formatted display
if (!currentMessageDiv.classList.contains('pre-formatted')) {
currentMessageDiv.classList.add('pre-formatted');
}
// Always rebuild the entire content from the full response
currentMessageDiv.innerHTML = marked.parse(fullResponse);
}
// Scroll to bottom
chatMessages.scrollTop = chatMessages.scrollHeight;
}
}
}
} catch (error) {
console.error('Error:', error);
if (currentMessageDiv) {
currentMessageDiv.textContent = `Error: ${error.message}`;
currentMessageDiv.style.color = 'red';
}
} finally {
// Always re-enable input and hide indicators
messageInput.disabled = false;
sendButton.disabled = false;
hideProcessingIndicator();
}
}
sendButton.addEventListener('click', sendMessage);
messageInput.addEventListener('keypress', function(e) {
if (e.key === 'Enter') {
sendMessage();
}
});
refreshButton.addEventListener('click', clearChat);
});

View File

@@ -3,6 +3,7 @@
{% block head_meta %} {% block head_meta %}
{{ super() }} {{ super() }}
<meta name="csrf-token" content="{{ csrf_token() }}"> <meta name="csrf-token" content="{{ csrf_token() }}">
<link rel="stylesheet" href="{{ url_for('wingman.static', filename='css/wingman_chat.css') }}">
{% endblock %} {% endblock %}
{% block content %} {% block content %}
@@ -77,25 +78,7 @@
</div> </div>
</div> </div>
<style>
.provider-section {
margin-bottom: 20px;
}
.provider-name {
font-size: 16px;
font-weight: bold;
margin-bottom: 10px;
color: #666;
}
.model-option {
margin-left: 15px;
margin-bottom: 8px;
}
.model-option label {
display: block;
cursor: pointer;
}
</style>
</div> </div>
</div> </div>
@@ -129,256 +112,6 @@
</div> </div>
</div> </div>
<style> <script src="https://cdn.jsdelivr.net/npm/marked@9.1.6/marked.min.js"></script>
.message { <script src="{{ url_for('wingman.static', filename='js/wingman_chat.js') }}"></script>
margin-bottom: 15px;
max-width: 80%;
clear: both;
}
.message-user {
float: right;
background-color: #f0f7ff;
border: 1px solid #d1e6ff;
border-radius: 15px 15px 0 15px;
padding: 10px 15px;
}
.message-assistant {
float: left;
background-color: #f8f9fa;
border: 1px solid #e9ecef;
border-radius: 15px 15px 15px 0;
padding: 10px 15px;
}
#chat-messages::after {
content: "";
clear: both;
display: table;
}
.panel-body::-webkit-scrollbar {
width: 8px;
}
.panel-body::-webkit-scrollbar-track {
background: #f1f1f1;
}
.panel-body::-webkit-scrollbar-thumb {
background: #888;
border-radius: 4px;
}
.panel-body::-webkit-scrollbar-thumb:hover {
background: #555;
}
</style>
<script>
document.addEventListener('DOMContentLoaded', function() {
// Add title attributes for tooltips
document.querySelectorAll('[data-bs-toggle="tooltip"]').forEach(function(el) {
el.title = el.getAttribute('title') || el.getAttribute('data-bs-original-title');
});
// Handle model selection and model name input
const modelNameInput = document.getElementById('modelName');
const modelRadios = document.querySelectorAll('input[name="model"]');
modelRadios.forEach(function(radio) {
radio.addEventListener('change', function() {
const provider = this.value.split(':')[0]; // Get provider from value instead of data attribute
const modelName = this.getAttribute('data-model-name');
console.log('Selected provider:', provider);
console.log('Model name:', modelName);
if (provider === 'openrouter') {
console.log('Enabling model name input');
modelNameInput.disabled = false;
modelNameInput.value = '';
modelNameInput.placeholder = 'Enter model name for OpenRouter';
} else {
console.log('Disabling model name input');
modelNameInput.disabled = true;
modelNameInput.value = modelName;
}
});
});
// Set initial state based on default selection
const defaultSelected = document.querySelector('input[name="model"]:checked');
if (defaultSelected) {
const provider = defaultSelected.value.split(':')[0]; // Get provider from value instead of data attribute
const modelName = defaultSelected.getAttribute('data-model-name');
console.log('Initial provider:', provider);
console.log('Initial model name:', modelName);
if (provider === 'openrouter') {
console.log('Initially enabling model name input');
modelNameInput.disabled = false;
modelNameInput.value = '';
modelNameInput.placeholder = 'Enter model name for OpenRouter';
} else {
console.log('Initially disabling model name input');
modelNameInput.disabled = true;
modelNameInput.value = modelName;
}
}
const messageInput = document.getElementById('message-input');
const sendButton = document.getElementById('send-button');
const refreshButton = document.getElementById('refresh-button');
const chatMessages = document.getElementById('chat-messages');
let currentMessageDiv = null;
let messageHistory = [];
function clearChat() {
// Clear the chat messages
chatMessages.innerHTML = '';
// Reset message history
messageHistory = [];
// Clear the input field
messageInput.value = '';
// Enable input if it was disabled
messageInput.disabled = false;
sendButton.disabled = false;
}
function addMessage(content, isUser) {
const messageDiv = document.createElement('div');
messageDiv.className = `message ${isUser ? 'message-user' : 'message-assistant'}`;
messageDiv.textContent = content;
chatMessages.appendChild(messageDiv);
chatMessages.scrollTop = chatMessages.scrollHeight;
return messageDiv;
}
async function sendMessage() {
const message = messageInput.value.trim();
if (!message) return;
// Get selected model
const selectedModel = document.querySelector('input[name="model"]:checked');
if (!selectedModel) {
alert('Please select a model');
return;
}
const [provider, modelId] = selectedModel.value.split(':');
const modelName = provider === 'openrouter' ? modelNameInput.value : modelId;
// Clear input and add user message
messageInput.value = '';
addMessage(message, true);
try {
// Add user message to history
messageHistory.push({
role: 'user',
content: message
});
// Use full message history for the request
const messages = [...messageHistory];
// Create assistant message div
currentMessageDiv = addMessage('', false);
// Get API key
const apiKey = document.getElementById('api-key').value.trim();
if (!apiKey) {
alert('Please enter an API key');
return;
}
// Debug log the request
const requestData = {
provider: provider,
model: modelName,
messages: messages,
api_key: apiKey,
stream: true,
temperature: 0.7
};
console.log('Sending request:', {...requestData, api_key: '***'});
// Get CSRF token
const csrfToken = document.querySelector('meta[name="csrf-token"]')?.getAttribute('content');
if (!csrfToken) {
throw new Error('CSRF token not found. Please refresh the page.');
}
// Send request
const response = await fetch('/wingman/chat', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'X-CSRFToken': csrfToken
},
body: JSON.stringify({
provider: provider,
model: modelName,
messages: messages,
api_key: apiKey,
stream: true,
temperature: 0.7
})
});
if (!response.ok) {
const error = await response.json();
throw new Error(error.error || 'Failed to get response');
}
// Handle streaming response
const reader = response.body.getReader();
const decoder = new TextDecoder();
let fullResponse = '';
while (true) {
const { value, done } = await reader.read();
if (done) break;
const chunk = decoder.decode(value);
const lines = chunk.split('\n');
for (const line of lines) {
if (line.startsWith('data: ')) {
const content = line.slice(6);
if (content) {
currentMessageDiv.textContent += content;
fullResponse += content;
chatMessages.scrollTop = chatMessages.scrollHeight;
}
}
}
}
// Add assistant's response to history
if (fullResponse) {
messageHistory.push({
role: 'assistant',
content: fullResponse
});
}
} catch (error) {
console.error('Error:', error);
currentMessageDiv.textContent = `Error: ${error.message}`;
currentMessageDiv.style.color = 'red';
}
}
sendButton.addEventListener('click', sendMessage);
messageInput.addEventListener('keypress', function(e) {
if (e.key === 'Enter') {
sendMessage();
}
});
refreshButton.addEventListener('click', clearChat);
});
</script>
{% endblock %} {% endblock %}

View File

@@ -0,0 +1,15 @@
"""
Tools module for Airflow Wingman.
This module contains the tools used by Airflow Wingman to interact with Airflow.
"""
from airflow_wingman.tools.conversion import convert_to_anthropic_tools, convert_to_openai_tools
from airflow_wingman.tools.execution import execute_airflow_tool, list_airflow_tools
__all__ = [
"convert_to_openai_tools",
"convert_to_anthropic_tools",
"list_airflow_tools",
"execute_airflow_tool",
]

View File

@@ -0,0 +1,143 @@
"""
Conversion utilities for Airflow Wingman tools.
This module contains functions to convert between different tool formats
for various LLM providers (OpenAI, Anthropic, etc.).
"""
import logging
from typing import Any
def convert_to_openai_tools(airflow_tools: list) -> list:
"""
Convert Airflow tools to OpenAI tool definitions.
Args:
airflow_tools: List of Airflow tools from MCP server
Returns:
List of OpenAI tool definitions
"""
openai_tools = []
for tool in airflow_tools:
# Initialize the OpenAI tool structure
openai_tool = {"type": "function", "function": {"name": tool.name, "description": tool.description or tool.name, "parameters": {"type": "object", "properties": {}, "required": []}}}
# Extract parameters directly from inputSchema if available
if hasattr(tool, "inputSchema") and tool.inputSchema:
# Set the type and required fields directly from the schema
if "type" in tool.inputSchema:
openai_tool["function"]["parameters"]["type"] = tool.inputSchema["type"]
# Add required parameters if specified
if "required" in tool.inputSchema:
openai_tool["function"]["parameters"]["required"] = tool.inputSchema["required"]
# Add properties from the input schema
if "properties" in tool.inputSchema:
for param_name, param_info in tool.inputSchema["properties"].items():
# Create parameter definition
param_def = {}
# Handle different schema constructs
if "anyOf" in param_info:
_handle_schema_construct(param_def, param_info, "anyOf")
elif "oneOf" in param_info:
_handle_schema_construct(param_def, param_info, "oneOf")
elif "allOf" in param_info:
_handle_schema_construct(param_def, param_info, "allOf")
elif "type" in param_info:
param_def["type"] = param_info["type"]
# Add format if available
if "format" in param_info:
param_def["format"] = param_info["format"]
else:
param_def["type"] = "string" # Default type
# Add description from title or param name
param_def["description"] = param_info.get("description", param_info.get("title", param_name))
# Add enum values if available
if "enum" in param_info:
param_def["enum"] = param_info["enum"]
# Add default value if available
if "default" in param_info and param_info["default"] is not None:
param_def["default"] = param_info["default"]
# Add to properties
openai_tool["function"]["parameters"]["properties"][param_name] = param_def
openai_tools.append(openai_tool)
return openai_tools
def convert_to_anthropic_tools(airflow_tools: list) -> list:
"""
Convert Airflow tools to Anthropic tool definitions.
Args:
airflow_tools: List of Airflow tools from MCP server
Returns:
List of Anthropic tool definitions
"""
logger = logging.getLogger("airflow.plugins.wingman")
logger.info(f"Converting {len(airflow_tools)} Airflow tools to Anthropic format")
anthropic_tools = []
for tool in airflow_tools:
# Initialize the Anthropic tool structure
anthropic_tool = {"name": tool.name, "description": tool.description or tool.name, "input_schema": {}}
# Extract parameters directly from inputSchema if available
if hasattr(tool, "inputSchema") and tool.inputSchema:
# Copy the input schema directly as Anthropic's format is similar to JSON Schema
anthropic_tool["input_schema"] = tool.inputSchema
else:
# Create a minimal schema if none exists
anthropic_tool["input_schema"] = {"type": "object", "properties": {}, "required": []}
anthropic_tools.append(anthropic_tool)
logger.info(f"Converted {len(anthropic_tools)} tools to Anthropic format")
return anthropic_tools
def _handle_schema_construct(param_def: dict[str, Any], param_info: dict[str, Any], construct_type: str) -> None:
"""
Helper function to handle JSON Schema constructs like anyOf, oneOf, allOf.
Args:
param_def: Parameter definition to update
param_info: Parameter info from the schema
construct_type: Type of construct (anyOf, oneOf, allOf)
"""
# Get the list of schemas from the construct
schemas = param_info[construct_type]
# Find the first schema with a type
for schema in schemas:
if "type" in schema:
param_def["type"] = schema["type"]
# Add format if available
if "format" in schema:
param_def["format"] = schema["format"]
# Add enum values if available
if "enum" in schema:
param_def["enum"] = schema["enum"]
# Add default value if available
if "default" in schema and schema["default"] is not None:
param_def["default"] = schema["default"]
break
# If no type was found, default to string
if "type" not in param_def:
param_def["type"] = "string"

View File

@@ -0,0 +1,153 @@
"""
Tool execution module for Airflow Wingman.
This module contains functions to list and execute Airflow tools.
"""
import asyncio
import json
import logging
import traceback
from airflow import configuration
from airflow_mcp_server.config import AirflowConfig
from airflow_mcp_server.tools.tool_manager import get_airflow_tools, get_tool
# Create a properly namespaced logger for the Airflow plugin
logger = logging.getLogger("airflow.plugins.wingman")
async def _list_airflow_tools_async(cookie: str) -> list:
"""
Async implementation to list available Airflow tools.
Args:
cookie: Cookie for authentication
Returns:
List of available Airflow tools
"""
try:
# Set up configuration
base_url = f"{configuration.conf.get('webserver', 'base_url')}/api/v1/"
logger.info(f"Setting up AirflowConfig with base_url: {base_url}")
# Format the cookie properly if it doesn't already have the 'session=' prefix
formatted_cookie = cookie
if cookie and not cookie.startswith("session="):
formatted_cookie = f"session={cookie}"
logger.info(f"Formatted cookie with session prefix: {formatted_cookie[:10]}...")
config = AirflowConfig(base_url=base_url, cookie=formatted_cookie, auth_token=None)
# Get available tools
logger.info("Getting Airflow tools...")
tools = await get_airflow_tools(config=config, mode="safe")
logger.info(f"Got {len(tools)} tools")
return tools
except Exception as e:
error_msg = f"Error listing Airflow tools: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
return []
def list_airflow_tools(cookie: str) -> list:
"""
Synchronous wrapper to list available Airflow tools.
Args:
cookie: Cookie for authentication
Returns:
List of available Airflow tools
"""
return asyncio.run(_list_airflow_tools_async(cookie))
async def _execute_airflow_tool_async(tool_name: str, arguments: dict, cookie: str) -> str:
"""
Async implementation to execute an Airflow tool.
Args:
tool_name: Name of the tool to execute
arguments: Arguments to pass to the tool
cookie: Cookie for authentication
Returns:
Result of the tool execution as a string
"""
try:
# Set up configuration
base_url = f"{configuration.conf.get('webserver', 'base_url')}/api/v1/"
logger.info(f"Setting up AirflowConfig with base_url: {base_url}")
# Format the cookie properly if it doesn't already have the 'session=' prefix
formatted_cookie = cookie
if cookie and not cookie.startswith("session="):
formatted_cookie = f"session={cookie}"
logger.info(f"Formatted cookie with session prefix: {formatted_cookie[:10]}...")
config = AirflowConfig(base_url=base_url, cookie=formatted_cookie, auth_token=None)
# Get the tool
logger.info(f"Getting tool: {tool_name}")
tool = await get_tool(config=config, name=tool_name)
if not tool:
error_msg = f"Tool not found: {tool_name}"
logger.error(error_msg)
return json.dumps({"error": error_msg})
# Execute the tool - ensure the client is in an async context
logger.info(f"Executing tool: {tool_name} with arguments: {arguments}")
# The AirflowClient needs to be used as an async context manager
# to properly initialize its session
async with tool.client as client: # noqa F841
# Now the client has a _session attribute and is in an async context
result = await tool.run(arguments)
# Convert result to string
if isinstance(result, dict | list):
result_str = json.dumps(result, indent=2)
else:
result_str = str(result)
logger.info(f"Tool execution result: {result_str[:100]}...")
return result_str
except Exception as e:
error_msg = f"Error executing tool: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
return json.dumps({"error": error_msg})
def execute_airflow_tool(tool_name: str, arguments: dict, cookie: str) -> str:
"""
Synchronous wrapper to execute an Airflow tool.
Args:
tool_name: Name of the tool to execute
arguments: Arguments to pass to the tool
cookie: Cookie for authentication
Returns:
Result of the tool execution as a string
"""
# Create a new event loop for this execution
# This ensures we're always in a clean async context
loop = asyncio.new_event_loop()
try:
# Set the event loop for this thread
asyncio.set_event_loop(loop)
# Run the async function in the new event loop
result = loop.run_until_complete(_execute_airflow_tool_async(tool_name, arguments, cookie))
return result
except Exception as e:
error_msg = f"Error in execute_airflow_tool: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
return json.dumps({"error": error_msg})
finally:
# Always close the loop to free resources
loop.close()

View File

@@ -1,6 +1,9 @@
"""Views for Airflow Wingman plugin.""" """Views for Airflow Wingman plugin."""
from flask import Response, request, stream_with_context import json
import logging
from flask import Response, request, session
from flask.json import jsonify from flask.json import jsonify
from flask_appbuilder import BaseView as AppBuilderBaseView, expose from flask_appbuilder import BaseView as AppBuilderBaseView, expose
@@ -8,6 +11,10 @@ from airflow_wingman.llm_client import LLMClient
from airflow_wingman.llms_models import MODELS from airflow_wingman.llms_models import MODELS
from airflow_wingman.notes import INTERFACE_MESSAGES from airflow_wingman.notes import INTERFACE_MESSAGES
from airflow_wingman.prompt_engineering import prepare_messages from airflow_wingman.prompt_engineering import prepare_messages
from airflow_wingman.tools import list_airflow_tools
# Create a properly namespaced logger for the Airflow plugin
logger = logging.getLogger("airflow.plugins.wingman")
class WingmanView(AppBuilderBaseView): class WingmanView(AppBuilderBaseView):
@@ -28,8 +35,43 @@ class WingmanView(AppBuilderBaseView):
try: try:
data = self._validate_chat_request(request.get_json()) data = self._validate_chat_request(request.get_json())
# Create a new client for this request if data.get("cookie"):
client = LLMClient(data["api_key"]) session["airflow_cookie"] = data["cookie"]
# Get available Airflow tools using the stored cookie
airflow_tools = []
airflow_cookie = request.cookies.get("session")
if airflow_cookie:
try:
airflow_tools = list_airflow_tools(airflow_cookie)
logger.info(f"Loaded {len(airflow_tools)} Airflow tools")
if len(airflow_tools) > 0:
logger.info(f"First tool: {airflow_tools[0].name if hasattr(airflow_tools[0], 'name') else 'Unknown'}")
else:
logger.warning("No Airflow tools were loaded")
except Exception as e:
# Log the error but continue without tools
logger.error(f"Error fetching Airflow tools: {str(e)}")
# Prepare messages with Airflow tools included in the prompt
data["messages"] = prepare_messages(data["messages"])
# Get provider name from request or use default
provider_name = data.get("provider", "openai")
# Get base URL from models configuration based on provider
base_url = MODELS.get(provider_name, {}).get("endpoint")
# Log the request parameters (excluding API key for security)
safe_data = {k: v for k, v in data.items() if k != "api_key"}
logger.info(f"Chat request: provider={provider_name}, model={data.get('model')}, stream={data.get('stream')}")
logger.info(f"Request parameters: {json.dumps(safe_data)[:200]}...")
# Create a new client for this request with the appropriate provider
client = LLMClient(provider_name=provider_name, api_key=data["api_key"], base_url=base_url)
# Set the Airflow tools for the client to use
client.set_airflow_tools(airflow_tools)
if data["stream"]: if data["stream"]:
return self._handle_streaming_response(client, data) return self._handle_streaming_response(client, data)
@@ -46,39 +88,116 @@ class WingmanView(AppBuilderBaseView):
if not data: if not data:
raise ValueError("No data provided") raise ValueError("No data provided")
required_fields = ["provider", "model", "messages", "api_key"] required_fields = ["model", "messages", "api_key"]
missing = [f for f in required_fields if not data.get(f)] missing = [f for f in required_fields if not data.get(f)]
if missing: if missing:
raise ValueError(f"Missing required fields: {', '.join(missing)}") raise ValueError(f"Missing required fields: {', '.join(missing)}")
# Prepare messages with system instruction while maintaining history # Validate provider if provided
messages = data["messages"] provider = data.get("provider", "openai")
messages = prepare_messages(messages) if provider not in MODELS:
raise ValueError(f"Unsupported provider: {provider}. Supported providers: {', '.join(MODELS.keys())}")
return { return {
"provider": data["provider"],
"model": data["model"], "model": data["model"],
"messages": messages, "messages": data["messages"],
"api_key": data["api_key"], "api_key": data["api_key"],
"stream": data.get("stream", False), "stream": data.get("stream", True),
"temperature": data.get("temperature", 0.7), "temperature": data.get("temperature", 0.4),
"max_tokens": data.get("max_tokens"), "max_tokens": data.get("max_tokens"),
"cookie": data.get("cookie"),
"provider": provider,
"base_url": data.get("base_url"),
} }
def _handle_streaming_response(self, client: LLMClient, data: dict) -> Response: def _handle_streaming_response(self, client: LLMClient, data: dict) -> Response:
"""Handle streaming response.""" """Handle streaming response."""
try:
logger.info("Beginning streaming response")
# Get the cookie at the beginning of the request handler
airflow_cookie = request.cookies.get("session")
logger.info(f"Got airflow_cookie: {airflow_cookie is not None}")
def generate(): # Use the enhanced chat_completion method with return_response_obj=True
for chunk in client.chat_completion(messages=data["messages"], model=data["model"], provider=data["provider"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=True): streaming_response = client.chat_completion(messages=data["messages"], model=data["model"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=True)
yield f"data: {chunk}\n\n"
response = Response(stream_with_context(generate()), mimetype="text/event-stream") def stream_response(cookie=airflow_cookie):
response.headers["Content-Type"] = "text/event-stream" complete_response = ""
response.headers["Cache-Control"] = "no-cache"
response.headers["Connection"] = "keep-alive" # Stream the initial response
return response for chunk in streaming_response:
if chunk:
complete_response += chunk
yield f"data: {chunk}\n\n"
# Log the complete assembled response
logger.info("COMPLETE RESPONSE START >>>")
logger.info(complete_response)
logger.info("<<< COMPLETE RESPONSE END")
# Check for tool calls and make follow-up if needed
has_tool_calls = client.provider.has_tool_calls(streaming_response)
logger.info(f"Has tool calls: {has_tool_calls}")
if has_tool_calls:
# Signal tool processing start - frontend should disable send button
yield f"data: {json.dumps({'event': 'tool_processing_start'})}\n\n"
# Signal to replace content - frontend should clear the current message
yield f"data: {json.dumps({'event': 'replace_content'})}\n\n"
logger.info("Response contains tool calls, making follow-up request")
logger.info(f"Using cookie from closure: {cookie is not None}")
# Process tool calls and get follow-up response (handles recursive tool calls)
# Always stream the follow-up response for consistent handling
follow_up_response = client.process_tool_calls_and_follow_up(
streaming_response, data["messages"], data["model"], data["temperature"], data["max_tokens"], cookie=cookie, stream=True
)
# Collect the follow-up response
follow_up_complete_response = ""
for chunk in follow_up_response:
if chunk:
follow_up_complete_response += chunk
# Send the follow-up response as a single event
if follow_up_complete_response:
follow_up_event = json.dumps({'event': 'follow_up_response', 'content': follow_up_complete_response})
logger.info(f"Follow-up event created with length: {len(follow_up_event)}")
data_line = f"data: {follow_up_event}\n\n"
logger.info(f"Yielding data line with length: {len(data_line)}")
yield data_line
# Log the complete follow-up response
logger.info("FOLLOW-UP RESPONSE START >>>")
logger.info(follow_up_complete_response)
logger.info("<<< FOLLOW-UP RESPONSE END")
# Signal tool processing complete - frontend can re-enable send button
yield f"data: {json.dumps({'event': 'tool_processing_complete'})}\n\n"
# Send the complete response as a special event (for compatibility with existing code)
complete_event = json.dumps({"event": "complete_response", "content": complete_response})
yield f"data: {complete_event}\n\n"
# Signal end of stream
yield "data: [DONE]\n\n"
return Response(stream_response(), mimetype="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
except Exception as e:
logger.error(f"Streaming error: {str(e)}")
return jsonify({"error": str(e)}), 500
def _handle_regular_response(self, client: LLMClient, data: dict) -> Response: def _handle_regular_response(self, client: LLMClient, data: dict) -> Response:
"""Handle regular response.""" """Handle regular response."""
response = client.chat_completion(messages=data["messages"], model=data["model"], provider=data["provider"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=False) try:
return jsonify(response) logger.info("Beginning regular (non-streaming) response")
response = client.chat_completion(messages=data["messages"], model=data["model"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=False)
logger.info("COMPLETE RESPONSE START >>>")
logger.info(f"Response to frontend: {json.dumps(response)}")
logger.info("<<< COMPLETE RESPONSE END")
return jsonify(response)
except Exception as e:
logger.error(f"Regular response error: {str(e)}")
return jsonify({"error": str(e)}), 500