feat: enhance token usage tracking and context management for LLM providers

This commit is contained in:
2025-03-26 17:27:41 +00:00
parent 49aebc12d5
commit 15ecb9fc48
5 changed files with 395 additions and 68 deletions

View File

@@ -1,6 +1,5 @@
import atexit import atexit
import configparser import configparser
import json
import logging import logging
import streamlit as st import streamlit as st
@@ -99,8 +98,14 @@ def display_chat_messages():
"""Displays chat messages stored in session state.""" """Displays chat messages stored in session state."""
for message in st.session_state.messages: for message in st.session_state.messages:
with st.chat_message(message["role"]): with st.chat_message(message["role"]):
# Simple markdown display for now # Display content
st.markdown(message["content"]) st.markdown(message["content"])
# Display usage if available (for assistant messages)
if message["role"] == "assistant" and "usage" in message:
usage = message["usage"]
prompt_tokens = usage.get("prompt_tokens", "N/A")
completion_tokens = usage.get("completion_tokens", "N/A")
st.caption(f"Tokens: Prompt {prompt_tokens}, Completion {completion_tokens}")
def handle_user_input(): def handle_user_input():
@@ -116,60 +121,50 @@ def handle_user_input():
response_placeholder = st.empty() response_placeholder = st.empty()
full_response = "" full_response = ""
error_occurred = False error_occurred = False
response_usage = None # Initialize usage info
logger.info("Processing message via LLMClient...") logger.info("Processing message via LLMClient...")
# Use the new client and method, always requesting stream for UI # Use the new client and method
response_stream = st.session_state.client.chat_completion( # NOTE: Setting stream=False to easily get usage info from the response dict.
# A more complex solution is needed to get usage with streaming.
response_data = st.session_state.client.chat_completion(
messages=st.session_state.messages, messages=st.session_state.messages,
model=st.session_state.model_name, # Get model from session state model=st.session_state.model_name,
stream=True, stream=False, # Set to False for usage info
) )
# Handle the response (stream generator or error dict) # Handle the response (now expecting a dict)
if hasattr(response_stream, "__iter__") and not isinstance(response_stream, dict): if isinstance(response_data, dict):
logger.debug("Processing response stream...") if "error" in response_data:
for chunk in response_stream: full_response = f"Error: {response_data['error']}"
# Check for potential error JSON yielded by the stream logger.error(f"Error returned from chat_completion: {full_response}")
try: st.error(full_response)
# Attempt to parse chunk as JSON only if it looks like it error_occurred = True
if isinstance(chunk, str) and chunk.strip().startswith("{"): else:
error_data = json.loads(chunk) full_response = response_data.get("content", "")
if isinstance(error_data, dict) and "error" in error_data: response_usage = response_data.get("usage") # Get usage dict
full_response = f"Error: {error_data['error']}" if not full_response and not error_occurred: # Check error_occurred flag too
logger.error(f"Error received in stream: {full_response}") logger.warning("Empty content received from LLMClient.")
st.error(full_response) # Display nothing or a placeholder? Let's display nothing.
error_occurred = True # full_response = "[Empty Response]"
break # Stop processing stream on error # Display the full response at once (no streaming)
# If not error JSON, treat as content chunk response_placeholder.markdown(full_response)
if not error_occurred and isinstance(chunk, str): logger.debug("Non-streaming response processed.")
full_response += chunk
response_placeholder.markdown(full_response + "") # Add cursor effect
except (json.JSONDecodeError, TypeError):
# Not JSON or not error structure, treat as content chunk
if not error_occurred and isinstance(chunk, str):
full_response += chunk
response_placeholder.markdown(full_response + "") # Add cursor effect
if not error_occurred:
response_placeholder.markdown(full_response) # Final update without cursor
logger.debug("Stream processing complete.")
elif isinstance(response_stream, dict) and "error" in response_stream:
# Handle error dict returned directly (e.g., API error before streaming)
full_response = f"Error: {response_stream['error']}"
logger.error(f"Error returned directly from chat_completion: {full_response}")
st.error(full_response)
error_occurred = True
else: else:
# Unexpected response type # Unexpected response type
full_response = "[Unexpected response format from LLMClient]" full_response = "[Unexpected response format from LLMClient]"
logger.error(f"Unexpected response type: {type(response_stream)}") logger.error(f"Unexpected response type: {type(response_data)}")
st.error(full_response) st.error(full_response)
error_occurred = True error_occurred = True
# Only add non-error, non-empty responses to history # Add response to history, including usage if available
if not error_occurred and full_response: if not error_occurred and full_response: # Only add if no error and content exists
st.session_state.messages.append({"role": "assistant", "content": full_response}) assistant_message = {"role": "assistant", "content": full_response}
if response_usage:
assistant_message["usage"] = response_usage
logger.info(f"Assistant response usage: {response_usage}")
st.session_state.messages.append(assistant_message)
logger.info("Assistant response added to history.") logger.info("Assistant response added to history.")
elif error_occurred: elif error_occurred:
logger.warning("Assistant response not added to history due to error.") logger.warning("Assistant response not added to history due to error.")

