51 lines
1.8 KiB
Python
51 lines
1.8 KiB
Python
import json
|
|
import logging
|
|
import math
|
|
from typing import Any
|
|
|
|
from anthropic import Anthropic
|
|
|
|
from src.llm_models import MODELS
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def get_context_window(model: str) -> int:
|
|
default_window = 100000
|
|
try:
|
|
provider_models = MODELS.get("anthropic", {}).get("models", [])
|
|
for m in provider_models:
|
|
if m.get("id") == model:
|
|
return m.get("context_window", default_window)
|
|
logger.warning(f"Context window for Anthropic model '{model}' not found. Using default: {default_window}")
|
|
return default_window
|
|
except Exception as e:
|
|
logger.error(f"Error retrieving context window for model {model}: {e}. Using default: {default_window}", exc_info=True)
|
|
return default_window
|
|
|
|
|
|
def count_anthropic_tokens(client: Anthropic, messages: list[dict[str, Any]], system_prompt: str | None) -> int:
|
|
text_to_count = ""
|
|
if system_prompt:
|
|
text_to_count += f"System: {system_prompt}\n\n"
|
|
for message in messages:
|
|
role = message.get("role")
|
|
content = message.get("content")
|
|
if isinstance(content, str):
|
|
text_to_count += f"{role}: {content}\n"
|
|
elif isinstance(content, list):
|
|
try:
|
|
content_str = json.dumps(content)
|
|
text_to_count += f"{role}: {content_str}\n"
|
|
except Exception:
|
|
text_to_count += f"{role}: [Unserializable Content]\n"
|
|
try:
|
|
count = client.count_tokens(text=text_to_count)
|
|
logger.debug(f"Counted Anthropic tokens: {count}")
|
|
return count
|
|
except Exception as e:
|
|
logger.error(f"Error counting Anthropic tokens: {e}", exc_info=True)
|
|
estimated_tokens = math.ceil(len(text_to_count) / 4.0)
|
|
logger.warning(f"Falling back to approximation: {estimated_tokens}")
|
|
return estimated_tokens
|