From 7df5e3c55efc40f53d420c1169763f4bcdf6c5ce Mon Sep 17 00:00:00 2001 From: abhishekbhakat Date: Sat, 1 Mar 2025 17:40:58 +0000 Subject: [PATCH] fix tool listing and temperature --- src/airflow_wingman/llm_client.py | 95 ++++++++++++++++++- .../providers/anthropic_provider.py | 11 ++- src/airflow_wingman/providers/base.py | 4 +- .../providers/openai_provider.py | 21 +++- src/airflow_wingman/static/js/wingman_chat.js | 31 ++++-- src/airflow_wingman/tools/conversion.py | 4 + src/airflow_wingman/views.py | 48 ++++++++-- 7 files changed, 189 insertions(+), 25 deletions(-) diff --git a/src/airflow_wingman/llm_client.py b/src/airflow_wingman/llm_client.py index 727991b..b81a19b 100644 --- a/src/airflow_wingman/llm_client.py +++ b/src/airflow_wingman/llm_client.py @@ -50,7 +50,9 @@ class LLMClient: """ self.airflow_tools = tools - def chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float = 0.7, max_tokens: int | None = None, stream: bool = False) -> dict[str, Any]: + def chat_completion( + self, messages: list[dict[str, str]], model: str, temperature: float = 0.4, max_tokens: int | None = None, stream: bool = True, return_response_obj: bool = False + ) -> dict[str, Any] | tuple[Any, Any]: """ Send a chat completion request to the LLM provider. @@ -60,9 +62,12 @@ class LLMClient: temperature: Sampling temperature (0-1) max_tokens: Maximum tokens to generate stream: Whether to stream the response (default is True) + return_response_obj: If True and streaming, returns both the response object and generator Returns: - Dictionary with the response content or a generator for streaming + If stream=False: Dictionary with the response content + If stream=True and return_response_obj=False: Generator for streaming + If stream=True and return_response_obj=True: Tuple of (response_obj, generator) """ # Get provider-specific tool definitions from Airflow tools provider_tools = self.provider.convert_tools(self.airflow_tools) @@ -73,10 +78,13 @@ class LLMClient: response = self.provider.create_chat_completion(messages=messages, model=model, temperature=temperature, max_tokens=max_tokens, stream=stream, tools=provider_tools) logger.info(f"Received response from {self.provider_name}") - # If streaming, return the generator directly + # If streaming, handle based on return_response_obj flag if stream: logger.info(f"Using streaming response from {self.provider_name}") - return self.provider.get_streaming_content(response) + if return_response_obj: + return response, self.provider.get_streaming_content(response) + else: + return self.provider.get_streaming_content(response) # For non-streaming responses, handle tool calls if present if self.provider.has_tool_calls(response): @@ -135,6 +143,85 @@ class LLMClient: return cls(provider_name=provider_name, api_key=api_key, base_url=base_url) + def process_tool_calls_and_follow_up(self, response, messages, model, temperature, max_tokens, max_iterations=5): + """ + Process tool calls recursively from a response and make follow-up requests until + there are no more tool calls or max_iterations is reached. + Returns a generator for streaming the final follow-up response. + + Args: + response: The original response object containing tool calls + messages: List of message dictionaries with 'role' and 'content' + model: Model identifier + temperature: Sampling temperature (0-1) + max_tokens: Maximum tokens to generate + max_iterations: Maximum number of tool call iterations to prevent infinite loops + + Returns: + Generator for streaming the final follow-up response + """ + try: + iteration = 0 + current_response = response + cookie = session.get("airflow_cookie") + + if not cookie: + error_msg = "No Airflow cookie available" + logger.error(error_msg) + yield f"Error: {error_msg}" + return + + # Process tool calls recursively until there are no more or max_iterations is reached + while self.provider.has_tool_calls(current_response) and iteration < max_iterations: + iteration += 1 + logger.info(f"Processing tool calls iteration {iteration}/{max_iterations}") + + # Process tool calls and get results + tool_results = self.provider.process_tool_calls(current_response, cookie) + + # Make follow-up request with tool results + logger.info(f"Making follow-up request with tool results (iteration {iteration})") + + # Only stream on the final iteration + should_stream = (iteration == max_iterations) or not self.provider.has_tool_calls(current_response) + + follow_up_response = self.provider.create_follow_up_completion( + messages=messages, model=model, temperature=temperature, max_tokens=max_tokens, tool_results=tool_results, original_response=current_response, stream=should_stream + ) + + # Check if this follow-up response has more tool calls + if not self.provider.has_tool_calls(follow_up_response): + logger.info(f"No more tool calls after iteration {iteration}") + # Final response - return the streaming content + if not should_stream: + # If we didn't stream this response, we need to make a streaming version + content = self.provider.get_content(follow_up_response) + yield content + return + else: + # Return the streaming generator + return self.provider.get_streaming_content(follow_up_response) + + # Update current_response for the next iteration + current_response = follow_up_response + + # If we've reached max_iterations and still have tool calls, log a warning + if iteration == max_iterations and self.provider.has_tool_calls(current_response): + logger.warning(f"Reached maximum tool call iterations ({max_iterations})") + # Stream the final response even if it has tool calls + return self.provider.get_streaming_content(follow_up_response) + + # If we didn't process any tool calls (shouldn't happen), return an error + if iteration == 0: + error_msg = "No tool calls found in response" + logger.error(error_msg) + yield f"Error: {error_msg}" + + except Exception as e: + error_msg = f"Error processing tool calls: {str(e)}\n{traceback.format_exc()}" + logger.error(error_msg) + yield f"Error: {str(e)}" + def refresh_tools(self, cookie: str) -> None: """ Refresh the available Airflow tools. diff --git a/src/airflow_wingman/providers/anthropic_provider.py b/src/airflow_wingman/providers/anthropic_provider.py index d18dd1f..40e5787 100644 --- a/src/airflow_wingman/providers/anthropic_provider.py +++ b/src/airflow_wingman/providers/anthropic_provider.py @@ -5,6 +5,7 @@ This module contains the Anthropic provider implementation that handles API requests, tool conversion, and response processing for Anthropic's Claude models. """ +import json import logging import traceback from typing import Any @@ -50,7 +51,7 @@ class AnthropicProvider(BaseLLMProvider): return convert_to_anthropic_tools(airflow_tools) def create_chat_completion( - self, messages: list[dict[str, Any]], model: str, temperature: float = 0.7, max_tokens: int | None = None, stream: bool = False, tools: list[dict[str, Any]] | None = None + self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, stream: bool = False, tools: list[dict[str, Any]] | None = None ) -> Any: """ Make API request to Anthropic. @@ -84,6 +85,12 @@ class AnthropicProvider(BaseLLMProvider): # Add tools if provided if tools and len(tools) > 0: params["tools"] = tools + else: + logger.warning("No tools included in request") + + # Log the full request parameters (with sensitive information redacted) + log_params = params.copy() + logger.info(f"Request parameters: {json.dumps(log_params)}") # Make the API request response = self.client.messages.create(**params) @@ -185,7 +192,7 @@ class AnthropicProvider(BaseLLMProvider): return results def create_follow_up_completion( - self, messages: list[dict[str, Any]], model: str, temperature: float = 0.7, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None + self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None ) -> Any: """ Create a follow-up completion with tool results. diff --git a/src/airflow_wingman/providers/base.py b/src/airflow_wingman/providers/base.py index dcdbc44..87f4e7c 100644 --- a/src/airflow_wingman/providers/base.py +++ b/src/airflow_wingman/providers/base.py @@ -33,7 +33,7 @@ class BaseLLMProvider(ABC): @abstractmethod def create_chat_completion( - self, messages: list[dict[str, Any]], model: str, temperature: float = 0.7, max_tokens: int | None = None, stream: bool = False, tools: list[dict[str, Any]] | None = None + self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, stream: bool = False, tools: list[dict[str, Any]] | None = None ) -> Any: """ Make API request to provider. @@ -80,7 +80,7 @@ class BaseLLMProvider(ABC): @abstractmethod def create_follow_up_completion( - self, messages: list[dict[str, Any]], model: str, temperature: float = 0.7, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None + self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None ) -> Any: """ Create a follow-up completion with tool results. diff --git a/src/airflow_wingman/providers/openai_provider.py b/src/airflow_wingman/providers/openai_provider.py index 512970d..a6d5c0c 100644 --- a/src/airflow_wingman/providers/openai_provider.py +++ b/src/airflow_wingman/providers/openai_provider.py @@ -52,7 +52,7 @@ class OpenAIProvider(BaseLLMProvider): return convert_to_openai_tools(airflow_tools) def create_chat_completion( - self, messages: list[dict[str, Any]], model: str, temperature: float = 0.7, max_tokens: int | None = None, stream: bool = False, tools: list[dict[str, Any]] | None = None + self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, stream: bool = False, tools: list[dict[str, Any]] | None = None ) -> Any: """ Make API request to OpenAI. @@ -77,6 +77,23 @@ class OpenAIProvider(BaseLLMProvider): try: logger.info(f"Sending chat completion request to OpenAI with model: {model}") + + # Log information about tools + if not has_tools: + logger.warning("No tools included in request") + + # Log request parameters + request_params = { + "model": model, + "messages": messages, + "temperature": temperature, + "max_tokens": max_tokens, + "stream": stream, + "tools": tools if has_tools else None, + "tool_choice": tool_choice, + } + logger.info(f"Request parameters: {json.dumps(request_params)}") + response = self.client.chat.completions.create( model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream, tools=tools if has_tools else None, tool_choice=tool_choice ) @@ -143,7 +160,7 @@ class OpenAIProvider(BaseLLMProvider): return results def create_follow_up_completion( - self, messages: list[dict[str, Any]], model: str, temperature: float = 0.7, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None + self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None ) -> Any: """ Create a follow-up completion with tool results. diff --git a/src/airflow_wingman/static/js/wingman_chat.js b/src/airflow_wingman/static/js/wingman_chat.js index 5b83be3..553cbba 100644 --- a/src/airflow_wingman/static/js/wingman_chat.js +++ b/src/airflow_wingman/static/js/wingman_chat.js @@ -81,7 +81,7 @@ document.addEventListener('DOMContentLoaded', function() { messageDiv.className = `message ${isUser ? 'message-user' : 'message-assistant'}`; messageDiv.classList.add('pre-formatted'); - + // Use marked.js to render markdown try { // Configure marked options @@ -91,7 +91,7 @@ document.addEventListener('DOMContentLoaded', function() { headerIds: false, // Don't add IDs to headers mangle: false, // Don't mangle email addresses }); - + // Render markdown to HTML messageDiv.innerHTML = marked.parse(content); } catch (e) { @@ -99,7 +99,7 @@ document.addEventListener('DOMContentLoaded', function() { // Fallback to innerText if markdown parsing fails messageDiv.innerText = content; } - + chatMessages.appendChild(messageDiv); chatMessages.scrollTop = chatMessages.scrollHeight; return messageDiv; @@ -108,10 +108,18 @@ document.addEventListener('DOMContentLoaded', function() { function showProcessingIndicator() { processingIndicator.classList.add('visible'); chatMessages.scrollTop = chatMessages.scrollHeight; + + // Disable send button and input field during tool processing + sendButton.disabled = true; + messageInput.disabled = true; } function hideProcessingIndicator() { processingIndicator.classList.remove('visible'); + + // Re-enable send button and input field after tool processing + sendButton.disabled = false; + messageInput.disabled = false; } async function sendMessage() { @@ -169,7 +177,7 @@ document.addEventListener('DOMContentLoaded', function() { messages: messages, api_key: apiKey, stream: true, - temperature: 0.7 + temperature: 0.4, }; console.log('Sending request:', {...requestData, api_key: '***'}); @@ -231,6 +239,17 @@ document.addEventListener('DOMContentLoaded', function() { continue; } + if (parsed.event === 'replace_content') { + console.log('Replacing content due to tool call'); + // Clear the current message content + const currentMessageDiv = document.querySelector('.message.assistant:last-child .message-content'); + if (currentMessageDiv) { + currentMessageDiv.innerHTML = ''; + fullResponse = ''; // Reset the full response + } + continue; + } + if (parsed.event === 'tool_processing_complete') { console.log('Tool processing completed'); hideProcessingIndicator(); @@ -242,7 +261,7 @@ document.addEventListener('DOMContentLoaded', function() { console.log('Received complete response from backend'); // Use the complete response from the backend fullResponse = parsed.content; - + // Use marked.js to render markdown try { // Configure marked options @@ -252,7 +271,7 @@ document.addEventListener('DOMContentLoaded', function() { headerIds: false, // Don't add IDs to headers mangle: false, // Don't mangle email addresses }); - + // Render markdown to HTML currentMessageDiv.innerHTML = marked.parse(fullResponse); } catch (e) { diff --git a/src/airflow_wingman/tools/conversion.py b/src/airflow_wingman/tools/conversion.py index 7a30c02..09010e2 100644 --- a/src/airflow_wingman/tools/conversion.py +++ b/src/airflow_wingman/tools/conversion.py @@ -5,6 +5,7 @@ This module contains functions to convert between different tool formats for various LLM providers (OpenAI, Anthropic, etc.). """ +import logging from typing import Any @@ -84,6 +85,8 @@ def convert_to_anthropic_tools(airflow_tools: list) -> list: Returns: List of Anthropic tool definitions """ + logger = logging.getLogger("airflow.plugins.wingman") + logger.info(f"Converting {len(airflow_tools)} Airflow tools to Anthropic format") anthropic_tools = [] for tool in airflow_tools: @@ -100,6 +103,7 @@ def convert_to_anthropic_tools(airflow_tools: list) -> list: anthropic_tools.append(anthropic_tool) + logger.info(f"Converted {len(anthropic_tools)} tools to Anthropic format") return anthropic_tools diff --git a/src/airflow_wingman/views.py b/src/airflow_wingman/views.py index c5d6233..bc60804 100644 --- a/src/airflow_wingman/views.py +++ b/src/airflow_wingman/views.py @@ -40,12 +40,18 @@ class WingmanView(AppBuilderBaseView): # Get available Airflow tools using the stored cookie airflow_tools = [] - if session.get("airflow_cookie"): + airflow_cookie = request.cookies.get("session") + if airflow_cookie: try: - airflow_tools = list_airflow_tools(session["airflow_cookie"]) + airflow_tools = list_airflow_tools(airflow_cookie) + logger.info(f"Loaded {len(airflow_tools)} Airflow tools") + if len(airflow_tools) > 0: + logger.info(f"First tool: {airflow_tools[0].name if hasattr(airflow_tools[0], 'name') else 'Unknown'}") + else: + logger.warning("No Airflow tools were loaded") except Exception as e: # Log the error but continue without tools - print(f"Error fetching Airflow tools: {str(e)}") + logger.error(f"Error fetching Airflow tools: {str(e)}") # Prepare messages with Airflow tools included in the prompt data["messages"] = prepare_messages(data["messages"]) @@ -97,7 +103,7 @@ class WingmanView(AppBuilderBaseView): "messages": data["messages"], "api_key": data["api_key"], "stream": data.get("stream", True), - "temperature": data.get("temperature", 0.7), + "temperature": data.get("temperature", 0.4), "max_tokens": data.get("max_tokens"), "cookie": data.get("cookie"), "provider": provider, @@ -108,27 +114,51 @@ class WingmanView(AppBuilderBaseView): """Handle streaming response.""" try: logger.info("Beginning streaming response") - generator = client.chat_completion(messages=data["messages"], model=data["model"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=True) + # Use the enhanced chat_completion method with return_response_obj=True + response_obj, generator = client.chat_completion( + messages=data["messages"], model=data["model"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=True, return_response_obj=True + ) def stream_response(): complete_response = "" - # Send SSE format for each chunk + # Stream the initial response for chunk in generator: if chunk: complete_response += chunk yield f"data: {chunk}\n\n" - # Log the complete assembled response at the end + # Log the complete assembled response logger.info("COMPLETE RESPONSE START >>>") logger.info(complete_response) logger.info("<<< COMPLETE RESPONSE END") - # Send the complete response as a special event + # Check for tool calls and make follow-up if needed + if client.provider.has_tool_calls(response_obj): + # Signal tool processing start - frontend should disable send button + yield f"data: {json.dumps({'event': 'tool_processing_start'})}\n\n" + + # Signal to replace content - frontend should clear the current message + yield f"data: {json.dumps({'event': 'replace_content'})}\n\n" + + logger.info("Response contains tool calls, making follow-up request") + + # Process tool calls and get follow-up response (handles recursive tool calls) + follow_up_response = client.process_tool_calls_and_follow_up(response_obj, data["messages"], data["model"], data["temperature"], data["max_tokens"]) + + # Stream the follow-up response + for chunk in follow_up_response: + if chunk: + yield f"data: {chunk}\n\n" + + # Signal tool processing complete - frontend can re-enable send button + yield f"data: {json.dumps({'event': 'tool_processing_complete'})}\n\n" + + # Send the complete response as a special event (for compatibility with existing code) complete_event = json.dumps({"event": "complete_response", "content": complete_response}) yield f"data: {complete_event}\n\n" - # Signal the end of the stream + # Signal end of stream yield "data: [DONE]\n\n" return Response(stream_response(), mimetype="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})