feat: enhance token usage tracking and context management for LLM providers

This commit is contained in:
2025-03-26 17:27:41 +00:00
parent 49aebc12d5
commit 15ecb9fc48
5 changed files with 395 additions and 68 deletions

View File

@@ -1,16 +1,14 @@
# src/providers/anthropic_provider.py
import json
import logging
import math
from collections.abc import Generator
from typing import Any
from anthropic import Anthropic, Stream
from anthropic import Anthropic, APIError, Stream
from anthropic.types import Message, MessageStreamEvent, TextDelta
# Use relative imports for modules within the same package
from providers.base import BaseProvider
# Use absolute imports as per Ruff warning and user instructions
from src.llm_models import MODELS
from src.tools.conversion import convert_to_anthropic_tools
@@ -33,6 +31,126 @@ class AnthropicProvider(BaseProvider):
logger.error(f"Failed to initialize Anthropic client: {e}", exc_info=True)
raise
def _get_context_window(self, model: str) -> int:
"""Retrieves the context window size for a given Anthropic model."""
default_window = 100000 # Default fallback for Anthropic
try:
provider_models = MODELS.get("anthropic", {}).get("models", [])
for m in provider_models:
if m.get("id") == model:
return m.get("context_window", default_window)
logger.warning(f"Context window for Anthropic model '{model}' not found in MODELS config. Using default: {default_window}")
return default_window
except Exception as e:
logger.error(f"Error retrieving context window for model {model}: {e}. Using default: {default_window}", exc_info=True)
return default_window
def _count_anthropic_tokens(self, messages: list[dict[str, Any]], system_prompt: str | None) -> int:
"""Counts tokens for Anthropic messages using the official client."""
# Note: Anthropic's count_tokens might not directly accept the message list format used for creation.
# It often expects plain text. We need to concatenate the content appropriately.
# This is a simplification and might not be perfectly accurate, especially with tool calls/results.
# A more robust approach might involve formatting messages into a single string representation.
text_to_count = ""
if system_prompt:
text_to_count += f"System: {system_prompt}\n\n"
for message in messages:
role = message.get("role")
content = message.get("content")
# Simple concatenation - might need refinement for complex content types (tool calls/results)
if isinstance(content, str):
text_to_count += f"{role}: {content}\n"
elif isinstance(content, list): # Handle tool results/calls if represented as list
try:
content_str = json.dumps(content)
text_to_count += f"{role}: {content_str}\n"
except Exception:
text_to_count += f"{role}: [Unserializable Content]\n"
try:
# Use the client's count_tokens method if available and works with text
# Check Anthropic documentation for the correct usage
# Assuming self.client.count_tokens exists and takes text
count = self.client.count_tokens(text=text_to_count)
logger.debug(f"Counted Anthropic tokens using client.count_tokens: {count}")
return count
except APIError as api_err:
# Handle potential errors if count_tokens itself is an API call or fails
logger.error(f"Anthropic API error during token count: {api_err}", exc_info=True)
# Fallback to approximation if official count fails?
estimated_tokens = math.ceil(len(text_to_count) / 4.0) # Same approximation as OpenAI
logger.warning(f"Falling back to character count approximation for Anthropic: {estimated_tokens}")
return estimated_tokens
except AttributeError:
# Fallback if count_tokens method doesn't exist or works differently
logger.warning("self.client.count_tokens not available or failed. Falling back to character count approximation.")
estimated_tokens = math.ceil(len(text_to_count) / 4.0) # Same approximation as OpenAI
return estimated_tokens
except Exception as e:
logger.error(f"Unexpected error during Anthropic token count: {e}", exc_info=True)
estimated_tokens = math.ceil(len(text_to_count) / 4.0) # Fallback approximation
logger.warning(f"Falling back to character count approximation due to unexpected error: {estimated_tokens}")
return estimated_tokens
def _truncate_messages(self, messages: list[dict[str, Any]], system_prompt: str | None, model: str) -> tuple[list[dict[str, Any]], str | None, int, int]:
"""
Truncates messages for Anthropic, preserving system prompt.
Returns:
- Potentially truncated list of messages.
- Original system prompt (or None).
- Initial token count.
- Final token count.
"""
context_limit = self._get_context_window(model)
buffer = 200 # Safety buffer
effective_limit = context_limit - buffer
initial_token_count = self._count_anthropic_tokens(messages, system_prompt)
final_token_count = initial_token_count
truncated_messages = list(messages) # Copy
# Anthropic requires alternating user/assistant messages. Truncation needs care.
# We remove from the beginning (after potential system prompt).
# Removing the oldest message (index 0 of the list passed here, as system is separate)
while final_token_count > effective_limit and len(truncated_messages) > 0:
# Always remove the oldest message (index 0)
removed_message = truncated_messages.pop(0)
logger.debug(f"Truncating Anthropic message at index 0 (Role: {removed_message.get('role')}) due to context limit.")
# Ensure alternation after removal if possible (might be complex)
# For simplicity, just remove and recount for now.
# A more robust approach might need to remove pairs (user/assistant).
final_token_count = self._count_anthropic_tokens(truncated_messages, system_prompt)
logger.debug(f"Recalculated Anthropic tokens: {final_token_count}")
# Safety break
if not truncated_messages:
logger.warning("Truncation resulted in empty message list for Anthropic.")
break
if initial_token_count != final_token_count:
logger.info(
f"Truncated messages for Anthropic model {model}. Initial tokens: {initial_token_count}, Final tokens: {final_token_count}, Limit: {context_limit} (Effective: {effective_limit})"
)
else:
logger.debug(f"No truncation needed for Anthropic model {model}. Tokens: {final_token_count}, Limit: {context_limit} (Effective: {effective_limit})")
# Ensure the remaining messages start with 'user' role if no system prompt
if not system_prompt and truncated_messages and truncated_messages[0].get("role") != "user":
logger.warning("First message after truncation is not 'user'. Prepending placeholder.")
# This might indicate an issue with the simple pop(0) logic if pairs weren't removed.
# For now, prepend a basic user message.
truncated_messages.insert(0, {"role": "user", "content": "[Context truncated]"})
# Recount after adding placeholder? Might exceed limit again. Risky.
# Let's log a warning instead of adding potentially problematic content.
# logger.warning("First message after truncation is not 'user'. This might cause issues with Anthropic API.")
return truncated_messages, system_prompt, initial_token_count, final_token_count
def _convert_messages(self, messages: list[dict[str, Any]]) -> tuple[str | None, list[dict[str, Any]]]:
"""Converts standard message format to Anthropic's format, extracting system prompt."""
anthropic_messages = []
@@ -93,26 +211,33 @@ class AnthropicProvider(BaseProvider):
stream: bool = True,
tools: list[dict[str, Any]] | None = None,
) -> Stream[MessageStreamEvent] | Message:
"""Creates a chat completion using the Anthropic API."""
logger.debug(f"Anthropic create_chat_completion called. Stream: {stream}, Tools: {bool(tools)}")
"""Creates a chat completion using the Anthropic API, handling context truncation."""
logger.debug(f"Anthropic create_chat_completion called. Model: {model}, Stream: {stream}, Tools: {bool(tools)}")
# --- Context Truncation ---
# First, convert to Anthropic format to separate system prompt
temp_system_prompt, temp_anthropic_messages = self._convert_messages(messages)
# Then, truncate based on token count
truncated_anthropic_msgs, final_system_prompt, _, _ = self._truncate_messages(temp_anthropic_messages, temp_system_prompt, model)
# --------------------------
# Anthropic requires max_tokens
if max_tokens is None:
max_tokens = 4096 # Default value if not provided
logger.warning(f"max_tokens not provided for Anthropic, defaulting to {max_tokens}")
system_prompt, anthropic_messages = self._convert_messages(messages)
# system_prompt, anthropic_messages = self._convert_messages(messages) # Moved above
try:
completion_params = {
"model": model,
"messages": anthropic_messages,
"messages": truncated_anthropic_msgs, # Use truncated messages
"temperature": temperature,
"max_tokens": max_tokens,
"stream": stream,
}
if system_prompt:
completion_params["system"] = system_prompt
if final_system_prompt: # Use potentially modified system prompt
completion_params["system"] = final_system_prompt
if tools:
completion_params["tools"] = tools
# Anthropic doesn't have an explicit 'tool_choice' like OpenAI's 'auto' in the main API call
@@ -129,7 +254,22 @@ class AnthropicProvider(BaseProvider):
response = self.client.messages.create(**completion_params)
logger.debug("Anthropic API call successful.")
# --- Capture Actual Usage ---
actual_usage = None
if isinstance(response, Message) and response.usage:
actual_usage = {
"prompt_tokens": response.usage.input_tokens, # Anthropic uses input_tokens
"completion_tokens": response.usage.output_tokens, # Anthropic uses output_tokens
# Anthropic doesn't typically provide total_tokens directly in usage block
"total_tokens": response.usage.input_tokens + response.usage.output_tokens,
}
logger.info(f"Actual Anthropic API usage: {actual_usage}")
# TODO: How to get usage for streaming responses? Anthropic might send it in a final 'message_stop' event? Needs investigation.
return response
# --------------------------
except Exception as e:
logger.error(f"Anthropic API error: {e}", exc_info=True)
raise
@@ -293,3 +433,21 @@ class AnthropicProvider(BaseProvider):
except Exception as e:
logger.error(f"Error extracting original Anthropic message with calls: {e}", exc_info=True)
return {"role": "assistant", "content": f"[Error extracting tool calls message: {str(e)}]"}
def get_usage(self, response: Any) -> dict[str, int] | None:
"""Extracts token usage from a non-streaming Anthropic response."""
try:
if isinstance(response, Message) and response.usage:
usage = {
"prompt_tokens": response.usage.input_tokens,
"completion_tokens": response.usage.output_tokens,
# "total_tokens": response.usage.input_tokens + response.usage.output_tokens, # Optional
}
logger.debug(f"Extracted usage from Anthropic response: {usage}")
return usage
else:
logger.warning(f"Could not extract usage from Anthropic response object of type {type(response)}")
return None
except Exception as e:
logger.error(f"Error extracting usage from Anthropic response: {e}", exc_info=True)
return None