feat: implement OpenAIProvider with client initialization, message handling, and utility functions

This commit is contained in:
2025-03-26 19:59:01 +00:00
parent bae517a322
commit 678f395649
8 changed files with 522 additions and 443 deletions

View File

@@ -1,390 +0,0 @@
# src/providers/openai_provider.py
import json
import logging
import math
from collections.abc import Generator
from typing import Any
from openai import OpenAI, Stream
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall
from src.llm_models import MODELS
from src.providers.base import BaseProvider
logger = logging.getLogger(__name__)
class OpenAIProvider(BaseProvider):
"""Provider implementation for OpenAI and compatible APIs."""
def __init__(self, api_key: str, base_url: str | None = None):
# Use default OpenAI endpoint if base_url is not provided
effective_base_url = base_url or MODELS.get("openai", {}).get("endpoint")
super().__init__(api_key, effective_base_url)
logger.info(f"Initializing OpenAIProvider with base URL: {self.base_url}")
try:
# TODO: Add default headers like in original client?
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
except Exception as e:
logger.error(f"Failed to initialize OpenAI client: {e}", exc_info=True)
raise
def _get_context_window(self, model: str) -> int:
"""Retrieves the context window size for a given model."""
# Default to a safe fallback if model or provider info is missing
default_window = 8000
try:
# Assuming MODELS structure: MODELS['openai']['models'] is a list of dicts
provider_models = MODELS.get("openai", {}).get("models", [])
for m in provider_models:
if m.get("id") == model:
return m.get("context_window", default_window)
# Fallback if specific model ID not found in our list
logger.warning(f"Context window for OpenAI model '{model}' not found in MODELS config. Using default: {default_window}")
return default_window
except Exception as e:
logger.error(f"Error retrieving context window for model {model}: {e}. Using default: {default_window}", exc_info=True)
return default_window
def _estimate_openai_token_count(self, messages: list[dict[str, str]]) -> int:
"""
Estimates the token count for OpenAI messages using char count / 4 approximation.
Note: This is less accurate than using tiktoken.
"""
total_chars = 0
for message in messages:
total_chars += len(message.get("role", ""))
content = message.get("content")
if isinstance(content, str):
total_chars += len(content)
# Rough approximation for function/tool call overhead if needed later
# Using math.ceil to round up, ensuring we don't underestimate too much.
estimated_tokens = math.ceil(total_chars / 4.0)
logger.debug(f"Estimated OpenAI token count (char/4): {estimated_tokens} for {len(messages)} messages")
return estimated_tokens
def _truncate_messages(self, messages: list[dict[str, str]], model: str) -> tuple[list[dict[str, str]], int, int]:
"""
Truncates messages from the beginning if estimated token count exceeds the limit.
Preserves the first message if it's a system prompt.
Returns:
- The potentially truncated list of messages.
- The initial estimated token count.
- The final estimated token count after truncation (if any).
"""
context_limit = self._get_context_window(model)
# Add a buffer to be safer with approximation
buffer = 200 # Reduce buffer slightly as we round up now
effective_limit = context_limit - buffer
initial_estimated_count = self._estimate_openai_token_count(messages)
final_estimated_count = initial_estimated_count
truncated_messages = list(messages) # Make a copy
# Identify if the first message is a system prompt
has_system_prompt = False
if truncated_messages and truncated_messages[0].get("role") == "system":
has_system_prompt = True
# If only system prompt exists, don't truncate further
if len(truncated_messages) == 1 and final_estimated_count > effective_limit:
logger.warning(f"System prompt alone ({final_estimated_count} tokens) exceeds effective limit ({effective_limit}). Cannot truncate further.")
# Return original messages to avoid removing the only message
return messages, initial_estimated_count, final_estimated_count
while final_estimated_count > effective_limit:
if has_system_prompt and len(truncated_messages) <= 1:
# Should not happen if check above works, but safety break
logger.warning("Truncation stopped: Only system prompt remains.")
break
if not has_system_prompt and len(truncated_messages) <= 0:
logger.warning("Truncation stopped: No messages left.")
break # No messages left
# Determine index to remove: 1 if system prompt exists and list is long enough, else 0
remove_index = 1 if has_system_prompt and len(truncated_messages) > 1 else 0
if remove_index >= len(truncated_messages):
logger.error(f"Truncation logic error: remove_index {remove_index} out of bounds for {len(truncated_messages)} messages.")
break # Avoid index error
removed_message = truncated_messages.pop(remove_index)
logger.debug(f"Truncating message at index {remove_index} (Role: {removed_message.get('role')}) due to context limit.")
# Recalculate estimated count
final_estimated_count = self._estimate_openai_token_count(truncated_messages)
logger.debug(f"Recalculated estimated tokens: {final_estimated_count}")
# Safety break if list becomes unexpectedly empty
if not truncated_messages:
logger.warning("Truncation resulted in empty message list.")
break
if initial_estimated_count != final_estimated_count:
logger.info(
f"Truncated messages for model {model}. "
f"Initial estimated tokens: {initial_estimated_count}, "
f"Final estimated tokens: {final_estimated_count}, "
f"Limit: {context_limit} (Effective: {effective_limit})"
)
else:
logger.debug(f"No truncation needed for model {model}. Estimated tokens: {final_estimated_count}, Limit: {context_limit} (Effective: {effective_limit})")
return truncated_messages, initial_estimated_count, final_estimated_count
def create_chat_completion(
self,
messages: list[dict[str, str]],
model: str,
temperature: float = 0.4,
max_tokens: int | None = None,
stream: bool = True,
tools: list[dict[str, Any]] | None = None,
# Add usage dict to return type hint? Needs careful thought for streaming vs non-streaming
) -> Stream[ChatCompletionChunk] | ChatCompletion: # How to return usage info cleanly?
"""Creates a chat completion using the OpenAI API, handling context window truncation."""
logger.debug(f"OpenAI create_chat_completion called. Model: {model}, Stream: {stream}, Tools: {bool(tools)}")
# --- Truncation Step ---
truncated_messages, initial_est_tokens, final_est_tokens = self._truncate_messages(messages, model)
# -----------------------
try:
completion_params = {
"model": model,
"messages": truncated_messages, # Use truncated messages
"temperature": temperature,
"max_tokens": max_tokens,
"stream": stream,
}
if tools:
completion_params["tools"] = tools
completion_params["tool_choice"] = "auto" # Let OpenAI decide when to use tools
# Remove None values like max_tokens if not provided
completion_params = {k: v for k, v in completion_params.items() if v is not None}
# --- Added Debug Logging ---
log_params = completion_params.copy()
# Avoid logging full messages if they are too long
if "messages" in log_params:
log_params["messages"] = [
{k: (v[:100] + "..." if isinstance(v, str) and len(v) > 100 else v) for k, v in msg.items()}
for msg in log_params["messages"][-2:] # Log last 2 messages summary
]
# Specifically log tools structure if present
tools_log = log_params.get("tools", "Not Present")
logger.debug(f"Calling OpenAI API. Model: {log_params.get('model')}, Stream: {log_params.get('stream')}, Tools: {tools_log}")
logger.debug(f"Full API Params (messages summarized): {log_params}")
# --- End Added Debug Logging ---
response = self.client.chat.completions.create(**completion_params)
logger.debug("OpenAI API call successful.")
# --- Capture Actual Usage (for UI display later) ---
# This part is tricky. Usage info is easily available on the *non-streaming* response.
# For streaming, it's often not available until the stream is fully consumed,
# or sometimes via response headers or a final event (provider-dependent).
# For now, let's focus on getting it from the non-streaming case.
# We need a way to pass this back alongside the content/stream.
# Option 1: Modify return type (complex for stream/non-stream union)
# Option 2: Store it in the provider instance (stateful, maybe bad)
# Option 3: Have LLMClient handle extraction (requires LLMClient to know response structure)
# Let's try returning it alongside for non-streaming, and figure out streaming later.
# This requires changing the BaseProvider interface and LLMClient handling.
# For now, just log it here.
actual_usage = None
if isinstance(response, ChatCompletion) and response.usage:
actual_usage = {
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens,
}
logger.info(f"Actual OpenAI API usage: {actual_usage}")
# TODO: How to handle usage for streaming responses? Needs investigation.
# Return the raw response for now. LLMClient will process it.
return response
# ----------------------------------------------------
except Exception as e:
logger.error(f"OpenAI API error: {e}", exc_info=True)
# Re-raise for the LLMClient to handle
raise
def get_streaming_content(self, response: Stream[ChatCompletionChunk]) -> Generator[str, None, None]:
"""Yields content chunks from an OpenAI streaming response."""
logger.debug("Processing OpenAI stream...")
full_delta = ""
try:
for chunk in response:
delta = chunk.choices[0].delta.content
if delta:
full_delta += delta
yield delta
logger.debug(f"Stream finished. Total delta length: {len(full_delta)}")
except Exception as e:
logger.error(f"Error processing OpenAI stream: {e}", exc_info=True)
# Yield an error message? Or let the generator stop?
yield json.dumps({"error": f"Stream processing error: {str(e)}"})
def get_content(self, response: ChatCompletion) -> str:
"""Extracts content from a non-streaming OpenAI response."""
try:
content = response.choices[0].message.content
logger.debug(f"Extracted content (length {len(content) if content else 0}) from non-streaming response.")
return content or "" # Return empty string if content is None
except Exception as e:
logger.error(f"Error extracting content from OpenAI response: {e}", exc_info=True)
return f"[Error extracting content: {str(e)}]"
def has_tool_calls(self, response: Stream[ChatCompletionChunk] | ChatCompletion) -> bool:
"""Checks if the OpenAI response contains tool calls."""
try:
if isinstance(response, ChatCompletion): # Non-streaming
return bool(response.choices[0].message.tool_calls)
elif hasattr(response, "_iterator"): # Check if it looks like our stream wrapper
# This is tricky for streams. We'd need to peek at the first chunk(s)
# or buffer the response. For simplicity, this check might be unreliable
# for streams *before* they are consumed. LLMClient needs robust handling.
logger.warning("has_tool_calls check on a stream is unreliable before consumption.")
# A more robust check would involve consuming the start of the stream
# or relying on the structure after consumption.
return False # Assume no for unconsumed stream for now
else:
# If it's already consumed stream or unexpected type
logger.warning(f"has_tool_calls received unexpected type: {type(response)}")
return False
except Exception as e:
logger.error(f"Error checking for tool calls: {e}", exc_info=True)
return False
def parse_tool_calls(self, response: ChatCompletion) -> list[dict[str, Any]]:
"""Parses tool calls from a non-streaming OpenAI response."""
# This implementation assumes a non-streaming response or a fully buffered stream
parsed_calls = []
try:
if not isinstance(response, ChatCompletion):
logger.error(f"parse_tool_calls expects ChatCompletion, got {type(response)}")
# Attempt to handle buffered stream if possible? Complex.
return []
tool_calls: list[ChatCompletionMessageToolCall] | None = response.choices[0].message.tool_calls
if not tool_calls:
return []
logger.debug(f"Parsing {len(tool_calls)} tool calls from OpenAI response.")
for call in tool_calls:
if call.type == "function":
# Attempt to parse server_name from function name if prefixed
# e.g., "server-name__actual-tool-name"
parts = call.function.name.split("__", 1)
if len(parts) == 2:
server_name, func_name = parts
else:
# If no prefix, how do we know the server? Needs refinement.
# Defaulting to None or a default server? Log warning.
logger.warning(f"Could not determine server_name from tool name '{call.function.name}'. Assuming default or error needed.")
server_name = None # Or raise error, or use a default?
func_name = call.function.name
parsed_calls.append({
"id": call.id,
"server_name": server_name, # May be None if not prefixed
"function_name": func_name,
"arguments": call.function.arguments, # Arguments are already a string here
})
else:
logger.warning(f"Unsupported tool call type: {call.type}")
return parsed_calls
except Exception as e:
logger.error(f"Error parsing OpenAI tool calls: {e}", exc_info=True)
return [] # Return empty list on error
def format_tool_results(self, tool_call_id: str, result: Any) -> dict[str, Any]:
"""Formats a tool result for an OpenAI follow-up request."""
# Result might be a dict (including potential errors) or simple string/number
# OpenAI expects the content to be a string, often JSON.
try:
if isinstance(result, dict):
content = json.dumps(result)
else:
content = str(result) # Ensure it's a string
except Exception as e:
logger.error(f"Error JSON-encoding tool result for {tool_call_id}: {e}")
content = json.dumps({"error": "Failed to encode tool result", "original_type": str(type(result))})
logger.debug(f"Formatting tool result for call ID {tool_call_id}")
return {
"role": "tool",
"tool_call_id": tool_call_id,
"content": content,
}
def convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Converts internal tool format to OpenAI's format."""
openai_tools = []
logger.debug(f"Converting {len(tools)} tools to OpenAI format.")
for tool in tools:
server_name = tool.get("server_name")
tool_name = tool.get("name")
description = tool.get("description")
input_schema = tool.get("inputSchema")
if not server_name or not tool_name or not description or not input_schema:
logger.warning(f"Skipping invalid tool definition during conversion: {tool}")
continue
# Prefix tool name with server name to avoid clashes and allow routing
prefixed_tool_name = f"{server_name}__{tool_name}"
openai_tool_format = {
"type": "function",
"function": {
"name": prefixed_tool_name,
"description": description,
"parameters": input_schema, # OpenAI uses JSON Schema directly
},
}
openai_tools.append(openai_tool_format)
logger.debug(f"Converted tool: {prefixed_tool_name}")
return openai_tools
# Helper needed by LLMClient's current tool handling logic
def get_original_message_with_calls(self, response: ChatCompletion) -> dict[str, Any]:
"""Extracts the assistant's message containing tool calls."""
try:
if isinstance(response, ChatCompletion) and response.choices[0].message.tool_calls:
message = response.choices[0].message
# Convert Pydantic model to dict for message history
return message.model_dump(exclude_unset=True)
else:
logger.warning("Could not extract original message with tool calls from response.")
# Return a placeholder or raise error?
return {"role": "assistant", "content": "[Could not extract tool calls message]"}
except Exception as e:
logger.error(f"Error extracting original message with calls: {e}", exc_info=True)
return {"role": "assistant", "content": f"[Error extracting tool calls message: {str(e)}]"}
def get_usage(self, response: Any) -> dict[str, int] | None:
"""Extracts token usage from a non-streaming OpenAI response."""
try:
if isinstance(response, ChatCompletion) and response.usage:
usage = {
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
# "total_tokens": response.usage.total_tokens, # Optional
}
logger.debug(f"Extracted usage from OpenAI response: {usage}")
return usage
else:
logger.warning(f"Could not extract usage from OpenAI response object of type {type(response)}")
return None
except Exception as e:
logger.error(f"Error extracting usage from OpenAI response: {e}", exc_info=True)
return None

