Restructure foe 2 main providers and tools conversion

This commit is contained in:
2025-02-26 19:51:49 +00:00
parent a8a3d6d1a1
commit 3904bfc644
13 changed files with 1126 additions and 100 deletions

View File

@@ -68,7 +68,8 @@ lint.select = [
lint.ignore = [
"C416", # Unnecessary list comprehension - rewrite as a generator expression
"C408", # Unnecessary `dict` call - rewrite as a literal
"ISC001" # Single line implicit string concatenation
"ISC001", # Single line implicit string concatenation
"C901"
]
lint.fixable = ["ALL"]

View File

@@ -1,109 +1,145 @@
"""
Client for making API calls to various LLM providers using their official SDKs.
Multi-provider LLM client for Airflow Wingman.
This module contains the LLMClient class that supports multiple LLM providers
(OpenAI, Anthropic, OpenRouter) through a unified interface.
"""
from collections.abc import Generator
import traceback
from typing import Any
from anthropic import Anthropic
from openai import OpenAI
from airflow.utils.log.logging_mixin import LoggingMixin
from flask import session
from airflow_wingman.providers import create_llm_provider
from airflow_wingman.tools import list_airflow_tools
# Create a logger instance
logger = LoggingMixin().log
class LLMClient:
def __init__(self, api_key: str):
"""Initialize the LLM client.
"""
Multi-provider LLM client for Airflow Wingman.
This class handles chat completion requests to various LLM providers
(OpenAI, Anthropic, OpenRouter) through a unified interface.
"""
def __init__(self, provider_name: str, api_key: str, base_url: str | None = None):
"""
Initialize the LLM client.
Args:
provider_name: Name of the provider (openai, anthropic, openrouter)
api_key: API key for the provider
base_url: Optional base URL for the provider API
"""
self.provider_name = provider_name
self.api_key = api_key
self.openai_client = OpenAI(api_key=api_key)
self.anthropic_client = Anthropic(api_key=api_key)
self.openrouter_client = OpenAI(
base_url="https://openrouter.ai/api/v1",
api_key=api_key,
default_headers={
"HTTP-Referer": "Airflow Wingman", # Required by OpenRouter
"X-Title": "Airflow Wingman", # Required by OpenRouter
},
)
self.base_url = base_url
self.provider = create_llm_provider(provider_name, api_key, base_url)
self.airflow_tools = []
def chat_completion(
self, messages: list[dict[str, str]], model: str, provider: str, temperature: float = 0.7, max_tokens: int | None = None, stream: bool = False
) -> Generator[str, None, None] | dict:
"""Send a chat completion request to the specified provider.
def set_airflow_tools(self, tools: list):
"""
Set the available Airflow tools.
Args:
tools: List of Airflow Tool objects
"""
self.airflow_tools = tools
def chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float = 0.7, max_tokens: int | None = None, stream: bool = False) -> dict[str, Any]:
"""
Send a chat completion request to the LLM provider.
Args:
messages: List of message dictionaries with 'role' and 'content'
model: Model identifier
provider: Provider identifier (openai, anthropic, openrouter)
temperature: Sampling temperature (0-1)
max_tokens: Maximum tokens to generate
stream: Whether to stream the response
stream: Whether to stream the response (default is True)
Returns:
If stream=True, returns a generator yielding response chunks
If stream=False, returns the complete response
Dictionary with the response content or a generator for streaming
"""
# Get provider-specific tool definitions from Airflow tools
provider_tools = self.provider.convert_tools(self.airflow_tools)
try:
if provider == "openai":
return self._openai_chat_completion(messages, model, temperature, max_tokens, stream)
elif provider == "anthropic":
return self._anthropic_chat_completion(messages, model, temperature, max_tokens, stream)
elif provider == "openrouter":
return self._openrouter_chat_completion(messages, model, temperature, max_tokens, stream)
# Make the initial request with tools
logger.info(f"Sending chat completion request to {self.provider_name} with model: {model}")
response = self.provider.create_chat_completion(messages=messages, model=model, temperature=temperature, max_tokens=max_tokens, stream=stream, tools=provider_tools)
logger.info(f"Received response from {self.provider_name}")
# If streaming, return the generator directly
if stream:
return self.provider.get_streaming_content(response)
# For non-streaming responses, handle tool calls if present
if self.provider.has_tool_calls(response):
logger.info("Response contains tool calls")
# Process tool calls and get results
cookie = session.get("airflow_cookie")
if not cookie:
error_msg = "No Airflow cookie available"
logger.error(error_msg)
return {"error": error_msg}
tool_results = self.provider.process_tool_calls(response, cookie)
# Create a follow-up completion with the tool results
logger.info("Making follow-up request with tool results")
follow_up_response = self.provider.create_follow_up_completion(
messages=messages, model=model, temperature=temperature, max_tokens=max_tokens, tool_results=tool_results, original_response=response
)
return {"content": self.provider.get_content(follow_up_response)}
else:
return {"error": f"Unknown provider: {provider}"}
logger.info("Response does not contain tool calls")
return {"content": self.provider.get_content(response)}
except Exception as e:
error_msg = f"Error in {self.provider_name} API call: {str(e)}\\n{traceback.format_exc()}"
logger.error(error_msg)
return {"error": f"API request failed: {str(e)}"}
def _openai_chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float, max_tokens: int | None, stream: bool):
"""Handle OpenAI chat completion requests."""
response = self.openai_client.chat.completions.create(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream)
@classmethod
def from_config(cls, config: dict[str, Any]) -> "LLMClient":
"""
Create an LLMClient instance from a configuration dictionary.
if stream:
Args:
config: Configuration dictionary with provider_name, api_key, and optional base_url
def response_generator():
for chunk in response:
if chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
Returns:
LLMClient instance
"""
provider_name = config.get("provider_name", "openai")
api_key = config.get("api_key")
base_url = config.get("base_url")
return response_generator()
else:
return {"content": response.choices[0].message.content}
if not api_key:
raise ValueError("API key is required")
def _anthropic_chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float, max_tokens: int | None, stream: bool):
"""Handle Anthropic chat completion requests."""
# Convert messages to Anthropic format
system_message = next((m["content"] for m in messages if m["role"] == "system"), None)
conversation = []
for m in messages:
if m["role"] != "system":
conversation.append({"role": "assistant" if m["role"] == "assistant" else "user", "content": m["content"]})
return cls(provider_name=provider_name, api_key=api_key, base_url=base_url)
response = self.anthropic_client.messages.create(model=model, messages=conversation, system=system_message, temperature=temperature, max_tokens=max_tokens, stream=stream)
def refresh_tools(self, cookie: str) -> None:
"""
Refresh the available Airflow tools.
if stream:
def response_generator():
for chunk in response:
if chunk.delta.text:
yield chunk.delta.text
return response_generator()
else:
return {"content": response.content[0].text}
def _openrouter_chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float, max_tokens: int | None, stream: bool):
"""Handle OpenRouter chat completion requests."""
response = self.openrouter_client.chat.completions.create(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream)
if stream:
def response_generator():
for chunk in response:
if chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
return response_generator()
else:
return {"content": response.choices[0].message.content}
Args:
cookie: Airflow cookie for authentication
"""
try:
logger.info("Refreshing Airflow tools")
tools = list_airflow_tools(cookie)
self.set_airflow_tools(tools)
logger.info(f"Refreshed {len(tools)} Airflow tools")
except Exception as e:
error_msg = f"Error refreshing Airflow tools: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
# Don't raise the exception, just log it
# The client will continue to use the existing tools (if any)

