fix tool listing and temperature
This commit is contained in:
@@ -50,7 +50,9 @@ class LLMClient:
|
||||
"""
|
||||
self.airflow_tools = tools
|
||||
|
||||
def chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float = 0.7, max_tokens: int | None = None, stream: bool = False) -> dict[str, Any]:
|
||||
def chat_completion(
|
||||
self, messages: list[dict[str, str]], model: str, temperature: float = 0.4, max_tokens: int | None = None, stream: bool = True, return_response_obj: bool = False
|
||||
) -> dict[str, Any] | tuple[Any, Any]:
|
||||
"""
|
||||
Send a chat completion request to the LLM provider.
|
||||
|
||||
@@ -60,9 +62,12 @@ class LLMClient:
|
||||
temperature: Sampling temperature (0-1)
|
||||
max_tokens: Maximum tokens to generate
|
||||
stream: Whether to stream the response (default is True)
|
||||
return_response_obj: If True and streaming, returns both the response object and generator
|
||||
|
||||
Returns:
|
||||
Dictionary with the response content or a generator for streaming
|
||||
If stream=False: Dictionary with the response content
|
||||
If stream=True and return_response_obj=False: Generator for streaming
|
||||
If stream=True and return_response_obj=True: Tuple of (response_obj, generator)
|
||||
"""
|
||||
# Get provider-specific tool definitions from Airflow tools
|
||||
provider_tools = self.provider.convert_tools(self.airflow_tools)
|
||||
@@ -73,10 +78,13 @@ class LLMClient:
|
||||
response = self.provider.create_chat_completion(messages=messages, model=model, temperature=temperature, max_tokens=max_tokens, stream=stream, tools=provider_tools)
|
||||
logger.info(f"Received response from {self.provider_name}")
|
||||
|
||||
# If streaming, return the generator directly
|
||||
# If streaming, handle based on return_response_obj flag
|
||||
if stream:
|
||||
logger.info(f"Using streaming response from {self.provider_name}")
|
||||
return self.provider.get_streaming_content(response)
|
||||
if return_response_obj:
|
||||
return response, self.provider.get_streaming_content(response)
|
||||
else:
|
||||
return self.provider.get_streaming_content(response)
|
||||
|
||||
# For non-streaming responses, handle tool calls if present
|
||||
if self.provider.has_tool_calls(response):
|
||||
@@ -135,6 +143,85 @@ class LLMClient:
|
||||
|
||||
return cls(provider_name=provider_name, api_key=api_key, base_url=base_url)
|
||||
|
||||
def process_tool_calls_and_follow_up(self, response, messages, model, temperature, max_tokens, max_iterations=5):
|
||||
"""
|
||||
Process tool calls recursively from a response and make follow-up requests until
|
||||
there are no more tool calls or max_iterations is reached.
|
||||
Returns a generator for streaming the final follow-up response.
|
||||
|
||||
Args:
|
||||
response: The original response object containing tool calls
|
||||
messages: List of message dictionaries with 'role' and 'content'
|
||||
model: Model identifier
|
||||
temperature: Sampling temperature (0-1)
|
||||
max_tokens: Maximum tokens to generate
|
||||
max_iterations: Maximum number of tool call iterations to prevent infinite loops
|
||||
|
||||
Returns:
|
||||
Generator for streaming the final follow-up response
|
||||
"""
|
||||
try:
|
||||
iteration = 0
|
||||
current_response = response
|
||||
cookie = session.get("airflow_cookie")
|
||||
|
||||
if not cookie:
|
||||
error_msg = "No Airflow cookie available"
|
||||
logger.error(error_msg)
|
||||
yield f"Error: {error_msg}"
|
||||
return
|
||||
|
||||
# Process tool calls recursively until there are no more or max_iterations is reached
|
||||
while self.provider.has_tool_calls(current_response) and iteration < max_iterations:
|
||||
iteration += 1
|
||||
logger.info(f"Processing tool calls iteration {iteration}/{max_iterations}")
|
||||
|
||||
# Process tool calls and get results
|
||||
tool_results = self.provider.process_tool_calls(current_response, cookie)
|
||||
|
||||
# Make follow-up request with tool results
|
||||
logger.info(f"Making follow-up request with tool results (iteration {iteration})")
|
||||
|
||||
# Only stream on the final iteration
|
||||
should_stream = (iteration == max_iterations) or not self.provider.has_tool_calls(current_response)
|
||||
|
||||
follow_up_response = self.provider.create_follow_up_completion(
|
||||
messages=messages, model=model, temperature=temperature, max_tokens=max_tokens, tool_results=tool_results, original_response=current_response, stream=should_stream
|
||||
)
|
||||
|
||||
# Check if this follow-up response has more tool calls
|
||||
if not self.provider.has_tool_calls(follow_up_response):
|
||||
logger.info(f"No more tool calls after iteration {iteration}")
|
||||
# Final response - return the streaming content
|
||||
if not should_stream:
|
||||
# If we didn't stream this response, we need to make a streaming version
|
||||
content = self.provider.get_content(follow_up_response)
|
||||
yield content
|
||||
return
|
||||
else:
|
||||
# Return the streaming generator
|
||||
return self.provider.get_streaming_content(follow_up_response)
|
||||
|
||||
# Update current_response for the next iteration
|
||||
current_response = follow_up_response
|
||||
|
||||
# If we've reached max_iterations and still have tool calls, log a warning
|
||||
if iteration == max_iterations and self.provider.has_tool_calls(current_response):
|
||||
logger.warning(f"Reached maximum tool call iterations ({max_iterations})")
|
||||
# Stream the final response even if it has tool calls
|
||||
return self.provider.get_streaming_content(follow_up_response)
|
||||
|
||||
# If we didn't process any tool calls (shouldn't happen), return an error
|
||||
if iteration == 0:
|
||||
error_msg = "No tool calls found in response"
|
||||
logger.error(error_msg)
|
||||
yield f"Error: {error_msg}"
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error processing tool calls: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
yield f"Error: {str(e)}"
|
||||
|
||||
def refresh_tools(self, cookie: str) -> None:
|
||||
"""
|
||||
Refresh the available Airflow tools.
|
||||
|
||||
@@ -5,6 +5,7 @@ This module contains the Anthropic provider implementation that handles
|
||||
API requests, tool conversion, and response processing for Anthropic's Claude models.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import traceback
|
||||
from typing import Any
|
||||
@@ -50,7 +51,7 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
return convert_to_anthropic_tools(airflow_tools)
|
||||
|
||||
def create_chat_completion(
|
||||
self, messages: list[dict[str, Any]], model: str, temperature: float = 0.7, max_tokens: int | None = None, stream: bool = False, tools: list[dict[str, Any]] | None = None
|
||||
self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, stream: bool = False, tools: list[dict[str, Any]] | None = None
|
||||
) -> Any:
|
||||
"""
|
||||
Make API request to Anthropic.
|
||||
@@ -84,6 +85,12 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
# Add tools if provided
|
||||
if tools and len(tools) > 0:
|
||||
params["tools"] = tools
|
||||
else:
|
||||
logger.warning("No tools included in request")
|
||||
|
||||
# Log the full request parameters (with sensitive information redacted)
|
||||
log_params = params.copy()
|
||||
logger.info(f"Request parameters: {json.dumps(log_params)}")
|
||||
|
||||
# Make the API request
|
||||
response = self.client.messages.create(**params)
|
||||
@@ -185,7 +192,7 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
return results
|
||||
|
||||
def create_follow_up_completion(
|
||||
self, messages: list[dict[str, Any]], model: str, temperature: float = 0.7, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None
|
||||
self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None
|
||||
) -> Any:
|
||||
"""
|
||||
Create a follow-up completion with tool results.
|
||||
|
||||
@@ -33,7 +33,7 @@ class BaseLLMProvider(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def create_chat_completion(
|
||||
self, messages: list[dict[str, Any]], model: str, temperature: float = 0.7, max_tokens: int | None = None, stream: bool = False, tools: list[dict[str, Any]] | None = None
|
||||
self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, stream: bool = False, tools: list[dict[str, Any]] | None = None
|
||||
) -> Any:
|
||||
"""
|
||||
Make API request to provider.
|
||||
@@ -80,7 +80,7 @@ class BaseLLMProvider(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def create_follow_up_completion(
|
||||
self, messages: list[dict[str, Any]], model: str, temperature: float = 0.7, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None
|
||||
self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None
|
||||
) -> Any:
|
||||
"""
|
||||
Create a follow-up completion with tool results.
|
||||
|
||||
@@ -52,7 +52,7 @@ class OpenAIProvider(BaseLLMProvider):
|
||||
return convert_to_openai_tools(airflow_tools)
|
||||
|
||||
def create_chat_completion(
|
||||
self, messages: list[dict[str, Any]], model: str, temperature: float = 0.7, max_tokens: int | None = None, stream: bool = False, tools: list[dict[str, Any]] | None = None
|
||||
self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, stream: bool = False, tools: list[dict[str, Any]] | None = None
|
||||
) -> Any:
|
||||
"""
|
||||
Make API request to OpenAI.
|
||||
@@ -77,6 +77,23 @@ class OpenAIProvider(BaseLLMProvider):
|
||||
|
||||
try:
|
||||
logger.info(f"Sending chat completion request to OpenAI with model: {model}")
|
||||
|
||||
# Log information about tools
|
||||
if not has_tools:
|
||||
logger.warning("No tools included in request")
|
||||
|
||||
# Log request parameters
|
||||
request_params = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"stream": stream,
|
||||
"tools": tools if has_tools else None,
|
||||
"tool_choice": tool_choice,
|
||||
}
|
||||
logger.info(f"Request parameters: {json.dumps(request_params)}")
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream, tools=tools if has_tools else None, tool_choice=tool_choice
|
||||
)
|
||||
@@ -143,7 +160,7 @@ class OpenAIProvider(BaseLLMProvider):
|
||||
return results
|
||||
|
||||
def create_follow_up_completion(
|
||||
self, messages: list[dict[str, Any]], model: str, temperature: float = 0.7, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None
|
||||
self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None
|
||||
) -> Any:
|
||||
"""
|
||||
Create a follow-up completion with tool results.
|
||||
|
||||
@@ -81,7 +81,7 @@ document.addEventListener('DOMContentLoaded', function() {
|
||||
messageDiv.className = `message ${isUser ? 'message-user' : 'message-assistant'}`;
|
||||
|
||||
messageDiv.classList.add('pre-formatted');
|
||||
|
||||
|
||||
// Use marked.js to render markdown
|
||||
try {
|
||||
// Configure marked options
|
||||
@@ -91,7 +91,7 @@ document.addEventListener('DOMContentLoaded', function() {
|
||||
headerIds: false, // Don't add IDs to headers
|
||||
mangle: false, // Don't mangle email addresses
|
||||
});
|
||||
|
||||
|
||||
// Render markdown to HTML
|
||||
messageDiv.innerHTML = marked.parse(content);
|
||||
} catch (e) {
|
||||
@@ -99,7 +99,7 @@ document.addEventListener('DOMContentLoaded', function() {
|
||||
// Fallback to innerText if markdown parsing fails
|
||||
messageDiv.innerText = content;
|
||||
}
|
||||
|
||||
|
||||
chatMessages.appendChild(messageDiv);
|
||||
chatMessages.scrollTop = chatMessages.scrollHeight;
|
||||
return messageDiv;
|
||||
@@ -108,10 +108,18 @@ document.addEventListener('DOMContentLoaded', function() {
|
||||
function showProcessingIndicator() {
|
||||
processingIndicator.classList.add('visible');
|
||||
chatMessages.scrollTop = chatMessages.scrollHeight;
|
||||
|
||||
// Disable send button and input field during tool processing
|
||||
sendButton.disabled = true;
|
||||
messageInput.disabled = true;
|
||||
}
|
||||
|
||||
function hideProcessingIndicator() {
|
||||
processingIndicator.classList.remove('visible');
|
||||
|
||||
// Re-enable send button and input field after tool processing
|
||||
sendButton.disabled = false;
|
||||
messageInput.disabled = false;
|
||||
}
|
||||
|
||||
async function sendMessage() {
|
||||
@@ -169,7 +177,7 @@ document.addEventListener('DOMContentLoaded', function() {
|
||||
messages: messages,
|
||||
api_key: apiKey,
|
||||
stream: true,
|
||||
temperature: 0.7
|
||||
temperature: 0.4,
|
||||
};
|
||||
console.log('Sending request:', {...requestData, api_key: '***'});
|
||||
|
||||
@@ -231,6 +239,17 @@ document.addEventListener('DOMContentLoaded', function() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (parsed.event === 'replace_content') {
|
||||
console.log('Replacing content due to tool call');
|
||||
// Clear the current message content
|
||||
const currentMessageDiv = document.querySelector('.message.assistant:last-child .message-content');
|
||||
if (currentMessageDiv) {
|
||||
currentMessageDiv.innerHTML = '';
|
||||
fullResponse = ''; // Reset the full response
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (parsed.event === 'tool_processing_complete') {
|
||||
console.log('Tool processing completed');
|
||||
hideProcessingIndicator();
|
||||
@@ -242,7 +261,7 @@ document.addEventListener('DOMContentLoaded', function() {
|
||||
console.log('Received complete response from backend');
|
||||
// Use the complete response from the backend
|
||||
fullResponse = parsed.content;
|
||||
|
||||
|
||||
// Use marked.js to render markdown
|
||||
try {
|
||||
// Configure marked options
|
||||
@@ -252,7 +271,7 @@ document.addEventListener('DOMContentLoaded', function() {
|
||||
headerIds: false, // Don't add IDs to headers
|
||||
mangle: false, // Don't mangle email addresses
|
||||
});
|
||||
|
||||
|
||||
// Render markdown to HTML
|
||||
currentMessageDiv.innerHTML = marked.parse(fullResponse);
|
||||
} catch (e) {
|
||||
|
||||
@@ -5,6 +5,7 @@ This module contains functions to convert between different tool formats
|
||||
for various LLM providers (OpenAI, Anthropic, etc.).
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
|
||||
@@ -84,6 +85,8 @@ def convert_to_anthropic_tools(airflow_tools: list) -> list:
|
||||
Returns:
|
||||
List of Anthropic tool definitions
|
||||
"""
|
||||
logger = logging.getLogger("airflow.plugins.wingman")
|
||||
logger.info(f"Converting {len(airflow_tools)} Airflow tools to Anthropic format")
|
||||
anthropic_tools = []
|
||||
|
||||
for tool in airflow_tools:
|
||||
@@ -100,6 +103,7 @@ def convert_to_anthropic_tools(airflow_tools: list) -> list:
|
||||
|
||||
anthropic_tools.append(anthropic_tool)
|
||||
|
||||
logger.info(f"Converted {len(anthropic_tools)} tools to Anthropic format")
|
||||
return anthropic_tools
|
||||
|
||||
|
||||
|
||||
@@ -40,12 +40,18 @@ class WingmanView(AppBuilderBaseView):
|
||||
|
||||
# Get available Airflow tools using the stored cookie
|
||||
airflow_tools = []
|
||||
if session.get("airflow_cookie"):
|
||||
airflow_cookie = request.cookies.get("session")
|
||||
if airflow_cookie:
|
||||
try:
|
||||
airflow_tools = list_airflow_tools(session["airflow_cookie"])
|
||||
airflow_tools = list_airflow_tools(airflow_cookie)
|
||||
logger.info(f"Loaded {len(airflow_tools)} Airflow tools")
|
||||
if len(airflow_tools) > 0:
|
||||
logger.info(f"First tool: {airflow_tools[0].name if hasattr(airflow_tools[0], 'name') else 'Unknown'}")
|
||||
else:
|
||||
logger.warning("No Airflow tools were loaded")
|
||||
except Exception as e:
|
||||
# Log the error but continue without tools
|
||||
print(f"Error fetching Airflow tools: {str(e)}")
|
||||
logger.error(f"Error fetching Airflow tools: {str(e)}")
|
||||
|
||||
# Prepare messages with Airflow tools included in the prompt
|
||||
data["messages"] = prepare_messages(data["messages"])
|
||||
@@ -97,7 +103,7 @@ class WingmanView(AppBuilderBaseView):
|
||||
"messages": data["messages"],
|
||||
"api_key": data["api_key"],
|
||||
"stream": data.get("stream", True),
|
||||
"temperature": data.get("temperature", 0.7),
|
||||
"temperature": data.get("temperature", 0.4),
|
||||
"max_tokens": data.get("max_tokens"),
|
||||
"cookie": data.get("cookie"),
|
||||
"provider": provider,
|
||||
@@ -108,27 +114,51 @@ class WingmanView(AppBuilderBaseView):
|
||||
"""Handle streaming response."""
|
||||
try:
|
||||
logger.info("Beginning streaming response")
|
||||
generator = client.chat_completion(messages=data["messages"], model=data["model"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=True)
|
||||
# Use the enhanced chat_completion method with return_response_obj=True
|
||||
response_obj, generator = client.chat_completion(
|
||||
messages=data["messages"], model=data["model"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=True, return_response_obj=True
|
||||
)
|
||||
|
||||
def stream_response():
|
||||
complete_response = ""
|
||||
|
||||
# Send SSE format for each chunk
|
||||
# Stream the initial response
|
||||
for chunk in generator:
|
||||
if chunk:
|
||||
complete_response += chunk
|
||||
yield f"data: {chunk}\n\n"
|
||||
|
||||
# Log the complete assembled response at the end
|
||||
# Log the complete assembled response
|
||||
logger.info("COMPLETE RESPONSE START >>>")
|
||||
logger.info(complete_response)
|
||||
logger.info("<<< COMPLETE RESPONSE END")
|
||||
|
||||
# Send the complete response as a special event
|
||||
# Check for tool calls and make follow-up if needed
|
||||
if client.provider.has_tool_calls(response_obj):
|
||||
# Signal tool processing start - frontend should disable send button
|
||||
yield f"data: {json.dumps({'event': 'tool_processing_start'})}\n\n"
|
||||
|
||||
# Signal to replace content - frontend should clear the current message
|
||||
yield f"data: {json.dumps({'event': 'replace_content'})}\n\n"
|
||||
|
||||
logger.info("Response contains tool calls, making follow-up request")
|
||||
|
||||
# Process tool calls and get follow-up response (handles recursive tool calls)
|
||||
follow_up_response = client.process_tool_calls_and_follow_up(response_obj, data["messages"], data["model"], data["temperature"], data["max_tokens"])
|
||||
|
||||
# Stream the follow-up response
|
||||
for chunk in follow_up_response:
|
||||
if chunk:
|
||||
yield f"data: {chunk}\n\n"
|
||||
|
||||
# Signal tool processing complete - frontend can re-enable send button
|
||||
yield f"data: {json.dumps({'event': 'tool_processing_complete'})}\n\n"
|
||||
|
||||
# Send the complete response as a special event (for compatibility with existing code)
|
||||
complete_event = json.dumps({"event": "complete_response", "content": complete_response})
|
||||
yield f"data: {complete_event}\n\n"
|
||||
|
||||
# Signal the end of the stream
|
||||
# Signal end of stream
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return Response(stream_response(), mimetype="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
|
||||
|
||||
Reference in New Issue
Block a user