feat: enhance token usage tracking and context management for LLM providers
This commit is contained in:
@@ -1,16 +1,14 @@
|
||||
# src/providers/anthropic_provider.py
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from anthropic import Anthropic, Stream
|
||||
from anthropic import Anthropic, APIError, Stream
|
||||
from anthropic.types import Message, MessageStreamEvent, TextDelta
|
||||
|
||||
# Use relative imports for modules within the same package
|
||||
from providers.base import BaseProvider
|
||||
|
||||
# Use absolute imports as per Ruff warning and user instructions
|
||||
from src.llm_models import MODELS
|
||||
from src.tools.conversion import convert_to_anthropic_tools
|
||||
|
||||
@@ -33,6 +31,126 @@ class AnthropicProvider(BaseProvider):
|
||||
logger.error(f"Failed to initialize Anthropic client: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def _get_context_window(self, model: str) -> int:
|
||||
"""Retrieves the context window size for a given Anthropic model."""
|
||||
default_window = 100000 # Default fallback for Anthropic
|
||||
try:
|
||||
provider_models = MODELS.get("anthropic", {}).get("models", [])
|
||||
for m in provider_models:
|
||||
if m.get("id") == model:
|
||||
return m.get("context_window", default_window)
|
||||
logger.warning(f"Context window for Anthropic model '{model}' not found in MODELS config. Using default: {default_window}")
|
||||
return default_window
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving context window for model {model}: {e}. Using default: {default_window}", exc_info=True)
|
||||
return default_window
|
||||
|
||||
def _count_anthropic_tokens(self, messages: list[dict[str, Any]], system_prompt: str | None) -> int:
|
||||
"""Counts tokens for Anthropic messages using the official client."""
|
||||
# Note: Anthropic's count_tokens might not directly accept the message list format used for creation.
|
||||
# It often expects plain text. We need to concatenate the content appropriately.
|
||||
# This is a simplification and might not be perfectly accurate, especially with tool calls/results.
|
||||
# A more robust approach might involve formatting messages into a single string representation.
|
||||
text_to_count = ""
|
||||
if system_prompt:
|
||||
text_to_count += f"System: {system_prompt}\n\n"
|
||||
for message in messages:
|
||||
role = message.get("role")
|
||||
content = message.get("content")
|
||||
# Simple concatenation - might need refinement for complex content types (tool calls/results)
|
||||
if isinstance(content, str):
|
||||
text_to_count += f"{role}: {content}\n"
|
||||
elif isinstance(content, list): # Handle tool results/calls if represented as list
|
||||
try:
|
||||
content_str = json.dumps(content)
|
||||
text_to_count += f"{role}: {content_str}\n"
|
||||
except Exception:
|
||||
text_to_count += f"{role}: [Unserializable Content]\n"
|
||||
|
||||
try:
|
||||
# Use the client's count_tokens method if available and works with text
|
||||
# Check Anthropic documentation for the correct usage
|
||||
# Assuming self.client.count_tokens exists and takes text
|
||||
count = self.client.count_tokens(text=text_to_count)
|
||||
logger.debug(f"Counted Anthropic tokens using client.count_tokens: {count}")
|
||||
return count
|
||||
except APIError as api_err:
|
||||
# Handle potential errors if count_tokens itself is an API call or fails
|
||||
logger.error(f"Anthropic API error during token count: {api_err}", exc_info=True)
|
||||
# Fallback to approximation if official count fails?
|
||||
estimated_tokens = math.ceil(len(text_to_count) / 4.0) # Same approximation as OpenAI
|
||||
logger.warning(f"Falling back to character count approximation for Anthropic: {estimated_tokens}")
|
||||
return estimated_tokens
|
||||
except AttributeError:
|
||||
# Fallback if count_tokens method doesn't exist or works differently
|
||||
logger.warning("self.client.count_tokens not available or failed. Falling back to character count approximation.")
|
||||
estimated_tokens = math.ceil(len(text_to_count) / 4.0) # Same approximation as OpenAI
|
||||
return estimated_tokens
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during Anthropic token count: {e}", exc_info=True)
|
||||
estimated_tokens = math.ceil(len(text_to_count) / 4.0) # Fallback approximation
|
||||
logger.warning(f"Falling back to character count approximation due to unexpected error: {estimated_tokens}")
|
||||
return estimated_tokens
|
||||
|
||||
def _truncate_messages(self, messages: list[dict[str, Any]], system_prompt: str | None, model: str) -> tuple[list[dict[str, Any]], str | None, int, int]:
|
||||
"""
|
||||
Truncates messages for Anthropic, preserving system prompt.
|
||||
|
||||
Returns:
|
||||
- Potentially truncated list of messages.
|
||||
- Original system prompt (or None).
|
||||
- Initial token count.
|
||||
- Final token count.
|
||||
"""
|
||||
context_limit = self._get_context_window(model)
|
||||
buffer = 200 # Safety buffer
|
||||
effective_limit = context_limit - buffer
|
||||
|
||||
initial_token_count = self._count_anthropic_tokens(messages, system_prompt)
|
||||
final_token_count = initial_token_count
|
||||
|
||||
truncated_messages = list(messages) # Copy
|
||||
|
||||
# Anthropic requires alternating user/assistant messages. Truncation needs care.
|
||||
# We remove from the beginning (after potential system prompt).
|
||||
# Removing the oldest message (index 0 of the list passed here, as system is separate)
|
||||
|
||||
while final_token_count > effective_limit and len(truncated_messages) > 0:
|
||||
# Always remove the oldest message (index 0)
|
||||
removed_message = truncated_messages.pop(0)
|
||||
logger.debug(f"Truncating Anthropic message at index 0 (Role: {removed_message.get('role')}) due to context limit.")
|
||||
|
||||
# Ensure alternation after removal if possible (might be complex)
|
||||
# For simplicity, just remove and recount for now.
|
||||
# A more robust approach might need to remove pairs (user/assistant).
|
||||
|
||||
final_token_count = self._count_anthropic_tokens(truncated_messages, system_prompt)
|
||||
logger.debug(f"Recalculated Anthropic tokens: {final_token_count}")
|
||||
|
||||
# Safety break
|
||||
if not truncated_messages:
|
||||
logger.warning("Truncation resulted in empty message list for Anthropic.")
|
||||
break
|
||||
|
||||
if initial_token_count != final_token_count:
|
||||
logger.info(
|
||||
f"Truncated messages for Anthropic model {model}. Initial tokens: {initial_token_count}, Final tokens: {final_token_count}, Limit: {context_limit} (Effective: {effective_limit})"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"No truncation needed for Anthropic model {model}. Tokens: {final_token_count}, Limit: {context_limit} (Effective: {effective_limit})")
|
||||
|
||||
# Ensure the remaining messages start with 'user' role if no system prompt
|
||||
if not system_prompt and truncated_messages and truncated_messages[0].get("role") != "user":
|
||||
logger.warning("First message after truncation is not 'user'. Prepending placeholder.")
|
||||
# This might indicate an issue with the simple pop(0) logic if pairs weren't removed.
|
||||
# For now, prepend a basic user message.
|
||||
truncated_messages.insert(0, {"role": "user", "content": "[Context truncated]"})
|
||||
# Recount after adding placeholder? Might exceed limit again. Risky.
|
||||
# Let's log a warning instead of adding potentially problematic content.
|
||||
# logger.warning("First message after truncation is not 'user'. This might cause issues with Anthropic API.")
|
||||
|
||||
return truncated_messages, system_prompt, initial_token_count, final_token_count
|
||||
|
||||
def _convert_messages(self, messages: list[dict[str, Any]]) -> tuple[str | None, list[dict[str, Any]]]:
|
||||
"""Converts standard message format to Anthropic's format, extracting system prompt."""
|
||||
anthropic_messages = []
|
||||
@@ -93,26 +211,33 @@ class AnthropicProvider(BaseProvider):
|
||||
stream: bool = True,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
) -> Stream[MessageStreamEvent] | Message:
|
||||
"""Creates a chat completion using the Anthropic API."""
|
||||
logger.debug(f"Anthropic create_chat_completion called. Stream: {stream}, Tools: {bool(tools)}")
|
||||
"""Creates a chat completion using the Anthropic API, handling context truncation."""
|
||||
logger.debug(f"Anthropic create_chat_completion called. Model: {model}, Stream: {stream}, Tools: {bool(tools)}")
|
||||
|
||||
# --- Context Truncation ---
|
||||
# First, convert to Anthropic format to separate system prompt
|
||||
temp_system_prompt, temp_anthropic_messages = self._convert_messages(messages)
|
||||
# Then, truncate based on token count
|
||||
truncated_anthropic_msgs, final_system_prompt, _, _ = self._truncate_messages(temp_anthropic_messages, temp_system_prompt, model)
|
||||
# --------------------------
|
||||
|
||||
# Anthropic requires max_tokens
|
||||
if max_tokens is None:
|
||||
max_tokens = 4096 # Default value if not provided
|
||||
logger.warning(f"max_tokens not provided for Anthropic, defaulting to {max_tokens}")
|
||||
|
||||
system_prompt, anthropic_messages = self._convert_messages(messages)
|
||||
# system_prompt, anthropic_messages = self._convert_messages(messages) # Moved above
|
||||
|
||||
try:
|
||||
completion_params = {
|
||||
"model": model,
|
||||
"messages": anthropic_messages,
|
||||
"messages": truncated_anthropic_msgs, # Use truncated messages
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"stream": stream,
|
||||
}
|
||||
if system_prompt:
|
||||
completion_params["system"] = system_prompt
|
||||
if final_system_prompt: # Use potentially modified system prompt
|
||||
completion_params["system"] = final_system_prompt
|
||||
if tools:
|
||||
completion_params["tools"] = tools
|
||||
# Anthropic doesn't have an explicit 'tool_choice' like OpenAI's 'auto' in the main API call
|
||||
@@ -129,7 +254,22 @@ class AnthropicProvider(BaseProvider):
|
||||
|
||||
response = self.client.messages.create(**completion_params)
|
||||
logger.debug("Anthropic API call successful.")
|
||||
|
||||
# --- Capture Actual Usage ---
|
||||
actual_usage = None
|
||||
if isinstance(response, Message) and response.usage:
|
||||
actual_usage = {
|
||||
"prompt_tokens": response.usage.input_tokens, # Anthropic uses input_tokens
|
||||
"completion_tokens": response.usage.output_tokens, # Anthropic uses output_tokens
|
||||
# Anthropic doesn't typically provide total_tokens directly in usage block
|
||||
"total_tokens": response.usage.input_tokens + response.usage.output_tokens,
|
||||
}
|
||||
logger.info(f"Actual Anthropic API usage: {actual_usage}")
|
||||
# TODO: How to get usage for streaming responses? Anthropic might send it in a final 'message_stop' event? Needs investigation.
|
||||
|
||||
return response
|
||||
# --------------------------
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Anthropic API error: {e}", exc_info=True)
|
||||
raise
|
||||
@@ -293,3 +433,21 @@ class AnthropicProvider(BaseProvider):
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting original Anthropic message with calls: {e}", exc_info=True)
|
||||
return {"role": "assistant", "content": f"[Error extracting tool calls message: {str(e)}]"}
|
||||
|
||||
def get_usage(self, response: Any) -> dict[str, int] | None:
|
||||
"""Extracts token usage from a non-streaming Anthropic response."""
|
||||
try:
|
||||
if isinstance(response, Message) and response.usage:
|
||||
usage = {
|
||||
"prompt_tokens": response.usage.input_tokens,
|
||||
"completion_tokens": response.usage.output_tokens,
|
||||
# "total_tokens": response.usage.input_tokens + response.usage.output_tokens, # Optional
|
||||
}
|
||||
logger.debug(f"Extracted usage from Anthropic response: {usage}")
|
||||
return usage
|
||||
else:
|
||||
logger.warning(f"Could not extract usage from Anthropic response object of type {type(response)}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting usage from Anthropic response: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user