feat: enhance logging in MCP client and manager for better debugging and error tracking

This commit is contained in:
2025-03-26 06:55:52 +00:00
parent ccd0a1e45b
commit ec39844bf1
3 changed files with 447 additions and 122 deletions

View File

@@ -2,11 +2,15 @@
import asyncio
import json
import logging
import os
from collections.abc import AsyncGenerator
from openai import AsyncOpenAI
# Get a logger for this module
logger = logging.getLogger(__name__)
class MCPClient:
"""Lightweight MCP client with JSON-RPC communication."""
@@ -21,180 +25,291 @@ class MCPClient:
self.request_id = 0
self.responses = {}
self._shutdown = False
# Use a logger specific to this client instance
self.logger = logging.getLogger(f"{__name__}.{self.server_name}")
async def _receive_loop(self):
"""Listen for responses from the MCP server."""
try:
while not self.process.stdout.at_eof():
line = await self.process.stdout.readline()
if not line:
while self.process and self.process.stdout and not self.process.stdout.at_eof():
line_bytes = await self.process.stdout.readline()
if not line_bytes:
self.logger.debug("STDOUT EOF reached.")
break
line_str = line_bytes.decode().strip()
self.logger.debug(f"STDOUT Raw line: {line_str}")
try:
message = json.loads(line.decode().strip())
message = json.loads(line_str)
if "jsonrpc" in message and "id" in message and ("result" in message or "error" in message):
self.logger.debug(f"STDOUT Parsed response for ID {message['id']}")
self.responses[message["id"]] = message
except Exception:
pass
except Exception:
pass
elif "jsonrpc" in message and "method" in message:
self.logger.debug(f"STDOUT Received notification: {message.get('method')}")
else:
self.logger.debug(f"STDOUT Parsed non-response/notification JSON: {message}")
except json.JSONDecodeError:
self.logger.warning("STDOUT Failed to parse line as JSON.")
except Exception as e:
self.logger.error(f"STDOUT Error processing line: {e}", exc_info=True)
except asyncio.CancelledError:
self.logger.debug("STDOUT Receive loop cancelled.")
except Exception as e:
self.logger.error(f"STDOUT Receive loop error: {e}", exc_info=True)
finally:
self.logger.debug("STDOUT Receive loop finished.")
async def _stderr_loop(self):
"""Listen for stderr messages from the MCP server."""
try:
while self.process and self.process.stderr and not self.process.stderr.at_eof():
line_bytes = await self.process.stderr.readline()
if not line_bytes:
self.logger.debug("STDERR EOF reached.")
break
line_str = line_bytes.decode().strip()
self.logger.warning(f"STDERR: {line_str}") # Log stderr as warning
except asyncio.CancelledError:
self.logger.debug("STDERR Stderr loop cancelled.")
except Exception as e:
self.logger.error(f"STDERR Stderr loop error: {e}", exc_info=True)
finally:
self.logger.debug("STDERR Stderr loop finished.")
async def _send_message(self, message: dict) -> bool:
"""Send a JSON-RPC message to the MCP server."""
if not self.process:
if not self.process or not self.process.stdin:
self.logger.warning("STDIN Cannot send message, process or stdin not available.")
return False
try:
data = json.dumps(message) + "\n"
self.logger.debug(f"STDIN Sending: {data.strip()}")
self.process.stdin.write(data.encode())
await self.process.stdin.drain()
return True
except Exception:
except ConnectionResetError:
self.logger.error("STDIN Connection reset while sending message.")
self.process = None # Mark process as dead
return False
except Exception as e:
self.logger.error(f"STDIN Error sending message: {e}", exc_info=True)
return False
async def start(self) -> bool:
"""Start the MCP server process."""
# Expand ~ in paths
self.logger.info("Attempting to start server...")
# Expand ~ in paths and prepare args
expanded_args = []
for a in self.args:
if isinstance(a, str) and "~" in a:
expanded_args.append(os.path.expanduser(a))
else:
expanded_args.append(a)
try:
for a in self.args:
if isinstance(a, str) and "~" in a:
expanded_args.append(os.path.expanduser(a))
else:
expanded_args.append(str(a)) # Ensure all args are strings
except Exception as e:
self.logger.error(f"Error expanding arguments: {e}", exc_info=True)
return False
# Set up environment
env_vars = os.environ.copy()
if self.env:
env_vars.update(self.env)
self.logger.debug(f"Command: {self.command}")
self.logger.debug(f"Expanded Args: {expanded_args}")
# Avoid logging full env unless necessary for debugging sensitive info
# self.logger.debug(f"Environment: {env_vars}")
try:
# Start the subprocess
self.logger.debug("Creating subprocess...")
self.process = await asyncio.create_subprocess_exec(
self.command, *expanded_args, stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, env=env_vars
self.command,
*expanded_args,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE, # Capture stderr
env=env_vars,
)
self.logger.info(f"Subprocess created with PID: {self.process.pid}")
# Start the receive loop
# Start the receive loops
asyncio.create_task(self._receive_loop())
asyncio.create_task(self._stderr_loop()) # Start stderr loop
# Initialize the server
return await self._initialize()
except Exception:
self.logger.debug("Attempting initialization handshake...")
init_success = await self._initialize()
if init_success:
self.logger.info("Initialization handshake successful (or skipped).")
# Add delay after successful start
await asyncio.sleep(0.5)
self.logger.debug("Post-initialization delay complete.")
return True
else:
self.logger.error("Initialization handshake failed.")
await self.stop() # Ensure cleanup if init fails
return False
except FileNotFoundError:
self.logger.error(f"Error starting subprocess: Command not found: '{self.command}'")
return False
except Exception as e:
self.logger.error(f"Error starting subprocess: {e}", exc_info=True)
return False
async def _initialize(self) -> bool:
"""Initialize the MCP server connection."""
"""Initialize the MCP server connection. Modified to not wait for response."""
self.logger.debug("Sending 'initialize' request...")
if not self.process:
self.logger.warning("Cannot initialize, process not running.")
return False
# Send initialize request
self.request_id += 1
req_id = self.request_id
initialize_req = {"jsonrpc": "2.0", "id": req_id, "method": "initialize", "params": {"protocolVersion": "2024-11-05", "clientInfo": {"name": "CustomMCPClient", "version": "1.0.0"}}}
await self._send_message(initialize_req)
if not await self._send_message(initialize_req):
self.logger.warning("Failed to send 'initialize' request.")
# Continue anyway for non-compliant servers
# Wait for response
start_time = asyncio.get_event_loop().time()
timeout = 10
while asyncio.get_event_loop().time() - start_time < timeout:
if req_id in self.responses:
resp = self.responses[req_id]
del self.responses[req_id]
# Send initialized notification immediately
self.logger.debug("Sending 'initialized' notification...")
notify = {"jsonrpc": "2.0", "method": "notifications/initialized"}
if await self._send_message(notify):
self.logger.debug("'initialized' notification sent.")
else:
self.logger.warning("Failed to send 'initialized' notification.")
# Still return True as the server might be running
if "error" in resp:
return False
# Send initialized notification
notify = {"jsonrpc": "2.0", "method": "notifications/initialized"}
await self._send_message(notify)
return True
await asyncio.sleep(0.05)
return False
self.logger.info("Skipping wait for 'initialize' response (assuming non-compliant server).")
return True # Assume success without waiting for response
async def list_tools(self) -> list[dict]:
"""List available tools from the MCP server."""
if not self.process:
self.logger.warning("Cannot list tools, process not running.")
return []
self.logger.debug("Sending 'tools/list' request...")
self.request_id += 1
req_id = self.request_id
req = {"jsonrpc": "2.0", "id": req_id, "method": "tools/list", "params": {}}
await self._send_message(req)
if not await self._send_message(req):
self.logger.error("Failed to send 'tools/list' request.")
return []
# Wait for response
self.logger.debug(f"Waiting for 'tools/list' response (ID: {req_id})...")
start_time = asyncio.get_event_loop().time()
timeout = 10
timeout = 10 # seconds
while asyncio.get_event_loop().time() - start_time < timeout:
if req_id in self.responses:
resp = self.responses[req_id]
del self.responses[req_id]
resp = self.responses.pop(req_id)
self.logger.debug(f"Received 'tools/list' response: {resp}")
if "error" in resp:
self.logger.error(f"'tools/list' error response: {resp['error']}")
return []
if "result" in resp and "tools" in resp["result"]:
self.tools = resp["result"]["tools"]
self.logger.info(f"Successfully listed tools: {len(self.tools)}")
return self.tools
else:
self.logger.error("Invalid 'tools/list' response format.")
return []
await asyncio.sleep(0.05)
self.logger.error(f"'tools/list' request timed out after {timeout} seconds.")
return []
async def call_tool(self, tool_name: str, arguments: dict) -> dict:
"""Call a tool on the MCP server."""
if not self.process:
self.logger.warning(f"Cannot call tool '{tool_name}', process not running.")
return {"error": "Server not started"}
self.logger.debug(f"Sending 'tools/call' request for tool '{tool_name}'...")
self.request_id += 1
req_id = self.request_id
req = {"jsonrpc": "2.0", "id": req_id, "method": "tools/call", "params": {"name": tool_name, "arguments": arguments}}
await self._send_message(req)
if not await self._send_message(req):
self.logger.error("Failed to send 'tools/call' request.")
return {"error": "Failed to send tool call request"}
# Wait for response
self.logger.debug(f"Waiting for 'tools/call' response (ID: {req_id})...")
start_time = asyncio.get_event_loop().time()
timeout = 30
timeout = 30 # seconds
while asyncio.get_event_loop().time() - start_time < timeout:
if req_id in self.responses:
resp = self.responses[req_id]
del self.responses[req_id]
resp = self.responses.pop(req_id)
self.logger.debug(f"Received 'tools/call' response: {resp}")
if "error" in resp:
self.logger.error(f"'tools/call' error response: {resp['error']}")
return {"error": str(resp["error"])}
if "result" in resp:
self.logger.info(f"Tool '{tool_name}' executed successfully.")
return resp["result"]
else:
self.logger.error("Invalid 'tools/call' response format.")
return {"error": "Invalid tool call response format"}
await asyncio.sleep(0.05)
self.logger.error(f"Tool call '{tool_name}' timed out after {timeout} seconds.")
return {"error": f"Tool call timed out after {timeout}s"}
async def stop(self):
"""Stop the MCP server process."""
self.logger.info("Attempting to stop server...")
if self._shutdown or not self.process:
self.logger.debug("Server already stopped or not running.")
return
self._shutdown = True
proc = self.process # Keep a local reference
self.process = None # Prevent further operations
try:
# Send shutdown notification
self.logger.debug("Sending 'shutdown' notification...")
notify = {"jsonrpc": "2.0", "method": "shutdown"}
await self._send_message(notify)
await asyncio.sleep(0.5)
await self._send_message(notify) # Use the method which now handles None process
await asyncio.sleep(0.5) # Give server time to process
# Close stdin
if self.process.stdin:
self.process.stdin.close()
if proc and proc.stdin:
try:
if not proc.stdin.is_closing():
proc.stdin.close()
await proc.stdin.wait_closed()
self.logger.debug("Stdin closed.")
except Exception as e:
self.logger.warning(f"Error closing stdin: {e}", exc_info=True)
# Terminate the process
self.process.terminate()
try:
await asyncio.wait_for(self.process.wait(), timeout=1.0)
except TimeoutError:
self.process.kill()
except Exception:
pass
if proc:
self.logger.debug(f"Terminating process {proc.pid}...")
proc.terminate()
try:
await asyncio.wait_for(proc.wait(), timeout=2.0)
self.logger.info(f"Process {proc.pid} terminated gracefully.")
except TimeoutError:
self.logger.warning(f"Process {proc.pid} did not terminate gracefully, killing...")
proc.kill()
await proc.wait()
self.logger.info(f"Process {proc.pid} killed.")
except Exception as e:
self.logger.error(f"Error waiting for process termination: {e}", exc_info=True)
except Exception as e:
self.logger.error(f"Error during shutdown sequence: {e}", exc_info=True)
finally:
self.logger.debug("Stop sequence finished.")
# Ensure self.process is None even if errors occurred
self.process = None
@@ -203,17 +318,20 @@ async def process_tool_call(tool_call: dict, servers: dict[str, MCPClient]) -> d
func_name = tool_call["function"]["name"]
try:
func_args = json.loads(tool_call["function"].get("arguments", "{}"))
except json.JSONDecodeError:
except json.JSONDecodeError as e:
logger.error(f"Invalid tool arguments format for {func_name}: {e}")
return {"error": "Invalid arguments format"}
# Parse server_name and tool_name from function name
parts = func_name.split("_", 1)
if len(parts) != 2:
logger.error(f"Invalid tool function name format: {func_name}")
return {"error": "Invalid function name format"}
server_name, tool_name = parts
if server_name not in servers:
logger.error(f"Tool call for unknown server: {server_name}")
return {"error": f"Unknown server: {server_name}"}
# Call the tool
@@ -234,82 +352,197 @@ async def run_interaction(user_query: str, model_name: str, config: dict, stream
Dictionary containing response or AsyncGenerator for streaming
"""
# Get OpenAI configuration
api_key = config["models"][0]["apiKey"]
# TODO: Handle multiple models in config?
if not config.get("models"):
logger.error("No models defined in MCP configuration.")
# This function needs to return something compatible, maybe raise error?
# For now, returning error dict for non-streaming case
if not stream:
return {"error": "No models defined in MCP configuration."}
else:
# How to handle error in async generator? Yield an error dict?
async def error_gen():
yield {"error": "No models defined in MCP configuration."}
return error_gen()
api_key = config["models"][0].get("apiKey")
base_url = config["models"][0].get("apiBase", "https://api.openai.com/v1")
if not api_key:
logger.error("apiKey missing for the first model in MCP configuration.")
if not stream:
return {"error": "apiKey missing for the first model in MCP configuration."}
else:
async def error_gen():
yield {"error": "apiKey missing for the first model in MCP configuration."}
return error_gen()
# Start MCP servers
servers = {}
all_functions = []
for server_name, server_config in config["mcpServers"].items():
client = MCPClient(server_name=server_name, command=server_config.get("command"), args=server_config.get("args", []), env=server_config.get("env", {}))
if config.get("mcpServers"):
for server_name, server_config in config["mcpServers"].items():
client = MCPClient(server_name=server_name, command=server_config.get("command"), args=server_config.get("args", []), env=server_config.get("env", {}))
if await client.start():
tools = await client.list_tools()
for tool in tools:
all_functions.append({"name": f"{server_name}_{tool['name']}", "description": tool.get("description", ""), "parameters": tool.get("inputSchema", {})})
servers[server_name] = client
if await client.start():
tools = await client.list_tools()
for tool in tools:
# Ensure parameters is a dict, default to empty if missing or not dict
params = tool.get("inputSchema", {})
if not isinstance(params, dict):
logger.warning(f"Tool '{tool['name']}' for server '{server_name}' has non-dict inputSchema, defaulting to empty.")
params = {}
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
all_functions.append({
"type": "function", # Explicitly set type for clarity with newer OpenAI API
"function": {"name": f"{server_name}_{tool['name']}", "description": tool.get("description", ""), "parameters": params},
})
servers[server_name] = client
else:
logger.warning(f"Failed to start MCP server '{server_name}', it will be unavailable.")
else:
logger.info("No mcpServers defined in configuration.")
openai_client = AsyncOpenAI(api_key=api_key, base_url=base_url)
messages = [{"role": "user", "content": user_query}]
tool_defs = [{"type": "function", "function": f["function"]} for f in all_functions] if all_functions else None
if stream:
async def response_generator():
active_servers = list(servers.values()) # Keep track for cleanup
try:
while True:
logger.debug(f"Calling OpenAI with messages: {messages}")
logger.debug(f"Calling OpenAI with tools: {tool_defs}")
# Get OpenAI response
response = await client.chat.completions.create(model=model_name, messages=messages, tools=all_functions, stream=True)
# Process streaming response
full_response = ""
tool_calls = []
async for chunk in response:
if chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
full_response += content
yield {"assistant_text": content, "is_chunk": True}
if chunk.choices[0].delta.tool_calls:
for tc in chunk.choices[0].delta.tool_calls:
if len(tool_calls) <= tc.index:
tool_calls.append({"id": tc.id, "type": "function", "function": {"name": tc.function.name, "arguments": tc.function.arguments}})
else:
tool_calls[tc.index]["function"]["arguments"] += tc.function.arguments
# Handle tool calls
if tool_calls:
assistant_message = {"role": "assistant", "content": full_response, "tool_calls": tool_calls}
messages.append(assistant_message)
for tc in tool_calls:
result = await process_tool_call(tc, servers)
messages.append({"role": "tool", "content": json.dumps(result), "tool_call_id": tc["id"]})
else:
try:
response = await openai_client.chat.completions.create(
model=model_name,
messages=messages,
tools=tool_defs,
tool_choice="auto" if tool_defs else None, # Only set tool_choice if tools exist
stream=True,
)
except Exception as e:
logger.error(f"OpenAI API error: {e}", exc_info=True)
yield {"error": f"OpenAI API error: {e}"}
break
# Process streaming response
full_response_content = ""
tool_calls = []
async for chunk in response:
delta = chunk.choices[0].delta
if delta.content:
content = delta.content
full_response_content += content
yield {"assistant_text": content, "is_chunk": True}
if delta.tool_calls:
for tc in delta.tool_calls:
# Initialize tool call structure if it's the first chunk for this index
if tc.index >= len(tool_calls):
tool_calls.append({"id": "", "type": "function", "function": {"name": "", "arguments": ""}})
# Append parts as they arrive
if tc.id:
tool_calls[tc.index]["id"] = tc.id
if tc.function and tc.function.name:
tool_calls[tc.index]["function"]["name"] = tc.function.name
if tc.function and tc.function.arguments:
tool_calls[tc.index]["function"]["arguments"] += tc.function.arguments
# Add assistant message with content and potential tool calls
assistant_message = {"role": "assistant", "content": full_response_content}
if tool_calls:
# Filter out incomplete tool calls just in case
valid_tool_calls = [tc for tc in tool_calls if tc["id"] and tc["function"]["name"]]
if valid_tool_calls:
assistant_message["tool_calls"] = valid_tool_calls
else:
logger.warning("Received tool call chunks but couldn't assemble valid tool calls.")
messages.append(assistant_message)
logger.debug(f"Assistant message added: {assistant_message}")
# Handle tool calls if any were successfully assembled
if "tool_calls" in assistant_message:
tool_results = []
for tc in assistant_message["tool_calls"]:
logger.info(f"Processing tool call: {tc['function']['name']}")
result = await process_tool_call(tc, servers)
tool_results.append({"role": "tool", "content": json.dumps(result), "tool_call_id": tc["id"]})
messages.extend(tool_results)
logger.debug(f"Tool results added: {tool_results}")
# Loop back to call OpenAI again with tool results
else:
# No tool calls, interaction finished for this turn
yield {"assistant_text": full_response_content, "is_chunk": False, "final": True} # Signal final chunk
break
except Exception as e:
logger.error(f"Error during streaming interaction: {e}", exc_info=True)
yield {"error": f"Interaction error: {e}"}
finally:
# Clean up servers
for server in servers.values():
logger.debug("Cleaning up MCP servers (stream)...")
for server in active_servers:
await server.stop()
logger.debug("MCP server cleanup finished (stream).")
else:
return response_generator()
else: # Non-streaming case
active_servers = list(servers.values()) # Keep track for cleanup
try:
while True:
logger.debug(f"Calling OpenAI with messages: {messages}")
logger.debug(f"Calling OpenAI with tools: {tool_defs}")
# Get OpenAI response
response = await client.chat.completions.create(model=model_name, messages=messages, tools=all_functions)
try:
response = await openai_client.chat.completions.create(
model=model_name,
messages=messages,
tools=tool_defs,
tool_choice="auto" if tool_defs else None, # Only set tool_choice if tools exist
)
except Exception as e:
logger.error(f"OpenAI API error: {e}", exc_info=True)
return {"error": f"OpenAI API error: {e}"}
message = response.choices[0].message
messages.append(message)
logger.debug(f"OpenAI response message: {message}")
# Handle tool calls
if message.tool_calls:
tool_results = []
for tc in message.tool_calls:
result = await process_tool_call({"id": tc.id, "function": {"name": tc.function.name, "arguments": tc.function.arguments}}, servers)
messages.append({"role": "tool", "content": json.dumps(result), "tool_call_id": tc.id})
logger.info(f"Processing tool call: {tc.function.name}")
# Reconstruct dict for process_tool_call
tool_call_dict = {"id": tc.id, "type": "function", "function": {"name": tc.function.name, "arguments": tc.function.arguments}}
result = await process_tool_call(tool_call_dict, servers)
tool_results.append({"role": "tool", "content": json.dumps(result), "tool_call_id": tc.id})
messages.extend(tool_results)
logger.debug(f"Tool results added: {tool_results}")
# Loop back to call OpenAI again with tool results
else:
# No tool calls, interaction finished
logger.info("Interaction finished, no tool calls.")
return {"assistant_text": message.content or "", "tool_calls": []}
except Exception as e:
logger.error(f"Error during non-streaming interaction: {e}", exc_info=True)
return {"error": f"Interaction error: {e}"}
finally:
# Clean up servers
for server in servers.values():
logger.debug("Cleaning up MCP servers (non-stream)...")
for server in active_servers:
await server.stop()
logger.debug("MCP server cleanup finished (non-stream).")

View File

@@ -3,10 +3,18 @@
import asyncio
import importlib.resources
import json
import logging # Import logging
import threading
from custom_mcp_client import MCPClient, run_interaction
# Configure basic logging for the application if not already configured
# This basic config helps ensure logs are visible during development
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
# Get a logger for this module
logger = logging.getLogger(__name__)
class SyncMCPManager:
"""Synchronous wrapper for managing MCP servers and interactions"""
@@ -17,86 +25,168 @@ class SyncMCPManager:
self.servers = {}
self.initialized = False
self._lock = threading.Lock()
logger.info(f"Initializing SyncMCPManager with config path: {config_path}")
self._load_config()
def _load_config(self):
"""Load MCP configuration from JSON file using importlib"""
logger.debug(f"Attempting to load MCP config from: {self.config_path}")
try:
# First try to load as a package resource
try:
with importlib.resources.files("streamlit-chat-app").joinpath(self.config_path).open("r") as f:
# Try anchoring to the project name defined in pyproject.toml
# This *might* work depending on editable installs or context.
resource_path = importlib.resources.files("streamlit-chat-app").joinpath(self.config_path)
with resource_path.open("r") as f:
self.config = json.load(f)
except (ImportError, ModuleNotFoundError, TypeError, FileNotFoundError):
# Fall back to direct file access
logger.debug("Loaded config via importlib.resources anchored to 'streamlit-chat-app'.")
# REMOVED: raise FileNotFoundError
except (ImportError, ModuleNotFoundError, TypeError, FileNotFoundError, NotADirectoryError): # Added NotADirectoryError
logger.debug("Failed to load via importlib.resources, falling back to direct file access.")
# Fall back to direct file access relative to CWD
with open(self.config_path) as f:
self.config = json.load(f)
logger.debug("Loaded config via direct file access.")
logger.info("MCP configuration loaded successfully.")
logger.debug(f"Config content: {self.config}") # Log content only if loaded
except FileNotFoundError:
logger.error(f"MCP config file not found at {self.config_path}")
self.config = None
except json.JSONDecodeError as e:
logger.error(f"Error decoding JSON from MCP config file {self.config_path}: {e}")
self.config = None
except Exception as e:
print(f"Error loading MCP config from {self.config_path}: {str(e)}")
logger.error(f"Error loading MCP config from {self.config_path}: {e}", exc_info=True)
self.config = None
def initialize(self) -> bool:
"""Initialize and start all MCP servers synchronously"""
if not self.config or not self.config.get("mcpServers"):
logger.info("Initialize requested.")
if not self.config:
logger.warning("Initialization skipped: No configuration loaded.")
return False
if not self.config.get("mcpServers"):
logger.warning("Initialization skipped: No 'mcpServers' defined in configuration.")
return False
if self.initialized:
logger.debug("Initialization skipped: Already initialized.")
return True
with self._lock:
if self.initialized: # Double-check after acquiring lock
logger.debug("Initialization skipped inside lock: Already initialized.")
return True
logger.info("Starting asynchronous initialization...")
# Run async initialization in a new event loop
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) # Ensure this loop is used by tasks
success = loop.run_until_complete(self._async_initialize())
loop.close()
asyncio.set_event_loop(None) # Clean up
self.initialized = success
return success
if success:
logger.info("Asynchronous initialization completed successfully.")
self.initialized = True
else:
logger.error("Asynchronous initialization failed.")
self.initialized = False # Ensure state is False on failure
return self.initialized
async def _async_initialize(self) -> bool:
"""Async implementation of server initialization"""
success = True
for server_name, server_config in self.config["mcpServers"].items():
logger.debug("Starting _async_initialize...")
all_success = True
if not self.config or not self.config.get("mcpServers"):
logger.warning("_async_initialize: No config or mcpServers found.")
return False
tasks = []
server_names = list(self.config["mcpServers"].keys())
async def start_server(server_name, server_config):
logger.info(f"Initializing server: {server_name}")
try:
client = MCPClient(server_name=server_name, command=server_config.get("command"), args=server_config.get("args", []), env=server_config.get("env", {}))
logger.debug(f"Attempting to start client for {server_name}...")
if await client.start():
logger.info(f"Client for {server_name} started successfully.")
tools = await client.list_tools()
logger.info(f"Tools listed for {server_name}: {len(tools)}")
self.servers[server_name] = {"client": client, "tools": tools}
return True
else:
success = False
print(f"Failed to start MCP server: {server_name}")
logger.error(f"Failed to start MCP server: {server_name}")
return False
except Exception as e:
print(f"Error initializing server {server_name}: {str(e)}")
success = False
logger.error(f"Error initializing server {server_name}: {e}", exc_info=True)
return False
return success
# Start servers concurrently
for server_name in server_names:
server_config = self.config["mcpServers"][server_name]
tasks.append(start_server(server_name, server_config))
results = await asyncio.gather(*tasks)
# Check if all servers started successfully
all_success = all(results)
if all_success:
logger.debug("_async_initialize completed: All servers started successfully.")
else:
failed_servers = [server_names[i] for i, res in enumerate(results) if not res]
logger.error(f"_async_initialize completed with failures. Failed servers: {failed_servers}")
# Optionally shutdown servers that did start if partial success is not desired
# await self._async_shutdown() # Uncomment to enforce all-or-nothing startup
return all_success
def shutdown(self):
"""Shut down all MCP servers synchronously"""
logger.info("Shutdown requested.")
if not self.initialized:
logger.debug("Shutdown skipped: Not initialized.")
return
with self._lock:
if not self.initialized:
logger.debug("Shutdown skipped inside lock: Not initialized.")
return
logger.info("Starting asynchronous shutdown...")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(self._async_shutdown())
loop.close()
asyncio.set_event_loop(None)
self.servers = {}
self.initialized = False
logger.info("Shutdown complete.")
async def _async_shutdown(self):
"""Async implementation of server shutdown"""
for server_info in self.servers.values():
try:
await server_info["client"].stop()
except Exception as e:
print(f"Error shutting down server: {str(e)}")
logger.debug("Starting _async_shutdown...")
tasks = []
for server_name, server_info in self.servers.items():
logger.debug(f"Initiating shutdown for server: {server_name}")
tasks.append(server_info["client"].stop())
results = await asyncio.gather(*tasks, return_exceptions=True)
for i, result in enumerate(results):
server_name = list(self.servers.keys())[i]
if isinstance(result, Exception):
logger.error(f"Error shutting down server {server_name}: {result}", exc_info=result)
else:
logger.debug(f"Shutdown completed for server: {server_name}")
logger.debug("_async_shutdown finished.")
def process_query(self, query: str, model_name: str) -> dict:
"""