Restructure foe 2 main providers and tools conversion
This commit is contained in:
@@ -1,109 +1,145 @@
|
||||
"""
|
||||
Client for making API calls to various LLM providers using their official SDKs.
|
||||
Multi-provider LLM client for Airflow Wingman.
|
||||
|
||||
This module contains the LLMClient class that supports multiple LLM providers
|
||||
(OpenAI, Anthropic, OpenRouter) through a unified interface.
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
import traceback
|
||||
from typing import Any
|
||||
|
||||
from anthropic import Anthropic
|
||||
from openai import OpenAI
|
||||
from airflow.utils.log.logging_mixin import LoggingMixin
|
||||
from flask import session
|
||||
|
||||
from airflow_wingman.providers import create_llm_provider
|
||||
from airflow_wingman.tools import list_airflow_tools
|
||||
|
||||
# Create a logger instance
|
||||
logger = LoggingMixin().log
|
||||
|
||||
|
||||
class LLMClient:
|
||||
def __init__(self, api_key: str):
|
||||
"""Initialize the LLM client.
|
||||
"""
|
||||
Multi-provider LLM client for Airflow Wingman.
|
||||
|
||||
This class handles chat completion requests to various LLM providers
|
||||
(OpenAI, Anthropic, OpenRouter) through a unified interface.
|
||||
"""
|
||||
|
||||
def __init__(self, provider_name: str, api_key: str, base_url: str | None = None):
|
||||
"""
|
||||
Initialize the LLM client.
|
||||
|
||||
Args:
|
||||
provider_name: Name of the provider (openai, anthropic, openrouter)
|
||||
api_key: API key for the provider
|
||||
base_url: Optional base URL for the provider API
|
||||
"""
|
||||
self.provider_name = provider_name
|
||||
self.api_key = api_key
|
||||
self.openai_client = OpenAI(api_key=api_key)
|
||||
self.anthropic_client = Anthropic(api_key=api_key)
|
||||
self.openrouter_client = OpenAI(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
api_key=api_key,
|
||||
default_headers={
|
||||
"HTTP-Referer": "Airflow Wingman", # Required by OpenRouter
|
||||
"X-Title": "Airflow Wingman", # Required by OpenRouter
|
||||
},
|
||||
)
|
||||
self.base_url = base_url
|
||||
self.provider = create_llm_provider(provider_name, api_key, base_url)
|
||||
self.airflow_tools = []
|
||||
|
||||
def chat_completion(
|
||||
self, messages: list[dict[str, str]], model: str, provider: str, temperature: float = 0.7, max_tokens: int | None = None, stream: bool = False
|
||||
) -> Generator[str, None, None] | dict:
|
||||
"""Send a chat completion request to the specified provider.
|
||||
def set_airflow_tools(self, tools: list):
|
||||
"""
|
||||
Set the available Airflow tools.
|
||||
|
||||
Args:
|
||||
tools: List of Airflow Tool objects
|
||||
"""
|
||||
self.airflow_tools = tools
|
||||
|
||||
def chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float = 0.7, max_tokens: int | None = None, stream: bool = False) -> dict[str, Any]:
|
||||
"""
|
||||
Send a chat completion request to the LLM provider.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries with 'role' and 'content'
|
||||
model: Model identifier
|
||||
provider: Provider identifier (openai, anthropic, openrouter)
|
||||
temperature: Sampling temperature (0-1)
|
||||
max_tokens: Maximum tokens to generate
|
||||
stream: Whether to stream the response
|
||||
stream: Whether to stream the response (default is True)
|
||||
|
||||
Returns:
|
||||
If stream=True, returns a generator yielding response chunks
|
||||
If stream=False, returns the complete response
|
||||
Dictionary with the response content or a generator for streaming
|
||||
"""
|
||||
# Get provider-specific tool definitions from Airflow tools
|
||||
provider_tools = self.provider.convert_tools(self.airflow_tools)
|
||||
|
||||
try:
|
||||
if provider == "openai":
|
||||
return self._openai_chat_completion(messages, model, temperature, max_tokens, stream)
|
||||
elif provider == "anthropic":
|
||||
return self._anthropic_chat_completion(messages, model, temperature, max_tokens, stream)
|
||||
elif provider == "openrouter":
|
||||
return self._openrouter_chat_completion(messages, model, temperature, max_tokens, stream)
|
||||
# Make the initial request with tools
|
||||
logger.info(f"Sending chat completion request to {self.provider_name} with model: {model}")
|
||||
response = self.provider.create_chat_completion(messages=messages, model=model, temperature=temperature, max_tokens=max_tokens, stream=stream, tools=provider_tools)
|
||||
logger.info(f"Received response from {self.provider_name}")
|
||||
|
||||
# If streaming, return the generator directly
|
||||
if stream:
|
||||
return self.provider.get_streaming_content(response)
|
||||
|
||||
# For non-streaming responses, handle tool calls if present
|
||||
if self.provider.has_tool_calls(response):
|
||||
logger.info("Response contains tool calls")
|
||||
|
||||
# Process tool calls and get results
|
||||
cookie = session.get("airflow_cookie")
|
||||
if not cookie:
|
||||
error_msg = "No Airflow cookie available"
|
||||
logger.error(error_msg)
|
||||
return {"error": error_msg}
|
||||
|
||||
tool_results = self.provider.process_tool_calls(response, cookie)
|
||||
|
||||
# Create a follow-up completion with the tool results
|
||||
logger.info("Making follow-up request with tool results")
|
||||
follow_up_response = self.provider.create_follow_up_completion(
|
||||
messages=messages, model=model, temperature=temperature, max_tokens=max_tokens, tool_results=tool_results, original_response=response
|
||||
)
|
||||
|
||||
return {"content": self.provider.get_content(follow_up_response)}
|
||||
else:
|
||||
return {"error": f"Unknown provider: {provider}"}
|
||||
logger.info("Response does not contain tool calls")
|
||||
return {"content": self.provider.get_content(response)}
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error in {self.provider_name} API call: {str(e)}\\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
return {"error": f"API request failed: {str(e)}"}
|
||||
|
||||
def _openai_chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float, max_tokens: int | None, stream: bool):
|
||||
"""Handle OpenAI chat completion requests."""
|
||||
response = self.openai_client.chat.completions.create(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream)
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "LLMClient":
|
||||
"""
|
||||
Create an LLMClient instance from a configuration dictionary.
|
||||
|
||||
if stream:
|
||||
Args:
|
||||
config: Configuration dictionary with provider_name, api_key, and optional base_url
|
||||
|
||||
def response_generator():
|
||||
for chunk in response:
|
||||
if chunk.choices[0].delta.content:
|
||||
yield chunk.choices[0].delta.content
|
||||
Returns:
|
||||
LLMClient instance
|
||||
"""
|
||||
provider_name = config.get("provider_name", "openai")
|
||||
api_key = config.get("api_key")
|
||||
base_url = config.get("base_url")
|
||||
|
||||
return response_generator()
|
||||
else:
|
||||
return {"content": response.choices[0].message.content}
|
||||
if not api_key:
|
||||
raise ValueError("API key is required")
|
||||
|
||||
def _anthropic_chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float, max_tokens: int | None, stream: bool):
|
||||
"""Handle Anthropic chat completion requests."""
|
||||
# Convert messages to Anthropic format
|
||||
system_message = next((m["content"] for m in messages if m["role"] == "system"), None)
|
||||
conversation = []
|
||||
for m in messages:
|
||||
if m["role"] != "system":
|
||||
conversation.append({"role": "assistant" if m["role"] == "assistant" else "user", "content": m["content"]})
|
||||
return cls(provider_name=provider_name, api_key=api_key, base_url=base_url)
|
||||
|
||||
response = self.anthropic_client.messages.create(model=model, messages=conversation, system=system_message, temperature=temperature, max_tokens=max_tokens, stream=stream)
|
||||
def refresh_tools(self, cookie: str) -> None:
|
||||
"""
|
||||
Refresh the available Airflow tools.
|
||||
|
||||
if stream:
|
||||
|
||||
def response_generator():
|
||||
for chunk in response:
|
||||
if chunk.delta.text:
|
||||
yield chunk.delta.text
|
||||
|
||||
return response_generator()
|
||||
else:
|
||||
return {"content": response.content[0].text}
|
||||
|
||||
def _openrouter_chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float, max_tokens: int | None, stream: bool):
|
||||
"""Handle OpenRouter chat completion requests."""
|
||||
response = self.openrouter_client.chat.completions.create(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream)
|
||||
|
||||
if stream:
|
||||
|
||||
def response_generator():
|
||||
for chunk in response:
|
||||
if chunk.choices[0].delta.content:
|
||||
yield chunk.choices[0].delta.content
|
||||
|
||||
return response_generator()
|
||||
else:
|
||||
return {"content": response.choices[0].message.content}
|
||||
Args:
|
||||
cookie: Airflow cookie for authentication
|
||||
"""
|
||||
try:
|
||||
logger.info("Refreshing Airflow tools")
|
||||
tools = list_airflow_tools(cookie)
|
||||
self.set_airflow_tools(tools)
|
||||
logger.info(f"Refreshed {len(tools)} Airflow tools")
|
||||
except Exception as e:
|
||||
error_msg = f"Error refreshing Airflow tools: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
# Don't raise the exception, just log it
|
||||
# The client will continue to use the existing tools (if any)
|
||||
|
||||
@@ -17,14 +17,14 @@ MODELS = {
|
||||
"endpoint": "https://api.anthropic.com/v1/messages",
|
||||
"models": [
|
||||
{
|
||||
"id": "claude-3.7-sonnet",
|
||||
"id": "claude-3-7-sonnet-20250219",
|
||||
"name": "Claude 3.7 Sonnet",
|
||||
"default": True,
|
||||
"context_window": 200000,
|
||||
"description": "Input $3/M tokens, Output $15/M tokens",
|
||||
},
|
||||
{
|
||||
"id": "claude-3.5-haiku",
|
||||
"id": "claude-3-5-haiku-20241022",
|
||||
"name": "Claude 3.5 Haiku",
|
||||
"default": False,
|
||||
"context_window": 200000,
|
||||
|
||||
@@ -8,9 +8,11 @@ INSTRUCTIONS = {
|
||||
You have deep knowledge of Apache Airflow's architecture, DAGs, operators, and best practices.
|
||||
The Airflow version being used is >=2.10.
|
||||
|
||||
You have access to the following Airflow API tools:
|
||||
You have access to Airflow MCP tools that you can use to fetch information and help users understand
|
||||
and manage their Airflow environment.
|
||||
|
||||
You can use these tools to fetch information and help users understand and manage their Airflow environment.
|
||||
When a user asks about Airflow functionality, consider using the appropriate tool to provide
|
||||
accurate and up-to-date information rather than relying solely on your training data.
|
||||
"""
|
||||
}
|
||||
|
||||
|
||||
41
src/airflow_wingman/providers/__init__.py
Normal file
41
src/airflow_wingman/providers/__init__.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""
|
||||
Provider factory for Airflow Wingman.
|
||||
|
||||
This module contains the factory function to create provider instances
|
||||
based on the provider name.
|
||||
"""
|
||||
|
||||
from airflow_wingman.providers.anthropic_provider import AnthropicProvider
|
||||
from airflow_wingman.providers.base import BaseLLMProvider
|
||||
from airflow_wingman.providers.openai_provider import OpenAIProvider
|
||||
|
||||
|
||||
def create_llm_provider(provider_name: str, api_key: str, base_url: str | None = None) -> BaseLLMProvider:
|
||||
"""
|
||||
Create a provider instance based on the provider name.
|
||||
|
||||
Args:
|
||||
provider_name: Name of the provider (openai, anthropic, openrouter)
|
||||
api_key: API key for the provider
|
||||
base_url: Optional base URL for the provider API
|
||||
|
||||
Returns:
|
||||
Provider instance
|
||||
|
||||
Raises:
|
||||
ValueError: If the provider is not supported
|
||||
"""
|
||||
provider_name = provider_name.lower()
|
||||
|
||||
if provider_name == "openai":
|
||||
return OpenAIProvider(api_key=api_key, base_url=base_url)
|
||||
elif provider_name == "openrouter":
|
||||
# OpenRouter uses the OpenAI API format, so we can use the OpenAI provider
|
||||
# with a custom base URL
|
||||
if not base_url:
|
||||
base_url = "https://openrouter.ai/api/v1"
|
||||
return OpenAIProvider(api_key=api_key, base_url=base_url)
|
||||
elif provider_name == "anthropic":
|
||||
return AnthropicProvider(api_key=api_key)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {provider_name}. Supported providers: openai, anthropic, openrouter")
|
||||
288
src/airflow_wingman/providers/anthropic_provider.py
Normal file
288
src/airflow_wingman/providers/anthropic_provider.py
Normal file
@@ -0,0 +1,288 @@
|
||||
"""
|
||||
Anthropic provider implementation for Airflow Wingman.
|
||||
|
||||
This module contains the Anthropic provider implementation that handles
|
||||
API requests, tool conversion, and response processing for Anthropic's Claude models.
|
||||
"""
|
||||
|
||||
import traceback
|
||||
from typing import Any
|
||||
|
||||
from airflow.utils.log.logging_mixin import LoggingMixin
|
||||
from anthropic import Anthropic
|
||||
|
||||
from airflow_wingman.providers.base import BaseLLMProvider
|
||||
from airflow_wingman.tools import execute_airflow_tool
|
||||
from airflow_wingman.tools.conversion import convert_to_anthropic_tools
|
||||
|
||||
logger = LoggingMixin().log
|
||||
|
||||
|
||||
class AnthropicProvider(BaseLLMProvider):
|
||||
"""
|
||||
Anthropic provider implementation.
|
||||
|
||||
This class handles API requests, tool conversion, and response processing
|
||||
for the Anthropic API (Claude models).
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: str):
|
||||
"""
|
||||
Initialize the Anthropic provider.
|
||||
|
||||
Args:
|
||||
api_key: API key for Anthropic
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.client = Anthropic(api_key=api_key)
|
||||
|
||||
def convert_tools(self, airflow_tools: list) -> list:
|
||||
"""
|
||||
Convert Airflow tools to Anthropic format.
|
||||
|
||||
Args:
|
||||
airflow_tools: List of Airflow tools from MCP server
|
||||
|
||||
Returns:
|
||||
List of Anthropic tool definitions
|
||||
"""
|
||||
return convert_to_anthropic_tools(airflow_tools)
|
||||
|
||||
def create_chat_completion(
|
||||
self, messages: list[dict[str, Any]], model: str, temperature: float = 0.7, max_tokens: int | None = None, stream: bool = False, tools: list[dict[str, Any]] | None = None
|
||||
) -> Any:
|
||||
"""
|
||||
Make API request to Anthropic.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries with 'role' and 'content'
|
||||
model: Model identifier
|
||||
temperature: Sampling temperature (0-1)
|
||||
max_tokens: Maximum tokens to generate
|
||||
stream: Whether to stream the response
|
||||
tools: List of tool definitions in Anthropic format
|
||||
|
||||
Returns:
|
||||
Anthropic response object
|
||||
|
||||
Raises:
|
||||
Exception: If the API request fails
|
||||
"""
|
||||
# Convert max_tokens to Anthropic's max_tokens parameter (if provided)
|
||||
max_tokens_param = max_tokens if max_tokens is not None else 4096
|
||||
|
||||
# Convert messages from ChatML format to Anthropic's format
|
||||
anthropic_messages = self._convert_to_anthropic_messages(messages)
|
||||
|
||||
try:
|
||||
logger.info(f"Sending chat completion request to Anthropic with model: {model}")
|
||||
|
||||
# Create request parameters
|
||||
params = {"model": model, "messages": anthropic_messages, "temperature": temperature, "max_tokens": max_tokens_param, "stream": stream}
|
||||
|
||||
# Add tools if provided
|
||||
if tools and len(tools) > 0:
|
||||
params["tools"] = tools
|
||||
|
||||
# Make the API request
|
||||
response = self.client.messages.create(**params)
|
||||
|
||||
logger.info("Received response from Anthropic")
|
||||
return response
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f"Failed to get response from Anthropic: {error_msg}\n{traceback.format_exc()}")
|
||||
raise
|
||||
|
||||
def _convert_to_anthropic_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Convert messages from ChatML format to Anthropic's format.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries in ChatML format
|
||||
|
||||
Returns:
|
||||
List of message dictionaries in Anthropic format
|
||||
"""
|
||||
anthropic_messages = []
|
||||
|
||||
for message in messages:
|
||||
role = message["role"]
|
||||
content = message["content"]
|
||||
|
||||
# Map ChatML roles to Anthropic roles
|
||||
if role == "system":
|
||||
# System messages in Anthropic are handled differently
|
||||
# We'll add them as a user message with a special prefix
|
||||
anthropic_messages.append({"role": "user", "content": f"<system>\n{content}\n</system>"})
|
||||
elif role == "user":
|
||||
anthropic_messages.append({"role": "user", "content": content})
|
||||
elif role == "assistant":
|
||||
anthropic_messages.append({"role": "assistant", "content": content})
|
||||
elif role == "tool":
|
||||
# Tool messages in ChatML become part of the user message in Anthropic
|
||||
# We'll handle this in the follow-up completion
|
||||
continue
|
||||
|
||||
return anthropic_messages
|
||||
|
||||
def has_tool_calls(self, response: Any) -> bool:
|
||||
"""
|
||||
Check if the response contains tool calls.
|
||||
|
||||
Args:
|
||||
response: Anthropic response object
|
||||
|
||||
Returns:
|
||||
True if the response contains tool calls, False otherwise
|
||||
"""
|
||||
# Check if any content block is a tool_use block
|
||||
if not hasattr(response, "content"):
|
||||
return False
|
||||
|
||||
for block in response.content:
|
||||
if isinstance(block, dict) and block.get("type") == "tool_use":
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def process_tool_calls(self, response: Any, cookie: str) -> dict[str, Any]:
|
||||
"""
|
||||
Process tool calls from the response.
|
||||
|
||||
Args:
|
||||
response: Anthropic response object
|
||||
cookie: Airflow cookie for authentication
|
||||
|
||||
Returns:
|
||||
Dictionary mapping tool call IDs to results
|
||||
"""
|
||||
results = {}
|
||||
|
||||
if not self.has_tool_calls(response):
|
||||
return results
|
||||
|
||||
# Extract tool_use blocks
|
||||
tool_use_blocks = [block for block in response.content if isinstance(block, dict) and block.get("type") == "tool_use"]
|
||||
|
||||
for block in tool_use_blocks:
|
||||
tool_id = block.get("id")
|
||||
tool_name = block.get("name")
|
||||
tool_input = block.get("input", {})
|
||||
|
||||
try:
|
||||
# Execute the Airflow tool with the provided arguments and cookie
|
||||
logger.info(f"Executing tool: {tool_name} with arguments: {tool_input}")
|
||||
result = execute_airflow_tool(tool_name, tool_input, cookie)
|
||||
logger.info(f"Tool execution result: {result}")
|
||||
results[tool_id] = {"status": "success", "result": result}
|
||||
except Exception as e:
|
||||
error_msg = f"Error executing tool: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
results[tool_id] = {"status": "error", "message": error_msg}
|
||||
|
||||
return results
|
||||
|
||||
def create_follow_up_completion(
|
||||
self, messages: list[dict[str, Any]], model: str, temperature: float = 0.7, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None
|
||||
) -> Any:
|
||||
"""
|
||||
Create a follow-up completion with tool results.
|
||||
|
||||
Args:
|
||||
messages: Original messages
|
||||
model: Model identifier
|
||||
temperature: Sampling temperature (0-1)
|
||||
max_tokens: Maximum tokens to generate
|
||||
tool_results: Results of tool executions
|
||||
original_response: Original response with tool calls
|
||||
|
||||
Returns:
|
||||
Anthropic response object
|
||||
"""
|
||||
if not original_response or not tool_results:
|
||||
return original_response
|
||||
|
||||
# Extract tool_use blocks from the original response
|
||||
tool_use_blocks = [block for block in original_response.content if isinstance(block, dict) and block.get("type") == "tool_use"]
|
||||
|
||||
# Create tool result blocks
|
||||
tool_result_blocks = []
|
||||
for tool_id, result in tool_results.items():
|
||||
tool_result_blocks.append({"type": "tool_result", "tool_use_id": tool_id, "content": result.get("result", str(result))})
|
||||
|
||||
# Convert original messages to Anthropic format
|
||||
anthropic_messages = self._convert_to_anthropic_messages(messages)
|
||||
|
||||
# Add the assistant response with tool use
|
||||
anthropic_messages.append({"role": "assistant", "content": tool_use_blocks})
|
||||
|
||||
# Add the user message with tool results
|
||||
anthropic_messages.append({"role": "user", "content": tool_result_blocks})
|
||||
|
||||
# Make a second request to get the final response
|
||||
logger.info("Making second request with tool results")
|
||||
return self.create_chat_completion(
|
||||
messages=anthropic_messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stream=False,
|
||||
tools=None, # No tools needed for follow-up
|
||||
)
|
||||
|
||||
def get_content(self, response: Any) -> str:
|
||||
"""
|
||||
Extract content from the response.
|
||||
|
||||
Args:
|
||||
response: Anthropic response object
|
||||
|
||||
Returns:
|
||||
Content string from the response
|
||||
"""
|
||||
if not hasattr(response, "content"):
|
||||
return ""
|
||||
|
||||
# Combine all text blocks into a single string
|
||||
content_parts = []
|
||||
for block in response.content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
content_parts.append(block.get("text", ""))
|
||||
elif isinstance(block, str):
|
||||
content_parts.append(block)
|
||||
|
||||
return "".join(content_parts)
|
||||
|
||||
def get_streaming_content(self, response: Any) -> Any:
|
||||
"""
|
||||
Get a generator for streaming content from the response.
|
||||
|
||||
Args:
|
||||
response: Anthropic streaming response object
|
||||
|
||||
Returns:
|
||||
Generator yielding content chunks
|
||||
"""
|
||||
|
||||
def generate():
|
||||
for chunk in response:
|
||||
logger.debug(f"Chunk type: {type(chunk)}")
|
||||
|
||||
# Handle different types of chunks from Anthropic API
|
||||
content = None
|
||||
if hasattr(chunk, "type") and chunk.type == "content_block_delta":
|
||||
if hasattr(chunk, "delta") and hasattr(chunk.delta, "text"):
|
||||
content = chunk.delta.text
|
||||
elif hasattr(chunk, "delta") and hasattr(chunk.delta, "text"):
|
||||
content = chunk.delta.text
|
||||
elif hasattr(chunk, "content") and chunk.content:
|
||||
for block in chunk.content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
content = block.get("text", "")
|
||||
|
||||
if content:
|
||||
# Don't do any newline replacement here
|
||||
yield content
|
||||
|
||||
return generate()
|
||||
125
src/airflow_wingman/providers/base.py
Normal file
125
src/airflow_wingman/providers/base.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""
|
||||
Base provider interface for Airflow Wingman.
|
||||
|
||||
This module contains the base provider interface that all provider implementations
|
||||
must adhere to. It defines the methods required for tool conversion, API requests,
|
||||
and response processing.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
|
||||
class BaseLLMProvider(ABC):
|
||||
"""
|
||||
Base provider interface for LLM providers.
|
||||
|
||||
This abstract class defines the methods that all provider implementations
|
||||
must implement to support tool integration.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def convert_tools(self, airflow_tools: list) -> list:
|
||||
"""
|
||||
Convert internal tool representation to provider format.
|
||||
|
||||
Args:
|
||||
airflow_tools: List of Airflow tools from MCP server
|
||||
|
||||
Returns:
|
||||
List of provider-specific tool definitions
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_chat_completion(
|
||||
self, messages: list[dict[str, Any]], model: str, temperature: float = 0.7, max_tokens: int | None = None, stream: bool = False, tools: list[dict[str, Any]] | None = None
|
||||
) -> Any:
|
||||
"""
|
||||
Make API request to provider.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries with 'role' and 'content'
|
||||
model: Model identifier
|
||||
temperature: Sampling temperature (0-1)
|
||||
max_tokens: Maximum tokens to generate
|
||||
stream: Whether to stream the response
|
||||
tools: List of tool definitions in provider format
|
||||
|
||||
Returns:
|
||||
Provider-specific response object
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def has_tool_calls(self, response: Any) -> bool:
|
||||
"""
|
||||
Check if the response contains tool calls.
|
||||
|
||||
Args:
|
||||
response: Provider-specific response object
|
||||
|
||||
Returns:
|
||||
True if the response contains tool calls, False otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def process_tool_calls(self, response: Any, cookie: str) -> dict[str, Any]:
|
||||
"""
|
||||
Process tool calls from the response.
|
||||
|
||||
Args:
|
||||
response: Provider-specific response object
|
||||
cookie: Airflow cookie for authentication
|
||||
|
||||
Returns:
|
||||
Dictionary mapping tool call IDs to results
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_follow_up_completion(
|
||||
self, messages: list[dict[str, Any]], model: str, temperature: float = 0.7, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None
|
||||
) -> Any:
|
||||
"""
|
||||
Create a follow-up completion with tool results.
|
||||
|
||||
Args:
|
||||
messages: Original messages
|
||||
model: Model identifier
|
||||
temperature: Sampling temperature (0-1)
|
||||
max_tokens: Maximum tokens to generate
|
||||
tool_results: Results of tool executions
|
||||
original_response: Original response with tool calls
|
||||
|
||||
Returns:
|
||||
Provider-specific response object
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_content(self, response: Any) -> str:
|
||||
"""
|
||||
Extract content from the response.
|
||||
|
||||
Args:
|
||||
response: Provider-specific response object
|
||||
|
||||
Returns:
|
||||
Content string from the response
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_streaming_content(self, response: Any) -> Any:
|
||||
"""
|
||||
Get a generator for streaming content from the response.
|
||||
|
||||
Args:
|
||||
response: Provider-specific response object
|
||||
|
||||
Returns:
|
||||
Generator yielding content chunks
|
||||
"""
|
||||
pass
|
||||
224
src/airflow_wingman/providers/openai_provider.py
Normal file
224
src/airflow_wingman/providers/openai_provider.py
Normal file
@@ -0,0 +1,224 @@
|
||||
"""
|
||||
OpenAI provider implementation for Airflow Wingman.
|
||||
|
||||
This module contains the OpenAI provider implementation that handles
|
||||
API requests, tool conversion, and response processing for OpenAI.
|
||||
"""
|
||||
|
||||
import json
|
||||
import traceback
|
||||
from typing import Any
|
||||
|
||||
from airflow.utils.log.logging_mixin import LoggingMixin
|
||||
from openai import OpenAI
|
||||
|
||||
from airflow_wingman.providers.base import BaseLLMProvider
|
||||
from airflow_wingman.tools import execute_airflow_tool
|
||||
from airflow_wingman.tools.conversion import convert_to_openai_tools
|
||||
|
||||
# Create a logger instance
|
||||
logger = LoggingMixin().log
|
||||
|
||||
|
||||
class OpenAIProvider(BaseLLMProvider):
|
||||
"""
|
||||
OpenAI provider implementation.
|
||||
|
||||
This class handles API requests, tool conversion, and response processing
|
||||
for the OpenAI API. It can also be used for OpenRouter with a custom base URL.
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: str, base_url: str | None = None):
|
||||
"""
|
||||
Initialize the OpenAI provider.
|
||||
|
||||
Args:
|
||||
api_key: API key for OpenAI
|
||||
base_url: Optional base URL for the API (used for OpenRouter)
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
||||
|
||||
def convert_tools(self, airflow_tools: list) -> list:
|
||||
"""
|
||||
Convert Airflow tools to OpenAI format.
|
||||
|
||||
Args:
|
||||
airflow_tools: List of Airflow tools from MCP server
|
||||
|
||||
Returns:
|
||||
List of OpenAI tool definitions
|
||||
"""
|
||||
return convert_to_openai_tools(airflow_tools)
|
||||
|
||||
def create_chat_completion(
|
||||
self, messages: list[dict[str, Any]], model: str, temperature: float = 0.7, max_tokens: int | None = None, stream: bool = False, tools: list[dict[str, Any]] | None = None
|
||||
) -> Any:
|
||||
"""
|
||||
Make API request to OpenAI.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries with 'role' and 'content'
|
||||
model: Model identifier
|
||||
temperature: Sampling temperature (0-1)
|
||||
max_tokens: Maximum tokens to generate
|
||||
stream: Whether to stream the response
|
||||
tools: List of tool definitions in OpenAI format
|
||||
|
||||
Returns:
|
||||
OpenAI response object
|
||||
|
||||
Raises:
|
||||
Exception: If the API request fails
|
||||
"""
|
||||
# Only include tools if we have any
|
||||
has_tools = tools is not None and len(tools) > 0
|
||||
tool_choice = "auto" if has_tools else None
|
||||
|
||||
try:
|
||||
logger.info(f"Sending chat completion request to OpenAI with model: {model}")
|
||||
response = self.client.chat.completions.create(
|
||||
model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream, tools=tools if has_tools else None, tool_choice=tool_choice
|
||||
)
|
||||
logger.info("Received response from OpenAI")
|
||||
return response
|
||||
except Exception as e:
|
||||
# If the API call fails due to tools not being supported, retry without tools
|
||||
error_msg = str(e)
|
||||
logger.warning(f"Error in OpenAI API call: {error_msg}")
|
||||
if "tools" in error_msg.lower():
|
||||
logger.info("Retrying without tools")
|
||||
response = self.client.chat.completions.create(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream)
|
||||
return response
|
||||
else:
|
||||
logger.error(f"Failed to get response from OpenAI: {error_msg}\n{traceback.format_exc()}")
|
||||
raise
|
||||
|
||||
def has_tool_calls(self, response: Any) -> bool:
|
||||
"""
|
||||
Check if the response contains tool calls.
|
||||
|
||||
Args:
|
||||
response: OpenAI response object
|
||||
|
||||
Returns:
|
||||
True if the response contains tool calls, False otherwise
|
||||
"""
|
||||
message = response.choices[0].message
|
||||
return hasattr(message, "tool_calls") and message.tool_calls
|
||||
|
||||
def process_tool_calls(self, response: Any, cookie: str) -> dict[str, Any]:
|
||||
"""
|
||||
Process tool calls from the response.
|
||||
|
||||
Args:
|
||||
response: OpenAI response object
|
||||
cookie: Airflow cookie for authentication
|
||||
|
||||
Returns:
|
||||
Dictionary mapping tool call IDs to results
|
||||
"""
|
||||
results = {}
|
||||
message = response.choices[0].message
|
||||
|
||||
if not self.has_tool_calls(response):
|
||||
return results
|
||||
|
||||
for tool_call in message.tool_calls:
|
||||
tool_id = tool_call.id
|
||||
function_name = tool_call.function.name
|
||||
arguments = json.loads(tool_call.function.arguments)
|
||||
|
||||
try:
|
||||
# Execute the Airflow tool with the provided arguments and cookie
|
||||
logger.info(f"Executing tool: {function_name} with arguments: {arguments}")
|
||||
result = execute_airflow_tool(function_name, arguments, cookie)
|
||||
logger.info(f"Tool execution result: {result}")
|
||||
results[tool_id] = {"status": "success", "result": result}
|
||||
except Exception as e:
|
||||
error_msg = f"Error executing tool: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
results[tool_id] = {"status": "error", "message": error_msg}
|
||||
|
||||
return results
|
||||
|
||||
def create_follow_up_completion(
|
||||
self, messages: list[dict[str, Any]], model: str, temperature: float = 0.7, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None
|
||||
) -> Any:
|
||||
"""
|
||||
Create a follow-up completion with tool results.
|
||||
|
||||
Args:
|
||||
messages: Original messages
|
||||
model: Model identifier
|
||||
temperature: Sampling temperature (0-1)
|
||||
max_tokens: Maximum tokens to generate
|
||||
tool_results: Results of tool executions
|
||||
original_response: Original response with tool calls
|
||||
|
||||
Returns:
|
||||
OpenAI response object
|
||||
"""
|
||||
if not original_response or not tool_results:
|
||||
return original_response
|
||||
|
||||
# Get the original message with tool calls
|
||||
original_message = original_response.choices[0].message
|
||||
|
||||
# Create a new message with the tool calls
|
||||
assistant_message = {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{"id": tc.id, "type": "function", "function": {"name": tc.function.name, "arguments": tc.function.arguments}} for tc in original_message.tool_calls],
|
||||
}
|
||||
|
||||
# Create tool result messages
|
||||
tool_messages = []
|
||||
for tool_call_id, result in tool_results.items():
|
||||
tool_messages.append({"role": "tool", "tool_call_id": tool_call_id, "content": result.get("result", str(result))})
|
||||
|
||||
# Add the original messages, assistant message, and tool results
|
||||
new_messages = messages + [assistant_message] + tool_messages
|
||||
|
||||
# Make a second request to get the final response
|
||||
logger.info("Making second request with tool results")
|
||||
return self.create_chat_completion(
|
||||
messages=new_messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stream=False,
|
||||
tools=None, # No tools needed for follow-up
|
||||
)
|
||||
|
||||
def get_content(self, response: Any) -> str:
|
||||
"""
|
||||
Extract content from the response.
|
||||
|
||||
Args:
|
||||
response: OpenAI response object
|
||||
|
||||
Returns:
|
||||
Content string from the response
|
||||
"""
|
||||
return response.choices[0].message.content
|
||||
|
||||
def get_streaming_content(self, response: Any) -> Any:
|
||||
"""
|
||||
Get a generator for streaming content from the response.
|
||||
|
||||
Args:
|
||||
response: OpenAI streaming response object
|
||||
|
||||
Returns:
|
||||
Generator yielding content chunks
|
||||
"""
|
||||
|
||||
def generate():
|
||||
for chunk in response:
|
||||
if chunk.choices and chunk.choices[0].delta.content:
|
||||
# Don't do any newline replacement here
|
||||
content = chunk.choices[0].delta.content
|
||||
yield content
|
||||
|
||||
return generate()
|
||||
@@ -150,6 +150,7 @@
|
||||
border: 1px solid #e9ecef;
|
||||
border-radius: 15px 15px 15px 0;
|
||||
padding: 10px 15px;
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
|
||||
#chat-messages::after {
|
||||
@@ -349,6 +350,8 @@ document.addEventListener('DOMContentLoaded', function() {
|
||||
if (line.startsWith('data: ')) {
|
||||
const content = line.slice(6);
|
||||
if (content) {
|
||||
// Use textContent to properly handle newlines
|
||||
console.log('Received chunk:', JSON.stringify(content)); // Debug
|
||||
currentMessageDiv.textContent += content;
|
||||
fullResponse += content;
|
||||
chatMessages.scrollTop = chatMessages.scrollHeight;
|
||||
|
||||
15
src/airflow_wingman/tools/__init__.py
Normal file
15
src/airflow_wingman/tools/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
Tools module for Airflow Wingman.
|
||||
|
||||
This module contains the tools used by Airflow Wingman to interact with Airflow.
|
||||
"""
|
||||
|
||||
from airflow_wingman.tools.conversion import convert_to_anthropic_tools, convert_to_openai_tools
|
||||
from airflow_wingman.tools.execution import execute_airflow_tool, list_airflow_tools
|
||||
|
||||
__all__ = [
|
||||
"convert_to_openai_tools",
|
||||
"convert_to_anthropic_tools",
|
||||
"list_airflow_tools",
|
||||
"execute_airflow_tool",
|
||||
]
|
||||
139
src/airflow_wingman/tools/conversion.py
Normal file
139
src/airflow_wingman/tools/conversion.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""
|
||||
Conversion utilities for Airflow Wingman tools.
|
||||
|
||||
This module contains functions to convert between different tool formats
|
||||
for various LLM providers (OpenAI, Anthropic, etc.).
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def convert_to_openai_tools(airflow_tools: list) -> list:
|
||||
"""
|
||||
Convert Airflow tools to OpenAI tool definitions.
|
||||
|
||||
Args:
|
||||
airflow_tools: List of Airflow tools from MCP server
|
||||
|
||||
Returns:
|
||||
List of OpenAI tool definitions
|
||||
"""
|
||||
openai_tools = []
|
||||
|
||||
for tool in airflow_tools:
|
||||
# Initialize the OpenAI tool structure
|
||||
openai_tool = {"type": "function", "function": {"name": tool.name, "description": tool.description or tool.name, "parameters": {"type": "object", "properties": {}, "required": []}}}
|
||||
|
||||
# Extract parameters directly from inputSchema if available
|
||||
if hasattr(tool, "inputSchema") and tool.inputSchema:
|
||||
# Set the type and required fields directly from the schema
|
||||
if "type" in tool.inputSchema:
|
||||
openai_tool["function"]["parameters"]["type"] = tool.inputSchema["type"]
|
||||
|
||||
# Add required parameters if specified
|
||||
if "required" in tool.inputSchema:
|
||||
openai_tool["function"]["parameters"]["required"] = tool.inputSchema["required"]
|
||||
|
||||
# Add properties from the input schema
|
||||
if "properties" in tool.inputSchema:
|
||||
for param_name, param_info in tool.inputSchema["properties"].items():
|
||||
# Create parameter definition
|
||||
param_def = {}
|
||||
|
||||
# Handle different schema constructs
|
||||
if "anyOf" in param_info:
|
||||
_handle_schema_construct(param_def, param_info, "anyOf")
|
||||
elif "oneOf" in param_info:
|
||||
_handle_schema_construct(param_def, param_info, "oneOf")
|
||||
elif "allOf" in param_info:
|
||||
_handle_schema_construct(param_def, param_info, "allOf")
|
||||
elif "type" in param_info:
|
||||
param_def["type"] = param_info["type"]
|
||||
# Add format if available
|
||||
if "format" in param_info:
|
||||
param_def["format"] = param_info["format"]
|
||||
else:
|
||||
param_def["type"] = "string" # Default type
|
||||
|
||||
# Add description from title or param name
|
||||
param_def["description"] = param_info.get("description", param_info.get("title", param_name))
|
||||
|
||||
# Add enum values if available
|
||||
if "enum" in param_info:
|
||||
param_def["enum"] = param_info["enum"]
|
||||
|
||||
# Add default value if available
|
||||
if "default" in param_info and param_info["default"] is not None:
|
||||
param_def["default"] = param_info["default"]
|
||||
|
||||
# Add to properties
|
||||
openai_tool["function"]["parameters"]["properties"][param_name] = param_def
|
||||
|
||||
openai_tools.append(openai_tool)
|
||||
|
||||
return openai_tools
|
||||
|
||||
|
||||
def convert_to_anthropic_tools(airflow_tools: list) -> list:
|
||||
"""
|
||||
Convert Airflow tools to Anthropic tool definitions.
|
||||
|
||||
Args:
|
||||
airflow_tools: List of Airflow tools from MCP server
|
||||
|
||||
Returns:
|
||||
List of Anthropic tool definitions
|
||||
"""
|
||||
anthropic_tools = []
|
||||
|
||||
for tool in airflow_tools:
|
||||
# Initialize the Anthropic tool structure
|
||||
anthropic_tool = {"name": tool.name, "description": tool.description or tool.name, "input_schema": {}}
|
||||
|
||||
# Extract parameters directly from inputSchema if available
|
||||
if hasattr(tool, "inputSchema") and tool.inputSchema:
|
||||
# Copy the input schema directly as Anthropic's format is similar to JSON Schema
|
||||
anthropic_tool["input_schema"] = tool.inputSchema
|
||||
else:
|
||||
# Create a minimal schema if none exists
|
||||
anthropic_tool["input_schema"] = {"type": "object", "properties": {}, "required": []}
|
||||
|
||||
anthropic_tools.append(anthropic_tool)
|
||||
|
||||
return anthropic_tools
|
||||
|
||||
|
||||
def _handle_schema_construct(param_def: dict[str, Any], param_info: dict[str, Any], construct_type: str) -> None:
|
||||
"""
|
||||
Helper function to handle JSON Schema constructs like anyOf, oneOf, allOf.
|
||||
|
||||
Args:
|
||||
param_def: Parameter definition to update
|
||||
param_info: Parameter info from the schema
|
||||
construct_type: Type of construct (anyOf, oneOf, allOf)
|
||||
"""
|
||||
# Get the list of schemas from the construct
|
||||
schemas = param_info[construct_type]
|
||||
|
||||
# Find the first schema with a type
|
||||
for schema in schemas:
|
||||
if "type" in schema:
|
||||
param_def["type"] = schema["type"]
|
||||
|
||||
# Add format if available
|
||||
if "format" in schema:
|
||||
param_def["format"] = schema["format"]
|
||||
|
||||
# Add enum values if available
|
||||
if "enum" in schema:
|
||||
param_def["enum"] = schema["enum"]
|
||||
|
||||
# Add default value if available
|
||||
if "default" in schema and schema["default"] is not None:
|
||||
param_def["default"] = schema["default"]
|
||||
|
||||
break
|
||||
|
||||
# If no type was found, default to string
|
||||
if "type" not in param_def:
|
||||
param_def["type"] = "string"
|
||||
117
src/airflow_wingman/tools/execution.py
Normal file
117
src/airflow_wingman/tools/execution.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""
|
||||
Tool execution module for Airflow Wingman.
|
||||
|
||||
This module contains functions to list and execute Airflow tools.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import traceback
|
||||
|
||||
from airflow import configuration
|
||||
from airflow.utils.log.logging_mixin import LoggingMixin
|
||||
from airflow_mcp_server.config import AirflowConfig
|
||||
from airflow_mcp_server.tools.tool_manager import get_airflow_tools, get_tool
|
||||
|
||||
# Create a logger instance
|
||||
logger = LoggingMixin().log
|
||||
|
||||
|
||||
async def _list_airflow_tools_async(cookie: str) -> list:
|
||||
"""
|
||||
Async implementation to list available Airflow tools.
|
||||
|
||||
Args:
|
||||
cookie: Cookie for authentication
|
||||
|
||||
Returns:
|
||||
List of available Airflow tools
|
||||
"""
|
||||
try:
|
||||
# Set up configuration
|
||||
base_url = f"{configuration.conf.get('webserver', 'base_url')}/api/v1/"
|
||||
logger.info(f"Setting up AirflowConfig with base_url: {base_url}")
|
||||
config = AirflowConfig(base_url=base_url, cookie=cookie, auth_token=None)
|
||||
|
||||
# Get available tools
|
||||
logger.info("Getting Airflow tools...")
|
||||
tools = await get_airflow_tools(config=config, mode="safe")
|
||||
logger.info(f"Got {len(tools)} tools")
|
||||
return tools
|
||||
except Exception as e:
|
||||
error_msg = f"Error listing Airflow tools: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
return []
|
||||
|
||||
|
||||
def list_airflow_tools(cookie: str) -> list:
|
||||
"""
|
||||
Synchronous wrapper to list available Airflow tools.
|
||||
|
||||
Args:
|
||||
cookie: Cookie for authentication
|
||||
|
||||
Returns:
|
||||
List of available Airflow tools
|
||||
"""
|
||||
return asyncio.run(_list_airflow_tools_async(cookie))
|
||||
|
||||
|
||||
async def _execute_airflow_tool_async(tool_name: str, arguments: dict, cookie: str) -> str:
|
||||
"""
|
||||
Async implementation to execute an Airflow tool.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool to execute
|
||||
arguments: Arguments to pass to the tool
|
||||
cookie: Cookie for authentication
|
||||
|
||||
Returns:
|
||||
Result of the tool execution as a string
|
||||
"""
|
||||
try:
|
||||
# Set up configuration
|
||||
base_url = f"{configuration.conf.get('webserver', 'base_url')}/api/v1/"
|
||||
logger.info(f"Setting up AirflowConfig with base_url: {base_url}")
|
||||
config = AirflowConfig(base_url=base_url, cookie=cookie, auth_token=None)
|
||||
|
||||
# Get the tool
|
||||
logger.info(f"Getting tool: {tool_name}")
|
||||
tool = await get_tool(config=config, tool_name=tool_name)
|
||||
|
||||
if not tool:
|
||||
error_msg = f"Tool not found: {tool_name}"
|
||||
logger.error(error_msg)
|
||||
return json.dumps({"error": error_msg})
|
||||
|
||||
# Execute the tool
|
||||
logger.info(f"Executing tool: {tool_name} with arguments: {arguments}")
|
||||
result = await tool.run(arguments)
|
||||
|
||||
# Convert result to string
|
||||
if isinstance(result, dict | list):
|
||||
result_str = json.dumps(result, indent=2)
|
||||
else:
|
||||
result_str = str(result)
|
||||
|
||||
logger.info(f"Tool execution result: {result_str[:100]}...")
|
||||
return result_str
|
||||
except Exception as e:
|
||||
error_msg = f"Error executing tool: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
return json.dumps({"error": error_msg})
|
||||
|
||||
|
||||
def execute_airflow_tool(tool_name: str, arguments: dict, cookie: str) -> str:
|
||||
"""
|
||||
Synchronous wrapper to execute an Airflow tool.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool to execute
|
||||
arguments: Arguments to pass to the tool
|
||||
cookie: Cookie for authentication
|
||||
|
||||
Returns:
|
||||
Result of the tool execution as a string
|
||||
"""
|
||||
return asyncio.run(_execute_airflow_tool_async(tool_name, arguments, cookie))
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Views for Airflow Wingman plugin."""
|
||||
|
||||
from flask import Response, request, stream_with_context
|
||||
from flask import Response, request, session
|
||||
from flask.json import jsonify
|
||||
from flask_appbuilder import BaseView as AppBuilderBaseView, expose
|
||||
|
||||
@@ -8,6 +8,7 @@ from airflow_wingman.llm_client import LLMClient
|
||||
from airflow_wingman.llms_models import MODELS
|
||||
from airflow_wingman.notes import INTERFACE_MESSAGES
|
||||
from airflow_wingman.prompt_engineering import prepare_messages
|
||||
from airflow_wingman.tools import list_airflow_tools
|
||||
|
||||
|
||||
class WingmanView(AppBuilderBaseView):
|
||||
@@ -28,8 +29,32 @@ class WingmanView(AppBuilderBaseView):
|
||||
try:
|
||||
data = self._validate_chat_request(request.get_json())
|
||||
|
||||
# Create a new client for this request
|
||||
client = LLMClient(data["api_key"])
|
||||
if data.get("cookie"):
|
||||
session["airflow_cookie"] = data["cookie"]
|
||||
|
||||
# Get available Airflow tools using the stored cookie
|
||||
airflow_tools = []
|
||||
if session.get("airflow_cookie"):
|
||||
try:
|
||||
airflow_tools = list_airflow_tools(session["airflow_cookie"])
|
||||
except Exception as e:
|
||||
# Log the error but continue without tools
|
||||
print(f"Error fetching Airflow tools: {str(e)}")
|
||||
|
||||
# Prepare messages with Airflow tools included in the prompt
|
||||
data["messages"] = prepare_messages(data["messages"])
|
||||
|
||||
# Get provider name from request or use default
|
||||
provider_name = data.get("provider", "openai")
|
||||
|
||||
# Get base URL from models configuration based on provider
|
||||
base_url = MODELS.get(provider_name, {}).get("endpoint")
|
||||
|
||||
# Create a new client for this request with the appropriate provider
|
||||
client = LLMClient(provider_name=provider_name, api_key=data["api_key"], base_url=base_url)
|
||||
|
||||
# Set the Airflow tools for the client to use
|
||||
client.set_airflow_tools(airflow_tools)
|
||||
|
||||
if data["stream"]:
|
||||
return self._handle_streaming_response(client, data)
|
||||
@@ -46,39 +71,49 @@ class WingmanView(AppBuilderBaseView):
|
||||
if not data:
|
||||
raise ValueError("No data provided")
|
||||
|
||||
required_fields = ["provider", "model", "messages", "api_key"]
|
||||
required_fields = ["model", "messages", "api_key"]
|
||||
missing = [f for f in required_fields if not data.get(f)]
|
||||
if missing:
|
||||
raise ValueError(f"Missing required fields: {', '.join(missing)}")
|
||||
|
||||
# Prepare messages with system instruction while maintaining history
|
||||
messages = data["messages"]
|
||||
messages = prepare_messages(messages)
|
||||
# Validate provider if provided
|
||||
provider = data.get("provider", "openai")
|
||||
if provider not in MODELS:
|
||||
raise ValueError(f"Unsupported provider: {provider}. Supported providers: {', '.join(MODELS.keys())}")
|
||||
|
||||
return {
|
||||
"provider": data["provider"],
|
||||
"model": data["model"],
|
||||
"messages": messages,
|
||||
"messages": data["messages"],
|
||||
"api_key": data["api_key"],
|
||||
"stream": data.get("stream", False),
|
||||
"stream": data.get("stream", True),
|
||||
"temperature": data.get("temperature", 0.7),
|
||||
"max_tokens": data.get("max_tokens"),
|
||||
"cookie": data.get("cookie"),
|
||||
"provider": provider,
|
||||
"base_url": data.get("base_url"),
|
||||
}
|
||||
|
||||
def _handle_streaming_response(self, client: LLMClient, data: dict) -> Response:
|
||||
"""Handle streaming response."""
|
||||
try:
|
||||
# Get the streaming generator from the client
|
||||
generator = client.chat_completion(messages=data["messages"], model=data["model"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=True)
|
||||
|
||||
def generate():
|
||||
for chunk in client.chat_completion(messages=data["messages"], model=data["model"], provider=data["provider"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=True):
|
||||
yield f"data: {chunk}\n\n"
|
||||
def stream_response():
|
||||
# Send SSE format for each chunk
|
||||
for chunk in generator:
|
||||
if chunk:
|
||||
yield f"data: {chunk}\n\n"
|
||||
|
||||
response = Response(stream_with_context(generate()), mimetype="text/event-stream")
|
||||
response.headers["Content-Type"] = "text/event-stream"
|
||||
response.headers["Cache-Control"] = "no-cache"
|
||||
response.headers["Connection"] = "keep-alive"
|
||||
return response
|
||||
# Signal end of stream
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return Response(stream_response(), mimetype="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
|
||||
except Exception as e:
|
||||
# If streaming fails, return error
|
||||
return jsonify({"error": str(e)}), 500
|
||||
|
||||
def _handle_regular_response(self, client: LLMClient, data: dict) -> Response:
|
||||
"""Handle regular response."""
|
||||
response = client.chat_completion(messages=data["messages"], model=data["model"], provider=data["provider"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=False)
|
||||
response = client.chat_completion(messages=data["messages"], model=data["model"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=False)
|
||||
return jsonify(response)
|
||||
|
||||
Reference in New Issue
Block a user