View File

@@ -0,0 +1,66 @@
# src/providers/openai_provider/__init__.py
from typing import Any
from openai import Stream
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from providers.openai_provider.client import initialize_client
from providers.openai_provider.completion import create_chat_completion
from providers.openai_provider.response import get_content, get_streaming_content, get_usage
from providers.openai_provider.tools import (
convert_tools,
format_tool_results,
get_original_message_with_calls,
has_tool_calls,
parse_tool_calls,
)
from src.providers.base import BaseProvider
class OpenAIProvider(BaseProvider):
"""Provider implementation for OpenAI and compatible APIs."""
def __init__(self, api_key: str, base_url: str | None = None):
# BaseProvider __init__ might not be needed if client init handles base_url logic
# super().__init__(api_key, base_url) # Let's see if we need this
self.client = initialize_client(api_key, base_url)
# Store api_key and base_url if needed by BaseProvider or other methods
self.api_key = api_key
self.base_url = self.client.base_url # Get effective base_url from client
def create_chat_completion(
self,
messages: list[dict[str, str]],
model: str,
temperature: float = 0.4,
max_tokens: int | None = None,
stream: bool = True,
tools: list[dict[str, Any]] | None = None,
) -> Stream[ChatCompletionChunk] | ChatCompletion:
# Pass self (provider instance) to the helper function
return create_chat_completion(self, messages, model, temperature, max_tokens, stream, tools)
def get_streaming_content(self, response: Stream[ChatCompletionChunk]):
return get_streaming_content(response)
def get_content(self, response: ChatCompletion) -> str:
return get_content(response)
def has_tool_calls(self, response: Stream[ChatCompletionChunk] | ChatCompletion) -> bool:
# This method might need the full response after streaming, handled by LLMClient
return has_tool_calls(response)
def parse_tool_calls(self, response: ChatCompletion) -> list[dict[str, Any]]:
return parse_tool_calls(response)
def format_tool_results(self, tool_call_id: str, result: Any) -> dict[str, Any]:
return format_tool_results(tool_call_id, result)
def convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
return convert_tools(tools)
def get_original_message_with_calls(self, response: ChatCompletion) -> dict[str, Any]:
return get_original_message_with_calls(response)
def get_usage(self, response: Any) -> dict[str, int] | None:
return get_usage(response)