View File

@@ -72,8 +72,9 @@ class LLMClient:
Returns: Returns:
If stream=True: A generator yielding content chunks. If stream=True: A generator yielding content chunks.
If stream=False: A dictionary containing the final content or an error. If stream=False: A dictionary containing the final content, usage, or an error.
e.g., {"content": "..."} or {"error": "..."} e.g., {"content": "...", "usage": {"prompt_tokens": ..., "completion_tokens": ...}}
or {"error": "..."}
""" """
# Ensure tools are up-to-date (optional, could be done less frequently) # Ensure tools are up-to-date (optional, could be done less frequently)
# self._refresh_mcp_tools() # self._refresh_mcp_tools()
@@ -173,8 +174,12 @@ class LLMClient:
tools=provider_tools, # Pass tools again? Some providers might need it. tools=provider_tools, # Pass tools again? Some providers might need it.
) )
final_content = self.provider.get_content(follow_up_response) final_content = self.provider.get_content(follow_up_response)
final_usage = self.provider.get_usage(follow_up_response) # Get usage from follow-up
logger.info("Received follow-up response content.") logger.info("Received follow-up response content.")
return {"content": final_content} result_dict = {"content": final_content}
if final_usage:
result_dict["usage"] = final_usage
return result_dict
except Exception as tool_handling_err: except Exception as tool_handling_err:
logger.error(f"Error processing tool calls: {tool_handling_err}", exc_info=True) logger.error(f"Error processing tool calls: {tool_handling_err}", exc_info=True)
@@ -183,7 +188,11 @@ class LLMClient:
else: # No tool calls else: # No tool calls
logger.info("No tool calls detected.") logger.info("No tool calls detected.")
content = self.provider.get_content(response) content = self.provider.get_content(response)
return {"content": content} usage = self.provider.get_usage(response) # Get usage from initial response
result_dict = {"content": content}
if usage:
result_dict["usage"] = usage
return result_dict
except Exception as e: except Exception as e:
error_msg = f"LLM API Error: {str(e)}" error_msg = f"LLM API Error: {str(e)}"

View File

