feat: implement OpenAIProvider with client initialization, message handling, and utility functions
This commit is contained in:
114
src/providers/openai_provider/utils.py
Normal file
114
src/providers/openai_provider/utils.py
Normal file
@@ -0,0 +1,114 @@
|
||||
# src/providers/openai_provider/utils.py
|
||||
import logging
|
||||
import math
|
||||
|
||||
from src.llm_models import MODELS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_context_window(model: str) -> int:
|
||||
"""Retrieves the context window size for a given model."""
|
||||
# Default to a safe fallback if model or provider info is missing
|
||||
default_window = 8000
|
||||
try:
|
||||
# Assuming MODELS structure: MODELS['openai']['models'] is a list of dicts
|
||||
provider_models = MODELS.get("openai", {}).get("models", [])
|
||||
for m in provider_models:
|
||||
if m.get("id") == model:
|
||||
return m.get("context_window", default_window)
|
||||
# Fallback if specific model ID not found in our list
|
||||
logger.warning(f"Context window for OpenAI model '{model}' not found in MODELS config. Using default: {default_window}")
|
||||
return default_window
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving context window for model {model}: {e}. Using default: {default_window}", exc_info=True)
|
||||
return default_window
|
||||
|
||||
|
||||
def estimate_openai_token_count(messages: list[dict[str, str]]) -> int:
|
||||
"""
|
||||
Estimates the token count for OpenAI messages using char count / 4 approximation.
|
||||
Note: This is less accurate than using tiktoken.
|
||||
"""
|
||||
total_chars = 0
|
||||
for message in messages:
|
||||
total_chars += len(message.get("role", ""))
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
total_chars += len(content)
|
||||
# Rough approximation for function/tool call overhead if needed later
|
||||
# Using math.ceil to round up, ensuring we don't underestimate too much.
|
||||
estimated_tokens = math.ceil(total_chars / 4.0)
|
||||
logger.debug(f"Estimated OpenAI token count (char/4): {estimated_tokens} for {len(messages)} messages")
|
||||
return estimated_tokens
|
||||
|
||||
|
||||
def truncate_messages(messages: list[dict[str, str]], model: str) -> tuple[list[dict[str, str]], int, int]:
|
||||
"""
|
||||
Truncates messages from the beginning if estimated token count exceeds the limit.
|
||||
Preserves the first message if it's a system prompt.
|
||||
|
||||
Returns:
|
||||
- The potentially truncated list of messages.
|
||||
- The initial estimated token count.
|
||||
- The final estimated token count after truncation (if any).
|
||||
"""
|
||||
context_limit = get_context_window(model)
|
||||
# Add a buffer to be safer with approximation
|
||||
buffer = 200 # Reduce buffer slightly as we round up now
|
||||
effective_limit = context_limit - buffer
|
||||
|
||||
initial_estimated_count = estimate_openai_token_count(messages)
|
||||
final_estimated_count = initial_estimated_count
|
||||
|
||||
truncated_messages = list(messages) # Make a copy
|
||||
|
||||
# Identify if the first message is a system prompt
|
||||
has_system_prompt = False
|
||||
if truncated_messages and truncated_messages[0].get("role") == "system":
|
||||
has_system_prompt = True
|
||||
# If only system prompt exists, don't truncate further
|
||||
if len(truncated_messages) == 1 and final_estimated_count > effective_limit:
|
||||
logger.warning(f"System prompt alone ({final_estimated_count} tokens) exceeds effective limit ({effective_limit}). Cannot truncate further.")
|
||||
# Return original messages to avoid removing the only message
|
||||
return messages, initial_estimated_count, final_estimated_count
|
||||
|
||||
while final_estimated_count > effective_limit:
|
||||
if has_system_prompt and len(truncated_messages) <= 1:
|
||||
# Should not happen if check above works, but safety break
|
||||
logger.warning("Truncation stopped: Only system prompt remains.")
|
||||
break
|
||||
if not has_system_prompt and len(truncated_messages) <= 0:
|
||||
logger.warning("Truncation stopped: No messages left.")
|
||||
break # No messages left
|
||||
|
||||
# Determine index to remove: 1 if system prompt exists and list is long enough, else 0
|
||||
remove_index = 1 if has_system_prompt and len(truncated_messages) > 1 else 0
|
||||
|
||||
if remove_index >= len(truncated_messages):
|
||||
logger.error(f"Truncation logic error: remove_index {remove_index} out of bounds for {len(truncated_messages)} messages.")
|
||||
break # Avoid index error
|
||||
|
||||
removed_message = truncated_messages.pop(remove_index)
|
||||
logger.debug(f"Truncating message at index {remove_index} (Role: {removed_message.get('role')}) due to context limit.")
|
||||
|
||||
# Recalculate estimated count
|
||||
final_estimated_count = estimate_openai_token_count(truncated_messages)
|
||||
logger.debug(f"Recalculated estimated tokens: {final_estimated_count}")
|
||||
|
||||
# Safety break if list becomes unexpectedly empty
|
||||
if not truncated_messages:
|
||||
logger.warning("Truncation resulted in empty message list.")
|
||||
break
|
||||
|
||||
if initial_estimated_count != final_estimated_count:
|
||||
logger.info(
|
||||
f"Truncated messages for model {model}. "
|
||||
f"Initial estimated tokens: {initial_estimated_count}, "
|
||||
f"Final estimated tokens: {final_estimated_count}, "
|
||||
f"Limit: {context_limit} (Effective: {effective_limit})"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"No truncation needed for model {model}. Estimated tokens: {final_estimated_count}, Limit: {context_limit} (Effective: {effective_limit})")
|
||||
|
||||
return truncated_messages, initial_estimated_count, final_estimated_count
|
||||
Reference in New Issue
Block a user