View File

@@ -0,0 +1,23 @@
# src/providers/openai_provider/client.py
import logging
from openai import OpenAI
from src.llm_models import MODELS
logger = logging.getLogger(__name__)
def initialize_client(api_key: str, base_url: str | None = None) -> OpenAI:
"""Initializes and returns an OpenAI client instance."""
# Use default OpenAI endpoint if base_url is not provided explicitly
effective_base_url = base_url or MODELS.get("openai", {}).get("endpoint")
logger.info(f"Initializing OpenAI client with base URL: {effective_base_url}")
try:
# TODO: Add default headers if needed, similar to the original openai_client.py?
# default_headers={"HTTP-Referer": "...", "X-Title": "..."}
client = OpenAI(api_key=api_key, base_url=effective_base_url)
return client
except Exception as e:
logger.error(f"Failed to initialize OpenAI client: {e}", exc_info=True)
raise

View File

@@ -0,0 +1,80 @@
# src/providers/openai_provider/completion.py
import logging
from typing import Any
from openai import Stream
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from providers.openai_provider.utils import truncate_messages
logger = logging.getLogger(__name__)
def create_chat_completion(
provider, # The OpenAIProvider instance
messages: list[dict[str, str]],
model: str,
temperature: float = 0.4,
max_tokens: int | None = None,
stream: bool = True,
tools: list[dict[str, Any]] | None = None,
) -> Stream[ChatCompletionChunk] | ChatCompletion:
"""Creates a chat completion using the OpenAI API, handling context window truncation."""
logger.debug(f"OpenAI create_chat_completion called. Model: {model}, Stream: {stream}, Tools: {bool(tools)}")
# --- Truncation Step ---
truncated_messages, initial_est_tokens, final_est_tokens = truncate_messages(messages, model)
# -----------------------
try:
completion_params = {
"model": model,
"messages": truncated_messages, # Use truncated messages
"temperature": temperature,
"max_tokens": max_tokens,
"stream": stream,
}
if tools:
completion_params["tools"] = tools
completion_params["tool_choice"] = "auto" # Let OpenAI decide when to use tools
# Remove None values like max_tokens if not provided
completion_params = {k: v for k, v in completion_params.items() if v is not None}
# --- Added Debug Logging ---
log_params = completion_params.copy()
# Avoid logging full messages if they are too long
if "messages" in log_params:
log_params["messages"] = [
{k: (v[:100] + "..." if isinstance(v, str) and len(v) > 100 else v) for k, v in msg.items()}
for msg in log_params["messages"][-2:] # Log last 2 messages summary
]
# Specifically log tools structure if present
tools_log = log_params.get("tools", "Not Present")
logger.debug(f"Calling OpenAI API. Model: {log_params.get('model')}, Stream: {log_params.get('stream')}, Tools: {tools_log}")
logger.debug(f"Full API Params (messages summarized): {log_params}")
# --- End Added Debug Logging ---
response = provider.client.chat.completions.create(**completion_params)
logger.debug("OpenAI API call successful.")
# --- Capture Actual Usage (for UI display later) ---
# Log usage if available (primarily non-streaming)
actual_usage = None
if isinstance(response, ChatCompletion) and response.usage:
actual_usage = {
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens,
}
logger.info(f"Actual OpenAI API usage: {actual_usage}")
# TODO: How to handle usage for streaming responses? Needs investigation.
# Return the raw response for now. LLMClient will process it.
return response
# ----------------------------------------------------
except Exception as e:
logger.error(f"OpenAI API error: {e}", exc_info=True)
# Re-raise for the LLMClient to handle
raise

