fix tool listing and temperature

This commit is contained in:
2025-03-01 17:40:58 +00:00
parent ab39631815
commit 7df5e3c55e
7 changed files with 189 additions and 25 deletions

View File

@@ -50,7 +50,9 @@ class LLMClient:
""" """
self.airflow_tools = tools self.airflow_tools = tools
def chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float = 0.7, max_tokens: int | None = None, stream: bool = False) -> dict[str, Any]: def chat_completion(
self, messages: list[dict[str, str]], model: str, temperature: float = 0.4, max_tokens: int | None = None, stream: bool = True, return_response_obj: bool = False
) -> dict[str, Any] | tuple[Any, Any]:
""" """
Send a chat completion request to the LLM provider. Send a chat completion request to the LLM provider.
@@ -60,9 +62,12 @@ class LLMClient:
temperature: Sampling temperature (0-1) temperature: Sampling temperature (0-1)
max_tokens: Maximum tokens to generate max_tokens: Maximum tokens to generate
stream: Whether to stream the response (default is True) stream: Whether to stream the response (default is True)
return_response_obj: If True and streaming, returns both the response object and generator
Returns: Returns:
Dictionary with the response content or a generator for streaming If stream=False: Dictionary with the response content
If stream=True and return_response_obj=False: Generator for streaming
If stream=True and return_response_obj=True: Tuple of (response_obj, generator)
""" """
# Get provider-specific tool definitions from Airflow tools # Get provider-specific tool definitions from Airflow tools
provider_tools = self.provider.convert_tools(self.airflow_tools) provider_tools = self.provider.convert_tools(self.airflow_tools)
@@ -73,9 +78,12 @@ class LLMClient:
response = self.provider.create_chat_completion(messages=messages, model=model, temperature=temperature, max_tokens=max_tokens, stream=stream, tools=provider_tools) response = self.provider.create_chat_completion(messages=messages, model=model, temperature=temperature, max_tokens=max_tokens, stream=stream, tools=provider_tools)
logger.info(f"Received response from {self.provider_name}") logger.info(f"Received response from {self.provider_name}")
# If streaming, return the generator directly # If streaming, handle based on return_response_obj flag
if stream: if stream:
logger.info(f"Using streaming response from {self.provider_name}") logger.info(f"Using streaming response from {self.provider_name}")
if return_response_obj:
return response, self.provider.get_streaming_content(response)
else:
return self.provider.get_streaming_content(response) return self.provider.get_streaming_content(response)
# For non-streaming responses, handle tool calls if present # For non-streaming responses, handle tool calls if present
@@ -135,6 +143,85 @@ class LLMClient:
return cls(provider_name=provider_name, api_key=api_key, base_url=base_url) return cls(provider_name=provider_name, api_key=api_key, base_url=base_url)
def process_tool_calls_and_follow_up(self, response, messages, model, temperature, max_tokens, max_iterations=5):
"""
Process tool calls recursively from a response and make follow-up requests until
there are no more tool calls or max_iterations is reached.
Returns a generator for streaming the final follow-up response.
Args:
response: The original response object containing tool calls
messages: List of message dictionaries with 'role' and 'content'
model: Model identifier
temperature: Sampling temperature (0-1)
max_tokens: Maximum tokens to generate
max_iterations: Maximum number of tool call iterations to prevent infinite loops
Returns:
Generator for streaming the final follow-up response
"""
try:
iteration = 0
current_response = response
cookie = session.get("airflow_cookie")
if not cookie:
error_msg = "No Airflow cookie available"
logger.error(error_msg)
yield f"Error: {error_msg}"
return
# Process tool calls recursively until there are no more or max_iterations is reached
while self.provider.has_tool_calls(current_response) and iteration < max_iterations:
iteration += 1
logger.info(f"Processing tool calls iteration {iteration}/{max_iterations}")
# Process tool calls and get results
tool_results = self.provider.process_tool_calls(current_response, cookie)
# Make follow-up request with tool results
logger.info(f"Making follow-up request with tool results (iteration {iteration})")
# Only stream on the final iteration
should_stream = (iteration == max_iterations) or not self.provider.has_tool_calls(current_response)
follow_up_response = self.provider.create_follow_up_completion(
messages=messages, model=model, temperature=temperature, max_tokens=max_tokens, tool_results=tool_results, original_response=current_response, stream=should_stream
)
# Check if this follow-up response has more tool calls
if not self.provider.has_tool_calls(follow_up_response):
logger.info(f"No more tool calls after iteration {iteration}")
# Final response - return the streaming content
if not should_stream:
# If we didn't stream this response, we need to make a streaming version
content = self.provider.get_content(follow_up_response)
yield content
return
else:
# Return the streaming generator
return self.provider.get_streaming_content(follow_up_response)
# Update current_response for the next iteration
current_response = follow_up_response
# If we've reached max_iterations and still have tool calls, log a warning
if iteration == max_iterations and self.provider.has_tool_calls(current_response):
logger.warning(f"Reached maximum tool call iterations ({max_iterations})")
# Stream the final response even if it has tool calls
return self.provider.get_streaming_content(follow_up_response)
# If we didn't process any tool calls (shouldn't happen), return an error
if iteration == 0:
error_msg = "No tool calls found in response"
logger.error(error_msg)
yield f"Error: {error_msg}"
except Exception as e:
error_msg = f"Error processing tool calls: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
yield f"Error: {str(e)}"
def refresh_tools(self, cookie: str) -> None: def refresh_tools(self, cookie: str) -> None:
""" """
Refresh the available Airflow tools. Refresh the available Airflow tools.

