Refactor Google and OpenAI provider response handling and tool utilities

- Improved error handling and logging in Google response processing.
- Simplified streaming content extraction and error detection in Google provider.
- Enhanced content extraction logic in OpenAI provider to handle edge cases.
- Streamlined tool conversion functions for both Google and OpenAI providers.
- Removed redundant comments and improved code readability across multiple files.
- Updated context window retrieval and message truncation logic for better performance.
- Ensured consistent handling of tool calls and arguments in OpenAI responses.
This commit is contained in:
2025-03-28 04:20:39 +00:00
parent 51e3058961
commit 247835e595
27 changed files with 265 additions and 564 deletions

View File

@@ -1,4 +1,3 @@
# src/providers/openai_provider/tools.py
import json
import logging
from typing import Any
@@ -13,20 +12,16 @@ logger = logging.getLogger(__name__)
def has_tool_calls(response: Stream[ChatCompletionChunk] | ChatCompletion) -> bool:
"""Checks if the OpenAI response contains tool calls."""
try:
if isinstance(response, ChatCompletion): # Non-streaming
# Check if choices exist and are not empty
if isinstance(response, ChatCompletion):
if response.choices:
return bool(response.choices[0].message.tool_calls)
else:
logger.warning("No choices found in OpenAI non-streaming response for tool check.")
return False
elif isinstance(response, Stream):
# This check remains unreliable for unconsumed streams.
# LLMClient needs robust handling after consumption.
logger.warning("has_tool_calls check on a stream is unreliable before consumption.")
return False # Assume no for unconsumed stream for now
return False
else:
# If it's already consumed stream or unexpected type
logger.warning(f"has_tool_calls received unexpected type: {type(response)}")
return False
except Exception as e:
@@ -36,14 +31,12 @@ def has_tool_calls(response: Stream[ChatCompletionChunk] | ChatCompletion) -> bo
def parse_tool_calls(response: ChatCompletion) -> list[dict[str, Any]]:
"""Parses tool calls from a non-streaming OpenAI response."""
# This implementation assumes a non-streaming response or a fully buffered stream
parsed_calls = []
try:
if not isinstance(response, ChatCompletion):
logger.error(f"parse_tool_calls expects ChatCompletion, got {type(response)}")
return []
# Check if choices exist and are not empty
if not response.choices:
logger.warning("No choices found in OpenAI non-streaming response for tool parsing.")
return []
@@ -55,38 +48,30 @@ def parse_tool_calls(response: ChatCompletion) -> list[dict[str, Any]]:
logger.debug(f"Parsing {len(tool_calls)} tool calls from OpenAI response.")
for call in tool_calls:
if call.type == "function":
# Attempt to parse server_name from function name if prefixed
# e.g., "server-name__actual-tool-name"
parts = call.function.name.split("__", 1)
if len(parts) == 2:
server_name, func_name = parts
else:
# If no prefix, how do we know the server? Needs refinement.
# Defaulting to None or a default server? Log warning.
logger.warning(f"Could not determine server_name from tool name '{call.function.name}'. Assuming default or error needed.")
server_name = None # Or raise error, or use a default?
server_name = None
func_name = call.function.name
# Arguments might be a string needing JSON parsing, or already parsed dict
arguments_obj = None
try:
if isinstance(call.function.arguments, str):
arguments_obj = json.loads(call.function.arguments)
else:
# Assuming it might already be a dict if not a string (less common)
arguments_obj = call.function.arguments
except json.JSONDecodeError as json_err:
logger.error(f"Failed to parse JSON arguments for tool {func_name} (ID: {call.id}): {json_err}")
logger.error(f"Raw arguments string: {call.function.arguments}")
# Decide how to handle: skip tool, pass raw string, pass error?
# Passing raw string for now, but this might break consumers.
arguments_obj = {"error": "Failed to parse arguments", "raw_arguments": call.function.arguments}
parsed_calls.append({
"id": call.id,
"server_name": server_name, # May be None if not prefixed
"server_name": server_name,
"function_name": func_name,
"arguments": arguments_obj, # Pass parsed arguments (or error dict)
"arguments": arguments_obj,
})
else:
logger.warning(f"Unsupported tool call type: {call.type}")
@@ -94,20 +79,18 @@ def parse_tool_calls(response: ChatCompletion) -> list[dict[str, Any]]:
return parsed_calls
except Exception as e:
logger.error(f"Error parsing OpenAI tool calls: {e}", exc_info=True)
return [] # Return empty list on error
return []
def format_tool_results(tool_call_id: str, result: Any) -> dict[str, Any]:
"""Formats a tool result for an OpenAI follow-up request."""
# Result might be a dict (including potential errors) or simple string/number
# OpenAI expects the content to be a string, often JSON.
try:
if isinstance(result, dict):
content = json.dumps(result)
elif isinstance(result, str):
content = result # Allow plain strings if result is already string
content = result
else:
content = str(result) # Ensure it's a string otherwise
content = str(result)
except Exception as e:
logger.error(f"Error JSON-encoding tool result for {tool_call_id}: {e}")
content = json.dumps({"error": "Failed to encode tool result", "original_type": str(type(result))})
@@ -122,9 +105,6 @@ def format_tool_results(tool_call_id: str, result: Any) -> dict[str, Any]:
def convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Converts internal tool format to OpenAI's format."""
# This function seems identical to the one in src/tools/conversion.py
# We can potentially remove it from here and import from the central location.
# For now, keep it duplicated to maintain modularity until a decision is made.
openai_tools = []
logger.debug(f"Converting {len(tools)} tools to OpenAI format.")
for tool in tools:
@@ -137,7 +117,6 @@ def convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
logger.warning(f"Skipping invalid tool definition during conversion: {tool}")
continue
# Prefix tool name with server name to avoid clashes and allow routing
prefixed_tool_name = f"{server_name}__{tool_name}"
openai_tool_format = {
@@ -145,7 +124,7 @@ def convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
"function": {
"name": prefixed_tool_name,
"description": description,
"parameters": input_schema, # OpenAI uses JSON Schema directly
"parameters": input_schema,
},
}
openai_tools.append(openai_tool_format)
@@ -159,11 +138,9 @@ def get_original_message_with_calls(response: ChatCompletion) -> dict[str, Any]:
try:
if isinstance(response, ChatCompletion) and response.choices and response.choices[0].message.tool_calls:
message = response.choices[0].message
# Convert Pydantic model to dict for message history
return message.model_dump(exclude_unset=True)
else:
logger.warning("Could not extract original message with tool calls from response.")
# Return a placeholder or raise error?
return {"role": "assistant", "content": "[Could not extract tool calls message]"}
except Exception as e:
logger.error(f"Error extracting original message with calls: {e}", exc_info=True)