View File

@@ -0,0 +1,69 @@
# src/providers/openai_provider/response.py
import json
import logging
from collections.abc import Generator
from typing import Any
from openai import Stream
from openai.types.chat import ChatCompletion, ChatCompletionChunk
logger = logging.getLogger(__name__)
def get_streaming_content(response: Stream[ChatCompletionChunk]) -> Generator[str, None, None]:
"""Yields content chunks from an OpenAI streaming response."""
logger.debug("Processing OpenAI stream...")
full_delta = ""
try:
for chunk in response:
# Check if choices exist and are not empty
if chunk.choices:
delta = chunk.choices[0].delta.content
if delta:
full_delta += delta
yield delta
# Handle potential finish reasons or other stream elements if needed
# else:
# logger.debug(f"Stream chunk without choices: {chunk}") # Or handle finish reason etc.
logger.debug(f"Stream finished. Total delta length: {len(full_delta)}")
except Exception as e:
logger.error(f"Error processing OpenAI stream: {e}", exc_info=True)
# Yield an error message? Or let the generator stop?
yield json.dumps({"error": f"Stream processing error: {str(e)}"})
def get_content(response: ChatCompletion) -> str:
"""Extracts content from a non-streaming OpenAI response."""
try:
# Check if choices exist and are not empty
if response.choices:
content = response.choices[0].message.content
logger.debug(f"Extracted content (length {len(content) if content else 0}) from non-streaming response.")
return content or "" # Return empty string if content is None
else:
logger.warning("No choices found in OpenAI non-streaming response.")
return "[No content received]"
except Exception as e:
logger.error(f"Error extracting content from OpenAI response: {e}", exc_info=True)
return f"[Error extracting content: {str(e)}]"
def get_usage(response: Any) -> dict[str, int] | None:
"""Extracts token usage from a non-streaming OpenAI response."""
try:
if isinstance(response, ChatCompletion) and response.usage:
usage = {
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
# "total_tokens": response.usage.total_tokens, # Optional
}
logger.debug(f"Extracted usage from OpenAI response: {usage}")
return usage
else:
# Don't log warning for streams, as usage isn't expected here
if not isinstance(response, Stream):
logger.warning(f"Could not extract usage from OpenAI response object of type {type(response)}")
return None
except Exception as e:
logger.error(f"Error extracting usage from OpenAI response: {e}", exc_info=True)
return None