@@ -1,16 +1,14 @@
# src/providers/anthropic_provider.py # src/providers/anthropic_provider.py
import json import json
import logging import logging
import math
from collections.abc import Generator from collections.abc import Generator
from typing import Any from typing import Any
from anthropic import Anthropic, Stream from anthropic import Anthropic, APIError, Stream
from anthropic.types import Message, MessageStreamEvent, TextDelta from anthropic.types import Message, MessageStreamEvent, TextDelta
# Use relative imports for modules within the same package
from providers.base import BaseProvider from providers.base import BaseProvider
# Use absolute imports as per Ruff warning and user instructions
from src.llm_models import MODELS from src.llm_models import MODELS
from src.tools.conversion import convert_to_anthropic_tools from src.tools.conversion import convert_to_anthropic_tools
@@ -33,6 +31,126 @@ class AnthropicProvider(BaseProvider):
logger.error(f"Failed to initialize Anthropic client: {e}", exc_info=True) logger.error(f"Failed to initialize Anthropic client: {e}", exc_info=True)
raise raise
def _get_context_window(self, model: str) -> int:
"""Retrieves the context window size for a given Anthropic model."""
default_window = 100000 # Default fallback for Anthropic
try:
provider_models = MODELS.get("anthropic", {}).get("models", [])
for m in provider_models:
if m.get("id") == model:
return m.get("context_window", default_window)
logger.warning(f"Context window for Anthropic model '{model}' not found in MODELS config. Using default: {default_window}")
return default_window
except Exception as e:
logger.error(f"Error retrieving context window for model {model}: {e}. Using default: {default_window}", exc_info=True)
return default_window
def _count_anthropic_tokens(self, messages: list[dict[str, Any]], system_prompt: str | None) -> int:
"""Counts tokens for Anthropic messages using the official client."""
# Note: Anthropic's count_tokens might not directly accept the message list format used for creation.
# It often expects plain text. We need to concatenate the content appropriately.
# This is a simplification and might not be perfectly accurate, especially with tool calls/results.
# A more robust approach might involve formatting messages into a single string representation.
text_to_count = ""
if system_prompt:
text_to_count += f"System: {system_prompt}\n\n"
for message in messages:
role = message.get("role")
content = message.get("content")
# Simple concatenation - might need refinement for complex content types (tool calls/results)
if isinstance(content, str):
text_to_count += f"{role}: {content}\n"
elif isinstance(content, list): # Handle tool results/calls if represented as list
try:
content_str = json.dumps(content)
text_to_count += f"{role}: {content_str}\n"
except Exception:
text_to_count += f"{role}: [Unserializable Content]\n"
try:
# Use the client's count_tokens method if available and works with text
# Check Anthropic documentation for the correct usage
# Assuming self.client.count_tokens exists and takes text
count = self.client.count_tokens(text=text_to_count)
logger.debug(f"Counted Anthropic tokens using client.count_tokens: {count}")
return count
except APIError as api_err:
# Handle potential errors if count_tokens itself is an API call or fails
logger.error(f"Anthropic API error during token count: {api_err}", exc_info=True)
# Fallback to approximation if official count fails?
estimated_tokens = math.ceil(len(text_to_count) / 4.0) # Same approximation as OpenAI
logger.warning(f"Falling back to character count approximation for Anthropic: {estimated_tokens}")
return estimated_tokens
except AttributeError:
# Fallback if count_tokens method doesn't exist or works differently
logger.warning("self.client.count_tokens not available or failed. Falling back to character count approximation.")
estimated_tokens = math.ceil(len(text_to_count) / 4.0) # Same approximation as OpenAI
return estimated_tokens
except Exception as e:
logger.error(f"Unexpected error during Anthropic token count: {e}", exc_info=True)
estimated_tokens = math.ceil(len(text_to_count) / 4.0) # Fallback approximation
logger.warning(f"Falling back to character count approximation due to unexpected error: {estimated_tokens}")
return estimated_tokens
def _truncate_messages(self, messages: list[dict[str, Any]], system_prompt: str | None, model: str) -> tuple[list[dict[str, Any]], str | None, int, int]:
"""
Truncates messages for Anthropic, preserving system prompt.
Returns:
- Potentially truncated list of messages.
- Original system prompt (or None).
- Initial token count.
- Final token count.
"""
context_limit = self._get_context_window(model)
buffer = 200 # Safety buffer
effective_limit = context_limit - buffer
initial_token_count = self._count_anthropic_tokens(messages, system_prompt)
final_token_count = initial_token_count
truncated_messages = list(messages) # Copy
# Anthropic requires alternating user/assistant messages. Truncation needs care.
# We remove from the beginning (after potential system prompt).
# Removing the oldest message (index 0 of the list passed here, as system is separate)
while final_token_count > effective_limit and len(truncated_messages) > 0:
# Always remove the oldest message (index 0)
removed_message = truncated_messages.pop(0)
logger.debug(f"Truncating Anthropic message at index 0 (Role: {removed_message.get('role')}) due to context limit.")
# Ensure alternation after removal if possible (might be complex)
# For simplicity, just remove and recount for now.
# A more robust approach might need to remove pairs (user/assistant).
final_token_count = self._count_anthropic_tokens(truncated_messages, system_prompt)
logger.debug(f"Recalculated Anthropic tokens: {final_token_count}")
# Safety break
if not truncated_messages:
logger.warning("Truncation resulted in empty message list for Anthropic.")
break
if initial_token_count != final_token_count:
logger.info(
f"Truncated messages for Anthropic model {model}. Initial tokens: {initial_token_count}, Final tokens: {final_token_count}, Limit: {context_limit} (Effective: {effective_limit})"
)
else:
logger.debug(f"No truncation needed for Anthropic model {model}. Tokens: {final_token_count}, Limit: {context_limit} (Effective: {effective_limit})")
# Ensure the remaining messages start with 'user' role if no system prompt
if not system_prompt and truncated_messages and truncated_messages[0].get("role") != "user":
logger.warning("First message after truncation is not 'user'. Prepending placeholder.")
# This might indicate an issue with the simple pop(0) logic if pairs weren't removed.
# For now, prepend a basic user message.
truncated_messages.insert(0, {"role": "user", "content": "[Context truncated]"})
# Recount after adding placeholder? Might exceed limit again. Risky.
# Let's log a warning instead of adding potentially problematic content.
# logger.warning("First message after truncation is not 'user'. This might cause issues with Anthropic API.")
return truncated_messages, system_prompt, initial_token_count, final_token_count
def _convert_messages(self, messages: list[dict[str, Any]]) -> tuple[str | None, list[dict[str, Any]]]: def _convert_messages(self, messages: list[dict[str, Any]]) -> tuple[str | None, list[dict[str, Any]]]:
"""Converts standard message format to Anthropic's format, extracting system prompt.""" """Converts standard message format to Anthropic's format, extracting system prompt."""
anthropic_messages = [] anthropic_messages = []
@@ -93,26 +211,33 @@ class AnthropicProvider(BaseProvider):
stream: bool = True, stream: bool = True,
tools: list[dict[str, Any]] | None = None, tools: list[dict[str, Any]] | None = None,
) -> Stream[MessageStreamEvent] | Message: ) -> Stream[MessageStreamEvent] | Message:
"""Creates a chat completion using the Anthropic API.""" """Creates a chat completion using the Anthropic API, handling context truncation."""
logger.debug(f"Anthropic create_chat_completion called. Stream: {stream}, Tools: {bool(tools)}") logger.debug(f"Anthropic create_chat_completion called. Model: {model}, Stream: {stream}, Tools: {bool(tools)}")
# --- Context Truncation ---
# First, convert to Anthropic format to separate system prompt
temp_system_prompt, temp_anthropic_messages = self._convert_messages(messages)
# Then, truncate based on token count
truncated_anthropic_msgs, final_system_prompt, _, _ = self._truncate_messages(temp_anthropic_messages, temp_system_prompt, model)
# --------------------------
# Anthropic requires max_tokens # Anthropic requires max_tokens
if max_tokens is None: if max_tokens is None:
max_tokens = 4096 # Default value if not provided max_tokens = 4096 # Default value if not provided
logger.warning(f"max_tokens not provided for Anthropic, defaulting to {max_tokens}") logger.warning(f"max_tokens not provided for Anthropic, defaulting to {max_tokens}")
system_prompt, anthropic_messages = self._convert_messages(messages) # system_prompt, anthropic_messages = self._convert_messages(messages) # Moved above
try: try:
completion_params = { completion_params = {
"model": model, "model": model,
"messages": anthropic_messages, "messages": truncated_anthropic_msgs, # Use truncated messages
"temperature": temperature, "temperature": temperature,
"max_tokens": max_tokens, "max_tokens": max_tokens,
"stream": stream, "stream": stream,
} }
if system_prompt: if final_system_prompt: # Use potentially modified system prompt
completion_params["system"] = system_prompt completion_params["system"] = final_system_prompt
if tools: if tools:
completion_params["tools"] = tools completion_params["tools"] = tools
# Anthropic doesn't have an explicit 'tool_choice' like OpenAI's 'auto' in the main API call # Anthropic doesn't have an explicit 'tool_choice' like OpenAI's 'auto' in the main API call
@@ -129,7 +254,22 @@ class AnthropicProvider(BaseProvider):
response = self.client.messages.create(**completion_params) response = self.client.messages.create(**completion_params)
logger.debug("Anthropic API call successful.") logger.debug("Anthropic API call successful.")
# --- Capture Actual Usage ---
actual_usage = None
if isinstance(response, Message) and response.usage:
actual_usage = {
"prompt_tokens": response.usage.input_tokens, # Anthropic uses input_tokens
"completion_tokens": response.usage.output_tokens, # Anthropic uses output_tokens
# Anthropic doesn't typically provide total_tokens directly in usage block
"total_tokens": response.usage.input_tokens + response.usage.output_tokens,
}
logger.info(f"Actual Anthropic API usage: {actual_usage}")
# TODO: How to get usage for streaming responses? Anthropic might send it in a final 'message_stop' event? Needs investigation.
return response return response
# --------------------------
except Exception as e: except Exception as e:
logger.error(f"Anthropic API error: {e}", exc_info=True) logger.error(f"Anthropic API error: {e}", exc_info=True)
raise raise
@@ -293,3 +433,21 @@ class AnthropicProvider(BaseProvider):
except Exception as e: except Exception as e:
logger.error(f"Error extracting original Anthropic message with calls: {e}", exc_info=True) logger.error(f"Error extracting original Anthropic message with calls: {e}", exc_info=True)
return {"role": "assistant", "content": f"[Error extracting tool calls message: {str(e)}]"} return {"role": "assistant", "content": f"[Error extracting tool calls message: {str(e)}]"}
def get_usage(self, response: Any) -> dict[str, int] | None:
"""Extracts token usage from a non-streaming Anthropic response."""
try:
if isinstance(response, Message) and response.usage:
usage = {
"prompt_tokens": response.usage.input_tokens,
"completion_tokens": response.usage.output_tokens,
# "total_tokens": response.usage.input_tokens + response.usage.output_tokens, # Optional
}
logger.debug(f"Extracted usage from Anthropic response: {usage}")
return usage
else:
logger.warning(f"Could not extract usage from Anthropic response object of type {type(response)}")
return None
except Exception as e:
logger.error(f"Error extracting usage from Anthropic response: {e}", exc_info=True)
return None

