From 15ecb9fc489dd75f739e27ddd964932082228369 Mon Sep 17 00:00:00 2001 From: abhishekbhakat Date: Wed, 26 Mar 2025 17:27:41 +0000 Subject: [PATCH] feat: enhance token usage tracking and context management for LLM providers --- src/app.py | 83 ++++++------- src/llm_client.py | 17 ++- src/providers/anthropic_provider.py | 178 ++++++++++++++++++++++++++-- src/providers/base.py | 14 +++ src/providers/openai_provider.py | 171 ++++++++++++++++++++++++-- 5 files changed, 395 insertions(+), 68 deletions(-) diff --git a/src/app.py b/src/app.py index 2f99d17..15509d9 100644 --- a/src/app.py +++ b/src/app.py @@ -1,6 +1,5 @@ import atexit import configparser -import json import logging import streamlit as st @@ -99,8 +98,14 @@ def display_chat_messages(): """Displays chat messages stored in session state.""" for message in st.session_state.messages: with st.chat_message(message["role"]): - # Simple markdown display for now + # Display content st.markdown(message["content"]) + # Display usage if available (for assistant messages) + if message["role"] == "assistant" and "usage" in message: + usage = message["usage"] + prompt_tokens = usage.get("prompt_tokens", "N/A") + completion_tokens = usage.get("completion_tokens", "N/A") + st.caption(f"Tokens: Prompt {prompt_tokens}, Completion {completion_tokens}") def handle_user_input(): @@ -116,60 +121,50 @@ def handle_user_input(): response_placeholder = st.empty() full_response = "" error_occurred = False + response_usage = None # Initialize usage info logger.info("Processing message via LLMClient...") - # Use the new client and method, always requesting stream for UI - response_stream = st.session_state.client.chat_completion( + # Use the new client and method + # NOTE: Setting stream=False to easily get usage info from the response dict. + # A more complex solution is needed to get usage with streaming. + response_data = st.session_state.client.chat_completion( messages=st.session_state.messages, - model=st.session_state.model_name, # Get model from session state - stream=True, + model=st.session_state.model_name, + stream=False, # Set to False for usage info ) - # Handle the response (stream generator or error dict) - if hasattr(response_stream, "__iter__") and not isinstance(response_stream, dict): - logger.debug("Processing response stream...") - for chunk in response_stream: - # Check for potential error JSON yielded by the stream - try: - # Attempt to parse chunk as JSON only if it looks like it - if isinstance(chunk, str) and chunk.strip().startswith("{"): - error_data = json.loads(chunk) - if isinstance(error_data, dict) and "error" in error_data: - full_response = f"Error: {error_data['error']}" - logger.error(f"Error received in stream: {full_response}") - st.error(full_response) - error_occurred = True - break # Stop processing stream on error - # If not error JSON, treat as content chunk - if not error_occurred and isinstance(chunk, str): - full_response += chunk - response_placeholder.markdown(full_response + "▌") # Add cursor effect - except (json.JSONDecodeError, TypeError): - # Not JSON or not error structure, treat as content chunk - if not error_occurred and isinstance(chunk, str): - full_response += chunk - response_placeholder.markdown(full_response + "▌") # Add cursor effect + # Handle the response (now expecting a dict) + if isinstance(response_data, dict): + if "error" in response_data: + full_response = f"Error: {response_data['error']}" + logger.error(f"Error returned from chat_completion: {full_response}") + st.error(full_response) + error_occurred = True + else: + full_response = response_data.get("content", "") + response_usage = response_data.get("usage") # Get usage dict + if not full_response and not error_occurred: # Check error_occurred flag too + logger.warning("Empty content received from LLMClient.") + # Display nothing or a placeholder? Let's display nothing. + # full_response = "[Empty Response]" + # Display the full response at once (no streaming) + response_placeholder.markdown(full_response) + logger.debug("Non-streaming response processed.") - if not error_occurred: - response_placeholder.markdown(full_response) # Final update without cursor - logger.debug("Stream processing complete.") - - elif isinstance(response_stream, dict) and "error" in response_stream: - # Handle error dict returned directly (e.g., API error before streaming) - full_response = f"Error: {response_stream['error']}" - logger.error(f"Error returned directly from chat_completion: {full_response}") - st.error(full_response) - error_occurred = True else: # Unexpected response type full_response = "[Unexpected response format from LLMClient]" - logger.error(f"Unexpected response type: {type(response_stream)}") + logger.error(f"Unexpected response type: {type(response_data)}") st.error(full_response) error_occurred = True - # Only add non-error, non-empty responses to history - if not error_occurred and full_response: - st.session_state.messages.append({"role": "assistant", "content": full_response}) + # Add response to history, including usage if available + if not error_occurred and full_response: # Only add if no error and content exists + assistant_message = {"role": "assistant", "content": full_response} + if response_usage: + assistant_message["usage"] = response_usage + logger.info(f"Assistant response usage: {response_usage}") + st.session_state.messages.append(assistant_message) logger.info("Assistant response added to history.") elif error_occurred: logger.warning("Assistant response not added to history due to error.") diff --git a/src/llm_client.py b/src/llm_client.py index f68a3f2..daf18cf 100644 --- a/src/llm_client.py +++ b/src/llm_client.py @@ -72,8 +72,9 @@ class LLMClient: Returns: If stream=True: A generator yielding content chunks. - If stream=False: A dictionary containing the final content or an error. - e.g., {"content": "..."} or {"error": "..."} + If stream=False: A dictionary containing the final content, usage, or an error. + e.g., {"content": "...", "usage": {"prompt_tokens": ..., "completion_tokens": ...}} + or {"error": "..."} """ # Ensure tools are up-to-date (optional, could be done less frequently) # self._refresh_mcp_tools() @@ -173,8 +174,12 @@ class LLMClient: tools=provider_tools, # Pass tools again? Some providers might need it. ) final_content = self.provider.get_content(follow_up_response) + final_usage = self.provider.get_usage(follow_up_response) # Get usage from follow-up logger.info("Received follow-up response content.") - return {"content": final_content} + result_dict = {"content": final_content} + if final_usage: + result_dict["usage"] = final_usage + return result_dict except Exception as tool_handling_err: logger.error(f"Error processing tool calls: {tool_handling_err}", exc_info=True) @@ -183,7 +188,11 @@ class LLMClient: else: # No tool calls logger.info("No tool calls detected.") content = self.provider.get_content(response) - return {"content": content} + usage = self.provider.get_usage(response) # Get usage from initial response + result_dict = {"content": content} + if usage: + result_dict["usage"] = usage + return result_dict except Exception as e: error_msg = f"LLM API Error: {str(e)}" diff --git a/src/providers/anthropic_provider.py b/src/providers/anthropic_provider.py index 14f96f6..7645e51 100644 --- a/src/providers/anthropic_provider.py +++ b/src/providers/anthropic_provider.py @@ -1,16 +1,14 @@ # src/providers/anthropic_provider.py import json import logging +import math from collections.abc import Generator from typing import Any -from anthropic import Anthropic, Stream +from anthropic import Anthropic, APIError, Stream from anthropic.types import Message, MessageStreamEvent, TextDelta -# Use relative imports for modules within the same package from providers.base import BaseProvider - -# Use absolute imports as per Ruff warning and user instructions from src.llm_models import MODELS from src.tools.conversion import convert_to_anthropic_tools @@ -33,6 +31,126 @@ class AnthropicProvider(BaseProvider): logger.error(f"Failed to initialize Anthropic client: {e}", exc_info=True) raise + def _get_context_window(self, model: str) -> int: + """Retrieves the context window size for a given Anthropic model.""" + default_window = 100000 # Default fallback for Anthropic + try: + provider_models = MODELS.get("anthropic", {}).get("models", []) + for m in provider_models: + if m.get("id") == model: + return m.get("context_window", default_window) + logger.warning(f"Context window for Anthropic model '{model}' not found in MODELS config. Using default: {default_window}") + return default_window + except Exception as e: + logger.error(f"Error retrieving context window for model {model}: {e}. Using default: {default_window}", exc_info=True) + return default_window + + def _count_anthropic_tokens(self, messages: list[dict[str, Any]], system_prompt: str | None) -> int: + """Counts tokens for Anthropic messages using the official client.""" + # Note: Anthropic's count_tokens might not directly accept the message list format used for creation. + # It often expects plain text. We need to concatenate the content appropriately. + # This is a simplification and might not be perfectly accurate, especially with tool calls/results. + # A more robust approach might involve formatting messages into a single string representation. + text_to_count = "" + if system_prompt: + text_to_count += f"System: {system_prompt}\n\n" + for message in messages: + role = message.get("role") + content = message.get("content") + # Simple concatenation - might need refinement for complex content types (tool calls/results) + if isinstance(content, str): + text_to_count += f"{role}: {content}\n" + elif isinstance(content, list): # Handle tool results/calls if represented as list + try: + content_str = json.dumps(content) + text_to_count += f"{role}: {content_str}\n" + except Exception: + text_to_count += f"{role}: [Unserializable Content]\n" + + try: + # Use the client's count_tokens method if available and works with text + # Check Anthropic documentation for the correct usage + # Assuming self.client.count_tokens exists and takes text + count = self.client.count_tokens(text=text_to_count) + logger.debug(f"Counted Anthropic tokens using client.count_tokens: {count}") + return count + except APIError as api_err: + # Handle potential errors if count_tokens itself is an API call or fails + logger.error(f"Anthropic API error during token count: {api_err}", exc_info=True) + # Fallback to approximation if official count fails? + estimated_tokens = math.ceil(len(text_to_count) / 4.0) # Same approximation as OpenAI + logger.warning(f"Falling back to character count approximation for Anthropic: {estimated_tokens}") + return estimated_tokens + except AttributeError: + # Fallback if count_tokens method doesn't exist or works differently + logger.warning("self.client.count_tokens not available or failed. Falling back to character count approximation.") + estimated_tokens = math.ceil(len(text_to_count) / 4.0) # Same approximation as OpenAI + return estimated_tokens + except Exception as e: + logger.error(f"Unexpected error during Anthropic token count: {e}", exc_info=True) + estimated_tokens = math.ceil(len(text_to_count) / 4.0) # Fallback approximation + logger.warning(f"Falling back to character count approximation due to unexpected error: {estimated_tokens}") + return estimated_tokens + + def _truncate_messages(self, messages: list[dict[str, Any]], system_prompt: str | None, model: str) -> tuple[list[dict[str, Any]], str | None, int, int]: + """ + Truncates messages for Anthropic, preserving system prompt. + + Returns: + - Potentially truncated list of messages. + - Original system prompt (or None). + - Initial token count. + - Final token count. + """ + context_limit = self._get_context_window(model) + buffer = 200 # Safety buffer + effective_limit = context_limit - buffer + + initial_token_count = self._count_anthropic_tokens(messages, system_prompt) + final_token_count = initial_token_count + + truncated_messages = list(messages) # Copy + + # Anthropic requires alternating user/assistant messages. Truncation needs care. + # We remove from the beginning (after potential system prompt). + # Removing the oldest message (index 0 of the list passed here, as system is separate) + + while final_token_count > effective_limit and len(truncated_messages) > 0: + # Always remove the oldest message (index 0) + removed_message = truncated_messages.pop(0) + logger.debug(f"Truncating Anthropic message at index 0 (Role: {removed_message.get('role')}) due to context limit.") + + # Ensure alternation after removal if possible (might be complex) + # For simplicity, just remove and recount for now. + # A more robust approach might need to remove pairs (user/assistant). + + final_token_count = self._count_anthropic_tokens(truncated_messages, system_prompt) + logger.debug(f"Recalculated Anthropic tokens: {final_token_count}") + + # Safety break + if not truncated_messages: + logger.warning("Truncation resulted in empty message list for Anthropic.") + break + + if initial_token_count != final_token_count: + logger.info( + f"Truncated messages for Anthropic model {model}. Initial tokens: {initial_token_count}, Final tokens: {final_token_count}, Limit: {context_limit} (Effective: {effective_limit})" + ) + else: + logger.debug(f"No truncation needed for Anthropic model {model}. Tokens: {final_token_count}, Limit: {context_limit} (Effective: {effective_limit})") + + # Ensure the remaining messages start with 'user' role if no system prompt + if not system_prompt and truncated_messages and truncated_messages[0].get("role") != "user": + logger.warning("First message after truncation is not 'user'. Prepending placeholder.") + # This might indicate an issue with the simple pop(0) logic if pairs weren't removed. + # For now, prepend a basic user message. + truncated_messages.insert(0, {"role": "user", "content": "[Context truncated]"}) + # Recount after adding placeholder? Might exceed limit again. Risky. + # Let's log a warning instead of adding potentially problematic content. + # logger.warning("First message after truncation is not 'user'. This might cause issues with Anthropic API.") + + return truncated_messages, system_prompt, initial_token_count, final_token_count + def _convert_messages(self, messages: list[dict[str, Any]]) -> tuple[str | None, list[dict[str, Any]]]: """Converts standard message format to Anthropic's format, extracting system prompt.""" anthropic_messages = [] @@ -93,26 +211,33 @@ class AnthropicProvider(BaseProvider): stream: bool = True, tools: list[dict[str, Any]] | None = None, ) -> Stream[MessageStreamEvent] | Message: - """Creates a chat completion using the Anthropic API.""" - logger.debug(f"Anthropic create_chat_completion called. Stream: {stream}, Tools: {bool(tools)}") + """Creates a chat completion using the Anthropic API, handling context truncation.""" + logger.debug(f"Anthropic create_chat_completion called. Model: {model}, Stream: {stream}, Tools: {bool(tools)}") + + # --- Context Truncation --- + # First, convert to Anthropic format to separate system prompt + temp_system_prompt, temp_anthropic_messages = self._convert_messages(messages) + # Then, truncate based on token count + truncated_anthropic_msgs, final_system_prompt, _, _ = self._truncate_messages(temp_anthropic_messages, temp_system_prompt, model) + # -------------------------- # Anthropic requires max_tokens if max_tokens is None: max_tokens = 4096 # Default value if not provided logger.warning(f"max_tokens not provided for Anthropic, defaulting to {max_tokens}") - system_prompt, anthropic_messages = self._convert_messages(messages) + # system_prompt, anthropic_messages = self._convert_messages(messages) # Moved above try: completion_params = { "model": model, - "messages": anthropic_messages, + "messages": truncated_anthropic_msgs, # Use truncated messages "temperature": temperature, "max_tokens": max_tokens, "stream": stream, } - if system_prompt: - completion_params["system"] = system_prompt + if final_system_prompt: # Use potentially modified system prompt + completion_params["system"] = final_system_prompt if tools: completion_params["tools"] = tools # Anthropic doesn't have an explicit 'tool_choice' like OpenAI's 'auto' in the main API call @@ -129,7 +254,22 @@ class AnthropicProvider(BaseProvider): response = self.client.messages.create(**completion_params) logger.debug("Anthropic API call successful.") + + # --- Capture Actual Usage --- + actual_usage = None + if isinstance(response, Message) and response.usage: + actual_usage = { + "prompt_tokens": response.usage.input_tokens, # Anthropic uses input_tokens + "completion_tokens": response.usage.output_tokens, # Anthropic uses output_tokens + # Anthropic doesn't typically provide total_tokens directly in usage block + "total_tokens": response.usage.input_tokens + response.usage.output_tokens, + } + logger.info(f"Actual Anthropic API usage: {actual_usage}") + # TODO: How to get usage for streaming responses? Anthropic might send it in a final 'message_stop' event? Needs investigation. + return response + # -------------------------- + except Exception as e: logger.error(f"Anthropic API error: {e}", exc_info=True) raise @@ -293,3 +433,21 @@ class AnthropicProvider(BaseProvider): except Exception as e: logger.error(f"Error extracting original Anthropic message with calls: {e}", exc_info=True) return {"role": "assistant", "content": f"[Error extracting tool calls message: {str(e)}]"} + + def get_usage(self, response: Any) -> dict[str, int] | None: + """Extracts token usage from a non-streaming Anthropic response.""" + try: + if isinstance(response, Message) and response.usage: + usage = { + "prompt_tokens": response.usage.input_tokens, + "completion_tokens": response.usage.output_tokens, + # "total_tokens": response.usage.input_tokens + response.usage.output_tokens, # Optional + } + logger.debug(f"Extracted usage from Anthropic response: {usage}") + return usage + else: + logger.warning(f"Could not extract usage from Anthropic response object of type {type(response)}") + return None + except Exception as e: + logger.error(f"Error extracting usage from Anthropic response: {e}", exc_info=True) + return None diff --git a/src/providers/base.py b/src/providers/base.py index 25fb02c..3332509 100644 --- a/src/providers/base.py +++ b/src/providers/base.py @@ -134,6 +134,20 @@ class BaseProvider(abc.ABC): """ pass + @abc.abstractmethod + def get_usage(self, response: Any) -> dict[str, int] | None: + """ + Extracts token usage information from a non-streaming response object. + + Args: + response: The non-streaming response object. + + Returns: + A dictionary containing 'prompt_tokens' and 'completion_tokens', + or None if usage information is not available. + """ + pass + # Optional: Add a method for follow-up completions if the provider API # requires a specific structure different from just appending messages. # def create_follow_up_completion(...) -> Any: diff --git a/src/providers/openai_provider.py b/src/providers/openai_provider.py index 7306081..7b16eb4 100644 --- a/src/providers/openai_provider.py +++ b/src/providers/openai_provider.py @@ -1,6 +1,7 @@ # src/providers/openai_provider.py import json import logging +import math from collections.abc import Generator from typing import Any @@ -8,8 +9,8 @@ from openai import OpenAI, Stream from openai.types.chat import ChatCompletion, ChatCompletionChunk from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall -from providers.base import BaseProvider -from src.llm_models import MODELS # Use absolute import +from src.llm_models import MODELS +from src.providers.base import BaseProvider logger = logging.getLogger(__name__) @@ -29,6 +30,110 @@ class OpenAIProvider(BaseProvider): logger.error(f"Failed to initialize OpenAI client: {e}", exc_info=True) raise + def _get_context_window(self, model: str) -> int: + """Retrieves the context window size for a given model.""" + # Default to a safe fallback if model or provider info is missing + default_window = 8000 + try: + # Assuming MODELS structure: MODELS['openai']['models'] is a list of dicts + provider_models = MODELS.get("openai", {}).get("models", []) + for m in provider_models: + if m.get("id") == model: + return m.get("context_window", default_window) + # Fallback if specific model ID not found in our list + logger.warning(f"Context window for OpenAI model '{model}' not found in MODELS config. Using default: {default_window}") + return default_window + except Exception as e: + logger.error(f"Error retrieving context window for model {model}: {e}. Using default: {default_window}", exc_info=True) + return default_window + + def _estimate_openai_token_count(self, messages: list[dict[str, str]]) -> int: + """ + Estimates the token count for OpenAI messages using char count / 4 approximation. + Note: This is less accurate than using tiktoken. + """ + total_chars = 0 + for message in messages: + total_chars += len(message.get("role", "")) + content = message.get("content") + if isinstance(content, str): + total_chars += len(content) + # Rough approximation for function/tool call overhead if needed later + # Using math.ceil to round up, ensuring we don't underestimate too much. + estimated_tokens = math.ceil(total_chars / 4.0) + logger.debug(f"Estimated OpenAI token count (char/4): {estimated_tokens} for {len(messages)} messages") + return estimated_tokens + + def _truncate_messages(self, messages: list[dict[str, str]], model: str) -> tuple[list[dict[str, str]], int, int]: + """ + Truncates messages from the beginning if estimated token count exceeds the limit. + Preserves the first message if it's a system prompt. + + Returns: + - The potentially truncated list of messages. + - The initial estimated token count. + - The final estimated token count after truncation (if any). + """ + context_limit = self._get_context_window(model) + # Add a buffer to be safer with approximation + buffer = 200 # Reduce buffer slightly as we round up now + effective_limit = context_limit - buffer + + initial_estimated_count = self._estimate_openai_token_count(messages) + final_estimated_count = initial_estimated_count + + truncated_messages = list(messages) # Make a copy + + # Identify if the first message is a system prompt + has_system_prompt = False + if truncated_messages and truncated_messages[0].get("role") == "system": + has_system_prompt = True + # If only system prompt exists, don't truncate further + if len(truncated_messages) == 1 and final_estimated_count > effective_limit: + logger.warning(f"System prompt alone ({final_estimated_count} tokens) exceeds effective limit ({effective_limit}). Cannot truncate further.") + # Return original messages to avoid removing the only message + return messages, initial_estimated_count, final_estimated_count + + while final_estimated_count > effective_limit: + if has_system_prompt and len(truncated_messages) <= 1: + # Should not happen if check above works, but safety break + logger.warning("Truncation stopped: Only system prompt remains.") + break + if not has_system_prompt and len(truncated_messages) <= 0: + logger.warning("Truncation stopped: No messages left.") + break # No messages left + + # Determine index to remove: 1 if system prompt exists and list is long enough, else 0 + remove_index = 1 if has_system_prompt and len(truncated_messages) > 1 else 0 + + if remove_index >= len(truncated_messages): + logger.error(f"Truncation logic error: remove_index {remove_index} out of bounds for {len(truncated_messages)} messages.") + break # Avoid index error + + removed_message = truncated_messages.pop(remove_index) + logger.debug(f"Truncating message at index {remove_index} (Role: {removed_message.get('role')}) due to context limit.") + + # Recalculate estimated count + final_estimated_count = self._estimate_openai_token_count(truncated_messages) + logger.debug(f"Recalculated estimated tokens: {final_estimated_count}") + + # Safety break if list becomes unexpectedly empty + if not truncated_messages: + logger.warning("Truncation resulted in empty message list.") + break + + if initial_estimated_count != final_estimated_count: + logger.info( + f"Truncated messages for model {model}. " + f"Initial estimated tokens: {initial_estimated_count}, " + f"Final estimated tokens: {final_estimated_count}, " + f"Limit: {context_limit} (Effective: {effective_limit})" + ) + else: + logger.debug(f"No truncation needed for model {model}. Estimated tokens: {final_estimated_count}, Limit: {context_limit} (Effective: {effective_limit})") + + return truncated_messages, initial_estimated_count, final_estimated_count + def create_chat_completion( self, messages: list[dict[str, str]], @@ -37,13 +142,19 @@ class OpenAIProvider(BaseProvider): max_tokens: int | None = None, stream: bool = True, tools: list[dict[str, Any]] | None = None, - ) -> Stream[ChatCompletionChunk] | ChatCompletion: - """Creates a chat completion using the OpenAI API.""" - logger.debug(f"OpenAI create_chat_completion called. Stream: {stream}, Tools: {bool(tools)}") + # Add usage dict to return type hint? Needs careful thought for streaming vs non-streaming + ) -> Stream[ChatCompletionChunk] | ChatCompletion: # How to return usage info cleanly? + """Creates a chat completion using the OpenAI API, handling context window truncation.""" + logger.debug(f"OpenAI create_chat_completion called. Model: {model}, Stream: {stream}, Tools: {bool(tools)}") + + # --- Truncation Step --- + truncated_messages, initial_est_tokens, final_est_tokens = self._truncate_messages(messages, model) + # ----------------------- + try: completion_params = { "model": model, - "messages": messages, + "messages": truncated_messages, # Use truncated messages "temperature": temperature, "max_tokens": max_tokens, "stream": stream, @@ -71,7 +182,34 @@ class OpenAIProvider(BaseProvider): response = self.client.chat.completions.create(**completion_params) logger.debug("OpenAI API call successful.") + + # --- Capture Actual Usage (for UI display later) --- + # This part is tricky. Usage info is easily available on the *non-streaming* response. + # For streaming, it's often not available until the stream is fully consumed, + # or sometimes via response headers or a final event (provider-dependent). + # For now, let's focus on getting it from the non-streaming case. + # We need a way to pass this back alongside the content/stream. + # Option 1: Modify return type (complex for stream/non-stream union) + # Option 2: Store it in the provider instance (stateful, maybe bad) + # Option 3: Have LLMClient handle extraction (requires LLMClient to know response structure) + + # Let's try returning it alongside for non-streaming, and figure out streaming later. + # This requires changing the BaseProvider interface and LLMClient handling. + # For now, just log it here. + actual_usage = None + if isinstance(response, ChatCompletion) and response.usage: + actual_usage = { + "prompt_tokens": response.usage.prompt_tokens, + "completion_tokens": response.usage.completion_tokens, + "total_tokens": response.usage.total_tokens, + } + logger.info(f"Actual OpenAI API usage: {actual_usage}") + # TODO: How to handle usage for streaming responses? Needs investigation. + + # Return the raw response for now. LLMClient will process it. return response + # ---------------------------------------------------- + except Exception as e: logger.error(f"OpenAI API error: {e}", exc_info=True) # Re-raise for the LLMClient to handle @@ -233,7 +371,20 @@ class OpenAIProvider(BaseProvider): logger.error(f"Error extracting original message with calls: {e}", exc_info=True) return {"role": "assistant", "content": f"[Error extracting tool calls message: {str(e)}]"} - -# Register this provider (if using the registration mechanism) -# from . import register_provider -# register_provider("openai", OpenAIProvider) + def get_usage(self, response: Any) -> dict[str, int] | None: + """Extracts token usage from a non-streaming OpenAI response.""" + try: + if isinstance(response, ChatCompletion) and response.usage: + usage = { + "prompt_tokens": response.usage.prompt_tokens, + "completion_tokens": response.usage.completion_tokens, + # "total_tokens": response.usage.total_tokens, # Optional + } + logger.debug(f"Extracted usage from OpenAI response: {usage}") + return usage + else: + logger.warning(f"Could not extract usage from OpenAI response object of type {type(response)}") + return None + except Exception as e: + logger.error(f"Error extracting usage from OpenAI response: {e}", exc_info=True) + return None