View File

@@ -0,0 +1,170 @@
# src/providers/openai_provider/tools.py
import json
import logging
from typing import Any
from openai import Stream
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall
logger = logging.getLogger(__name__)
def has_tool_calls(response: Stream[ChatCompletionChunk] | ChatCompletion) -> bool:
"""Checks if the OpenAI response contains tool calls."""
try:
if isinstance(response, ChatCompletion): # Non-streaming
# Check if choices exist and are not empty
if response.choices:
return bool(response.choices[0].message.tool_calls)
else:
logger.warning("No choices found in OpenAI non-streaming response for tool check.")
return False
elif isinstance(response, Stream):
# This check remains unreliable for unconsumed streams.
# LLMClient needs robust handling after consumption.
logger.warning("has_tool_calls check on a stream is unreliable before consumption.")
return False # Assume no for unconsumed stream for now
else:
# If it's already consumed stream or unexpected type
logger.warning(f"has_tool_calls received unexpected type: {type(response)}")
return False
except Exception as e:
logger.error(f"Error checking for tool calls: {e}", exc_info=True)
return False
def parse_tool_calls(response: ChatCompletion) -> list[dict[str, Any]]:
"""Parses tool calls from a non-streaming OpenAI response."""
# This implementation assumes a non-streaming response or a fully buffered stream
parsed_calls = []
try:
if not isinstance(response, ChatCompletion):
logger.error(f"parse_tool_calls expects ChatCompletion, got {type(response)}")
return []
# Check if choices exist and are not empty
if not response.choices:
logger.warning("No choices found in OpenAI non-streaming response for tool parsing.")
return []
tool_calls: list[ChatCompletionMessageToolCall] | None = response.choices[0].message.tool_calls
if not tool_calls:
return []
logger.debug(f"Parsing {len(tool_calls)} tool calls from OpenAI response.")
for call in tool_calls:
if call.type == "function":
# Attempt to parse server_name from function name if prefixed
# e.g., "server-name__actual-tool-name"
parts = call.function.name.split("__", 1)
if len(parts) == 2:
server_name, func_name = parts
else:
# If no prefix, how do we know the server? Needs refinement.
# Defaulting to None or a default server? Log warning.
logger.warning(f"Could not determine server_name from tool name '{call.function.name}'. Assuming default or error needed.")
server_name = None # Or raise error, or use a default?
func_name = call.function.name
# Arguments might be a string needing JSON parsing, or already parsed dict
arguments_obj = None
try:
if isinstance(call.function.arguments, str):
arguments_obj = json.loads(call.function.arguments)
else:
# Assuming it might already be a dict if not a string (less common)
arguments_obj = call.function.arguments
except json.JSONDecodeError as json_err:
logger.error(f"Failed to parse JSON arguments for tool {func_name} (ID: {call.id}): {json_err}")
logger.error(f"Raw arguments string: {call.function.arguments}")
# Decide how to handle: skip tool, pass raw string, pass error?
# Passing raw string for now, but this might break consumers.
arguments_obj = {"error": "Failed to parse arguments", "raw_arguments": call.function.arguments}
parsed_calls.append({
"id": call.id,
"server_name": server_name, # May be None if not prefixed
"function_name": func_name,
"arguments": arguments_obj, # Pass parsed arguments (or error dict)
})
else:
logger.warning(f"Unsupported tool call type: {call.type}")
return parsed_calls
except Exception as e:
logger.error(f"Error parsing OpenAI tool calls: {e}", exc_info=True)
return [] # Return empty list on error
def format_tool_results(tool_call_id: str, result: Any) -> dict[str, Any]:
"""Formats a tool result for an OpenAI follow-up request."""
# Result might be a dict (including potential errors) or simple string/number
# OpenAI expects the content to be a string, often JSON.
try:
if isinstance(result, dict):
content = json.dumps(result)
elif isinstance(result, str):
content = result # Allow plain strings if result is already string
else:
content = str(result) # Ensure it's a string otherwise
except Exception as e:
logger.error(f"Error JSON-encoding tool result for {tool_call_id}: {e}")
content = json.dumps({"error": "Failed to encode tool result", "original_type": str(type(result))})
logger.debug(f"Formatting tool result for call ID {tool_call_id}")
return {
"role": "tool",
"tool_call_id": tool_call_id,
"content": content,
}
def convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Converts internal tool format to OpenAI's format."""
# This function seems identical to the one in src/tools/conversion.py
# We can potentially remove it from here and import from the central location.
# For now, keep it duplicated to maintain modularity until a decision is made.
openai_tools = []
logger.debug(f"Converting {len(tools)} tools to OpenAI format.")
for tool in tools:
server_name = tool.get("server_name")
tool_name = tool.get("name")
description = tool.get("description")
input_schema = tool.get("inputSchema")
if not server_name or not tool_name or not description or not input_schema:
logger.warning(f"Skipping invalid tool definition during conversion: {tool}")
continue
# Prefix tool name with server name to avoid clashes and allow routing
prefixed_tool_name = f"{server_name}__{tool_name}"
openai_tool_format = {
"type": "function",
"function": {
"name": prefixed_tool_name,
"description": description,
"parameters": input_schema, # OpenAI uses JSON Schema directly
},
}
openai_tools.append(openai_tool_format)
logger.debug(f"Converted tool: {prefixed_tool_name}")
return openai_tools
def get_original_message_with_calls(response: ChatCompletion) -> dict[str, Any]:
"""Extracts the assistant's message containing tool calls."""
try:
if isinstance(response, ChatCompletion) and response.choices and response.choices[0].message.tool_calls:
message = response.choices[0].message
# Convert Pydantic model to dict for message history
return message.model_dump(exclude_unset=True)
else:
logger.warning("Could not extract original message with tool calls from response.")
# Return a placeholder or raise error?
return {"role": "assistant", "content": "[Could not extract tool calls message]"}
except Exception as e:
logger.error(f"Error extracting original message with calls: {e}", exc_info=True)
return {"role": "assistant", "content": f"[Error extracting tool calls message: {str(e)}]"}

