feat: implement OpenAIProvider with client initialization, message handling, and utility functions
This commit is contained in:
170
src/providers/openai_provider/tools.py
Normal file
170
src/providers/openai_provider/tools.py
Normal file
@@ -0,0 +1,170 @@
|
||||
# src/providers/openai_provider/tools.py
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from openai import Stream
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def has_tool_calls(response: Stream[ChatCompletionChunk] | ChatCompletion) -> bool:
|
||||
"""Checks if the OpenAI response contains tool calls."""
|
||||
try:
|
||||
if isinstance(response, ChatCompletion): # Non-streaming
|
||||
# Check if choices exist and are not empty
|
||||
if response.choices:
|
||||
return bool(response.choices[0].message.tool_calls)
|
||||
else:
|
||||
logger.warning("No choices found in OpenAI non-streaming response for tool check.")
|
||||
return False
|
||||
elif isinstance(response, Stream):
|
||||
# This check remains unreliable for unconsumed streams.
|
||||
# LLMClient needs robust handling after consumption.
|
||||
logger.warning("has_tool_calls check on a stream is unreliable before consumption.")
|
||||
return False # Assume no for unconsumed stream for now
|
||||
else:
|
||||
# If it's already consumed stream or unexpected type
|
||||
logger.warning(f"has_tool_calls received unexpected type: {type(response)}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking for tool calls: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
|
||||
def parse_tool_calls(response: ChatCompletion) -> list[dict[str, Any]]:
|
||||
"""Parses tool calls from a non-streaming OpenAI response."""
|
||||
# This implementation assumes a non-streaming response or a fully buffered stream
|
||||
parsed_calls = []
|
||||
try:
|
||||
if not isinstance(response, ChatCompletion):
|
||||
logger.error(f"parse_tool_calls expects ChatCompletion, got {type(response)}")
|
||||
return []
|
||||
|
||||
# Check if choices exist and are not empty
|
||||
if not response.choices:
|
||||
logger.warning("No choices found in OpenAI non-streaming response for tool parsing.")
|
||||
return []
|
||||
|
||||
tool_calls: list[ChatCompletionMessageToolCall] | None = response.choices[0].message.tool_calls
|
||||
if not tool_calls:
|
||||
return []
|
||||
|
||||
logger.debug(f"Parsing {len(tool_calls)} tool calls from OpenAI response.")
|
||||
for call in tool_calls:
|
||||
if call.type == "function":
|
||||
# Attempt to parse server_name from function name if prefixed
|
||||
# e.g., "server-name__actual-tool-name"
|
||||
parts = call.function.name.split("__", 1)
|
||||
if len(parts) == 2:
|
||||
server_name, func_name = parts
|
||||
else:
|
||||
# If no prefix, how do we know the server? Needs refinement.
|
||||
# Defaulting to None or a default server? Log warning.
|
||||
logger.warning(f"Could not determine server_name from tool name '{call.function.name}'. Assuming default or error needed.")
|
||||
server_name = None # Or raise error, or use a default?
|
||||
func_name = call.function.name
|
||||
|
||||
# Arguments might be a string needing JSON parsing, or already parsed dict
|
||||
arguments_obj = None
|
||||
try:
|
||||
if isinstance(call.function.arguments, str):
|
||||
arguments_obj = json.loads(call.function.arguments)
|
||||
else:
|
||||
# Assuming it might already be a dict if not a string (less common)
|
||||
arguments_obj = call.function.arguments
|
||||
except json.JSONDecodeError as json_err:
|
||||
logger.error(f"Failed to parse JSON arguments for tool {func_name} (ID: {call.id}): {json_err}")
|
||||
logger.error(f"Raw arguments string: {call.function.arguments}")
|
||||
# Decide how to handle: skip tool, pass raw string, pass error?
|
||||
# Passing raw string for now, but this might break consumers.
|
||||
arguments_obj = {"error": "Failed to parse arguments", "raw_arguments": call.function.arguments}
|
||||
|
||||
parsed_calls.append({
|
||||
"id": call.id,
|
||||
"server_name": server_name, # May be None if not prefixed
|
||||
"function_name": func_name,
|
||||
"arguments": arguments_obj, # Pass parsed arguments (or error dict)
|
||||
})
|
||||
else:
|
||||
logger.warning(f"Unsupported tool call type: {call.type}")
|
||||
|
||||
return parsed_calls
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing OpenAI tool calls: {e}", exc_info=True)
|
||||
return [] # Return empty list on error
|
||||
|
||||
|
||||
def format_tool_results(tool_call_id: str, result: Any) -> dict[str, Any]:
|
||||
"""Formats a tool result for an OpenAI follow-up request."""
|
||||
# Result might be a dict (including potential errors) or simple string/number
|
||||
# OpenAI expects the content to be a string, often JSON.
|
||||
try:
|
||||
if isinstance(result, dict):
|
||||
content = json.dumps(result)
|
||||
elif isinstance(result, str):
|
||||
content = result # Allow plain strings if result is already string
|
||||
else:
|
||||
content = str(result) # Ensure it's a string otherwise
|
||||
except Exception as e:
|
||||
logger.error(f"Error JSON-encoding tool result for {tool_call_id}: {e}")
|
||||
content = json.dumps({"error": "Failed to encode tool result", "original_type": str(type(result))})
|
||||
|
||||
logger.debug(f"Formatting tool result for call ID {tool_call_id}")
|
||||
return {
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"content": content,
|
||||
}
|
||||
|
||||
|
||||
def convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Converts internal tool format to OpenAI's format."""
|
||||
# This function seems identical to the one in src/tools/conversion.py
|
||||
# We can potentially remove it from here and import from the central location.
|
||||
# For now, keep it duplicated to maintain modularity until a decision is made.
|
||||
openai_tools = []
|
||||
logger.debug(f"Converting {len(tools)} tools to OpenAI format.")
|
||||
for tool in tools:
|
||||
server_name = tool.get("server_name")
|
||||
tool_name = tool.get("name")
|
||||
description = tool.get("description")
|
||||
input_schema = tool.get("inputSchema")
|
||||
|
||||
if not server_name or not tool_name or not description or not input_schema:
|
||||
logger.warning(f"Skipping invalid tool definition during conversion: {tool}")
|
||||
continue
|
||||
|
||||
# Prefix tool name with server name to avoid clashes and allow routing
|
||||
prefixed_tool_name = f"{server_name}__{tool_name}"
|
||||
|
||||
openai_tool_format = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": prefixed_tool_name,
|
||||
"description": description,
|
||||
"parameters": input_schema, # OpenAI uses JSON Schema directly
|
||||
},
|
||||
}
|
||||
openai_tools.append(openai_tool_format)
|
||||
logger.debug(f"Converted tool: {prefixed_tool_name}")
|
||||
|
||||
return openai_tools
|
||||
|
||||
|
||||
def get_original_message_with_calls(response: ChatCompletion) -> dict[str, Any]:
|
||||
"""Extracts the assistant's message containing tool calls."""
|
||||
try:
|
||||
if isinstance(response, ChatCompletion) and response.choices and response.choices[0].message.tool_calls:
|
||||
message = response.choices[0].message
|
||||
# Convert Pydantic model to dict for message history
|
||||
return message.model_dump(exclude_unset=True)
|
||||
else:
|
||||
logger.warning("Could not extract original message with tool calls from response.")
|
||||
# Return a placeholder or raise error?
|
||||
return {"role": "assistant", "content": "[Could not extract tool calls message]"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting original message with calls: {e}", exc_info=True)
|
||||
return {"role": "assistant", "content": f"[Error extracting tool calls message: {str(e)}]"}
|
||||
Reference in New Issue
Block a user