View File

@@ -5,6 +5,7 @@ This module contains the Anthropic provider implementation that handles
API requests, tool conversion, and response processing for Anthropic's Claude models. API requests, tool conversion, and response processing for Anthropic's Claude models.
""" """
import json
import logging import logging
import traceback import traceback
from typing import Any from typing import Any
@@ -50,7 +51,7 @@ class AnthropicProvider(BaseLLMProvider):
return convert_to_anthropic_tools(airflow_tools) return convert_to_anthropic_tools(airflow_tools)
def create_chat_completion( def create_chat_completion(
self, messages: list[dict[str, Any]], model: str, temperature: float = 0.7, max_tokens: int | None = None, stream: bool = False, tools: list[dict[str, Any]] | None = None self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, stream: bool = False, tools: list[dict[str, Any]] | None = None
) -> Any: ) -> Any:
""" """
Make API request to Anthropic. Make API request to Anthropic.
@@ -84,6 +85,12 @@ class AnthropicProvider(BaseLLMProvider):
# Add tools if provided # Add tools if provided
if tools and len(tools) > 0: if tools and len(tools) > 0:
params["tools"] = tools params["tools"] = tools
else:
logger.warning("No tools included in request")
# Log the full request parameters (with sensitive information redacted)
log_params = params.copy()
logger.info(f"Request parameters: {json.dumps(log_params)}")
# Make the API request # Make the API request
response = self.client.messages.create(**params) response = self.client.messages.create(**params)
@@ -185,7 +192,7 @@ class AnthropicProvider(BaseLLMProvider):
return results return results
def create_follow_up_completion( def create_follow_up_completion(
self, messages: list[dict[str, Any]], model: str, temperature: float = 0.7, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None
) -> Any: ) -> Any:
""" """
Create a follow-up completion with tool results. Create a follow-up completion with tool results.

View File

@@ -33,7 +33,7 @@ class BaseLLMProvider(ABC):
@abstractmethod @abstractmethod
def create_chat_completion( def create_chat_completion(
self, messages: list[dict[str, Any]], model: str, temperature: float = 0.7, max_tokens: int | None = None, stream: bool = False, tools: list[dict[str, Any]] | None = None self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, stream: bool = False, tools: list[dict[str, Any]] | None = None
) -> Any: ) -> Any:
""" """
Make API request to provider. Make API request to provider.
@@ -80,7 +80,7 @@ class BaseLLMProvider(ABC):
@abstractmethod @abstractmethod
def create_follow_up_completion( def create_follow_up_completion(
self, messages: list[dict[str, Any]], model: str, temperature: float = 0.7, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None
) -> Any: ) -> Any:
""" """
Create a follow-up completion with tool results. Create a follow-up completion with tool results.

View File

