feat: enhance token usage tracking and context management for LLM providers

This commit is contained in:
2025-03-26 17:27:41 +00:00
parent 49aebc12d5
commit 15ecb9fc48
5 changed files with 395 additions and 68 deletions

View File

@@ -1,16 +1,14 @@
# src/providers/anthropic_provider.py
import json
import logging
import math
from collections.abc import Generator
from typing import Any
from anthropic import Anthropic, Stream
from anthropic import Anthropic, APIError, Stream
from anthropic.types import Message, MessageStreamEvent, TextDelta
# Use relative imports for modules within the same package
from providers.base import BaseProvider
# Use absolute imports as per Ruff warning and user instructions
from src.llm_models import MODELS
from src.tools.conversion import convert_to_anthropic_tools
@@ -33,6 +31,126 @@ class AnthropicProvider(BaseProvider):
logger.error(f"Failed to initialize Anthropic client: {e}", exc_info=True)
raise
def _get_context_window(self, model: str) -> int:
"""Retrieves the context window size for a given Anthropic model."""
default_window = 100000 # Default fallback for Anthropic
try:
provider_models = MODELS.get("anthropic", {}).get("models", [])
for m in provider_models:
if m.get("id") == model:
return m.get("context_window", default_window)
logger.warning(f"Context window for Anthropic model '{model}' not found in MODELS config. Using default: {default_window}")
return default_window
except Exception as e:
logger.error(f"Error retrieving context window for model {model}: {e}. Using default: {default_window}", exc_info=True)
return default_window
def _count_anthropic_tokens(self, messages: list[dict[str, Any]], system_prompt: str | None) -> int:
"""Counts tokens for Anthropic messages using the official client."""
# Note: Anthropic's count_tokens might not directly accept the message list format used for creation.
# It often expects plain text. We need to concatenate the content appropriately.
# This is a simplification and might not be perfectly accurate, especially with tool calls/results.
# A more robust approach might involve formatting messages into a single string representation.
text_to_count = ""
if system_prompt:
text_to_count += f"System: {system_prompt}\n\n"
for message in messages:
role = message.get("role")
content = message.get("content")
# Simple concatenation - might need refinement for complex content types (tool calls/results)
if isinstance(content, str):
text_to_count += f"{role}: {content}\n"
elif isinstance(content, list): # Handle tool results/calls if represented as list
try:
content_str = json.dumps(content)
text_to_count += f"{role}: {content_str}\n"
except Exception:
text_to_count += f"{role}: [Unserializable Content]\n"
try:
# Use the client's count_tokens method if available and works with text
# Check Anthropic documentation for the correct usage
# Assuming self.client.count_tokens exists and takes text
count = self.client.count_tokens(text=text_to_count)
logger.debug(f"Counted Anthropic tokens using client.count_tokens: {count}")
return count
except APIError as api_err:
# Handle potential errors if count_tokens itself is an API call or fails
logger.error(f"Anthropic API error during token count: {api_err}", exc_info=True)
# Fallback to approximation if official count fails?
estimated_tokens = math.ceil(len(text_to_count) / 4.0) # Same approximation as OpenAI
logger.warning(f"Falling back to character count approximation for Anthropic: {estimated_tokens}")
return estimated_tokens
except AttributeError:
# Fallback if count_tokens method doesn't exist or works differently
logger.warning("self.client.count_tokens not available or failed. Falling back to character count approximation.")
estimated_tokens = math.ceil(len(text_to_count) / 4.0) # Same approximation as OpenAI
return estimated_tokens
except Exception as e:
logger.error(f"Unexpected error during Anthropic token count: {e}", exc_info=True)
estimated_tokens = math.ceil(len(text_to_count) / 4.0) # Fallback approximation
logger.warning(f"Falling back to character count approximation due to unexpected error: {estimated_tokens}")
return estimated_tokens
def _truncate_messages(self, messages: list[dict[str, Any]], system_prompt: str | None, model: str) -> tuple[list[dict[str, Any]], str | None, int, int]:
"""
Truncates messages for Anthropic, preserving system prompt.
Returns:
- Potentially truncated list of messages.
- Original system prompt (or None).
- Initial token count.
- Final token count.
"""
context_limit = self._get_context_window(model)
buffer = 200 # Safety buffer
effective_limit = context_limit - buffer
initial_token_count = self._count_anthropic_tokens(messages, system_prompt)
final_token_count = initial_token_count
truncated_messages = list(messages) # Copy
# Anthropic requires alternating user/assistant messages. Truncation needs care.
# We remove from the beginning (after potential system prompt).
# Removing the oldest message (index 0 of the list passed here, as system is separate)
while final_token_count > effective_limit and len(truncated_messages) > 0:
# Always remove the oldest message (index 0)
removed_message = truncated_messages.pop(0)
logger.debug(f"Truncating Anthropic message at index 0 (Role: {removed_message.get('role')}) due to context limit.")
# Ensure alternation after removal if possible (might be complex)
# For simplicity, just remove and recount for now.
# A more robust approach might need to remove pairs (user/assistant).
final_token_count = self._count_anthropic_tokens(truncated_messages, system_prompt)
logger.debug(f"Recalculated Anthropic tokens: {final_token_count}")
# Safety break
if not truncated_messages:
logger.warning("Truncation resulted in empty message list for Anthropic.")
break
if initial_token_count != final_token_count:
logger.info(
f"Truncated messages for Anthropic model {model}. Initial tokens: {initial_token_count}, Final tokens: {final_token_count}, Limit: {context_limit} (Effective: {effective_limit})"
)
else:
logger.debug(f"No truncation needed for Anthropic model {model}. Tokens: {final_token_count}, Limit: {context_limit} (Effective: {effective_limit})")
# Ensure the remaining messages start with 'user' role if no system prompt
if not system_prompt and truncated_messages and truncated_messages[0].get("role") != "user":
logger.warning("First message after truncation is not 'user'. Prepending placeholder.")
# This might indicate an issue with the simple pop(0) logic if pairs weren't removed.
# For now, prepend a basic user message.
truncated_messages.insert(0, {"role": "user", "content": "[Context truncated]"})
# Recount after adding placeholder? Might exceed limit again. Risky.
# Let's log a warning instead of adding potentially problematic content.
# logger.warning("First message after truncation is not 'user'. This might cause issues with Anthropic API.")
return truncated_messages, system_prompt, initial_token_count, final_token_count
def _convert_messages(self, messages: list[dict[str, Any]]) -> tuple[str | None, list[dict[str, Any]]]:
"""Converts standard message format to Anthropic's format, extracting system prompt."""
anthropic_messages = []
@@ -93,26 +211,33 @@ class AnthropicProvider(BaseProvider):
stream: bool = True,
tools: list[dict[str, Any]] | None = None,
) -> Stream[MessageStreamEvent] | Message:
"""Creates a chat completion using the Anthropic API."""
logger.debug(f"Anthropic create_chat_completion called. Stream: {stream}, Tools: {bool(tools)}")
"""Creates a chat completion using the Anthropic API, handling context truncation."""
logger.debug(f"Anthropic create_chat_completion called. Model: {model}, Stream: {stream}, Tools: {bool(tools)}")
# --- Context Truncation ---
# First, convert to Anthropic format to separate system prompt
temp_system_prompt, temp_anthropic_messages = self._convert_messages(messages)
# Then, truncate based on token count
truncated_anthropic_msgs, final_system_prompt, _, _ = self._truncate_messages(temp_anthropic_messages, temp_system_prompt, model)
# --------------------------
# Anthropic requires max_tokens
if max_tokens is None:
max_tokens = 4096 # Default value if not provided
logger.warning(f"max_tokens not provided for Anthropic, defaulting to {max_tokens}")
system_prompt, anthropic_messages = self._convert_messages(messages)
# system_prompt, anthropic_messages = self._convert_messages(messages) # Moved above
try:
completion_params = {
"model": model,
"messages": anthropic_messages,
"messages": truncated_anthropic_msgs, # Use truncated messages
"temperature": temperature,
"max_tokens": max_tokens,
"stream": stream,
}
if system_prompt:
completion_params["system"] = system_prompt
if final_system_prompt: # Use potentially modified system prompt
completion_params["system"] = final_system_prompt
if tools:
completion_params["tools"] = tools
# Anthropic doesn't have an explicit 'tool_choice' like OpenAI's 'auto' in the main API call
@@ -129,7 +254,22 @@ class AnthropicProvider(BaseProvider):
response = self.client.messages.create(**completion_params)
logger.debug("Anthropic API call successful.")
# --- Capture Actual Usage ---
actual_usage = None
if isinstance(response, Message) and response.usage:
actual_usage = {
"prompt_tokens": response.usage.input_tokens, # Anthropic uses input_tokens
"completion_tokens": response.usage.output_tokens, # Anthropic uses output_tokens
# Anthropic doesn't typically provide total_tokens directly in usage block
"total_tokens": response.usage.input_tokens + response.usage.output_tokens,
}
logger.info(f"Actual Anthropic API usage: {actual_usage}")
# TODO: How to get usage for streaming responses? Anthropic might send it in a final 'message_stop' event? Needs investigation.
return response
# --------------------------
except Exception as e:
logger.error(f"Anthropic API error: {e}", exc_info=True)
raise
@@ -293,3 +433,21 @@ class AnthropicProvider(BaseProvider):
except Exception as e:
logger.error(f"Error extracting original Anthropic message with calls: {e}", exc_info=True)
return {"role": "assistant", "content": f"[Error extracting tool calls message: {str(e)}]"}
def get_usage(self, response: Any) -> dict[str, int] | None:
"""Extracts token usage from a non-streaming Anthropic response."""
try:
if isinstance(response, Message) and response.usage:
usage = {
"prompt_tokens": response.usage.input_tokens,
"completion_tokens": response.usage.output_tokens,
# "total_tokens": response.usage.input_tokens + response.usage.output_tokens, # Optional
}
logger.debug(f"Extracted usage from Anthropic response: {usage}")
return usage
else:
logger.warning(f"Could not extract usage from Anthropic response object of type {type(response)}")
return None
except Exception as e:
logger.error(f"Error extracting usage from Anthropic response: {e}", exc_info=True)
return None