View File

@@ -0,0 +1,114 @@
# src/providers/openai_provider/utils.py
import logging
import math
from src.llm_models import MODELS
logger = logging.getLogger(__name__)
def get_context_window(model: str) -> int:
"""Retrieves the context window size for a given model."""
# Default to a safe fallback if model or provider info is missing
default_window = 8000
try:
# Assuming MODELS structure: MODELS['openai']['models'] is a list of dicts
provider_models = MODELS.get("openai", {}).get("models", [])
for m in provider_models:
if m.get("id") == model:
return m.get("context_window", default_window)
# Fallback if specific model ID not found in our list
logger.warning(f"Context window for OpenAI model '{model}' not found in MODELS config. Using default: {default_window}")
return default_window
except Exception as e:
logger.error(f"Error retrieving context window for model {model}: {e}. Using default: {default_window}", exc_info=True)
return default_window
def estimate_openai_token_count(messages: list[dict[str, str]]) -> int:
"""
Estimates the token count for OpenAI messages using char count / 4 approximation.
Note: This is less accurate than using tiktoken.
"""
total_chars = 0
for message in messages:
total_chars += len(message.get("role", ""))
content = message.get("content")
if isinstance(content, str):
total_chars += len(content)
# Rough approximation for function/tool call overhead if needed later
# Using math.ceil to round up, ensuring we don't underestimate too much.
estimated_tokens = math.ceil(total_chars / 4.0)
logger.debug(f"Estimated OpenAI token count (char/4): {estimated_tokens} for {len(messages)} messages")
return estimated_tokens
def truncate_messages(messages: list[dict[str, str]], model: str) -> tuple[list[dict[str, str]], int, int]:
"""
Truncates messages from the beginning if estimated token count exceeds the limit.
Preserves the first message if it's a system prompt.
Returns:
- The potentially truncated list of messages.
- The initial estimated token count.
- The final estimated token count after truncation (if any).
"""
context_limit = get_context_window(model)
# Add a buffer to be safer with approximation
buffer = 200 # Reduce buffer slightly as we round up now
effective_limit = context_limit - buffer
initial_estimated_count = estimate_openai_token_count(messages)
final_estimated_count = initial_estimated_count
truncated_messages = list(messages) # Make a copy
# Identify if the first message is a system prompt
has_system_prompt = False
if truncated_messages and truncated_messages[0].get("role") == "system":
has_system_prompt = True
# If only system prompt exists, don't truncate further
if len(truncated_messages) == 1 and final_estimated_count > effective_limit:
logger.warning(f"System prompt alone ({final_estimated_count} tokens) exceeds effective limit ({effective_limit}). Cannot truncate further.")
# Return original messages to avoid removing the only message
return messages, initial_estimated_count, final_estimated_count
while final_estimated_count > effective_limit:
if has_system_prompt and len(truncated_messages) <= 1:
# Should not happen if check above works, but safety break
logger.warning("Truncation stopped: Only system prompt remains.")
break
if not has_system_prompt and len(truncated_messages) <= 0:
logger.warning("Truncation stopped: No messages left.")
break # No messages left
# Determine index to remove: 1 if system prompt exists and list is long enough, else 0
remove_index = 1 if has_system_prompt and len(truncated_messages) > 1 else 0
if remove_index >= len(truncated_messages):
logger.error(f"Truncation logic error: remove_index {remove_index} out of bounds for {len(truncated_messages)} messages.")
break # Avoid index error
removed_message = truncated_messages.pop(remove_index)
logger.debug(f"Truncating message at index {remove_index} (Role: {removed_message.get('role')}) due to context limit.")
# Recalculate estimated count
final_estimated_count = estimate_openai_token_count(truncated_messages)
logger.debug(f"Recalculated estimated tokens: {final_estimated_count}")
# Safety break if list becomes unexpectedly empty
if not truncated_messages:
logger.warning("Truncation resulted in empty message list.")
break
if initial_estimated_count != final_estimated_count:
logger.info(
f"Truncated messages for model {model}. "
f"Initial estimated tokens: {initial_estimated_count}, "
f"Final estimated tokens: {final_estimated_count}, "
f"Limit: {context_limit} (Effective: {effective_limit})"
)
else:
logger.debug(f"No truncation needed for model {model}. Estimated tokens: {final_estimated_count}, Limit: {context_limit} (Effective: {effective_limit})")
return truncated_messages, initial_estimated_count, final_estimated_count

