feat: implement OpenAIProvider with client initialization, message handling, and utility functions

This commit is contained in:
2025-03-26 19:59:01 +00:00
parent bae517a322
commit 678f395649
8 changed files with 522 additions and 443 deletions

View File

@@ -0,0 +1,170 @@
# src/providers/openai_provider/tools.py
import json
import logging
from typing import Any
from openai import Stream
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall
logger = logging.getLogger(__name__)
def has_tool_calls(response: Stream[ChatCompletionChunk] | ChatCompletion) -> bool:
"""Checks if the OpenAI response contains tool calls."""
try:
if isinstance(response, ChatCompletion): # Non-streaming
# Check if choices exist and are not empty
if response.choices:
return bool(response.choices[0].message.tool_calls)
else:
logger.warning("No choices found in OpenAI non-streaming response for tool check.")
return False
elif isinstance(response, Stream):
# This check remains unreliable for unconsumed streams.
# LLMClient needs robust handling after consumption.
logger.warning("has_tool_calls check on a stream is unreliable before consumption.")
return False # Assume no for unconsumed stream for now
else:
# If it's already consumed stream or unexpected type
logger.warning(f"has_tool_calls received unexpected type: {type(response)}")
return False
except Exception as e:
logger.error(f"Error checking for tool calls: {e}", exc_info=True)
return False
def parse_tool_calls(response: ChatCompletion) -> list[dict[str, Any]]:
"""Parses tool calls from a non-streaming OpenAI response."""
# This implementation assumes a non-streaming response or a fully buffered stream
parsed_calls = []
try:
if not isinstance(response, ChatCompletion):
logger.error(f"parse_tool_calls expects ChatCompletion, got {type(response)}")
return []
# Check if choices exist and are not empty
if not response.choices:
logger.warning("No choices found in OpenAI non-streaming response for tool parsing.")
return []
tool_calls: list[ChatCompletionMessageToolCall] | None = response.choices[0].message.tool_calls
if not tool_calls:
return []
logger.debug(f"Parsing {len(tool_calls)} tool calls from OpenAI response.")
for call in tool_calls:
if call.type == "function":
# Attempt to parse server_name from function name if prefixed
# e.g., "server-name__actual-tool-name"
parts = call.function.name.split("__", 1)
if len(parts) == 2:
server_name, func_name = parts
else:
# If no prefix, how do we know the server? Needs refinement.
# Defaulting to None or a default server? Log warning.
logger.warning(f"Could not determine server_name from tool name '{call.function.name}'. Assuming default or error needed.")
server_name = None # Or raise error, or use a default?
func_name = call.function.name
# Arguments might be a string needing JSON parsing, or already parsed dict
arguments_obj = None
try:
if isinstance(call.function.arguments, str):
arguments_obj = json.loads(call.function.arguments)
else:
# Assuming it might already be a dict if not a string (less common)
arguments_obj = call.function.arguments
except json.JSONDecodeError as json_err:
logger.error(f"Failed to parse JSON arguments for tool {func_name} (ID: {call.id}): {json_err}")
logger.error(f"Raw arguments string: {call.function.arguments}")
# Decide how to handle: skip tool, pass raw string, pass error?
# Passing raw string for now, but this might break consumers.
arguments_obj = {"error": "Failed to parse arguments", "raw_arguments": call.function.arguments}
parsed_calls.append({
"id": call.id,
"server_name": server_name, # May be None if not prefixed
"function_name": func_name,
"arguments": arguments_obj, # Pass parsed arguments (or error dict)
})
else:
logger.warning(f"Unsupported tool call type: {call.type}")
return parsed_calls
except Exception as e:
logger.error(f"Error parsing OpenAI tool calls: {e}", exc_info=True)
return [] # Return empty list on error
def format_tool_results(tool_call_id: str, result: Any) -> dict[str, Any]:
"""Formats a tool result for an OpenAI follow-up request."""
# Result might be a dict (including potential errors) or simple string/number
# OpenAI expects the content to be a string, often JSON.
try:
if isinstance(result, dict):
content = json.dumps(result)
elif isinstance(result, str):
content = result # Allow plain strings if result is already string
else:
content = str(result) # Ensure it's a string otherwise
except Exception as e:
logger.error(f"Error JSON-encoding tool result for {tool_call_id}: {e}")
content = json.dumps({"error": "Failed to encode tool result", "original_type": str(type(result))})
logger.debug(f"Formatting tool result for call ID {tool_call_id}")
return {
"role": "tool",
"tool_call_id": tool_call_id,
"content": content,
}
def convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Converts internal tool format to OpenAI's format."""
# This function seems identical to the one in src/tools/conversion.py
# We can potentially remove it from here and import from the central location.
# For now, keep it duplicated to maintain modularity until a decision is made.
openai_tools = []
logger.debug(f"Converting {len(tools)} tools to OpenAI format.")
for tool in tools:
server_name = tool.get("server_name")
tool_name = tool.get("name")
description = tool.get("description")
input_schema = tool.get("inputSchema")
if not server_name or not tool_name or not description or not input_schema:
logger.warning(f"Skipping invalid tool definition during conversion: {tool}")
continue
# Prefix tool name with server name to avoid clashes and allow routing
prefixed_tool_name = f"{server_name}__{tool_name}"
openai_tool_format = {
"type": "function",
"function": {
"name": prefixed_tool_name,
"description": description,
"parameters": input_schema, # OpenAI uses JSON Schema directly
},
}
openai_tools.append(openai_tool_format)
logger.debug(f"Converted tool: {prefixed_tool_name}")
return openai_tools
def get_original_message_with_calls(response: ChatCompletion) -> dict[str, Any]:
"""Extracts the assistant's message containing tool calls."""
try:
if isinstance(response, ChatCompletion) and response.choices and response.choices[0].message.tool_calls:
message = response.choices[0].message
# Convert Pydantic model to dict for message history
return message.model_dump(exclude_unset=True)
else:
logger.warning("Could not extract original message with tool calls from response.")
# Return a placeholder or raise error?
return {"role": "assistant", "content": "[Could not extract tool calls message]"}
except Exception as e:
logger.error(f"Error extracting original message with calls: {e}", exc_info=True)
return {"role": "assistant", "content": f"[Error extracting tool calls message: {str(e)}]"}