View File

@@ -134,6 +134,20 @@ class BaseProvider(abc.ABC):
"""
pass
@abc.abstractmethod
def get_usage(self, response: Any) -> dict[str, int] | None:
"""
Extracts token usage information from a non-streaming response object.
Args:
response: The non-streaming response object.
Returns:
A dictionary containing 'prompt_tokens' and 'completion_tokens',
or None if usage information is not available.
"""
pass
# Optional: Add a method for follow-up completions if the provider API
# requires a specific structure different from just appending messages.
# def create_follow_up_completion(...) -> Any:

View File

@@ -1,6 +1,7 @@
# src/providers/openai_provider.py
import json
import logging
import math
from collections.abc import Generator
from typing import Any
@@ -8,8 +9,8 @@ from openai import OpenAI, Stream
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall
from providers.base import BaseProvider
from src.llm_models import MODELS # Use absolute import
from src.llm_models import MODELS
from src.providers.base import BaseProvider
logger = logging.getLogger(__name__)
@@ -29,6 +30,110 @@ class OpenAIProvider(BaseProvider):
logger.error(f"Failed to initialize OpenAI client: {e}", exc_info=True)
raise
def _get_context_window(self, model: str) -> int:
"""Retrieves the context window size for a given model."""
# Default to a safe fallback if model or provider info is missing
default_window = 8000
try:
# Assuming MODELS structure: MODELS['openai']['models'] is a list of dicts
provider_models = MODELS.get("openai", {}).get("models", [])
for m in provider_models:
if m.get("id") == model:
return m.get("context_window", default_window)
# Fallback if specific model ID not found in our list
logger.warning(f"Context window for OpenAI model '{model}' not found in MODELS config. Using default: {default_window}")
return default_window
except Exception as e:
logger.error(f"Error retrieving context window for model {model}: {e}. Using default: {default_window}", exc_info=True)
return default_window
def _estimate_openai_token_count(self, messages: list[dict[str, str]]) -> int:
"""
Estimates the token count for OpenAI messages using char count / 4 approximation.
Note: This is less accurate than using tiktoken.
"""
total_chars = 0
for message in messages:
total_chars += len(message.get("role", ""))
content = message.get("content")
if isinstance(content, str):
total_chars += len(content)
# Rough approximation for function/tool call overhead if needed later
# Using math.ceil to round up, ensuring we don't underestimate too much.
estimated_tokens = math.ceil(total_chars / 4.0)
logger.debug(f"Estimated OpenAI token count (char/4): {estimated_tokens} for {len(messages)} messages")
return estimated_tokens
def _truncate_messages(self, messages: list[dict[str, str]], model: str) -> tuple[list[dict[str, str]], int, int]:
"""
Truncates messages from the beginning if estimated token count exceeds the limit.
Preserves the first message if it's a system prompt.
Returns:
- The potentially truncated list of messages.
- The initial estimated token count.
- The final estimated token count after truncation (if any).
"""
context_limit = self._get_context_window(model)
# Add a buffer to be safer with approximation
buffer = 200 # Reduce buffer slightly as we round up now
effective_limit = context_limit - buffer
initial_estimated_count = self._estimate_openai_token_count(messages)
final_estimated_count = initial_estimated_count
truncated_messages = list(messages) # Make a copy
# Identify if the first message is a system prompt
has_system_prompt = False
if truncated_messages and truncated_messages[0].get("role") == "system":
has_system_prompt = True
# If only system prompt exists, don't truncate further
if len(truncated_messages) == 1 and final_estimated_count > effective_limit:
logger.warning(f"System prompt alone ({final_estimated_count} tokens) exceeds effective limit ({effective_limit}). Cannot truncate further.")
# Return original messages to avoid removing the only message
return messages, initial_estimated_count, final_estimated_count
while final_estimated_count > effective_limit:
if has_system_prompt and len(truncated_messages) <= 1:
# Should not happen if check above works, but safety break
logger.warning("Truncation stopped: Only system prompt remains.")
break
if not has_system_prompt and len(truncated_messages) <= 0:
logger.warning("Truncation stopped: No messages left.")
break # No messages left
# Determine index to remove: 1 if system prompt exists and list is long enough, else 0
remove_index = 1 if has_system_prompt and len(truncated_messages) > 1 else 0
if remove_index >= len(truncated_messages):
logger.error(f"Truncation logic error: remove_index {remove_index} out of bounds for {len(truncated_messages)} messages.")
break # Avoid index error
removed_message = truncated_messages.pop(remove_index)
logger.debug(f"Truncating message at index {remove_index} (Role: {removed_message.get('role')}) due to context limit.")
# Recalculate estimated count
final_estimated_count = self._estimate_openai_token_count(truncated_messages)
logger.debug(f"Recalculated estimated tokens: {final_estimated_count}")
# Safety break if list becomes unexpectedly empty
if not truncated_messages:
logger.warning("Truncation resulted in empty message list.")
break
if initial_estimated_count != final_estimated_count:
logger.info(
f"Truncated messages for model {model}. "
f"Initial estimated tokens: {initial_estimated_count}, "
f"Final estimated tokens: {final_estimated_count}, "
f"Limit: {context_limit} (Effective: {effective_limit})"
)
else:
logger.debug(f"No truncation needed for model {model}. Estimated tokens: {final_estimated_count}, Limit: {context_limit} (Effective: {effective_limit})")
return truncated_messages, initial_estimated_count, final_estimated_count
def create_chat_completion(
self,
messages: list[dict[str, str]],
@@ -37,13 +142,19 @@ class OpenAIProvider(BaseProvider):
max_tokens: int | None = None,
stream: bool = True,
tools: list[dict[str, Any]] | None = None,
) -> Stream[ChatCompletionChunk] | ChatCompletion:
"""Creates a chat completion using the OpenAI API."""
logger.debug(f"OpenAI create_chat_completion called. Stream: {stream}, Tools: {bool(tools)}")
# Add usage dict to return type hint? Needs careful thought for streaming vs non-streaming
) -> Stream[ChatCompletionChunk] | ChatCompletion: # How to return usage info cleanly?
"""Creates a chat completion using the OpenAI API, handling context window truncation."""
logger.debug(f"OpenAI create_chat_completion called. Model: {model}, Stream: {stream}, Tools: {bool(tools)}")
# --- Truncation Step ---
truncated_messages, initial_est_tokens, final_est_tokens = self._truncate_messages(messages, model)
# -----------------------
try:
completion_params = {
"model": model,
"messages": messages,
"messages": truncated_messages, # Use truncated messages
"temperature": temperature,
"max_tokens": max_tokens,
"stream": stream,
@@ -71,7 +182,34 @@ class OpenAIProvider(BaseProvider):
response = self.client.chat.completions.create(**completion_params)
logger.debug("OpenAI API call successful.")
# --- Capture Actual Usage (for UI display later) ---
# This part is tricky. Usage info is easily available on the *non-streaming* response.
# For streaming, it's often not available until the stream is fully consumed,
# or sometimes via response headers or a final event (provider-dependent).
# For now, let's focus on getting it from the non-streaming case.
# We need a way to pass this back alongside the content/stream.
# Option 1: Modify return type (complex for stream/non-stream union)
# Option 2: Store it in the provider instance (stateful, maybe bad)
# Option 3: Have LLMClient handle extraction (requires LLMClient to know response structure)
# Let's try returning it alongside for non-streaming, and figure out streaming later.
# This requires changing the BaseProvider interface and LLMClient handling.
# For now, just log it here.
actual_usage = None
if isinstance(response, ChatCompletion) and response.usage:
actual_usage = {
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens,
}
logger.info(f"Actual OpenAI API usage: {actual_usage}")
# TODO: How to handle usage for streaming responses? Needs investigation.
# Return the raw response for now. LLMClient will process it.
return response
# ----------------------------------------------------
except Exception as e:
logger.error(f"OpenAI API error: {e}", exc_info=True)
# Re-raise for the LLMClient to handle
@@ -233,7 +371,20 @@ class OpenAIProvider(BaseProvider):
logger.error(f"Error extracting original message with calls: {e}", exc_info=True)
return {"role": "assistant", "content": f"[Error extracting tool calls message: {str(e)}]"}
# Register this provider (if using the registration mechanism)
# from . import register_provider
# register_provider("openai", OpenAIProvider)
def get_usage(self, response: Any) -> dict[str, int] | None:
"""Extracts token usage from a non-streaming OpenAI response."""
try:
if isinstance(response, ChatCompletion) and response.usage:
usage = {
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
# "total_tokens": response.usage.total_tokens, # Optional
}
logger.debug(f"Extracted usage from OpenAI response: {usage}")
return usage
else:
logger.warning(f"Could not extract usage from OpenAI response object of type {type(response)}")
return None
except Exception as e:
logger.error(f"Error extracting usage from OpenAI response: {e}", exc_info=True)
return None