@@ -52,7 +52,7 @@ class OpenAIProvider(BaseLLMProvider):
return convert_to_openai_tools(airflow_tools) return convert_to_openai_tools(airflow_tools)
def create_chat_completion( def create_chat_completion(
self, messages: list[dict[str, Any]], model: str, temperature: float = 0.7, max_tokens: int | None = None, stream: bool = False, tools: list[dict[str, Any]] | None = None self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, stream: bool = False, tools: list[dict[str, Any]] | None = None
) -> Any: ) -> Any:
""" """
Make API request to OpenAI. Make API request to OpenAI.
@@ -77,6 +77,23 @@ class OpenAIProvider(BaseLLMProvider):
try: try:
logger.info(f"Sending chat completion request to OpenAI with model: {model}") logger.info(f"Sending chat completion request to OpenAI with model: {model}")
# Log information about tools
if not has_tools:
logger.warning("No tools included in request")
# Log request parameters
request_params = {
"model": model,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
"stream": stream,
"tools": tools if has_tools else None,
"tool_choice": tool_choice,
}
logger.info(f"Request parameters: {json.dumps(request_params)}")
response = self.client.chat.completions.create( response = self.client.chat.completions.create(
model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream, tools=tools if has_tools else None, tool_choice=tool_choice model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream, tools=tools if has_tools else None, tool_choice=tool_choice
) )
@@ -143,7 +160,7 @@ class OpenAIProvider(BaseLLMProvider):
return results return results
def create_follow_up_completion( def create_follow_up_completion(
self, messages: list[dict[str, Any]], model: str, temperature: float = 0.7, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None
) -> Any: ) -> Any:
""" """
Create a follow-up completion with tool results. Create a follow-up completion with tool results.

View File

