refactor: remove custom MCP client implementation files

This commit is contained in:
2025-03-26 11:00:43 +00:00
parent 80ba05338f
commit b4986e0eb9
2 changed files with 0 additions and 555 deletions

View File

@@ -1,5 +0,0 @@
"""Custom MCP client implementation focused on OpenAI integration."""
from .client import MCPClient, run_interaction
__all__ = ["MCPClient", "run_interaction"]

View File

@@ -1,550 +0,0 @@
"""Custom MCP client implementation with JSON-RPC and OpenAI integration."""
import asyncio
import json
import logging
import os
from collections.abc import AsyncGenerator
from openai import AsyncOpenAI
# Get a logger for this module
logger = logging.getLogger(__name__)
class MCPClient:
"""Lightweight MCP client with JSON-RPC communication."""
def __init__(self, server_name: str, command: str, args: list[str] | None = None, env: dict[str, str] | None = None):
self.server_name = server_name
self.command = command
self.args = args or []
self.env = env or {}
self.process = None
self.tools = []
self.request_id = 0
self.responses = {}
self._shutdown = False
# Use a logger specific to this client instance
self.logger = logging.getLogger(f"{__name__}.{self.server_name}")
async def _receive_loop(self):
"""Listen for responses from the MCP server."""
try:
while self.process and self.process.stdout and not self.process.stdout.at_eof():
line_bytes = await self.process.stdout.readline()
if not line_bytes:
self.logger.debug("STDOUT EOF reached.")
break
line_str = line_bytes.decode().strip()
self.logger.debug(f"STDOUT Raw line: {line_str}")
try:
message = json.loads(line_str)
if "jsonrpc" in message and "id" in message and ("result" in message or "error" in message):
self.logger.debug(f"STDOUT Parsed response for ID {message['id']}")
self.responses[message["id"]] = message
elif "jsonrpc" in message and "method" in message:
self.logger.debug(f"STDOUT Received notification: {message.get('method')}")
else:
self.logger.debug(f"STDOUT Parsed non-response/notification JSON: {message}")
except json.JSONDecodeError:
self.logger.warning("STDOUT Failed to parse line as JSON.")
except Exception as e:
self.logger.error(f"STDOUT Error processing line: {e}", exc_info=True)
except asyncio.CancelledError:
self.logger.debug("STDOUT Receive loop cancelled.")
except Exception as e:
self.logger.error(f"STDOUT Receive loop error: {e}", exc_info=True)
finally:
self.logger.debug("STDOUT Receive loop finished.")
async def _stderr_loop(self):
"""Listen for stderr messages from the MCP server."""
try:
while self.process and self.process.stderr and not self.process.stderr.at_eof():
line_bytes = await self.process.stderr.readline()
if not line_bytes:
self.logger.debug("STDERR EOF reached.")
break
line_str = line_bytes.decode().strip()
self.logger.warning(f"STDERR: {line_str}") # Log stderr as warning
except asyncio.CancelledError:
self.logger.debug("STDERR Stderr loop cancelled.")
except Exception as e:
self.logger.error(f"STDERR Stderr loop error: {e}", exc_info=True)
finally:
self.logger.debug("STDERR Stderr loop finished.")
async def _send_message(self, message: dict) -> bool:
"""Send a JSON-RPC message to the MCP server."""
if not self.process or not self.process.stdin:
self.logger.warning("STDIN Cannot send message, process or stdin not available.")
return False
try:
data = json.dumps(message) + "\n"
self.logger.debug(f"STDIN Sending: {data.strip()}")
self.process.stdin.write(data.encode())
await self.process.stdin.drain()
return True
except ConnectionResetError:
self.logger.error("STDIN Connection reset while sending message.")
self.process = None # Mark process as dead
return False
except Exception as e:
self.logger.error(f"STDIN Error sending message: {e}", exc_info=True)
return False
async def start(self) -> bool:
"""Start the MCP server process."""
self.logger.info("Attempting to start server...")
# Expand ~ in paths and prepare args
expanded_args = []
try:
for a in self.args:
if isinstance(a, str) and "~" in a:
expanded_args.append(os.path.expanduser(a))
else:
expanded_args.append(str(a)) # Ensure all args are strings
except Exception as e:
self.logger.error(f"Error expanding arguments: {e}", exc_info=True)
return False
# Set up environment
env_vars = os.environ.copy()
if self.env:
env_vars.update(self.env)
self.logger.debug(f"Command: {self.command}")
self.logger.debug(f"Expanded Args: {expanded_args}")
# Avoid logging full env unless necessary for debugging sensitive info
# self.logger.debug(f"Environment: {env_vars}")
try:
# Start the subprocess
self.logger.debug("Creating subprocess...")
self.process = await asyncio.create_subprocess_exec(
self.command,
*expanded_args,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE, # Capture stderr
env=env_vars,
)
self.logger.info(f"Subprocess created with PID: {self.process.pid}")
# Start the receive loops
asyncio.create_task(self._receive_loop())
asyncio.create_task(self._stderr_loop()) # Start stderr loop
# Initialize the server
self.logger.debug("Attempting initialization handshake...")
init_success = await self._initialize()
if init_success:
self.logger.info("Initialization handshake successful (or skipped).")
# Add delay after successful start
await asyncio.sleep(0.5)
self.logger.debug("Post-initialization delay complete.")
return True
else:
self.logger.error("Initialization handshake failed.")
await self.stop() # Ensure cleanup if init fails
return False
except FileNotFoundError:
self.logger.error(f"Error starting subprocess: Command not found: '{self.command}'")
return False
except Exception as e:
self.logger.error(f"Error starting subprocess: {e}", exc_info=True)
return False
async def _initialize(self) -> bool:
"""Initialize the MCP server connection. Modified to not wait for response."""
self.logger.debug("Sending 'initialize' request...")
if not self.process:
self.logger.warning("Cannot initialize, process not running.")
return False
# Send initialize request
self.request_id += 1
req_id = self.request_id
initialize_req = {
"jsonrpc": "2.0",
"id": req_id,
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"clientInfo": {"name": "CustomMCPClient", "version": "1.0.0"},
"capabilities": {}, # Add empty capabilities object
},
}
if not await self._send_message(initialize_req):
self.logger.warning("Failed to send 'initialize' request.")
# Continue anyway for non-compliant servers
# Send initialized notification immediately
self.logger.debug("Sending 'initialized' notification...")
notify = {"jsonrpc": "2.0", "method": "notifications/initialized"}
if await self._send_message(notify):
self.logger.debug("'initialized' notification sent.")
else:
self.logger.warning("Failed to send 'initialized' notification.")
# Still return True as the server might be running
self.logger.info("Skipping wait for 'initialize' response (assuming non-compliant server).")
return True # Assume success without waiting for response
async def list_tools(self) -> list[dict]:
"""List available tools from the MCP server."""
if not self.process:
self.logger.warning("Cannot list tools, process not running.")
return []
self.logger.debug("Sending 'tools/list' request...")
self.request_id += 1
req_id = self.request_id
req = {"jsonrpc": "2.0", "id": req_id, "method": "tools/list", "params": {}}
if not await self._send_message(req):
self.logger.error("Failed to send 'tools/list' request.")
return []
# Wait for response
self.logger.debug(f"Waiting for 'tools/list' response (ID: {req_id})...")
start_time = asyncio.get_event_loop().time()
timeout = 10 # seconds
while asyncio.get_event_loop().time() - start_time < timeout:
if req_id in self.responses:
resp = self.responses.pop(req_id)
self.logger.debug(f"Received 'tools/list' response: {resp}")
if "error" in resp:
self.logger.error(f"'tools/list' error response: {resp['error']}")
return []
if "result" in resp and "tools" in resp["result"]:
self.tools = resp["result"]["tools"]
self.logger.info(f"Successfully listed tools: {len(self.tools)}")
return self.tools
else:
self.logger.error("Invalid 'tools/list' response format.")
return []
await asyncio.sleep(0.05)
self.logger.error(f"'tools/list' request timed out after {timeout} seconds.")
return []
async def call_tool(self, tool_name: str, arguments: dict) -> dict:
"""Call a tool on the MCP server."""
if not self.process:
self.logger.warning(f"Cannot call tool '{tool_name}', process not running.")
return {"error": "Server not started"}
self.logger.debug(f"Sending 'tools/call' request for tool '{tool_name}'...")
self.request_id += 1
req_id = self.request_id
req = {"jsonrpc": "2.0", "id": req_id, "method": "tools/call", "params": {"name": tool_name, "arguments": arguments}}
if not await self._send_message(req):
self.logger.error("Failed to send 'tools/call' request.")
return {"error": "Failed to send tool call request"}
# Wait for response
self.logger.debug(f"Waiting for 'tools/call' response (ID: {req_id})...")
start_time = asyncio.get_event_loop().time()
timeout = 30 # seconds
while asyncio.get_event_loop().time() - start_time < timeout:
if req_id in self.responses:
resp = self.responses.pop(req_id)
self.logger.debug(f"Received 'tools/call' response: {resp}")
if "error" in resp:
self.logger.error(f"'tools/call' error response: {resp['error']}")
return {"error": str(resp["error"])}
if "result" in resp:
self.logger.info(f"Tool '{tool_name}' executed successfully.")
return resp["result"]
else:
self.logger.error("Invalid 'tools/call' response format.")
return {"error": "Invalid tool call response format"}
await asyncio.sleep(0.05)
self.logger.error(f"Tool call '{tool_name}' timed out after {timeout} seconds.")
return {"error": f"Tool call timed out after {timeout}s"}
async def stop(self):
"""Stop the MCP server process."""
self.logger.info("Attempting to stop server...")
if self._shutdown or not self.process:
self.logger.debug("Server already stopped or not running.")
return
self._shutdown = True
proc = self.process # Keep a local reference
self.process = None # Prevent further operations
try:
# Send shutdown notification
self.logger.debug("Sending 'shutdown' notification...")
notify = {"jsonrpc": "2.0", "method": "shutdown"}
await self._send_message(notify) # Use the method which now handles None process
await asyncio.sleep(0.5) # Give server time to process
# Close stdin
if proc and proc.stdin:
try:
if not proc.stdin.is_closing():
proc.stdin.close()
await proc.stdin.wait_closed()
self.logger.debug("Stdin closed.")
except Exception as e:
self.logger.warning(f"Error closing stdin: {e}", exc_info=True)
# Terminate the process
if proc:
self.logger.debug(f"Terminating process {proc.pid}...")
proc.terminate()
try:
await asyncio.wait_for(proc.wait(), timeout=2.0)
self.logger.info(f"Process {proc.pid} terminated gracefully.")
except TimeoutError:
self.logger.warning(f"Process {proc.pid} did not terminate gracefully, killing...")
proc.kill()
await proc.wait()
self.logger.info(f"Process {proc.pid} killed.")
except Exception as e:
self.logger.error(f"Error waiting for process termination: {e}", exc_info=True)
except Exception as e:
self.logger.error(f"Error during shutdown sequence: {e}", exc_info=True)
finally:
self.logger.debug("Stop sequence finished.")
# Ensure self.process is None even if errors occurred
self.process = None
async def process_tool_call(tool_call: dict, servers: dict[str, MCPClient]) -> dict:
"""Process a tool call from OpenAI."""
func_name = tool_call["function"]["name"]
try:
func_args = json.loads(tool_call["function"].get("arguments", "{}"))
except json.JSONDecodeError as e:
logger.error(f"Invalid tool arguments format for {func_name}: {e}")
return {"error": "Invalid arguments format"}
# Parse server_name and tool_name from function name
parts = func_name.split("_", 1)
if len(parts) != 2:
logger.error(f"Invalid tool function name format: {func_name}")
return {"error": "Invalid function name format"}
server_name, tool_name = parts
if server_name not in servers:
logger.error(f"Tool call for unknown server: {server_name}")
return {"error": f"Unknown server: {server_name}"}
# Call the tool
return await servers[server_name].call_tool(tool_name, func_args)
async def run_interaction(
user_query: str,
model_name: str,
api_key: str,
base_url: str | None,
mcp_config: dict,
stream: bool = False,
) -> dict | AsyncGenerator:
"""
Run an interaction with OpenAI using MCP server tools.
Args:
user_query: The user's input query.
model_name: The model to use for processing.
api_key: The OpenAI API key.
base_url: The OpenAI API base URL (optional).
mcp_config: The MCP configuration dictionary (for servers).
stream: Whether to stream the response.
Returns:
Dictionary containing response or AsyncGenerator for streaming.
"""
# Validate passed arguments
if not api_key:
logger.error("API key is missing.")
if not stream:
return {"error": "API key is missing."}
else:
async def error_gen():
yield {"error": "API key is missing."}
return error_gen()
# Start MCP servers using mcp_config
servers = {}
all_functions = []
if mcp_config.get("mcpServers"): # Use mcp_config here
for server_name, server_config in mcp_config["mcpServers"].items(): # Use mcp_config here
client = MCPClient(server_name=server_name, command=server_config.get("command"), args=server_config.get("args", []), env=server_config.get("env", {}))
if await client.start():
tools = await client.list_tools()
for tool in tools:
# Ensure parameters is a dict, default to empty if missing or not dict
params = tool.get("inputSchema", {})
if not isinstance(params, dict):
logger.warning(f"Tool '{tool['name']}' for server '{server_name}' has non-dict inputSchema, defaulting to empty.")
params = {}
all_functions.append({
"type": "function", # Explicitly set type for clarity with newer OpenAI API
"function": {"name": f"{server_name}_{tool['name']}", "description": tool.get("description", ""), "parameters": params},
})
servers[server_name] = client
else:
logger.warning(f"Failed to start MCP server '{server_name}', it will be unavailable.")
else:
logger.info("No mcpServers defined in configuration.")
# Use passed api_key and base_url
openai_client = AsyncOpenAI(api_key=api_key, base_url=base_url) # Use arguments
messages = [{"role": "user", "content": user_query}]
tool_defs = [{"type": "function", "function": f["function"]} for f in all_functions] if all_functions else None
if stream:
async def response_generator():
active_servers = list(servers.values()) # Keep track for cleanup
try:
while True:
logger.debug(f"Calling OpenAI with messages: {messages}")
logger.debug(f"Calling OpenAI with tools: {tool_defs}")
# Get OpenAI response
try:
response = await openai_client.chat.completions.create(
model=model_name,
messages=messages,
tools=tool_defs,
tool_choice="auto" if tool_defs else None, # Only set tool_choice if tools exist
stream=True,
)
except Exception as e:
logger.error(f"OpenAI API error: {e}", exc_info=True)
yield {"error": f"OpenAI API error: {e}"}
break
# Process streaming response
full_response_content = ""
tool_calls = []
async for chunk in response:
delta = chunk.choices[0].delta
if delta.content:
content = delta.content
full_response_content += content
yield {"assistant_text": content, "is_chunk": True}
if delta.tool_calls:
for tc in delta.tool_calls:
# Initialize tool call structure if it's the first chunk for this index
if tc.index >= len(tool_calls):
tool_calls.append({"id": "", "type": "function", "function": {"name": "", "arguments": ""}})
# Append parts as they arrive
if tc.id:
tool_calls[tc.index]["id"] = tc.id
if tc.function and tc.function.name:
tool_calls[tc.index]["function"]["name"] = tc.function.name
if tc.function and tc.function.arguments:
tool_calls[tc.index]["function"]["arguments"] += tc.function.arguments
# Add assistant message with content and potential tool calls
assistant_message = {"role": "assistant", "content": full_response_content}
if tool_calls:
# Filter out incomplete tool calls just in case
valid_tool_calls = [tc for tc in tool_calls if tc["id"] and tc["function"]["name"]]
if valid_tool_calls:
assistant_message["tool_calls"] = valid_tool_calls
else:
logger.warning("Received tool call chunks but couldn't assemble valid tool calls.")
messages.append(assistant_message)
logger.debug(f"Assistant message added: {assistant_message}")
# Handle tool calls if any were successfully assembled
if "tool_calls" in assistant_message:
tool_results = []
for tc in assistant_message["tool_calls"]:
logger.info(f"Processing tool call: {tc['function']['name']}")
result = await process_tool_call(tc, servers)
tool_results.append({"role": "tool", "content": json.dumps(result), "tool_call_id": tc["id"]})
messages.extend(tool_results)
logger.debug(f"Tool results added: {tool_results}")
# Loop back to call OpenAI again with tool results
else:
# No tool calls, interaction finished for this turn
yield {"assistant_text": full_response_content, "is_chunk": False, "final": True} # Signal final chunk
break
except Exception as e:
logger.error(f"Error during streaming interaction: {e}", exc_info=True)
yield {"error": f"Interaction error: {e}"}
finally:
# Clean up servers
logger.debug("Cleaning up MCP servers (stream)...")
for server in active_servers:
await server.stop()
logger.debug("MCP server cleanup finished (stream).")
return response_generator()
else: # Non-streaming case
active_servers = list(servers.values()) # Keep track for cleanup
try:
while True:
logger.debug(f"Calling OpenAI with messages: {messages}")
logger.debug(f"Calling OpenAI with tools: {tool_defs}")
# Get OpenAI response
try:
response = await openai_client.chat.completions.create(
model=model_name,
messages=messages,
tools=tool_defs,
tool_choice="auto" if tool_defs else None, # Only set tool_choice if tools exist
)
except Exception as e:
logger.error(f"OpenAI API error: {e}", exc_info=True)
return {"error": f"OpenAI API error: {e}"}
message = response.choices[0].message
messages.append(message)
logger.debug(f"OpenAI response message: {message}")
# Handle tool calls
if message.tool_calls:
tool_results = []
for tc in message.tool_calls:
logger.info(f"Processing tool call: {tc.function.name}")
# Reconstruct dict for process_tool_call
tool_call_dict = {"id": tc.id, "type": "function", "function": {"name": tc.function.name, "arguments": tc.function.arguments}}
result = await process_tool_call(tool_call_dict, servers)
tool_results.append({"role": "tool", "content": json.dumps(result), "tool_call_id": tc.id})
messages.extend(tool_results)
logger.debug(f"Tool results added: {tool_results}")
# Loop back to call OpenAI again with tool results
else:
# No tool calls, interaction finished
logger.info("Interaction finished, no tool calls.")
return {"assistant_text": message.content or "", "tool_calls": []}
except Exception as e:
logger.error(f"Error during non-streaming interaction: {e}", exc_info=True)
return {"error": f"Interaction error: {e}"}
finally:
# Clean up servers
logger.debug("Cleaning up MCP servers (non-stream)...")
for server in active_servers:
await server.stop()
logger.debug("MCP server cleanup finished (non-stream).")