Merge pull request #2 from abhishekbhakat/1-llm-chat-to-have-access-to-tools
1 llm chat to have access to tools
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -174,3 +174,5 @@ cython_debug/
|
|||||||
# Local Resources
|
# Local Resources
|
||||||
plugins_reference/
|
plugins_reference/
|
||||||
astro/
|
astro/
|
||||||
|
|
||||||
|
node_modules/
|
||||||
|
|||||||
@@ -9,9 +9,10 @@ authors = [
|
|||||||
{name = "Abhishek Bhakat", email = "abhishek.bhakat@hotmail.com"}
|
{name = "Abhishek Bhakat", email = "abhishek.bhakat@hotmail.com"}
|
||||||
]
|
]
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"airflow-mcp-server>=0.4.0",
|
||||||
|
"anthropic>=0.46.0",
|
||||||
"apache-airflow>=2.10.0",
|
"apache-airflow>=2.10.0",
|
||||||
"openai>=1.64.0",
|
"openai>=1.64.0",
|
||||||
"anthropic>=0.46.0"
|
|
||||||
]
|
]
|
||||||
classifiers = [
|
classifiers = [
|
||||||
"Development Status :: 3 - Alpha",
|
"Development Status :: 3 - Alpha",
|
||||||
@@ -31,6 +32,13 @@ Issues = "https://github.com/abhishekbhakat/airflow-wingman/issues"
|
|||||||
[project.entry-points."airflow.plugins"]
|
[project.entry-points."airflow.plugins"]
|
||||||
wingman = "airflow_wingman:WingmanPlugin"
|
wingman = "airflow_wingman:WingmanPlugin"
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
dev = [
|
||||||
|
"build>=1.2.2",
|
||||||
|
"pre-commit>=4.0.1",
|
||||||
|
"ruff>=0.9.2"
|
||||||
|
]
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["hatchling"]
|
requires = ["hatchling"]
|
||||||
build-backend = "hatchling.build"
|
build-backend = "hatchling.build"
|
||||||
@@ -60,7 +68,8 @@ lint.select = [
|
|||||||
lint.ignore = [
|
lint.ignore = [
|
||||||
"C416", # Unnecessary list comprehension - rewrite as a generator expression
|
"C416", # Unnecessary list comprehension - rewrite as a generator expression
|
||||||
"C408", # Unnecessary `dict` call - rewrite as a literal
|
"C408", # Unnecessary `dict` call - rewrite as a literal
|
||||||
"ISC001" # Single line implicit string concatenation
|
"ISC001", # Single line implicit string concatenation
|
||||||
|
"C901"
|
||||||
]
|
]
|
||||||
|
|
||||||
lint.fixable = ["ALL"]
|
lint.fixable = ["ALL"]
|
||||||
|
|||||||
@@ -1,109 +1,266 @@
|
|||||||
"""
|
"""
|
||||||
Client for making API calls to various LLM providers using their official SDKs.
|
Multi-provider LLM client for Airflow Wingman.
|
||||||
|
|
||||||
|
This module contains the LLMClient class that supports multiple LLM providers
|
||||||
|
(OpenAI, Anthropic, OpenRouter) through a unified interface.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from collections.abc import Generator
|
import logging
|
||||||
|
import traceback
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from anthropic import Anthropic
|
from flask import session
|
||||||
from openai import OpenAI
|
|
||||||
|
from airflow_wingman.providers import create_llm_provider
|
||||||
|
from airflow_wingman.tools import list_airflow_tools
|
||||||
|
|
||||||
|
# Create a properly namespaced logger for the Airflow plugin
|
||||||
|
logger = logging.getLogger("airflow.plugins.wingman")
|
||||||
|
|
||||||
|
|
||||||
class LLMClient:
|
class LLMClient:
|
||||||
def __init__(self, api_key: str):
|
"""
|
||||||
"""Initialize the LLM client.
|
Multi-provider LLM client for Airflow Wingman.
|
||||||
|
|
||||||
|
This class handles chat completion requests to various LLM providers
|
||||||
|
(OpenAI, Anthropic, OpenRouter) through a unified interface.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, provider_name: str, api_key: str, base_url: str | None = None):
|
||||||
|
"""
|
||||||
|
Initialize the LLM client.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
provider_name: Name of the provider (openai, anthropic, openrouter)
|
||||||
api_key: API key for the provider
|
api_key: API key for the provider
|
||||||
|
base_url: Optional base URL for the provider API
|
||||||
"""
|
"""
|
||||||
|
self.provider_name = provider_name
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.openai_client = OpenAI(api_key=api_key)
|
self.base_url = base_url
|
||||||
self.anthropic_client = Anthropic(api_key=api_key)
|
self.provider = create_llm_provider(provider_name, api_key, base_url)
|
||||||
self.openrouter_client = OpenAI(
|
self.airflow_tools = []
|
||||||
base_url="https://openrouter.ai/api/v1",
|
|
||||||
api_key=api_key,
|
def set_airflow_tools(self, tools: list):
|
||||||
default_headers={
|
"""
|
||||||
"HTTP-Referer": "Airflow Wingman", # Required by OpenRouter
|
Set the available Airflow tools.
|
||||||
"X-Title": "Airflow Wingman", # Required by OpenRouter
|
|
||||||
},
|
Args:
|
||||||
)
|
tools: List of Airflow Tool objects
|
||||||
|
"""
|
||||||
|
self.airflow_tools = tools
|
||||||
|
|
||||||
def chat_completion(
|
def chat_completion(
|
||||||
self, messages: list[dict[str, str]], model: str, provider: str, temperature: float = 0.7, max_tokens: int | None = None, stream: bool = False
|
self, messages: list[dict[str, str]], model: str, temperature: float = 0.4, max_tokens: int | None = None, stream: bool = True, return_response_obj: bool = False
|
||||||
) -> Generator[str, None, None] | dict:
|
) -> dict[str, Any] | tuple[Any, Any]:
|
||||||
"""Send a chat completion request to the specified provider.
|
"""
|
||||||
|
Send a chat completion request to the LLM provider.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages: List of message dictionaries with 'role' and 'content'
|
messages: List of message dictionaries with 'role' and 'content'
|
||||||
model: Model identifier
|
model: Model identifier
|
||||||
provider: Provider identifier (openai, anthropic, openrouter)
|
|
||||||
temperature: Sampling temperature (0-1)
|
temperature: Sampling temperature (0-1)
|
||||||
max_tokens: Maximum tokens to generate
|
max_tokens: Maximum tokens to generate
|
||||||
|
stream: Whether to stream the response (default is True)
|
||||||
|
return_response_obj: If True and streaming, returns both the response object and generator
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
If stream=False: Dictionary with the response content
|
||||||
|
If stream=True and return_response_obj=False: Generator for streaming
|
||||||
|
If stream=True and return_response_obj=True: Tuple of (response_obj, generator)
|
||||||
|
"""
|
||||||
|
# Get provider-specific tool definitions from Airflow tools
|
||||||
|
provider_tools = self.provider.convert_tools(self.airflow_tools)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Make the initial request with tools
|
||||||
|
logger.info(f"Sending chat completion request to {self.provider_name} with model: {model}")
|
||||||
|
response = self.provider.create_chat_completion(messages=messages, model=model, temperature=temperature, max_tokens=max_tokens, stream=stream, tools=provider_tools)
|
||||||
|
logger.info(f"Received response from {self.provider_name}")
|
||||||
|
|
||||||
|
# If streaming, handle based on return_response_obj flag
|
||||||
|
if stream:
|
||||||
|
logger.info(f"Using streaming response from {self.provider_name}")
|
||||||
|
streaming_content = self.provider.get_streaming_content(response)
|
||||||
|
return streaming_content
|
||||||
|
|
||||||
|
# For non-streaming responses, handle tool calls if present
|
||||||
|
if self.provider.has_tool_calls(response):
|
||||||
|
logger.info("Response contains tool calls")
|
||||||
|
|
||||||
|
# Process tool calls and get results
|
||||||
|
cookie = session.get("airflow_cookie")
|
||||||
|
if not cookie:
|
||||||
|
error_msg = "No Airflow cookie available"
|
||||||
|
logger.error(error_msg)
|
||||||
|
return {"error": error_msg}
|
||||||
|
|
||||||
|
tool_results = self.provider.process_tool_calls(response, cookie)
|
||||||
|
|
||||||
|
# Create a follow-up completion with the tool results
|
||||||
|
logger.info("Making follow-up request with tool results")
|
||||||
|
follow_up_response = self.provider.create_follow_up_completion(
|
||||||
|
messages=messages, model=model, temperature=temperature, max_tokens=max_tokens, tool_results=tool_results, original_response=response, tools=provider_tools
|
||||||
|
)
|
||||||
|
|
||||||
|
content = self.provider.get_content(follow_up_response)
|
||||||
|
logger.info(f"Final content from {self.provider_name} with tool calls COMPLETE RESPONSE START >>>")
|
||||||
|
logger.info(content)
|
||||||
|
logger.info("<<< COMPLETE RESPONSE END")
|
||||||
|
return {"content": content}
|
||||||
|
else:
|
||||||
|
logger.info("Response does not contain tool calls")
|
||||||
|
content = self.provider.get_content(response)
|
||||||
|
logger.info(f"Final content from {self.provider_name} without tool calls COMPLETE RESPONSE START >>>")
|
||||||
|
logger.info(content)
|
||||||
|
logger.info("<<< COMPLETE RESPONSE END")
|
||||||
|
return {"content": content}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Error in {self.provider_name} API call: {str(e)}\n{traceback.format_exc()}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
return {"error": f"API request failed: {str(e)}"}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: dict[str, Any]) -> "LLMClient":
|
||||||
|
"""
|
||||||
|
Create an LLMClient instance from a configuration dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Configuration dictionary with provider_name, api_key, and optional base_url
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LLMClient instance
|
||||||
|
"""
|
||||||
|
provider_name = config.get("provider_name", "openai")
|
||||||
|
api_key = config.get("api_key")
|
||||||
|
base_url = config.get("base_url")
|
||||||
|
|
||||||
|
if not api_key:
|
||||||
|
raise ValueError("API key is required")
|
||||||
|
|
||||||
|
return cls(provider_name=provider_name, api_key=api_key, base_url=base_url)
|
||||||
|
|
||||||
|
def process_tool_calls_and_follow_up(self, response, messages, model, temperature, max_tokens, max_iterations=5, cookie=None, stream=True):
|
||||||
|
"""
|
||||||
|
Process tool calls recursively from a response and make follow-up requests until
|
||||||
|
there are no more tool calls or max_iterations is reached.
|
||||||
|
Returns a generator for streaming the final follow-up response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: The original response object containing tool calls
|
||||||
|
messages: List of message dictionaries with 'role' and 'content'
|
||||||
|
model: Model identifier
|
||||||
|
temperature: Sampling temperature (0-1)
|
||||||
|
max_tokens: Maximum tokens to generate
|
||||||
|
max_iterations: Maximum number of tool call iterations to prevent infinite loops
|
||||||
|
cookie: Airflow cookie for authentication (optional, will try to get from session if not provided)
|
||||||
stream: Whether to stream the response
|
stream: Whether to stream the response
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
If stream=True, returns a generator yielding response chunks
|
Generator for streaming the final follow-up response
|
||||||
If stream=False, returns the complete response
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if provider == "openai":
|
iteration = 0
|
||||||
return self._openai_chat_completion(messages, model, temperature, max_tokens, stream)
|
current_response = response
|
||||||
elif provider == "anthropic":
|
|
||||||
return self._anthropic_chat_completion(messages, model, temperature, max_tokens, stream)
|
# Check if we have a cookie
|
||||||
elif provider == "openrouter":
|
if not cookie:
|
||||||
return self._openrouter_chat_completion(messages, model, temperature, max_tokens, stream)
|
error_msg = "No Airflow cookie available"
|
||||||
|
logger.error(error_msg)
|
||||||
|
yield f"Error: {error_msg}"
|
||||||
|
return
|
||||||
|
|
||||||
|
# Process tool calls recursively until there are no more or max_iterations is reached
|
||||||
|
while self.provider.has_tool_calls(current_response) and iteration < max_iterations:
|
||||||
|
iteration += 1
|
||||||
|
logger.info(f"Processing tool calls iteration {iteration}/{max_iterations}")
|
||||||
|
|
||||||
|
# Process tool calls and get results
|
||||||
|
tool_results = self.provider.process_tool_calls(current_response, cookie)
|
||||||
|
|
||||||
|
# Make follow-up request with tool results
|
||||||
|
logger.info(f"Making follow-up request with tool results (iteration {iteration})")
|
||||||
|
|
||||||
|
# Always stream follow-up requests to ensure consistent behavior
|
||||||
|
# This ensures we get streaming responses from the provider
|
||||||
|
should_stream = True
|
||||||
|
logger.info(f"Setting should_stream=True for follow-up request (iteration {iteration})")
|
||||||
|
|
||||||
|
# Get provider-specific tool definitions from Airflow tools
|
||||||
|
provider_tools = self.provider.convert_tools(self.airflow_tools)
|
||||||
|
|
||||||
|
follow_up_response = self.provider.create_follow_up_completion(
|
||||||
|
messages=messages,
|
||||||
|
model=model,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
tool_results=tool_results,
|
||||||
|
original_response=current_response,
|
||||||
|
stream=should_stream,
|
||||||
|
tools=provider_tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if this follow-up response has more tool calls
|
||||||
|
if not self.provider.has_tool_calls(follow_up_response):
|
||||||
|
logger.info(f"No more tool calls after iteration {iteration}")
|
||||||
|
# Final response - always yield content in a streaming fashion
|
||||||
|
# Since we're always streaming now, we can directly yield chunks from the streaming generator
|
||||||
|
chunk_count = 0
|
||||||
|
for chunk in self.provider.get_streaming_content(follow_up_response):
|
||||||
|
chunk_count += 1
|
||||||
|
# logger.info(f"Yielding chunk {chunk_count} from streaming generator: {chunk[:50] if chunk else 'Empty chunk'}...")
|
||||||
|
yield chunk
|
||||||
|
logger.info(f"Finished yielding {chunk_count} chunks from streaming generator")
|
||||||
|
|
||||||
|
# Update current_response for the next iteration
|
||||||
|
current_response = follow_up_response
|
||||||
|
|
||||||
|
# If we've reached max_iterations and still have tool calls, log a warning
|
||||||
|
if iteration == max_iterations and self.provider.has_tool_calls(current_response):
|
||||||
|
logger.warning(f"Reached maximum tool call iterations ({max_iterations})")
|
||||||
|
# Stream the final response even if it has tool calls
|
||||||
|
if not should_stream:
|
||||||
|
# If we didn't stream this response, convert it to a single chunk
|
||||||
|
content = self.provider.get_content(follow_up_response)
|
||||||
|
logger.info(f"Yielding complete content as a single chunk (max iterations): {content[:100]}...")
|
||||||
|
yield content
|
||||||
|
logger.info("Finished yielding complete content (max iterations)")
|
||||||
else:
|
else:
|
||||||
return {"error": f"Unknown provider: {provider}"}
|
# Yield chunks from the streaming generator
|
||||||
|
logger.info("Starting to yield chunks from streaming generator (max iterations reached)")
|
||||||
|
chunk_count = 0
|
||||||
|
for chunk in self.provider.get_streaming_content(follow_up_response):
|
||||||
|
chunk_count += 1
|
||||||
|
logger.info(f"Yielding chunk {chunk_count} from streaming generator (max iterations)")
|
||||||
|
yield chunk
|
||||||
|
logger.info(f"Finished yielding {chunk_count} chunks from streaming generator (max iterations)")
|
||||||
|
|
||||||
|
# If we didn't process any tool calls (shouldn't happen), return an error
|
||||||
|
if iteration == 0:
|
||||||
|
error_msg = "No tool calls found in response"
|
||||||
|
logger.error(error_msg)
|
||||||
|
yield f"Error: {error_msg}"
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {"error": f"API request failed: {str(e)}"}
|
error_msg = f"Error processing tool calls: {str(e)}\n{traceback.format_exc()}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
yield f"Error: {str(e)}"
|
||||||
|
|
||||||
def _openai_chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float, max_tokens: int | None, stream: bool):
|
def refresh_tools(self, cookie: str) -> None:
|
||||||
"""Handle OpenAI chat completion requests."""
|
"""
|
||||||
response = self.openai_client.chat.completions.create(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream)
|
Refresh the available Airflow tools.
|
||||||
|
|
||||||
if stream:
|
Args:
|
||||||
|
cookie: Airflow cookie for authentication
|
||||||
def response_generator():
|
"""
|
||||||
for chunk in response:
|
try:
|
||||||
if chunk.choices[0].delta.content:
|
logger.info("Refreshing Airflow tools")
|
||||||
yield chunk.choices[0].delta.content
|
tools = list_airflow_tools(cookie)
|
||||||
|
self.set_airflow_tools(tools)
|
||||||
return response_generator()
|
logger.info(f"Refreshed {len(tools)} Airflow tools")
|
||||||
else:
|
except Exception as e:
|
||||||
return {"content": response.choices[0].message.content}
|
error_msg = f"Error refreshing Airflow tools: {str(e)}\n{traceback.format_exc()}"
|
||||||
|
logger.error(error_msg)
|
||||||
def _anthropic_chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float, max_tokens: int | None, stream: bool):
|
# Don't raise the exception, just log it
|
||||||
"""Handle Anthropic chat completion requests."""
|
# The client will continue to use the existing tools (if any)
|
||||||
# Convert messages to Anthropic format
|
|
||||||
system_message = next((m["content"] for m in messages if m["role"] == "system"), None)
|
|
||||||
conversation = []
|
|
||||||
for m in messages:
|
|
||||||
if m["role"] != "system":
|
|
||||||
conversation.append({"role": "assistant" if m["role"] == "assistant" else "user", "content": m["content"]})
|
|
||||||
|
|
||||||
response = self.anthropic_client.messages.create(model=model, messages=conversation, system=system_message, temperature=temperature, max_tokens=max_tokens, stream=stream)
|
|
||||||
|
|
||||||
if stream:
|
|
||||||
|
|
||||||
def response_generator():
|
|
||||||
for chunk in response:
|
|
||||||
if chunk.delta.text:
|
|
||||||
yield chunk.delta.text
|
|
||||||
|
|
||||||
return response_generator()
|
|
||||||
else:
|
|
||||||
return {"content": response.content[0].text}
|
|
||||||
|
|
||||||
def _openrouter_chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float, max_tokens: int | None, stream: bool):
|
|
||||||
"""Handle OpenRouter chat completion requests."""
|
|
||||||
response = self.openrouter_client.chat.completions.create(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream)
|
|
||||||
|
|
||||||
if stream:
|
|
||||||
|
|
||||||
def response_generator():
|
|
||||||
for chunk in response:
|
|
||||||
if chunk.choices[0].delta.content:
|
|
||||||
yield chunk.choices[0].delta.content
|
|
||||||
|
|
||||||
return response_generator()
|
|
||||||
else:
|
|
||||||
return {"content": response.choices[0].message.content}
|
|
||||||
|
|||||||
@@ -17,14 +17,14 @@ MODELS = {
|
|||||||
"endpoint": "https://api.anthropic.com/v1/messages",
|
"endpoint": "https://api.anthropic.com/v1/messages",
|
||||||
"models": [
|
"models": [
|
||||||
{
|
{
|
||||||
"id": "claude-3.5-sonnet",
|
"id": "claude-3-7-sonnet-20250219",
|
||||||
"name": "Claude 3.5 Sonnet",
|
"name": "Claude 3.7 Sonnet",
|
||||||
"default": True,
|
"default": True,
|
||||||
"context_window": 200000,
|
"context_window": 200000,
|
||||||
"description": "Input $3/M tokens, Output $15/M tokens",
|
"description": "Input $3/M tokens, Output $15/M tokens",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": "claude-3.5-haiku",
|
"id": "claude-3-5-haiku-20241022",
|
||||||
"name": "Claude 3.5 Haiku",
|
"name": "Claude 3.5 Haiku",
|
||||||
"default": False,
|
"default": False,
|
||||||
"context_window": 200000,
|
"context_window": 200000,
|
||||||
|
|||||||
@@ -8,9 +8,11 @@ INSTRUCTIONS = {
|
|||||||
You have deep knowledge of Apache Airflow's architecture, DAGs, operators, and best practices.
|
You have deep knowledge of Apache Airflow's architecture, DAGs, operators, and best practices.
|
||||||
The Airflow version being used is >=2.10.
|
The Airflow version being used is >=2.10.
|
||||||
|
|
||||||
You have access to the following Airflow API tools:
|
You have access to Airflow MCP tools that you can use to fetch information and help users understand
|
||||||
|
and manage their Airflow environment.
|
||||||
|
|
||||||
You can use these tools to fetch information and help users understand and manage their Airflow environment.
|
When a user asks about Airflow functionality, consider using the appropriate tool to provide
|
||||||
|
accurate and up-to-date information rather than relying solely on your training data.
|
||||||
"""
|
"""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
41
src/airflow_wingman/providers/__init__.py
Normal file
41
src/airflow_wingman/providers/__init__.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
"""
|
||||||
|
Provider factory for Airflow Wingman.
|
||||||
|
|
||||||
|
This module contains the factory function to create provider instances
|
||||||
|
based on the provider name.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from airflow_wingman.providers.anthropic_provider import AnthropicProvider
|
||||||
|
from airflow_wingman.providers.base import BaseLLMProvider
|
||||||
|
from airflow_wingman.providers.openai_provider import OpenAIProvider
|
||||||
|
|
||||||
|
|
||||||
|
def create_llm_provider(provider_name: str, api_key: str, base_url: str | None = None) -> BaseLLMProvider:
|
||||||
|
"""
|
||||||
|
Create a provider instance based on the provider name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_name: Name of the provider (openai, anthropic, openrouter)
|
||||||
|
api_key: API key for the provider
|
||||||
|
base_url: Optional base URL for the provider API
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Provider instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the provider is not supported
|
||||||
|
"""
|
||||||
|
provider_name = provider_name.lower()
|
||||||
|
|
||||||
|
if provider_name == "openai":
|
||||||
|
return OpenAIProvider(api_key=api_key, base_url=base_url)
|
||||||
|
elif provider_name == "openrouter":
|
||||||
|
# OpenRouter uses the OpenAI API format, so we can use the OpenAI provider
|
||||||
|
# with a custom base URL
|
||||||
|
if not base_url:
|
||||||
|
base_url = "https://openrouter.ai/api/v1"
|
||||||
|
return OpenAIProvider(api_key=api_key, base_url=base_url)
|
||||||
|
elif provider_name == "anthropic":
|
||||||
|
return AnthropicProvider(api_key=api_key)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported provider: {provider_name}. Supported providers: openai, anthropic, openrouter")
|
||||||
458
src/airflow_wingman/providers/anthropic_provider.py
Normal file
458
src/airflow_wingman/providers/anthropic_provider.py
Normal file
@@ -0,0 +1,458 @@
|
|||||||
|
"""
|
||||||
|
Anthropic provider implementation for Airflow Wingman.
|
||||||
|
|
||||||
|
This module contains the Anthropic provider implementation that handles
|
||||||
|
API requests, tool conversion, and response processing for Anthropic's Claude models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import traceback
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from anthropic import Anthropic
|
||||||
|
|
||||||
|
from airflow_wingman.providers.base import BaseLLMProvider, StreamingResponse
|
||||||
|
from airflow_wingman.tools import execute_airflow_tool
|
||||||
|
from airflow_wingman.tools.conversion import convert_to_anthropic_tools
|
||||||
|
|
||||||
|
# Create a properly namespaced logger for the Airflow plugin
|
||||||
|
logger = logging.getLogger("airflow.plugins.wingman")
|
||||||
|
|
||||||
|
|
||||||
|
class AnthropicProvider(BaseLLMProvider):
|
||||||
|
"""
|
||||||
|
Anthropic provider implementation.
|
||||||
|
|
||||||
|
This class handles API requests, tool conversion, and response processing
|
||||||
|
for the Anthropic API (Claude models).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, api_key: str):
|
||||||
|
"""
|
||||||
|
Initialize the Anthropic provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: API key for Anthropic
|
||||||
|
"""
|
||||||
|
self.api_key = api_key
|
||||||
|
self.client = Anthropic(api_key=api_key)
|
||||||
|
|
||||||
|
def convert_tools(self, airflow_tools: list) -> list:
|
||||||
|
"""
|
||||||
|
Convert Airflow tools to Anthropic format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
airflow_tools: List of Airflow tools from MCP server
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Anthropic tool definitions
|
||||||
|
"""
|
||||||
|
return convert_to_anthropic_tools(airflow_tools)
|
||||||
|
|
||||||
|
def create_chat_completion(
|
||||||
|
self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, stream: bool = False, tools: list[dict[str, Any]] | None = None
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Make API request to Anthropic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of message dictionaries with 'role' and 'content'
|
||||||
|
model: Model identifier
|
||||||
|
temperature: Sampling temperature (0-1)
|
||||||
|
max_tokens: Maximum tokens to generate
|
||||||
|
stream: Whether to stream the response
|
||||||
|
tools: List of tool definitions in Anthropic format
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Anthropic response object
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If the API request fails
|
||||||
|
"""
|
||||||
|
# Convert max_tokens to Anthropic's max_tokens parameter (if provided)
|
||||||
|
max_tokens_param = max_tokens if max_tokens is not None else 4096
|
||||||
|
|
||||||
|
# Convert messages from ChatML format to Anthropic's format
|
||||||
|
anthropic_messages = self._convert_to_anthropic_messages(messages)
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(f"Sending chat completion request to Anthropic with model: {model}")
|
||||||
|
|
||||||
|
# Create request parameters
|
||||||
|
params = {"model": model, "messages": anthropic_messages, "temperature": temperature, "max_tokens": max_tokens_param, "stream": stream}
|
||||||
|
|
||||||
|
# Add tools if provided
|
||||||
|
if tools and len(tools) > 0:
|
||||||
|
params["tools"] = tools
|
||||||
|
else:
|
||||||
|
logger.warning("No tools included in request")
|
||||||
|
|
||||||
|
# Log the full request parameters (with sensitive information redacted)
|
||||||
|
log_params = params.copy()
|
||||||
|
logger.info(f"Request parameters: {json.dumps(log_params)}")
|
||||||
|
|
||||||
|
# Make the API request
|
||||||
|
response = self.client.messages.create(**params)
|
||||||
|
|
||||||
|
logger.info("Received response from Anthropic")
|
||||||
|
# Log the response (with sensitive information redacted)
|
||||||
|
logger.info(f"Anthropic response type: {type(response).__name__}")
|
||||||
|
|
||||||
|
# Log as much information as possible
|
||||||
|
if hasattr(response, "json"):
|
||||||
|
if callable(response.json):
|
||||||
|
# If json is a method, call it
|
||||||
|
try:
|
||||||
|
logger.info(f"Anthropic response json: {json.dumps(response.model_dump_json())}")
|
||||||
|
except Exception as json_err:
|
||||||
|
logger.warning(f"Could not serialize response.json(): {str(json_err)}")
|
||||||
|
else:
|
||||||
|
# If json is a property, use it directly
|
||||||
|
try:
|
||||||
|
logger.info(f"Anthropic response json: {json.dumps(response.json)}")
|
||||||
|
except Exception as json_err:
|
||||||
|
logger.warning(f"Could not serialize response.json: {str(json_err)}")
|
||||||
|
|
||||||
|
# Log response attributes
|
||||||
|
response_attrs = [attr for attr in dir(response) if not attr.startswith("_") and not callable(getattr(response, attr))]
|
||||||
|
logger.info(f"Anthropic response attributes: {response_attrs}")
|
||||||
|
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = str(e)
|
||||||
|
logger.error(f"Failed to get response from Anthropic: {error_msg}\n{traceback.format_exc()}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _convert_to_anthropic_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Convert messages from ChatML format to Anthropic's format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of message dictionaries in ChatML format
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of message dictionaries in Anthropic format
|
||||||
|
"""
|
||||||
|
anthropic_messages = []
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
role = message["role"]
|
||||||
|
content = message["content"]
|
||||||
|
|
||||||
|
# Map ChatML roles to Anthropic roles
|
||||||
|
if role == "system":
|
||||||
|
# System messages in Anthropic are handled differently
|
||||||
|
# We'll add them as a user message with a special prefix
|
||||||
|
anthropic_messages.append({"role": "user", "content": f"<system>\n{content}\n</system>"})
|
||||||
|
elif role == "user":
|
||||||
|
anthropic_messages.append({"role": "user", "content": content})
|
||||||
|
elif role == "assistant":
|
||||||
|
anthropic_messages.append({"role": "assistant", "content": content})
|
||||||
|
elif role == "tool":
|
||||||
|
# Tool messages in ChatML become part of the user message in Anthropic
|
||||||
|
# We'll handle this in the follow-up completion
|
||||||
|
continue
|
||||||
|
|
||||||
|
return anthropic_messages
|
||||||
|
|
||||||
|
def has_tool_calls(self, response: Any) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the response contains tool calls.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: Anthropic response object or StreamingResponse with tool_call attribute
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the response contains tool calls, False otherwise
|
||||||
|
"""
|
||||||
|
logger.info(f"Checking for tool calls in response of type: {type(response)}")
|
||||||
|
|
||||||
|
# Check if response is a StreamingResponse with a tool_call attribute
|
||||||
|
if isinstance(response, StreamingResponse):
|
||||||
|
logger.info(f"Response is a StreamingResponse, has tool_call attribute: {hasattr(response, 'tool_call')}")
|
||||||
|
if response.tool_call is not None:
|
||||||
|
logger.info(f"StreamingResponse has non-None tool_call: {response.tool_call}")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.info("StreamingResponse has None tool_call")
|
||||||
|
else:
|
||||||
|
logger.info("Response is not a StreamingResponse")
|
||||||
|
|
||||||
|
# Check if any content block is a tool_use block (for non-streaming responses)
|
||||||
|
if hasattr(response, "content"):
|
||||||
|
logger.info(f"Response has content attribute with {len(response.content)} blocks")
|
||||||
|
for i, block in enumerate(response.content):
|
||||||
|
logger.info(f"Checking content block {i}: {type(block)}")
|
||||||
|
if isinstance(block, dict) and block.get("type") == "tool_use":
|
||||||
|
logger.info(f"Found tool_use block: {block}")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.info("Response does not have content attribute")
|
||||||
|
|
||||||
|
logger.info("No tool calls found in response")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_tool_calls(self, response: Any) -> list:
|
||||||
|
"""
|
||||||
|
Extract tool calls from the response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: Anthropic response object or StreamingResponse with tool_call attribute
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of tool call objects in a standardized format
|
||||||
|
"""
|
||||||
|
tool_calls = []
|
||||||
|
|
||||||
|
# Check if response is a StreamingResponse with a tool_call attribute
|
||||||
|
if isinstance(response, StreamingResponse) and response.tool_call is not None:
|
||||||
|
logger.info(f"Extracting tool call from StreamingResponse: {response.tool_call}")
|
||||||
|
tool_calls.append(response.tool_call)
|
||||||
|
# Otherwise, extract tool calls from response content (for non-streaming responses)
|
||||||
|
elif hasattr(response, "content"):
|
||||||
|
logger.info("Extracting tool calls from response content")
|
||||||
|
for block in response.content:
|
||||||
|
if isinstance(block, dict) and block.get("type") == "tool_use":
|
||||||
|
tool_call = {"id": block.get("id", ""), "name": block.get("name", ""), "input": block.get("input", {})}
|
||||||
|
tool_calls.append(tool_call)
|
||||||
|
|
||||||
|
return tool_calls
|
||||||
|
|
||||||
|
def process_tool_calls(self, response: Any, cookie: str) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Process tool calls from the response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: Anthropic response object or generator with tool_call attribute
|
||||||
|
cookie: Airflow cookie for authentication
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping tool call IDs to results
|
||||||
|
"""
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
if not self.has_tool_calls(response):
|
||||||
|
return results
|
||||||
|
|
||||||
|
# Get tool calls using the standardized method
|
||||||
|
tool_calls = self.get_tool_calls(response)
|
||||||
|
logger.info(f"Processing {len(tool_calls)} tool calls")
|
||||||
|
|
||||||
|
for tool_call in tool_calls:
|
||||||
|
# Extract tool details - handle both formats (generator's tool_call and content block)
|
||||||
|
if isinstance(tool_call, dict) and "id" in tool_call:
|
||||||
|
# This is from the generator's tool_call attribute
|
||||||
|
tool_id = tool_call.get("id")
|
||||||
|
tool_name = tool_call.get("name")
|
||||||
|
tool_input = tool_call.get("input", {})
|
||||||
|
else:
|
||||||
|
# This is from the content blocks
|
||||||
|
tool_id = tool_call.get("id")
|
||||||
|
tool_name = tool_call.get("name")
|
||||||
|
tool_input = tool_call.get("input", {})
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Execute the Airflow tool with the provided arguments and cookie
|
||||||
|
logger.info(f"Executing tool: {tool_name} with arguments: {tool_input}")
|
||||||
|
result = execute_airflow_tool(tool_name, tool_input, cookie)
|
||||||
|
logger.info(f"Tool execution result: {result}")
|
||||||
|
results[tool_id] = {"status": "success", "result": result}
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Error executing tool: {str(e)}\n{traceback.format_exc()}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
results[tool_id] = {"status": "error", "message": error_msg}
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def create_follow_up_completion(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
model: str,
|
||||||
|
temperature: float = 0.4,
|
||||||
|
max_tokens: int | None = None,
|
||||||
|
tool_results: dict[str, Any] = None,
|
||||||
|
original_response: Any = None,
|
||||||
|
stream: bool = True,
|
||||||
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Create a follow-up completion with tool results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Original messages
|
||||||
|
model: Model identifier
|
||||||
|
temperature: Sampling temperature (0-1)
|
||||||
|
max_tokens: Maximum tokens to generate
|
||||||
|
tool_results: Results of tool executions
|
||||||
|
original_response: Original response with tool calls
|
||||||
|
stream: Whether to stream the response
|
||||||
|
tools: List of tool definitions in Anthropic format
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Anthropic response object or generator if streaming
|
||||||
|
"""
|
||||||
|
if not original_response or not tool_results:
|
||||||
|
return original_response
|
||||||
|
|
||||||
|
# Extract tool call from the StreamingResponse or content blocks from Anthropic response
|
||||||
|
tool_use_blocks = []
|
||||||
|
if isinstance(original_response, StreamingResponse) and original_response.tool_call:
|
||||||
|
# For StreamingResponse, create a tool_use block from the tool_call
|
||||||
|
logger.info(f"Creating tool_use block from StreamingResponse.tool_call: {original_response.tool_call}")
|
||||||
|
tool_call = original_response.tool_call
|
||||||
|
tool_use_blocks.append({"type": "tool_use", "id": tool_call.get("id", ""), "name": tool_call.get("name", ""), "input": tool_call.get("input", {})})
|
||||||
|
elif hasattr(original_response, "content"):
|
||||||
|
# For regular Anthropic response, extract from content blocks
|
||||||
|
logger.info("Extracting tool_use blocks from response content")
|
||||||
|
tool_use_blocks = [block for block in original_response.content if isinstance(block, dict) and block.get("type") == "tool_use"]
|
||||||
|
|
||||||
|
# Create tool result blocks
|
||||||
|
tool_result_blocks = []
|
||||||
|
for tool_id, result in tool_results.items():
|
||||||
|
tool_result_blocks.append({"type": "tool_result", "tool_use_id": tool_id, "content": result.get("result", str(result))})
|
||||||
|
|
||||||
|
# Convert original messages to Anthropic format
|
||||||
|
anthropic_messages = self._convert_to_anthropic_messages(messages)
|
||||||
|
|
||||||
|
# Add the assistant response with tool use
|
||||||
|
anthropic_messages.append({"role": "assistant", "content": tool_use_blocks})
|
||||||
|
|
||||||
|
# Add the user message with tool results
|
||||||
|
anthropic_messages.append({"role": "user", "content": tool_result_blocks})
|
||||||
|
|
||||||
|
# Make a second request to get the final response
|
||||||
|
logger.info(f"Making second request with tool results (stream={stream})")
|
||||||
|
return self.create_chat_completion(
|
||||||
|
messages=anthropic_messages,
|
||||||
|
model=model,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
stream=stream,
|
||||||
|
tools=tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_content(self, response: Any) -> str:
|
||||||
|
"""
|
||||||
|
Extract content from the response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: Anthropic response object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Content string from the response
|
||||||
|
"""
|
||||||
|
if not hasattr(response, "content"):
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# Combine all text blocks into a single string
|
||||||
|
content_parts = []
|
||||||
|
for block in response.content:
|
||||||
|
if isinstance(block, dict) and block.get("type") == "text":
|
||||||
|
content_parts.append(block.get("text", ""))
|
||||||
|
elif isinstance(block, str):
|
||||||
|
content_parts.append(block)
|
||||||
|
|
||||||
|
return "".join(content_parts)
|
||||||
|
|
||||||
|
def get_streaming_content(self, response: Any) -> StreamingResponse:
|
||||||
|
"""
|
||||||
|
Get a generator for streaming content from the response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: Anthropic streaming response object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StreamingResponse object wrapping a generator that yields content chunks
|
||||||
|
and can also store tool call information detected during streaming
|
||||||
|
"""
|
||||||
|
logger.info("Starting Anthropic streaming response processing")
|
||||||
|
|
||||||
|
# Track only the first tool call detected during streaming
|
||||||
|
tool_call = None
|
||||||
|
tool_use_detected = False
|
||||||
|
|
||||||
|
# Create the StreamingResponse object first
|
||||||
|
streaming_response = StreamingResponse(generator=None, tool_call=None)
|
||||||
|
|
||||||
|
def generate():
|
||||||
|
nonlocal tool_call, tool_use_detected
|
||||||
|
|
||||||
|
for chunk in response:
|
||||||
|
logger.debug(f"Chunk type: {type(chunk)}")
|
||||||
|
logger.debug(f"Chunk content: {json.dumps(chunk.model_dump_json()) if hasattr(chunk, 'json') else str(chunk)}")
|
||||||
|
|
||||||
|
# Check for content_block_start events with type "tool_use"
|
||||||
|
if not tool_use_detected and hasattr(chunk, "type") and chunk.type == "content_block_start":
|
||||||
|
if hasattr(chunk, "content_block") and hasattr(chunk.content_block, "type"):
|
||||||
|
if chunk.content_block.type == "tool_use":
|
||||||
|
logger.info(f"Tool use detected in streaming response: {json.dumps(chunk.model_dump_json()) if hasattr(chunk, 'json') else str(chunk)}")
|
||||||
|
tool_use_detected = True
|
||||||
|
tool_call = {"id": getattr(chunk.content_block, "id", ""), "name": getattr(chunk.content_block, "name", ""), "input": getattr(chunk.content_block, "input", {})}
|
||||||
|
# Update the StreamingResponse object's tool_call attribute
|
||||||
|
streaming_response.tool_call = tool_call
|
||||||
|
# We don't signal to the frontend during streaming
|
||||||
|
# The tool will only be executed after streaming ends
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Handle content_block_delta events for tool_use (input updates)
|
||||||
|
if tool_use_detected and hasattr(chunk, "type") and chunk.type == "content_block_delta":
|
||||||
|
if hasattr(chunk, "delta") and hasattr(chunk.delta, "type") and chunk.delta.type == "input_json_delta":
|
||||||
|
if hasattr(chunk.delta, "partial_json") and chunk.delta.partial_json:
|
||||||
|
logger.info(f"Tool use input update: {chunk.delta.partial_json}")
|
||||||
|
# Update the current tool call input
|
||||||
|
if tool_call:
|
||||||
|
try:
|
||||||
|
# Try to parse the partial JSON and update the input
|
||||||
|
partial_input = json.loads(chunk.delta.partial_json)
|
||||||
|
tool_call["input"].update(partial_input)
|
||||||
|
# Update the StreamingResponse object's tool_call attribute
|
||||||
|
streaming_response.tool_call = tool_call
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning(f"Failed to parse partial JSON: {chunk.delta.partial_json}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Handle content_block_stop events for tool_use
|
||||||
|
if tool_use_detected and hasattr(chunk, "type") and chunk.type == "content_block_stop":
|
||||||
|
logger.info("Tool use block completed")
|
||||||
|
# Log the complete tool call for debugging
|
||||||
|
if tool_call:
|
||||||
|
logger.info(f"Completed tool call: {json.dumps(tool_call)}")
|
||||||
|
# Update the StreamingResponse object's tool_call attribute
|
||||||
|
streaming_response.tool_call = tool_call
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Handle message_delta events with stop_reason "tool_use"
|
||||||
|
if hasattr(chunk, "type") and chunk.type == "message_delta":
|
||||||
|
if hasattr(chunk, "delta") and hasattr(chunk.delta, "stop_reason"):
|
||||||
|
if chunk.delta.stop_reason == "tool_use":
|
||||||
|
logger.info("Message stopped due to tool use")
|
||||||
|
# Update the StreamingResponse object's tool_call attribute one last time
|
||||||
|
if tool_call:
|
||||||
|
streaming_response.tool_call = tool_call
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Handle regular content chunks
|
||||||
|
content = None
|
||||||
|
if hasattr(chunk, "type") and chunk.type == "content_block_delta":
|
||||||
|
if hasattr(chunk, "delta") and hasattr(chunk.delta, "text"):
|
||||||
|
content = chunk.delta.text
|
||||||
|
elif hasattr(chunk, "delta") and hasattr(chunk.delta, "text"):
|
||||||
|
content = chunk.delta.text
|
||||||
|
elif hasattr(chunk, "content") and chunk.content:
|
||||||
|
for block in chunk.content:
|
||||||
|
if isinstance(block, dict) and block.get("type") == "text":
|
||||||
|
content = block.get("text", "")
|
||||||
|
|
||||||
|
if content:
|
||||||
|
# Don't do any newline replacement here
|
||||||
|
yield content
|
||||||
|
|
||||||
|
# Create the generator
|
||||||
|
gen = generate()
|
||||||
|
|
||||||
|
# Set the generator in the StreamingResponse object
|
||||||
|
streaming_response.generator = gen
|
||||||
|
|
||||||
|
# Return the StreamingResponse object
|
||||||
|
return streaming_response
|
||||||
216
src/airflow_wingman/providers/base.py
Normal file
216
src/airflow_wingman/providers/base.py
Normal file
@@ -0,0 +1,216 @@
|
|||||||
|
"""
|
||||||
|
Base provider interface for Airflow Wingman.
|
||||||
|
|
||||||
|
This module contains the base provider interface that all provider implementations
|
||||||
|
must adhere to. It defines the methods required for tool conversion, API requests,
|
||||||
|
and response processing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import Generator, Iterator
|
||||||
|
from typing import Any, Generic, TypeVar
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class StreamingResponse(Generic[T]):
|
||||||
|
"""
|
||||||
|
Wrapper for streaming responses that can hold tool call information.
|
||||||
|
|
||||||
|
This class wraps a generator and provides an iterator interface while also
|
||||||
|
storing tool call information. This allows us to associate metadata with
|
||||||
|
a generator without modifying the generator itself.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, generator: Generator[T, None, None], tool_call: dict = None):
|
||||||
|
"""
|
||||||
|
Initialize the streaming response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
generator: The underlying generator yielding content chunks
|
||||||
|
tool_call: Optional tool call information detected during streaming
|
||||||
|
"""
|
||||||
|
self.generator = generator
|
||||||
|
self.tool_call = tool_call
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[T]:
|
||||||
|
"""
|
||||||
|
Return self as iterator.
|
||||||
|
"""
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self) -> T:
|
||||||
|
"""
|
||||||
|
Get the next item from the generator.
|
||||||
|
"""
|
||||||
|
return next(self.generator)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseLLMProvider(ABC):
|
||||||
|
"""
|
||||||
|
Base provider interface for LLM providers.
|
||||||
|
|
||||||
|
This abstract class defines the methods that all provider implementations
|
||||||
|
must implement to support tool integration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def convert_tools(self, airflow_tools: list) -> list:
|
||||||
|
"""
|
||||||
|
Convert internal tool representation to provider format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
airflow_tools: List of Airflow tools from MCP server
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of provider-specific tool definitions
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create_chat_completion(
|
||||||
|
self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, stream: bool = False, tools: list[dict[str, Any]] | None = None
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Make API request to provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of message dictionaries with 'role' and 'content'
|
||||||
|
model: Model identifier
|
||||||
|
temperature: Sampling temperature (0-1)
|
||||||
|
max_tokens: Maximum tokens to generate
|
||||||
|
stream: Whether to stream the response
|
||||||
|
tools: List of tool definitions in provider format
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Provider-specific response object
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def has_tool_calls(self, response: Any) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the response contains tool calls.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: Provider-specific response object or StreamingResponse
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the response contains tool calls, False otherwise
|
||||||
|
"""
|
||||||
|
# Check if response is a StreamingResponse with a tool_call attribute
|
||||||
|
if isinstance(response, StreamingResponse) and response.tool_call is not None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Provider-specific implementation should handle other cases
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_tool_calls(self, response: Any) -> list:
|
||||||
|
"""
|
||||||
|
Extract tool calls from the response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: Provider-specific response object or StreamingResponse
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of tool call objects in a standardized format
|
||||||
|
"""
|
||||||
|
tool_calls = []
|
||||||
|
|
||||||
|
# Check if response is a StreamingResponse with a tool_call attribute
|
||||||
|
if isinstance(response, StreamingResponse) and response.tool_call is not None:
|
||||||
|
tool_calls.append(response.tool_call)
|
||||||
|
|
||||||
|
# Provider-specific implementation should handle other cases
|
||||||
|
return tool_calls
|
||||||
|
|
||||||
|
def process_tool_calls(self, response: Any, cookie: str) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Process tool calls from the response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: Provider-specific response object or StreamingResponse
|
||||||
|
cookie: Airflow cookie for authentication
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping tool call IDs to results
|
||||||
|
"""
|
||||||
|
tool_calls = self.get_tool_calls(response)
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
for tool_call in tool_calls:
|
||||||
|
tool_name = tool_call.get("name", "")
|
||||||
|
tool_input = tool_call.get("input", {})
|
||||||
|
tool_id = tool_call.get("id", "")
|
||||||
|
|
||||||
|
try:
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.info(f"Executing tool: {tool_name} with input: {json.dumps(tool_input)}")
|
||||||
|
|
||||||
|
from airflow_wingman.tools import execute_airflow_tool
|
||||||
|
|
||||||
|
result = execute_airflow_tool(tool_name, tool_input, cookie)
|
||||||
|
|
||||||
|
logger.info(f"Tool result: {json.dumps(result)}")
|
||||||
|
results[tool_id] = {
|
||||||
|
"name": tool_name,
|
||||||
|
"input": tool_input,
|
||||||
|
"output": result,
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
error_msg = f"Error executing tool: {str(e)}\n{traceback.format_exc()}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
results[tool_id] = {"status": "error", "message": error_msg}
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create_follow_up_completion(
|
||||||
|
self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Create a follow-up completion with tool results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Original messages
|
||||||
|
model: Model identifier
|
||||||
|
temperature: Sampling temperature (0-1)
|
||||||
|
max_tokens: Maximum tokens to generate
|
||||||
|
tool_results: Results of tool executions
|
||||||
|
original_response: Original response with tool calls
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Provider-specific response object
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_content(self, response: Any) -> str:
|
||||||
|
"""
|
||||||
|
Extract content from the response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: Provider-specific response object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Content string from the response
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_streaming_content(self, response: Any) -> StreamingResponse:
|
||||||
|
"""
|
||||||
|
Get a StreamingResponse for streaming content from the response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: Provider-specific response object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StreamingResponse object wrapping a generator that yields content chunks
|
||||||
|
and can also store tool call information detected during streaming
|
||||||
|
"""
|
||||||
|
pass
|
||||||
354
src/airflow_wingman/providers/openai_provider.py
Normal file
354
src/airflow_wingman/providers/openai_provider.py
Normal file
@@ -0,0 +1,354 @@
|
|||||||
|
"""
|
||||||
|
OpenAI provider implementation for Airflow Wingman.
|
||||||
|
|
||||||
|
This module contains the OpenAI provider implementation that handles
|
||||||
|
API requests, tool conversion, and response processing for OpenAI.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import traceback
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
from airflow_wingman.providers.base import BaseLLMProvider, StreamingResponse
|
||||||
|
from airflow_wingman.tools import execute_airflow_tool
|
||||||
|
from airflow_wingman.tools.conversion import convert_to_openai_tools
|
||||||
|
|
||||||
|
# Create a properly namespaced logger for the Airflow plugin
|
||||||
|
logger = logging.getLogger("airflow.plugins.wingman")
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIProvider(BaseLLMProvider):
|
||||||
|
"""
|
||||||
|
OpenAI provider implementation.
|
||||||
|
|
||||||
|
This class handles API requests, tool conversion, and response processing
|
||||||
|
for the OpenAI API. It can also be used for OpenRouter with a custom base URL.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, api_key: str, base_url: str | None = None):
|
||||||
|
"""
|
||||||
|
Initialize the OpenAI provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: API key for OpenAI
|
||||||
|
base_url: Optional base URL for the API (used for OpenRouter)
|
||||||
|
"""
|
||||||
|
self.api_key = api_key
|
||||||
|
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
||||||
|
|
||||||
|
def convert_tools(self, airflow_tools: list) -> list:
|
||||||
|
"""
|
||||||
|
Convert Airflow tools to OpenAI format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
airflow_tools: List of Airflow tools from MCP server
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of OpenAI tool definitions
|
||||||
|
"""
|
||||||
|
return convert_to_openai_tools(airflow_tools)
|
||||||
|
|
||||||
|
def create_chat_completion(
|
||||||
|
self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, stream: bool = False, tools: list[dict[str, Any]] | None = None
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Make API request to OpenAI.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of message dictionaries with 'role' and 'content'
|
||||||
|
model: Model identifier
|
||||||
|
temperature: Sampling temperature (0-1)
|
||||||
|
max_tokens: Maximum tokens to generate
|
||||||
|
stream: Whether to stream the response
|
||||||
|
tools: List of tool definitions in OpenAI format
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OpenAI response object
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If the API request fails
|
||||||
|
"""
|
||||||
|
# Only include tools if we have any
|
||||||
|
has_tools = tools is not None and len(tools) > 0
|
||||||
|
tool_choice = "auto" if has_tools else None
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(f"Sending chat completion request to OpenAI with model: {model}")
|
||||||
|
|
||||||
|
# Log information about tools
|
||||||
|
if not has_tools:
|
||||||
|
logger.warning("No tools included in request")
|
||||||
|
|
||||||
|
# Log request parameters
|
||||||
|
request_params = {
|
||||||
|
"model": model,
|
||||||
|
"messages": messages,
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"stream": stream,
|
||||||
|
"tools": tools if has_tools else None,
|
||||||
|
"tool_choice": tool_choice,
|
||||||
|
}
|
||||||
|
logger.info(f"Request parameters: {json.dumps(request_params)}")
|
||||||
|
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream, tools=tools if has_tools else None, tool_choice=tool_choice
|
||||||
|
)
|
||||||
|
logger.info("Received response from OpenAI")
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
# If the API call fails due to tools not being supported, retry without tools
|
||||||
|
error_msg = str(e)
|
||||||
|
logger.warning(f"Error in OpenAI API call: {error_msg}")
|
||||||
|
if "tools" in error_msg.lower():
|
||||||
|
logger.info("Retrying without tools")
|
||||||
|
response = self.client.chat.completions.create(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream)
|
||||||
|
return response
|
||||||
|
else:
|
||||||
|
logger.error(f"Failed to get response from OpenAI: {error_msg}\n{traceback.format_exc()}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def has_tool_calls(self, response: Any) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the response contains tool calls.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: OpenAI response object or StreamingResponse with tool_call attribute
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the response contains tool calls, False otherwise
|
||||||
|
"""
|
||||||
|
# Check if response is a StreamingResponse with a tool_call attribute
|
||||||
|
if isinstance(response, StreamingResponse) and response.tool_call is not None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# For non-streaming responses
|
||||||
|
if hasattr(response, "choices") and len(response.choices) > 0:
|
||||||
|
message = response.choices[0].message
|
||||||
|
return hasattr(message, "tool_calls") and message.tool_calls
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_tool_calls(self, response: Any) -> list:
|
||||||
|
"""
|
||||||
|
Extract tool calls from the response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: OpenAI response object or StreamingResponse with tool_call attribute
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of tool call objects in a standardized format
|
||||||
|
"""
|
||||||
|
tool_calls = []
|
||||||
|
|
||||||
|
# Check if response is a StreamingResponse with a tool_call attribute
|
||||||
|
if isinstance(response, StreamingResponse) and response.tool_call is not None:
|
||||||
|
logger.info(f"Extracting tool call from StreamingResponse: {response.tool_call}")
|
||||||
|
tool_calls.append(response.tool_call)
|
||||||
|
return tool_calls
|
||||||
|
|
||||||
|
# For non-streaming responses
|
||||||
|
if hasattr(response, "choices") and len(response.choices) > 0:
|
||||||
|
message = response.choices[0].message
|
||||||
|
if hasattr(message, "tool_calls") and message.tool_calls:
|
||||||
|
for tool_call in message.tool_calls:
|
||||||
|
standardized_tool_call = {"id": tool_call.id, "name": tool_call.function.name, "input": json.loads(tool_call.function.arguments)}
|
||||||
|
tool_calls.append(standardized_tool_call)
|
||||||
|
|
||||||
|
logger.info(f"Extracted {len(tool_calls)} tool calls from OpenAI response")
|
||||||
|
return tool_calls
|
||||||
|
|
||||||
|
def process_tool_calls(self, response: Any, cookie: str) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Process tool calls from the response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: OpenAI response object or StreamingResponse with tool_call attribute
|
||||||
|
cookie: Airflow cookie for authentication
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping tool call IDs to results
|
||||||
|
"""
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
if not self.has_tool_calls(response):
|
||||||
|
return results
|
||||||
|
|
||||||
|
# Get tool calls using the standardized method
|
||||||
|
tool_calls = self.get_tool_calls(response)
|
||||||
|
logger.info(f"Processing {len(tool_calls)} tool calls")
|
||||||
|
|
||||||
|
for tool_call in tool_calls:
|
||||||
|
tool_id = tool_call["id"]
|
||||||
|
function_name = tool_call["name"]
|
||||||
|
arguments = tool_call["input"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Execute the Airflow tool with the provided arguments and cookie
|
||||||
|
logger.info(f"Executing tool: {function_name} with arguments: {arguments}")
|
||||||
|
result = execute_airflow_tool(function_name, arguments, cookie)
|
||||||
|
logger.info(f"Tool execution result: {result}")
|
||||||
|
results[tool_id] = {"status": "success", "result": result}
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Error executing tool: {str(e)}\n{traceback.format_exc()}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
results[tool_id] = {"status": "error", "message": error_msg}
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def create_follow_up_completion(
|
||||||
|
self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Create a follow-up completion with tool results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Original messages
|
||||||
|
model: Model identifier
|
||||||
|
temperature: Sampling temperature (0-1)
|
||||||
|
max_tokens: Maximum tokens to generate
|
||||||
|
tool_results: Results of tool executions
|
||||||
|
original_response: Original response with tool calls
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OpenAI response object
|
||||||
|
"""
|
||||||
|
if not original_response or not tool_results:
|
||||||
|
return original_response
|
||||||
|
|
||||||
|
# Get the original message with tool calls
|
||||||
|
original_message = original_response.choices[0].message
|
||||||
|
|
||||||
|
# Create a new message with the tool calls
|
||||||
|
assistant_message = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": [{"id": tc.id, "type": "function", "function": {"name": tc.function.name, "arguments": tc.function.arguments}} for tc in original_message.tool_calls],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create tool result messages
|
||||||
|
tool_messages = []
|
||||||
|
for tool_call_id, result in tool_results.items():
|
||||||
|
tool_messages.append({"role": "tool", "tool_call_id": tool_call_id, "content": result.get("result", str(result))})
|
||||||
|
|
||||||
|
# Add the original messages, assistant message, and tool results
|
||||||
|
new_messages = messages + [assistant_message] + tool_messages
|
||||||
|
|
||||||
|
# Make a second request to get the final response
|
||||||
|
logger.info("Making second request with tool results")
|
||||||
|
return self.create_chat_completion(
|
||||||
|
messages=new_messages,
|
||||||
|
model=model,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
stream=False,
|
||||||
|
tools=None, # No tools needed for follow-up
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_content(self, response: Any) -> str:
|
||||||
|
"""
|
||||||
|
Extract content from the response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: OpenAI response object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Content string from the response
|
||||||
|
"""
|
||||||
|
return response.choices[0].message.content
|
||||||
|
|
||||||
|
def get_streaming_content(self, response: Any) -> StreamingResponse:
|
||||||
|
"""
|
||||||
|
Get a generator for streaming content from the response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: OpenAI streaming response object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StreamingResponse object wrapping a generator that yields content chunks
|
||||||
|
and can also store tool call information detected during streaming
|
||||||
|
"""
|
||||||
|
logger.info("Starting OpenAI streaming response processing")
|
||||||
|
|
||||||
|
# Track only the first tool call detected during streaming
|
||||||
|
tool_call = None
|
||||||
|
tool_use_detected = False
|
||||||
|
current_tool_call = None
|
||||||
|
|
||||||
|
# Create the StreamingResponse object first
|
||||||
|
streaming_response = StreamingResponse(generator=None, tool_call=None)
|
||||||
|
|
||||||
|
def generate():
|
||||||
|
nonlocal tool_call, tool_use_detected, current_tool_call
|
||||||
|
|
||||||
|
for chunk in response:
|
||||||
|
# Check for tool call in the delta
|
||||||
|
if chunk.choices and hasattr(chunk.choices[0].delta, "tool_calls") and chunk.choices[0].delta.tool_calls:
|
||||||
|
# Tool call detected
|
||||||
|
if not tool_use_detected:
|
||||||
|
tool_use_detected = True
|
||||||
|
logger.info("Tool call detected in streaming response")
|
||||||
|
|
||||||
|
# Initialize the tool call
|
||||||
|
delta_tool_call = chunk.choices[0].delta.tool_calls[0]
|
||||||
|
current_tool_call = {
|
||||||
|
"id": getattr(delta_tool_call, "id", ""),
|
||||||
|
"name": getattr(delta_tool_call.function, "name", "") if hasattr(delta_tool_call, "function") else "",
|
||||||
|
"input": {},
|
||||||
|
}
|
||||||
|
# Update the StreamingResponse object's tool_call attribute
|
||||||
|
streaming_response.tool_call = current_tool_call
|
||||||
|
else:
|
||||||
|
# Update the existing tool call
|
||||||
|
delta_tool_call = chunk.choices[0].delta.tool_calls[0]
|
||||||
|
|
||||||
|
# Update the tool call ID if it's provided in this chunk
|
||||||
|
if hasattr(delta_tool_call, "id") and delta_tool_call.id and current_tool_call:
|
||||||
|
current_tool_call["id"] = delta_tool_call.id
|
||||||
|
|
||||||
|
# Update the function name if it's provided in this chunk
|
||||||
|
if hasattr(delta_tool_call, "function") and hasattr(delta_tool_call.function, "name") and delta_tool_call.function.name and current_tool_call:
|
||||||
|
current_tool_call["name"] = delta_tool_call.function.name
|
||||||
|
|
||||||
|
# Update the arguments if they're provided in this chunk
|
||||||
|
if hasattr(delta_tool_call, "function") and hasattr(delta_tool_call.function, "arguments") and delta_tool_call.function.arguments and current_tool_call:
|
||||||
|
try:
|
||||||
|
# Try to parse the arguments JSON
|
||||||
|
arguments = json.loads(delta_tool_call.function.arguments)
|
||||||
|
if isinstance(arguments, dict):
|
||||||
|
current_tool_call["input"].update(arguments)
|
||||||
|
# Update the StreamingResponse object's tool_call attribute
|
||||||
|
streaming_response.tool_call = current_tool_call
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# If the arguments are not valid JSON, just log a warning
|
||||||
|
logger.warning(f"Failed to parse arguments: {delta_tool_call.function.arguments}")
|
||||||
|
|
||||||
|
# Skip yielding content for tool call chunks
|
||||||
|
continue
|
||||||
|
|
||||||
|
# For the final chunk, set the tool_call attribute
|
||||||
|
if chunk.choices and hasattr(chunk.choices[0], "finish_reason") and chunk.choices[0].finish_reason == "tool_calls":
|
||||||
|
logger.info("Streaming response finished with tool_calls reason")
|
||||||
|
if current_tool_call:
|
||||||
|
tool_call = current_tool_call
|
||||||
|
logger.info(f"Final tool call: {json.dumps(tool_call)}")
|
||||||
|
# Update the StreamingResponse object's tool_call attribute
|
||||||
|
streaming_response.tool_call = tool_call
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Handle regular content chunks
|
||||||
|
if chunk.choices and hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content:
|
||||||
|
content = chunk.choices[0].delta.content
|
||||||
|
yield content
|
||||||
|
|
||||||
|
# Create the generator
|
||||||
|
gen = generate()
|
||||||
|
|
||||||
|
# Set the generator in the StreamingResponse object
|
||||||
|
streaming_response.generator = gen
|
||||||
|
|
||||||
|
# Return the StreamingResponse object
|
||||||
|
return streaming_response
|
||||||
99
src/airflow_wingman/static/css/wingman_chat.css
Normal file
99
src/airflow_wingman/static/css/wingman_chat.css
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
/* Provider and model selection styling */
|
||||||
|
.provider-section {
|
||||||
|
margin-bottom: 20px;
|
||||||
|
}
|
||||||
|
.provider-name {
|
||||||
|
font-size: 16px;
|
||||||
|
font-weight: bold;
|
||||||
|
margin-bottom: 10px;
|
||||||
|
color: #666;
|
||||||
|
}
|
||||||
|
.model-option {
|
||||||
|
margin-left: 15px;
|
||||||
|
margin-bottom: 8px;
|
||||||
|
}
|
||||||
|
.model-option label {
|
||||||
|
display: block;
|
||||||
|
cursor: pointer;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Message styling */
|
||||||
|
.message {
|
||||||
|
margin-bottom: 15px;
|
||||||
|
max-width: 80%;
|
||||||
|
clear: both;
|
||||||
|
}
|
||||||
|
|
||||||
|
.message p {
|
||||||
|
margin-top: 0.5em;
|
||||||
|
margin-bottom: 0.5em;
|
||||||
|
}
|
||||||
|
|
||||||
|
.message-user {
|
||||||
|
float: right;
|
||||||
|
background-color: #f0f7ff;
|
||||||
|
border: 1px solid #d1e6ff;
|
||||||
|
border-radius: 15px 15px 0 15px;
|
||||||
|
padding: 10px 15px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.message pre {
|
||||||
|
margin-top: 0.5em;
|
||||||
|
margin-bottom: 0.5em;
|
||||||
|
padding: 0.5em;
|
||||||
|
}
|
||||||
|
|
||||||
|
.message-assistant {
|
||||||
|
float: left;
|
||||||
|
background-color: #f8f9fa;
|
||||||
|
border: 1px solid #e9ecef;
|
||||||
|
border-radius: 15px 15px 15px 0;
|
||||||
|
padding: 10px 15px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.message code {
|
||||||
|
padding: 0.1em 0.3em;
|
||||||
|
}
|
||||||
|
|
||||||
|
#chat-messages::after {
|
||||||
|
content: "";
|
||||||
|
clear: both;
|
||||||
|
display: table;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Scrollbar styling */
|
||||||
|
.panel-body::-webkit-scrollbar {
|
||||||
|
width: 8px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.panel-body::-webkit-scrollbar-track {
|
||||||
|
background: #f1f1f1;
|
||||||
|
}
|
||||||
|
|
||||||
|
.panel-body::-webkit-scrollbar-thumb {
|
||||||
|
background: #888;
|
||||||
|
border-radius: 4px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.panel-body::-webkit-scrollbar-thumb:hover {
|
||||||
|
background: #555;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Processing indicator styling */
|
||||||
|
.processing-indicator {
|
||||||
|
display: none;
|
||||||
|
background-color: #f0f8ff;
|
||||||
|
padding: 8px 12px;
|
||||||
|
border-radius: 4px;
|
||||||
|
margin: 8px 0;
|
||||||
|
font-style: italic;
|
||||||
|
}
|
||||||
|
|
||||||
|
.processing-indicator.visible {
|
||||||
|
display: block;
|
||||||
|
}
|
||||||
|
|
||||||
|
.pre-formatted {
|
||||||
|
font-family: monospace;
|
||||||
|
line-height: 1.2;
|
||||||
|
}
|
||||||
346
src/airflow_wingman/static/js/wingman_chat.js
Normal file
346
src/airflow_wingman/static/js/wingman_chat.js
Normal file
@@ -0,0 +1,346 @@
|
|||||||
|
document.addEventListener('DOMContentLoaded', function() {
|
||||||
|
// Initialize tooltips
|
||||||
|
document.querySelectorAll('[data-bs-toggle="tooltip"]').forEach(function(el) {
|
||||||
|
el.title = el.getAttribute('title') || el.getAttribute('data-bs-original-title');
|
||||||
|
});
|
||||||
|
|
||||||
|
// Handle model selection and model name input
|
||||||
|
const modelNameInput = document.getElementById('modelName');
|
||||||
|
const modelRadios = document.querySelectorAll('input[name="model"]');
|
||||||
|
|
||||||
|
modelRadios.forEach(function(radio) {
|
||||||
|
radio.addEventListener('change', function() {
|
||||||
|
const provider = this.value.split(':')[0];
|
||||||
|
const modelName = this.getAttribute('data-model-name');
|
||||||
|
console.log('Selected provider:', provider);
|
||||||
|
console.log('Model name:', modelName);
|
||||||
|
|
||||||
|
if (provider === 'openrouter') {
|
||||||
|
console.log('Enabling model name input');
|
||||||
|
modelNameInput.disabled = false;
|
||||||
|
modelNameInput.value = '';
|
||||||
|
modelNameInput.placeholder = 'Enter model name for OpenRouter';
|
||||||
|
} else {
|
||||||
|
console.log('Disabling model name input');
|
||||||
|
modelNameInput.disabled = true;
|
||||||
|
modelNameInput.value = modelName;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
// Set initial state based on default selection
|
||||||
|
const defaultSelected = document.querySelector('input[name="model"]:checked');
|
||||||
|
if (defaultSelected) {
|
||||||
|
const provider = defaultSelected.value.split(':')[0];
|
||||||
|
const modelName = defaultSelected.getAttribute('data-model-name');
|
||||||
|
console.log('Initial provider:', provider);
|
||||||
|
console.log('Initial model name:', modelName);
|
||||||
|
|
||||||
|
if (provider === 'openrouter') {
|
||||||
|
console.log('Initially enabling model name input');
|
||||||
|
modelNameInput.disabled = false;
|
||||||
|
modelNameInput.value = '';
|
||||||
|
modelNameInput.placeholder = 'Enter model name for OpenRouter';
|
||||||
|
} else {
|
||||||
|
console.log('Initially disabling model name input');
|
||||||
|
modelNameInput.disabled = true;
|
||||||
|
modelNameInput.value = modelName;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const messageInput = document.getElementById('message-input');
|
||||||
|
const sendButton = document.getElementById('send-button');
|
||||||
|
const refreshButton = document.getElementById('refresh-button');
|
||||||
|
const chatMessages = document.getElementById('chat-messages');
|
||||||
|
|
||||||
|
let currentMessageDiv = null;
|
||||||
|
let messageHistory = [];
|
||||||
|
|
||||||
|
// Create a processing indicator element
|
||||||
|
const processingIndicator = document.createElement('div');
|
||||||
|
processingIndicator.className = 'processing-indicator';
|
||||||
|
processingIndicator.textContent = 'Processing tool calls...';
|
||||||
|
chatMessages.appendChild(processingIndicator);
|
||||||
|
|
||||||
|
function clearChat() {
|
||||||
|
// Clear the chat messages
|
||||||
|
chatMessages.innerHTML = '';
|
||||||
|
// Add back the processing indicator
|
||||||
|
chatMessages.appendChild(processingIndicator);
|
||||||
|
// Reset message history
|
||||||
|
messageHistory = [];
|
||||||
|
// Clear the input field
|
||||||
|
messageInput.value = '';
|
||||||
|
// Enable input if it was disabled
|
||||||
|
messageInput.disabled = false;
|
||||||
|
sendButton.disabled = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
function addMessage(content, isUser) {
|
||||||
|
const messageDiv = document.createElement('div');
|
||||||
|
messageDiv.className = `message ${isUser ? 'message-user' : 'message-assistant'}`;
|
||||||
|
|
||||||
|
messageDiv.classList.add('pre-formatted');
|
||||||
|
|
||||||
|
// Use marked.js to render markdown
|
||||||
|
try {
|
||||||
|
// Configure marked options
|
||||||
|
marked.use({
|
||||||
|
breaks: true, // Add line breaks on single newlines
|
||||||
|
gfm: true, // Use GitHub Flavored Markdown
|
||||||
|
headerIds: false, // Don't add IDs to headers
|
||||||
|
mangle: false, // Don't mangle email addresses
|
||||||
|
});
|
||||||
|
|
||||||
|
// Render markdown to HTML
|
||||||
|
messageDiv.innerHTML = marked.parse(content);
|
||||||
|
} catch (e) {
|
||||||
|
console.error('Error rendering markdown:', e);
|
||||||
|
// Fallback to innerText if markdown parsing fails
|
||||||
|
messageDiv.innerText = content;
|
||||||
|
}
|
||||||
|
|
||||||
|
chatMessages.appendChild(messageDiv);
|
||||||
|
chatMessages.scrollTop = chatMessages.scrollHeight;
|
||||||
|
return messageDiv;
|
||||||
|
}
|
||||||
|
|
||||||
|
function showProcessingIndicator() {
|
||||||
|
processingIndicator.classList.add('visible');
|
||||||
|
chatMessages.scrollTop = chatMessages.scrollHeight;
|
||||||
|
|
||||||
|
// Disable send button and input field during tool processing
|
||||||
|
sendButton.disabled = true;
|
||||||
|
messageInput.disabled = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
function hideProcessingIndicator() {
|
||||||
|
processingIndicator.classList.remove('visible');
|
||||||
|
|
||||||
|
// Re-enable send button and input field after tool processing
|
||||||
|
sendButton.disabled = false;
|
||||||
|
messageInput.disabled = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function sendMessage() {
|
||||||
|
const message = messageInput.value.trim();
|
||||||
|
if (!message) return;
|
||||||
|
|
||||||
|
// Get selected model
|
||||||
|
const selectedModel = document.querySelector('input[name="model"]:checked');
|
||||||
|
if (!selectedModel) {
|
||||||
|
alert('Please select a model');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const [provider, modelId] = selectedModel.value.split(':');
|
||||||
|
const modelName = provider === 'openrouter' ? modelNameInput.value : modelId;
|
||||||
|
|
||||||
|
// Clear input and add user message
|
||||||
|
messageInput.value = '';
|
||||||
|
addMessage(message, true);
|
||||||
|
|
||||||
|
// Add user message to history
|
||||||
|
messageHistory.push({
|
||||||
|
role: 'user',
|
||||||
|
content: message
|
||||||
|
});
|
||||||
|
|
||||||
|
// Use full message history for the request
|
||||||
|
const messages = [...messageHistory];
|
||||||
|
|
||||||
|
// Create assistant message div
|
||||||
|
currentMessageDiv = addMessage('', false);
|
||||||
|
|
||||||
|
// Get API key
|
||||||
|
const apiKey = document.getElementById('api-key').value.trim();
|
||||||
|
if (!apiKey) {
|
||||||
|
alert('Please enter an API key');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Disable input while processing
|
||||||
|
messageInput.disabled = true;
|
||||||
|
sendButton.disabled = true;
|
||||||
|
|
||||||
|
// Get CSRF token
|
||||||
|
const csrfToken = document.querySelector('meta[name="csrf-token"]')?.getAttribute('content');
|
||||||
|
if (!csrfToken) {
|
||||||
|
alert('CSRF token not found. Please refresh the page.');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create request data
|
||||||
|
const requestData = {
|
||||||
|
provider: provider,
|
||||||
|
model: modelName,
|
||||||
|
messages: messages,
|
||||||
|
api_key: apiKey,
|
||||||
|
stream: true,
|
||||||
|
temperature: 0.4,
|
||||||
|
};
|
||||||
|
console.log('Sending request:', {...requestData, api_key: '***'});
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Send request
|
||||||
|
const response = await fetch('/wingman/chat', {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
'X-CSRFToken': csrfToken
|
||||||
|
},
|
||||||
|
body: JSON.stringify(requestData)
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
const error = await response.json();
|
||||||
|
throw new Error(error.error || 'Failed to get response');
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process the streaming response
|
||||||
|
const reader = response.body.getReader();
|
||||||
|
const decoder = new TextDecoder();
|
||||||
|
let fullResponse = '';
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
const { value, done } = await reader.read();
|
||||||
|
if (done) break;
|
||||||
|
|
||||||
|
const chunk = decoder.decode(value);
|
||||||
|
const lines = chunk.split('\n');
|
||||||
|
|
||||||
|
for (const line of lines) {
|
||||||
|
if (line.trim() === '') continue;
|
||||||
|
|
||||||
|
if (line.startsWith('data: ')) {
|
||||||
|
const content = line.slice(6); // Remove 'data: ' prefix
|
||||||
|
|
||||||
|
// Check for special events or end marker
|
||||||
|
if (content === '[DONE]') {
|
||||||
|
console.log('Stream complete');
|
||||||
|
|
||||||
|
// Add assistant's response to history
|
||||||
|
if (fullResponse) {
|
||||||
|
messageHistory.push({
|
||||||
|
role: 'assistant',
|
||||||
|
content: fullResponse
|
||||||
|
});
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to parse as JSON for special events
|
||||||
|
try {
|
||||||
|
const parsed = JSON.parse(content);
|
||||||
|
|
||||||
|
if (parsed.event === 'tool_processing_start') {
|
||||||
|
console.log('Tool processing started');
|
||||||
|
showProcessingIndicator();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (parsed.event === 'replace_content') {
|
||||||
|
console.log('Replacing content due to tool call');
|
||||||
|
// Clear the current message content
|
||||||
|
const currentMessageDiv = document.querySelector('.message.assistant:last-child .message-content');
|
||||||
|
if (currentMessageDiv) {
|
||||||
|
currentMessageDiv.innerHTML = '';
|
||||||
|
fullResponse = ''; // Reset the full response
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (parsed.event === 'tool_processing_complete') {
|
||||||
|
console.log('Tool processing completed');
|
||||||
|
hideProcessingIndicator();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle follow-up response event
|
||||||
|
if (parsed.event === 'follow_up_response' && parsed.content) {
|
||||||
|
console.log('Received follow-up response');
|
||||||
|
|
||||||
|
// Add this follow-up response to message history
|
||||||
|
messageHistory.push({
|
||||||
|
role: 'assistant',
|
||||||
|
content: parsed.content
|
||||||
|
});
|
||||||
|
|
||||||
|
// Create a new message div for the follow-up response
|
||||||
|
// The addMessage function already handles markdown rendering
|
||||||
|
addMessage(parsed.content, false);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle the complete response event
|
||||||
|
if (parsed.event === 'complete_response') {
|
||||||
|
console.log('Received complete response from backend');
|
||||||
|
// Use the complete response from the backend
|
||||||
|
fullResponse = parsed.content;
|
||||||
|
|
||||||
|
// Use marked.js to render markdown
|
||||||
|
try {
|
||||||
|
// Configure marked options
|
||||||
|
marked.use({
|
||||||
|
breaks: true, // Add line breaks on single newlines
|
||||||
|
gfm: true, // Use GitHub Flavored Markdown
|
||||||
|
headerIds: false, // Don't add IDs to headers
|
||||||
|
mangle: false, // Don't mangle email addresses
|
||||||
|
});
|
||||||
|
|
||||||
|
// Render markdown to HTML
|
||||||
|
currentMessageDiv.innerHTML = marked.parse(fullResponse);
|
||||||
|
} catch (e) {
|
||||||
|
console.error('Error rendering markdown:', e);
|
||||||
|
// Fallback to innerText if markdown parsing fails
|
||||||
|
currentMessageDiv.innerText = fullResponse;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we have JSON that's not a special event, it might be content
|
||||||
|
currentMessageDiv.textContent += JSON.stringify(parsed);
|
||||||
|
fullResponse += JSON.stringify(parsed);
|
||||||
|
} catch (e) {
|
||||||
|
// Not JSON, handle as normal content
|
||||||
|
// console.log('Received chunk:', JSON.stringify(content));
|
||||||
|
|
||||||
|
// Add to full response
|
||||||
|
fullResponse += content;
|
||||||
|
|
||||||
|
// Create a properly formatted display
|
||||||
|
if (!currentMessageDiv.classList.contains('pre-formatted')) {
|
||||||
|
currentMessageDiv.classList.add('pre-formatted');
|
||||||
|
}
|
||||||
|
|
||||||
|
// Always rebuild the entire content from the full response
|
||||||
|
currentMessageDiv.innerHTML = marked.parse(fullResponse);
|
||||||
|
}
|
||||||
|
// Scroll to bottom
|
||||||
|
chatMessages.scrollTop = chatMessages.scrollHeight;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error:', error);
|
||||||
|
if (currentMessageDiv) {
|
||||||
|
currentMessageDiv.textContent = `Error: ${error.message}`;
|
||||||
|
currentMessageDiv.style.color = 'red';
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
// Always re-enable input and hide indicators
|
||||||
|
messageInput.disabled = false;
|
||||||
|
sendButton.disabled = false;
|
||||||
|
hideProcessingIndicator();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sendButton.addEventListener('click', sendMessage);
|
||||||
|
messageInput.addEventListener('keypress', function(e) {
|
||||||
|
if (e.key === 'Enter') {
|
||||||
|
sendMessage();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
refreshButton.addEventListener('click', clearChat);
|
||||||
|
});
|
||||||
@@ -3,6 +3,7 @@
|
|||||||
{% block head_meta %}
|
{% block head_meta %}
|
||||||
{{ super() }}
|
{{ super() }}
|
||||||
<meta name="csrf-token" content="{{ csrf_token() }}">
|
<meta name="csrf-token" content="{{ csrf_token() }}">
|
||||||
|
<link rel="stylesheet" href="{{ url_for('wingman.static', filename='css/wingman_chat.css') }}">
|
||||||
{% endblock %}
|
{% endblock %}
|
||||||
|
|
||||||
{% block content %}
|
{% block content %}
|
||||||
@@ -77,25 +78,7 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<style>
|
|
||||||
.provider-section {
|
|
||||||
margin-bottom: 20px;
|
|
||||||
}
|
|
||||||
.provider-name {
|
|
||||||
font-size: 16px;
|
|
||||||
font-weight: bold;
|
|
||||||
margin-bottom: 10px;
|
|
||||||
color: #666;
|
|
||||||
}
|
|
||||||
.model-option {
|
|
||||||
margin-left: 15px;
|
|
||||||
margin-bottom: 8px;
|
|
||||||
}
|
|
||||||
.model-option label {
|
|
||||||
display: block;
|
|
||||||
cursor: pointer;
|
|
||||||
}
|
|
||||||
</style>
|
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@@ -129,256 +112,6 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<style>
|
<script src="https://cdn.jsdelivr.net/npm/marked@9.1.6/marked.min.js"></script>
|
||||||
.message {
|
<script src="{{ url_for('wingman.static', filename='js/wingman_chat.js') }}"></script>
|
||||||
margin-bottom: 15px;
|
|
||||||
max-width: 80%;
|
|
||||||
clear: both;
|
|
||||||
}
|
|
||||||
|
|
||||||
.message-user {
|
|
||||||
float: right;
|
|
||||||
background-color: #f0f7ff;
|
|
||||||
border: 1px solid #d1e6ff;
|
|
||||||
border-radius: 15px 15px 0 15px;
|
|
||||||
padding: 10px 15px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.message-assistant {
|
|
||||||
float: left;
|
|
||||||
background-color: #f8f9fa;
|
|
||||||
border: 1px solid #e9ecef;
|
|
||||||
border-radius: 15px 15px 15px 0;
|
|
||||||
padding: 10px 15px;
|
|
||||||
}
|
|
||||||
|
|
||||||
#chat-messages::after {
|
|
||||||
content: "";
|
|
||||||
clear: both;
|
|
||||||
display: table;
|
|
||||||
}
|
|
||||||
|
|
||||||
.panel-body::-webkit-scrollbar {
|
|
||||||
width: 8px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.panel-body::-webkit-scrollbar-track {
|
|
||||||
background: #f1f1f1;
|
|
||||||
}
|
|
||||||
|
|
||||||
.panel-body::-webkit-scrollbar-thumb {
|
|
||||||
background: #888;
|
|
||||||
border-radius: 4px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.panel-body::-webkit-scrollbar-thumb:hover {
|
|
||||||
background: #555;
|
|
||||||
}
|
|
||||||
</style>
|
|
||||||
|
|
||||||
<script>
|
|
||||||
document.addEventListener('DOMContentLoaded', function() {
|
|
||||||
// Add title attributes for tooltips
|
|
||||||
document.querySelectorAll('[data-bs-toggle="tooltip"]').forEach(function(el) {
|
|
||||||
el.title = el.getAttribute('title') || el.getAttribute('data-bs-original-title');
|
|
||||||
});
|
|
||||||
|
|
||||||
// Handle model selection and model name input
|
|
||||||
const modelNameInput = document.getElementById('modelName');
|
|
||||||
const modelRadios = document.querySelectorAll('input[name="model"]');
|
|
||||||
|
|
||||||
modelRadios.forEach(function(radio) {
|
|
||||||
radio.addEventListener('change', function() {
|
|
||||||
const provider = this.value.split(':')[0]; // Get provider from value instead of data attribute
|
|
||||||
const modelName = this.getAttribute('data-model-name');
|
|
||||||
console.log('Selected provider:', provider);
|
|
||||||
console.log('Model name:', modelName);
|
|
||||||
|
|
||||||
if (provider === 'openrouter') {
|
|
||||||
console.log('Enabling model name input');
|
|
||||||
modelNameInput.disabled = false;
|
|
||||||
modelNameInput.value = '';
|
|
||||||
modelNameInput.placeholder = 'Enter model name for OpenRouter';
|
|
||||||
} else {
|
|
||||||
console.log('Disabling model name input');
|
|
||||||
modelNameInput.disabled = true;
|
|
||||||
modelNameInput.value = modelName;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
// Set initial state based on default selection
|
|
||||||
const defaultSelected = document.querySelector('input[name="model"]:checked');
|
|
||||||
if (defaultSelected) {
|
|
||||||
const provider = defaultSelected.value.split(':')[0]; // Get provider from value instead of data attribute
|
|
||||||
const modelName = defaultSelected.getAttribute('data-model-name');
|
|
||||||
console.log('Initial provider:', provider);
|
|
||||||
console.log('Initial model name:', modelName);
|
|
||||||
|
|
||||||
if (provider === 'openrouter') {
|
|
||||||
console.log('Initially enabling model name input');
|
|
||||||
modelNameInput.disabled = false;
|
|
||||||
modelNameInput.value = '';
|
|
||||||
modelNameInput.placeholder = 'Enter model name for OpenRouter';
|
|
||||||
} else {
|
|
||||||
console.log('Initially disabling model name input');
|
|
||||||
modelNameInput.disabled = true;
|
|
||||||
modelNameInput.value = modelName;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const messageInput = document.getElementById('message-input');
|
|
||||||
const sendButton = document.getElementById('send-button');
|
|
||||||
const refreshButton = document.getElementById('refresh-button');
|
|
||||||
const chatMessages = document.getElementById('chat-messages');
|
|
||||||
|
|
||||||
let currentMessageDiv = null;
|
|
||||||
let messageHistory = [];
|
|
||||||
|
|
||||||
function clearChat() {
|
|
||||||
// Clear the chat messages
|
|
||||||
chatMessages.innerHTML = '';
|
|
||||||
// Reset message history
|
|
||||||
messageHistory = [];
|
|
||||||
// Clear the input field
|
|
||||||
messageInput.value = '';
|
|
||||||
// Enable input if it was disabled
|
|
||||||
messageInput.disabled = false;
|
|
||||||
sendButton.disabled = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
function addMessage(content, isUser) {
|
|
||||||
const messageDiv = document.createElement('div');
|
|
||||||
messageDiv.className = `message ${isUser ? 'message-user' : 'message-assistant'}`;
|
|
||||||
messageDiv.textContent = content;
|
|
||||||
chatMessages.appendChild(messageDiv);
|
|
||||||
chatMessages.scrollTop = chatMessages.scrollHeight;
|
|
||||||
return messageDiv;
|
|
||||||
}
|
|
||||||
|
|
||||||
async function sendMessage() {
|
|
||||||
const message = messageInput.value.trim();
|
|
||||||
if (!message) return;
|
|
||||||
|
|
||||||
// Get selected model
|
|
||||||
const selectedModel = document.querySelector('input[name="model"]:checked');
|
|
||||||
if (!selectedModel) {
|
|
||||||
alert('Please select a model');
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const [provider, modelId] = selectedModel.value.split(':');
|
|
||||||
const modelName = provider === 'openrouter' ? modelNameInput.value : modelId;
|
|
||||||
|
|
||||||
// Clear input and add user message
|
|
||||||
messageInput.value = '';
|
|
||||||
addMessage(message, true);
|
|
||||||
|
|
||||||
try {
|
|
||||||
// Add user message to history
|
|
||||||
messageHistory.push({
|
|
||||||
role: 'user',
|
|
||||||
content: message
|
|
||||||
});
|
|
||||||
|
|
||||||
// Use full message history for the request
|
|
||||||
const messages = [...messageHistory];
|
|
||||||
|
|
||||||
// Create assistant message div
|
|
||||||
currentMessageDiv = addMessage('', false);
|
|
||||||
|
|
||||||
// Get API key
|
|
||||||
const apiKey = document.getElementById('api-key').value.trim();
|
|
||||||
if (!apiKey) {
|
|
||||||
alert('Please enter an API key');
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Debug log the request
|
|
||||||
const requestData = {
|
|
||||||
provider: provider,
|
|
||||||
model: modelName,
|
|
||||||
messages: messages,
|
|
||||||
api_key: apiKey,
|
|
||||||
stream: true,
|
|
||||||
temperature: 0.7
|
|
||||||
};
|
|
||||||
console.log('Sending request:', {...requestData, api_key: '***'});
|
|
||||||
|
|
||||||
// Get CSRF token
|
|
||||||
const csrfToken = document.querySelector('meta[name="csrf-token"]')?.getAttribute('content');
|
|
||||||
if (!csrfToken) {
|
|
||||||
throw new Error('CSRF token not found. Please refresh the page.');
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send request
|
|
||||||
const response = await fetch('/wingman/chat', {
|
|
||||||
method: 'POST',
|
|
||||||
headers: {
|
|
||||||
'Content-Type': 'application/json',
|
|
||||||
'X-CSRFToken': csrfToken
|
|
||||||
},
|
|
||||||
body: JSON.stringify({
|
|
||||||
provider: provider,
|
|
||||||
model: modelName,
|
|
||||||
messages: messages,
|
|
||||||
api_key: apiKey,
|
|
||||||
stream: true,
|
|
||||||
temperature: 0.7
|
|
||||||
})
|
|
||||||
});
|
|
||||||
|
|
||||||
if (!response.ok) {
|
|
||||||
const error = await response.json();
|
|
||||||
throw new Error(error.error || 'Failed to get response');
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle streaming response
|
|
||||||
const reader = response.body.getReader();
|
|
||||||
const decoder = new TextDecoder();
|
|
||||||
let fullResponse = '';
|
|
||||||
|
|
||||||
while (true) {
|
|
||||||
const { value, done } = await reader.read();
|
|
||||||
if (done) break;
|
|
||||||
|
|
||||||
const chunk = decoder.decode(value);
|
|
||||||
const lines = chunk.split('\n');
|
|
||||||
|
|
||||||
for (const line of lines) {
|
|
||||||
if (line.startsWith('data: ')) {
|
|
||||||
const content = line.slice(6);
|
|
||||||
if (content) {
|
|
||||||
currentMessageDiv.textContent += content;
|
|
||||||
fullResponse += content;
|
|
||||||
chatMessages.scrollTop = chatMessages.scrollHeight;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add assistant's response to history
|
|
||||||
if (fullResponse) {
|
|
||||||
messageHistory.push({
|
|
||||||
role: 'assistant',
|
|
||||||
content: fullResponse
|
|
||||||
});
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
console.error('Error:', error);
|
|
||||||
currentMessageDiv.textContent = `Error: ${error.message}`;
|
|
||||||
currentMessageDiv.style.color = 'red';
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
sendButton.addEventListener('click', sendMessage);
|
|
||||||
messageInput.addEventListener('keypress', function(e) {
|
|
||||||
if (e.key === 'Enter') {
|
|
||||||
sendMessage();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
refreshButton.addEventListener('click', clearChat);
|
|
||||||
});
|
|
||||||
</script>
|
|
||||||
{% endblock %}
|
{% endblock %}
|
||||||
|
|||||||
15
src/airflow_wingman/tools/__init__.py
Normal file
15
src/airflow_wingman/tools/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
"""
|
||||||
|
Tools module for Airflow Wingman.
|
||||||
|
|
||||||
|
This module contains the tools used by Airflow Wingman to interact with Airflow.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from airflow_wingman.tools.conversion import convert_to_anthropic_tools, convert_to_openai_tools
|
||||||
|
from airflow_wingman.tools.execution import execute_airflow_tool, list_airflow_tools
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"convert_to_openai_tools",
|
||||||
|
"convert_to_anthropic_tools",
|
||||||
|
"list_airflow_tools",
|
||||||
|
"execute_airflow_tool",
|
||||||
|
]
|
||||||
143
src/airflow_wingman/tools/conversion.py
Normal file
143
src/airflow_wingman/tools/conversion.py
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
"""
|
||||||
|
Conversion utilities for Airflow Wingman tools.
|
||||||
|
|
||||||
|
This module contains functions to convert between different tool formats
|
||||||
|
for various LLM providers (OpenAI, Anthropic, etc.).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_openai_tools(airflow_tools: list) -> list:
|
||||||
|
"""
|
||||||
|
Convert Airflow tools to OpenAI tool definitions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
airflow_tools: List of Airflow tools from MCP server
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of OpenAI tool definitions
|
||||||
|
"""
|
||||||
|
openai_tools = []
|
||||||
|
|
||||||
|
for tool in airflow_tools:
|
||||||
|
# Initialize the OpenAI tool structure
|
||||||
|
openai_tool = {"type": "function", "function": {"name": tool.name, "description": tool.description or tool.name, "parameters": {"type": "object", "properties": {}, "required": []}}}
|
||||||
|
|
||||||
|
# Extract parameters directly from inputSchema if available
|
||||||
|
if hasattr(tool, "inputSchema") and tool.inputSchema:
|
||||||
|
# Set the type and required fields directly from the schema
|
||||||
|
if "type" in tool.inputSchema:
|
||||||
|
openai_tool["function"]["parameters"]["type"] = tool.inputSchema["type"]
|
||||||
|
|
||||||
|
# Add required parameters if specified
|
||||||
|
if "required" in tool.inputSchema:
|
||||||
|
openai_tool["function"]["parameters"]["required"] = tool.inputSchema["required"]
|
||||||
|
|
||||||
|
# Add properties from the input schema
|
||||||
|
if "properties" in tool.inputSchema:
|
||||||
|
for param_name, param_info in tool.inputSchema["properties"].items():
|
||||||
|
# Create parameter definition
|
||||||
|
param_def = {}
|
||||||
|
|
||||||
|
# Handle different schema constructs
|
||||||
|
if "anyOf" in param_info:
|
||||||
|
_handle_schema_construct(param_def, param_info, "anyOf")
|
||||||
|
elif "oneOf" in param_info:
|
||||||
|
_handle_schema_construct(param_def, param_info, "oneOf")
|
||||||
|
elif "allOf" in param_info:
|
||||||
|
_handle_schema_construct(param_def, param_info, "allOf")
|
||||||
|
elif "type" in param_info:
|
||||||
|
param_def["type"] = param_info["type"]
|
||||||
|
# Add format if available
|
||||||
|
if "format" in param_info:
|
||||||
|
param_def["format"] = param_info["format"]
|
||||||
|
else:
|
||||||
|
param_def["type"] = "string" # Default type
|
||||||
|
|
||||||
|
# Add description from title or param name
|
||||||
|
param_def["description"] = param_info.get("description", param_info.get("title", param_name))
|
||||||
|
|
||||||
|
# Add enum values if available
|
||||||
|
if "enum" in param_info:
|
||||||
|
param_def["enum"] = param_info["enum"]
|
||||||
|
|
||||||
|
# Add default value if available
|
||||||
|
if "default" in param_info and param_info["default"] is not None:
|
||||||
|
param_def["default"] = param_info["default"]
|
||||||
|
|
||||||
|
# Add to properties
|
||||||
|
openai_tool["function"]["parameters"]["properties"][param_name] = param_def
|
||||||
|
|
||||||
|
openai_tools.append(openai_tool)
|
||||||
|
|
||||||
|
return openai_tools
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_anthropic_tools(airflow_tools: list) -> list:
|
||||||
|
"""
|
||||||
|
Convert Airflow tools to Anthropic tool definitions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
airflow_tools: List of Airflow tools from MCP server
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Anthropic tool definitions
|
||||||
|
"""
|
||||||
|
logger = logging.getLogger("airflow.plugins.wingman")
|
||||||
|
logger.info(f"Converting {len(airflow_tools)} Airflow tools to Anthropic format")
|
||||||
|
anthropic_tools = []
|
||||||
|
|
||||||
|
for tool in airflow_tools:
|
||||||
|
# Initialize the Anthropic tool structure
|
||||||
|
anthropic_tool = {"name": tool.name, "description": tool.description or tool.name, "input_schema": {}}
|
||||||
|
|
||||||
|
# Extract parameters directly from inputSchema if available
|
||||||
|
if hasattr(tool, "inputSchema") and tool.inputSchema:
|
||||||
|
# Copy the input schema directly as Anthropic's format is similar to JSON Schema
|
||||||
|
anthropic_tool["input_schema"] = tool.inputSchema
|
||||||
|
else:
|
||||||
|
# Create a minimal schema if none exists
|
||||||
|
anthropic_tool["input_schema"] = {"type": "object", "properties": {}, "required": []}
|
||||||
|
|
||||||
|
anthropic_tools.append(anthropic_tool)
|
||||||
|
|
||||||
|
logger.info(f"Converted {len(anthropic_tools)} tools to Anthropic format")
|
||||||
|
return anthropic_tools
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_schema_construct(param_def: dict[str, Any], param_info: dict[str, Any], construct_type: str) -> None:
|
||||||
|
"""
|
||||||
|
Helper function to handle JSON Schema constructs like anyOf, oneOf, allOf.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
param_def: Parameter definition to update
|
||||||
|
param_info: Parameter info from the schema
|
||||||
|
construct_type: Type of construct (anyOf, oneOf, allOf)
|
||||||
|
"""
|
||||||
|
# Get the list of schemas from the construct
|
||||||
|
schemas = param_info[construct_type]
|
||||||
|
|
||||||
|
# Find the first schema with a type
|
||||||
|
for schema in schemas:
|
||||||
|
if "type" in schema:
|
||||||
|
param_def["type"] = schema["type"]
|
||||||
|
|
||||||
|
# Add format if available
|
||||||
|
if "format" in schema:
|
||||||
|
param_def["format"] = schema["format"]
|
||||||
|
|
||||||
|
# Add enum values if available
|
||||||
|
if "enum" in schema:
|
||||||
|
param_def["enum"] = schema["enum"]
|
||||||
|
|
||||||
|
# Add default value if available
|
||||||
|
if "default" in schema and schema["default"] is not None:
|
||||||
|
param_def["default"] = schema["default"]
|
||||||
|
|
||||||
|
break
|
||||||
|
|
||||||
|
# If no type was found, default to string
|
||||||
|
if "type" not in param_def:
|
||||||
|
param_def["type"] = "string"
|
||||||
153
src/airflow_wingman/tools/execution.py
Normal file
153
src/airflow_wingman/tools/execution.py
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
"""
|
||||||
|
Tool execution module for Airflow Wingman.
|
||||||
|
|
||||||
|
This module contains functions to list and execute Airflow tools.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
from airflow import configuration
|
||||||
|
from airflow_mcp_server.config import AirflowConfig
|
||||||
|
from airflow_mcp_server.tools.tool_manager import get_airflow_tools, get_tool
|
||||||
|
|
||||||
|
# Create a properly namespaced logger for the Airflow plugin
|
||||||
|
logger = logging.getLogger("airflow.plugins.wingman")
|
||||||
|
|
||||||
|
|
||||||
|
async def _list_airflow_tools_async(cookie: str) -> list:
|
||||||
|
"""
|
||||||
|
Async implementation to list available Airflow tools.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cookie: Cookie for authentication
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of available Airflow tools
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Set up configuration
|
||||||
|
base_url = f"{configuration.conf.get('webserver', 'base_url')}/api/v1/"
|
||||||
|
logger.info(f"Setting up AirflowConfig with base_url: {base_url}")
|
||||||
|
|
||||||
|
# Format the cookie properly if it doesn't already have the 'session=' prefix
|
||||||
|
formatted_cookie = cookie
|
||||||
|
if cookie and not cookie.startswith("session="):
|
||||||
|
formatted_cookie = f"session={cookie}"
|
||||||
|
logger.info(f"Formatted cookie with session prefix: {formatted_cookie[:10]}...")
|
||||||
|
|
||||||
|
config = AirflowConfig(base_url=base_url, cookie=formatted_cookie, auth_token=None)
|
||||||
|
|
||||||
|
# Get available tools
|
||||||
|
logger.info("Getting Airflow tools...")
|
||||||
|
tools = await get_airflow_tools(config=config, mode="safe")
|
||||||
|
logger.info(f"Got {len(tools)} tools")
|
||||||
|
return tools
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Error listing Airflow tools: {str(e)}\n{traceback.format_exc()}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def list_airflow_tools(cookie: str) -> list:
|
||||||
|
"""
|
||||||
|
Synchronous wrapper to list available Airflow tools.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cookie: Cookie for authentication
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of available Airflow tools
|
||||||
|
"""
|
||||||
|
return asyncio.run(_list_airflow_tools_async(cookie))
|
||||||
|
|
||||||
|
|
||||||
|
async def _execute_airflow_tool_async(tool_name: str, arguments: dict, cookie: str) -> str:
|
||||||
|
"""
|
||||||
|
Async implementation to execute an Airflow tool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: Name of the tool to execute
|
||||||
|
arguments: Arguments to pass to the tool
|
||||||
|
cookie: Cookie for authentication
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Result of the tool execution as a string
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Set up configuration
|
||||||
|
base_url = f"{configuration.conf.get('webserver', 'base_url')}/api/v1/"
|
||||||
|
logger.info(f"Setting up AirflowConfig with base_url: {base_url}")
|
||||||
|
|
||||||
|
# Format the cookie properly if it doesn't already have the 'session=' prefix
|
||||||
|
formatted_cookie = cookie
|
||||||
|
if cookie and not cookie.startswith("session="):
|
||||||
|
formatted_cookie = f"session={cookie}"
|
||||||
|
logger.info(f"Formatted cookie with session prefix: {formatted_cookie[:10]}...")
|
||||||
|
|
||||||
|
config = AirflowConfig(base_url=base_url, cookie=formatted_cookie, auth_token=None)
|
||||||
|
|
||||||
|
# Get the tool
|
||||||
|
logger.info(f"Getting tool: {tool_name}")
|
||||||
|
tool = await get_tool(config=config, name=tool_name)
|
||||||
|
|
||||||
|
if not tool:
|
||||||
|
error_msg = f"Tool not found: {tool_name}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
return json.dumps({"error": error_msg})
|
||||||
|
|
||||||
|
# Execute the tool - ensure the client is in an async context
|
||||||
|
logger.info(f"Executing tool: {tool_name} with arguments: {arguments}")
|
||||||
|
|
||||||
|
# The AirflowClient needs to be used as an async context manager
|
||||||
|
# to properly initialize its session
|
||||||
|
async with tool.client as client: # noqa F841
|
||||||
|
# Now the client has a _session attribute and is in an async context
|
||||||
|
result = await tool.run(arguments)
|
||||||
|
|
||||||
|
# Convert result to string
|
||||||
|
if isinstance(result, dict | list):
|
||||||
|
result_str = json.dumps(result, indent=2)
|
||||||
|
else:
|
||||||
|
result_str = str(result)
|
||||||
|
|
||||||
|
logger.info(f"Tool execution result: {result_str[:100]}...")
|
||||||
|
return result_str
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Error executing tool: {str(e)}\n{traceback.format_exc()}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
return json.dumps({"error": error_msg})
|
||||||
|
|
||||||
|
|
||||||
|
def execute_airflow_tool(tool_name: str, arguments: dict, cookie: str) -> str:
|
||||||
|
"""
|
||||||
|
Synchronous wrapper to execute an Airflow tool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: Name of the tool to execute
|
||||||
|
arguments: Arguments to pass to the tool
|
||||||
|
cookie: Cookie for authentication
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Result of the tool execution as a string
|
||||||
|
"""
|
||||||
|
# Create a new event loop for this execution
|
||||||
|
# This ensures we're always in a clean async context
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Set the event loop for this thread
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
|
# Run the async function in the new event loop
|
||||||
|
result = loop.run_until_complete(_execute_airflow_tool_async(tool_name, arguments, cookie))
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Error in execute_airflow_tool: {str(e)}\n{traceback.format_exc()}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
return json.dumps({"error": error_msg})
|
||||||
|
finally:
|
||||||
|
# Always close the loop to free resources
|
||||||
|
loop.close()
|
||||||
@@ -1,6 +1,9 @@
|
|||||||
"""Views for Airflow Wingman plugin."""
|
"""Views for Airflow Wingman plugin."""
|
||||||
|
|
||||||
from flask import Response, request, stream_with_context
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from flask import Response, request, session
|
||||||
from flask.json import jsonify
|
from flask.json import jsonify
|
||||||
from flask_appbuilder import BaseView as AppBuilderBaseView, expose
|
from flask_appbuilder import BaseView as AppBuilderBaseView, expose
|
||||||
|
|
||||||
@@ -8,6 +11,10 @@ from airflow_wingman.llm_client import LLMClient
|
|||||||
from airflow_wingman.llms_models import MODELS
|
from airflow_wingman.llms_models import MODELS
|
||||||
from airflow_wingman.notes import INTERFACE_MESSAGES
|
from airflow_wingman.notes import INTERFACE_MESSAGES
|
||||||
from airflow_wingman.prompt_engineering import prepare_messages
|
from airflow_wingman.prompt_engineering import prepare_messages
|
||||||
|
from airflow_wingman.tools import list_airflow_tools
|
||||||
|
|
||||||
|
# Create a properly namespaced logger for the Airflow plugin
|
||||||
|
logger = logging.getLogger("airflow.plugins.wingman")
|
||||||
|
|
||||||
|
|
||||||
class WingmanView(AppBuilderBaseView):
|
class WingmanView(AppBuilderBaseView):
|
||||||
@@ -28,8 +35,43 @@ class WingmanView(AppBuilderBaseView):
|
|||||||
try:
|
try:
|
||||||
data = self._validate_chat_request(request.get_json())
|
data = self._validate_chat_request(request.get_json())
|
||||||
|
|
||||||
# Create a new client for this request
|
if data.get("cookie"):
|
||||||
client = LLMClient(data["api_key"])
|
session["airflow_cookie"] = data["cookie"]
|
||||||
|
|
||||||
|
# Get available Airflow tools using the stored cookie
|
||||||
|
airflow_tools = []
|
||||||
|
airflow_cookie = request.cookies.get("session")
|
||||||
|
if airflow_cookie:
|
||||||
|
try:
|
||||||
|
airflow_tools = list_airflow_tools(airflow_cookie)
|
||||||
|
logger.info(f"Loaded {len(airflow_tools)} Airflow tools")
|
||||||
|
if len(airflow_tools) > 0:
|
||||||
|
logger.info(f"First tool: {airflow_tools[0].name if hasattr(airflow_tools[0], 'name') else 'Unknown'}")
|
||||||
|
else:
|
||||||
|
logger.warning("No Airflow tools were loaded")
|
||||||
|
except Exception as e:
|
||||||
|
# Log the error but continue without tools
|
||||||
|
logger.error(f"Error fetching Airflow tools: {str(e)}")
|
||||||
|
|
||||||
|
# Prepare messages with Airflow tools included in the prompt
|
||||||
|
data["messages"] = prepare_messages(data["messages"])
|
||||||
|
|
||||||
|
# Get provider name from request or use default
|
||||||
|
provider_name = data.get("provider", "openai")
|
||||||
|
|
||||||
|
# Get base URL from models configuration based on provider
|
||||||
|
base_url = MODELS.get(provider_name, {}).get("endpoint")
|
||||||
|
|
||||||
|
# Log the request parameters (excluding API key for security)
|
||||||
|
safe_data = {k: v for k, v in data.items() if k != "api_key"}
|
||||||
|
logger.info(f"Chat request: provider={provider_name}, model={data.get('model')}, stream={data.get('stream')}")
|
||||||
|
logger.info(f"Request parameters: {json.dumps(safe_data)[:200]}...")
|
||||||
|
|
||||||
|
# Create a new client for this request with the appropriate provider
|
||||||
|
client = LLMClient(provider_name=provider_name, api_key=data["api_key"], base_url=base_url)
|
||||||
|
|
||||||
|
# Set the Airflow tools for the client to use
|
||||||
|
client.set_airflow_tools(airflow_tools)
|
||||||
|
|
||||||
if data["stream"]:
|
if data["stream"]:
|
||||||
return self._handle_streaming_response(client, data)
|
return self._handle_streaming_response(client, data)
|
||||||
@@ -46,39 +88,116 @@ class WingmanView(AppBuilderBaseView):
|
|||||||
if not data:
|
if not data:
|
||||||
raise ValueError("No data provided")
|
raise ValueError("No data provided")
|
||||||
|
|
||||||
required_fields = ["provider", "model", "messages", "api_key"]
|
required_fields = ["model", "messages", "api_key"]
|
||||||
missing = [f for f in required_fields if not data.get(f)]
|
missing = [f for f in required_fields if not data.get(f)]
|
||||||
if missing:
|
if missing:
|
||||||
raise ValueError(f"Missing required fields: {', '.join(missing)}")
|
raise ValueError(f"Missing required fields: {', '.join(missing)}")
|
||||||
|
|
||||||
# Prepare messages with system instruction while maintaining history
|
# Validate provider if provided
|
||||||
messages = data["messages"]
|
provider = data.get("provider", "openai")
|
||||||
messages = prepare_messages(messages)
|
if provider not in MODELS:
|
||||||
|
raise ValueError(f"Unsupported provider: {provider}. Supported providers: {', '.join(MODELS.keys())}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"provider": data["provider"],
|
|
||||||
"model": data["model"],
|
"model": data["model"],
|
||||||
"messages": messages,
|
"messages": data["messages"],
|
||||||
"api_key": data["api_key"],
|
"api_key": data["api_key"],
|
||||||
"stream": data.get("stream", False),
|
"stream": data.get("stream", True),
|
||||||
"temperature": data.get("temperature", 0.7),
|
"temperature": data.get("temperature", 0.4),
|
||||||
"max_tokens": data.get("max_tokens"),
|
"max_tokens": data.get("max_tokens"),
|
||||||
|
"cookie": data.get("cookie"),
|
||||||
|
"provider": provider,
|
||||||
|
"base_url": data.get("base_url"),
|
||||||
}
|
}
|
||||||
|
|
||||||
def _handle_streaming_response(self, client: LLMClient, data: dict) -> Response:
|
def _handle_streaming_response(self, client: LLMClient, data: dict) -> Response:
|
||||||
"""Handle streaming response."""
|
"""Handle streaming response."""
|
||||||
|
try:
|
||||||
|
logger.info("Beginning streaming response")
|
||||||
|
# Get the cookie at the beginning of the request handler
|
||||||
|
airflow_cookie = request.cookies.get("session")
|
||||||
|
logger.info(f"Got airflow_cookie: {airflow_cookie is not None}")
|
||||||
|
|
||||||
def generate():
|
# Use the enhanced chat_completion method with return_response_obj=True
|
||||||
for chunk in client.chat_completion(messages=data["messages"], model=data["model"], provider=data["provider"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=True):
|
streaming_response = client.chat_completion(messages=data["messages"], model=data["model"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=True)
|
||||||
|
|
||||||
|
def stream_response(cookie=airflow_cookie):
|
||||||
|
complete_response = ""
|
||||||
|
|
||||||
|
# Stream the initial response
|
||||||
|
for chunk in streaming_response:
|
||||||
|
if chunk:
|
||||||
|
complete_response += chunk
|
||||||
yield f"data: {chunk}\n\n"
|
yield f"data: {chunk}\n\n"
|
||||||
|
|
||||||
response = Response(stream_with_context(generate()), mimetype="text/event-stream")
|
# Log the complete assembled response
|
||||||
response.headers["Content-Type"] = "text/event-stream"
|
logger.info("COMPLETE RESPONSE START >>>")
|
||||||
response.headers["Cache-Control"] = "no-cache"
|
logger.info(complete_response)
|
||||||
response.headers["Connection"] = "keep-alive"
|
logger.info("<<< COMPLETE RESPONSE END")
|
||||||
return response
|
|
||||||
|
# Check for tool calls and make follow-up if needed
|
||||||
|
has_tool_calls = client.provider.has_tool_calls(streaming_response)
|
||||||
|
logger.info(f"Has tool calls: {has_tool_calls}")
|
||||||
|
if has_tool_calls:
|
||||||
|
# Signal tool processing start - frontend should disable send button
|
||||||
|
yield f"data: {json.dumps({'event': 'tool_processing_start'})}\n\n"
|
||||||
|
|
||||||
|
# Signal to replace content - frontend should clear the current message
|
||||||
|
yield f"data: {json.dumps({'event': 'replace_content'})}\n\n"
|
||||||
|
|
||||||
|
logger.info("Response contains tool calls, making follow-up request")
|
||||||
|
logger.info(f"Using cookie from closure: {cookie is not None}")
|
||||||
|
|
||||||
|
# Process tool calls and get follow-up response (handles recursive tool calls)
|
||||||
|
# Always stream the follow-up response for consistent handling
|
||||||
|
follow_up_response = client.process_tool_calls_and_follow_up(
|
||||||
|
streaming_response, data["messages"], data["model"], data["temperature"], data["max_tokens"], cookie=cookie, stream=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Collect the follow-up response
|
||||||
|
follow_up_complete_response = ""
|
||||||
|
for chunk in follow_up_response:
|
||||||
|
if chunk:
|
||||||
|
follow_up_complete_response += chunk
|
||||||
|
|
||||||
|
# Send the follow-up response as a single event
|
||||||
|
if follow_up_complete_response:
|
||||||
|
follow_up_event = json.dumps({'event': 'follow_up_response', 'content': follow_up_complete_response})
|
||||||
|
logger.info(f"Follow-up event created with length: {len(follow_up_event)}")
|
||||||
|
data_line = f"data: {follow_up_event}\n\n"
|
||||||
|
logger.info(f"Yielding data line with length: {len(data_line)}")
|
||||||
|
yield data_line
|
||||||
|
|
||||||
|
# Log the complete follow-up response
|
||||||
|
logger.info("FOLLOW-UP RESPONSE START >>>")
|
||||||
|
logger.info(follow_up_complete_response)
|
||||||
|
logger.info("<<< FOLLOW-UP RESPONSE END")
|
||||||
|
|
||||||
|
# Signal tool processing complete - frontend can re-enable send button
|
||||||
|
yield f"data: {json.dumps({'event': 'tool_processing_complete'})}\n\n"
|
||||||
|
|
||||||
|
# Send the complete response as a special event (for compatibility with existing code)
|
||||||
|
complete_event = json.dumps({"event": "complete_response", "content": complete_response})
|
||||||
|
yield f"data: {complete_event}\n\n"
|
||||||
|
|
||||||
|
# Signal end of stream
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
return Response(stream_response(), mimetype="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Streaming error: {str(e)}")
|
||||||
|
return jsonify({"error": str(e)}), 500
|
||||||
|
|
||||||
def _handle_regular_response(self, client: LLMClient, data: dict) -> Response:
|
def _handle_regular_response(self, client: LLMClient, data: dict) -> Response:
|
||||||
"""Handle regular response."""
|
"""Handle regular response."""
|
||||||
response = client.chat_completion(messages=data["messages"], model=data["model"], provider=data["provider"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=False)
|
try:
|
||||||
|
logger.info("Beginning regular (non-streaming) response")
|
||||||
|
response = client.chat_completion(messages=data["messages"], model=data["model"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=False)
|
||||||
|
logger.info("COMPLETE RESPONSE START >>>")
|
||||||
|
logger.info(f"Response to frontend: {json.dumps(response)}")
|
||||||
|
logger.info("<<< COMPLETE RESPONSE END")
|
||||||
|
|
||||||
return jsonify(response)
|
return jsonify(response)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Regular response error: {str(e)}")
|
||||||
|
return jsonify({"error": str(e)}), 500
|
||||||
|
|||||||
Reference in New Issue
Block a user