View File

@@ -11,59 +11,6 @@ from typing import Any
logger = logging.getLogger(__name__)
def convert_to_openai_tools(mcp_tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""
Convert MCP tools to OpenAI tool definitions.
Args:
mcp_tools: List of MCP tools (each with server_name, name, description, inputSchema).
Returns:
List of OpenAI tool definitions.
"""
openai_tools = []
logger.debug(f"Converting {len(mcp_tools)} MCP tools to OpenAI format.")
for tool in mcp_tools:
server_name = tool.get("server_name")
tool_name = tool.get("name")
description = tool.get("description")
input_schema = tool.get("inputSchema")
if not server_name or not tool_name or not description or not input_schema:
logger.warning(f"Skipping invalid MCP tool definition during OpenAI conversion: {tool}")
continue
# Prefix tool name with server name for routing
prefixed_tool_name = f"{server_name}__{tool_name}"
# Initialize the OpenAI tool structure
openai_tool = {
"type": "function",
"function": {
"name": prefixed_tool_name,
"description": description,
"parameters": input_schema, # OpenAI uses JSON Schema directly
},
}
# Basic validation/cleaning of schema if needed could go here
if not isinstance(input_schema, dict) or input_schema.get("type") != "object":
logger.warning(f"Input schema for tool '{prefixed_tool_name}' is not a valid JSON object schema. OpenAI might reject this.")
# Ensure basic structure if missing
if not isinstance(input_schema, dict):
input_schema = {}
if "type" not in input_schema:
input_schema["type"] = "object"
if "properties" not in input_schema:
input_schema["properties"] = {}
openai_tool["function"]["parameters"] = input_schema
openai_tools.append(openai_tool)
logger.debug(f"Converted MCP tool to OpenAI: {prefixed_tool_name}")
return openai_tools
def convert_to_google_tools(mcp_tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""
Convert MCP tools to Google Gemini format (dictionary structure).