Compare commits
3 Commits
ccd0a1e45b
...
a7d5a4cb33
| Author | SHA1 | Date | |
|---|---|---|---|
|
a7d5a4cb33
|
|||
|
845f2e77dd
|
|||
|
ec39844bf1
|
@@ -41,6 +41,8 @@ build-backend = "hatchling.build"
|
|||||||
|
|
||||||
[tool.hatch.build.targets.wheel]
|
[tool.hatch.build.targets.wheel]
|
||||||
packages = ["src"]
|
packages = ["src"]
|
||||||
|
[tool.hatch.build.targets.wheel.force-include]
|
||||||
|
"config" = "config"
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
line-length = 200
|
line-length = 200
|
||||||
|
|||||||
@@ -37,10 +37,12 @@ def handle_user_input():
|
|||||||
response = st.session_state.client.get_chat_response(st.session_state.messages)
|
response = st.session_state.client.get_chat_response(st.session_state.messages)
|
||||||
|
|
||||||
# Handle both MCP and standard OpenAI responses
|
# Handle both MCP and standard OpenAI responses
|
||||||
if hasattr(response, "__iter__"):
|
# Check if it's NOT a dict (assuming stream is not a dict)
|
||||||
|
if not isinstance(response, dict):
|
||||||
# Standard OpenAI streaming response
|
# Standard OpenAI streaming response
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
if chunk.choices[0].delta.content:
|
# Ensure chunk has choices and delta before accessing
|
||||||
|
if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
|
||||||
full_response += chunk.choices[0].delta.content
|
full_response += chunk.choices[0].delta.content
|
||||||
response_placeholder.markdown(full_response + "▌")
|
response_placeholder.markdown(full_response + "▌")
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -2,11 +2,15 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
|
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
# Get a logger for this module
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MCPClient:
|
class MCPClient:
|
||||||
"""Lightweight MCP client with JSON-RPC communication."""
|
"""Lightweight MCP client with JSON-RPC communication."""
|
||||||
@@ -21,180 +25,300 @@ class MCPClient:
|
|||||||
self.request_id = 0
|
self.request_id = 0
|
||||||
self.responses = {}
|
self.responses = {}
|
||||||
self._shutdown = False
|
self._shutdown = False
|
||||||
|
# Use a logger specific to this client instance
|
||||||
|
self.logger = logging.getLogger(f"{__name__}.{self.server_name}")
|
||||||
|
|
||||||
async def _receive_loop(self):
|
async def _receive_loop(self):
|
||||||
"""Listen for responses from the MCP server."""
|
"""Listen for responses from the MCP server."""
|
||||||
try:
|
try:
|
||||||
while not self.process.stdout.at_eof():
|
while self.process and self.process.stdout and not self.process.stdout.at_eof():
|
||||||
line = await self.process.stdout.readline()
|
line_bytes = await self.process.stdout.readline()
|
||||||
if not line:
|
if not line_bytes:
|
||||||
|
self.logger.debug("STDOUT EOF reached.")
|
||||||
break
|
break
|
||||||
|
line_str = line_bytes.decode().strip()
|
||||||
|
self.logger.debug(f"STDOUT Raw line: {line_str}")
|
||||||
try:
|
try:
|
||||||
message = json.loads(line.decode().strip())
|
message = json.loads(line_str)
|
||||||
if "jsonrpc" in message and "id" in message and ("result" in message or "error" in message):
|
if "jsonrpc" in message and "id" in message and ("result" in message or "error" in message):
|
||||||
|
self.logger.debug(f"STDOUT Parsed response for ID {message['id']}")
|
||||||
self.responses[message["id"]] = message
|
self.responses[message["id"]] = message
|
||||||
except Exception:
|
elif "jsonrpc" in message and "method" in message:
|
||||||
pass
|
self.logger.debug(f"STDOUT Received notification: {message.get('method')}")
|
||||||
except Exception:
|
else:
|
||||||
pass
|
self.logger.debug(f"STDOUT Parsed non-response/notification JSON: {message}")
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
self.logger.warning("STDOUT Failed to parse line as JSON.")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"STDOUT Error processing line: {e}", exc_info=True)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
self.logger.debug("STDOUT Receive loop cancelled.")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"STDOUT Receive loop error: {e}", exc_info=True)
|
||||||
|
finally:
|
||||||
|
self.logger.debug("STDOUT Receive loop finished.")
|
||||||
|
|
||||||
|
async def _stderr_loop(self):
|
||||||
|
"""Listen for stderr messages from the MCP server."""
|
||||||
|
try:
|
||||||
|
while self.process and self.process.stderr and not self.process.stderr.at_eof():
|
||||||
|
line_bytes = await self.process.stderr.readline()
|
||||||
|
if not line_bytes:
|
||||||
|
self.logger.debug("STDERR EOF reached.")
|
||||||
|
break
|
||||||
|
line_str = line_bytes.decode().strip()
|
||||||
|
self.logger.warning(f"STDERR: {line_str}") # Log stderr as warning
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
self.logger.debug("STDERR Stderr loop cancelled.")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"STDERR Stderr loop error: {e}", exc_info=True)
|
||||||
|
finally:
|
||||||
|
self.logger.debug("STDERR Stderr loop finished.")
|
||||||
|
|
||||||
async def _send_message(self, message: dict) -> bool:
|
async def _send_message(self, message: dict) -> bool:
|
||||||
"""Send a JSON-RPC message to the MCP server."""
|
"""Send a JSON-RPC message to the MCP server."""
|
||||||
if not self.process:
|
if not self.process or not self.process.stdin:
|
||||||
|
self.logger.warning("STDIN Cannot send message, process or stdin not available.")
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
data = json.dumps(message) + "\n"
|
data = json.dumps(message) + "\n"
|
||||||
|
self.logger.debug(f"STDIN Sending: {data.strip()}")
|
||||||
self.process.stdin.write(data.encode())
|
self.process.stdin.write(data.encode())
|
||||||
await self.process.stdin.drain()
|
await self.process.stdin.drain()
|
||||||
return True
|
return True
|
||||||
except Exception:
|
except ConnectionResetError:
|
||||||
|
self.logger.error("STDIN Connection reset while sending message.")
|
||||||
|
self.process = None # Mark process as dead
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"STDIN Error sending message: {e}", exc_info=True)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def start(self) -> bool:
|
async def start(self) -> bool:
|
||||||
"""Start the MCP server process."""
|
"""Start the MCP server process."""
|
||||||
# Expand ~ in paths
|
self.logger.info("Attempting to start server...")
|
||||||
|
# Expand ~ in paths and prepare args
|
||||||
expanded_args = []
|
expanded_args = []
|
||||||
for a in self.args:
|
try:
|
||||||
if isinstance(a, str) and "~" in a:
|
for a in self.args:
|
||||||
expanded_args.append(os.path.expanduser(a))
|
if isinstance(a, str) and "~" in a:
|
||||||
else:
|
expanded_args.append(os.path.expanduser(a))
|
||||||
expanded_args.append(a)
|
else:
|
||||||
|
expanded_args.append(str(a)) # Ensure all args are strings
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error expanding arguments: {e}", exc_info=True)
|
||||||
|
return False
|
||||||
|
|
||||||
# Set up environment
|
# Set up environment
|
||||||
env_vars = os.environ.copy()
|
env_vars = os.environ.copy()
|
||||||
if self.env:
|
if self.env:
|
||||||
env_vars.update(self.env)
|
env_vars.update(self.env)
|
||||||
|
|
||||||
|
self.logger.debug(f"Command: {self.command}")
|
||||||
|
self.logger.debug(f"Expanded Args: {expanded_args}")
|
||||||
|
# Avoid logging full env unless necessary for debugging sensitive info
|
||||||
|
# self.logger.debug(f"Environment: {env_vars}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Start the subprocess
|
# Start the subprocess
|
||||||
|
self.logger.debug("Creating subprocess...")
|
||||||
self.process = await asyncio.create_subprocess_exec(
|
self.process = await asyncio.create_subprocess_exec(
|
||||||
self.command, *expanded_args, stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, env=env_vars
|
self.command,
|
||||||
|
*expanded_args,
|
||||||
|
stdin=asyncio.subprocess.PIPE,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE, # Capture stderr
|
||||||
|
env=env_vars,
|
||||||
)
|
)
|
||||||
|
self.logger.info(f"Subprocess created with PID: {self.process.pid}")
|
||||||
|
|
||||||
# Start the receive loop
|
# Start the receive loops
|
||||||
asyncio.create_task(self._receive_loop())
|
asyncio.create_task(self._receive_loop())
|
||||||
|
asyncio.create_task(self._stderr_loop()) # Start stderr loop
|
||||||
|
|
||||||
# Initialize the server
|
# Initialize the server
|
||||||
return await self._initialize()
|
self.logger.debug("Attempting initialization handshake...")
|
||||||
except Exception:
|
init_success = await self._initialize()
|
||||||
|
if init_success:
|
||||||
|
self.logger.info("Initialization handshake successful (or skipped).")
|
||||||
|
# Add delay after successful start
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
self.logger.debug("Post-initialization delay complete.")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
self.logger.error("Initialization handshake failed.")
|
||||||
|
await self.stop() # Ensure cleanup if init fails
|
||||||
|
return False
|
||||||
|
except FileNotFoundError:
|
||||||
|
self.logger.error(f"Error starting subprocess: Command not found: '{self.command}'")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error starting subprocess: {e}", exc_info=True)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def _initialize(self) -> bool:
|
async def _initialize(self) -> bool:
|
||||||
"""Initialize the MCP server connection."""
|
"""Initialize the MCP server connection. Modified to not wait for response."""
|
||||||
|
self.logger.debug("Sending 'initialize' request...")
|
||||||
if not self.process:
|
if not self.process:
|
||||||
|
self.logger.warning("Cannot initialize, process not running.")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Send initialize request
|
# Send initialize request
|
||||||
self.request_id += 1
|
self.request_id += 1
|
||||||
req_id = self.request_id
|
req_id = self.request_id
|
||||||
initialize_req = {"jsonrpc": "2.0", "id": req_id, "method": "initialize", "params": {"protocolVersion": "2024-11-05", "clientInfo": {"name": "CustomMCPClient", "version": "1.0.0"}}}
|
initialize_req = {
|
||||||
await self._send_message(initialize_req)
|
"jsonrpc": "2.0",
|
||||||
|
"id": req_id,
|
||||||
|
"method": "initialize",
|
||||||
|
"params": {
|
||||||
|
"protocolVersion": "2024-11-05",
|
||||||
|
"clientInfo": {"name": "CustomMCPClient", "version": "1.0.0"},
|
||||||
|
"capabilities": {}, # Add empty capabilities object
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if not await self._send_message(initialize_req):
|
||||||
|
self.logger.warning("Failed to send 'initialize' request.")
|
||||||
|
# Continue anyway for non-compliant servers
|
||||||
|
|
||||||
# Wait for response
|
# Send initialized notification immediately
|
||||||
start_time = asyncio.get_event_loop().time()
|
self.logger.debug("Sending 'initialized' notification...")
|
||||||
timeout = 10
|
notify = {"jsonrpc": "2.0", "method": "notifications/initialized"}
|
||||||
while asyncio.get_event_loop().time() - start_time < timeout:
|
if await self._send_message(notify):
|
||||||
if req_id in self.responses:
|
self.logger.debug("'initialized' notification sent.")
|
||||||
resp = self.responses[req_id]
|
else:
|
||||||
del self.responses[req_id]
|
self.logger.warning("Failed to send 'initialized' notification.")
|
||||||
|
# Still return True as the server might be running
|
||||||
|
|
||||||
if "error" in resp:
|
self.logger.info("Skipping wait for 'initialize' response (assuming non-compliant server).")
|
||||||
return False
|
return True # Assume success without waiting for response
|
||||||
|
|
||||||
# Send initialized notification
|
|
||||||
notify = {"jsonrpc": "2.0", "method": "notifications/initialized"}
|
|
||||||
await self._send_message(notify)
|
|
||||||
return True
|
|
||||||
|
|
||||||
await asyncio.sleep(0.05)
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def list_tools(self) -> list[dict]:
|
async def list_tools(self) -> list[dict]:
|
||||||
"""List available tools from the MCP server."""
|
"""List available tools from the MCP server."""
|
||||||
if not self.process:
|
if not self.process:
|
||||||
|
self.logger.warning("Cannot list tools, process not running.")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
self.logger.debug("Sending 'tools/list' request...")
|
||||||
self.request_id += 1
|
self.request_id += 1
|
||||||
req_id = self.request_id
|
req_id = self.request_id
|
||||||
req = {"jsonrpc": "2.0", "id": req_id, "method": "tools/list", "params": {}}
|
req = {"jsonrpc": "2.0", "id": req_id, "method": "tools/list", "params": {}}
|
||||||
await self._send_message(req)
|
if not await self._send_message(req):
|
||||||
|
self.logger.error("Failed to send 'tools/list' request.")
|
||||||
|
return []
|
||||||
|
|
||||||
# Wait for response
|
# Wait for response
|
||||||
|
self.logger.debug(f"Waiting for 'tools/list' response (ID: {req_id})...")
|
||||||
start_time = asyncio.get_event_loop().time()
|
start_time = asyncio.get_event_loop().time()
|
||||||
timeout = 10
|
timeout = 10 # seconds
|
||||||
while asyncio.get_event_loop().time() - start_time < timeout:
|
while asyncio.get_event_loop().time() - start_time < timeout:
|
||||||
if req_id in self.responses:
|
if req_id in self.responses:
|
||||||
resp = self.responses[req_id]
|
resp = self.responses.pop(req_id)
|
||||||
del self.responses[req_id]
|
self.logger.debug(f"Received 'tools/list' response: {resp}")
|
||||||
|
|
||||||
if "error" in resp:
|
if "error" in resp:
|
||||||
|
self.logger.error(f"'tools/list' error response: {resp['error']}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
if "result" in resp and "tools" in resp["result"]:
|
if "result" in resp and "tools" in resp["result"]:
|
||||||
self.tools = resp["result"]["tools"]
|
self.tools = resp["result"]["tools"]
|
||||||
|
self.logger.info(f"Successfully listed tools: {len(self.tools)}")
|
||||||
return self.tools
|
return self.tools
|
||||||
|
else:
|
||||||
|
self.logger.error("Invalid 'tools/list' response format.")
|
||||||
|
return []
|
||||||
|
|
||||||
await asyncio.sleep(0.05)
|
await asyncio.sleep(0.05)
|
||||||
|
|
||||||
|
self.logger.error(f"'tools/list' request timed out after {timeout} seconds.")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async def call_tool(self, tool_name: str, arguments: dict) -> dict:
|
async def call_tool(self, tool_name: str, arguments: dict) -> dict:
|
||||||
"""Call a tool on the MCP server."""
|
"""Call a tool on the MCP server."""
|
||||||
if not self.process:
|
if not self.process:
|
||||||
|
self.logger.warning(f"Cannot call tool '{tool_name}', process not running.")
|
||||||
return {"error": "Server not started"}
|
return {"error": "Server not started"}
|
||||||
|
|
||||||
|
self.logger.debug(f"Sending 'tools/call' request for tool '{tool_name}'...")
|
||||||
self.request_id += 1
|
self.request_id += 1
|
||||||
req_id = self.request_id
|
req_id = self.request_id
|
||||||
req = {"jsonrpc": "2.0", "id": req_id, "method": "tools/call", "params": {"name": tool_name, "arguments": arguments}}
|
req = {"jsonrpc": "2.0", "id": req_id, "method": "tools/call", "params": {"name": tool_name, "arguments": arguments}}
|
||||||
await self._send_message(req)
|
if not await self._send_message(req):
|
||||||
|
self.logger.error("Failed to send 'tools/call' request.")
|
||||||
|
return {"error": "Failed to send tool call request"}
|
||||||
|
|
||||||
# Wait for response
|
# Wait for response
|
||||||
|
self.logger.debug(f"Waiting for 'tools/call' response (ID: {req_id})...")
|
||||||
start_time = asyncio.get_event_loop().time()
|
start_time = asyncio.get_event_loop().time()
|
||||||
timeout = 30
|
timeout = 30 # seconds
|
||||||
while asyncio.get_event_loop().time() - start_time < timeout:
|
while asyncio.get_event_loop().time() - start_time < timeout:
|
||||||
if req_id in self.responses:
|
if req_id in self.responses:
|
||||||
resp = self.responses[req_id]
|
resp = self.responses.pop(req_id)
|
||||||
del self.responses[req_id]
|
self.logger.debug(f"Received 'tools/call' response: {resp}")
|
||||||
|
|
||||||
if "error" in resp:
|
if "error" in resp:
|
||||||
|
self.logger.error(f"'tools/call' error response: {resp['error']}")
|
||||||
return {"error": str(resp["error"])}
|
return {"error": str(resp["error"])}
|
||||||
|
|
||||||
if "result" in resp:
|
if "result" in resp:
|
||||||
|
self.logger.info(f"Tool '{tool_name}' executed successfully.")
|
||||||
return resp["result"]
|
return resp["result"]
|
||||||
|
else:
|
||||||
|
self.logger.error("Invalid 'tools/call' response format.")
|
||||||
|
return {"error": "Invalid tool call response format"}
|
||||||
|
|
||||||
await asyncio.sleep(0.05)
|
await asyncio.sleep(0.05)
|
||||||
|
|
||||||
|
self.logger.error(f"Tool call '{tool_name}' timed out after {timeout} seconds.")
|
||||||
return {"error": f"Tool call timed out after {timeout}s"}
|
return {"error": f"Tool call timed out after {timeout}s"}
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self):
|
||||||
"""Stop the MCP server process."""
|
"""Stop the MCP server process."""
|
||||||
|
self.logger.info("Attempting to stop server...")
|
||||||
if self._shutdown or not self.process:
|
if self._shutdown or not self.process:
|
||||||
|
self.logger.debug("Server already stopped or not running.")
|
||||||
return
|
return
|
||||||
|
|
||||||
self._shutdown = True
|
self._shutdown = True
|
||||||
|
proc = self.process # Keep a local reference
|
||||||
|
self.process = None # Prevent further operations
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Send shutdown notification
|
# Send shutdown notification
|
||||||
|
self.logger.debug("Sending 'shutdown' notification...")
|
||||||
notify = {"jsonrpc": "2.0", "method": "shutdown"}
|
notify = {"jsonrpc": "2.0", "method": "shutdown"}
|
||||||
await self._send_message(notify)
|
await self._send_message(notify) # Use the method which now handles None process
|
||||||
await asyncio.sleep(0.5)
|
await asyncio.sleep(0.5) # Give server time to process
|
||||||
|
|
||||||
# Close stdin
|
# Close stdin
|
||||||
if self.process.stdin:
|
if proc and proc.stdin:
|
||||||
self.process.stdin.close()
|
try:
|
||||||
|
if not proc.stdin.is_closing():
|
||||||
|
proc.stdin.close()
|
||||||
|
await proc.stdin.wait_closed()
|
||||||
|
self.logger.debug("Stdin closed.")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(f"Error closing stdin: {e}", exc_info=True)
|
||||||
|
|
||||||
# Terminate the process
|
# Terminate the process
|
||||||
self.process.terminate()
|
if proc:
|
||||||
try:
|
self.logger.debug(f"Terminating process {proc.pid}...")
|
||||||
await asyncio.wait_for(self.process.wait(), timeout=1.0)
|
proc.terminate()
|
||||||
except TimeoutError:
|
try:
|
||||||
self.process.kill()
|
await asyncio.wait_for(proc.wait(), timeout=2.0)
|
||||||
|
self.logger.info(f"Process {proc.pid} terminated gracefully.")
|
||||||
except Exception:
|
except TimeoutError:
|
||||||
pass
|
self.logger.warning(f"Process {proc.pid} did not terminate gracefully, killing...")
|
||||||
|
proc.kill()
|
||||||
|
await proc.wait()
|
||||||
|
self.logger.info(f"Process {proc.pid} killed.")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error waiting for process termination: {e}", exc_info=True)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error during shutdown sequence: {e}", exc_info=True)
|
||||||
finally:
|
finally:
|
||||||
|
self.logger.debug("Stop sequence finished.")
|
||||||
|
# Ensure self.process is None even if errors occurred
|
||||||
self.process = None
|
self.process = None
|
||||||
|
|
||||||
|
|
||||||
@@ -203,113 +327,224 @@ async def process_tool_call(tool_call: dict, servers: dict[str, MCPClient]) -> d
|
|||||||
func_name = tool_call["function"]["name"]
|
func_name = tool_call["function"]["name"]
|
||||||
try:
|
try:
|
||||||
func_args = json.loads(tool_call["function"].get("arguments", "{}"))
|
func_args = json.loads(tool_call["function"].get("arguments", "{}"))
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"Invalid tool arguments format for {func_name}: {e}")
|
||||||
return {"error": "Invalid arguments format"}
|
return {"error": "Invalid arguments format"}
|
||||||
|
|
||||||
# Parse server_name and tool_name from function name
|
# Parse server_name and tool_name from function name
|
||||||
parts = func_name.split("_", 1)
|
parts = func_name.split("_", 1)
|
||||||
if len(parts) != 2:
|
if len(parts) != 2:
|
||||||
|
logger.error(f"Invalid tool function name format: {func_name}")
|
||||||
return {"error": "Invalid function name format"}
|
return {"error": "Invalid function name format"}
|
||||||
|
|
||||||
server_name, tool_name = parts
|
server_name, tool_name = parts
|
||||||
|
|
||||||
if server_name not in servers:
|
if server_name not in servers:
|
||||||
|
logger.error(f"Tool call for unknown server: {server_name}")
|
||||||
return {"error": f"Unknown server: {server_name}"}
|
return {"error": f"Unknown server: {server_name}"}
|
||||||
|
|
||||||
# Call the tool
|
# Call the tool
|
||||||
return await servers[server_name].call_tool(tool_name, func_args)
|
return await servers[server_name].call_tool(tool_name, func_args)
|
||||||
|
|
||||||
|
|
||||||
async def run_interaction(user_query: str, model_name: str, config: dict, stream: bool = False) -> dict | AsyncGenerator:
|
async def run_interaction(
|
||||||
|
user_query: str,
|
||||||
|
model_name: str,
|
||||||
|
api_key: str,
|
||||||
|
base_url: str | None,
|
||||||
|
mcp_config: dict,
|
||||||
|
stream: bool = False,
|
||||||
|
) -> dict | AsyncGenerator:
|
||||||
"""
|
"""
|
||||||
Run an interaction with OpenAI using MCP server tools.
|
Run an interaction with OpenAI using MCP server tools.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_query: The user's input query
|
user_query: The user's input query.
|
||||||
model_name: The model to use for processing
|
model_name: The model to use for processing.
|
||||||
config: Configuration dictionary
|
api_key: The OpenAI API key.
|
||||||
stream: Whether to stream the response
|
base_url: The OpenAI API base URL (optional).
|
||||||
|
mcp_config: The MCP configuration dictionary (for servers).
|
||||||
|
stream: Whether to stream the response.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary containing response or AsyncGenerator for streaming
|
Dictionary containing response or AsyncGenerator for streaming.
|
||||||
"""
|
"""
|
||||||
# Get OpenAI configuration
|
# Validate passed arguments
|
||||||
api_key = config["models"][0]["apiKey"]
|
if not api_key:
|
||||||
base_url = config["models"][0].get("apiBase", "https://api.openai.com/v1")
|
logger.error("API key is missing.")
|
||||||
|
if not stream:
|
||||||
|
return {"error": "API key is missing."}
|
||||||
|
else:
|
||||||
|
|
||||||
# Start MCP servers
|
async def error_gen():
|
||||||
|
yield {"error": "API key is missing."}
|
||||||
|
|
||||||
|
return error_gen()
|
||||||
|
|
||||||
|
# Start MCP servers using mcp_config
|
||||||
servers = {}
|
servers = {}
|
||||||
all_functions = []
|
all_functions = []
|
||||||
for server_name, server_config in config["mcpServers"].items():
|
if mcp_config.get("mcpServers"): # Use mcp_config here
|
||||||
client = MCPClient(server_name=server_name, command=server_config.get("command"), args=server_config.get("args", []), env=server_config.get("env", {}))
|
for server_name, server_config in mcp_config["mcpServers"].items(): # Use mcp_config here
|
||||||
|
client = MCPClient(server_name=server_name, command=server_config.get("command"), args=server_config.get("args", []), env=server_config.get("env", {}))
|
||||||
|
|
||||||
if await client.start():
|
if await client.start():
|
||||||
tools = await client.list_tools()
|
tools = await client.list_tools()
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
all_functions.append({"name": f"{server_name}_{tool['name']}", "description": tool.get("description", ""), "parameters": tool.get("inputSchema", {})})
|
# Ensure parameters is a dict, default to empty if missing or not dict
|
||||||
servers[server_name] = client
|
params = tool.get("inputSchema", {})
|
||||||
|
if not isinstance(params, dict):
|
||||||
|
logger.warning(f"Tool '{tool['name']}' for server '{server_name}' has non-dict inputSchema, defaulting to empty.")
|
||||||
|
params = {}
|
||||||
|
|
||||||
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
all_functions.append({
|
||||||
|
"type": "function", # Explicitly set type for clarity with newer OpenAI API
|
||||||
|
"function": {"name": f"{server_name}_{tool['name']}", "description": tool.get("description", ""), "parameters": params},
|
||||||
|
})
|
||||||
|
servers[server_name] = client
|
||||||
|
else:
|
||||||
|
logger.warning(f"Failed to start MCP server '{server_name}', it will be unavailable.")
|
||||||
|
else:
|
||||||
|
logger.info("No mcpServers defined in configuration.")
|
||||||
|
|
||||||
|
# Use passed api_key and base_url
|
||||||
|
openai_client = AsyncOpenAI(api_key=api_key, base_url=base_url) # Use arguments
|
||||||
messages = [{"role": "user", "content": user_query}]
|
messages = [{"role": "user", "content": user_query}]
|
||||||
|
tool_defs = [{"type": "function", "function": f["function"]} for f in all_functions] if all_functions else None
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
|
|
||||||
async def response_generator():
|
async def response_generator():
|
||||||
|
active_servers = list(servers.values()) # Keep track for cleanup
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
|
logger.debug(f"Calling OpenAI with messages: {messages}")
|
||||||
|
logger.debug(f"Calling OpenAI with tools: {tool_defs}")
|
||||||
# Get OpenAI response
|
# Get OpenAI response
|
||||||
response = await client.chat.completions.create(model=model_name, messages=messages, tools=all_functions, stream=True)
|
try:
|
||||||
|
response = await openai_client.chat.completions.create(
|
||||||
# Process streaming response
|
model=model_name,
|
||||||
full_response = ""
|
messages=messages,
|
||||||
tool_calls = []
|
tools=tool_defs,
|
||||||
async for chunk in response:
|
tool_choice="auto" if tool_defs else None, # Only set tool_choice if tools exist
|
||||||
if chunk.choices[0].delta.content:
|
stream=True,
|
||||||
content = chunk.choices[0].delta.content
|
)
|
||||||
full_response += content
|
except Exception as e:
|
||||||
yield {"assistant_text": content, "is_chunk": True}
|
logger.error(f"OpenAI API error: {e}", exc_info=True)
|
||||||
|
yield {"error": f"OpenAI API error: {e}"}
|
||||||
if chunk.choices[0].delta.tool_calls:
|
|
||||||
for tc in chunk.choices[0].delta.tool_calls:
|
|
||||||
if len(tool_calls) <= tc.index:
|
|
||||||
tool_calls.append({"id": tc.id, "type": "function", "function": {"name": tc.function.name, "arguments": tc.function.arguments}})
|
|
||||||
else:
|
|
||||||
tool_calls[tc.index]["function"]["arguments"] += tc.function.arguments
|
|
||||||
|
|
||||||
# Handle tool calls
|
|
||||||
if tool_calls:
|
|
||||||
assistant_message = {"role": "assistant", "content": full_response, "tool_calls": tool_calls}
|
|
||||||
messages.append(assistant_message)
|
|
||||||
|
|
||||||
for tc in tool_calls:
|
|
||||||
result = await process_tool_call(tc, servers)
|
|
||||||
messages.append({"role": "tool", "content": json.dumps(result), "tool_call_id": tc["id"]})
|
|
||||||
else:
|
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# Process streaming response
|
||||||
|
full_response_content = ""
|
||||||
|
tool_calls = []
|
||||||
|
async for chunk in response:
|
||||||
|
delta = chunk.choices[0].delta
|
||||||
|
if delta.content:
|
||||||
|
content = delta.content
|
||||||
|
full_response_content += content
|
||||||
|
yield {"assistant_text": content, "is_chunk": True}
|
||||||
|
|
||||||
|
if delta.tool_calls:
|
||||||
|
for tc in delta.tool_calls:
|
||||||
|
# Initialize tool call structure if it's the first chunk for this index
|
||||||
|
if tc.index >= len(tool_calls):
|
||||||
|
tool_calls.append({"id": "", "type": "function", "function": {"name": "", "arguments": ""}})
|
||||||
|
|
||||||
|
# Append parts as they arrive
|
||||||
|
if tc.id:
|
||||||
|
tool_calls[tc.index]["id"] = tc.id
|
||||||
|
if tc.function and tc.function.name:
|
||||||
|
tool_calls[tc.index]["function"]["name"] = tc.function.name
|
||||||
|
if tc.function and tc.function.arguments:
|
||||||
|
tool_calls[tc.index]["function"]["arguments"] += tc.function.arguments
|
||||||
|
|
||||||
|
# Add assistant message with content and potential tool calls
|
||||||
|
assistant_message = {"role": "assistant", "content": full_response_content}
|
||||||
|
if tool_calls:
|
||||||
|
# Filter out incomplete tool calls just in case
|
||||||
|
valid_tool_calls = [tc for tc in tool_calls if tc["id"] and tc["function"]["name"]]
|
||||||
|
if valid_tool_calls:
|
||||||
|
assistant_message["tool_calls"] = valid_tool_calls
|
||||||
|
else:
|
||||||
|
logger.warning("Received tool call chunks but couldn't assemble valid tool calls.")
|
||||||
|
|
||||||
|
messages.append(assistant_message)
|
||||||
|
logger.debug(f"Assistant message added: {assistant_message}")
|
||||||
|
|
||||||
|
# Handle tool calls if any were successfully assembled
|
||||||
|
if "tool_calls" in assistant_message:
|
||||||
|
tool_results = []
|
||||||
|
for tc in assistant_message["tool_calls"]:
|
||||||
|
logger.info(f"Processing tool call: {tc['function']['name']}")
|
||||||
|
result = await process_tool_call(tc, servers)
|
||||||
|
tool_results.append({"role": "tool", "content": json.dumps(result), "tool_call_id": tc["id"]})
|
||||||
|
|
||||||
|
messages.extend(tool_results)
|
||||||
|
logger.debug(f"Tool results added: {tool_results}")
|
||||||
|
# Loop back to call OpenAI again with tool results
|
||||||
|
else:
|
||||||
|
# No tool calls, interaction finished for this turn
|
||||||
|
yield {"assistant_text": full_response_content, "is_chunk": False, "final": True} # Signal final chunk
|
||||||
|
break
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during streaming interaction: {e}", exc_info=True)
|
||||||
|
yield {"error": f"Interaction error: {e}"}
|
||||||
finally:
|
finally:
|
||||||
# Clean up servers
|
# Clean up servers
|
||||||
for server in servers.values():
|
logger.debug("Cleaning up MCP servers (stream)...")
|
||||||
|
for server in active_servers:
|
||||||
await server.stop()
|
await server.stop()
|
||||||
|
logger.debug("MCP server cleanup finished (stream).")
|
||||||
|
|
||||||
else:
|
return response_generator()
|
||||||
|
|
||||||
|
else: # Non-streaming case
|
||||||
|
active_servers = list(servers.values()) # Keep track for cleanup
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
|
logger.debug(f"Calling OpenAI with messages: {messages}")
|
||||||
|
logger.debug(f"Calling OpenAI with tools: {tool_defs}")
|
||||||
# Get OpenAI response
|
# Get OpenAI response
|
||||||
response = await client.chat.completions.create(model=model_name, messages=messages, tools=all_functions)
|
try:
|
||||||
|
response = await openai_client.chat.completions.create(
|
||||||
|
model=model_name,
|
||||||
|
messages=messages,
|
||||||
|
tools=tool_defs,
|
||||||
|
tool_choice="auto" if tool_defs else None, # Only set tool_choice if tools exist
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"OpenAI API error: {e}", exc_info=True)
|
||||||
|
return {"error": f"OpenAI API error: {e}"}
|
||||||
|
|
||||||
message = response.choices[0].message
|
message = response.choices[0].message
|
||||||
messages.append(message)
|
messages.append(message)
|
||||||
|
logger.debug(f"OpenAI response message: {message}")
|
||||||
|
|
||||||
# Handle tool calls
|
# Handle tool calls
|
||||||
if message.tool_calls:
|
if message.tool_calls:
|
||||||
|
tool_results = []
|
||||||
for tc in message.tool_calls:
|
for tc in message.tool_calls:
|
||||||
result = await process_tool_call({"id": tc.id, "function": {"name": tc.function.name, "arguments": tc.function.arguments}}, servers)
|
logger.info(f"Processing tool call: {tc.function.name}")
|
||||||
messages.append({"role": "tool", "content": json.dumps(result), "tool_call_id": tc.id})
|
# Reconstruct dict for process_tool_call
|
||||||
|
tool_call_dict = {"id": tc.id, "type": "function", "function": {"name": tc.function.name, "arguments": tc.function.arguments}}
|
||||||
|
result = await process_tool_call(tool_call_dict, servers)
|
||||||
|
tool_results.append({"role": "tool", "content": json.dumps(result), "tool_call_id": tc.id})
|
||||||
|
|
||||||
|
messages.extend(tool_results)
|
||||||
|
logger.debug(f"Tool results added: {tool_results}")
|
||||||
|
# Loop back to call OpenAI again with tool results
|
||||||
else:
|
else:
|
||||||
|
# No tool calls, interaction finished
|
||||||
|
logger.info("Interaction finished, no tool calls.")
|
||||||
return {"assistant_text": message.content or "", "tool_calls": []}
|
return {"assistant_text": message.content or "", "tool_calls": []}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during non-streaming interaction: {e}", exc_info=True)
|
||||||
|
return {"error": f"Interaction error: {e}"}
|
||||||
finally:
|
finally:
|
||||||
# Clean up servers
|
# Clean up servers
|
||||||
for server in servers.values():
|
logger.debug("Cleaning up MCP servers (non-stream)...")
|
||||||
|
for server in active_servers:
|
||||||
await server.stop()
|
await server.stop()
|
||||||
|
logger.debug("MCP server cleanup finished (non-stream).")
|
||||||
|
|||||||
@@ -3,10 +3,18 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import importlib.resources
|
import importlib.resources
|
||||||
import json
|
import json
|
||||||
|
import logging # Import logging
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
from custom_mcp_client import MCPClient, run_interaction
|
from custom_mcp_client import MCPClient, run_interaction
|
||||||
|
|
||||||
|
# Configure basic logging for the application if not already configured
|
||||||
|
# This basic config helps ensure logs are visible during development
|
||||||
|
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||||
|
|
||||||
|
# Get a logger for this module
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class SyncMCPManager:
|
class SyncMCPManager:
|
||||||
"""Synchronous wrapper for managing MCP servers and interactions"""
|
"""Synchronous wrapper for managing MCP servers and interactions"""
|
||||||
@@ -17,110 +25,210 @@ class SyncMCPManager:
|
|||||||
self.servers = {}
|
self.servers = {}
|
||||||
self.initialized = False
|
self.initialized = False
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
|
logger.info(f"Initializing SyncMCPManager with config path: {config_path}")
|
||||||
self._load_config()
|
self._load_config()
|
||||||
|
|
||||||
def _load_config(self):
|
def _load_config(self):
|
||||||
"""Load MCP configuration from JSON file using importlib"""
|
"""Load MCP configuration from JSON file using importlib"""
|
||||||
|
logger.debug(f"Attempting to load MCP config from: {self.config_path}")
|
||||||
try:
|
try:
|
||||||
# First try to load as a package resource
|
# First try to load as a package resource
|
||||||
try:
|
try:
|
||||||
with importlib.resources.files("streamlit-chat-app").joinpath(self.config_path).open("r") as f:
|
# Try anchoring to the project name defined in pyproject.toml
|
||||||
|
# This *might* work depending on editable installs or context.
|
||||||
|
resource_path = importlib.resources.files("streamlit-chat-app").joinpath(self.config_path)
|
||||||
|
with resource_path.open("r") as f:
|
||||||
self.config = json.load(f)
|
self.config = json.load(f)
|
||||||
except (ImportError, ModuleNotFoundError, TypeError, FileNotFoundError):
|
logger.debug("Loaded config via importlib.resources anchored to 'streamlit-chat-app'.")
|
||||||
# Fall back to direct file access
|
# REMOVED: raise FileNotFoundError
|
||||||
|
|
||||||
|
except (ImportError, ModuleNotFoundError, TypeError, FileNotFoundError, NotADirectoryError): # Added NotADirectoryError
|
||||||
|
logger.debug("Failed to load via importlib.resources, falling back to direct file access.")
|
||||||
|
# Fall back to direct file access relative to CWD
|
||||||
with open(self.config_path) as f:
|
with open(self.config_path) as f:
|
||||||
self.config = json.load(f)
|
self.config = json.load(f)
|
||||||
|
logger.debug("Loaded config via direct file access.")
|
||||||
|
|
||||||
|
logger.info("MCP configuration loaded successfully.")
|
||||||
|
logger.debug(f"Config content: {self.config}") # Log content only if loaded
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.error(f"MCP config file not found at {self.config_path}")
|
||||||
|
self.config = None
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"Error decoding JSON from MCP config file {self.config_path}: {e}")
|
||||||
|
self.config = None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error loading MCP config from {self.config_path}: {str(e)}")
|
logger.error(f"Error loading MCP config from {self.config_path}: {e}", exc_info=True)
|
||||||
self.config = None
|
self.config = None
|
||||||
|
|
||||||
def initialize(self) -> bool:
|
def initialize(self) -> bool:
|
||||||
"""Initialize and start all MCP servers synchronously"""
|
"""Initialize and start all MCP servers synchronously"""
|
||||||
if not self.config or not self.config.get("mcpServers"):
|
logger.info("Initialize requested.")
|
||||||
|
if not self.config:
|
||||||
|
logger.warning("Initialization skipped: No configuration loaded.")
|
||||||
|
return False
|
||||||
|
if not self.config.get("mcpServers"):
|
||||||
|
logger.warning("Initialization skipped: No 'mcpServers' defined in configuration.")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if self.initialized:
|
if self.initialized:
|
||||||
|
logger.debug("Initialization skipped: Already initialized.")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if self.initialized: # Double-check after acquiring lock
|
if self.initialized: # Double-check after acquiring lock
|
||||||
|
logger.debug("Initialization skipped inside lock: Already initialized.")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
logger.info("Starting asynchronous initialization...")
|
||||||
# Run async initialization in a new event loop
|
# Run async initialization in a new event loop
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop) # Ensure this loop is used by tasks
|
||||||
success = loop.run_until_complete(self._async_initialize())
|
success = loop.run_until_complete(self._async_initialize())
|
||||||
loop.close()
|
loop.close()
|
||||||
|
asyncio.set_event_loop(None) # Clean up
|
||||||
|
|
||||||
self.initialized = success
|
if success:
|
||||||
return success
|
logger.info("Asynchronous initialization completed successfully.")
|
||||||
|
self.initialized = True
|
||||||
|
else:
|
||||||
|
logger.error("Asynchronous initialization failed.")
|
||||||
|
self.initialized = False # Ensure state is False on failure
|
||||||
|
|
||||||
|
return self.initialized
|
||||||
|
|
||||||
async def _async_initialize(self) -> bool:
|
async def _async_initialize(self) -> bool:
|
||||||
"""Async implementation of server initialization"""
|
"""Async implementation of server initialization"""
|
||||||
success = True
|
logger.debug("Starting _async_initialize...")
|
||||||
for server_name, server_config in self.config["mcpServers"].items():
|
all_success = True
|
||||||
|
if not self.config or not self.config.get("mcpServers"):
|
||||||
|
logger.warning("_async_initialize: No config or mcpServers found.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
tasks = []
|
||||||
|
server_names = list(self.config["mcpServers"].keys())
|
||||||
|
|
||||||
|
async def start_server(server_name, server_config):
|
||||||
|
logger.info(f"Initializing server: {server_name}")
|
||||||
try:
|
try:
|
||||||
client = MCPClient(server_name=server_name, command=server_config.get("command"), args=server_config.get("args", []), env=server_config.get("env", {}))
|
client = MCPClient(server_name=server_name, command=server_config.get("command"), args=server_config.get("args", []), env=server_config.get("env", {}))
|
||||||
|
|
||||||
|
logger.debug(f"Attempting to start client for {server_name}...")
|
||||||
if await client.start():
|
if await client.start():
|
||||||
|
logger.info(f"Client for {server_name} started successfully.")
|
||||||
tools = await client.list_tools()
|
tools = await client.list_tools()
|
||||||
|
logger.info(f"Tools listed for {server_name}: {len(tools)}")
|
||||||
self.servers[server_name] = {"client": client, "tools": tools}
|
self.servers[server_name] = {"client": client, "tools": tools}
|
||||||
|
return True
|
||||||
else:
|
else:
|
||||||
success = False
|
logger.error(f"Failed to start MCP server: {server_name}")
|
||||||
print(f"Failed to start MCP server: {server_name}")
|
return False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error initializing server {server_name}: {str(e)}")
|
logger.error(f"Error initializing server {server_name}: {e}", exc_info=True)
|
||||||
success = False
|
return False
|
||||||
|
|
||||||
return success
|
# Start servers concurrently
|
||||||
|
for server_name in server_names:
|
||||||
|
server_config = self.config["mcpServers"][server_name]
|
||||||
|
tasks.append(start_server(server_name, server_config))
|
||||||
|
|
||||||
|
results = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
# Check if all servers started successfully
|
||||||
|
all_success = all(results)
|
||||||
|
|
||||||
|
if all_success:
|
||||||
|
logger.debug("_async_initialize completed: All servers started successfully.")
|
||||||
|
else:
|
||||||
|
failed_servers = [server_names[i] for i, res in enumerate(results) if not res]
|
||||||
|
logger.error(f"_async_initialize completed with failures. Failed servers: {failed_servers}")
|
||||||
|
# Optionally shutdown servers that did start if partial success is not desired
|
||||||
|
# await self._async_shutdown() # Uncomment to enforce all-or-nothing startup
|
||||||
|
|
||||||
|
return all_success
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
"""Shut down all MCP servers synchronously"""
|
"""Shut down all MCP servers synchronously"""
|
||||||
|
logger.info("Shutdown requested.")
|
||||||
if not self.initialized:
|
if not self.initialized:
|
||||||
|
logger.debug("Shutdown skipped: Not initialized.")
|
||||||
return
|
return
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if not self.initialized:
|
if not self.initialized:
|
||||||
|
logger.debug("Shutdown skipped inside lock: Not initialized.")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
logger.info("Starting asynchronous shutdown...")
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
loop.run_until_complete(self._async_shutdown())
|
loop.run_until_complete(self._async_shutdown())
|
||||||
loop.close()
|
loop.close()
|
||||||
|
asyncio.set_event_loop(None)
|
||||||
|
|
||||||
self.servers = {}
|
self.servers = {}
|
||||||
self.initialized = False
|
self.initialized = False
|
||||||
|
logger.info("Shutdown complete.")
|
||||||
|
|
||||||
async def _async_shutdown(self):
|
async def _async_shutdown(self):
|
||||||
"""Async implementation of server shutdown"""
|
"""Async implementation of server shutdown"""
|
||||||
for server_info in self.servers.values():
|
logger.debug("Starting _async_shutdown...")
|
||||||
try:
|
tasks = []
|
||||||
await server_info["client"].stop()
|
for server_name, server_info in self.servers.items():
|
||||||
except Exception as e:
|
logger.debug(f"Initiating shutdown for server: {server_name}")
|
||||||
print(f"Error shutting down server: {str(e)}")
|
tasks.append(server_info["client"].stop())
|
||||||
|
|
||||||
def process_query(self, query: str, model_name: str) -> dict:
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
server_name = list(self.servers.keys())[i]
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
logger.error(f"Error shutting down server {server_name}: {result}", exc_info=result)
|
||||||
|
else:
|
||||||
|
logger.debug(f"Shutdown completed for server: {server_name}")
|
||||||
|
logger.debug("_async_shutdown finished.")
|
||||||
|
|
||||||
|
# Updated process_query signature
|
||||||
|
def process_query(self, query: str, model_name: str, api_key: str, base_url: str | None) -> dict:
|
||||||
"""
|
"""
|
||||||
Process a query using MCP tools synchronously
|
Process a query using MCP tools synchronously
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: The user's input query
|
query: The user's input query.
|
||||||
model_name: The model to use for processing
|
model_name: The model to use for processing.
|
||||||
|
api_key: The OpenAI API key.
|
||||||
|
base_url: The OpenAI API base URL.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary containing response or error
|
Dictionary containing response or error.
|
||||||
"""
|
"""
|
||||||
if not self.initialized and not self.initialize():
|
if not self.initialized and not self.initialize():
|
||||||
|
logger.error("process_query called but MCP manager failed to initialize.")
|
||||||
return {"error": "Failed to initialize MCP servers"}
|
return {"error": "Failed to initialize MCP servers"}
|
||||||
|
|
||||||
|
logger.debug(f"Processing query synchronously: '{query}'")
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
try:
|
try:
|
||||||
result = loop.run_until_complete(self._async_process_query(query, model_name))
|
# Pass api_key and base_url to _async_process_query
|
||||||
|
result = loop.run_until_complete(self._async_process_query(query, model_name, api_key, base_url))
|
||||||
|
logger.debug(f"Synchronous query processing result: {result}")
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logger.error(f"Error during synchronous query processing: {e}", exc_info=True)
|
||||||
return {"error": f"Processing error: {str(e)}"}
|
return {"error": f"Processing error: {str(e)}"}
|
||||||
finally:
|
finally:
|
||||||
loop.close()
|
loop.close()
|
||||||
|
|
||||||
async def _async_process_query(self, query: str, model_name: str) -> dict:
|
# Updated _async_process_query signature
|
||||||
|
async def _async_process_query(self, query: str, model_name: str, api_key: str, base_url: str | None) -> dict:
|
||||||
"""Async implementation of query processing"""
|
"""Async implementation of query processing"""
|
||||||
return await run_interaction(user_query=query, model_name=model_name, config=self.config, stream=False)
|
# Pass api_key, base_url, and the MCP config separately to run_interaction
|
||||||
|
return await run_interaction(
|
||||||
|
user_query=query,
|
||||||
|
model_name=model_name,
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url,
|
||||||
|
mcp_config=self.config, # self.config only contains MCP server definitions now
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,14 +1,19 @@
|
|||||||
"""OpenAI client with custom MCP integration."""
|
"""OpenAI client with custom MCP integration."""
|
||||||
|
|
||||||
import configparser
|
import configparser
|
||||||
|
import logging # Import logging
|
||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
from mcp_manager import SyncMCPManager
|
from mcp_manager import SyncMCPManager
|
||||||
|
|
||||||
|
# Get a logger for this module
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class OpenAIClient:
|
class OpenAIClient:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
logger.debug("Initializing OpenAIClient...") # Add init log
|
||||||
self.config = configparser.ConfigParser()
|
self.config = configparser.ConfigParser()
|
||||||
self.config.read("config/config.ini")
|
self.config.read("config/config.ini")
|
||||||
|
|
||||||
@@ -33,28 +38,31 @@ class OpenAIClient:
|
|||||||
try:
|
try:
|
||||||
# Try using MCP if available
|
# Try using MCP if available
|
||||||
if self.mcp_manager and self.mcp_manager.initialize():
|
if self.mcp_manager and self.mcp_manager.initialize():
|
||||||
print("Using MCP with tools...")
|
logger.info("Using MCP with tools...") # Use logger
|
||||||
last_message = messages[-1]["content"]
|
last_message = messages[-1]["content"]
|
||||||
response = self.mcp_manager.process_query(last_message, model_name=self.config["openai"]["model"])
|
# Pass API key and base URL from config.ini
|
||||||
|
response = self.mcp_manager.process_query(
|
||||||
|
query=last_message,
|
||||||
|
model_name=self.config["openai"]["model"],
|
||||||
|
api_key=self.config["openai"]["api_key"],
|
||||||
|
base_url=self.config["openai"].get("base_url"), # Use .get for optional base_url
|
||||||
|
)
|
||||||
|
|
||||||
if "error" not in response:
|
if "error" not in response:
|
||||||
|
logger.debug("MCP processing successful, wrapping response.")
|
||||||
# Convert to OpenAI-compatible response format
|
# Convert to OpenAI-compatible response format
|
||||||
return self._wrap_mcp_response(response)
|
return self._wrap_mcp_response(response)
|
||||||
|
|
||||||
# Fall back to standard OpenAI
|
# Fall back to standard OpenAI
|
||||||
print(f"Using standard OpenAI API with model: {self.config['openai']['model']}")
|
logger.info(f"Falling back to standard OpenAI API with model: {self.config['openai']['model']}") # Use logger
|
||||||
return self.client.chat.completions.create(model=self.config["openai"]["model"], messages=messages, stream=True)
|
return self.client.chat.completions.create(model=self.config["openai"]["model"], messages=messages, stream=True)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"API Error (Code: {getattr(e, 'code', 'N/A')}): {str(e)}"
|
error_msg = f"API Error (Code: {getattr(e, 'code', 'N/A')}): {str(e)}"
|
||||||
print(error_msg)
|
logger.error(error_msg, exc_info=True) # Use logger
|
||||||
raise Exception(error_msg)
|
raise Exception(error_msg)
|
||||||
|
|
||||||
def _wrap_mcp_response(self, response: dict):
|
def _wrap_mcp_response(self, response: dict):
|
||||||
"""Convert MCP response to OpenAI-compatible format"""
|
"""Return the MCP response dictionary directly (for non-streaming)."""
|
||||||
|
# No conversion needed if app.py handles dicts separately
|
||||||
# Create a generator to simulate streaming response
|
return response
|
||||||
def response_generator():
|
|
||||||
yield {"choices": [{"delta": {"content": response.get("assistant_text", "")}}]}
|
|
||||||
|
|
||||||
return response_generator()
|
|
||||||
|
|||||||
Reference in New Issue
Block a user