# src/providers/openai_provider/utils.py import logging import math from src.llm_models import MODELS logger = logging.getLogger(__name__) def get_context_window(model: str) -> int: """Retrieves the context window size for a given model.""" # Default to a safe fallback if model or provider info is missing default_window = 8000 try: # Assuming MODELS structure: MODELS['openai']['models'] is a list of dicts provider_models = MODELS.get("openai", {}).get("models", []) for m in provider_models: if m.get("id") == model: return m.get("context_window", default_window) # Fallback if specific model ID not found in our list logger.warning(f"Context window for OpenAI model '{model}' not found in MODELS config. Using default: {default_window}") return default_window except Exception as e: logger.error(f"Error retrieving context window for model {model}: {e}. Using default: {default_window}", exc_info=True) return default_window def estimate_openai_token_count(messages: list[dict[str, str]]) -> int: """ Estimates the token count for OpenAI messages using char count / 4 approximation. Note: This is less accurate than using tiktoken. """ total_chars = 0 for message in messages: total_chars += len(message.get("role", "")) content = message.get("content") if isinstance(content, str): total_chars += len(content) # Rough approximation for function/tool call overhead if needed later # Using math.ceil to round up, ensuring we don't underestimate too much. estimated_tokens = math.ceil(total_chars / 4.0) logger.debug(f"Estimated OpenAI token count (char/4): {estimated_tokens} for {len(messages)} messages") return estimated_tokens def truncate_messages(messages: list[dict[str, str]], model: str) -> tuple[list[dict[str, str]], int, int]: """ Truncates messages from the beginning if estimated token count exceeds the limit. Preserves the first message if it's a system prompt. Returns: - The potentially truncated list of messages. - The initial estimated token count. - The final estimated token count after truncation (if any). """ context_limit = get_context_window(model) # Add a buffer to be safer with approximation buffer = 200 # Reduce buffer slightly as we round up now effective_limit = context_limit - buffer initial_estimated_count = estimate_openai_token_count(messages) final_estimated_count = initial_estimated_count truncated_messages = list(messages) # Make a copy # Identify if the first message is a system prompt has_system_prompt = False if truncated_messages and truncated_messages[0].get("role") == "system": has_system_prompt = True # If only system prompt exists, don't truncate further if len(truncated_messages) == 1 and final_estimated_count > effective_limit: logger.warning(f"System prompt alone ({final_estimated_count} tokens) exceeds effective limit ({effective_limit}). Cannot truncate further.") # Return original messages to avoid removing the only message return messages, initial_estimated_count, final_estimated_count while final_estimated_count > effective_limit: if has_system_prompt and len(truncated_messages) <= 1: # Should not happen if check above works, but safety break logger.warning("Truncation stopped: Only system prompt remains.") break if not has_system_prompt and len(truncated_messages) <= 0: logger.warning("Truncation stopped: No messages left.") break # No messages left # Determine index to remove: 1 if system prompt exists and list is long enough, else 0 remove_index = 1 if has_system_prompt and len(truncated_messages) > 1 else 0 if remove_index >= len(truncated_messages): logger.error(f"Truncation logic error: remove_index {remove_index} out of bounds for {len(truncated_messages)} messages.") break # Avoid index error removed_message = truncated_messages.pop(remove_index) logger.debug(f"Truncating message at index {remove_index} (Role: {removed_message.get('role')}) due to context limit.") # Recalculate estimated count final_estimated_count = estimate_openai_token_count(truncated_messages) logger.debug(f"Recalculated estimated tokens: {final_estimated_count}") # Safety break if list becomes unexpectedly empty if not truncated_messages: logger.warning("Truncation resulted in empty message list.") break if initial_estimated_count != final_estimated_count: logger.info( f"Truncated messages for model {model}. " f"Initial estimated tokens: {initial_estimated_count}, " f"Final estimated tokens: {final_estimated_count}, " f"Limit: {context_limit} (Effective: {effective_limit})" ) else: logger.debug(f"No truncation needed for model {model}. Estimated tokens: {final_estimated_count}, Limit: {context_limit} (Effective: {effective_limit})") return truncated_messages, initial_estimated_count, final_estimated_count