feat: enhance token usage tracking and context management for LLM providers

This commit is contained in:
2025-03-26 17:27:41 +00:00
parent 49aebc12d5
commit 15ecb9fc48
5 changed files with 395 additions and 68 deletions

View File

@@ -1,6 +1,5 @@
import atexit
import configparser
import json
import logging
import streamlit as st
@@ -99,8 +98,14 @@ def display_chat_messages():
"""Displays chat messages stored in session state."""
for message in st.session_state.messages:
with st.chat_message(message["role"]):
# Simple markdown display for now
# Display content
st.markdown(message["content"])
# Display usage if available (for assistant messages)
if message["role"] == "assistant" and "usage" in message:
usage = message["usage"]
prompt_tokens = usage.get("prompt_tokens", "N/A")
completion_tokens = usage.get("completion_tokens", "N/A")
st.caption(f"Tokens: Prompt {prompt_tokens}, Completion {completion_tokens}")
def handle_user_input():
@@ -116,60 +121,50 @@ def handle_user_input():
response_placeholder = st.empty()
full_response = ""
error_occurred = False
response_usage = None # Initialize usage info
logger.info("Processing message via LLMClient...")
# Use the new client and method, always requesting stream for UI
response_stream = st.session_state.client.chat_completion(
# Use the new client and method
# NOTE: Setting stream=False to easily get usage info from the response dict.
# A more complex solution is needed to get usage with streaming.
response_data = st.session_state.client.chat_completion(
messages=st.session_state.messages,
model=st.session_state.model_name, # Get model from session state
stream=True,
model=st.session_state.model_name,
stream=False, # Set to False for usage info
)
# Handle the response (stream generator or error dict)
if hasattr(response_stream, "__iter__") and not isinstance(response_stream, dict):
logger.debug("Processing response stream...")
for chunk in response_stream:
# Check for potential error JSON yielded by the stream
try:
# Attempt to parse chunk as JSON only if it looks like it
if isinstance(chunk, str) and chunk.strip().startswith("{"):
error_data = json.loads(chunk)
if isinstance(error_data, dict) and "error" in error_data:
full_response = f"Error: {error_data['error']}"
logger.error(f"Error received in stream: {full_response}")
st.error(full_response)
error_occurred = True
break # Stop processing stream on error
# If not error JSON, treat as content chunk
if not error_occurred and isinstance(chunk, str):
full_response += chunk
response_placeholder.markdown(full_response + "") # Add cursor effect
except (json.JSONDecodeError, TypeError):
# Not JSON or not error structure, treat as content chunk
if not error_occurred and isinstance(chunk, str):
full_response += chunk
response_placeholder.markdown(full_response + "") # Add cursor effect
# Handle the response (now expecting a dict)
if isinstance(response_data, dict):
if "error" in response_data:
full_response = f"Error: {response_data['error']}"
logger.error(f"Error returned from chat_completion: {full_response}")
st.error(full_response)
error_occurred = True
else:
full_response = response_data.get("content", "")
response_usage = response_data.get("usage") # Get usage dict
if not full_response and not error_occurred: # Check error_occurred flag too
logger.warning("Empty content received from LLMClient.")
# Display nothing or a placeholder? Let's display nothing.
# full_response = "[Empty Response]"
# Display the full response at once (no streaming)
response_placeholder.markdown(full_response)
logger.debug("Non-streaming response processed.")
if not error_occurred:
response_placeholder.markdown(full_response) # Final update without cursor
logger.debug("Stream processing complete.")
elif isinstance(response_stream, dict) and "error" in response_stream:
# Handle error dict returned directly (e.g., API error before streaming)
full_response = f"Error: {response_stream['error']}"
logger.error(f"Error returned directly from chat_completion: {full_response}")
st.error(full_response)
error_occurred = True
else:
# Unexpected response type
full_response = "[Unexpected response format from LLMClient]"
logger.error(f"Unexpected response type: {type(response_stream)}")
logger.error(f"Unexpected response type: {type(response_data)}")
st.error(full_response)
error_occurred = True
# Only add non-error, non-empty responses to history
if not error_occurred and full_response:
st.session_state.messages.append({"role": "assistant", "content": full_response})
# Add response to history, including usage if available
if not error_occurred and full_response: # Only add if no error and content exists
assistant_message = {"role": "assistant", "content": full_response}
if response_usage:
assistant_message["usage"] = response_usage
logger.info(f"Assistant response usage: {response_usage}")
st.session_state.messages.append(assistant_message)
logger.info("Assistant response added to history.")
elif error_occurred:
logger.warning("Assistant response not added to history due to error.")