diff --git a/config/sample_mcp_config.json b/config/sample_mcp_config.json index 492bbe4..16b9b7e 100644 --- a/config/sample_mcp_config.json +++ b/config/sample_mcp_config.json @@ -1,12 +1,12 @@ { - "mcpServers": { - "dolphin-demo-database-sqlite": { - "command": "uvx", - "args": [ - "mcp-server-sqlite", - "--db-path", - "~/.dolphin/dolphin.db" - ] - } + "mcpServers": { + "mcp-server-sqlite": { + "command": "uvx", + "args": [ + "mcp-server-sqlite", + "--db-path", + "~/.mcpapp/mcpapp.db" + ] } + } } diff --git a/pyproject.toml b/pyproject.toml index fe11f97..475ed09 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,6 +61,7 @@ lint.select = [ "T10", # flake8-debugger "A", # flake8-builtins "UP", # pyupgrade + "TID", # flake8-tidy-imports ] lint.ignore = [ @@ -81,7 +82,7 @@ skip-magic-trailing-comma = false combine-as-imports = true [tool.ruff.lint.mccabe] -max-complexity = 12 +max-complexity = 16 [tool.ruff.lint.flake8-tidy-imports] # Disallow all relative imports. diff --git a/src/app.py b/src/app.py index 7fa66f9..4bba335 100644 --- a/src/app.py +++ b/src/app.py @@ -1,29 +1,111 @@ import atexit +import configparser +import json # For handling potential error JSON in stream +import logging import streamlit as st -from openai_client import OpenAIClient +# Updated imports +from llm_client import LLMClient +from src.custom_mcp.manager import SyncMCPManager # Updated import path + +# Configure logging for the app +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) def init_session_state(): + """Initializes session state variables including clients.""" if "messages" not in st.session_state: st.session_state.messages = [] + logger.info("Initialized session state: messages") + if "client" not in st.session_state: - st.session_state.client = OpenAIClient() - # Register cleanup for MCP servers - if hasattr(st.session_state.client, "mcp_manager"): - atexit.register(st.session_state.client.mcp_manager.shutdown) + logger.info("Attempting to initialize clients...") + try: + config = configparser.ConfigParser() + # TODO: Improve config file path handling (e.g., environment variable, absolute path) + config_files_read = config.read("config/config.ini") + if not config_files_read: + raise FileNotFoundError("config.ini not found or could not be read.") + logger.info(f"Read configuration from: {config_files_read}") + + # --- MCP Manager Setup --- + mcp_config_path = "config/mcp_config.json" # Default + if config.has_section("mcp") and config["mcp"].get("servers_json"): + mcp_config_path = config["mcp"]["servers_json"] + logger.info(f"Using MCP config path from config.ini: {mcp_config_path}") + else: + logger.info(f"Using default MCP config path: {mcp_config_path}") + + mcp_manager = SyncMCPManager(mcp_config_path) + if not mcp_manager.initialize(): + # Log warning but continue - LLMClient will operate without tools + logger.warning("MCP Manager failed to initialize. Proceeding without MCP tools.") + else: + logger.info("MCP Manager initialized successfully.") + # Register shutdown hook for MCP manager + atexit.register(mcp_manager.shutdown) + logger.info("Registered MCP Manager shutdown hook.") + + # --- LLM Client Setup --- + # Determine provider details (e.g., from an [llm] section) + provider_name = "openai" # Default + model_name = None # Must be provided in config + api_key = None + base_url = None + + # Prioritize [llm] section, fallback to [openai] for compatibility + if config.has_section("llm"): + logger.info("Reading configuration from [llm] section.") + provider_name = config["llm"].get("provider", provider_name) + model_name = config["llm"].get("model") + api_key = config["llm"].get("api_key") + base_url = config["llm"].get("base_url") # Optional + elif config.has_section("openai"): + logger.warning("Using legacy [openai] section for configuration.") + provider_name = "openai" # Force openai if using this section + model_name = config["openai"].get("model") + api_key = config["openai"].get("api_key") + base_url = config["openai"].get("base_url") # Optional + else: + raise ValueError("Missing [llm] or [openai] section in config.ini") + + # Validate required config + if not api_key: + raise ValueError("Missing 'api_key' in config.ini ([llm] or [openai] section)") + if not model_name: + raise ValueError("Missing 'model' name in config.ini ([llm] or [openai] section)") + + logger.info(f"Configuring LLMClient for provider: {provider_name}, model: {model_name}") + st.session_state.client = LLMClient( + provider_name=provider_name, + api_key=api_key, + mcp_manager=mcp_manager, + base_url=base_url, # Pass None if not provided + ) + st.session_state.model_name = model_name # Store model name for chat requests + logger.info("LLMClient initialized successfully.") + + except Exception as e: + logger.error(f"Failed to initialize application clients: {e}", exc_info=True) + st.error(f"Application Initialization Error: {e}. Please check configuration and logs.") + # Stop the app if initialization fails critically + st.stop() def display_chat_messages(): + """Displays chat messages stored in session state.""" for message in st.session_state.messages: with st.chat_message(message["role"]): + # Simple markdown display for now st.markdown(message["content"]) def handle_user_input(): + """Handles user input, calls LLMClient, and displays the response.""" if prompt := st.chat_input("Type your message..."): - print(f"User input received: {prompt}") # Debug log + logger.info(f"User input received: '{prompt[:50]}...'") st.session_state.messages.append({"role": "user", "content": prompt}) with st.chat_message("user"): st.markdown(prompt) @@ -32,39 +114,85 @@ def handle_user_input(): with st.chat_message("assistant"): response_placeholder = st.empty() full_response = "" + error_occurred = False - print("Processing message...") # Debug log - response = st.session_state.client.get_chat_response(st.session_state.messages) + logger.info("Processing message via LLMClient...") + # Use the new client and method, always requesting stream for UI + response_stream = st.session_state.client.chat_completion( + messages=st.session_state.messages, + model=st.session_state.model_name, # Get model from session state + stream=True, + ) - # Handle both MCP and standard OpenAI responses - # Check if it's NOT a dict (assuming stream is not a dict) - if not isinstance(response, dict): - # Standard OpenAI streaming response - for chunk in response: - # Ensure chunk has choices and delta before accessing - if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content: - full_response += chunk.choices[0].delta.content - response_placeholder.markdown(full_response + "▌") + # Handle the response (stream generator or error dict) + if hasattr(response_stream, "__iter__") and not isinstance(response_stream, dict): + logger.debug("Processing response stream...") + for chunk in response_stream: + # Check for potential error JSON yielded by the stream + try: + # Attempt to parse chunk as JSON only if it looks like it + if isinstance(chunk, str) and chunk.strip().startswith("{"): + error_data = json.loads(chunk) + if isinstance(error_data, dict) and "error" in error_data: + full_response = f"Error: {error_data['error']}" + logger.error(f"Error received in stream: {full_response}") + st.error(full_response) + error_occurred = True + break # Stop processing stream on error + # If not error JSON, treat as content chunk + if not error_occurred and isinstance(chunk, str): + full_response += chunk + response_placeholder.markdown(full_response + "▌") # Add cursor effect + except (json.JSONDecodeError, TypeError): + # Not JSON or not error structure, treat as content chunk + if not error_occurred and isinstance(chunk, str): + full_response += chunk + response_placeholder.markdown(full_response + "▌") # Add cursor effect + + if not error_occurred: + response_placeholder.markdown(full_response) # Final update without cursor + logger.debug("Stream processing complete.") + + elif isinstance(response_stream, dict) and "error" in response_stream: + # Handle error dict returned directly (e.g., API error before streaming) + full_response = f"Error: {response_stream['error']}" + logger.error(f"Error returned directly from chat_completion: {full_response}") + st.error(full_response) + error_occurred = True else: - # MCP non-streaming response - full_response = response.get("assistant_text", "") - response_placeholder.markdown(full_response) + # Unexpected response type + full_response = "[Unexpected response format from LLMClient]" + logger.error(f"Unexpected response type: {type(response_stream)}") + st.error(full_response) + error_occurred = True - response_placeholder.markdown(full_response) - st.session_state.messages.append({"role": "assistant", "content": full_response}) - print("Message processed successfully") # Debug log + # Only add non-error, non-empty responses to history + if not error_occurred and full_response: + st.session_state.messages.append({"role": "assistant", "content": full_response}) + logger.info("Assistant response added to history.") + elif error_occurred: + logger.warning("Assistant response not added to history due to error.") + else: + logger.warning("Empty assistant response received, not added to history.") except Exception as e: - st.error(f"Error processing message: {str(e)}") - print(f"Error details: {str(e)}") # Debug log + logger.error(f"Error during chat handling: {str(e)}", exc_info=True) + st.error(f"An unexpected error occurred: {str(e)}") def main(): - st.title("Streamlit Chat App") - init_session_state() - display_chat_messages() - handle_user_input() + """Main function to run the Streamlit app.""" + st.title("MCP Chat App") # Updated title + try: + init_session_state() + display_chat_messages() + handle_user_input() + except Exception as e: + # Catch potential errors during rendering or handling + logger.critical(f"Critical error in main app flow: {e}", exc_info=True) + st.error(f"A critical application error occurred: {e}") if __name__ == "__main__": + logger.info("Starting Streamlit Chat App...") main() diff --git a/src/custom_mcp/__init__.py b/src/custom_mcp/__init__.py new file mode 100644 index 0000000..4c52e6a --- /dev/null +++ b/src/custom_mcp/__init__.py @@ -0,0 +1 @@ +# This file makes src/mcp a Python package diff --git a/src/custom_mcp/client.py b/src/custom_mcp/client.py new file mode 100644 index 0000000..01d8b9a --- /dev/null +++ b/src/custom_mcp/client.py @@ -0,0 +1,281 @@ +# src/mcp/client.py +"""Client class for managing and interacting with a single MCP server process.""" + +import asyncio +import logging +from typing import Any + +from custom_mcp import process, protocol + +logger = logging.getLogger(__name__) + +# Define reasonable timeouts +LIST_TOOLS_TIMEOUT = 20.0 # Seconds (using the increased value from previous step) +CALL_TOOL_TIMEOUT = 110.0 # Seconds + + +class MCPClient: + """ + Manages the lifecycle and async communication with a single MCP server process. + """ + + def __init__(self, server_name: str, command: str, args: list[str], config_env: dict[str, str]): + """ + Initializes the client for a specific server configuration. + + Args: + server_name: Unique name for the server (for logging). + command: The command executable. + args: List of arguments for the command. + config_env: Server-specific environment variables. + """ + self.server_name = server_name + self.command = command + self.args = args + self.config_env = config_env + self.process: asyncio.subprocess.Process | None = None + self.reader: asyncio.StreamReader | None = None + self.writer: asyncio.StreamWriter | None = None + self._stderr_task: asyncio.Task | None = None + self._request_counter = 0 + self._is_running = False + self.logger = logging.getLogger(f"{__name__}.{self.server_name}") # Instance-specific logger + + async def _log_stderr(self): + """Logs stderr output from the server process.""" + if not self.process or not self.process.stderr: + self.logger.debug("Stderr logging skipped: process or stderr not available.") + return + stderr_reader = self.process.stderr + try: + while not stderr_reader.at_eof(): + line = await stderr_reader.readline() + if line: + self.logger.warning(f"[stderr] {line.decode().strip()}") + except asyncio.CancelledError: + self.logger.debug("Stderr logging task cancelled.") + except Exception as e: + # Log errors but don't crash the logger task if possible + self.logger.error(f"Error reading stderr: {e}", exc_info=True) + finally: + self.logger.debug("Stderr logging task finished.") + + async def start(self) -> bool: + """ + Starts the MCP server subprocess and sets up communication streams. + + Returns: + True if the process started successfully, False otherwise. + """ + if self._is_running: + self.logger.warning("Start called but client is already running.") + return True + + self.logger.info("Starting MCP server process...") + try: + self.process = await process.start_mcp_process(self.command, self.args, self.config_env) + self.reader = self.process.stdout + self.writer = self.process.stdin + + if self.reader is None or self.writer is None: + self.logger.error("Failed to get stdout/stdin streams after process start.") + await self.stop() # Attempt cleanup + return False + + # Start background task to monitor stderr + self._stderr_task = asyncio.create_task(self._log_stderr()) + + # --- Start MCP Initialization Handshake --- + self.logger.info("Starting MCP initialization handshake...") + self._request_counter += 1 + init_req_id = self._request_counter + initialize_req = { + "jsonrpc": "2.0", + "id": init_req_id, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", # Use a recent version + "clientInfo": {"name": "CustomMCPClient", "version": "1.0.0"}, # Identify the client + "capabilities": {}, # Client capabilities (can be empty) + }, + } + + # Define a timeout for initialization + INITIALIZE_TIMEOUT = 15.0 # Seconds + + try: + # Send initialize request + await protocol.send_request(self.writer, initialize_req) + self.logger.debug(f"Sent 'initialize' request (ID: {init_req_id}). Waiting for response...") + + # Wait for initialize response + init_response = await protocol.read_response(self.reader, INITIALIZE_TIMEOUT) + + if init_response and init_response.get("id") == init_req_id: + if "error" in init_response: + self.logger.error(f"Server returned error during initialization: {init_response['error']}") + await self.stop() + return False + elif "result" in init_response: + self.logger.info(f"Received 'initialize' response: {init_response.get('result', '{}')}") # Log server capabilities if provided + + # Send initialized notification (using standard method name) + initialized_notify = {"jsonrpc": "2.0", "method": "notifications/initialized", "params": {}} + await protocol.send_request(self.writer, initialized_notify) + self.logger.info("'notifications/initialized' notification sent.") + + self._is_running = True + self.logger.info("MCP server process started and initialized successfully.") + return True + else: + self.logger.error("Invalid 'initialize' response format (missing result/error).") + await self.stop() + return False + elif init_response: + self.logger.error(f"Received response with mismatched ID during initialization. Expected {init_req_id}, got {init_response.get('id')}") + await self.stop() + return False + else: # Timeout case + self.logger.error(f"'initialize' request timed out after {INITIALIZE_TIMEOUT} seconds.") + await self.stop() + return False + + except ConnectionResetError: + self.logger.error("Connection lost during initialization handshake. Stopping client.") + await self.stop() + return False + except Exception as e: + self.logger.error(f"Unexpected error during initialization handshake: {e}", exc_info=True) + await self.stop() + return False + # --- End MCP Initialization Handshake --- + + except Exception as e: + self.logger.error(f"Failed to start MCP server process: {e}", exc_info=True) + self.process = None # Ensure process is None on failure + self.reader = None + self.writer = None + self._is_running = False + return False + + async def stop(self): + """Stops the MCP server subprocess gracefully.""" + if not self._is_running and not self.process: + self.logger.debug("Stop called but client is not running.") + return + + self.logger.info("Stopping MCP server process...") + self._is_running = False # Mark as stopping + + # Cancel stderr logging task + if self._stderr_task and not self._stderr_task.done(): + self._stderr_task.cancel() + try: + await self._stderr_task + except asyncio.CancelledError: + self.logger.debug("Stderr task successfully cancelled.") + except Exception as e: + self.logger.error(f"Error waiting for stderr task cancellation: {e}") + self._stderr_task = None + + # Stop the process using the utility function + if self.process: + await process.stop_mcp_process(self.process, self.server_name) + + # Nullify references + self.process = None + self.reader = None + self.writer = None + self.logger.info("MCP server process stopped.") + + async def list_tools(self) -> list[dict[str, Any]] | None: + """ + Sends a 'tools/list' request and waits for the response. + + Returns: + A list of tool dictionaries, or None on error/timeout. + """ + if not self._is_running or not self.writer or not self.reader: + self.logger.error("Cannot list tools: client not running or streams unavailable.") + return None + + self._request_counter += 1 + req_id = self._request_counter + request = {"jsonrpc": "2.0", "method": "tools/list", "id": req_id} + + try: + await protocol.send_request(self.writer, request) + response = await protocol.read_response(self.reader, LIST_TOOLS_TIMEOUT) + + if response and "result" in response and isinstance(response["result"], dict) and "tools" in response["result"]: + tools = response["result"]["tools"] + if isinstance(tools, list): + self.logger.info(f"Successfully listed {len(tools)} tools.") + return tools + else: + self.logger.error(f"Invalid 'tools' format in response ID {req_id}: {type(tools)}") + return None + elif response and "error" in response: + self.logger.error(f"Error response for listTools ID {req_id}: {response['error']}") + return None + else: + # Includes timeout case (read_response returns None) + self.logger.error(f"No valid response or timeout for listTools ID {req_id}.") + return None + + except ConnectionResetError: + self.logger.error("Connection lost during listTools request. Stopping client.") + await self.stop() + return None + except Exception as e: + self.logger.error(f"Unexpected error during listTools: {e}", exc_info=True) + return None + + async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> dict[str, Any] | None: + """ + Sends a 'tools/call' request and waits for the response. + + Args: + tool_name: The name of the tool to call. + arguments: The arguments for the tool. + + Returns: + The result dictionary from the server, or None on error/timeout. + """ + if not self._is_running or not self.writer or not self.reader: + self.logger.error(f"Cannot call tool '{tool_name}': client not running or streams unavailable.") + return None + + self._request_counter += 1 + req_id = self._request_counter + request = { + "jsonrpc": "2.0", + "method": "tools/call", + "params": {"name": tool_name, "arguments": arguments}, + "id": req_id, + } + + try: + await protocol.send_request(self.writer, request) + response = await protocol.read_response(self.reader, CALL_TOOL_TIMEOUT) + + if response and "result" in response: + # Assuming result is the desired payload + self.logger.info(f"Tool '{tool_name}' executed successfully.") + return response["result"] + elif response and "error" in response: + self.logger.error(f"Error response for tool '{tool_name}' ID {req_id}: {response['error']}") + # Return the error structure itself? Or just None? Returning error dict for now. + return {"error": response["error"]} + else: + # Includes timeout case + self.logger.error(f"No valid response or timeout for tool '{tool_name}' ID {req_id}.") + return None + + except ConnectionResetError: + self.logger.error(f"Connection lost during callTool '{tool_name}'. Stopping client.") + await self.stop() + return None + except Exception as e: + self.logger.error(f"Unexpected error during callTool '{tool_name}': {e}", exc_info=True) + return None diff --git a/src/custom_mcp/manager.py b/src/custom_mcp/manager.py new file mode 100644 index 0000000..dcf739c --- /dev/null +++ b/src/custom_mcp/manager.py @@ -0,0 +1,366 @@ +# src/mcp/manager.py +"""Synchronous manager for multiple MCPClient instances.""" + +import asyncio +import json +import logging +import threading +from typing import Any + +# Use relative imports within the mcp package +from custom_mcp.client import MCPClient + +# Configure basic logging +# Consider moving this to the main app entry point if not already done +logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) + +# Define reasonable timeouts for sync calls (should be slightly longer than async timeouts) +INITIALIZE_TIMEOUT = 60.0 # Seconds +SHUTDOWN_TIMEOUT = 30.0 # Seconds +LIST_ALL_TOOLS_TIMEOUT = 30.0 # Seconds +EXECUTE_TOOL_TIMEOUT = 120.0 # Seconds + + +class SyncMCPManager: + """ + Manages the lifecycle of multiple MCPClient instances and provides a + synchronous interface to interact with them using a background event loop. + """ + + def __init__(self, config_path: str = "config/mcp_config.json"): + """ + Initializes the manager, loads config, but does not start servers yet. + + Args: + config_path: Path to the MCP server configuration JSON file. + """ + self.config_path = config_path + self.config: dict[str, Any] | None = None + # Stores server_name -> MCPClient instance + self.servers: dict[str, MCPClient] = {} + self.initialized = False + self._lock = threading.Lock() + self._loop: asyncio.AbstractEventLoop | None = None + self._thread: threading.Thread | None = None + logger.info(f"Initializing SyncMCPManager with config path: {config_path}") + self._load_config() + + def _load_config(self): + """Load MCP configuration from JSON file.""" + logger.debug(f"Attempting to load MCP config from: {self.config_path}") + try: + # Using direct file access + with open(self.config_path) as f: + self.config = json.load(f) + logger.info("MCP configuration loaded successfully.") + logger.debug(f"Config content: {self.config}") + except FileNotFoundError: + logger.error(f"MCP config file not found at {self.config_path}") + self.config = None + except json.JSONDecodeError as e: + logger.error(f"Error decoding JSON from MCP config file {self.config_path}: {e}") + self.config = None + except Exception as e: + logger.error(f"Error loading MCP config from {self.config_path}: {e}", exc_info=True) + self.config = None + + # --- Background Event Loop Management --- + + def _run_event_loop(self): + """Target function for the background event loop thread.""" + try: + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + self._loop.run_forever() + finally: + if self._loop and not self._loop.is_closed(): + # Clean up remaining tasks before closing + try: + tasks = asyncio.all_tasks(self._loop) + if tasks: + logger.debug(f"Cancelling {len(tasks)} outstanding tasks before closing loop...") + for task in tasks: + task.cancel() + # Allow cancellation to propagate + self._loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True)) + logger.debug("Outstanding tasks cancelled.") + self._loop.run_until_complete(self._loop.shutdown_asyncgens()) + except Exception as e: + logger.error(f"Error during event loop cleanup: {e}") + finally: + self._loop.close() + asyncio.set_event_loop(None) + logger.info("Event loop thread finished.") + + def _start_event_loop_thread(self): + """Starts the background event loop thread if not already running.""" + if self._thread is None or not self._thread.is_alive(): + self._thread = threading.Thread(target=self._run_event_loop, name="MCPEventLoop", daemon=True) + self._thread.start() + logger.info("Event loop thread started.") + # Wait briefly for the loop to become available and running + while self._loop is None or not self._loop.is_running(): + # Use time.sleep in sync context + import time + + time.sleep(0.01) + logger.debug("Event loop is running.") + + def _stop_event_loop_thread(self): + """Stops the background event loop thread.""" + if self._loop and self._loop.is_running(): + logger.info("Requesting event loop stop...") + self._loop.call_soon_threadsafe(self._loop.stop) + if self._thread and self._thread.is_alive(): + logger.info("Waiting for event loop thread to join...") + self._thread.join(timeout=5) + if self._thread.is_alive(): + logger.warning("Event loop thread did not stop gracefully.") + self._loop = None + self._thread = None + logger.info("Event loop stopped.") + + # --- Public Synchronous Interface --- + + def initialize(self) -> bool: + """ + Initializes and starts all configured MCP servers synchronously. + + Returns: + True if all servers started successfully, False otherwise. + """ + logger.info("Manager initialization requested.") + if not self.config or not self.config.get("mcpServers"): + logger.warning("Initialization skipped: No valid configuration loaded.") + return False + + with self._lock: + if self.initialized: + logger.debug("Initialization skipped: Already initialized.") + return True + + self._start_event_loop_thread() + if not self._loop: + logger.error("Failed to start event loop for initialization.") + return False + + logger.info("Submitting asynchronous server initialization...") + + # Prepare coroutine to start all clients + + async def _async_init_all(): + tasks = [] + for server_name, server_config in self.config["mcpServers"].items(): + command = server_config.get("command") + args = server_config.get("args", []) + config_env = server_config.get("env", {}) + if not command: + logger.error(f"Skipping server {server_name}: Missing 'command'.") + continue + + client = MCPClient(server_name, command, args, config_env) + self.servers[server_name] = client + tasks.append(client.start()) # Append the start coroutine + + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Check results - True means success, False or Exception means failure + all_success = True + failed_servers = [] + for i, result in enumerate(results): + server_name = list(self.config["mcpServers"].keys())[i] # Assumes order is maintained + if isinstance(result, Exception) or result is False: + all_success = False + failed_servers.append(server_name) + # Remove failed client from managed servers + if server_name in self.servers: + del self.servers[server_name] + logger.error(f"Failed to start client for server '{server_name}'. Result/Error: {result}") + + if not all_success: + logger.error(f"Initialization failed for servers: {failed_servers}") + return all_success + + # Run the initialization coroutine in the background loop + future = asyncio.run_coroutine_threadsafe(_async_init_all(), self._loop) + try: + success = future.result(timeout=INITIALIZE_TIMEOUT) + if success: + logger.info("Asynchronous initialization completed successfully.") + self.initialized = True + else: + logger.error("Asynchronous initialization failed.") + self.initialized = False + # Attempt to clean up any partially started servers + self.shutdown() # Call sync shutdown + except TimeoutError: + logger.error(f"Initialization timed out after {INITIALIZE_TIMEOUT}s.") + self.initialized = False + self.shutdown() # Clean up + success = False + except Exception as e: + logger.error(f"Exception during initialization future result: {e}", exc_info=True) + self.initialized = False + self.shutdown() # Clean up + success = False + + return self.initialized + + def shutdown(self): + """Shuts down all managed MCP servers synchronously.""" + logger.info("Manager shutdown requested.") + with self._lock: + # Check servers dict too, in case init was partial + if not self.initialized and not self.servers: + logger.debug("Shutdown skipped: Not initialized or no servers running.") + # Ensure loop is stopped if it exists + if self._thread and self._thread.is_alive(): + self._stop_event_loop_thread() + return + + if not self._loop or not self._loop.is_running(): + logger.warning("Shutdown requested but event loop not running. Attempting direct cleanup.") + # Attempt direct cleanup if loop isn't running (shouldn't happen ideally) + # This part is tricky as MCPClient.stop is async. + # For simplicity, we might just log and rely on process termination on app exit. + # Or, try a temporary loop just for shutdown? Let's stick to stopping the thread for now. + self.servers = {} + self.initialized = False + if self._thread and self._thread.is_alive(): + self._stop_event_loop_thread() + return + + logger.info("Submitting asynchronous server shutdown...") + + # Prepare coroutine to stop all clients + + async def _async_shutdown_all(): + tasks = [client.stop() for client in self.servers.values()] + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) + + # Run the shutdown coroutine in the background loop + future = asyncio.run_coroutine_threadsafe(_async_shutdown_all(), self._loop) + try: + future.result(timeout=SHUTDOWN_TIMEOUT) + logger.info("Asynchronous shutdown completed.") + except TimeoutError: + logger.error(f"Shutdown timed out after {SHUTDOWN_TIMEOUT}s. Event loop will be stopped.") + # Processes might still be running, OS will clean up on exit hopefully + except Exception as e: + logger.error(f"Exception during shutdown future result: {e}", exc_info=True) + finally: + # Always mark as uninitialized and clear servers dict + self.servers = {} + self.initialized = False + # Stop the background thread + self._stop_event_loop_thread() + + logger.info("Manager shutdown complete.") + + def list_all_tools(self) -> list[dict[str, Any]]: + """ + Retrieves tools from all initialized MCP servers synchronously. + + Returns: + A list of tool definitions in the standard internal format, + aggregated from all servers. Returns empty list on failure. + """ + if not self.initialized or not self.servers: + logger.warning("Cannot list tools: Manager not initialized or no servers running.") + return [] + + if not self._loop or not self._loop.is_running(): + logger.error("Cannot list tools: Event loop not running.") + return [] + + logger.info(f"Requesting tools from {len(self.servers)} servers...") + + # Prepare coroutine to list tools from all clients + async def _async_list_all(): + tasks = [] + server_names_in_order = [] + for server_name, client in self.servers.items(): + tasks.append(client.list_tools()) + server_names_in_order.append(server_name) + + results = await asyncio.gather(*tasks, return_exceptions=True) + + all_tools = [] + for i, result in enumerate(results): + server_name = server_names_in_order[i] + if isinstance(result, Exception): + logger.error(f"Error listing tools for server '{server_name}': {result}") + elif result is None: + # MCPClient.list_tools returns None on timeout/error + logger.error(f"Failed to list tools for server '{server_name}' (timeout or error).") + elif isinstance(result, list): + # Add server_name to each tool definition + for tool in result: + tool["server_name"] = server_name + all_tools.extend(result) + logger.debug(f"Received {len(result)} tools from {server_name}") + else: + logger.error(f"Unexpected result type ({type(result)}) when listing tools for {server_name}.") + return all_tools + + # Run the coroutine in the background loop + future = asyncio.run_coroutine_threadsafe(_async_list_all(), self._loop) + try: + aggregated_tools = future.result(timeout=LIST_ALL_TOOLS_TIMEOUT) + logger.info(f"Aggregated {len(aggregated_tools)} tools from all servers.") + return aggregated_tools + except TimeoutError: + logger.error(f"Listing all tools timed out after {LIST_ALL_TOOLS_TIMEOUT}s.") + return [] + except Exception as e: + logger.error(f"Exception during listing all tools future result: {e}", exc_info=True) + return [] + + def execute_tool(self, server_name: str, tool_name: str, arguments: dict[str, Any]) -> dict[str, Any] | None: + """ + Executes a specific tool on the designated MCP server synchronously. + + Args: + server_name: The name of the server hosting the tool. + tool_name: The name of the tool to execute. + arguments: A dictionary of arguments for the tool. + + Returns: + The result content from the tool execution (dict), + an error dict ({"error": ...}), or None on timeout/comm failure. + """ + if not self.initialized: + logger.warning(f"Cannot execute tool '{tool_name}' on {server_name}: Manager not initialized.") + return None + + client = self.servers.get(server_name) + if not client: + logger.error(f"Cannot execute tool: Server '{server_name}' not found.") + return None + + if not self._loop or not self._loop.is_running(): + logger.error(f"Cannot execute tool '{tool_name}': Event loop not running.") + return None + + logger.info(f"Executing tool '{tool_name}' on server '{server_name}' with args: {arguments}") + + # Run the client's call_tool coroutine in the background loop + future = asyncio.run_coroutine_threadsafe(client.call_tool(tool_name, arguments), self._loop) + try: + result = future.result(timeout=EXECUTE_TOOL_TIMEOUT) + # MCPClient.call_tool returns the result dict or an error dict or None + if result is None: + logger.error(f"Tool execution '{tool_name}' on {server_name} failed (timeout or comm error).") + elif isinstance(result, dict) and "error" in result: + logger.error(f"Tool execution '{tool_name}' on {server_name} returned error: {result['error']}") + else: + logger.info(f"Tool '{tool_name}' execution successful.") + return result # Return result dict, error dict, or None + except TimeoutError: + logger.error(f"Tool execution timed out after {EXECUTE_TOOL_TIMEOUT}s for '{tool_name}' on {server_name}.") + return None + except Exception as e: + logger.error(f"Exception during tool execution future result for '{tool_name}' on {server_name}: {e}", exc_info=True) + return None diff --git a/src/custom_mcp/process.py b/src/custom_mcp/process.py new file mode 100644 index 0000000..ec312e9 --- /dev/null +++ b/src/custom_mcp/process.py @@ -0,0 +1,128 @@ +# src/mcp/process.py +"""Async utilities for managing MCP server subprocesses.""" + +import asyncio +import logging +import os +import subprocess + +logger = logging.getLogger(__name__) + + +async def start_mcp_process(command: str, args: list[str], config_env: dict[str, str]) -> asyncio.subprocess.Process: + """ + Starts an MCP server subprocess using asyncio.create_subprocess_shell. + + Handles argument expansion and environment merging. + + Args: + command: The main command executable. + args: A list of arguments for the command. + config_env: Server-specific environment variables from config. + + Returns: + The started asyncio.subprocess.Process object. + + Raises: + FileNotFoundError: If the command is not found. + Exception: For other errors during subprocess creation. + """ + logger.debug(f"Preparing to start process for command: {command}") + + # --- Add tilde expansion for arguments --- + expanded_args = [] + try: + for arg in args: + if isinstance(arg, str) and "~" in arg: + expanded_args.append(os.path.expanduser(arg)) + else: + # Ensure all args are strings for list2cmdline + expanded_args.append(str(arg)) + logger.debug(f"Expanded args: {expanded_args}") + except Exception as e: + logger.error(f"Error expanding arguments for {command}: {e}", exc_info=True) + raise ValueError(f"Failed to expand arguments: {e}") from e + + # --- Merge os.environ with config_env --- + merged_env = {**os.environ, **config_env} + # logger.debug(f"Merged environment prepared (keys: {list(merged_env.keys())})") # Avoid logging values + + # Combine command and expanded args into a single string for shell execution + try: + cmd_string = subprocess.list2cmdline([command] + expanded_args) + logger.debug(f"Executing shell command: {cmd_string}") + except Exception as e: + logger.error(f"Error creating command string: {e}", exc_info=True) + raise ValueError(f"Failed to create command string: {e}") from e + + # --- Start the subprocess using shell --- + try: + process = await asyncio.create_subprocess_shell( + cmd_string, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env=merged_env, + ) + logger.info(f"Subprocess started (PID: {process.pid}) for command: {command}") + return process + except FileNotFoundError: + logger.error(f"Command not found: '{command}' when trying to execute '{cmd_string}'") + raise # Re-raise specific error + except Exception as e: + logger.error(f"Failed to create subprocess for '{cmd_string}': {e}", exc_info=True) + raise # Re-raise other errors + + +async def stop_mcp_process(process: asyncio.subprocess.Process, server_name: str = "MCP Server"): + """ + Attempts to gracefully stop the MCP server subprocess. + + Args: + process: The asyncio.subprocess.Process object to stop. + server_name: A name for logging purposes. + """ + if process is None or process.returncode is not None: + logger.debug(f"Process {server_name} (PID: {process.pid if process else 'N/A'}) already stopped or not started.") + return + + pid = process.pid + logger.info(f"Attempting to stop process {server_name} (PID: {pid})...") + + # Close stdin first + if process.stdin and not process.stdin.is_closing(): + try: + process.stdin.close() + await process.stdin.wait_closed() + logger.debug(f"Stdin closed for {server_name} (PID: {pid})") + except Exception as e: + logger.warning(f"Error closing stdin for {server_name} (PID: {pid}): {e}") + + # Attempt graceful termination + try: + process.terminate() + logger.debug(f"Sent terminate signal to {server_name} (PID: {pid})") + await asyncio.wait_for(process.wait(), timeout=5.0) + logger.info(f"Process {server_name} (PID: {pid}) terminated gracefully (return code: {process.returncode}).") + except TimeoutError: + logger.warning(f"Process {server_name} (PID: {pid}) did not terminate gracefully after 5s, killing.") + try: + process.kill() + await process.wait() # Wait for kill to complete + logger.info(f"Process {server_name} (PID: {pid}) killed (return code: {process.returncode}).") + except ProcessLookupError: + logger.warning(f"Process {server_name} (PID: {pid}) already exited before kill.") + except Exception as e_kill: + logger.error(f"Error killing process {server_name} (PID: {pid}): {e_kill}") + except ProcessLookupError: + logger.warning(f"Process {server_name} (PID: {pid}) already exited before termination.") + except Exception as e_term: + logger.error(f"Error during termination of {server_name} (PID: {pid}): {e_term}") + # Attempt kill as fallback if terminate failed and process might still be running + if process.returncode is None: + try: + process.kill() + await process.wait() + logger.info(f"Process {server_name} (PID: {pid}) killed after termination error (return code: {process.returncode}).") + except Exception as e_kill_fallback: + logger.error(f"Error killing process {server_name} (PID: {pid}) after termination error: {e_kill_fallback}") diff --git a/src/custom_mcp/protocol.py b/src/custom_mcp/protocol.py new file mode 100644 index 0000000..1bb6f33 --- /dev/null +++ b/src/custom_mcp/protocol.py @@ -0,0 +1,75 @@ +# src/mcp/protocol.py +"""Async utilities for MCP JSON-RPC communication over streams.""" + +import asyncio +import json +import logging +from typing import Any + +logger = logging.getLogger(__name__) + + +async def send_request(writer: asyncio.StreamWriter, request_dict: dict[str, Any]): + """ + Sends a JSON-RPC request dictionary to the MCP server's stdin stream. + + Args: + writer: The asyncio StreamWriter connected to the process stdin. + request_dict: The request dictionary to send. + + Raises: + ConnectionResetError: If the connection is lost during send. + Exception: For other stream writing errors. + """ + try: + request_json = json.dumps(request_dict) + "\n" + writer.write(request_json.encode("utf-8")) + await writer.drain() + logger.debug(f"Sent request ID {request_dict.get('id')}: {request_json.strip()}") + except ConnectionResetError: + logger.error(f"Connection lost while sending request ID {request_dict.get('id')}") + raise # Re-raise for the caller (MCPClient) to handle + except Exception as e: + logger.error(f"Error sending request ID {request_dict.get('id')}: {e}", exc_info=True) + raise # Re-raise for the caller + + +async def read_response(reader: asyncio.StreamReader, timeout: float) -> dict[str, Any] | None: + """ + Reads and parses a JSON-RPC response line from the MCP server's stdout stream. + + Args: + reader: The asyncio StreamReader connected to the process stdout. + timeout: Seconds to wait for a response line. + + Returns: + The parsed response dictionary, or None if timeout or error occurs. + """ + response_str = None + try: + response_json = await asyncio.wait_for(reader.readline(), timeout=timeout) + if not response_json: + logger.warning("Received empty response line (EOF?).") + return None + + response_str = response_json.decode("utf-8").strip() + if not response_str: + logger.warning("Received empty response string after strip.") + return None + + logger.debug(f"Received response line: {response_str}") + response_dict = json.loads(response_str) + return response_dict + + except TimeoutError: + logger.error(f"Timeout ({timeout}s) waiting for response.") + return None + except asyncio.IncompleteReadError: + logger.warning("Connection closed while waiting for response.") + return None + except json.JSONDecodeError as e: + logger.error(f"Error decoding JSON response: {e}. Response: '{response_str}'") + return None + except Exception as e: + logger.error(f"Error reading response: {e}", exc_info=True) + return None diff --git a/src/llm_client.py b/src/llm_client.py new file mode 100644 index 0000000..f68a3f2 --- /dev/null +++ b/src/llm_client.py @@ -0,0 +1,219 @@ +# src/llm_client.py +""" +Generic LLM client supporting multiple providers and MCP tool integration. +""" + +import json +import logging +from collections.abc import Generator +from typing import Any + +from src.custom_mcp.manager import SyncMCPManager # Updated import path +from src.providers import BaseProvider, create_llm_provider + +logger = logging.getLogger(__name__) + + +class LLMClient: + """ + Handles chat completion requests to various LLM providers through a unified + interface, integrating with MCP tools via SyncMCPManager. + """ + + def __init__( + self, + provider_name: str, + api_key: str, + mcp_manager: SyncMCPManager, + base_url: str | None = None, + ): + """ + Initialize the LLM client. + + Args: + provider_name: Name of the provider (e.g., 'openai', 'anthropic'). + api_key: API key for the provider. + mcp_manager: An initialized instance of SyncMCPManager. + base_url: Optional base URL for the provider API. + """ + logger.info(f"Initializing LLMClient for provider: {provider_name}") + self.provider: BaseProvider = create_llm_provider(provider_name, api_key, base_url) + self.mcp_manager = mcp_manager + self.mcp_tools: list[dict[str, Any]] = [] + self._refresh_mcp_tools() # Initial tool load + + def _refresh_mcp_tools(self): + """Retrieves the latest tools from the MCP manager.""" + logger.info("Refreshing MCP tools...") + try: + self.mcp_tools = self.mcp_manager.list_all_tools() + logger.info(f"Refreshed {len(self.mcp_tools)} MCP tools.") + except Exception as e: + logger.error(f"Error refreshing MCP tools: {e}", exc_info=True) + # Keep existing tools if refresh fails + + def chat_completion( + self, + messages: list[dict[str, str]], + model: str, + temperature: float = 0.4, + max_tokens: int | None = None, + stream: bool = True, + ) -> Generator[str, None, None] | dict[str, Any]: + """ + Send a chat completion request, handling potential tool calls. + + Args: + messages: List of message dictionaries ({'role': 'user'/'assistant', 'content': ...}). + model: Model identifier string. + temperature: Sampling temperature. + max_tokens: Maximum tokens to generate. + stream: Whether to stream the response. + + Returns: + If stream=True: A generator yielding content chunks. + If stream=False: A dictionary containing the final content or an error. + e.g., {"content": "..."} or {"error": "..."} + """ + # Ensure tools are up-to-date (optional, could be done less frequently) + # self._refresh_mcp_tools() + + # Convert tools to the provider-specific format + try: + provider_tools = self.provider.convert_tools(self.mcp_tools) + logger.debug(f"Converted {len(self.mcp_tools)} tools for provider {self.provider.__class__.__name__}") + except Exception as e: + logger.error(f"Error converting tools for provider: {e}", exc_info=True) + provider_tools = None # Proceed without tools if conversion fails + + try: + logger.info(f"Sending chat completion request to provider with model: {model}") + response = self.provider.create_chat_completion( + messages=messages, + model=model, + temperature=temperature, + max_tokens=max_tokens, + stream=stream, + tools=provider_tools, + ) + logger.info("Received response from provider.") + + if stream: + # Streaming with tool calls requires more complex handling (like airflow_wingman's + # process_tool_calls_and_follow_up). For now, we'll yield the initial stream + # and handle tool calls *after* the stream completes if detected (less ideal UX). + # A better approach involves checking for tool calls before streaming fully. + # This simplified version just streams the first response. + logger.info("Streaming response...") + # NOTE: This simple version doesn't handle tool calls during streaming well. + # It will stream the initial response which might *contain* the tool call request, + # but won't execute it within the stream. + return self._stream_generator(response) + + else: # Non-streaming + logger.info("Processing non-streaming response...") + if self.provider.has_tool_calls(response): + logger.info("Tool calls detected in response.") + # Simplified non-streaming tool call handling (one round) + try: + tool_calls = self.provider.parse_tool_calls(response) + logger.debug(f"Parsed tool calls: {tool_calls}") + + tool_results = [] + original_message_with_calls = self.provider.get_original_message_with_calls(response) # Provider needs to implement this + messages.append(original_message_with_calls) # Add assistant's turn with tool requests + + for tool_call in tool_calls: + server_name = tool_call.get("server_name") # Needs to be parsed by provider + func_name = tool_call.get("function_name") + func_args_str = tool_call.get("arguments") + call_id = tool_call.get("id") + + if not server_name or not func_name or func_args_str is None or call_id is None: + logger.error(f"Skipping invalid tool call data: {tool_call}") + # Add error result? + result_content = {"error": "Invalid tool call structure from LLM"} + else: + try: + # Arguments might be a JSON string, parse them + arguments = json.loads(func_args_str) + logger.info(f"Executing tool '{func_name}' on server '{server_name}' with args: {arguments}") + # Execute synchronously using the manager + execution_result = self.mcp_manager.execute_tool(server_name, func_name, arguments) + logger.debug(f"Tool execution result: {execution_result}") + + if execution_result is None: + result_content = {"error": f"Tool execution failed or timed out for {func_name}"} + elif isinstance(execution_result, dict) and "error" in execution_result: + result_content = execution_result # Propagate error from tool/server + else: + # Assuming result is the content payload + result_content = execution_result + + except json.JSONDecodeError: + logger.error(f"Failed to parse arguments for tool {func_name}: {func_args_str}") + result_content = {"error": f"Invalid arguments format for tool {func_name}"} + except Exception as exec_err: + logger.error(f"Error executing tool {func_name}: {exec_err}", exc_info=True) + result_content = {"error": f"Exception during tool execution: {str(exec_err)}"} + + # Format result for the provider's follow-up message + formatted_result = self.provider.format_tool_results(call_id, result_content) + tool_results.append(formatted_result) + messages.append(formatted_result) # Add tool result message + + # Make follow-up call + logger.info("Making follow-up request with tool results...") + follow_up_response = self.provider.create_chat_completion( + messages=messages, # Now includes assistant's turn and tool results + model=model, + temperature=temperature, + max_tokens=max_tokens, + stream=False, # Follow-up is non-streaming here + tools=provider_tools, # Pass tools again? Some providers might need it. + ) + final_content = self.provider.get_content(follow_up_response) + logger.info("Received follow-up response content.") + return {"content": final_content} + + except Exception as tool_handling_err: + logger.error(f"Error processing tool calls: {tool_handling_err}", exc_info=True) + return {"error": f"Failed to handle tool calls: {str(tool_handling_err)}"} + + else: # No tool calls + logger.info("No tool calls detected.") + content = self.provider.get_content(response) + return {"content": content} + + except Exception as e: + error_msg = f"LLM API Error: {str(e)}" + logger.error(error_msg, exc_info=True) + if stream: + # How to signal error in a stream? Yield a specific error message? + # This simple generator won't handle it well. Returning an error dict for now. + return {"error": error_msg} # Or raise? + else: + return {"error": error_msg} + + def _stream_generator(self, response: Any) -> Generator[str, None, None]: + """Helper to yield content from the provider's streaming method.""" + try: + # Use yield from for cleaner and potentially more efficient delegation + yield from self.provider.get_streaming_content(response) + except Exception as e: + logger.error(f"Error during streaming: {e}", exc_info=True) + yield json.dumps({"error": f"Streaming error: {str(e)}"}) # Yield error as JSON chunk + + +# Example of how a provider might need to implement get_original_message_with_calls +# This would be in the specific provider class (e.g., openai_provider.py) +# def get_original_message_with_calls(self, response: Any) -> Dict[str, Any]: +# # For OpenAI, the tool calls are usually in the *first* response chunk's choice delta +# # or in the non-streaming response's choice message +# # Needs careful implementation based on provider's response structure +# assistant_message = { +# "role": "assistant", +# "content": None, # Often null when tool calls are present +# "tool_calls": [...] # Extracted tool calls in provider format +# } +# return assistant_message diff --git a/src/llm_models.py b/src/llm_models.py new file mode 100644 index 0000000..67f65ae --- /dev/null +++ b/src/llm_models.py @@ -0,0 +1,61 @@ +MODELS = { + "openai": { + "name": "OpenAI", + "endpoint": "https://api.openai.com/v1", + "models": [ + { + "id": "gpt-4o", + "name": "GPT-4o", + "default": True, + "context_window": 128000, + "description": "Input $5/M tokens, Output $15/M tokens", + } + ], + }, + "anthropic": { + "name": "Anthropic", + "endpoint": "https://api.anthropic.com/v1/messages", + "models": [ + { + "id": "claude-3-7-sonnet-20250219", + "name": "Claude 3.7 Sonnet", + "default": True, + "context_window": 200000, + "description": "Input $3/M tokens, Output $15/M tokens", + }, + { + "id": "claude-3-5-haiku-20241022", + "name": "Claude 3.5 Haiku", + "default": False, + "context_window": 200000, + "description": "Input $0.80/M tokens, Output $4/M tokens", + }, + ], + }, + "google": { + "name": "Google Gemini", + "endpoint": "https://generativelanguage.googleapis.com/v1beta/generateContent", + "models": [ + { + "id": "gemini-2.0-flash", + "name": "Gemini 2.0 Flash", + "default": True, + "context_window": 1000000, + "description": "Input $0.1/M tokens, Output $0.4/M tokens", + } + ], + }, + "openrouter": { + "name": "OpenRouter", + "endpoint": "https://openrouter.ai/api/v1/chat/completions", + "models": [ + { + "id": "custom", + "name": "Custom Model", + "default": False, + "context_window": 128000, # Default context window, will be updated based on model + "description": "Enter any model name supported by OpenRouter (e.g., 'anthropic/claude-3-opus', 'meta-llama/llama-2-70b')", + }, + ], + }, +} diff --git a/src/mcp_manager.py b/src/mcp_manager.py deleted file mode 100644 index 82c905a..0000000 --- a/src/mcp_manager.py +++ /dev/null @@ -1,234 +0,0 @@ -"""Synchronous wrapper for managing MCP servers using our custom implementation.""" - -import asyncio -import importlib.resources -import json -import logging # Import logging -import threading - -from custom_mcp_client import MCPClient, run_interaction - -# Configure basic logging for the application if not already configured -# This basic config helps ensure logs are visible during development -logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") - -# Get a logger for this module -logger = logging.getLogger(__name__) - - -class SyncMCPManager: - """Synchronous wrapper for managing MCP servers and interactions""" - - def __init__(self, config_path: str = "config/mcp_config.json"): - self.config_path = config_path - self.config = None - self.servers = {} - self.initialized = False - self._lock = threading.Lock() - logger.info(f"Initializing SyncMCPManager with config path: {config_path}") - self._load_config() - - def _load_config(self): - """Load MCP configuration from JSON file using importlib""" - logger.debug(f"Attempting to load MCP config from: {self.config_path}") - try: - # First try to load as a package resource - try: - # Try anchoring to the project name defined in pyproject.toml - # This *might* work depending on editable installs or context. - resource_path = importlib.resources.files("streamlit-chat-app").joinpath(self.config_path) - with resource_path.open("r") as f: - self.config = json.load(f) - logger.debug("Loaded config via importlib.resources anchored to 'streamlit-chat-app'.") - # REMOVED: raise FileNotFoundError - - except (ImportError, ModuleNotFoundError, TypeError, FileNotFoundError, NotADirectoryError): # Added NotADirectoryError - logger.debug("Failed to load via importlib.resources, falling back to direct file access.") - # Fall back to direct file access relative to CWD - with open(self.config_path) as f: - self.config = json.load(f) - logger.debug("Loaded config via direct file access.") - - logger.info("MCP configuration loaded successfully.") - logger.debug(f"Config content: {self.config}") # Log content only if loaded - - except FileNotFoundError: - logger.error(f"MCP config file not found at {self.config_path}") - self.config = None - except json.JSONDecodeError as e: - logger.error(f"Error decoding JSON from MCP config file {self.config_path}: {e}") - self.config = None - except Exception as e: - logger.error(f"Error loading MCP config from {self.config_path}: {e}", exc_info=True) - self.config = None - - def initialize(self) -> bool: - """Initialize and start all MCP servers synchronously""" - logger.info("Initialize requested.") - if not self.config: - logger.warning("Initialization skipped: No configuration loaded.") - return False - if not self.config.get("mcpServers"): - logger.warning("Initialization skipped: No 'mcpServers' defined in configuration.") - return False - - if self.initialized: - logger.debug("Initialization skipped: Already initialized.") - return True - - with self._lock: - if self.initialized: # Double-check after acquiring lock - logger.debug("Initialization skipped inside lock: Already initialized.") - return True - - logger.info("Starting asynchronous initialization...") - # Run async initialization in a new event loop - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) # Ensure this loop is used by tasks - success = loop.run_until_complete(self._async_initialize()) - loop.close() - asyncio.set_event_loop(None) # Clean up - - if success: - logger.info("Asynchronous initialization completed successfully.") - self.initialized = True - else: - logger.error("Asynchronous initialization failed.") - self.initialized = False # Ensure state is False on failure - - return self.initialized - - async def _async_initialize(self) -> bool: - """Async implementation of server initialization""" - logger.debug("Starting _async_initialize...") - all_success = True - if not self.config or not self.config.get("mcpServers"): - logger.warning("_async_initialize: No config or mcpServers found.") - return False - - tasks = [] - server_names = list(self.config["mcpServers"].keys()) - - async def start_server(server_name, server_config): - logger.info(f"Initializing server: {server_name}") - try: - client = MCPClient(server_name=server_name, command=server_config.get("command"), args=server_config.get("args", []), env=server_config.get("env", {})) - - logger.debug(f"Attempting to start client for {server_name}...") - if await client.start(): - logger.info(f"Client for {server_name} started successfully.") - tools = await client.list_tools() - logger.info(f"Tools listed for {server_name}: {len(tools)}") - self.servers[server_name] = {"client": client, "tools": tools} - return True - else: - logger.error(f"Failed to start MCP server: {server_name}") - return False - except Exception as e: - logger.error(f"Error initializing server {server_name}: {e}", exc_info=True) - return False - - # Start servers concurrently - for server_name in server_names: - server_config = self.config["mcpServers"][server_name] - tasks.append(start_server(server_name, server_config)) - - results = await asyncio.gather(*tasks) - - # Check if all servers started successfully - all_success = all(results) - - if all_success: - logger.debug("_async_initialize completed: All servers started successfully.") - else: - failed_servers = [server_names[i] for i, res in enumerate(results) if not res] - logger.error(f"_async_initialize completed with failures. Failed servers: {failed_servers}") - # Optionally shutdown servers that did start if partial success is not desired - # await self._async_shutdown() # Uncomment to enforce all-or-nothing startup - - return all_success - - def shutdown(self): - """Shut down all MCP servers synchronously""" - logger.info("Shutdown requested.") - if not self.initialized: - logger.debug("Shutdown skipped: Not initialized.") - return - - with self._lock: - if not self.initialized: - logger.debug("Shutdown skipped inside lock: Not initialized.") - return - - logger.info("Starting asynchronous shutdown...") - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop.run_until_complete(self._async_shutdown()) - loop.close() - asyncio.set_event_loop(None) - - self.servers = {} - self.initialized = False - logger.info("Shutdown complete.") - - async def _async_shutdown(self): - """Async implementation of server shutdown""" - logger.debug("Starting _async_shutdown...") - tasks = [] - for server_name, server_info in self.servers.items(): - logger.debug(f"Initiating shutdown for server: {server_name}") - tasks.append(server_info["client"].stop()) - - results = await asyncio.gather(*tasks, return_exceptions=True) - for i, result in enumerate(results): - server_name = list(self.servers.keys())[i] - if isinstance(result, Exception): - logger.error(f"Error shutting down server {server_name}: {result}", exc_info=result) - else: - logger.debug(f"Shutdown completed for server: {server_name}") - logger.debug("_async_shutdown finished.") - - # Updated process_query signature - def process_query(self, query: str, model_name: str, api_key: str, base_url: str | None) -> dict: - """ - Process a query using MCP tools synchronously - - Args: - query: The user's input query. - model_name: The model to use for processing. - api_key: The OpenAI API key. - base_url: The OpenAI API base URL. - - Returns: - Dictionary containing response or error. - """ - if not self.initialized and not self.initialize(): - logger.error("process_query called but MCP manager failed to initialize.") - return {"error": "Failed to initialize MCP servers"} - - logger.debug(f"Processing query synchronously: '{query}'") - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - # Pass api_key and base_url to _async_process_query - result = loop.run_until_complete(self._async_process_query(query, model_name, api_key, base_url)) - logger.debug(f"Synchronous query processing result: {result}") - return result - except Exception as e: - logger.error(f"Error during synchronous query processing: {e}", exc_info=True) - return {"error": f"Processing error: {str(e)}"} - finally: - loop.close() - - # Updated _async_process_query signature - async def _async_process_query(self, query: str, model_name: str, api_key: str, base_url: str | None) -> dict: - """Async implementation of query processing""" - # Pass api_key, base_url, and the MCP config separately to run_interaction - return await run_interaction( - user_query=query, - model_name=model_name, - api_key=api_key, - base_url=base_url, - mcp_config=self.config, # self.config only contains MCP server definitions now - stream=False, - ) diff --git a/src/providers/__init__.py b/src/providers/__init__.py new file mode 100644 index 0000000..b7c1dcc --- /dev/null +++ b/src/providers/__init__.py @@ -0,0 +1,71 @@ +# src/providers/__init__.py +import logging + +from providers.base import BaseProvider + +# Import specific provider implementations here as they are created +from providers.openai_provider import OpenAIProvider + +# from .anthropic_provider import AnthropicProvider +# from .google_provider import GoogleProvider +# from .openrouter_provider import OpenRouterProvider + +logger = logging.getLogger(__name__) + +# Map provider names (lowercase) to their corresponding class implementations +PROVIDER_MAP: dict[str, type[BaseProvider]] = { + "openai": OpenAIProvider, + # "anthropic": AnthropicProvider, + # "google": GoogleProvider, + # "openrouter": OpenRouterProvider, +} + + +def register_provider(name: str, provider_class: type[BaseProvider]): + """Registers a provider class.""" + if name.lower() in PROVIDER_MAP: + logger.warning(f"Provider '{name}' is already registered. Overwriting.") + PROVIDER_MAP[name.lower()] = provider_class + logger.info(f"Registered provider: {name}") + + +def create_llm_provider(provider_name: str, api_key: str, base_url: str | None = None) -> BaseProvider: + """ + Factory function to create an instance of a specific LLM provider. + + Args: + provider_name: The name of the provider (e.g., 'openai', 'anthropic'). + api_key: The API key for the provider. + base_url: Optional base URL for the provider's API. + + Returns: + An instance of the requested BaseProvider subclass. + + Raises: + ValueError: If the requested provider_name is not registered. + """ + provider_class = PROVIDER_MAP.get(provider_name.lower()) + + if provider_class is None: + available = ", ".join(PROVIDER_MAP.keys()) or "None" + raise ValueError(f"Unsupported LLM provider: '{provider_name}'. Available providers: {available}") + + logger.info(f"Creating LLM provider instance for: {provider_name}") + try: + return provider_class(api_key=api_key, base_url=base_url) + except Exception as e: + logger.error(f"Failed to instantiate provider '{provider_name}': {e}", exc_info=True) + raise RuntimeError(f"Could not create provider '{provider_name}'.") from e + + +def get_available_providers() -> list[str]: + """Returns a list of registered provider names.""" + return list(PROVIDER_MAP.keys()) + + +# Example of how specific providers would register themselves if structured as plugins, +# but for now, we'll explicitly import and map them above. +# def load_providers(): +# # Potentially load providers dynamically if designed as plugins +# pass +# load_providers() diff --git a/src/providers/base.py b/src/providers/base.py new file mode 100644 index 0000000..25fb02c --- /dev/null +++ b/src/providers/base.py @@ -0,0 +1,140 @@ +# src/providers/base.py +import abc +from collections.abc import Generator +from typing import Any + + +class BaseProvider(abc.ABC): + """ + Abstract base class for LLM providers. + + Defines the common interface for interacting with different LLM APIs, + including handling chat completions and tool usage. + """ + + def __init__(self, api_key: str, base_url: str | None = None): + """ + Initialize the provider. + + Args: + api_key: The API key for the provider. + base_url: Optional base URL for the provider's API. + """ + self.api_key = api_key + self.base_url = base_url + + @abc.abstractmethod + def create_chat_completion( + self, + messages: list[dict[str, str]], + model: str, + temperature: float = 0.4, + max_tokens: int | None = None, + stream: bool = True, + tools: list[dict[str, Any]] | None = None, + ) -> Any: + """ + Send a chat completion request to the LLM provider. + + Args: + messages: List of message dictionaries with 'role' and 'content'. + model: Model identifier. + temperature: Sampling temperature (0-1). + max_tokens: Maximum tokens to generate. + stream: Whether to stream the response. + tools: Optional list of tools in the provider-specific format. + + Returns: + Provider-specific response object (e.g., API response, stream object). + """ + pass + + @abc.abstractmethod + def get_streaming_content(self, response: Any) -> Generator[str, None, None]: + """ + Extracts and yields content chunks from a streaming response object. + + Args: + response: The streaming response object returned by create_chat_completion. + + Yields: + String chunks of the response content. + """ + pass + + @abc.abstractmethod + def get_content(self, response: Any) -> str: + """ + Extracts the complete content from a non-streaming response object. + + Args: + response: The non-streaming response object. + + Returns: + The complete response content as a string. + """ + pass + + @abc.abstractmethod + def has_tool_calls(self, response: Any) -> bool: + """ + Checks if the response object contains tool calls. + + Args: + response: The response object (streaming or non-streaming). + + Returns: + True if tool calls are present, False otherwise. + """ + pass + + @abc.abstractmethod + def parse_tool_calls(self, response: Any) -> list[dict[str, Any]]: + """ + Parses tool calls from the response object. + + Args: + response: The response object containing tool calls. + + Returns: + A list of dictionaries, each representing a tool call with details + like 'id', 'function_name', 'arguments'. The exact structure might + vary slightly based on provider needs but should contain enough + info for execution. + """ + pass + + @abc.abstractmethod + def format_tool_results(self, tool_call_id: str, result: Any) -> dict[str, Any]: + """ + Formats the result of a tool execution into the structure expected + by the provider for follow-up requests. + + Args: + tool_call_id: The unique ID of the tool call (from parse_tool_calls). + result: The data returned by the tool execution. + + Returns: + A dictionary representing the tool result in the provider's format. + """ + pass + + @abc.abstractmethod + def convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]: + """ + Converts a list of tools from the standard internal format to the + provider-specific format required for the API call. + + Args: + tools: List of tool definitions in the standard internal format. + Each dict contains 'server_name', 'name', 'description', 'input_schema'. + + Returns: + List of tool definitions in the provider-specific format. + """ + pass + + # Optional: Add a method for follow-up completions if the provider API + # requires a specific structure different from just appending messages. + # def create_follow_up_completion(...) -> Any: + # pass diff --git a/src/providers/openai_provider.py b/src/providers/openai_provider.py new file mode 100644 index 0000000..7306081 --- /dev/null +++ b/src/providers/openai_provider.py @@ -0,0 +1,239 @@ +# src/providers/openai_provider.py +import json +import logging +from collections.abc import Generator +from typing import Any + +from openai import OpenAI, Stream +from openai.types.chat import ChatCompletion, ChatCompletionChunk +from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall + +from providers.base import BaseProvider +from src.llm_models import MODELS # Use absolute import + +logger = logging.getLogger(__name__) + + +class OpenAIProvider(BaseProvider): + """Provider implementation for OpenAI and compatible APIs.""" + + def __init__(self, api_key: str, base_url: str | None = None): + # Use default OpenAI endpoint if base_url is not provided + effective_base_url = base_url or MODELS.get("openai", {}).get("endpoint") + super().__init__(api_key, effective_base_url) + logger.info(f"Initializing OpenAIProvider with base URL: {self.base_url}") + try: + # TODO: Add default headers like in original client? + self.client = OpenAI(api_key=self.api_key, base_url=self.base_url) + except Exception as e: + logger.error(f"Failed to initialize OpenAI client: {e}", exc_info=True) + raise + + def create_chat_completion( + self, + messages: list[dict[str, str]], + model: str, + temperature: float = 0.4, + max_tokens: int | None = None, + stream: bool = True, + tools: list[dict[str, Any]] | None = None, + ) -> Stream[ChatCompletionChunk] | ChatCompletion: + """Creates a chat completion using the OpenAI API.""" + logger.debug(f"OpenAI create_chat_completion called. Stream: {stream}, Tools: {bool(tools)}") + try: + completion_params = { + "model": model, + "messages": messages, + "temperature": temperature, + "max_tokens": max_tokens, + "stream": stream, + } + if tools: + completion_params["tools"] = tools + completion_params["tool_choice"] = "auto" # Let OpenAI decide when to use tools + + # Remove None values like max_tokens if not provided + completion_params = {k: v for k, v in completion_params.items() if v is not None} + + # --- Added Debug Logging --- + log_params = completion_params.copy() + # Avoid logging full messages if they are too long + if "messages" in log_params: + log_params["messages"] = [ + {k: (v[:100] + "..." if isinstance(v, str) and len(v) > 100 else v) for k, v in msg.items()} + for msg in log_params["messages"][-2:] # Log last 2 messages summary + ] + # Specifically log tools structure if present + tools_log = log_params.get("tools", "Not Present") + logger.debug(f"Calling OpenAI API. Model: {log_params.get('model')}, Stream: {log_params.get('stream')}, Tools: {tools_log}") + logger.debug(f"Full API Params (messages summarized): {log_params}") + # --- End Added Debug Logging --- + + response = self.client.chat.completions.create(**completion_params) + logger.debug("OpenAI API call successful.") + return response + except Exception as e: + logger.error(f"OpenAI API error: {e}", exc_info=True) + # Re-raise for the LLMClient to handle + raise + + def get_streaming_content(self, response: Stream[ChatCompletionChunk]) -> Generator[str, None, None]: + """Yields content chunks from an OpenAI streaming response.""" + logger.debug("Processing OpenAI stream...") + full_delta = "" + try: + for chunk in response: + delta = chunk.choices[0].delta.content + if delta: + full_delta += delta + yield delta + logger.debug(f"Stream finished. Total delta length: {len(full_delta)}") + except Exception as e: + logger.error(f"Error processing OpenAI stream: {e}", exc_info=True) + # Yield an error message? Or let the generator stop? + yield json.dumps({"error": f"Stream processing error: {str(e)}"}) + + def get_content(self, response: ChatCompletion) -> str: + """Extracts content from a non-streaming OpenAI response.""" + try: + content = response.choices[0].message.content + logger.debug(f"Extracted content (length {len(content) if content else 0}) from non-streaming response.") + return content or "" # Return empty string if content is None + except Exception as e: + logger.error(f"Error extracting content from OpenAI response: {e}", exc_info=True) + return f"[Error extracting content: {str(e)}]" + + def has_tool_calls(self, response: Stream[ChatCompletionChunk] | ChatCompletion) -> bool: + """Checks if the OpenAI response contains tool calls.""" + try: + if isinstance(response, ChatCompletion): # Non-streaming + return bool(response.choices[0].message.tool_calls) + elif hasattr(response, "_iterator"): # Check if it looks like our stream wrapper + # This is tricky for streams. We'd need to peek at the first chunk(s) + # or buffer the response. For simplicity, this check might be unreliable + # for streams *before* they are consumed. LLMClient needs robust handling. + logger.warning("has_tool_calls check on a stream is unreliable before consumption.") + # A more robust check would involve consuming the start of the stream + # or relying on the structure after consumption. + return False # Assume no for unconsumed stream for now + else: + # If it's already consumed stream or unexpected type + logger.warning(f"has_tool_calls received unexpected type: {type(response)}") + return False + except Exception as e: + logger.error(f"Error checking for tool calls: {e}", exc_info=True) + return False + + def parse_tool_calls(self, response: ChatCompletion) -> list[dict[str, Any]]: + """Parses tool calls from a non-streaming OpenAI response.""" + # This implementation assumes a non-streaming response or a fully buffered stream + parsed_calls = [] + try: + if not isinstance(response, ChatCompletion): + logger.error(f"parse_tool_calls expects ChatCompletion, got {type(response)}") + # Attempt to handle buffered stream if possible? Complex. + return [] + + tool_calls: list[ChatCompletionMessageToolCall] | None = response.choices[0].message.tool_calls + if not tool_calls: + return [] + + logger.debug(f"Parsing {len(tool_calls)} tool calls from OpenAI response.") + for call in tool_calls: + if call.type == "function": + # Attempt to parse server_name from function name if prefixed + # e.g., "server-name__actual-tool-name" + parts = call.function.name.split("__", 1) + if len(parts) == 2: + server_name, func_name = parts + else: + # If no prefix, how do we know the server? Needs refinement. + # Defaulting to None or a default server? Log warning. + logger.warning(f"Could not determine server_name from tool name '{call.function.name}'. Assuming default or error needed.") + server_name = None # Or raise error, or use a default? + func_name = call.function.name + + parsed_calls.append({ + "id": call.id, + "server_name": server_name, # May be None if not prefixed + "function_name": func_name, + "arguments": call.function.arguments, # Arguments are already a string here + }) + else: + logger.warning(f"Unsupported tool call type: {call.type}") + + return parsed_calls + except Exception as e: + logger.error(f"Error parsing OpenAI tool calls: {e}", exc_info=True) + return [] # Return empty list on error + + def format_tool_results(self, tool_call_id: str, result: Any) -> dict[str, Any]: + """Formats a tool result for an OpenAI follow-up request.""" + # Result might be a dict (including potential errors) or simple string/number + # OpenAI expects the content to be a string, often JSON. + try: + if isinstance(result, dict): + content = json.dumps(result) + else: + content = str(result) # Ensure it's a string + except Exception as e: + logger.error(f"Error JSON-encoding tool result for {tool_call_id}: {e}") + content = json.dumps({"error": "Failed to encode tool result", "original_type": str(type(result))}) + + logger.debug(f"Formatting tool result for call ID {tool_call_id}") + return { + "role": "tool", + "tool_call_id": tool_call_id, + "content": content, + } + + def convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Converts internal tool format to OpenAI's format.""" + openai_tools = [] + logger.debug(f"Converting {len(tools)} tools to OpenAI format.") + for tool in tools: + server_name = tool.get("server_name") + tool_name = tool.get("name") + description = tool.get("description") + input_schema = tool.get("inputSchema") + + if not server_name or not tool_name or not description or not input_schema: + logger.warning(f"Skipping invalid tool definition during conversion: {tool}") + continue + + # Prefix tool name with server name to avoid clashes and allow routing + prefixed_tool_name = f"{server_name}__{tool_name}" + + openai_tool_format = { + "type": "function", + "function": { + "name": prefixed_tool_name, + "description": description, + "parameters": input_schema, # OpenAI uses JSON Schema directly + }, + } + openai_tools.append(openai_tool_format) + logger.debug(f"Converted tool: {prefixed_tool_name}") + + return openai_tools + + # Helper needed by LLMClient's current tool handling logic + def get_original_message_with_calls(self, response: ChatCompletion) -> dict[str, Any]: + """Extracts the assistant's message containing tool calls.""" + try: + if isinstance(response, ChatCompletion) and response.choices[0].message.tool_calls: + message = response.choices[0].message + # Convert Pydantic model to dict for message history + return message.model_dump(exclude_unset=True) + else: + logger.warning("Could not extract original message with tool calls from response.") + # Return a placeholder or raise error? + return {"role": "assistant", "content": "[Could not extract tool calls message]"} + except Exception as e: + logger.error(f"Error extracting original message with calls: {e}", exc_info=True) + return {"role": "assistant", "content": f"[Error extracting tool calls message: {str(e)}]"} + + +# Register this provider (if using the registration mechanism) +# from . import register_provider +# register_provider("openai", OpenAIProvider)