fix openai tool calling

This commit is contained in:
2025-03-02 19:36:08 +00:00
parent dfa5e8c25d
commit 1c54053e65

View File

@@ -209,7 +209,7 @@ class OpenAIProvider(BaseLLMProvider):
return results return results
def create_follow_up_completion( def create_follow_up_completion(
self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None, stream: bool = False, tools: list[dict[str, Any]] | None = None
) -> Any: ) -> Any:
""" """
Create a follow-up completion with tool results. Create a follow-up completion with tool results.
@@ -221,22 +221,51 @@ class OpenAIProvider(BaseLLMProvider):
max_tokens: Maximum tokens to generate max_tokens: Maximum tokens to generate
tool_results: Results of tool executions tool_results: Results of tool executions
original_response: Original response with tool calls original_response: Original response with tool calls
stream: Whether to stream the response
tools: List of tool definitions in OpenAI format
Returns: Returns:
OpenAI response object OpenAI response object or StreamingResponse if streaming
""" """
if not original_response or not tool_results: if not original_response or not tool_results:
return original_response return original_response
# Get the original message with tool calls # Handle StreamingResponse objects
original_message = original_response.choices[0].message if isinstance(original_response, StreamingResponse):
logger.info("Processing StreamingResponse in create_follow_up_completion")
# Extract tool calls from StreamingResponse
tool_calls = []
if original_response.tool_call is not None:
logger.info(f"Found tool call in StreamingResponse: {original_response.tool_call}")
tool_call = original_response.tool_call
# Create a simplified tool call structure for the assistant message
tool_calls.append({
"id": tool_call.get("id", ""),
"type": "function",
"function": {
"name": tool_call.get("name", ""),
"arguments": json.dumps(tool_call.get("input", {}))
}
})
# Create a new message with the tool calls
assistant_message = {
"role": "assistant",
"content": None,
"tool_calls": tool_calls,
}
else:
# Handle regular OpenAI response objects
logger.info("Processing regular OpenAI response in create_follow_up_completion")
# Get the original message with tool calls
original_message = original_response.choices[0].message
# Create a new message with the tool calls # Create a new message with the tool calls
assistant_message = { assistant_message = {
"role": "assistant", "role": "assistant",
"content": None, "content": None,
"tool_calls": [{"id": tc.id, "type": "function", "function": {"name": tc.function.name, "arguments": tc.function.arguments}} for tc in original_message.tool_calls], "tool_calls": [{"id": tc.id, "type": "function", "function": {"name": tc.function.name, "arguments": tc.function.arguments}} for tc in original_message.tool_calls],
} }
# Create tool result messages # Create tool result messages
tool_messages = [] tool_messages = []
@@ -247,14 +276,14 @@ class OpenAIProvider(BaseLLMProvider):
new_messages = messages + [assistant_message] + tool_messages new_messages = messages + [assistant_message] + tool_messages
# Make a second request to get the final response # Make a second request to get the final response
logger.info("Making second request with tool results") logger.info(f"Making second request with tool results (stream={stream})")
return self.create_chat_completion( return self.create_chat_completion(
messages=new_messages, messages=new_messages,
model=model, model=model,
temperature=temperature, temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
stream=False, stream=stream,
tools=None, # No tools needed for follow-up tools=tools, # Pass tools parameter for follow-up
) )
def get_content(self, response: Any) -> str: def get_content(self, response: Any) -> str:
@@ -292,6 +321,9 @@ class OpenAIProvider(BaseLLMProvider):
def generate(): def generate():
nonlocal tool_call, tool_use_detected, current_tool_call nonlocal tool_call, tool_use_detected, current_tool_call
# Flag to track if we've yielded any content
has_yielded_content = False
for chunk in response: for chunk in response:
# Check for tool call in the delta # Check for tool call in the delta
@@ -324,16 +356,25 @@ class OpenAIProvider(BaseLLMProvider):
# Update the arguments if they're provided in this chunk # Update the arguments if they're provided in this chunk
if hasattr(delta_tool_call, "function") and hasattr(delta_tool_call.function, "arguments") and delta_tool_call.function.arguments and current_tool_call: if hasattr(delta_tool_call, "function") and hasattr(delta_tool_call.function, "arguments") and delta_tool_call.function.arguments and current_tool_call:
# Instead of trying to parse each chunk as JSON, accumulate the arguments
# and only parse the complete JSON at the end
if "_raw_arguments" not in current_tool_call:
current_tool_call["_raw_arguments"] = ""
# Accumulate the raw arguments
current_tool_call["_raw_arguments"] += delta_tool_call.function.arguments
# Try to parse the accumulated arguments
try: try:
# Try to parse the arguments JSON arguments = json.loads(current_tool_call["_raw_arguments"])
arguments = json.loads(delta_tool_call.function.arguments)
if isinstance(arguments, dict): if isinstance(arguments, dict):
current_tool_call["input"].update(arguments) # Successfully parsed the complete JSON
current_tool_call["input"] = arguments # Replace instead of update
# Update the StreamingResponse object's tool_call attribute # Update the StreamingResponse object's tool_call attribute
streaming_response.tool_call = current_tool_call streaming_response.tool_call = current_tool_call
except json.JSONDecodeError: except json.JSONDecodeError:
# If the arguments are not valid JSON, just log a warning # This is expected for partial JSON - we'll try again with the next chunk
logger.warning(f"Failed to parse arguments: {delta_tool_call.function.arguments}") logger.debug(f"Accumulated partial arguments: {current_tool_call['_raw_arguments']}")
# Skip yielding content for tool call chunks # Skip yielding content for tool call chunks
continue continue
@@ -341,7 +382,30 @@ class OpenAIProvider(BaseLLMProvider):
# For the final chunk, set the tool_call attribute # For the final chunk, set the tool_call attribute
if chunk.choices and hasattr(chunk.choices[0], "finish_reason") and chunk.choices[0].finish_reason == "tool_calls": if chunk.choices and hasattr(chunk.choices[0], "finish_reason") and chunk.choices[0].finish_reason == "tool_calls":
logger.info("Streaming response finished with tool_calls reason") logger.info("Streaming response finished with tool_calls reason")
# If we haven't yielded any content yet and we're finishing with tool_calls,
# yield a placeholder message so the frontend has something to display
if not has_yielded_content and tool_use_detected:
logger.info("Yielding placeholder content for tool call")
yield "I'll help you with that." # Simple placeholder message
has_yielded_content = True
if current_tool_call: if current_tool_call:
# One final attempt to parse the arguments if we have accumulated raw arguments
if "_raw_arguments" in current_tool_call and current_tool_call["_raw_arguments"]:
try:
arguments = json.loads(current_tool_call["_raw_arguments"])
if isinstance(arguments, dict):
current_tool_call["input"] = arguments
except json.JSONDecodeError:
logger.warning(f"Failed to parse final arguments: {current_tool_call['_raw_arguments']}")
# If we still can't parse it, use an empty dict as fallback
if not current_tool_call["input"]:
current_tool_call["input"] = {}
# Remove the raw arguments from the final tool call
if "_raw_arguments" in current_tool_call:
del current_tool_call["_raw_arguments"]
tool_call = current_tool_call tool_call = current_tool_call
logger.info(f"Final tool call: {json.dumps(tool_call)}") logger.info(f"Final tool call: {json.dumps(tool_call)}")
# Update the StreamingResponse object's tool_call attribute # Update the StreamingResponse object's tool_call attribute
@@ -352,6 +416,7 @@ class OpenAIProvider(BaseLLMProvider):
if chunk.choices and hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content: if chunk.choices and hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content content = chunk.choices[0].delta.content
yield content yield content
has_yielded_content = True
# Create the generator # Create the generator
gen = generate() gen = generate()