feat: implement OpenAIProvider with client initialization, message handling, and utility functions

This commit is contained in:
2025-03-26 19:59:01 +00:00
parent bae517a322
commit 678f395649
8 changed files with 522 additions and 443 deletions

View File

@@ -0,0 +1,114 @@
# src/providers/openai_provider/utils.py
import logging
import math
from src.llm_models import MODELS
logger = logging.getLogger(__name__)
def get_context_window(model: str) -> int:
"""Retrieves the context window size for a given model."""
# Default to a safe fallback if model or provider info is missing
default_window = 8000
try:
# Assuming MODELS structure: MODELS['openai']['models'] is a list of dicts
provider_models = MODELS.get("openai", {}).get("models", [])
for m in provider_models:
if m.get("id") == model:
return m.get("context_window", default_window)
# Fallback if specific model ID not found in our list
logger.warning(f"Context window for OpenAI model '{model}' not found in MODELS config. Using default: {default_window}")
return default_window
except Exception as e:
logger.error(f"Error retrieving context window for model {model}: {e}. Using default: {default_window}", exc_info=True)
return default_window
def estimate_openai_token_count(messages: list[dict[str, str]]) -> int:
"""
Estimates the token count for OpenAI messages using char count / 4 approximation.
Note: This is less accurate than using tiktoken.
"""
total_chars = 0
for message in messages:
total_chars += len(message.get("role", ""))
content = message.get("content")
if isinstance(content, str):
total_chars += len(content)
# Rough approximation for function/tool call overhead if needed later
# Using math.ceil to round up, ensuring we don't underestimate too much.
estimated_tokens = math.ceil(total_chars / 4.0)
logger.debug(f"Estimated OpenAI token count (char/4): {estimated_tokens} for {len(messages)} messages")
return estimated_tokens
def truncate_messages(messages: list[dict[str, str]], model: str) -> tuple[list[dict[str, str]], int, int]:
"""
Truncates messages from the beginning if estimated token count exceeds the limit.
Preserves the first message if it's a system prompt.
Returns:
- The potentially truncated list of messages.
- The initial estimated token count.
- The final estimated token count after truncation (if any).
"""
context_limit = get_context_window(model)
# Add a buffer to be safer with approximation
buffer = 200 # Reduce buffer slightly as we round up now
effective_limit = context_limit - buffer
initial_estimated_count = estimate_openai_token_count(messages)
final_estimated_count = initial_estimated_count
truncated_messages = list(messages) # Make a copy
# Identify if the first message is a system prompt
has_system_prompt = False
if truncated_messages and truncated_messages[0].get("role") == "system":
has_system_prompt = True
# If only system prompt exists, don't truncate further
if len(truncated_messages) == 1 and final_estimated_count > effective_limit:
logger.warning(f"System prompt alone ({final_estimated_count} tokens) exceeds effective limit ({effective_limit}). Cannot truncate further.")
# Return original messages to avoid removing the only message
return messages, initial_estimated_count, final_estimated_count
while final_estimated_count > effective_limit:
if has_system_prompt and len(truncated_messages) <= 1:
# Should not happen if check above works, but safety break
logger.warning("Truncation stopped: Only system prompt remains.")
break
if not has_system_prompt and len(truncated_messages) <= 0:
logger.warning("Truncation stopped: No messages left.")
break # No messages left
# Determine index to remove: 1 if system prompt exists and list is long enough, else 0
remove_index = 1 if has_system_prompt and len(truncated_messages) > 1 else 0
if remove_index >= len(truncated_messages):
logger.error(f"Truncation logic error: remove_index {remove_index} out of bounds for {len(truncated_messages)} messages.")
break # Avoid index error
removed_message = truncated_messages.pop(remove_index)
logger.debug(f"Truncating message at index {remove_index} (Role: {removed_message.get('role')}) due to context limit.")
# Recalculate estimated count
final_estimated_count = estimate_openai_token_count(truncated_messages)
logger.debug(f"Recalculated estimated tokens: {final_estimated_count}")
# Safety break if list becomes unexpectedly empty
if not truncated_messages:
logger.warning("Truncation resulted in empty message list.")
break
if initial_estimated_count != final_estimated_count:
logger.info(
f"Truncated messages for model {model}. "
f"Initial estimated tokens: {initial_estimated_count}, "
f"Final estimated tokens: {final_estimated_count}, "
f"Limit: {context_limit} (Effective: {effective_limit})"
)
else:
logger.debug(f"No truncation needed for model {model}. Estimated tokens: {final_estimated_count}, Limit: {context_limit} (Effective: {effective_limit})")
return truncated_messages, initial_estimated_count, final_estimated_count