View File

@@ -17,14 +17,14 @@ MODELS = {
"endpoint": "https://api.anthropic.com/v1/messages",
"models": [
{
"id": "claude-3.7-sonnet",
"id": "claude-3-7-sonnet-20250219",
"name": "Claude 3.7 Sonnet",
"default": True,
"context_window": 200000,
"description": "Input $3/M tokens, Output $15/M tokens",
},
{
"id": "claude-3.5-haiku",
"id": "claude-3-5-haiku-20241022",
"name": "Claude 3.5 Haiku",
"default": False,
"context_window": 200000,

View File

@@ -8,9 +8,11 @@ INSTRUCTIONS = {
You have deep knowledge of Apache Airflow's architecture, DAGs, operators, and best practices.
The Airflow version being used is >=2.10.
You have access to the following Airflow API tools:
You have access to Airflow MCP tools that you can use to fetch information and help users understand
and manage their Airflow environment.
You can use these tools to fetch information and help users understand and manage their Airflow environment.
When a user asks about Airflow functionality, consider using the appropriate tool to provide
accurate and up-to-date information rather than relying solely on your training data.
"""
}

View File

@@ -0,0 +1,41 @@
"""
Provider factory for Airflow Wingman.
This module contains the factory function to create provider instances
based on the provider name.
"""
from airflow_wingman.providers.anthropic_provider import AnthropicProvider
from airflow_wingman.providers.base import BaseLLMProvider
from airflow_wingman.providers.openai_provider import OpenAIProvider
def create_llm_provider(provider_name: str, api_key: str, base_url: str | None = None) -> BaseLLMProvider:
"""
Create a provider instance based on the provider name.
Args:
provider_name: Name of the provider (openai, anthropic, openrouter)
api_key: API key for the provider
base_url: Optional base URL for the provider API
Returns:
Provider instance
Raises:
ValueError: If the provider is not supported
"""
provider_name = provider_name.lower()
if provider_name == "openai":
return OpenAIProvider(api_key=api_key, base_url=base_url)
elif provider_name == "openrouter":
# OpenRouter uses the OpenAI API format, so we can use the OpenAI provider
# with a custom base URL
if not base_url:
base_url = "https://openrouter.ai/api/v1"
return OpenAIProvider(api_key=api_key, base_url=base_url)
elif provider_name == "anthropic":
return AnthropicProvider(api_key=api_key)
else:
raise ValueError(f"Unsupported provider: {provider_name}. Supported providers: openai, anthropic, openrouter")

View File

@@ -0,0 +1,288 @@
"""
Anthropic provider implementation for Airflow Wingman.
This module contains the Anthropic provider implementation that handles
API requests, tool conversion, and response processing for Anthropic's Claude models.
"""
import traceback
from typing import Any
from airflow.utils.log.logging_mixin import LoggingMixin
from anthropic import Anthropic
from airflow_wingman.providers.base import BaseLLMProvider
from airflow_wingman.tools import execute_airflow_tool
from airflow_wingman.tools.conversion import convert_to_anthropic_tools
logger = LoggingMixin().log
class AnthropicProvider(BaseLLMProvider):
"""
Anthropic provider implementation.
This class handles API requests, tool conversion, and response processing
for the Anthropic API (Claude models).
"""
def __init__(self, api_key: str):
"""
Initialize the Anthropic provider.
Args:
api_key: API key for Anthropic
"""
self.api_key = api_key
self.client = Anthropic(api_key=api_key)
def convert_tools(self, airflow_tools: list) -> list:
"""
Convert Airflow tools to Anthropic format.
Args:
airflow_tools: List of Airflow tools from MCP server
Returns:
List of Anthropic tool definitions
"""
return convert_to_anthropic_tools(airflow_tools)
def create_chat_completion(
self, messages: list[dict[str, Any]], model: str, temperature: float = 0.7, max_tokens: int | None = None, stream: bool = False, tools: list[dict[str, Any]] | None = None
) -> Any:
"""
Make API request to Anthropic.
Args:
messages: List of message dictionaries with 'role' and 'content'
model: Model identifier
temperature: Sampling temperature (0-1)
max_tokens: Maximum tokens to generate
stream: Whether to stream the response
tools: List of tool definitions in Anthropic format
Returns:
Anthropic response object
Raises:
Exception: If the API request fails
"""
# Convert max_tokens to Anthropic's max_tokens parameter (if provided)
max_tokens_param = max_tokens if max_tokens is not None else 4096
# Convert messages from ChatML format to Anthropic's format
anthropic_messages = self._convert_to_anthropic_messages(messages)
try:
logger.info(f"Sending chat completion request to Anthropic with model: {model}")
# Create request parameters
params = {"model": model, "messages": anthropic_messages, "temperature": temperature, "max_tokens": max_tokens_param, "stream": stream}
# Add tools if provided
if tools and len(tools) > 0:
params["tools"] = tools
# Make the API request
response = self.client.messages.create(**params)
logger.info("Received response from Anthropic")
return response
except Exception as e:
error_msg = str(e)
logger.error(f"Failed to get response from Anthropic: {error_msg}\n{traceback.format_exc()}")
raise
def _convert_to_anthropic_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""
Convert messages from ChatML format to Anthropic's format.
Args:
messages: List of message dictionaries in ChatML format
Returns:
List of message dictionaries in Anthropic format
"""
anthropic_messages = []
for message in messages:
role = message["role"]
content = message["content"]
# Map ChatML roles to Anthropic roles
if role == "system":
# System messages in Anthropic are handled differently
# We'll add them as a user message with a special prefix
anthropic_messages.append({"role": "user", "content": f"<system>\n{content}\n</system>"})
elif role == "user":
anthropic_messages.append({"role": "user", "content": content})
elif role == "assistant":
anthropic_messages.append({"role": "assistant", "content": content})
elif role == "tool":
# Tool messages in ChatML become part of the user message in Anthropic
# We'll handle this in the follow-up completion
continue
return anthropic_messages
def has_tool_calls(self, response: Any) -> bool:
"""
Check if the response contains tool calls.
Args:
response: Anthropic response object
Returns:
True if the response contains tool calls, False otherwise
"""
# Check if any content block is a tool_use block
if not hasattr(response, "content"):
return False
for block in response.content:
if isinstance(block, dict) and block.get("type") == "tool_use":
return True
return False
def process_tool_calls(self, response: Any, cookie: str) -> dict[str, Any]:
"""
Process tool calls from the response.
Args:
response: Anthropic response object
cookie: Airflow cookie for authentication
Returns:
Dictionary mapping tool call IDs to results
"""
results = {}
if not self.has_tool_calls(response):
return results
# Extract tool_use blocks
tool_use_blocks = [block for block in response.content if isinstance(block, dict) and block.get("type") == "tool_use"]
for block in tool_use_blocks:
tool_id = block.get("id")
tool_name = block.get("name")
tool_input = block.get("input", {})
try:
# Execute the Airflow tool with the provided arguments and cookie
logger.info(f"Executing tool: {tool_name} with arguments: {tool_input}")
result = execute_airflow_tool(tool_name, tool_input, cookie)
logger.info(f"Tool execution result: {result}")
results[tool_id] = {"status": "success", "result": result}
except Exception as e:
error_msg = f"Error executing tool: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
results[tool_id] = {"status": "error", "message": error_msg}
return results
def create_follow_up_completion(
self, messages: list[dict[str, Any]], model: str, temperature: float = 0.7, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None
) -> Any:
"""
Create a follow-up completion with tool results.
Args:
messages: Original messages
model: Model identifier
temperature: Sampling temperature (0-1)
max_tokens: Maximum tokens to generate
tool_results: Results of tool executions
original_response: Original response with tool calls
Returns:
Anthropic response object
"""
if not original_response or not tool_results:
return original_response
# Extract tool_use blocks from the original response
tool_use_blocks = [block for block in original_response.content if isinstance(block, dict) and block.get("type") == "tool_use"]
# Create tool result blocks
tool_result_blocks = []
for tool_id, result in tool_results.items():
tool_result_blocks.append({"type": "tool_result", "tool_use_id": tool_id, "content": result.get("result", str(result))})
# Convert original messages to Anthropic format
anthropic_messages = self._convert_to_anthropic_messages(messages)
# Add the assistant response with tool use
anthropic_messages.append({"role": "assistant", "content": tool_use_blocks})
# Add the user message with tool results
anthropic_messages.append({"role": "user", "content": tool_result_blocks})
# Make a second request to get the final response
logger.info("Making second request with tool results")
return self.create_chat_completion(
messages=anthropic_messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
stream=False,
tools=None, # No tools needed for follow-up
)
def get_content(self, response: Any) -> str:
"""
Extract content from the response.
Args:
response: Anthropic response object
Returns:
Content string from the response
"""
if not hasattr(response, "content"):
return ""
# Combine all text blocks into a single string
content_parts = []
for block in response.content:
if isinstance(block, dict) and block.get("type") == "text":
content_parts.append(block.get("text", ""))
elif isinstance(block, str):
content_parts.append(block)
return "".join(content_parts)
def get_streaming_content(self, response: Any) -> Any:
"""
Get a generator for streaming content from the response.
Args:
response: Anthropic streaming response object
Returns:
Generator yielding content chunks
"""
def generate():
for chunk in response:
logger.debug(f"Chunk type: {type(chunk)}")
# Handle different types of chunks from Anthropic API
content = None
if hasattr(chunk, "type") and chunk.type == "content_block_delta":
if hasattr(chunk, "delta") and hasattr(chunk.delta, "text"):
content = chunk.delta.text
elif hasattr(chunk, "delta") and hasattr(chunk.delta, "text"):
content = chunk.delta.text
elif hasattr(chunk, "content") and chunk.content:
for block in chunk.content:
if isinstance(block, dict) and block.get("type") == "text":
content = block.get("text", "")
if content:
# Don't do any newline replacement here
yield content
return generate()

View File

@@ -0,0 +1,125 @@
"""
Base provider interface for Airflow Wingman.
This module contains the base provider interface that all provider implementations
must adhere to. It defines the methods required for tool conversion, API requests,
and response processing.
"""
from abc import ABC, abstractmethod
from typing import Any
class BaseLLMProvider(ABC):
"""
Base provider interface for LLM providers.
This abstract class defines the methods that all provider implementations
must implement to support tool integration.
"""
@abstractmethod
def convert_tools(self, airflow_tools: list) -> list:
"""
Convert internal tool representation to provider format.
Args:
airflow_tools: List of Airflow tools from MCP server
Returns:
List of provider-specific tool definitions
"""
pass
@abstractmethod
def create_chat_completion(
self, messages: list[dict[str, Any]], model: str, temperature: float = 0.7, max_tokens: int | None = None, stream: bool = False, tools: list[dict[str, Any]] | None = None
) -> Any:
"""
Make API request to provider.
Args:
messages: List of message dictionaries with 'role' and 'content'
model: Model identifier
temperature: Sampling temperature (0-1)
max_tokens: Maximum tokens to generate
stream: Whether to stream the response
tools: List of tool definitions in provider format
Returns:
Provider-specific response object
"""
pass
@abstractmethod
def has_tool_calls(self, response: Any) -> bool:
"""
Check if the response contains tool calls.
Args:
response: Provider-specific response object
Returns:
True if the response contains tool calls, False otherwise
"""
pass
@abstractmethod
def process_tool_calls(self, response: Any, cookie: str) -> dict[str, Any]:
"""
Process tool calls from the response.
Args:
response: Provider-specific response object
cookie: Airflow cookie for authentication
Returns:
Dictionary mapping tool call IDs to results
"""
pass
@abstractmethod
def create_follow_up_completion(
self, messages: list[dict[str, Any]], model: str, temperature: float = 0.7, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None
) -> Any:
"""
Create a follow-up completion with tool results.
Args:
messages: Original messages
model: Model identifier
temperature: Sampling temperature (0-1)
max_tokens: Maximum tokens to generate
tool_results: Results of tool executions
original_response: Original response with tool calls
Returns:
Provider-specific response object
"""
pass
@abstractmethod
def get_content(self, response: Any) -> str:
"""
Extract content from the response.
Args:
response: Provider-specific response object
Returns:
Content string from the response
"""
pass
@abstractmethod
def get_streaming_content(self, response: Any) -> Any:
"""
Get a generator for streaming content from the response.
Args:
response: Provider-specific response object
Returns:
Generator yielding content chunks
"""
pass

View File

@@ -0,0 +1,224 @@
"""
OpenAI provider implementation for Airflow Wingman.
This module contains the OpenAI provider implementation that handles
API requests, tool conversion, and response processing for OpenAI.
"""
import json
import traceback
from typing import Any
from airflow.utils.log.logging_mixin import LoggingMixin
from openai import OpenAI
from airflow_wingman.providers.base import BaseLLMProvider
from airflow_wingman.tools import execute_airflow_tool
from airflow_wingman.tools.conversion import convert_to_openai_tools
# Create a logger instance
logger = LoggingMixin().log
class OpenAIProvider(BaseLLMProvider):
"""
OpenAI provider implementation.
This class handles API requests, tool conversion, and response processing
for the OpenAI API. It can also be used for OpenRouter with a custom base URL.
"""
def __init__(self, api_key: str, base_url: str | None = None):
"""
Initialize the OpenAI provider.
Args:
api_key: API key for OpenAI
base_url: Optional base URL for the API (used for OpenRouter)
"""
self.api_key = api_key
self.client = OpenAI(api_key=api_key, base_url=base_url)
def convert_tools(self, airflow_tools: list) -> list:
"""
Convert Airflow tools to OpenAI format.
Args:
airflow_tools: List of Airflow tools from MCP server
Returns:
List of OpenAI tool definitions
"""
return convert_to_openai_tools(airflow_tools)
def create_chat_completion(
self, messages: list[dict[str, Any]], model: str, temperature: float = 0.7, max_tokens: int | None = None, stream: bool = False, tools: list[dict[str, Any]] | None = None
) -> Any:
"""
Make API request to OpenAI.
Args:
messages: List of message dictionaries with 'role' and 'content'
model: Model identifier
temperature: Sampling temperature (0-1)
max_tokens: Maximum tokens to generate
stream: Whether to stream the response
tools: List of tool definitions in OpenAI format
Returns:
OpenAI response object
Raises:
Exception: If the API request fails
"""
# Only include tools if we have any
has_tools = tools is not None and len(tools) > 0
tool_choice = "auto" if has_tools else None
try:
logger.info(f"Sending chat completion request to OpenAI with model: {model}")
response = self.client.chat.completions.create(
model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream, tools=tools if has_tools else None, tool_choice=tool_choice
)
logger.info("Received response from OpenAI")
return response
except Exception as e:
# If the API call fails due to tools not being supported, retry without tools
error_msg = str(e)
logger.warning(f"Error in OpenAI API call: {error_msg}")
if "tools" in error_msg.lower():
logger.info("Retrying without tools")
response = self.client.chat.completions.create(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream)
return response
else:
logger.error(f"Failed to get response from OpenAI: {error_msg}\n{traceback.format_exc()}")
raise
def has_tool_calls(self, response: Any) -> bool:
"""
Check if the response contains tool calls.
Args:
response: OpenAI response object
Returns:
True if the response contains tool calls, False otherwise
"""
message = response.choices[0].message
return hasattr(message, "tool_calls") and message.tool_calls
def process_tool_calls(self, response: Any, cookie: str) -> dict[str, Any]:
"""
Process tool calls from the response.
Args:
response: OpenAI response object
cookie: Airflow cookie for authentication
Returns:
Dictionary mapping tool call IDs to results
"""
results = {}
message = response.choices[0].message
if not self.has_tool_calls(response):
return results
for tool_call in message.tool_calls:
tool_id = tool_call.id
function_name = tool_call.function.name
arguments = json.loads(tool_call.function.arguments)
try:
# Execute the Airflow tool with the provided arguments and cookie
logger.info(f"Executing tool: {function_name} with arguments: {arguments}")
result = execute_airflow_tool(function_name, arguments, cookie)
logger.info(f"Tool execution result: {result}")
results[tool_id] = {"status": "success", "result": result}
except Exception as e:
error_msg = f"Error executing tool: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
results[tool_id] = {"status": "error", "message": error_msg}
return results
def create_follow_up_completion(
self, messages: list[dict[str, Any]], model: str, temperature: float = 0.7, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None
) -> Any:
"""
Create a follow-up completion with tool results.
Args:
messages: Original messages
model: Model identifier
temperature: Sampling temperature (0-1)
max_tokens: Maximum tokens to generate
tool_results: Results of tool executions
original_response: Original response with tool calls
Returns:
OpenAI response object
"""
if not original_response or not tool_results:
return original_response
# Get the original message with tool calls
original_message = original_response.choices[0].message
# Create a new message with the tool calls
assistant_message = {
"role": "assistant",
"content": None,
"tool_calls": [{"id": tc.id, "type": "function", "function": {"name": tc.function.name, "arguments": tc.function.arguments}} for tc in original_message.tool_calls],
}
# Create tool result messages
tool_messages = []
for tool_call_id, result in tool_results.items():
tool_messages.append({"role": "tool", "tool_call_id": tool_call_id, "content": result.get("result", str(result))})
# Add the original messages, assistant message, and tool results
new_messages = messages + [assistant_message] + tool_messages
# Make a second request to get the final response
logger.info("Making second request with tool results")
return self.create_chat_completion(
messages=new_messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
stream=False,
tools=None, # No tools needed for follow-up
)
def get_content(self, response: Any) -> str:
"""
Extract content from the response.
Args:
response: OpenAI response object
Returns:
Content string from the response
"""
return response.choices[0].message.content
def get_streaming_content(self, response: Any) -> Any:
"""
Get a generator for streaming content from the response.
Args:
response: OpenAI streaming response object
Returns:
Generator yielding content chunks
"""
def generate():
for chunk in response:
if chunk.choices and chunk.choices[0].delta.content:
# Don't do any newline replacement here
content = chunk.choices[0].delta.content
yield content
return generate()

View File

@@ -150,6 +150,7 @@
border: 1px solid #e9ecef;
border-radius: 15px 15px 15px 0;
padding: 10px 15px;
white-space: pre-wrap;
}
#chat-messages::after {
@@ -349,6 +350,8 @@ document.addEventListener('DOMContentLoaded', function() {
if (line.startsWith('data: ')) {
const content = line.slice(6);
if (content) {
// Use textContent to properly handle newlines
console.log('Received chunk:', JSON.stringify(content)); // Debug
currentMessageDiv.textContent += content;
fullResponse += content;
chatMessages.scrollTop = chatMessages.scrollHeight;

View File

@@ -0,0 +1,15 @@
"""
Tools module for Airflow Wingman.
This module contains the tools used by Airflow Wingman to interact with Airflow.
"""
from airflow_wingman.tools.conversion import convert_to_anthropic_tools, convert_to_openai_tools
from airflow_wingman.tools.execution import execute_airflow_tool, list_airflow_tools
__all__ = [
"convert_to_openai_tools",
"convert_to_anthropic_tools",
"list_airflow_tools",
"execute_airflow_tool",
]

View File

@@ -0,0 +1,139 @@
"""
Conversion utilities for Airflow Wingman tools.
This module contains functions to convert between different tool formats
for various LLM providers (OpenAI, Anthropic, etc.).
"""
from typing import Any
def convert_to_openai_tools(airflow_tools: list) -> list:
"""
Convert Airflow tools to OpenAI tool definitions.
Args:
airflow_tools: List of Airflow tools from MCP server
Returns:
List of OpenAI tool definitions
"""
openai_tools = []
for tool in airflow_tools:
# Initialize the OpenAI tool structure
openai_tool = {"type": "function", "function": {"name": tool.name, "description": tool.description or tool.name, "parameters": {"type": "object", "properties": {}, "required": []}}}
# Extract parameters directly from inputSchema if available
if hasattr(tool, "inputSchema") and tool.inputSchema:
# Set the type and required fields directly from the schema
if "type" in tool.inputSchema:
openai_tool["function"]["parameters"]["type"] = tool.inputSchema["type"]
# Add required parameters if specified
if "required" in tool.inputSchema:
openai_tool["function"]["parameters"]["required"] = tool.inputSchema["required"]
# Add properties from the input schema
if "properties" in tool.inputSchema:
for param_name, param_info in tool.inputSchema["properties"].items():
# Create parameter definition
param_def = {}
# Handle different schema constructs
if "anyOf" in param_info:
_handle_schema_construct(param_def, param_info, "anyOf")
elif "oneOf" in param_info:
_handle_schema_construct(param_def, param_info, "oneOf")
elif "allOf" in param_info:
_handle_schema_construct(param_def, param_info, "allOf")
elif "type" in param_info:
param_def["type"] = param_info["type"]
# Add format if available
if "format" in param_info:
param_def["format"] = param_info["format"]
else:
param_def["type"] = "string" # Default type
# Add description from title or param name
param_def["description"] = param_info.get("description", param_info.get("title", param_name))
# Add enum values if available
if "enum" in param_info:
param_def["enum"] = param_info["enum"]
# Add default value if available
if "default" in param_info and param_info["default"] is not None:
param_def["default"] = param_info["default"]
# Add to properties
openai_tool["function"]["parameters"]["properties"][param_name] = param_def
openai_tools.append(openai_tool)
return openai_tools
def convert_to_anthropic_tools(airflow_tools: list) -> list:
"""
Convert Airflow tools to Anthropic tool definitions.
Args:
airflow_tools: List of Airflow tools from MCP server
Returns:
List of Anthropic tool definitions
"""
anthropic_tools = []
for tool in airflow_tools:
# Initialize the Anthropic tool structure
anthropic_tool = {"name": tool.name, "description": tool.description or tool.name, "input_schema": {}}
# Extract parameters directly from inputSchema if available
if hasattr(tool, "inputSchema") and tool.inputSchema:
# Copy the input schema directly as Anthropic's format is similar to JSON Schema
anthropic_tool["input_schema"] = tool.inputSchema
else:
# Create a minimal schema if none exists
anthropic_tool["input_schema"] = {"type": "object", "properties": {}, "required": []}
anthropic_tools.append(anthropic_tool)
return anthropic_tools
def _handle_schema_construct(param_def: dict[str, Any], param_info: dict[str, Any], construct_type: str) -> None:
"""
Helper function to handle JSON Schema constructs like anyOf, oneOf, allOf.
Args:
param_def: Parameter definition to update
param_info: Parameter info from the schema
construct_type: Type of construct (anyOf, oneOf, allOf)
"""
# Get the list of schemas from the construct
schemas = param_info[construct_type]
# Find the first schema with a type
for schema in schemas:
if "type" in schema:
param_def["type"] = schema["type"]
# Add format if available
if "format" in schema:
param_def["format"] = schema["format"]
# Add enum values if available
if "enum" in schema:
param_def["enum"] = schema["enum"]
# Add default value if available
if "default" in schema and schema["default"] is not None:
param_def["default"] = schema["default"]
break
# If no type was found, default to string
if "type" not in param_def:
param_def["type"] = "string"

View File

@@ -0,0 +1,117 @@
"""
Tool execution module for Airflow Wingman.
This module contains functions to list and execute Airflow tools.
"""
import asyncio
import json
import traceback
from airflow import configuration
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow_mcp_server.config import AirflowConfig
from airflow_mcp_server.tools.tool_manager import get_airflow_tools, get_tool
# Create a logger instance
logger = LoggingMixin().log
async def _list_airflow_tools_async(cookie: str) -> list:
"""
Async implementation to list available Airflow tools.
Args:
cookie: Cookie for authentication
Returns:
List of available Airflow tools
"""
try:
# Set up configuration
base_url = f"{configuration.conf.get('webserver', 'base_url')}/api/v1/"
logger.info(f"Setting up AirflowConfig with base_url: {base_url}")
config = AirflowConfig(base_url=base_url, cookie=cookie, auth_token=None)
# Get available tools
logger.info("Getting Airflow tools...")
tools = await get_airflow_tools(config=config, mode="safe")
logger.info(f"Got {len(tools)} tools")
return tools
except Exception as e:
error_msg = f"Error listing Airflow tools: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
return []
def list_airflow_tools(cookie: str) -> list:
"""
Synchronous wrapper to list available Airflow tools.
Args:
cookie: Cookie for authentication
Returns:
List of available Airflow tools
"""
return asyncio.run(_list_airflow_tools_async(cookie))
async def _execute_airflow_tool_async(tool_name: str, arguments: dict, cookie: str) -> str:
"""
Async implementation to execute an Airflow tool.
Args:
tool_name: Name of the tool to execute
arguments: Arguments to pass to the tool
cookie: Cookie for authentication
Returns:
Result of the tool execution as a string
"""
try:
# Set up configuration
base_url = f"{configuration.conf.get('webserver', 'base_url')}/api/v1/"
logger.info(f"Setting up AirflowConfig with base_url: {base_url}")
config = AirflowConfig(base_url=base_url, cookie=cookie, auth_token=None)
# Get the tool
logger.info(f"Getting tool: {tool_name}")
tool = await get_tool(config=config, tool_name=tool_name)
if not tool:
error_msg = f"Tool not found: {tool_name}"
logger.error(error_msg)
return json.dumps({"error": error_msg})
# Execute the tool
logger.info(f"Executing tool: {tool_name} with arguments: {arguments}")
result = await tool.run(arguments)
# Convert result to string
if isinstance(result, dict | list):
result_str = json.dumps(result, indent=2)
else:
result_str = str(result)
logger.info(f"Tool execution result: {result_str[:100]}...")
return result_str
except Exception as e:
error_msg = f"Error executing tool: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
return json.dumps({"error": error_msg})
def execute_airflow_tool(tool_name: str, arguments: dict, cookie: str) -> str:
"""
Synchronous wrapper to execute an Airflow tool.
Args:
tool_name: Name of the tool to execute
arguments: Arguments to pass to the tool
cookie: Cookie for authentication
Returns:
Result of the tool execution as a string
"""
return asyncio.run(_execute_airflow_tool_async(tool_name, arguments, cookie))

View File

@@ -1,6 +1,6 @@
"""Views for Airflow Wingman plugin."""
from flask import Response, request, stream_with_context
from flask import Response, request, session
from flask.json import jsonify
from flask_appbuilder import BaseView as AppBuilderBaseView, expose
@@ -8,6 +8,7 @@ from airflow_wingman.llm_client import LLMClient
from airflow_wingman.llms_models import MODELS
from airflow_wingman.notes import INTERFACE_MESSAGES
from airflow_wingman.prompt_engineering import prepare_messages
from airflow_wingman.tools import list_airflow_tools
class WingmanView(AppBuilderBaseView):
@@ -28,8 +29,32 @@ class WingmanView(AppBuilderBaseView):
try:
data = self._validate_chat_request(request.get_json())
# Create a new client for this request
client = LLMClient(data["api_key"])
if data.get("cookie"):
session["airflow_cookie"] = data["cookie"]
# Get available Airflow tools using the stored cookie
airflow_tools = []
if session.get("airflow_cookie"):
try:
airflow_tools = list_airflow_tools(session["airflow_cookie"])
except Exception as e:
# Log the error but continue without tools
print(f"Error fetching Airflow tools: {str(e)}")
# Prepare messages with Airflow tools included in the prompt
data["messages"] = prepare_messages(data["messages"])
# Get provider name from request or use default
provider_name = data.get("provider", "openai")
# Get base URL from models configuration based on provider
base_url = MODELS.get(provider_name, {}).get("endpoint")
# Create a new client for this request with the appropriate provider
client = LLMClient(provider_name=provider_name, api_key=data["api_key"], base_url=base_url)
# Set the Airflow tools for the client to use
client.set_airflow_tools(airflow_tools)
if data["stream"]:
return self._handle_streaming_response(client, data)
@@ -46,39 +71,49 @@ class WingmanView(AppBuilderBaseView):
if not data:
raise ValueError("No data provided")
required_fields = ["provider", "model", "messages", "api_key"]
required_fields = ["model", "messages", "api_key"]
missing = [f for f in required_fields if not data.get(f)]
if missing:
raise ValueError(f"Missing required fields: {', '.join(missing)}")
# Prepare messages with system instruction while maintaining history
messages = data["messages"]
messages = prepare_messages(messages)
# Validate provider if provided
provider = data.get("provider", "openai")
if provider not in MODELS:
raise ValueError(f"Unsupported provider: {provider}. Supported providers: {', '.join(MODELS.keys())}")
return {
"provider": data["provider"],
"model": data["model"],
"messages": messages,
"messages": data["messages"],
"api_key": data["api_key"],
"stream": data.get("stream", False),
"stream": data.get("stream", True),
"temperature": data.get("temperature", 0.7),
"max_tokens": data.get("max_tokens"),
"cookie": data.get("cookie"),
"provider": provider,
"base_url": data.get("base_url"),
}
def _handle_streaming_response(self, client: LLMClient, data: dict) -> Response:
"""Handle streaming response."""
try:
# Get the streaming generator from the client
generator = client.chat_completion(messages=data["messages"], model=data["model"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=True)
def generate():
for chunk in client.chat_completion(messages=data["messages"], model=data["model"], provider=data["provider"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=True):
yield f"data: {chunk}\n\n"
def stream_response():
# Send SSE format for each chunk
for chunk in generator:
if chunk:
yield f"data: {chunk}\n\n"
response = Response(stream_with_context(generate()), mimetype="text/event-stream")
response.headers["Content-Type"] = "text/event-stream"
response.headers["Cache-Control"] = "no-cache"
response.headers["Connection"] = "keep-alive"
return response
# Signal end of stream
yield "data: [DONE]\n\n"
return Response(stream_response(), mimetype="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
except Exception as e:
# If streaming fails, return error
return jsonify({"error": str(e)}), 500
def _handle_regular_response(self, client: LLMClient, data: dict) -> Response:
"""Handle regular response."""
response = client.chat_completion(messages=data["messages"], model=data["model"], provider=data["provider"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=False)
response = client.chat_completion(messages=data["messages"], model=data["model"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=False)
return jsonify(response)