From b4986e0eb91527494d4297bec7e008ef0da2d38c Mon Sep 17 00:00:00 2001 From: abhishekbhakat Date: Wed, 26 Mar 2025 11:00:43 +0000 Subject: [PATCH] refactor: remove custom MCP client implementation files --- src/custom_mcp_client/__init__.py | 5 - src/custom_mcp_client/client.py | 550 ------------------------------ 2 files changed, 555 deletions(-) delete mode 100644 src/custom_mcp_client/__init__.py delete mode 100644 src/custom_mcp_client/client.py diff --git a/src/custom_mcp_client/__init__.py b/src/custom_mcp_client/__init__.py deleted file mode 100644 index 4913c64..0000000 --- a/src/custom_mcp_client/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Custom MCP client implementation focused on OpenAI integration.""" - -from .client import MCPClient, run_interaction - -__all__ = ["MCPClient", "run_interaction"] diff --git a/src/custom_mcp_client/client.py b/src/custom_mcp_client/client.py deleted file mode 100644 index f619bf5..0000000 --- a/src/custom_mcp_client/client.py +++ /dev/null @@ -1,550 +0,0 @@ -"""Custom MCP client implementation with JSON-RPC and OpenAI integration.""" - -import asyncio -import json -import logging -import os -from collections.abc import AsyncGenerator - -from openai import AsyncOpenAI - -# Get a logger for this module -logger = logging.getLogger(__name__) - - -class MCPClient: - """Lightweight MCP client with JSON-RPC communication.""" - - def __init__(self, server_name: str, command: str, args: list[str] | None = None, env: dict[str, str] | None = None): - self.server_name = server_name - self.command = command - self.args = args or [] - self.env = env or {} - self.process = None - self.tools = [] - self.request_id = 0 - self.responses = {} - self._shutdown = False - # Use a logger specific to this client instance - self.logger = logging.getLogger(f"{__name__}.{self.server_name}") - - async def _receive_loop(self): - """Listen for responses from the MCP server.""" - try: - while self.process and self.process.stdout and not self.process.stdout.at_eof(): - line_bytes = await self.process.stdout.readline() - if not line_bytes: - self.logger.debug("STDOUT EOF reached.") - break - line_str = line_bytes.decode().strip() - self.logger.debug(f"STDOUT Raw line: {line_str}") - try: - message = json.loads(line_str) - if "jsonrpc" in message and "id" in message and ("result" in message or "error" in message): - self.logger.debug(f"STDOUT Parsed response for ID {message['id']}") - self.responses[message["id"]] = message - elif "jsonrpc" in message and "method" in message: - self.logger.debug(f"STDOUT Received notification: {message.get('method')}") - else: - self.logger.debug(f"STDOUT Parsed non-response/notification JSON: {message}") - except json.JSONDecodeError: - self.logger.warning("STDOUT Failed to parse line as JSON.") - except Exception as e: - self.logger.error(f"STDOUT Error processing line: {e}", exc_info=True) - except asyncio.CancelledError: - self.logger.debug("STDOUT Receive loop cancelled.") - except Exception as e: - self.logger.error(f"STDOUT Receive loop error: {e}", exc_info=True) - finally: - self.logger.debug("STDOUT Receive loop finished.") - - async def _stderr_loop(self): - """Listen for stderr messages from the MCP server.""" - try: - while self.process and self.process.stderr and not self.process.stderr.at_eof(): - line_bytes = await self.process.stderr.readline() - if not line_bytes: - self.logger.debug("STDERR EOF reached.") - break - line_str = line_bytes.decode().strip() - self.logger.warning(f"STDERR: {line_str}") # Log stderr as warning - except asyncio.CancelledError: - self.logger.debug("STDERR Stderr loop cancelled.") - except Exception as e: - self.logger.error(f"STDERR Stderr loop error: {e}", exc_info=True) - finally: - self.logger.debug("STDERR Stderr loop finished.") - - async def _send_message(self, message: dict) -> bool: - """Send a JSON-RPC message to the MCP server.""" - if not self.process or not self.process.stdin: - self.logger.warning("STDIN Cannot send message, process or stdin not available.") - return False - try: - data = json.dumps(message) + "\n" - self.logger.debug(f"STDIN Sending: {data.strip()}") - self.process.stdin.write(data.encode()) - await self.process.stdin.drain() - return True - except ConnectionResetError: - self.logger.error("STDIN Connection reset while sending message.") - self.process = None # Mark process as dead - return False - except Exception as e: - self.logger.error(f"STDIN Error sending message: {e}", exc_info=True) - return False - - async def start(self) -> bool: - """Start the MCP server process.""" - self.logger.info("Attempting to start server...") - # Expand ~ in paths and prepare args - expanded_args = [] - try: - for a in self.args: - if isinstance(a, str) and "~" in a: - expanded_args.append(os.path.expanduser(a)) - else: - expanded_args.append(str(a)) # Ensure all args are strings - except Exception as e: - self.logger.error(f"Error expanding arguments: {e}", exc_info=True) - return False - - # Set up environment - env_vars = os.environ.copy() - if self.env: - env_vars.update(self.env) - - self.logger.debug(f"Command: {self.command}") - self.logger.debug(f"Expanded Args: {expanded_args}") - # Avoid logging full env unless necessary for debugging sensitive info - # self.logger.debug(f"Environment: {env_vars}") - - try: - # Start the subprocess - self.logger.debug("Creating subprocess...") - self.process = await asyncio.create_subprocess_exec( - self.command, - *expanded_args, - stdin=asyncio.subprocess.PIPE, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, # Capture stderr - env=env_vars, - ) - self.logger.info(f"Subprocess created with PID: {self.process.pid}") - - # Start the receive loops - asyncio.create_task(self._receive_loop()) - asyncio.create_task(self._stderr_loop()) # Start stderr loop - - # Initialize the server - self.logger.debug("Attempting initialization handshake...") - init_success = await self._initialize() - if init_success: - self.logger.info("Initialization handshake successful (or skipped).") - # Add delay after successful start - await asyncio.sleep(0.5) - self.logger.debug("Post-initialization delay complete.") - return True - else: - self.logger.error("Initialization handshake failed.") - await self.stop() # Ensure cleanup if init fails - return False - except FileNotFoundError: - self.logger.error(f"Error starting subprocess: Command not found: '{self.command}'") - return False - except Exception as e: - self.logger.error(f"Error starting subprocess: {e}", exc_info=True) - return False - - async def _initialize(self) -> bool: - """Initialize the MCP server connection. Modified to not wait for response.""" - self.logger.debug("Sending 'initialize' request...") - if not self.process: - self.logger.warning("Cannot initialize, process not running.") - return False - - # Send initialize request - self.request_id += 1 - req_id = self.request_id - initialize_req = { - "jsonrpc": "2.0", - "id": req_id, - "method": "initialize", - "params": { - "protocolVersion": "2024-11-05", - "clientInfo": {"name": "CustomMCPClient", "version": "1.0.0"}, - "capabilities": {}, # Add empty capabilities object - }, - } - if not await self._send_message(initialize_req): - self.logger.warning("Failed to send 'initialize' request.") - # Continue anyway for non-compliant servers - - # Send initialized notification immediately - self.logger.debug("Sending 'initialized' notification...") - notify = {"jsonrpc": "2.0", "method": "notifications/initialized"} - if await self._send_message(notify): - self.logger.debug("'initialized' notification sent.") - else: - self.logger.warning("Failed to send 'initialized' notification.") - # Still return True as the server might be running - - self.logger.info("Skipping wait for 'initialize' response (assuming non-compliant server).") - return True # Assume success without waiting for response - - async def list_tools(self) -> list[dict]: - """List available tools from the MCP server.""" - if not self.process: - self.logger.warning("Cannot list tools, process not running.") - return [] - - self.logger.debug("Sending 'tools/list' request...") - self.request_id += 1 - req_id = self.request_id - req = {"jsonrpc": "2.0", "id": req_id, "method": "tools/list", "params": {}} - if not await self._send_message(req): - self.logger.error("Failed to send 'tools/list' request.") - return [] - - # Wait for response - self.logger.debug(f"Waiting for 'tools/list' response (ID: {req_id})...") - start_time = asyncio.get_event_loop().time() - timeout = 10 # seconds - while asyncio.get_event_loop().time() - start_time < timeout: - if req_id in self.responses: - resp = self.responses.pop(req_id) - self.logger.debug(f"Received 'tools/list' response: {resp}") - - if "error" in resp: - self.logger.error(f"'tools/list' error response: {resp['error']}") - return [] - - if "result" in resp and "tools" in resp["result"]: - self.tools = resp["result"]["tools"] - self.logger.info(f"Successfully listed tools: {len(self.tools)}") - return self.tools - else: - self.logger.error("Invalid 'tools/list' response format.") - return [] - - await asyncio.sleep(0.05) - - self.logger.error(f"'tools/list' request timed out after {timeout} seconds.") - return [] - - async def call_tool(self, tool_name: str, arguments: dict) -> dict: - """Call a tool on the MCP server.""" - if not self.process: - self.logger.warning(f"Cannot call tool '{tool_name}', process not running.") - return {"error": "Server not started"} - - self.logger.debug(f"Sending 'tools/call' request for tool '{tool_name}'...") - self.request_id += 1 - req_id = self.request_id - req = {"jsonrpc": "2.0", "id": req_id, "method": "tools/call", "params": {"name": tool_name, "arguments": arguments}} - if not await self._send_message(req): - self.logger.error("Failed to send 'tools/call' request.") - return {"error": "Failed to send tool call request"} - - # Wait for response - self.logger.debug(f"Waiting for 'tools/call' response (ID: {req_id})...") - start_time = asyncio.get_event_loop().time() - timeout = 30 # seconds - while asyncio.get_event_loop().time() - start_time < timeout: - if req_id in self.responses: - resp = self.responses.pop(req_id) - self.logger.debug(f"Received 'tools/call' response: {resp}") - - if "error" in resp: - self.logger.error(f"'tools/call' error response: {resp['error']}") - return {"error": str(resp["error"])} - - if "result" in resp: - self.logger.info(f"Tool '{tool_name}' executed successfully.") - return resp["result"] - else: - self.logger.error("Invalid 'tools/call' response format.") - return {"error": "Invalid tool call response format"} - - await asyncio.sleep(0.05) - - self.logger.error(f"Tool call '{tool_name}' timed out after {timeout} seconds.") - return {"error": f"Tool call timed out after {timeout}s"} - - async def stop(self): - """Stop the MCP server process.""" - self.logger.info("Attempting to stop server...") - if self._shutdown or not self.process: - self.logger.debug("Server already stopped or not running.") - return - - self._shutdown = True - proc = self.process # Keep a local reference - self.process = None # Prevent further operations - - try: - # Send shutdown notification - self.logger.debug("Sending 'shutdown' notification...") - notify = {"jsonrpc": "2.0", "method": "shutdown"} - await self._send_message(notify) # Use the method which now handles None process - await asyncio.sleep(0.5) # Give server time to process - - # Close stdin - if proc and proc.stdin: - try: - if not proc.stdin.is_closing(): - proc.stdin.close() - await proc.stdin.wait_closed() - self.logger.debug("Stdin closed.") - except Exception as e: - self.logger.warning(f"Error closing stdin: {e}", exc_info=True) - - # Terminate the process - if proc: - self.logger.debug(f"Terminating process {proc.pid}...") - proc.terminate() - try: - await asyncio.wait_for(proc.wait(), timeout=2.0) - self.logger.info(f"Process {proc.pid} terminated gracefully.") - except TimeoutError: - self.logger.warning(f"Process {proc.pid} did not terminate gracefully, killing...") - proc.kill() - await proc.wait() - self.logger.info(f"Process {proc.pid} killed.") - except Exception as e: - self.logger.error(f"Error waiting for process termination: {e}", exc_info=True) - - except Exception as e: - self.logger.error(f"Error during shutdown sequence: {e}", exc_info=True) - finally: - self.logger.debug("Stop sequence finished.") - # Ensure self.process is None even if errors occurred - self.process = None - - -async def process_tool_call(tool_call: dict, servers: dict[str, MCPClient]) -> dict: - """Process a tool call from OpenAI.""" - func_name = tool_call["function"]["name"] - try: - func_args = json.loads(tool_call["function"].get("arguments", "{}")) - except json.JSONDecodeError as e: - logger.error(f"Invalid tool arguments format for {func_name}: {e}") - return {"error": "Invalid arguments format"} - - # Parse server_name and tool_name from function name - parts = func_name.split("_", 1) - if len(parts) != 2: - logger.error(f"Invalid tool function name format: {func_name}") - return {"error": "Invalid function name format"} - - server_name, tool_name = parts - - if server_name not in servers: - logger.error(f"Tool call for unknown server: {server_name}") - return {"error": f"Unknown server: {server_name}"} - - # Call the tool - return await servers[server_name].call_tool(tool_name, func_args) - - -async def run_interaction( - user_query: str, - model_name: str, - api_key: str, - base_url: str | None, - mcp_config: dict, - stream: bool = False, -) -> dict | AsyncGenerator: - """ - Run an interaction with OpenAI using MCP server tools. - - Args: - user_query: The user's input query. - model_name: The model to use for processing. - api_key: The OpenAI API key. - base_url: The OpenAI API base URL (optional). - mcp_config: The MCP configuration dictionary (for servers). - stream: Whether to stream the response. - - Returns: - Dictionary containing response or AsyncGenerator for streaming. - """ - # Validate passed arguments - if not api_key: - logger.error("API key is missing.") - if not stream: - return {"error": "API key is missing."} - else: - - async def error_gen(): - yield {"error": "API key is missing."} - - return error_gen() - - # Start MCP servers using mcp_config - servers = {} - all_functions = [] - if mcp_config.get("mcpServers"): # Use mcp_config here - for server_name, server_config in mcp_config["mcpServers"].items(): # Use mcp_config here - client = MCPClient(server_name=server_name, command=server_config.get("command"), args=server_config.get("args", []), env=server_config.get("env", {})) - - if await client.start(): - tools = await client.list_tools() - for tool in tools: - # Ensure parameters is a dict, default to empty if missing or not dict - params = tool.get("inputSchema", {}) - if not isinstance(params, dict): - logger.warning(f"Tool '{tool['name']}' for server '{server_name}' has non-dict inputSchema, defaulting to empty.") - params = {} - - all_functions.append({ - "type": "function", # Explicitly set type for clarity with newer OpenAI API - "function": {"name": f"{server_name}_{tool['name']}", "description": tool.get("description", ""), "parameters": params}, - }) - servers[server_name] = client - else: - logger.warning(f"Failed to start MCP server '{server_name}', it will be unavailable.") - else: - logger.info("No mcpServers defined in configuration.") - - # Use passed api_key and base_url - openai_client = AsyncOpenAI(api_key=api_key, base_url=base_url) # Use arguments - messages = [{"role": "user", "content": user_query}] - tool_defs = [{"type": "function", "function": f["function"]} for f in all_functions] if all_functions else None - - if stream: - - async def response_generator(): - active_servers = list(servers.values()) # Keep track for cleanup - try: - while True: - logger.debug(f"Calling OpenAI with messages: {messages}") - logger.debug(f"Calling OpenAI with tools: {tool_defs}") - # Get OpenAI response - try: - response = await openai_client.chat.completions.create( - model=model_name, - messages=messages, - tools=tool_defs, - tool_choice="auto" if tool_defs else None, # Only set tool_choice if tools exist - stream=True, - ) - except Exception as e: - logger.error(f"OpenAI API error: {e}", exc_info=True) - yield {"error": f"OpenAI API error: {e}"} - break - - # Process streaming response - full_response_content = "" - tool_calls = [] - async for chunk in response: - delta = chunk.choices[0].delta - if delta.content: - content = delta.content - full_response_content += content - yield {"assistant_text": content, "is_chunk": True} - - if delta.tool_calls: - for tc in delta.tool_calls: - # Initialize tool call structure if it's the first chunk for this index - if tc.index >= len(tool_calls): - tool_calls.append({"id": "", "type": "function", "function": {"name": "", "arguments": ""}}) - - # Append parts as they arrive - if tc.id: - tool_calls[tc.index]["id"] = tc.id - if tc.function and tc.function.name: - tool_calls[tc.index]["function"]["name"] = tc.function.name - if tc.function and tc.function.arguments: - tool_calls[tc.index]["function"]["arguments"] += tc.function.arguments - - # Add assistant message with content and potential tool calls - assistant_message = {"role": "assistant", "content": full_response_content} - if tool_calls: - # Filter out incomplete tool calls just in case - valid_tool_calls = [tc for tc in tool_calls if tc["id"] and tc["function"]["name"]] - if valid_tool_calls: - assistant_message["tool_calls"] = valid_tool_calls - else: - logger.warning("Received tool call chunks but couldn't assemble valid tool calls.") - - messages.append(assistant_message) - logger.debug(f"Assistant message added: {assistant_message}") - - # Handle tool calls if any were successfully assembled - if "tool_calls" in assistant_message: - tool_results = [] - for tc in assistant_message["tool_calls"]: - logger.info(f"Processing tool call: {tc['function']['name']}") - result = await process_tool_call(tc, servers) - tool_results.append({"role": "tool", "content": json.dumps(result), "tool_call_id": tc["id"]}) - - messages.extend(tool_results) - logger.debug(f"Tool results added: {tool_results}") - # Loop back to call OpenAI again with tool results - else: - # No tool calls, interaction finished for this turn - yield {"assistant_text": full_response_content, "is_chunk": False, "final": True} # Signal final chunk - break - - except Exception as e: - logger.error(f"Error during streaming interaction: {e}", exc_info=True) - yield {"error": f"Interaction error: {e}"} - finally: - # Clean up servers - logger.debug("Cleaning up MCP servers (stream)...") - for server in active_servers: - await server.stop() - logger.debug("MCP server cleanup finished (stream).") - - return response_generator() - - else: # Non-streaming case - active_servers = list(servers.values()) # Keep track for cleanup - try: - while True: - logger.debug(f"Calling OpenAI with messages: {messages}") - logger.debug(f"Calling OpenAI with tools: {tool_defs}") - # Get OpenAI response - try: - response = await openai_client.chat.completions.create( - model=model_name, - messages=messages, - tools=tool_defs, - tool_choice="auto" if tool_defs else None, # Only set tool_choice if tools exist - ) - except Exception as e: - logger.error(f"OpenAI API error: {e}", exc_info=True) - return {"error": f"OpenAI API error: {e}"} - - message = response.choices[0].message - messages.append(message) - logger.debug(f"OpenAI response message: {message}") - - # Handle tool calls - if message.tool_calls: - tool_results = [] - for tc in message.tool_calls: - logger.info(f"Processing tool call: {tc.function.name}") - # Reconstruct dict for process_tool_call - tool_call_dict = {"id": tc.id, "type": "function", "function": {"name": tc.function.name, "arguments": tc.function.arguments}} - result = await process_tool_call(tool_call_dict, servers) - tool_results.append({"role": "tool", "content": json.dumps(result), "tool_call_id": tc.id}) - - messages.extend(tool_results) - logger.debug(f"Tool results added: {tool_results}") - # Loop back to call OpenAI again with tool results - else: - # No tool calls, interaction finished - logger.info("Interaction finished, no tool calls.") - return {"assistant_text": message.content or "", "tool_calls": []} - - except Exception as e: - logger.error(f"Error during non-streaming interaction: {e}", exc_info=True) - return {"error": f"Interaction error: {e}"} - finally: - # Clean up servers - logger.debug("Cleaning up MCP servers (non-stream)...") - for server in active_servers: - await server.stop() - logger.debug("MCP server cleanup finished (non-stream).")