diff --git a/pyproject.toml b/pyproject.toml index 04b1a34..fe11f97 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,8 @@ build-backend = "hatchling.build" [tool.hatch.build.targets.wheel] packages = ["src"] +[tool.hatch.build.targets.wheel.force-include] +"config" = "config" [tool.ruff] line-length = 200 diff --git a/src/custom_mcp_client/client.py b/src/custom_mcp_client/client.py index 7a0b372..ba7ef5a 100644 --- a/src/custom_mcp_client/client.py +++ b/src/custom_mcp_client/client.py @@ -2,11 +2,15 @@ import asyncio import json +import logging import os from collections.abc import AsyncGenerator from openai import AsyncOpenAI +# Get a logger for this module +logger = logging.getLogger(__name__) + class MCPClient: """Lightweight MCP client with JSON-RPC communication.""" @@ -21,180 +25,291 @@ class MCPClient: self.request_id = 0 self.responses = {} self._shutdown = False + # Use a logger specific to this client instance + self.logger = logging.getLogger(f"{__name__}.{self.server_name}") async def _receive_loop(self): """Listen for responses from the MCP server.""" try: - while not self.process.stdout.at_eof(): - line = await self.process.stdout.readline() - if not line: + while self.process and self.process.stdout and not self.process.stdout.at_eof(): + line_bytes = await self.process.stdout.readline() + if not line_bytes: + self.logger.debug("STDOUT EOF reached.") break + line_str = line_bytes.decode().strip() + self.logger.debug(f"STDOUT Raw line: {line_str}") try: - message = json.loads(line.decode().strip()) + message = json.loads(line_str) if "jsonrpc" in message and "id" in message and ("result" in message or "error" in message): + self.logger.debug(f"STDOUT Parsed response for ID {message['id']}") self.responses[message["id"]] = message - except Exception: - pass - except Exception: - pass + elif "jsonrpc" in message and "method" in message: + self.logger.debug(f"STDOUT Received notification: {message.get('method')}") + else: + self.logger.debug(f"STDOUT Parsed non-response/notification JSON: {message}") + except json.JSONDecodeError: + self.logger.warning("STDOUT Failed to parse line as JSON.") + except Exception as e: + self.logger.error(f"STDOUT Error processing line: {e}", exc_info=True) + except asyncio.CancelledError: + self.logger.debug("STDOUT Receive loop cancelled.") + except Exception as e: + self.logger.error(f"STDOUT Receive loop error: {e}", exc_info=True) + finally: + self.logger.debug("STDOUT Receive loop finished.") + + async def _stderr_loop(self): + """Listen for stderr messages from the MCP server.""" + try: + while self.process and self.process.stderr and not self.process.stderr.at_eof(): + line_bytes = await self.process.stderr.readline() + if not line_bytes: + self.logger.debug("STDERR EOF reached.") + break + line_str = line_bytes.decode().strip() + self.logger.warning(f"STDERR: {line_str}") # Log stderr as warning + except asyncio.CancelledError: + self.logger.debug("STDERR Stderr loop cancelled.") + except Exception as e: + self.logger.error(f"STDERR Stderr loop error: {e}", exc_info=True) + finally: + self.logger.debug("STDERR Stderr loop finished.") async def _send_message(self, message: dict) -> bool: """Send a JSON-RPC message to the MCP server.""" - if not self.process: + if not self.process or not self.process.stdin: + self.logger.warning("STDIN Cannot send message, process or stdin not available.") return False try: data = json.dumps(message) + "\n" + self.logger.debug(f"STDIN Sending: {data.strip()}") self.process.stdin.write(data.encode()) await self.process.stdin.drain() return True - except Exception: + except ConnectionResetError: + self.logger.error("STDIN Connection reset while sending message.") + self.process = None # Mark process as dead + return False + except Exception as e: + self.logger.error(f"STDIN Error sending message: {e}", exc_info=True) return False async def start(self) -> bool: """Start the MCP server process.""" - # Expand ~ in paths + self.logger.info("Attempting to start server...") + # Expand ~ in paths and prepare args expanded_args = [] - for a in self.args: - if isinstance(a, str) and "~" in a: - expanded_args.append(os.path.expanduser(a)) - else: - expanded_args.append(a) + try: + for a in self.args: + if isinstance(a, str) and "~" in a: + expanded_args.append(os.path.expanduser(a)) + else: + expanded_args.append(str(a)) # Ensure all args are strings + except Exception as e: + self.logger.error(f"Error expanding arguments: {e}", exc_info=True) + return False # Set up environment env_vars = os.environ.copy() if self.env: env_vars.update(self.env) + self.logger.debug(f"Command: {self.command}") + self.logger.debug(f"Expanded Args: {expanded_args}") + # Avoid logging full env unless necessary for debugging sensitive info + # self.logger.debug(f"Environment: {env_vars}") + try: # Start the subprocess + self.logger.debug("Creating subprocess...") self.process = await asyncio.create_subprocess_exec( - self.command, *expanded_args, stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, env=env_vars + self.command, + *expanded_args, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, # Capture stderr + env=env_vars, ) + self.logger.info(f"Subprocess created with PID: {self.process.pid}") - # Start the receive loop + # Start the receive loops asyncio.create_task(self._receive_loop()) + asyncio.create_task(self._stderr_loop()) # Start stderr loop # Initialize the server - return await self._initialize() - except Exception: + self.logger.debug("Attempting initialization handshake...") + init_success = await self._initialize() + if init_success: + self.logger.info("Initialization handshake successful (or skipped).") + # Add delay after successful start + await asyncio.sleep(0.5) + self.logger.debug("Post-initialization delay complete.") + return True + else: + self.logger.error("Initialization handshake failed.") + await self.stop() # Ensure cleanup if init fails + return False + except FileNotFoundError: + self.logger.error(f"Error starting subprocess: Command not found: '{self.command}'") + return False + except Exception as e: + self.logger.error(f"Error starting subprocess: {e}", exc_info=True) return False async def _initialize(self) -> bool: - """Initialize the MCP server connection.""" + """Initialize the MCP server connection. Modified to not wait for response.""" + self.logger.debug("Sending 'initialize' request...") if not self.process: + self.logger.warning("Cannot initialize, process not running.") return False # Send initialize request self.request_id += 1 req_id = self.request_id initialize_req = {"jsonrpc": "2.0", "id": req_id, "method": "initialize", "params": {"protocolVersion": "2024-11-05", "clientInfo": {"name": "CustomMCPClient", "version": "1.0.0"}}} - await self._send_message(initialize_req) + if not await self._send_message(initialize_req): + self.logger.warning("Failed to send 'initialize' request.") + # Continue anyway for non-compliant servers - # Wait for response - start_time = asyncio.get_event_loop().time() - timeout = 10 - while asyncio.get_event_loop().time() - start_time < timeout: - if req_id in self.responses: - resp = self.responses[req_id] - del self.responses[req_id] + # Send initialized notification immediately + self.logger.debug("Sending 'initialized' notification...") + notify = {"jsonrpc": "2.0", "method": "notifications/initialized"} + if await self._send_message(notify): + self.logger.debug("'initialized' notification sent.") + else: + self.logger.warning("Failed to send 'initialized' notification.") + # Still return True as the server might be running - if "error" in resp: - return False - - # Send initialized notification - notify = {"jsonrpc": "2.0", "method": "notifications/initialized"} - await self._send_message(notify) - return True - - await asyncio.sleep(0.05) - - return False + self.logger.info("Skipping wait for 'initialize' response (assuming non-compliant server).") + return True # Assume success without waiting for response async def list_tools(self) -> list[dict]: """List available tools from the MCP server.""" if not self.process: + self.logger.warning("Cannot list tools, process not running.") return [] + self.logger.debug("Sending 'tools/list' request...") self.request_id += 1 req_id = self.request_id req = {"jsonrpc": "2.0", "id": req_id, "method": "tools/list", "params": {}} - await self._send_message(req) + if not await self._send_message(req): + self.logger.error("Failed to send 'tools/list' request.") + return [] # Wait for response + self.logger.debug(f"Waiting for 'tools/list' response (ID: {req_id})...") start_time = asyncio.get_event_loop().time() - timeout = 10 + timeout = 10 # seconds while asyncio.get_event_loop().time() - start_time < timeout: if req_id in self.responses: - resp = self.responses[req_id] - del self.responses[req_id] + resp = self.responses.pop(req_id) + self.logger.debug(f"Received 'tools/list' response: {resp}") if "error" in resp: + self.logger.error(f"'tools/list' error response: {resp['error']}") return [] if "result" in resp and "tools" in resp["result"]: self.tools = resp["result"]["tools"] + self.logger.info(f"Successfully listed tools: {len(self.tools)}") return self.tools + else: + self.logger.error("Invalid 'tools/list' response format.") + return [] await asyncio.sleep(0.05) + self.logger.error(f"'tools/list' request timed out after {timeout} seconds.") return [] async def call_tool(self, tool_name: str, arguments: dict) -> dict: """Call a tool on the MCP server.""" if not self.process: + self.logger.warning(f"Cannot call tool '{tool_name}', process not running.") return {"error": "Server not started"} + self.logger.debug(f"Sending 'tools/call' request for tool '{tool_name}'...") self.request_id += 1 req_id = self.request_id req = {"jsonrpc": "2.0", "id": req_id, "method": "tools/call", "params": {"name": tool_name, "arguments": arguments}} - await self._send_message(req) + if not await self._send_message(req): + self.logger.error("Failed to send 'tools/call' request.") + return {"error": "Failed to send tool call request"} # Wait for response + self.logger.debug(f"Waiting for 'tools/call' response (ID: {req_id})...") start_time = asyncio.get_event_loop().time() - timeout = 30 + timeout = 30 # seconds while asyncio.get_event_loop().time() - start_time < timeout: if req_id in self.responses: - resp = self.responses[req_id] - del self.responses[req_id] + resp = self.responses.pop(req_id) + self.logger.debug(f"Received 'tools/call' response: {resp}") if "error" in resp: + self.logger.error(f"'tools/call' error response: {resp['error']}") return {"error": str(resp["error"])} if "result" in resp: + self.logger.info(f"Tool '{tool_name}' executed successfully.") return resp["result"] + else: + self.logger.error("Invalid 'tools/call' response format.") + return {"error": "Invalid tool call response format"} await asyncio.sleep(0.05) + self.logger.error(f"Tool call '{tool_name}' timed out after {timeout} seconds.") return {"error": f"Tool call timed out after {timeout}s"} async def stop(self): """Stop the MCP server process.""" + self.logger.info("Attempting to stop server...") if self._shutdown or not self.process: + self.logger.debug("Server already stopped or not running.") return self._shutdown = True + proc = self.process # Keep a local reference + self.process = None # Prevent further operations try: # Send shutdown notification + self.logger.debug("Sending 'shutdown' notification...") notify = {"jsonrpc": "2.0", "method": "shutdown"} - await self._send_message(notify) - await asyncio.sleep(0.5) + await self._send_message(notify) # Use the method which now handles None process + await asyncio.sleep(0.5) # Give server time to process # Close stdin - if self.process.stdin: - self.process.stdin.close() + if proc and proc.stdin: + try: + if not proc.stdin.is_closing(): + proc.stdin.close() + await proc.stdin.wait_closed() + self.logger.debug("Stdin closed.") + except Exception as e: + self.logger.warning(f"Error closing stdin: {e}", exc_info=True) # Terminate the process - self.process.terminate() - try: - await asyncio.wait_for(self.process.wait(), timeout=1.0) - except TimeoutError: - self.process.kill() - - except Exception: - pass + if proc: + self.logger.debug(f"Terminating process {proc.pid}...") + proc.terminate() + try: + await asyncio.wait_for(proc.wait(), timeout=2.0) + self.logger.info(f"Process {proc.pid} terminated gracefully.") + except TimeoutError: + self.logger.warning(f"Process {proc.pid} did not terminate gracefully, killing...") + proc.kill() + await proc.wait() + self.logger.info(f"Process {proc.pid} killed.") + except Exception as e: + self.logger.error(f"Error waiting for process termination: {e}", exc_info=True) + except Exception as e: + self.logger.error(f"Error during shutdown sequence: {e}", exc_info=True) finally: + self.logger.debug("Stop sequence finished.") + # Ensure self.process is None even if errors occurred self.process = None @@ -203,17 +318,20 @@ async def process_tool_call(tool_call: dict, servers: dict[str, MCPClient]) -> d func_name = tool_call["function"]["name"] try: func_args = json.loads(tool_call["function"].get("arguments", "{}")) - except json.JSONDecodeError: + except json.JSONDecodeError as e: + logger.error(f"Invalid tool arguments format for {func_name}: {e}") return {"error": "Invalid arguments format"} # Parse server_name and tool_name from function name parts = func_name.split("_", 1) if len(parts) != 2: + logger.error(f"Invalid tool function name format: {func_name}") return {"error": "Invalid function name format"} server_name, tool_name = parts if server_name not in servers: + logger.error(f"Tool call for unknown server: {server_name}") return {"error": f"Unknown server: {server_name}"} # Call the tool @@ -234,82 +352,197 @@ async def run_interaction(user_query: str, model_name: str, config: dict, stream Dictionary containing response or AsyncGenerator for streaming """ # Get OpenAI configuration - api_key = config["models"][0]["apiKey"] + # TODO: Handle multiple models in config? + if not config.get("models"): + logger.error("No models defined in MCP configuration.") + # This function needs to return something compatible, maybe raise error? + # For now, returning error dict for non-streaming case + if not stream: + return {"error": "No models defined in MCP configuration."} + else: + # How to handle error in async generator? Yield an error dict? + async def error_gen(): + yield {"error": "No models defined in MCP configuration."} + + return error_gen() + + api_key = config["models"][0].get("apiKey") base_url = config["models"][0].get("apiBase", "https://api.openai.com/v1") + if not api_key: + logger.error("apiKey missing for the first model in MCP configuration.") + if not stream: + return {"error": "apiKey missing for the first model in MCP configuration."} + else: + + async def error_gen(): + yield {"error": "apiKey missing for the first model in MCP configuration."} + + return error_gen() + # Start MCP servers servers = {} all_functions = [] - for server_name, server_config in config["mcpServers"].items(): - client = MCPClient(server_name=server_name, command=server_config.get("command"), args=server_config.get("args", []), env=server_config.get("env", {})) + if config.get("mcpServers"): + for server_name, server_config in config["mcpServers"].items(): + client = MCPClient(server_name=server_name, command=server_config.get("command"), args=server_config.get("args", []), env=server_config.get("env", {})) - if await client.start(): - tools = await client.list_tools() - for tool in tools: - all_functions.append({"name": f"{server_name}_{tool['name']}", "description": tool.get("description", ""), "parameters": tool.get("inputSchema", {})}) - servers[server_name] = client + if await client.start(): + tools = await client.list_tools() + for tool in tools: + # Ensure parameters is a dict, default to empty if missing or not dict + params = tool.get("inputSchema", {}) + if not isinstance(params, dict): + logger.warning(f"Tool '{tool['name']}' for server '{server_name}' has non-dict inputSchema, defaulting to empty.") + params = {} - client = AsyncOpenAI(api_key=api_key, base_url=base_url) + all_functions.append({ + "type": "function", # Explicitly set type for clarity with newer OpenAI API + "function": {"name": f"{server_name}_{tool['name']}", "description": tool.get("description", ""), "parameters": params}, + }) + servers[server_name] = client + else: + logger.warning(f"Failed to start MCP server '{server_name}', it will be unavailable.") + else: + logger.info("No mcpServers defined in configuration.") + + openai_client = AsyncOpenAI(api_key=api_key, base_url=base_url) messages = [{"role": "user", "content": user_query}] + tool_defs = [{"type": "function", "function": f["function"]} for f in all_functions] if all_functions else None if stream: async def response_generator(): + active_servers = list(servers.values()) # Keep track for cleanup try: while True: + logger.debug(f"Calling OpenAI with messages: {messages}") + logger.debug(f"Calling OpenAI with tools: {tool_defs}") # Get OpenAI response - response = await client.chat.completions.create(model=model_name, messages=messages, tools=all_functions, stream=True) - - # Process streaming response - full_response = "" - tool_calls = [] - async for chunk in response: - if chunk.choices[0].delta.content: - content = chunk.choices[0].delta.content - full_response += content - yield {"assistant_text": content, "is_chunk": True} - - if chunk.choices[0].delta.tool_calls: - for tc in chunk.choices[0].delta.tool_calls: - if len(tool_calls) <= tc.index: - tool_calls.append({"id": tc.id, "type": "function", "function": {"name": tc.function.name, "arguments": tc.function.arguments}}) - else: - tool_calls[tc.index]["function"]["arguments"] += tc.function.arguments - - # Handle tool calls - if tool_calls: - assistant_message = {"role": "assistant", "content": full_response, "tool_calls": tool_calls} - messages.append(assistant_message) - - for tc in tool_calls: - result = await process_tool_call(tc, servers) - messages.append({"role": "tool", "content": json.dumps(result), "tool_call_id": tc["id"]}) - else: + try: + response = await openai_client.chat.completions.create( + model=model_name, + messages=messages, + tools=tool_defs, + tool_choice="auto" if tool_defs else None, # Only set tool_choice if tools exist + stream=True, + ) + except Exception as e: + logger.error(f"OpenAI API error: {e}", exc_info=True) + yield {"error": f"OpenAI API error: {e}"} break + # Process streaming response + full_response_content = "" + tool_calls = [] + async for chunk in response: + delta = chunk.choices[0].delta + if delta.content: + content = delta.content + full_response_content += content + yield {"assistant_text": content, "is_chunk": True} + + if delta.tool_calls: + for tc in delta.tool_calls: + # Initialize tool call structure if it's the first chunk for this index + if tc.index >= len(tool_calls): + tool_calls.append({"id": "", "type": "function", "function": {"name": "", "arguments": ""}}) + + # Append parts as they arrive + if tc.id: + tool_calls[tc.index]["id"] = tc.id + if tc.function and tc.function.name: + tool_calls[tc.index]["function"]["name"] = tc.function.name + if tc.function and tc.function.arguments: + tool_calls[tc.index]["function"]["arguments"] += tc.function.arguments + + # Add assistant message with content and potential tool calls + assistant_message = {"role": "assistant", "content": full_response_content} + if tool_calls: + # Filter out incomplete tool calls just in case + valid_tool_calls = [tc for tc in tool_calls if tc["id"] and tc["function"]["name"]] + if valid_tool_calls: + assistant_message["tool_calls"] = valid_tool_calls + else: + logger.warning("Received tool call chunks but couldn't assemble valid tool calls.") + + messages.append(assistant_message) + logger.debug(f"Assistant message added: {assistant_message}") + + # Handle tool calls if any were successfully assembled + if "tool_calls" in assistant_message: + tool_results = [] + for tc in assistant_message["tool_calls"]: + logger.info(f"Processing tool call: {tc['function']['name']}") + result = await process_tool_call(tc, servers) + tool_results.append({"role": "tool", "content": json.dumps(result), "tool_call_id": tc["id"]}) + + messages.extend(tool_results) + logger.debug(f"Tool results added: {tool_results}") + # Loop back to call OpenAI again with tool results + else: + # No tool calls, interaction finished for this turn + yield {"assistant_text": full_response_content, "is_chunk": False, "final": True} # Signal final chunk + break + + except Exception as e: + logger.error(f"Error during streaming interaction: {e}", exc_info=True) + yield {"error": f"Interaction error: {e}"} finally: # Clean up servers - for server in servers.values(): + logger.debug("Cleaning up MCP servers (stream)...") + for server in active_servers: await server.stop() + logger.debug("MCP server cleanup finished (stream).") - else: + return response_generator() + + else: # Non-streaming case + active_servers = list(servers.values()) # Keep track for cleanup try: while True: + logger.debug(f"Calling OpenAI with messages: {messages}") + logger.debug(f"Calling OpenAI with tools: {tool_defs}") # Get OpenAI response - response = await client.chat.completions.create(model=model_name, messages=messages, tools=all_functions) + try: + response = await openai_client.chat.completions.create( + model=model_name, + messages=messages, + tools=tool_defs, + tool_choice="auto" if tool_defs else None, # Only set tool_choice if tools exist + ) + except Exception as e: + logger.error(f"OpenAI API error: {e}", exc_info=True) + return {"error": f"OpenAI API error: {e}"} message = response.choices[0].message messages.append(message) + logger.debug(f"OpenAI response message: {message}") # Handle tool calls if message.tool_calls: + tool_results = [] for tc in message.tool_calls: - result = await process_tool_call({"id": tc.id, "function": {"name": tc.function.name, "arguments": tc.function.arguments}}, servers) - messages.append({"role": "tool", "content": json.dumps(result), "tool_call_id": tc.id}) + logger.info(f"Processing tool call: {tc.function.name}") + # Reconstruct dict for process_tool_call + tool_call_dict = {"id": tc.id, "type": "function", "function": {"name": tc.function.name, "arguments": tc.function.arguments}} + result = await process_tool_call(tool_call_dict, servers) + tool_results.append({"role": "tool", "content": json.dumps(result), "tool_call_id": tc.id}) + + messages.extend(tool_results) + logger.debug(f"Tool results added: {tool_results}") + # Loop back to call OpenAI again with tool results else: + # No tool calls, interaction finished + logger.info("Interaction finished, no tool calls.") return {"assistant_text": message.content or "", "tool_calls": []} + except Exception as e: + logger.error(f"Error during non-streaming interaction: {e}", exc_info=True) + return {"error": f"Interaction error: {e}"} finally: # Clean up servers - for server in servers.values(): + logger.debug("Cleaning up MCP servers (non-stream)...") + for server in active_servers: await server.stop() + logger.debug("MCP server cleanup finished (non-stream).") diff --git a/src/mcp_manager.py b/src/mcp_manager.py index f6a6af7..3a439c9 100644 --- a/src/mcp_manager.py +++ b/src/mcp_manager.py @@ -3,10 +3,18 @@ import asyncio import importlib.resources import json +import logging # Import logging import threading from custom_mcp_client import MCPClient, run_interaction +# Configure basic logging for the application if not already configured +# This basic config helps ensure logs are visible during development +logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") + +# Get a logger for this module +logger = logging.getLogger(__name__) + class SyncMCPManager: """Synchronous wrapper for managing MCP servers and interactions""" @@ -17,86 +25,168 @@ class SyncMCPManager: self.servers = {} self.initialized = False self._lock = threading.Lock() + logger.info(f"Initializing SyncMCPManager with config path: {config_path}") self._load_config() def _load_config(self): """Load MCP configuration from JSON file using importlib""" + logger.debug(f"Attempting to load MCP config from: {self.config_path}") try: # First try to load as a package resource try: - with importlib.resources.files("streamlit-chat-app").joinpath(self.config_path).open("r") as f: + # Try anchoring to the project name defined in pyproject.toml + # This *might* work depending on editable installs or context. + resource_path = importlib.resources.files("streamlit-chat-app").joinpath(self.config_path) + with resource_path.open("r") as f: self.config = json.load(f) - except (ImportError, ModuleNotFoundError, TypeError, FileNotFoundError): - # Fall back to direct file access + logger.debug("Loaded config via importlib.resources anchored to 'streamlit-chat-app'.") + # REMOVED: raise FileNotFoundError + + except (ImportError, ModuleNotFoundError, TypeError, FileNotFoundError, NotADirectoryError): # Added NotADirectoryError + logger.debug("Failed to load via importlib.resources, falling back to direct file access.") + # Fall back to direct file access relative to CWD with open(self.config_path) as f: self.config = json.load(f) + logger.debug("Loaded config via direct file access.") + logger.info("MCP configuration loaded successfully.") + logger.debug(f"Config content: {self.config}") # Log content only if loaded + + except FileNotFoundError: + logger.error(f"MCP config file not found at {self.config_path}") + self.config = None + except json.JSONDecodeError as e: + logger.error(f"Error decoding JSON from MCP config file {self.config_path}: {e}") + self.config = None except Exception as e: - print(f"Error loading MCP config from {self.config_path}: {str(e)}") + logger.error(f"Error loading MCP config from {self.config_path}: {e}", exc_info=True) self.config = None def initialize(self) -> bool: """Initialize and start all MCP servers synchronously""" - if not self.config or not self.config.get("mcpServers"): + logger.info("Initialize requested.") + if not self.config: + logger.warning("Initialization skipped: No configuration loaded.") + return False + if not self.config.get("mcpServers"): + logger.warning("Initialization skipped: No 'mcpServers' defined in configuration.") return False if self.initialized: + logger.debug("Initialization skipped: Already initialized.") return True with self._lock: if self.initialized: # Double-check after acquiring lock + logger.debug("Initialization skipped inside lock: Already initialized.") return True + logger.info("Starting asynchronous initialization...") # Run async initialization in a new event loop loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) # Ensure this loop is used by tasks success = loop.run_until_complete(self._async_initialize()) loop.close() + asyncio.set_event_loop(None) # Clean up - self.initialized = success - return success + if success: + logger.info("Asynchronous initialization completed successfully.") + self.initialized = True + else: + logger.error("Asynchronous initialization failed.") + self.initialized = False # Ensure state is False on failure + + return self.initialized async def _async_initialize(self) -> bool: """Async implementation of server initialization""" - success = True - for server_name, server_config in self.config["mcpServers"].items(): + logger.debug("Starting _async_initialize...") + all_success = True + if not self.config or not self.config.get("mcpServers"): + logger.warning("_async_initialize: No config or mcpServers found.") + return False + + tasks = [] + server_names = list(self.config["mcpServers"].keys()) + + async def start_server(server_name, server_config): + logger.info(f"Initializing server: {server_name}") try: client = MCPClient(server_name=server_name, command=server_config.get("command"), args=server_config.get("args", []), env=server_config.get("env", {})) + logger.debug(f"Attempting to start client for {server_name}...") if await client.start(): + logger.info(f"Client for {server_name} started successfully.") tools = await client.list_tools() + logger.info(f"Tools listed for {server_name}: {len(tools)}") self.servers[server_name] = {"client": client, "tools": tools} + return True else: - success = False - print(f"Failed to start MCP server: {server_name}") + logger.error(f"Failed to start MCP server: {server_name}") + return False except Exception as e: - print(f"Error initializing server {server_name}: {str(e)}") - success = False + logger.error(f"Error initializing server {server_name}: {e}", exc_info=True) + return False - return success + # Start servers concurrently + for server_name in server_names: + server_config = self.config["mcpServers"][server_name] + tasks.append(start_server(server_name, server_config)) + + results = await asyncio.gather(*tasks) + + # Check if all servers started successfully + all_success = all(results) + + if all_success: + logger.debug("_async_initialize completed: All servers started successfully.") + else: + failed_servers = [server_names[i] for i, res in enumerate(results) if not res] + logger.error(f"_async_initialize completed with failures. Failed servers: {failed_servers}") + # Optionally shutdown servers that did start if partial success is not desired + # await self._async_shutdown() # Uncomment to enforce all-or-nothing startup + + return all_success def shutdown(self): """Shut down all MCP servers synchronously""" + logger.info("Shutdown requested.") if not self.initialized: + logger.debug("Shutdown skipped: Not initialized.") return with self._lock: if not self.initialized: + logger.debug("Shutdown skipped inside lock: Not initialized.") return + logger.info("Starting asynchronous shutdown...") loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) loop.run_until_complete(self._async_shutdown()) loop.close() + asyncio.set_event_loop(None) self.servers = {} self.initialized = False + logger.info("Shutdown complete.") async def _async_shutdown(self): """Async implementation of server shutdown""" - for server_info in self.servers.values(): - try: - await server_info["client"].stop() - except Exception as e: - print(f"Error shutting down server: {str(e)}") + logger.debug("Starting _async_shutdown...") + tasks = [] + for server_name, server_info in self.servers.items(): + logger.debug(f"Initiating shutdown for server: {server_name}") + tasks.append(server_info["client"].stop()) + + results = await asyncio.gather(*tasks, return_exceptions=True) + for i, result in enumerate(results): + server_name = list(self.servers.keys())[i] + if isinstance(result, Exception): + logger.error(f"Error shutting down server {server_name}: {result}", exc_info=result) + else: + logger.debug(f"Shutdown completed for server: {server_name}") + logger.debug("_async_shutdown finished.") def process_query(self, query: str, model_name: str) -> dict: """