220 lines
10 KiB
Python
220 lines
10 KiB
Python
import atexit
|
|
import configparser
|
|
import logging
|
|
|
|
import streamlit as st
|
|
|
|
from llm_client import LLMClient
|
|
from src.custom_mcp.manager import SyncMCPManager
|
|
|
|
# Configure logging for the app
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def init_session_state():
|
|
"""Initializes session state variables including clients."""
|
|
if "messages" not in st.session_state:
|
|
st.session_state.messages = []
|
|
logger.info("Initialized session state: messages")
|
|
|
|
if "client" not in st.session_state:
|
|
logger.info("Attempting to initialize clients...")
|
|
try:
|
|
config = configparser.ConfigParser()
|
|
# TODO: Improve config file path handling (e.g., environment variable, absolute path)
|
|
config_files_read = config.read("config/config.ini")
|
|
if not config_files_read:
|
|
raise FileNotFoundError("config.ini not found or could not be read.")
|
|
logger.info(f"Read configuration from: {config_files_read}")
|
|
|
|
# --- MCP Manager Setup ---
|
|
mcp_config_path = "config/mcp_config.json" # Default
|
|
if config.has_section("mcp") and config["mcp"].get("servers_json"):
|
|
mcp_config_path = config["mcp"]["servers_json"]
|
|
logger.info(f"Using MCP config path from config.ini: {mcp_config_path}")
|
|
else:
|
|
logger.info(f"Using default MCP config path: {mcp_config_path}")
|
|
|
|
mcp_manager = SyncMCPManager(mcp_config_path)
|
|
if not mcp_manager.initialize():
|
|
# Log warning but continue - LLMClient will operate without tools
|
|
logger.warning("MCP Manager failed to initialize. Proceeding without MCP tools.")
|
|
else:
|
|
logger.info("MCP Manager initialized successfully.")
|
|
# Register shutdown hook for MCP manager
|
|
atexit.register(mcp_manager.shutdown)
|
|
logger.info("Registered MCP Manager shutdown hook.")
|
|
|
|
# --- LLM Client Setup ---
|
|
provider_name = None
|
|
model_name = None
|
|
api_key = None
|
|
base_url = None
|
|
|
|
# 1. Determine provider from [base] section
|
|
if config.has_section("base") and config["base"].get("provider"):
|
|
provider_name = config["base"].get("provider")
|
|
logger.info(f"Provider selected from [base] section: {provider_name}")
|
|
else:
|
|
# Fallback or error if [base] provider is missing? Let's error for now.
|
|
raise ValueError("Missing 'provider' setting in [base] section of config.ini")
|
|
|
|
# 2. Read details from the specific provider's section
|
|
if config.has_section(provider_name):
|
|
provider_config = config[provider_name]
|
|
model_name = provider_config.get("model")
|
|
api_key = provider_config.get("api_key")
|
|
base_url = provider_config.get("base_url") # Optional
|
|
logger.info(f"Read configuration from [{provider_name}] section.")
|
|
else:
|
|
raise ValueError(f"Missing configuration section '[{provider_name}]' in config.ini for the selected provider.")
|
|
|
|
# Validate required config
|
|
if not api_key:
|
|
raise ValueError(f"Missing 'api_key' in [{provider_name}] section of config.ini")
|
|
if not model_name:
|
|
raise ValueError(f"Missing 'model' name in [{provider_name}] section of config.ini")
|
|
|
|
logger.info(f"Configuring LLMClient for provider: {provider_name}, model: {model_name}")
|
|
st.session_state.client = LLMClient(
|
|
provider_name=provider_name,
|
|
api_key=api_key,
|
|
mcp_manager=mcp_manager,
|
|
base_url=base_url,
|
|
)
|
|
st.session_state.model_name = model_name
|
|
st.session_state.provider_name = provider_name # Store provider name
|
|
logger.info("LLMClient initialized successfully.")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize application clients: {e}", exc_info=True)
|
|
st.error(f"Application Initialization Error: {e}. Please check configuration and logs.")
|
|
# Stop the app if initialization fails critically
|
|
st.stop()
|
|
|
|
|
|
def display_chat_messages():
|
|
"""Displays chat messages stored in session state."""
|
|
for message in st.session_state.messages:
|
|
with st.chat_message(message["role"]):
|
|
# Display content
|
|
st.markdown(message["content"])
|
|
# Display usage if available (for assistant messages)
|
|
if message["role"] == "assistant" and "usage" in message:
|
|
usage = message["usage"]
|
|
prompt_tokens = usage.get("prompt_tokens", "N/A")
|
|
completion_tokens = usage.get("completion_tokens", "N/A")
|
|
st.caption(f"Tokens: Prompt {prompt_tokens}, Completion {completion_tokens}")
|
|
|
|
|
|
def handle_user_input():
|
|
"""Handles user input, calls LLMClient, and displays the response."""
|
|
if prompt := st.chat_input("Type your message..."):
|
|
logger.info(f"User input received: '{prompt[:50]}...'")
|
|
st.session_state.messages.append({"role": "user", "content": prompt})
|
|
with st.chat_message("user"):
|
|
st.markdown(prompt)
|
|
|
|
try:
|
|
with st.chat_message("assistant"):
|
|
response_placeholder = st.empty()
|
|
full_response = ""
|
|
error_occurred = False
|
|
response_usage = None # Initialize usage info
|
|
|
|
logger.info("Processing message via LLMClient...")
|
|
# Use the new client and method
|
|
# NOTE: Setting stream=False to easily get usage info from the response dict.
|
|
# A more complex solution is needed to get usage with streaming.
|
|
response_data = st.session_state.client.chat_completion(
|
|
messages=st.session_state.messages,
|
|
model=st.session_state.model_name,
|
|
stream=False, # Set to False for usage info
|
|
)
|
|
|
|
# Handle the response (now expecting a dict)
|
|
if isinstance(response_data, dict):
|
|
if "error" in response_data:
|
|
full_response = f"Error: {response_data['error']}"
|
|
logger.error(f"Error returned from chat_completion: {full_response}")
|
|
st.error(full_response)
|
|
error_occurred = True
|
|
else:
|
|
full_response = response_data.get("content", "")
|
|
response_usage = response_data.get("usage") # Get usage dict
|
|
if not full_response and not error_occurred: # Check error_occurred flag too
|
|
logger.warning("Empty content received from LLMClient.")
|
|
# Display nothing or a placeholder? Let's display nothing.
|
|
# full_response = "[Empty Response]"
|
|
# Display the full response at once (no streaming)
|
|
response_placeholder.markdown(full_response)
|
|
logger.debug("Non-streaming response processed.")
|
|
|
|
else:
|
|
# Unexpected response type
|
|
full_response = "[Unexpected response format from LLMClient]"
|
|
logger.error(f"Unexpected response type: {type(response_data)}")
|
|
st.error(full_response)
|
|
error_occurred = True
|
|
|
|
# Add response to history, including usage if available
|
|
if not error_occurred and full_response: # Only add if no error and content exists
|
|
assistant_message = {"role": "assistant", "content": full_response}
|
|
if response_usage:
|
|
assistant_message["usage"] = response_usage
|
|
logger.info(f"Assistant response usage: {response_usage}")
|
|
st.session_state.messages.append(assistant_message)
|
|
logger.info("Assistant response added to history.")
|
|
elif error_occurred:
|
|
logger.warning("Assistant response not added to history due to error.")
|
|
else:
|
|
logger.warning("Empty assistant response received, not added to history.")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error during chat handling: {str(e)}", exc_info=True)
|
|
st.error(f"An unexpected error occurred: {str(e)}")
|
|
|
|
|
|
def main():
|
|
"""Main function to run the Streamlit app."""
|
|
try:
|
|
init_session_state()
|
|
|
|
# --- Display Enhanced Header ---
|
|
provider_name = st.session_state.get("provider_name", "Unknown Provider")
|
|
model_name = st.session_state.get("model_name", "Unknown Model")
|
|
mcp_manager = st.session_state.client.mcp_manager # Get the manager
|
|
|
|
server_count = 0
|
|
tool_count = 0
|
|
if mcp_manager and mcp_manager.initialized:
|
|
server_count = len(mcp_manager.servers)
|
|
try:
|
|
# Get tool count (might be slightly slow if many tools/servers)
|
|
tool_count = len(mcp_manager.list_all_tools())
|
|
except Exception as e:
|
|
logger.warning(f"Could not retrieve tool count for header: {e}")
|
|
tool_count = "N/A" # Display N/A if listing fails
|
|
|
|
# Display the new header format
|
|
st.markdown(f"# Say Hi to **{provider_name.capitalize()}**!")
|
|
st.write(f"MCP Servers: **{server_count}** | Tools: **{tool_count}**")
|
|
st.write(f"Model: **{model_name}**")
|
|
st.divider()
|
|
# -----------------------------
|
|
|
|
# Removed the previous caption display
|
|
|
|
display_chat_messages()
|
|
handle_user_input()
|
|
except Exception as e:
|
|
# Catch potential errors during rendering or handling
|
|
logger.critical(f"Critical error in main app flow: {e}", exc_info=True)
|
|
st.error(f"A critical application error occurred: {e}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
logger.info("Starting Streamlit Chat App...")
|
|
main()
|