@@ -108,10 +108,18 @@ document.addEventListener('DOMContentLoaded', function() {
function showProcessingIndicator() { function showProcessingIndicator() {
processingIndicator.classList.add('visible'); processingIndicator.classList.add('visible');
chatMessages.scrollTop = chatMessages.scrollHeight; chatMessages.scrollTop = chatMessages.scrollHeight;
// Disable send button and input field during tool processing
sendButton.disabled = true;
messageInput.disabled = true;
} }
function hideProcessingIndicator() { function hideProcessingIndicator() {
processingIndicator.classList.remove('visible'); processingIndicator.classList.remove('visible');
// Re-enable send button and input field after tool processing
sendButton.disabled = false;
messageInput.disabled = false;
} }
async function sendMessage() { async function sendMessage() {
@@ -169,7 +177,7 @@ document.addEventListener('DOMContentLoaded', function() {
messages: messages, messages: messages,
api_key: apiKey, api_key: apiKey,
stream: true, stream: true,
temperature: 0.7 temperature: 0.4,
}; };
console.log('Sending request:', {...requestData, api_key: '***'}); console.log('Sending request:', {...requestData, api_key: '***'});
@@ -231,6 +239,17 @@ document.addEventListener('DOMContentLoaded', function() {
continue; continue;
} }
if (parsed.event === 'replace_content') {
console.log('Replacing content due to tool call');
// Clear the current message content
const currentMessageDiv = document.querySelector('.message.assistant:last-child .message-content');
if (currentMessageDiv) {
currentMessageDiv.innerHTML = '';
fullResponse = ''; // Reset the full response
}
continue;
}
if (parsed.event === 'tool_processing_complete') { if (parsed.event === 'tool_processing_complete') {
console.log('Tool processing completed'); console.log('Tool processing completed');
hideProcessingIndicator(); hideProcessingIndicator();

View File

@@ -5,6 +5,7 @@ This module contains functions to convert between different tool formats
for various LLM providers (OpenAI, Anthropic, etc.). for various LLM providers (OpenAI, Anthropic, etc.).
""" """
import logging
from typing import Any from typing import Any
@@ -84,6 +85,8 @@ def convert_to_anthropic_tools(airflow_tools: list) -> list:
Returns: Returns:
List of Anthropic tool definitions List of Anthropic tool definitions
""" """
logger = logging.getLogger("airflow.plugins.wingman")
logger.info(f"Converting {len(airflow_tools)} Airflow tools to Anthropic format")
anthropic_tools = [] anthropic_tools = []
for tool in airflow_tools: for tool in airflow_tools:
@@ -100,6 +103,7 @@ def convert_to_anthropic_tools(airflow_tools: list) -> list:
anthropic_tools.append(anthropic_tool) anthropic_tools.append(anthropic_tool)
logger.info(f"Converted {len(anthropic_tools)} tools to Anthropic format")
return anthropic_tools return anthropic_tools

View File

@@ -40,12 +40,18 @@ class WingmanView(AppBuilderBaseView):
# Get available Airflow tools using the stored cookie # Get available Airflow tools using the stored cookie
airflow_tools = [] airflow_tools = []
if session.get("airflow_cookie"): airflow_cookie = request.cookies.get("session")
if airflow_cookie:
try: try:
airflow_tools = list_airflow_tools(session["airflow_cookie"]) airflow_tools = list_airflow_tools(airflow_cookie)
logger.info(f"Loaded {len(airflow_tools)} Airflow tools")
if len(airflow_tools) > 0:
logger.info(f"First tool: {airflow_tools[0].name if hasattr(airflow_tools[0], 'name') else 'Unknown'}")
else:
logger.warning("No Airflow tools were loaded")
except Exception as e: except Exception as e:
# Log the error but continue without tools # Log the error but continue without tools
print(f"Error fetching Airflow tools: {str(e)}") logger.error(f"Error fetching Airflow tools: {str(e)}")
# Prepare messages with Airflow tools included in the prompt # Prepare messages with Airflow tools included in the prompt
data["messages"] = prepare_messages(data["messages"]) data["messages"] = prepare_messages(data["messages"])
@@ -97,7 +103,7 @@ class WingmanView(AppBuilderBaseView):
"messages": data["messages"], "messages": data["messages"],
"api_key": data["api_key"], "api_key": data["api_key"],
"stream": data.get("stream", True), "stream": data.get("stream", True),
"temperature": data.get("temperature", 0.7), "temperature": data.get("temperature", 0.4),
"max_tokens": data.get("max_tokens"), "max_tokens": data.get("max_tokens"),
"cookie": data.get("cookie"), "cookie": data.get("cookie"),
"provider": provider, "provider": provider,
@@ -108,27 +114,51 @@ class WingmanView(AppBuilderBaseView):
"""Handle streaming response.""" """Handle streaming response."""
try: try:
logger.info("Beginning streaming response") logger.info("Beginning streaming response")
generator = client.chat_completion(messages=data["messages"], model=data["model"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=True) # Use the enhanced chat_completion method with return_response_obj=True
response_obj, generator = client.chat_completion(
messages=data["messages"], model=data["model"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=True, return_response_obj=True
)
def stream_response(): def stream_response():
complete_response = "" complete_response = ""
# Send SSE format for each chunk # Stream the initial response
for chunk in generator: for chunk in generator:
if chunk: if chunk:
complete_response += chunk complete_response += chunk
yield f"data: {chunk}\n\n" yield f"data: {chunk}\n\n"
# Log the complete assembled response at the end # Log the complete assembled response
logger.info("COMPLETE RESPONSE START >>>") logger.info("COMPLETE RESPONSE START >>>")
logger.info(complete_response) logger.info(complete_response)
logger.info("<<< COMPLETE RESPONSE END") logger.info("<<< COMPLETE RESPONSE END")
# Send the complete response as a special event # Check for tool calls and make follow-up if needed
if client.provider.has_tool_calls(response_obj):
# Signal tool processing start - frontend should disable send button
yield f"data: {json.dumps({'event': 'tool_processing_start'})}\n\n"
# Signal to replace content - frontend should clear the current message
yield f"data: {json.dumps({'event': 'replace_content'})}\n\n"
logger.info("Response contains tool calls, making follow-up request")
# Process tool calls and get follow-up response (handles recursive tool calls)
follow_up_response = client.process_tool_calls_and_follow_up(response_obj, data["messages"], data["model"], data["temperature"], data["max_tokens"])
# Stream the follow-up response
for chunk in follow_up_response:
if chunk:
yield f"data: {chunk}\n\n"
# Signal tool processing complete - frontend can re-enable send button
yield f"data: {json.dumps({'event': 'tool_processing_complete'})}\n\n"
# Send the complete response as a special event (for compatibility with existing code)
complete_event = json.dumps({"event": "complete_response", "content": complete_response}) complete_event = json.dumps({"event": "complete_response", "content": complete_response})
yield f"data: {complete_event}\n\n" yield f"data: {complete_event}\n\n"
# Signal the end of the stream # Signal end of stream
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
return Response(stream_response(), mimetype="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}) return Response(stream_response(), mimetype="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})