- Improved error handling and logging in Google response processing. - Simplified streaming content extraction and error detection in Google provider. - Enhanced content extraction logic in OpenAI provider to handle edge cases. - Streamlined tool conversion functions for both Google and OpenAI providers. - Removed redundant comments and improved code readability across multiple files. - Updated context window retrieval and message truncation logic for better performance. - Ensured consistent handling of tool calls and arguments in OpenAI responses.
148 lines
6.0 KiB
Python
148 lines
6.0 KiB
Python
import json
|
|
import logging
|
|
from typing import Any
|
|
|
|
from openai import Stream
|
|
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
|
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def has_tool_calls(response: Stream[ChatCompletionChunk] | ChatCompletion) -> bool:
|
|
"""Checks if the OpenAI response contains tool calls."""
|
|
try:
|
|
if isinstance(response, ChatCompletion):
|
|
if response.choices:
|
|
return bool(response.choices[0].message.tool_calls)
|
|
else:
|
|
logger.warning("No choices found in OpenAI non-streaming response for tool check.")
|
|
return False
|
|
elif isinstance(response, Stream):
|
|
logger.warning("has_tool_calls check on a stream is unreliable before consumption.")
|
|
return False
|
|
else:
|
|
logger.warning(f"has_tool_calls received unexpected type: {type(response)}")
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"Error checking for tool calls: {e}", exc_info=True)
|
|
return False
|
|
|
|
|
|
def parse_tool_calls(response: ChatCompletion) -> list[dict[str, Any]]:
|
|
"""Parses tool calls from a non-streaming OpenAI response."""
|
|
parsed_calls = []
|
|
try:
|
|
if not isinstance(response, ChatCompletion):
|
|
logger.error(f"parse_tool_calls expects ChatCompletion, got {type(response)}")
|
|
return []
|
|
|
|
if not response.choices:
|
|
logger.warning("No choices found in OpenAI non-streaming response for tool parsing.")
|
|
return []
|
|
|
|
tool_calls: list[ChatCompletionMessageToolCall] | None = response.choices[0].message.tool_calls
|
|
if not tool_calls:
|
|
return []
|
|
|
|
logger.debug(f"Parsing {len(tool_calls)} tool calls from OpenAI response.")
|
|
for call in tool_calls:
|
|
if call.type == "function":
|
|
parts = call.function.name.split("__", 1)
|
|
if len(parts) == 2:
|
|
server_name, func_name = parts
|
|
else:
|
|
logger.warning(f"Could not determine server_name from tool name '{call.function.name}'. Assuming default or error needed.")
|
|
server_name = None
|
|
func_name = call.function.name
|
|
|
|
arguments_obj = None
|
|
try:
|
|
if isinstance(call.function.arguments, str):
|
|
arguments_obj = json.loads(call.function.arguments)
|
|
else:
|
|
arguments_obj = call.function.arguments
|
|
except json.JSONDecodeError as json_err:
|
|
logger.error(f"Failed to parse JSON arguments for tool {func_name} (ID: {call.id}): {json_err}")
|
|
logger.error(f"Raw arguments string: {call.function.arguments}")
|
|
arguments_obj = {"error": "Failed to parse arguments", "raw_arguments": call.function.arguments}
|
|
|
|
parsed_calls.append({
|
|
"id": call.id,
|
|
"server_name": server_name,
|
|
"function_name": func_name,
|
|
"arguments": arguments_obj,
|
|
})
|
|
else:
|
|
logger.warning(f"Unsupported tool call type: {call.type}")
|
|
|
|
return parsed_calls
|
|
except Exception as e:
|
|
logger.error(f"Error parsing OpenAI tool calls: {e}", exc_info=True)
|
|
return []
|
|
|
|
|
|
def format_tool_results(tool_call_id: str, result: Any) -> dict[str, Any]:
|
|
"""Formats a tool result for an OpenAI follow-up request."""
|
|
try:
|
|
if isinstance(result, dict):
|
|
content = json.dumps(result)
|
|
elif isinstance(result, str):
|
|
content = result
|
|
else:
|
|
content = str(result)
|
|
except Exception as e:
|
|
logger.error(f"Error JSON-encoding tool result for {tool_call_id}: {e}")
|
|
content = json.dumps({"error": "Failed to encode tool result", "original_type": str(type(result))})
|
|
|
|
logger.debug(f"Formatting tool result for call ID {tool_call_id}")
|
|
return {
|
|
"role": "tool",
|
|
"tool_call_id": tool_call_id,
|
|
"content": content,
|
|
}
|
|
|
|
|
|
def convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
"""Converts internal tool format to OpenAI's format."""
|
|
openai_tools = []
|
|
logger.debug(f"Converting {len(tools)} tools to OpenAI format.")
|
|
for tool in tools:
|
|
server_name = tool.get("server_name")
|
|
tool_name = tool.get("name")
|
|
description = tool.get("description")
|
|
input_schema = tool.get("inputSchema")
|
|
|
|
if not server_name or not tool_name or not description or not input_schema:
|
|
logger.warning(f"Skipping invalid tool definition during conversion: {tool}")
|
|
continue
|
|
|
|
prefixed_tool_name = f"{server_name}__{tool_name}"
|
|
|
|
openai_tool_format = {
|
|
"type": "function",
|
|
"function": {
|
|
"name": prefixed_tool_name,
|
|
"description": description,
|
|
"parameters": input_schema,
|
|
},
|
|
}
|
|
openai_tools.append(openai_tool_format)
|
|
logger.debug(f"Converted tool: {prefixed_tool_name}")
|
|
|
|
return openai_tools
|
|
|
|
|
|
def get_original_message_with_calls(response: ChatCompletion) -> dict[str, Any]:
|
|
"""Extracts the assistant's message containing tool calls."""
|
|
try:
|
|
if isinstance(response, ChatCompletion) and response.choices and response.choices[0].message.tool_calls:
|
|
message = response.choices[0].message
|
|
return message.model_dump(exclude_unset=True)
|
|
else:
|
|
logger.warning("Could not extract original message with tool calls from response.")
|
|
return {"role": "assistant", "content": "[Could not extract tool calls message]"}
|
|
except Exception as e:
|
|
logger.error(f"Error extracting original message with calls: {e}", exc_info=True)
|
|
return {"role": "assistant", "content": f"[Error extracting tool calls message: {str(e)}]"}
|