Refactor Google and OpenAI provider response handling and tool utilities
- Improved error handling and logging in Google response processing. - Simplified streaming content extraction and error detection in Google provider. - Enhanced content extraction logic in OpenAI provider to handle edge cases. - Streamlined tool conversion functions for both Google and OpenAI providers. - Removed redundant comments and improved code readability across multiple files. - Updated context window retrieval and message truncation logic for better performance. - Ensured consistent handling of tool calls and arguments in OpenAI responses.
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
# src/providers/openai_provider/tools.py
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
@@ -13,20 +12,16 @@ logger = logging.getLogger(__name__)
|
||||
def has_tool_calls(response: Stream[ChatCompletionChunk] | ChatCompletion) -> bool:
|
||||
"""Checks if the OpenAI response contains tool calls."""
|
||||
try:
|
||||
if isinstance(response, ChatCompletion): # Non-streaming
|
||||
# Check if choices exist and are not empty
|
||||
if isinstance(response, ChatCompletion):
|
||||
if response.choices:
|
||||
return bool(response.choices[0].message.tool_calls)
|
||||
else:
|
||||
logger.warning("No choices found in OpenAI non-streaming response for tool check.")
|
||||
return False
|
||||
elif isinstance(response, Stream):
|
||||
# This check remains unreliable for unconsumed streams.
|
||||
# LLMClient needs robust handling after consumption.
|
||||
logger.warning("has_tool_calls check on a stream is unreliable before consumption.")
|
||||
return False # Assume no for unconsumed stream for now
|
||||
return False
|
||||
else:
|
||||
# If it's already consumed stream or unexpected type
|
||||
logger.warning(f"has_tool_calls received unexpected type: {type(response)}")
|
||||
return False
|
||||
except Exception as e:
|
||||
@@ -36,14 +31,12 @@ def has_tool_calls(response: Stream[ChatCompletionChunk] | ChatCompletion) -> bo
|
||||
|
||||
def parse_tool_calls(response: ChatCompletion) -> list[dict[str, Any]]:
|
||||
"""Parses tool calls from a non-streaming OpenAI response."""
|
||||
# This implementation assumes a non-streaming response or a fully buffered stream
|
||||
parsed_calls = []
|
||||
try:
|
||||
if not isinstance(response, ChatCompletion):
|
||||
logger.error(f"parse_tool_calls expects ChatCompletion, got {type(response)}")
|
||||
return []
|
||||
|
||||
# Check if choices exist and are not empty
|
||||
if not response.choices:
|
||||
logger.warning("No choices found in OpenAI non-streaming response for tool parsing.")
|
||||
return []
|
||||
@@ -55,38 +48,30 @@ def parse_tool_calls(response: ChatCompletion) -> list[dict[str, Any]]:
|
||||
logger.debug(f"Parsing {len(tool_calls)} tool calls from OpenAI response.")
|
||||
for call in tool_calls:
|
||||
if call.type == "function":
|
||||
# Attempt to parse server_name from function name if prefixed
|
||||
# e.g., "server-name__actual-tool-name"
|
||||
parts = call.function.name.split("__", 1)
|
||||
if len(parts) == 2:
|
||||
server_name, func_name = parts
|
||||
else:
|
||||
# If no prefix, how do we know the server? Needs refinement.
|
||||
# Defaulting to None or a default server? Log warning.
|
||||
logger.warning(f"Could not determine server_name from tool name '{call.function.name}'. Assuming default or error needed.")
|
||||
server_name = None # Or raise error, or use a default?
|
||||
server_name = None
|
||||
func_name = call.function.name
|
||||
|
||||
# Arguments might be a string needing JSON parsing, or already parsed dict
|
||||
arguments_obj = None
|
||||
try:
|
||||
if isinstance(call.function.arguments, str):
|
||||
arguments_obj = json.loads(call.function.arguments)
|
||||
else:
|
||||
# Assuming it might already be a dict if not a string (less common)
|
||||
arguments_obj = call.function.arguments
|
||||
except json.JSONDecodeError as json_err:
|
||||
logger.error(f"Failed to parse JSON arguments for tool {func_name} (ID: {call.id}): {json_err}")
|
||||
logger.error(f"Raw arguments string: {call.function.arguments}")
|
||||
# Decide how to handle: skip tool, pass raw string, pass error?
|
||||
# Passing raw string for now, but this might break consumers.
|
||||
arguments_obj = {"error": "Failed to parse arguments", "raw_arguments": call.function.arguments}
|
||||
|
||||
parsed_calls.append({
|
||||
"id": call.id,
|
||||
"server_name": server_name, # May be None if not prefixed
|
||||
"server_name": server_name,
|
||||
"function_name": func_name,
|
||||
"arguments": arguments_obj, # Pass parsed arguments (or error dict)
|
||||
"arguments": arguments_obj,
|
||||
})
|
||||
else:
|
||||
logger.warning(f"Unsupported tool call type: {call.type}")
|
||||
@@ -94,20 +79,18 @@ def parse_tool_calls(response: ChatCompletion) -> list[dict[str, Any]]:
|
||||
return parsed_calls
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing OpenAI tool calls: {e}", exc_info=True)
|
||||
return [] # Return empty list on error
|
||||
return []
|
||||
|
||||
|
||||
def format_tool_results(tool_call_id: str, result: Any) -> dict[str, Any]:
|
||||
"""Formats a tool result for an OpenAI follow-up request."""
|
||||
# Result might be a dict (including potential errors) or simple string/number
|
||||
# OpenAI expects the content to be a string, often JSON.
|
||||
try:
|
||||
if isinstance(result, dict):
|
||||
content = json.dumps(result)
|
||||
elif isinstance(result, str):
|
||||
content = result # Allow plain strings if result is already string
|
||||
content = result
|
||||
else:
|
||||
content = str(result) # Ensure it's a string otherwise
|
||||
content = str(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error JSON-encoding tool result for {tool_call_id}: {e}")
|
||||
content = json.dumps({"error": "Failed to encode tool result", "original_type": str(type(result))})
|
||||
@@ -122,9 +105,6 @@ def format_tool_results(tool_call_id: str, result: Any) -> dict[str, Any]:
|
||||
|
||||
def convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Converts internal tool format to OpenAI's format."""
|
||||
# This function seems identical to the one in src/tools/conversion.py
|
||||
# We can potentially remove it from here and import from the central location.
|
||||
# For now, keep it duplicated to maintain modularity until a decision is made.
|
||||
openai_tools = []
|
||||
logger.debug(f"Converting {len(tools)} tools to OpenAI format.")
|
||||
for tool in tools:
|
||||
@@ -137,7 +117,6 @@ def convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
logger.warning(f"Skipping invalid tool definition during conversion: {tool}")
|
||||
continue
|
||||
|
||||
# Prefix tool name with server name to avoid clashes and allow routing
|
||||
prefixed_tool_name = f"{server_name}__{tool_name}"
|
||||
|
||||
openai_tool_format = {
|
||||
@@ -145,7 +124,7 @@ def convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"function": {
|
||||
"name": prefixed_tool_name,
|
||||
"description": description,
|
||||
"parameters": input_schema, # OpenAI uses JSON Schema directly
|
||||
"parameters": input_schema,
|
||||
},
|
||||
}
|
||||
openai_tools.append(openai_tool_format)
|
||||
@@ -159,11 +138,9 @@ def get_original_message_with_calls(response: ChatCompletion) -> dict[str, Any]:
|
||||
try:
|
||||
if isinstance(response, ChatCompletion) and response.choices and response.choices[0].message.tool_calls:
|
||||
message = response.choices[0].message
|
||||
# Convert Pydantic model to dict for message history
|
||||
return message.model_dump(exclude_unset=True)
|
||||
else:
|
||||
logger.warning("Could not extract original message with tool calls from response.")
|
||||
# Return a placeholder or raise error?
|
||||
return {"role": "assistant", "content": "[Could not extract tool calls message]"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting original message with calls: {e}", exc_info=True)
|
||||
|
||||
Reference in New Issue
Block a user