115 lines
5.2 KiB
Python
115 lines
5.2 KiB
Python
# src/providers/openai_provider/utils.py
|
|
import logging
|
|
import math
|
|
|
|
from src.llm_models import MODELS
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def get_context_window(model: str) -> int:
|
|
"""Retrieves the context window size for a given model."""
|
|
# Default to a safe fallback if model or provider info is missing
|
|
default_window = 8000
|
|
try:
|
|
# Assuming MODELS structure: MODELS['openai']['models'] is a list of dicts
|
|
provider_models = MODELS.get("openai", {}).get("models", [])
|
|
for m in provider_models:
|
|
if m.get("id") == model:
|
|
return m.get("context_window", default_window)
|
|
# Fallback if specific model ID not found in our list
|
|
logger.warning(f"Context window for OpenAI model '{model}' not found in MODELS config. Using default: {default_window}")
|
|
return default_window
|
|
except Exception as e:
|
|
logger.error(f"Error retrieving context window for model {model}: {e}. Using default: {default_window}", exc_info=True)
|
|
return default_window
|
|
|
|
|
|
def estimate_openai_token_count(messages: list[dict[str, str]]) -> int:
|
|
"""
|
|
Estimates the token count for OpenAI messages using char count / 4 approximation.
|
|
Note: This is less accurate than using tiktoken.
|
|
"""
|
|
total_chars = 0
|
|
for message in messages:
|
|
total_chars += len(message.get("role", ""))
|
|
content = message.get("content")
|
|
if isinstance(content, str):
|
|
total_chars += len(content)
|
|
# Rough approximation for function/tool call overhead if needed later
|
|
# Using math.ceil to round up, ensuring we don't underestimate too much.
|
|
estimated_tokens = math.ceil(total_chars / 4.0)
|
|
logger.debug(f"Estimated OpenAI token count (char/4): {estimated_tokens} for {len(messages)} messages")
|
|
return estimated_tokens
|
|
|
|
|
|
def truncate_messages(messages: list[dict[str, str]], model: str) -> tuple[list[dict[str, str]], int, int]:
|
|
"""
|
|
Truncates messages from the beginning if estimated token count exceeds the limit.
|
|
Preserves the first message if it's a system prompt.
|
|
|
|
Returns:
|
|
- The potentially truncated list of messages.
|
|
- The initial estimated token count.
|
|
- The final estimated token count after truncation (if any).
|
|
"""
|
|
context_limit = get_context_window(model)
|
|
# Add a buffer to be safer with approximation
|
|
buffer = 200 # Reduce buffer slightly as we round up now
|
|
effective_limit = context_limit - buffer
|
|
|
|
initial_estimated_count = estimate_openai_token_count(messages)
|
|
final_estimated_count = initial_estimated_count
|
|
|
|
truncated_messages = list(messages) # Make a copy
|
|
|
|
# Identify if the first message is a system prompt
|
|
has_system_prompt = False
|
|
if truncated_messages and truncated_messages[0].get("role") == "system":
|
|
has_system_prompt = True
|
|
# If only system prompt exists, don't truncate further
|
|
if len(truncated_messages) == 1 and final_estimated_count > effective_limit:
|
|
logger.warning(f"System prompt alone ({final_estimated_count} tokens) exceeds effective limit ({effective_limit}). Cannot truncate further.")
|
|
# Return original messages to avoid removing the only message
|
|
return messages, initial_estimated_count, final_estimated_count
|
|
|
|
while final_estimated_count > effective_limit:
|
|
if has_system_prompt and len(truncated_messages) <= 1:
|
|
# Should not happen if check above works, but safety break
|
|
logger.warning("Truncation stopped: Only system prompt remains.")
|
|
break
|
|
if not has_system_prompt and len(truncated_messages) <= 0:
|
|
logger.warning("Truncation stopped: No messages left.")
|
|
break # No messages left
|
|
|
|
# Determine index to remove: 1 if system prompt exists and list is long enough, else 0
|
|
remove_index = 1 if has_system_prompt and len(truncated_messages) > 1 else 0
|
|
|
|
if remove_index >= len(truncated_messages):
|
|
logger.error(f"Truncation logic error: remove_index {remove_index} out of bounds for {len(truncated_messages)} messages.")
|
|
break # Avoid index error
|
|
|
|
removed_message = truncated_messages.pop(remove_index)
|
|
logger.debug(f"Truncating message at index {remove_index} (Role: {removed_message.get('role')}) due to context limit.")
|
|
|
|
# Recalculate estimated count
|
|
final_estimated_count = estimate_openai_token_count(truncated_messages)
|
|
logger.debug(f"Recalculated estimated tokens: {final_estimated_count}")
|
|
|
|
# Safety break if list becomes unexpectedly empty
|
|
if not truncated_messages:
|
|
logger.warning("Truncation resulted in empty message list.")
|
|
break
|
|
|
|
if initial_estimated_count != final_estimated_count:
|
|
logger.info(
|
|
f"Truncated messages for model {model}. "
|
|
f"Initial estimated tokens: {initial_estimated_count}, "
|
|
f"Final estimated tokens: {final_estimated_count}, "
|
|
f"Limit: {context_limit} (Effective: {effective_limit})"
|
|
)
|
|
else:
|
|
logger.debug(f"No truncation needed for model {model}. Estimated tokens: {final_estimated_count}, Limit: {context_limit} (Effective: {effective_limit})")
|
|
|
|
return truncated_messages, initial_estimated_count, final_estimated_count
|