feat: implement custom MCP client and integrate with OpenAI API for enhanced chat functionality

This commit is contained in:
2025-03-25 19:00:00 +00:00
parent 1019eae9fe
commit 314b488bf9
6 changed files with 529 additions and 34 deletions

View File

@@ -10,8 +10,7 @@ authors = [
dependencies = [
"streamlit",
"python-dotenv",
"openai",
"dolphin-mcp"
"openai"
]
classifiers = [
"Development Status :: 3 - Alpha",
@@ -81,3 +80,10 @@ combine-as-imports = true
[tool.ruff.lint.mccabe]
max-complexity = 12
[tool.ruff.lint.flake8-tidy-imports]
# Disallow all relative imports.
ban-relative-imports = "all"
[tool.streamlit-chat-app.config]
mcp_config = "config/mcp_config.json"

View File

@@ -1,46 +1,68 @@
import atexit
import streamlit as st
from openai_client import OpenAIClient
def init_session_state():
if "messages" not in st.session_state:
st.session_state.messages = []
if "client" not in st.session_state:
st.session_state.client = OpenAIClient()
# Register cleanup for MCP servers
if hasattr(st.session_state.client, "mcp_manager"):
atexit.register(st.session_state.client.mcp_manager.shutdown)
def display_chat_messages():
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
def handle_user_input():
if prompt := st.chat_input("Type your message..."):
print(f"User input received: {prompt}") # Debug log
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
try:
with st.chat_message("assistant"):
response_placeholder = st.empty()
full_response = ""
client = OpenAIClient()
print("Calling OpenAI API...") # Debug log
for chunk in client.get_chat_response(st.session_state.messages):
if chunk.choices[0].delta.content:
full_response += chunk.choices[0].delta.content
response_placeholder.markdown(full_response + "")
print("Processing message...") # Debug log
response = st.session_state.client.get_chat_response(st.session_state.messages)
# Handle both MCP and standard OpenAI responses
if hasattr(response, "__iter__"):
# Standard OpenAI streaming response
for chunk in response:
if chunk.choices[0].delta.content:
full_response += chunk.choices[0].delta.content
response_placeholder.markdown(full_response + "")
else:
# MCP non-streaming response
full_response = response.get("assistant_text", "")
response_placeholder.markdown(full_response)
response_placeholder.markdown(full_response)
st.session_state.messages.append({"role": "assistant", "content": full_response})
print("API call completed successfully") # Debug log
print("Message processed successfully") # Debug log
except Exception as e:
st.error(f"Error processing message: {str(e)}")
print(f"Error details: {str(e)}") # Debug log
def main():
st.title("Streamlit Chat App")
init_session_state()
display_chat_messages()
handle_user_input()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,5 @@
"""Custom MCP client implementation focused on OpenAI integration."""
from .client import MCPClient, run_interaction
__all__ = ["MCPClient", "run_interaction"]

View File

@@ -0,0 +1,315 @@
"""Custom MCP client implementation with JSON-RPC and OpenAI integration."""
import asyncio
import json
import os
from collections.abc import AsyncGenerator
from openai import AsyncOpenAI
class MCPClient:
"""Lightweight MCP client with JSON-RPC communication."""
def __init__(self, server_name: str, command: str, args: list[str] | None = None, env: dict[str, str] | None = None):
self.server_name = server_name
self.command = command
self.args = args or []
self.env = env or {}
self.process = None
self.tools = []
self.request_id = 0
self.responses = {}
self._shutdown = False
async def _receive_loop(self):
"""Listen for responses from the MCP server."""
try:
while not self.process.stdout.at_eof():
line = await self.process.stdout.readline()
if not line:
break
try:
message = json.loads(line.decode().strip())
if "jsonrpc" in message and "id" in message and ("result" in message or "error" in message):
self.responses[message["id"]] = message
except Exception:
pass
except Exception:
pass
async def _send_message(self, message: dict) -> bool:
"""Send a JSON-RPC message to the MCP server."""
if not self.process:
return False
try:
data = json.dumps(message) + "\n"
self.process.stdin.write(data.encode())
await self.process.stdin.drain()
return True
except Exception:
return False
async def start(self) -> bool:
"""Start the MCP server process."""
# Expand ~ in paths
expanded_args = []
for a in self.args:
if isinstance(a, str) and "~" in a:
expanded_args.append(os.path.expanduser(a))
else:
expanded_args.append(a)
# Set up environment
env_vars = os.environ.copy()
if self.env:
env_vars.update(self.env)
try:
# Start the subprocess
self.process = await asyncio.create_subprocess_exec(
self.command, *expanded_args, stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, env=env_vars
)
# Start the receive loop
asyncio.create_task(self._receive_loop())
# Initialize the server
return await self._initialize()
except Exception:
return False
async def _initialize(self) -> bool:
"""Initialize the MCP server connection."""
if not self.process:
return False
# Send initialize request
self.request_id += 1
req_id = self.request_id
initialize_req = {"jsonrpc": "2.0", "id": req_id, "method": "initialize", "params": {"protocolVersion": "2024-11-05", "clientInfo": {"name": "CustomMCPClient", "version": "1.0.0"}}}
await self._send_message(initialize_req)
# Wait for response
start_time = asyncio.get_event_loop().time()
timeout = 10
while asyncio.get_event_loop().time() - start_time < timeout:
if req_id in self.responses:
resp = self.responses[req_id]
del self.responses[req_id]
if "error" in resp:
return False
# Send initialized notification
notify = {"jsonrpc": "2.0", "method": "notifications/initialized"}
await self._send_message(notify)
return True
await asyncio.sleep(0.05)
return False
async def list_tools(self) -> list[dict]:
"""List available tools from the MCP server."""
if not self.process:
return []
self.request_id += 1
req_id = self.request_id
req = {"jsonrpc": "2.0", "id": req_id, "method": "tools/list", "params": {}}
await self._send_message(req)
# Wait for response
start_time = asyncio.get_event_loop().time()
timeout = 10
while asyncio.get_event_loop().time() - start_time < timeout:
if req_id in self.responses:
resp = self.responses[req_id]
del self.responses[req_id]
if "error" in resp:
return []
if "result" in resp and "tools" in resp["result"]:
self.tools = resp["result"]["tools"]
return self.tools
await asyncio.sleep(0.05)
return []
async def call_tool(self, tool_name: str, arguments: dict) -> dict:
"""Call a tool on the MCP server."""
if not self.process:
return {"error": "Server not started"}
self.request_id += 1
req_id = self.request_id
req = {"jsonrpc": "2.0", "id": req_id, "method": "tools/call", "params": {"name": tool_name, "arguments": arguments}}
await self._send_message(req)
# Wait for response
start_time = asyncio.get_event_loop().time()
timeout = 30
while asyncio.get_event_loop().time() - start_time < timeout:
if req_id in self.responses:
resp = self.responses[req_id]
del self.responses[req_id]
if "error" in resp:
return {"error": str(resp["error"])}
if "result" in resp:
return resp["result"]
await asyncio.sleep(0.05)
return {"error": f"Tool call timed out after {timeout}s"}
async def stop(self):
"""Stop the MCP server process."""
if self._shutdown or not self.process:
return
self._shutdown = True
try:
# Send shutdown notification
notify = {"jsonrpc": "2.0", "method": "shutdown"}
await self._send_message(notify)
await asyncio.sleep(0.5)
# Close stdin
if self.process.stdin:
self.process.stdin.close()
# Terminate the process
self.process.terminate()
try:
await asyncio.wait_for(self.process.wait(), timeout=1.0)
except TimeoutError:
self.process.kill()
except Exception:
pass
finally:
self.process = None
async def process_tool_call(tool_call: dict, servers: dict[str, MCPClient]) -> dict:
"""Process a tool call from OpenAI."""
func_name = tool_call["function"]["name"]
try:
func_args = json.loads(tool_call["function"].get("arguments", "{}"))
except json.JSONDecodeError:
return {"error": "Invalid arguments format"}
# Parse server_name and tool_name from function name
parts = func_name.split("_", 1)
if len(parts) != 2:
return {"error": "Invalid function name format"}
server_name, tool_name = parts
if server_name not in servers:
return {"error": f"Unknown server: {server_name}"}
# Call the tool
return await servers[server_name].call_tool(tool_name, func_args)
async def run_interaction(user_query: str, model_name: str, config: dict, stream: bool = False) -> dict | AsyncGenerator:
"""
Run an interaction with OpenAI using MCP server tools.
Args:
user_query: The user's input query
model_name: The model to use for processing
config: Configuration dictionary
stream: Whether to stream the response
Returns:
Dictionary containing response or AsyncGenerator for streaming
"""
# Get OpenAI configuration
api_key = config["models"][0]["apiKey"]
base_url = config["models"][0].get("apiBase", "https://api.openai.com/v1")
# Start MCP servers
servers = {}
all_functions = []
for server_name, server_config in config["mcpServers"].items():
client = MCPClient(server_name=server_name, command=server_config.get("command"), args=server_config.get("args", []), env=server_config.get("env", {}))
if await client.start():
tools = await client.list_tools()
for tool in tools:
all_functions.append({"name": f"{server_name}_{tool['name']}", "description": tool.get("description", ""), "parameters": tool.get("inputSchema", {})})
servers[server_name] = client
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
messages = [{"role": "user", "content": user_query}]
if stream:
async def response_generator():
try:
while True:
# Get OpenAI response
response = await client.chat.completions.create(model=model_name, messages=messages, tools=all_functions, stream=True)
# Process streaming response
full_response = ""
tool_calls = []
async for chunk in response:
if chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
full_response += content
yield {"assistant_text": content, "is_chunk": True}
if chunk.choices[0].delta.tool_calls:
for tc in chunk.choices[0].delta.tool_calls:
if len(tool_calls) <= tc.index:
tool_calls.append({"id": tc.id, "type": "function", "function": {"name": tc.function.name, "arguments": tc.function.arguments}})
else:
tool_calls[tc.index]["function"]["arguments"] += tc.function.arguments
# Handle tool calls
if tool_calls:
assistant_message = {"role": "assistant", "content": full_response, "tool_calls": tool_calls}
messages.append(assistant_message)
for tc in tool_calls:
result = await process_tool_call(tc, servers)
messages.append({"role": "tool", "content": json.dumps(result), "tool_call_id": tc["id"]})
else:
break
finally:
# Clean up servers
for server in servers.values():
await server.stop()
else:
try:
while True:
# Get OpenAI response
response = await client.chat.completions.create(model=model_name, messages=messages, tools=all_functions)
message = response.choices[0].message
messages.append(message)
# Handle tool calls
if message.tool_calls:
for tc in message.tool_calls:
result = await process_tool_call({"id": tc.id, "function": {"name": tc.function.name, "arguments": tc.function.arguments}}, servers)
messages.append({"role": "tool", "content": json.dumps(result), "tool_call_id": tc.id})
else:
return {"assistant_text": message.content or "", "tool_calls": []}
finally:
# Clean up servers
for server in servers.values():
await server.stop()

126
src/mcp_manager.py Normal file
View File

@@ -0,0 +1,126 @@
"""Synchronous wrapper for managing MCP servers using our custom implementation."""
import asyncio
import importlib.resources
import json
import threading
from custom_mcp_client import MCPClient, run_interaction
class SyncMCPManager:
"""Synchronous wrapper for managing MCP servers and interactions"""
def __init__(self, config_path: str = "config/mcp_config.json"):
self.config_path = config_path
self.config = None
self.servers = {}
self.initialized = False
self._lock = threading.Lock()
self._load_config()
def _load_config(self):
"""Load MCP configuration from JSON file using importlib"""
try:
# First try to load as a package resource
try:
with importlib.resources.files("streamlit-chat-app").joinpath(self.config_path).open("r") as f:
self.config = json.load(f)
except (ImportError, ModuleNotFoundError, TypeError, FileNotFoundError):
# Fall back to direct file access
with open(self.config_path) as f:
self.config = json.load(f)
except Exception as e:
print(f"Error loading MCP config from {self.config_path}: {str(e)}")
self.config = None
def initialize(self) -> bool:
"""Initialize and start all MCP servers synchronously"""
if not self.config or not self.config.get("mcpServers"):
return False
if self.initialized:
return True
with self._lock:
if self.initialized: # Double-check after acquiring lock
return True
# Run async initialization in a new event loop
loop = asyncio.new_event_loop()
success = loop.run_until_complete(self._async_initialize())
loop.close()
self.initialized = success
return success
async def _async_initialize(self) -> bool:
"""Async implementation of server initialization"""
success = True
for server_name, server_config in self.config["mcpServers"].items():
try:
client = MCPClient(server_name=server_name, command=server_config.get("command"), args=server_config.get("args", []), env=server_config.get("env", {}))
if await client.start():
tools = await client.list_tools()
self.servers[server_name] = {"client": client, "tools": tools}
else:
success = False
print(f"Failed to start MCP server: {server_name}")
except Exception as e:
print(f"Error initializing server {server_name}: {str(e)}")
success = False
return success
def shutdown(self):
"""Shut down all MCP servers synchronously"""
if not self.initialized:
return
with self._lock:
if not self.initialized:
return
loop = asyncio.new_event_loop()
loop.run_until_complete(self._async_shutdown())
loop.close()
self.servers = {}
self.initialized = False
async def _async_shutdown(self):
"""Async implementation of server shutdown"""
for server_info in self.servers.values():
try:
await server_info["client"].stop()
except Exception as e:
print(f"Error shutting down server: {str(e)}")
def process_query(self, query: str, model_name: str) -> dict:
"""
Process a query using MCP tools synchronously
Args:
query: The user's input query
model_name: The model to use for processing
Returns:
Dictionary containing response or error
"""
if not self.initialized and not self.initialize():
return {"error": "Failed to initialize MCP servers"}
loop = asyncio.new_event_loop()
try:
result = loop.run_until_complete(self._async_process_query(query, model_name))
return result
except Exception as e:
return {"error": f"Processing error: {str(e)}"}
finally:
loop.close()
async def _async_process_query(self, query: str, model_name: str) -> dict:
"""Async implementation of query processing"""
return await run_interaction(user_query=query, model_name=model_name, config=self.config, stream=False)

View File

@@ -1,39 +1,60 @@
"""OpenAI client with custom MCP integration."""
import configparser
from openai import OpenAI
from mcp_manager import SyncMCPManager
class OpenAIClient:
def __init__(self):
self.config = configparser.ConfigParser()
self.config.read('config/config.ini')
self.config.read("config/config.ini")
# Validate configuration
if not self.config.has_section('openai'):
if not self.config.has_section("openai"):
raise Exception("Missing [openai] section in config.ini")
if not self.config['openai'].get('api_key'):
if not self.config["openai"].get("api_key"):
raise Exception("Missing api_key in config.ini")
# Configure OpenAI client
self.client = OpenAI(
api_key=self.config['openai']['api_key'],
base_url=self.config['openai']['base_url'],
default_headers={
"HTTP-Referer": "https://streamlit-chat-app.com",
"X-Title": "Streamlit Chat App"
}
api_key=self.config["openai"]["api_key"], base_url=self.config["openai"]["base_url"], default_headers={"HTTP-Referer": "https://streamlit-chat-app.com", "X-Title": "Streamlit Chat App"}
)
# Initialize MCP manager if configured
self.mcp_manager = None
if self.config.has_section("dolphin-mcp"):
mcp_config_path = self.config["dolphin-mcp"].get("servers_json", "config/mcp_config.json")
self.mcp_manager = SyncMCPManager(mcp_config_path)
def get_chat_response(self, messages):
try:
print(f"Sending request to {self.config['openai']['base_url']}") # Debug log
print(f"Using model: {self.config['openai']['model']}") # Debug log
response = self.client.chat.completions.create(
model=self.config['openai']['model'],
messages=messages,
stream=True
)
return response
# Try using MCP if available
if self.mcp_manager and self.mcp_manager.initialize():
print("Using MCP with tools...")
last_message = messages[-1]["content"]
response = self.mcp_manager.process_query(last_message, model_name=self.config["openai"]["model"])
if "error" not in response:
# Convert to OpenAI-compatible response format
return self._wrap_mcp_response(response)
# Fall back to standard OpenAI
print(f"Using standard OpenAI API with model: {self.config['openai']['model']}")
return self.client.chat.completions.create(model=self.config["openai"]["model"], messages=messages, stream=True)
except Exception as e:
error_msg = f"API Error (Code: {getattr(e, 'code', 'N/A')}): {str(e)}"
print(error_msg) # Debug log
print(error_msg)
raise Exception(error_msg)
def _wrap_mcp_response(self, response: dict):
"""Convert MCP response to OpenAI-compatible format"""
# Create a generator to simulate streaming response
def response_generator():
yield {"choices": [{"delta": {"content": response.get("assistant_text", "")}}]}
return response_generator()