View File

@@ -134,6 +134,20 @@ class BaseProvider(abc.ABC):
""" """
pass pass
@abc.abstractmethod
def get_usage(self, response: Any) -> dict[str, int] | None:
"""
Extracts token usage information from a non-streaming response object.
Args:
response: The non-streaming response object.
Returns:
A dictionary containing 'prompt_tokens' and 'completion_tokens',
or None if usage information is not available.
"""
pass
# Optional: Add a method for follow-up completions if the provider API # Optional: Add a method for follow-up completions if the provider API
# requires a specific structure different from just appending messages. # requires a specific structure different from just appending messages.
# def create_follow_up_completion(...) -> Any: # def create_follow_up_completion(...) -> Any:

View File

@@ -1,6 +1,7 @@
# src/providers/openai_provider.py # src/providers/openai_provider.py
import json import json
import logging import logging
import math
from collections.abc import Generator from collections.abc import Generator
from typing import Any from typing import Any
@@ -8,8 +9,8 @@ from openai import OpenAI, Stream
from openai.types.chat import ChatCompletion, ChatCompletionChunk from openai.types.chat import ChatCompletion, ChatCompletionChunk
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall
from providers.base import BaseProvider from src.llm_models import MODELS
from src.llm_models import MODELS # Use absolute import from src.providers.base import BaseProvider
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -29,6 +30,110 @@ class OpenAIProvider(BaseProvider):
logger.error(f"Failed to initialize OpenAI client: {e}", exc_info=True) logger.error(f"Failed to initialize OpenAI client: {e}", exc_info=True)
raise raise
def _get_context_window(self, model: str) -> int:
"""Retrieves the context window size for a given model."""
# Default to a safe fallback if model or provider info is missing
default_window = 8000
try:
# Assuming MODELS structure: MODELS['openai']['models'] is a list of dicts
provider_models = MODELS.get("openai", {}).get("models", [])
for m in provider_models:
if m.get("id") == model:
return m.get("context_window", default_window)
# Fallback if specific model ID not found in our list
logger.warning(f"Context window for OpenAI model '{model}' not found in MODELS config. Using default: {default_window}")
return default_window
except Exception as e:
logger.error(f"Error retrieving context window for model {model}: {e}. Using default: {default_window}", exc_info=True)
return default_window
def _estimate_openai_token_count(self, messages: list[dict[str, str]]) -> int:
"""
Estimates the token count for OpenAI messages using char count / 4 approximation.
Note: This is less accurate than using tiktoken.
"""
total_chars = 0
for message in messages:
total_chars += len(message.get("role", ""))
content = message.get("content")
if isinstance(content, str):
total_chars += len(content)
# Rough approximation for function/tool call overhead if needed later
# Using math.ceil to round up, ensuring we don't underestimate too much.
estimated_tokens = math.ceil(total_chars / 4.0)
logger.debug(f"Estimated OpenAI token count (char/4): {estimated_tokens} for {len(messages)} messages")
return estimated_tokens
def _truncate_messages(self, messages: list[dict[str, str]], model: str) -> tuple[list[dict[str, str]], int, int]:
"""
Truncates messages from the beginning if estimated token count exceeds the limit.
Preserves the first message if it's a system prompt.
Returns:
- The potentially truncated list of messages.
- The initial estimated token count.
- The final estimated token count after truncation (if any).
"""
context_limit = self._get_context_window(model)
# Add a buffer to be safer with approximation
buffer = 200 # Reduce buffer slightly as we round up now
effective_limit = context_limit - buffer
initial_estimated_count = self._estimate_openai_token_count(messages)
final_estimated_count = initial_estimated_count
truncated_messages = list(messages) # Make a copy
# Identify if the first message is a system prompt
has_system_prompt = False
if truncated_messages and truncated_messages[0].get("role") == "system":
has_system_prompt = True
# If only system prompt exists, don't truncate further
if len(truncated_messages) == 1 and final_estimated_count > effective_limit:
logger.warning(f"System prompt alone ({final_estimated_count} tokens) exceeds effective limit ({effective_limit}). Cannot truncate further.")
# Return original messages to avoid removing the only message
return messages, initial_estimated_count, final_estimated_count
while final_estimated_count > effective_limit:
if has_system_prompt and len(truncated_messages) <= 1:
# Should not happen if check above works, but safety break
logger.warning("Truncation stopped: Only system prompt remains.")
break
if not has_system_prompt and len(truncated_messages) <= 0:
logger.warning("Truncation stopped: No messages left.")
break # No messages left
# Determine index to remove: 1 if system prompt exists and list is long enough, else 0
remove_index = 1 if has_system_prompt and len(truncated_messages) > 1 else 0
if remove_index >= len(truncated_messages):
logger.error(f"Truncation logic error: remove_index {remove_index} out of bounds for {len(truncated_messages)} messages.")
break # Avoid index error
removed_message = truncated_messages.pop(remove_index)
logger.debug(f"Truncating message at index {remove_index} (Role: {removed_message.get('role')}) due to context limit.")
# Recalculate estimated count
final_estimated_count = self._estimate_openai_token_count(truncated_messages)
logger.debug(f"Recalculated estimated tokens: {final_estimated_count}")
# Safety break if list becomes unexpectedly empty
if not truncated_messages:
logger.warning("Truncation resulted in empty message list.")
break
if initial_estimated_count != final_estimated_count:
logger.info(
f"Truncated messages for model {model}. "
f"Initial estimated tokens: {initial_estimated_count}, "
f"Final estimated tokens: {final_estimated_count}, "
f"Limit: {context_limit} (Effective: {effective_limit})"
)
else:
logger.debug(f"No truncation needed for model {model}. Estimated tokens: {final_estimated_count}, Limit: {context_limit} (Effective: {effective_limit})")
return truncated_messages, initial_estimated_count, final_estimated_count
def create_chat_completion( def create_chat_completion(
self, self,
messages: list[dict[str, str]], messages: list[dict[str, str]],
@@ -37,13 +142,19 @@ class OpenAIProvider(BaseProvider):
max_tokens: int | None = None, max_tokens: int | None = None,
stream: bool = True, stream: bool = True,
tools: list[dict[str, Any]] | None = None, tools: list[dict[str, Any]] | None = None,
) -> Stream[ChatCompletionChunk] | ChatCompletion: # Add usage dict to return type hint? Needs careful thought for streaming vs non-streaming
"""Creates a chat completion using the OpenAI API.""" ) -> Stream[ChatCompletionChunk] | ChatCompletion: # How to return usage info cleanly?
logger.debug(f"OpenAI create_chat_completion called. Stream: {stream}, Tools: {bool(tools)}") """Creates a chat completion using the OpenAI API, handling context window truncation."""
logger.debug(f"OpenAI create_chat_completion called. Model: {model}, Stream: {stream}, Tools: {bool(tools)}")
# --- Truncation Step ---
truncated_messages, initial_est_tokens, final_est_tokens = self._truncate_messages(messages, model)
# -----------------------
try: try:
completion_params = { completion_params = {
"model": model, "model": model,
"messages": messages, "messages": truncated_messages, # Use truncated messages
"temperature": temperature, "temperature": temperature,
"max_tokens": max_tokens, "max_tokens": max_tokens,
"stream": stream, "stream": stream,
@@ -71,7 +182,34 @@ class OpenAIProvider(BaseProvider):
response = self.client.chat.completions.create(**completion_params) response = self.client.chat.completions.create(**completion_params)
logger.debug("OpenAI API call successful.") logger.debug("OpenAI API call successful.")
# --- Capture Actual Usage (for UI display later) ---
# This part is tricky. Usage info is easily available on the *non-streaming* response.
# For streaming, it's often not available until the stream is fully consumed,
# or sometimes via response headers or a final event (provider-dependent).
# For now, let's focus on getting it from the non-streaming case.
# We need a way to pass this back alongside the content/stream.
# Option 1: Modify return type (complex for stream/non-stream union)
# Option 2: Store it in the provider instance (stateful, maybe bad)
# Option 3: Have LLMClient handle extraction (requires LLMClient to know response structure)
# Let's try returning it alongside for non-streaming, and figure out streaming later.
# This requires changing the BaseProvider interface and LLMClient handling.
# For now, just log it here.
actual_usage = None
if isinstance(response, ChatCompletion) and response.usage:
actual_usage = {
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens,
}
logger.info(f"Actual OpenAI API usage: {actual_usage}")
# TODO: How to handle usage for streaming responses? Needs investigation.
# Return the raw response for now. LLMClient will process it.
return response return response
# ----------------------------------------------------
except Exception as e: except Exception as e:
logger.error(f"OpenAI API error: {e}", exc_info=True) logger.error(f"OpenAI API error: {e}", exc_info=True)
# Re-raise for the LLMClient to handle # Re-raise for the LLMClient to handle
@@ -233,7 +371,20 @@ class OpenAIProvider(BaseProvider):
logger.error(f"Error extracting original message with calls: {e}", exc_info=True) logger.error(f"Error extracting original message with calls: {e}", exc_info=True)
return {"role": "assistant", "content": f"[Error extracting tool calls message: {str(e)}]"} return {"role": "assistant", "content": f"[Error extracting tool calls message: {str(e)}]"}
def get_usage(self, response: Any) -> dict[str, int] | None:
# Register this provider (if using the registration mechanism) """Extracts token usage from a non-streaming OpenAI response."""
# from . import register_provider try:
# register_provider("openai", OpenAIProvider) if isinstance(response, ChatCompletion) and response.usage:
usage = {
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
# "total_tokens": response.usage.total_tokens, # Optional
}
logger.debug(f"Extracted usage from OpenAI response: {usage}")
return usage
else:
logger.warning(f"Could not extract usage from OpenAI response object of type {type(response)}")
return None
except Exception as e:
logger.error(f"Error extracting usage from OpenAI response: {e}", exc_info=True)
return None