From 1c54053e6522cd4156b3b3cb08f296466cd343a2 Mon Sep 17 00:00:00 2001 From: abhishekbhakat Date: Sun, 2 Mar 2025 19:36:08 +0000 Subject: [PATCH] fix openai tool calling --- .../providers/openai_provider.py | 101 ++++++++++++++---- 1 file changed, 83 insertions(+), 18 deletions(-) diff --git a/src/airflow_wingman/providers/openai_provider.py b/src/airflow_wingman/providers/openai_provider.py index 61e4a58..310dde7 100644 --- a/src/airflow_wingman/providers/openai_provider.py +++ b/src/airflow_wingman/providers/openai_provider.py @@ -209,7 +209,7 @@ class OpenAIProvider(BaseLLMProvider): return results def create_follow_up_completion( - self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None + self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None, stream: bool = False, tools: list[dict[str, Any]] | None = None ) -> Any: """ Create a follow-up completion with tool results. @@ -221,22 +221,51 @@ class OpenAIProvider(BaseLLMProvider): max_tokens: Maximum tokens to generate tool_results: Results of tool executions original_response: Original response with tool calls + stream: Whether to stream the response + tools: List of tool definitions in OpenAI format Returns: - OpenAI response object + OpenAI response object or StreamingResponse if streaming """ if not original_response or not tool_results: return original_response - # Get the original message with tool calls - original_message = original_response.choices[0].message + # Handle StreamingResponse objects + if isinstance(original_response, StreamingResponse): + logger.info("Processing StreamingResponse in create_follow_up_completion") + # Extract tool calls from StreamingResponse + tool_calls = [] + if original_response.tool_call is not None: + logger.info(f"Found tool call in StreamingResponse: {original_response.tool_call}") + tool_call = original_response.tool_call + # Create a simplified tool call structure for the assistant message + tool_calls.append({ + "id": tool_call.get("id", ""), + "type": "function", + "function": { + "name": tool_call.get("name", ""), + "arguments": json.dumps(tool_call.get("input", {})) + } + }) + + # Create a new message with the tool calls + assistant_message = { + "role": "assistant", + "content": None, + "tool_calls": tool_calls, + } + else: + # Handle regular OpenAI response objects + logger.info("Processing regular OpenAI response in create_follow_up_completion") + # Get the original message with tool calls + original_message = original_response.choices[0].message - # Create a new message with the tool calls - assistant_message = { - "role": "assistant", - "content": None, - "tool_calls": [{"id": tc.id, "type": "function", "function": {"name": tc.function.name, "arguments": tc.function.arguments}} for tc in original_message.tool_calls], - } + # Create a new message with the tool calls + assistant_message = { + "role": "assistant", + "content": None, + "tool_calls": [{"id": tc.id, "type": "function", "function": {"name": tc.function.name, "arguments": tc.function.arguments}} for tc in original_message.tool_calls], + } # Create tool result messages tool_messages = [] @@ -247,14 +276,14 @@ class OpenAIProvider(BaseLLMProvider): new_messages = messages + [assistant_message] + tool_messages # Make a second request to get the final response - logger.info("Making second request with tool results") + logger.info(f"Making second request with tool results (stream={stream})") return self.create_chat_completion( messages=new_messages, model=model, temperature=temperature, max_tokens=max_tokens, - stream=False, - tools=None, # No tools needed for follow-up + stream=stream, + tools=tools, # Pass tools parameter for follow-up ) def get_content(self, response: Any) -> str: @@ -292,6 +321,9 @@ class OpenAIProvider(BaseLLMProvider): def generate(): nonlocal tool_call, tool_use_detected, current_tool_call + + # Flag to track if we've yielded any content + has_yielded_content = False for chunk in response: # Check for tool call in the delta @@ -324,16 +356,25 @@ class OpenAIProvider(BaseLLMProvider): # Update the arguments if they're provided in this chunk if hasattr(delta_tool_call, "function") and hasattr(delta_tool_call.function, "arguments") and delta_tool_call.function.arguments and current_tool_call: + # Instead of trying to parse each chunk as JSON, accumulate the arguments + # and only parse the complete JSON at the end + if "_raw_arguments" not in current_tool_call: + current_tool_call["_raw_arguments"] = "" + + # Accumulate the raw arguments + current_tool_call["_raw_arguments"] += delta_tool_call.function.arguments + + # Try to parse the accumulated arguments try: - # Try to parse the arguments JSON - arguments = json.loads(delta_tool_call.function.arguments) + arguments = json.loads(current_tool_call["_raw_arguments"]) if isinstance(arguments, dict): - current_tool_call["input"].update(arguments) + # Successfully parsed the complete JSON + current_tool_call["input"] = arguments # Replace instead of update # Update the StreamingResponse object's tool_call attribute streaming_response.tool_call = current_tool_call except json.JSONDecodeError: - # If the arguments are not valid JSON, just log a warning - logger.warning(f"Failed to parse arguments: {delta_tool_call.function.arguments}") + # This is expected for partial JSON - we'll try again with the next chunk + logger.debug(f"Accumulated partial arguments: {current_tool_call['_raw_arguments']}") # Skip yielding content for tool call chunks continue @@ -341,7 +382,30 @@ class OpenAIProvider(BaseLLMProvider): # For the final chunk, set the tool_call attribute if chunk.choices and hasattr(chunk.choices[0], "finish_reason") and chunk.choices[0].finish_reason == "tool_calls": logger.info("Streaming response finished with tool_calls reason") + + # If we haven't yielded any content yet and we're finishing with tool_calls, + # yield a placeholder message so the frontend has something to display + if not has_yielded_content and tool_use_detected: + logger.info("Yielding placeholder content for tool call") + yield "I'll help you with that." # Simple placeholder message + has_yielded_content = True if current_tool_call: + # One final attempt to parse the arguments if we have accumulated raw arguments + if "_raw_arguments" in current_tool_call and current_tool_call["_raw_arguments"]: + try: + arguments = json.loads(current_tool_call["_raw_arguments"]) + if isinstance(arguments, dict): + current_tool_call["input"] = arguments + except json.JSONDecodeError: + logger.warning(f"Failed to parse final arguments: {current_tool_call['_raw_arguments']}") + # If we still can't parse it, use an empty dict as fallback + if not current_tool_call["input"]: + current_tool_call["input"] = {} + + # Remove the raw arguments from the final tool call + if "_raw_arguments" in current_tool_call: + del current_tool_call["_raw_arguments"] + tool_call = current_tool_call logger.info(f"Final tool call: {json.dumps(tool_call)}") # Update the StreamingResponse object's tool_call attribute @@ -352,6 +416,7 @@ class OpenAIProvider(BaseLLMProvider): if chunk.choices and hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content: content = chunk.choices[0].delta.content yield content + has_yielded_content = True # Create the generator gen = generate()