feat: enhance token usage tracking and context management for LLM providers
This commit is contained in:
81
src/app.py
81
src/app.py
@@ -1,6 +1,5 @@
|
|||||||
import atexit
|
import atexit
|
||||||
import configparser
|
import configparser
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
@@ -99,8 +98,14 @@ def display_chat_messages():
|
|||||||
"""Displays chat messages stored in session state."""
|
"""Displays chat messages stored in session state."""
|
||||||
for message in st.session_state.messages:
|
for message in st.session_state.messages:
|
||||||
with st.chat_message(message["role"]):
|
with st.chat_message(message["role"]):
|
||||||
# Simple markdown display for now
|
# Display content
|
||||||
st.markdown(message["content"])
|
st.markdown(message["content"])
|
||||||
|
# Display usage if available (for assistant messages)
|
||||||
|
if message["role"] == "assistant" and "usage" in message:
|
||||||
|
usage = message["usage"]
|
||||||
|
prompt_tokens = usage.get("prompt_tokens", "N/A")
|
||||||
|
completion_tokens = usage.get("completion_tokens", "N/A")
|
||||||
|
st.caption(f"Tokens: Prompt {prompt_tokens}, Completion {completion_tokens}")
|
||||||
|
|
||||||
|
|
||||||
def handle_user_input():
|
def handle_user_input():
|
||||||
@@ -116,60 +121,50 @@ def handle_user_input():
|
|||||||
response_placeholder = st.empty()
|
response_placeholder = st.empty()
|
||||||
full_response = ""
|
full_response = ""
|
||||||
error_occurred = False
|
error_occurred = False
|
||||||
|
response_usage = None # Initialize usage info
|
||||||
|
|
||||||
logger.info("Processing message via LLMClient...")
|
logger.info("Processing message via LLMClient...")
|
||||||
# Use the new client and method, always requesting stream for UI
|
# Use the new client and method
|
||||||
response_stream = st.session_state.client.chat_completion(
|
# NOTE: Setting stream=False to easily get usage info from the response dict.
|
||||||
|
# A more complex solution is needed to get usage with streaming.
|
||||||
|
response_data = st.session_state.client.chat_completion(
|
||||||
messages=st.session_state.messages,
|
messages=st.session_state.messages,
|
||||||
model=st.session_state.model_name, # Get model from session state
|
model=st.session_state.model_name,
|
||||||
stream=True,
|
stream=False, # Set to False for usage info
|
||||||
)
|
)
|
||||||
|
|
||||||
# Handle the response (stream generator or error dict)
|
# Handle the response (now expecting a dict)
|
||||||
if hasattr(response_stream, "__iter__") and not isinstance(response_stream, dict):
|
if isinstance(response_data, dict):
|
||||||
logger.debug("Processing response stream...")
|
if "error" in response_data:
|
||||||
for chunk in response_stream:
|
full_response = f"Error: {response_data['error']}"
|
||||||
# Check for potential error JSON yielded by the stream
|
logger.error(f"Error returned from chat_completion: {full_response}")
|
||||||
try:
|
|
||||||
# Attempt to parse chunk as JSON only if it looks like it
|
|
||||||
if isinstance(chunk, str) and chunk.strip().startswith("{"):
|
|
||||||
error_data = json.loads(chunk)
|
|
||||||
if isinstance(error_data, dict) and "error" in error_data:
|
|
||||||
full_response = f"Error: {error_data['error']}"
|
|
||||||
logger.error(f"Error received in stream: {full_response}")
|
|
||||||
st.error(full_response)
|
|
||||||
error_occurred = True
|
|
||||||
break # Stop processing stream on error
|
|
||||||
# If not error JSON, treat as content chunk
|
|
||||||
if not error_occurred and isinstance(chunk, str):
|
|
||||||
full_response += chunk
|
|
||||||
response_placeholder.markdown(full_response + "▌") # Add cursor effect
|
|
||||||
except (json.JSONDecodeError, TypeError):
|
|
||||||
# Not JSON or not error structure, treat as content chunk
|
|
||||||
if not error_occurred and isinstance(chunk, str):
|
|
||||||
full_response += chunk
|
|
||||||
response_placeholder.markdown(full_response + "▌") # Add cursor effect
|
|
||||||
|
|
||||||
if not error_occurred:
|
|
||||||
response_placeholder.markdown(full_response) # Final update without cursor
|
|
||||||
logger.debug("Stream processing complete.")
|
|
||||||
|
|
||||||
elif isinstance(response_stream, dict) and "error" in response_stream:
|
|
||||||
# Handle error dict returned directly (e.g., API error before streaming)
|
|
||||||
full_response = f"Error: {response_stream['error']}"
|
|
||||||
logger.error(f"Error returned directly from chat_completion: {full_response}")
|
|
||||||
st.error(full_response)
|
st.error(full_response)
|
||||||
error_occurred = True
|
error_occurred = True
|
||||||
|
else:
|
||||||
|
full_response = response_data.get("content", "")
|
||||||
|
response_usage = response_data.get("usage") # Get usage dict
|
||||||
|
if not full_response and not error_occurred: # Check error_occurred flag too
|
||||||
|
logger.warning("Empty content received from LLMClient.")
|
||||||
|
# Display nothing or a placeholder? Let's display nothing.
|
||||||
|
# full_response = "[Empty Response]"
|
||||||
|
# Display the full response at once (no streaming)
|
||||||
|
response_placeholder.markdown(full_response)
|
||||||
|
logger.debug("Non-streaming response processed.")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Unexpected response type
|
# Unexpected response type
|
||||||
full_response = "[Unexpected response format from LLMClient]"
|
full_response = "[Unexpected response format from LLMClient]"
|
||||||
logger.error(f"Unexpected response type: {type(response_stream)}")
|
logger.error(f"Unexpected response type: {type(response_data)}")
|
||||||
st.error(full_response)
|
st.error(full_response)
|
||||||
error_occurred = True
|
error_occurred = True
|
||||||
|
|
||||||
# Only add non-error, non-empty responses to history
|
# Add response to history, including usage if available
|
||||||
if not error_occurred and full_response:
|
if not error_occurred and full_response: # Only add if no error and content exists
|
||||||
st.session_state.messages.append({"role": "assistant", "content": full_response})
|
assistant_message = {"role": "assistant", "content": full_response}
|
||||||
|
if response_usage:
|
||||||
|
assistant_message["usage"] = response_usage
|
||||||
|
logger.info(f"Assistant response usage: {response_usage}")
|
||||||
|
st.session_state.messages.append(assistant_message)
|
||||||
logger.info("Assistant response added to history.")
|
logger.info("Assistant response added to history.")
|
||||||
elif error_occurred:
|
elif error_occurred:
|
||||||
logger.warning("Assistant response not added to history due to error.")
|
logger.warning("Assistant response not added to history due to error.")
|
||||||
|
|||||||
@@ -72,8 +72,9 @@ class LLMClient:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
If stream=True: A generator yielding content chunks.
|
If stream=True: A generator yielding content chunks.
|
||||||
If stream=False: A dictionary containing the final content or an error.
|
If stream=False: A dictionary containing the final content, usage, or an error.
|
||||||
e.g., {"content": "..."} or {"error": "..."}
|
e.g., {"content": "...", "usage": {"prompt_tokens": ..., "completion_tokens": ...}}
|
||||||
|
or {"error": "..."}
|
||||||
"""
|
"""
|
||||||
# Ensure tools are up-to-date (optional, could be done less frequently)
|
# Ensure tools are up-to-date (optional, could be done less frequently)
|
||||||
# self._refresh_mcp_tools()
|
# self._refresh_mcp_tools()
|
||||||
@@ -173,8 +174,12 @@ class LLMClient:
|
|||||||
tools=provider_tools, # Pass tools again? Some providers might need it.
|
tools=provider_tools, # Pass tools again? Some providers might need it.
|
||||||
)
|
)
|
||||||
final_content = self.provider.get_content(follow_up_response)
|
final_content = self.provider.get_content(follow_up_response)
|
||||||
|
final_usage = self.provider.get_usage(follow_up_response) # Get usage from follow-up
|
||||||
logger.info("Received follow-up response content.")
|
logger.info("Received follow-up response content.")
|
||||||
return {"content": final_content}
|
result_dict = {"content": final_content}
|
||||||
|
if final_usage:
|
||||||
|
result_dict["usage"] = final_usage
|
||||||
|
return result_dict
|
||||||
|
|
||||||
except Exception as tool_handling_err:
|
except Exception as tool_handling_err:
|
||||||
logger.error(f"Error processing tool calls: {tool_handling_err}", exc_info=True)
|
logger.error(f"Error processing tool calls: {tool_handling_err}", exc_info=True)
|
||||||
@@ -183,7 +188,11 @@ class LLMClient:
|
|||||||
else: # No tool calls
|
else: # No tool calls
|
||||||
logger.info("No tool calls detected.")
|
logger.info("No tool calls detected.")
|
||||||
content = self.provider.get_content(response)
|
content = self.provider.get_content(response)
|
||||||
return {"content": content}
|
usage = self.provider.get_usage(response) # Get usage from initial response
|
||||||
|
result_dict = {"content": content}
|
||||||
|
if usage:
|
||||||
|
result_dict["usage"] = usage
|
||||||
|
return result_dict
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"LLM API Error: {str(e)}"
|
error_msg = f"LLM API Error: {str(e)}"
|
||||||
|
|||||||
@@ -1,16 +1,14 @@
|
|||||||
# src/providers/anthropic_provider.py
|
# src/providers/anthropic_provider.py
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from anthropic import Anthropic, Stream
|
from anthropic import Anthropic, APIError, Stream
|
||||||
from anthropic.types import Message, MessageStreamEvent, TextDelta
|
from anthropic.types import Message, MessageStreamEvent, TextDelta
|
||||||
|
|
||||||
# Use relative imports for modules within the same package
|
|
||||||
from providers.base import BaseProvider
|
from providers.base import BaseProvider
|
||||||
|
|
||||||
# Use absolute imports as per Ruff warning and user instructions
|
|
||||||
from src.llm_models import MODELS
|
from src.llm_models import MODELS
|
||||||
from src.tools.conversion import convert_to_anthropic_tools
|
from src.tools.conversion import convert_to_anthropic_tools
|
||||||
|
|
||||||
@@ -33,6 +31,126 @@ class AnthropicProvider(BaseProvider):
|
|||||||
logger.error(f"Failed to initialize Anthropic client: {e}", exc_info=True)
|
logger.error(f"Failed to initialize Anthropic client: {e}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
def _get_context_window(self, model: str) -> int:
|
||||||
|
"""Retrieves the context window size for a given Anthropic model."""
|
||||||
|
default_window = 100000 # Default fallback for Anthropic
|
||||||
|
try:
|
||||||
|
provider_models = MODELS.get("anthropic", {}).get("models", [])
|
||||||
|
for m in provider_models:
|
||||||
|
if m.get("id") == model:
|
||||||
|
return m.get("context_window", default_window)
|
||||||
|
logger.warning(f"Context window for Anthropic model '{model}' not found in MODELS config. Using default: {default_window}")
|
||||||
|
return default_window
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error retrieving context window for model {model}: {e}. Using default: {default_window}", exc_info=True)
|
||||||
|
return default_window
|
||||||
|
|
||||||
|
def _count_anthropic_tokens(self, messages: list[dict[str, Any]], system_prompt: str | None) -> int:
|
||||||
|
"""Counts tokens for Anthropic messages using the official client."""
|
||||||
|
# Note: Anthropic's count_tokens might not directly accept the message list format used for creation.
|
||||||
|
# It often expects plain text. We need to concatenate the content appropriately.
|
||||||
|
# This is a simplification and might not be perfectly accurate, especially with tool calls/results.
|
||||||
|
# A more robust approach might involve formatting messages into a single string representation.
|
||||||
|
text_to_count = ""
|
||||||
|
if system_prompt:
|
||||||
|
text_to_count += f"System: {system_prompt}\n\n"
|
||||||
|
for message in messages:
|
||||||
|
role = message.get("role")
|
||||||
|
content = message.get("content")
|
||||||
|
# Simple concatenation - might need refinement for complex content types (tool calls/results)
|
||||||
|
if isinstance(content, str):
|
||||||
|
text_to_count += f"{role}: {content}\n"
|
||||||
|
elif isinstance(content, list): # Handle tool results/calls if represented as list
|
||||||
|
try:
|
||||||
|
content_str = json.dumps(content)
|
||||||
|
text_to_count += f"{role}: {content_str}\n"
|
||||||
|
except Exception:
|
||||||
|
text_to_count += f"{role}: [Unserializable Content]\n"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Use the client's count_tokens method if available and works with text
|
||||||
|
# Check Anthropic documentation for the correct usage
|
||||||
|
# Assuming self.client.count_tokens exists and takes text
|
||||||
|
count = self.client.count_tokens(text=text_to_count)
|
||||||
|
logger.debug(f"Counted Anthropic tokens using client.count_tokens: {count}")
|
||||||
|
return count
|
||||||
|
except APIError as api_err:
|
||||||
|
# Handle potential errors if count_tokens itself is an API call or fails
|
||||||
|
logger.error(f"Anthropic API error during token count: {api_err}", exc_info=True)
|
||||||
|
# Fallback to approximation if official count fails?
|
||||||
|
estimated_tokens = math.ceil(len(text_to_count) / 4.0) # Same approximation as OpenAI
|
||||||
|
logger.warning(f"Falling back to character count approximation for Anthropic: {estimated_tokens}")
|
||||||
|
return estimated_tokens
|
||||||
|
except AttributeError:
|
||||||
|
# Fallback if count_tokens method doesn't exist or works differently
|
||||||
|
logger.warning("self.client.count_tokens not available or failed. Falling back to character count approximation.")
|
||||||
|
estimated_tokens = math.ceil(len(text_to_count) / 4.0) # Same approximation as OpenAI
|
||||||
|
return estimated_tokens
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error during Anthropic token count: {e}", exc_info=True)
|
||||||
|
estimated_tokens = math.ceil(len(text_to_count) / 4.0) # Fallback approximation
|
||||||
|
logger.warning(f"Falling back to character count approximation due to unexpected error: {estimated_tokens}")
|
||||||
|
return estimated_tokens
|
||||||
|
|
||||||
|
def _truncate_messages(self, messages: list[dict[str, Any]], system_prompt: str | None, model: str) -> tuple[list[dict[str, Any]], str | None, int, int]:
|
||||||
|
"""
|
||||||
|
Truncates messages for Anthropic, preserving system prompt.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- Potentially truncated list of messages.
|
||||||
|
- Original system prompt (or None).
|
||||||
|
- Initial token count.
|
||||||
|
- Final token count.
|
||||||
|
"""
|
||||||
|
context_limit = self._get_context_window(model)
|
||||||
|
buffer = 200 # Safety buffer
|
||||||
|
effective_limit = context_limit - buffer
|
||||||
|
|
||||||
|
initial_token_count = self._count_anthropic_tokens(messages, system_prompt)
|
||||||
|
final_token_count = initial_token_count
|
||||||
|
|
||||||
|
truncated_messages = list(messages) # Copy
|
||||||
|
|
||||||
|
# Anthropic requires alternating user/assistant messages. Truncation needs care.
|
||||||
|
# We remove from the beginning (after potential system prompt).
|
||||||
|
# Removing the oldest message (index 0 of the list passed here, as system is separate)
|
||||||
|
|
||||||
|
while final_token_count > effective_limit and len(truncated_messages) > 0:
|
||||||
|
# Always remove the oldest message (index 0)
|
||||||
|
removed_message = truncated_messages.pop(0)
|
||||||
|
logger.debug(f"Truncating Anthropic message at index 0 (Role: {removed_message.get('role')}) due to context limit.")
|
||||||
|
|
||||||
|
# Ensure alternation after removal if possible (might be complex)
|
||||||
|
# For simplicity, just remove and recount for now.
|
||||||
|
# A more robust approach might need to remove pairs (user/assistant).
|
||||||
|
|
||||||
|
final_token_count = self._count_anthropic_tokens(truncated_messages, system_prompt)
|
||||||
|
logger.debug(f"Recalculated Anthropic tokens: {final_token_count}")
|
||||||
|
|
||||||
|
# Safety break
|
||||||
|
if not truncated_messages:
|
||||||
|
logger.warning("Truncation resulted in empty message list for Anthropic.")
|
||||||
|
break
|
||||||
|
|
||||||
|
if initial_token_count != final_token_count:
|
||||||
|
logger.info(
|
||||||
|
f"Truncated messages for Anthropic model {model}. Initial tokens: {initial_token_count}, Final tokens: {final_token_count}, Limit: {context_limit} (Effective: {effective_limit})"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug(f"No truncation needed for Anthropic model {model}. Tokens: {final_token_count}, Limit: {context_limit} (Effective: {effective_limit})")
|
||||||
|
|
||||||
|
# Ensure the remaining messages start with 'user' role if no system prompt
|
||||||
|
if not system_prompt and truncated_messages and truncated_messages[0].get("role") != "user":
|
||||||
|
logger.warning("First message after truncation is not 'user'. Prepending placeholder.")
|
||||||
|
# This might indicate an issue with the simple pop(0) logic if pairs weren't removed.
|
||||||
|
# For now, prepend a basic user message.
|
||||||
|
truncated_messages.insert(0, {"role": "user", "content": "[Context truncated]"})
|
||||||
|
# Recount after adding placeholder? Might exceed limit again. Risky.
|
||||||
|
# Let's log a warning instead of adding potentially problematic content.
|
||||||
|
# logger.warning("First message after truncation is not 'user'. This might cause issues with Anthropic API.")
|
||||||
|
|
||||||
|
return truncated_messages, system_prompt, initial_token_count, final_token_count
|
||||||
|
|
||||||
def _convert_messages(self, messages: list[dict[str, Any]]) -> tuple[str | None, list[dict[str, Any]]]:
|
def _convert_messages(self, messages: list[dict[str, Any]]) -> tuple[str | None, list[dict[str, Any]]]:
|
||||||
"""Converts standard message format to Anthropic's format, extracting system prompt."""
|
"""Converts standard message format to Anthropic's format, extracting system prompt."""
|
||||||
anthropic_messages = []
|
anthropic_messages = []
|
||||||
@@ -93,26 +211,33 @@ class AnthropicProvider(BaseProvider):
|
|||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
tools: list[dict[str, Any]] | None = None,
|
tools: list[dict[str, Any]] | None = None,
|
||||||
) -> Stream[MessageStreamEvent] | Message:
|
) -> Stream[MessageStreamEvent] | Message:
|
||||||
"""Creates a chat completion using the Anthropic API."""
|
"""Creates a chat completion using the Anthropic API, handling context truncation."""
|
||||||
logger.debug(f"Anthropic create_chat_completion called. Stream: {stream}, Tools: {bool(tools)}")
|
logger.debug(f"Anthropic create_chat_completion called. Model: {model}, Stream: {stream}, Tools: {bool(tools)}")
|
||||||
|
|
||||||
|
# --- Context Truncation ---
|
||||||
|
# First, convert to Anthropic format to separate system prompt
|
||||||
|
temp_system_prompt, temp_anthropic_messages = self._convert_messages(messages)
|
||||||
|
# Then, truncate based on token count
|
||||||
|
truncated_anthropic_msgs, final_system_prompt, _, _ = self._truncate_messages(temp_anthropic_messages, temp_system_prompt, model)
|
||||||
|
# --------------------------
|
||||||
|
|
||||||
# Anthropic requires max_tokens
|
# Anthropic requires max_tokens
|
||||||
if max_tokens is None:
|
if max_tokens is None:
|
||||||
max_tokens = 4096 # Default value if not provided
|
max_tokens = 4096 # Default value if not provided
|
||||||
logger.warning(f"max_tokens not provided for Anthropic, defaulting to {max_tokens}")
|
logger.warning(f"max_tokens not provided for Anthropic, defaulting to {max_tokens}")
|
||||||
|
|
||||||
system_prompt, anthropic_messages = self._convert_messages(messages)
|
# system_prompt, anthropic_messages = self._convert_messages(messages) # Moved above
|
||||||
|
|
||||||
try:
|
try:
|
||||||
completion_params = {
|
completion_params = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"messages": anthropic_messages,
|
"messages": truncated_anthropic_msgs, # Use truncated messages
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
"max_tokens": max_tokens,
|
"max_tokens": max_tokens,
|
||||||
"stream": stream,
|
"stream": stream,
|
||||||
}
|
}
|
||||||
if system_prompt:
|
if final_system_prompt: # Use potentially modified system prompt
|
||||||
completion_params["system"] = system_prompt
|
completion_params["system"] = final_system_prompt
|
||||||
if tools:
|
if tools:
|
||||||
completion_params["tools"] = tools
|
completion_params["tools"] = tools
|
||||||
# Anthropic doesn't have an explicit 'tool_choice' like OpenAI's 'auto' in the main API call
|
# Anthropic doesn't have an explicit 'tool_choice' like OpenAI's 'auto' in the main API call
|
||||||
@@ -129,7 +254,22 @@ class AnthropicProvider(BaseProvider):
|
|||||||
|
|
||||||
response = self.client.messages.create(**completion_params)
|
response = self.client.messages.create(**completion_params)
|
||||||
logger.debug("Anthropic API call successful.")
|
logger.debug("Anthropic API call successful.")
|
||||||
|
|
||||||
|
# --- Capture Actual Usage ---
|
||||||
|
actual_usage = None
|
||||||
|
if isinstance(response, Message) and response.usage:
|
||||||
|
actual_usage = {
|
||||||
|
"prompt_tokens": response.usage.input_tokens, # Anthropic uses input_tokens
|
||||||
|
"completion_tokens": response.usage.output_tokens, # Anthropic uses output_tokens
|
||||||
|
# Anthropic doesn't typically provide total_tokens directly in usage block
|
||||||
|
"total_tokens": response.usage.input_tokens + response.usage.output_tokens,
|
||||||
|
}
|
||||||
|
logger.info(f"Actual Anthropic API usage: {actual_usage}")
|
||||||
|
# TODO: How to get usage for streaming responses? Anthropic might send it in a final 'message_stop' event? Needs investigation.
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
# --------------------------
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Anthropic API error: {e}", exc_info=True)
|
logger.error(f"Anthropic API error: {e}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
@@ -293,3 +433,21 @@ class AnthropicProvider(BaseProvider):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error extracting original Anthropic message with calls: {e}", exc_info=True)
|
logger.error(f"Error extracting original Anthropic message with calls: {e}", exc_info=True)
|
||||||
return {"role": "assistant", "content": f"[Error extracting tool calls message: {str(e)}]"}
|
return {"role": "assistant", "content": f"[Error extracting tool calls message: {str(e)}]"}
|
||||||
|
|
||||||
|
def get_usage(self, response: Any) -> dict[str, int] | None:
|
||||||
|
"""Extracts token usage from a non-streaming Anthropic response."""
|
||||||
|
try:
|
||||||
|
if isinstance(response, Message) and response.usage:
|
||||||
|
usage = {
|
||||||
|
"prompt_tokens": response.usage.input_tokens,
|
||||||
|
"completion_tokens": response.usage.output_tokens,
|
||||||
|
# "total_tokens": response.usage.input_tokens + response.usage.output_tokens, # Optional
|
||||||
|
}
|
||||||
|
logger.debug(f"Extracted usage from Anthropic response: {usage}")
|
||||||
|
return usage
|
||||||
|
else:
|
||||||
|
logger.warning(f"Could not extract usage from Anthropic response object of type {type(response)}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error extracting usage from Anthropic response: {e}", exc_info=True)
|
||||||
|
return None
|
||||||
|
|||||||
@@ -134,6 +134,20 @@ class BaseProvider(abc.ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_usage(self, response: Any) -> dict[str, int] | None:
|
||||||
|
"""
|
||||||
|
Extracts token usage information from a non-streaming response object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: The non-streaming response object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary containing 'prompt_tokens' and 'completion_tokens',
|
||||||
|
or None if usage information is not available.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
# Optional: Add a method for follow-up completions if the provider API
|
# Optional: Add a method for follow-up completions if the provider API
|
||||||
# requires a specific structure different from just appending messages.
|
# requires a specific structure different from just appending messages.
|
||||||
# def create_follow_up_completion(...) -> Any:
|
# def create_follow_up_completion(...) -> Any:
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
# src/providers/openai_provider.py
|
# src/providers/openai_provider.py
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -8,8 +9,8 @@ from openai import OpenAI, Stream
|
|||||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall
|
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall
|
||||||
|
|
||||||
from providers.base import BaseProvider
|
from src.llm_models import MODELS
|
||||||
from src.llm_models import MODELS # Use absolute import
|
from src.providers.base import BaseProvider
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -29,6 +30,110 @@ class OpenAIProvider(BaseProvider):
|
|||||||
logger.error(f"Failed to initialize OpenAI client: {e}", exc_info=True)
|
logger.error(f"Failed to initialize OpenAI client: {e}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
def _get_context_window(self, model: str) -> int:
|
||||||
|
"""Retrieves the context window size for a given model."""
|
||||||
|
# Default to a safe fallback if model or provider info is missing
|
||||||
|
default_window = 8000
|
||||||
|
try:
|
||||||
|
# Assuming MODELS structure: MODELS['openai']['models'] is a list of dicts
|
||||||
|
provider_models = MODELS.get("openai", {}).get("models", [])
|
||||||
|
for m in provider_models:
|
||||||
|
if m.get("id") == model:
|
||||||
|
return m.get("context_window", default_window)
|
||||||
|
# Fallback if specific model ID not found in our list
|
||||||
|
logger.warning(f"Context window for OpenAI model '{model}' not found in MODELS config. Using default: {default_window}")
|
||||||
|
return default_window
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error retrieving context window for model {model}: {e}. Using default: {default_window}", exc_info=True)
|
||||||
|
return default_window
|
||||||
|
|
||||||
|
def _estimate_openai_token_count(self, messages: list[dict[str, str]]) -> int:
|
||||||
|
"""
|
||||||
|
Estimates the token count for OpenAI messages using char count / 4 approximation.
|
||||||
|
Note: This is less accurate than using tiktoken.
|
||||||
|
"""
|
||||||
|
total_chars = 0
|
||||||
|
for message in messages:
|
||||||
|
total_chars += len(message.get("role", ""))
|
||||||
|
content = message.get("content")
|
||||||
|
if isinstance(content, str):
|
||||||
|
total_chars += len(content)
|
||||||
|
# Rough approximation for function/tool call overhead if needed later
|
||||||
|
# Using math.ceil to round up, ensuring we don't underestimate too much.
|
||||||
|
estimated_tokens = math.ceil(total_chars / 4.0)
|
||||||
|
logger.debug(f"Estimated OpenAI token count (char/4): {estimated_tokens} for {len(messages)} messages")
|
||||||
|
return estimated_tokens
|
||||||
|
|
||||||
|
def _truncate_messages(self, messages: list[dict[str, str]], model: str) -> tuple[list[dict[str, str]], int, int]:
|
||||||
|
"""
|
||||||
|
Truncates messages from the beginning if estimated token count exceeds the limit.
|
||||||
|
Preserves the first message if it's a system prompt.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- The potentially truncated list of messages.
|
||||||
|
- The initial estimated token count.
|
||||||
|
- The final estimated token count after truncation (if any).
|
||||||
|
"""
|
||||||
|
context_limit = self._get_context_window(model)
|
||||||
|
# Add a buffer to be safer with approximation
|
||||||
|
buffer = 200 # Reduce buffer slightly as we round up now
|
||||||
|
effective_limit = context_limit - buffer
|
||||||
|
|
||||||
|
initial_estimated_count = self._estimate_openai_token_count(messages)
|
||||||
|
final_estimated_count = initial_estimated_count
|
||||||
|
|
||||||
|
truncated_messages = list(messages) # Make a copy
|
||||||
|
|
||||||
|
# Identify if the first message is a system prompt
|
||||||
|
has_system_prompt = False
|
||||||
|
if truncated_messages and truncated_messages[0].get("role") == "system":
|
||||||
|
has_system_prompt = True
|
||||||
|
# If only system prompt exists, don't truncate further
|
||||||
|
if len(truncated_messages) == 1 and final_estimated_count > effective_limit:
|
||||||
|
logger.warning(f"System prompt alone ({final_estimated_count} tokens) exceeds effective limit ({effective_limit}). Cannot truncate further.")
|
||||||
|
# Return original messages to avoid removing the only message
|
||||||
|
return messages, initial_estimated_count, final_estimated_count
|
||||||
|
|
||||||
|
while final_estimated_count > effective_limit:
|
||||||
|
if has_system_prompt and len(truncated_messages) <= 1:
|
||||||
|
# Should not happen if check above works, but safety break
|
||||||
|
logger.warning("Truncation stopped: Only system prompt remains.")
|
||||||
|
break
|
||||||
|
if not has_system_prompt and len(truncated_messages) <= 0:
|
||||||
|
logger.warning("Truncation stopped: No messages left.")
|
||||||
|
break # No messages left
|
||||||
|
|
||||||
|
# Determine index to remove: 1 if system prompt exists and list is long enough, else 0
|
||||||
|
remove_index = 1 if has_system_prompt and len(truncated_messages) > 1 else 0
|
||||||
|
|
||||||
|
if remove_index >= len(truncated_messages):
|
||||||
|
logger.error(f"Truncation logic error: remove_index {remove_index} out of bounds for {len(truncated_messages)} messages.")
|
||||||
|
break # Avoid index error
|
||||||
|
|
||||||
|
removed_message = truncated_messages.pop(remove_index)
|
||||||
|
logger.debug(f"Truncating message at index {remove_index} (Role: {removed_message.get('role')}) due to context limit.")
|
||||||
|
|
||||||
|
# Recalculate estimated count
|
||||||
|
final_estimated_count = self._estimate_openai_token_count(truncated_messages)
|
||||||
|
logger.debug(f"Recalculated estimated tokens: {final_estimated_count}")
|
||||||
|
|
||||||
|
# Safety break if list becomes unexpectedly empty
|
||||||
|
if not truncated_messages:
|
||||||
|
logger.warning("Truncation resulted in empty message list.")
|
||||||
|
break
|
||||||
|
|
||||||
|
if initial_estimated_count != final_estimated_count:
|
||||||
|
logger.info(
|
||||||
|
f"Truncated messages for model {model}. "
|
||||||
|
f"Initial estimated tokens: {initial_estimated_count}, "
|
||||||
|
f"Final estimated tokens: {final_estimated_count}, "
|
||||||
|
f"Limit: {context_limit} (Effective: {effective_limit})"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug(f"No truncation needed for model {model}. Estimated tokens: {final_estimated_count}, Limit: {context_limit} (Effective: {effective_limit})")
|
||||||
|
|
||||||
|
return truncated_messages, initial_estimated_count, final_estimated_count
|
||||||
|
|
||||||
def create_chat_completion(
|
def create_chat_completion(
|
||||||
self,
|
self,
|
||||||
messages: list[dict[str, str]],
|
messages: list[dict[str, str]],
|
||||||
@@ -37,13 +142,19 @@ class OpenAIProvider(BaseProvider):
|
|||||||
max_tokens: int | None = None,
|
max_tokens: int | None = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
tools: list[dict[str, Any]] | None = None,
|
tools: list[dict[str, Any]] | None = None,
|
||||||
) -> Stream[ChatCompletionChunk] | ChatCompletion:
|
# Add usage dict to return type hint? Needs careful thought for streaming vs non-streaming
|
||||||
"""Creates a chat completion using the OpenAI API."""
|
) -> Stream[ChatCompletionChunk] | ChatCompletion: # How to return usage info cleanly?
|
||||||
logger.debug(f"OpenAI create_chat_completion called. Stream: {stream}, Tools: {bool(tools)}")
|
"""Creates a chat completion using the OpenAI API, handling context window truncation."""
|
||||||
|
logger.debug(f"OpenAI create_chat_completion called. Model: {model}, Stream: {stream}, Tools: {bool(tools)}")
|
||||||
|
|
||||||
|
# --- Truncation Step ---
|
||||||
|
truncated_messages, initial_est_tokens, final_est_tokens = self._truncate_messages(messages, model)
|
||||||
|
# -----------------------
|
||||||
|
|
||||||
try:
|
try:
|
||||||
completion_params = {
|
completion_params = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"messages": messages,
|
"messages": truncated_messages, # Use truncated messages
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
"max_tokens": max_tokens,
|
"max_tokens": max_tokens,
|
||||||
"stream": stream,
|
"stream": stream,
|
||||||
@@ -71,7 +182,34 @@ class OpenAIProvider(BaseProvider):
|
|||||||
|
|
||||||
response = self.client.chat.completions.create(**completion_params)
|
response = self.client.chat.completions.create(**completion_params)
|
||||||
logger.debug("OpenAI API call successful.")
|
logger.debug("OpenAI API call successful.")
|
||||||
|
|
||||||
|
# --- Capture Actual Usage (for UI display later) ---
|
||||||
|
# This part is tricky. Usage info is easily available on the *non-streaming* response.
|
||||||
|
# For streaming, it's often not available until the stream is fully consumed,
|
||||||
|
# or sometimes via response headers or a final event (provider-dependent).
|
||||||
|
# For now, let's focus on getting it from the non-streaming case.
|
||||||
|
# We need a way to pass this back alongside the content/stream.
|
||||||
|
# Option 1: Modify return type (complex for stream/non-stream union)
|
||||||
|
# Option 2: Store it in the provider instance (stateful, maybe bad)
|
||||||
|
# Option 3: Have LLMClient handle extraction (requires LLMClient to know response structure)
|
||||||
|
|
||||||
|
# Let's try returning it alongside for non-streaming, and figure out streaming later.
|
||||||
|
# This requires changing the BaseProvider interface and LLMClient handling.
|
||||||
|
# For now, just log it here.
|
||||||
|
actual_usage = None
|
||||||
|
if isinstance(response, ChatCompletion) and response.usage:
|
||||||
|
actual_usage = {
|
||||||
|
"prompt_tokens": response.usage.prompt_tokens,
|
||||||
|
"completion_tokens": response.usage.completion_tokens,
|
||||||
|
"total_tokens": response.usage.total_tokens,
|
||||||
|
}
|
||||||
|
logger.info(f"Actual OpenAI API usage: {actual_usage}")
|
||||||
|
# TODO: How to handle usage for streaming responses? Needs investigation.
|
||||||
|
|
||||||
|
# Return the raw response for now. LLMClient will process it.
|
||||||
return response
|
return response
|
||||||
|
# ----------------------------------------------------
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"OpenAI API error: {e}", exc_info=True)
|
logger.error(f"OpenAI API error: {e}", exc_info=True)
|
||||||
# Re-raise for the LLMClient to handle
|
# Re-raise for the LLMClient to handle
|
||||||
@@ -233,7 +371,20 @@ class OpenAIProvider(BaseProvider):
|
|||||||
logger.error(f"Error extracting original message with calls: {e}", exc_info=True)
|
logger.error(f"Error extracting original message with calls: {e}", exc_info=True)
|
||||||
return {"role": "assistant", "content": f"[Error extracting tool calls message: {str(e)}]"}
|
return {"role": "assistant", "content": f"[Error extracting tool calls message: {str(e)}]"}
|
||||||
|
|
||||||
|
def get_usage(self, response: Any) -> dict[str, int] | None:
|
||||||
# Register this provider (if using the registration mechanism)
|
"""Extracts token usage from a non-streaming OpenAI response."""
|
||||||
# from . import register_provider
|
try:
|
||||||
# register_provider("openai", OpenAIProvider)
|
if isinstance(response, ChatCompletion) and response.usage:
|
||||||
|
usage = {
|
||||||
|
"prompt_tokens": response.usage.prompt_tokens,
|
||||||
|
"completion_tokens": response.usage.completion_tokens,
|
||||||
|
# "total_tokens": response.usage.total_tokens, # Optional
|
||||||
|
}
|
||||||
|
logger.debug(f"Extracted usage from OpenAI response: {usage}")
|
||||||
|
return usage
|
||||||
|
else:
|
||||||
|
logger.warning(f"Could not extract usage from OpenAI response object of type {type(response)}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error extracting usage from OpenAI response: {e}", exc_info=True)
|
||||||
|
return None
|
||||||
|
|||||||
Reference in New Issue
Block a user