feat: implement custom MCP client and integrate with OpenAI API for enhanced chat functionality
This commit is contained in:
42
src/app.py
42
src/app.py
@@ -1,46 +1,68 @@
|
||||
import atexit
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from openai_client import OpenAIClient
|
||||
|
||||
|
||||
def init_session_state():
|
||||
if "messages" not in st.session_state:
|
||||
st.session_state.messages = []
|
||||
if "client" not in st.session_state:
|
||||
st.session_state.client = OpenAIClient()
|
||||
# Register cleanup for MCP servers
|
||||
if hasattr(st.session_state.client, "mcp_manager"):
|
||||
atexit.register(st.session_state.client.mcp_manager.shutdown)
|
||||
|
||||
|
||||
def display_chat_messages():
|
||||
for message in st.session_state.messages:
|
||||
with st.chat_message(message["role"]):
|
||||
st.markdown(message["content"])
|
||||
|
||||
|
||||
def handle_user_input():
|
||||
if prompt := st.chat_input("Type your message..."):
|
||||
print(f"User input received: {prompt}") # Debug log
|
||||
st.session_state.messages.append({"role": "user", "content": prompt})
|
||||
with st.chat_message("user"):
|
||||
st.markdown(prompt)
|
||||
|
||||
|
||||
try:
|
||||
with st.chat_message("assistant"):
|
||||
response_placeholder = st.empty()
|
||||
full_response = ""
|
||||
|
||||
client = OpenAIClient()
|
||||
print("Calling OpenAI API...") # Debug log
|
||||
for chunk in client.get_chat_response(st.session_state.messages):
|
||||
if chunk.choices[0].delta.content:
|
||||
full_response += chunk.choices[0].delta.content
|
||||
response_placeholder.markdown(full_response + "▌")
|
||||
|
||||
|
||||
print("Processing message...") # Debug log
|
||||
response = st.session_state.client.get_chat_response(st.session_state.messages)
|
||||
|
||||
# Handle both MCP and standard OpenAI responses
|
||||
if hasattr(response, "__iter__"):
|
||||
# Standard OpenAI streaming response
|
||||
for chunk in response:
|
||||
if chunk.choices[0].delta.content:
|
||||
full_response += chunk.choices[0].delta.content
|
||||
response_placeholder.markdown(full_response + "▌")
|
||||
else:
|
||||
# MCP non-streaming response
|
||||
full_response = response.get("assistant_text", "")
|
||||
response_placeholder.markdown(full_response)
|
||||
|
||||
response_placeholder.markdown(full_response)
|
||||
st.session_state.messages.append({"role": "assistant", "content": full_response})
|
||||
print("API call completed successfully") # Debug log
|
||||
print("Message processed successfully") # Debug log
|
||||
|
||||
except Exception as e:
|
||||
st.error(f"Error processing message: {str(e)}")
|
||||
print(f"Error details: {str(e)}") # Debug log
|
||||
|
||||
|
||||
def main():
|
||||
st.title("Streamlit Chat App")
|
||||
init_session_state()
|
||||
display_chat_messages()
|
||||
handle_user_input()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
5
src/custom_mcp_client/__init__.py
Normal file
5
src/custom_mcp_client/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Custom MCP client implementation focused on OpenAI integration."""
|
||||
|
||||
from .client import MCPClient, run_interaction
|
||||
|
||||
__all__ = ["MCPClient", "run_interaction"]
|
||||
315
src/custom_mcp_client/client.py
Normal file
315
src/custom_mcp_client/client.py
Normal file
@@ -0,0 +1,315 @@
|
||||
"""Custom MCP client implementation with JSON-RPC and OpenAI integration."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
|
||||
class MCPClient:
|
||||
"""Lightweight MCP client with JSON-RPC communication."""
|
||||
|
||||
def __init__(self, server_name: str, command: str, args: list[str] | None = None, env: dict[str, str] | None = None):
|
||||
self.server_name = server_name
|
||||
self.command = command
|
||||
self.args = args or []
|
||||
self.env = env or {}
|
||||
self.process = None
|
||||
self.tools = []
|
||||
self.request_id = 0
|
||||
self.responses = {}
|
||||
self._shutdown = False
|
||||
|
||||
async def _receive_loop(self):
|
||||
"""Listen for responses from the MCP server."""
|
||||
try:
|
||||
while not self.process.stdout.at_eof():
|
||||
line = await self.process.stdout.readline()
|
||||
if not line:
|
||||
break
|
||||
try:
|
||||
message = json.loads(line.decode().strip())
|
||||
if "jsonrpc" in message and "id" in message and ("result" in message or "error" in message):
|
||||
self.responses[message["id"]] = message
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _send_message(self, message: dict) -> bool:
|
||||
"""Send a JSON-RPC message to the MCP server."""
|
||||
if not self.process:
|
||||
return False
|
||||
try:
|
||||
data = json.dumps(message) + "\n"
|
||||
self.process.stdin.write(data.encode())
|
||||
await self.process.stdin.drain()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def start(self) -> bool:
|
||||
"""Start the MCP server process."""
|
||||
# Expand ~ in paths
|
||||
expanded_args = []
|
||||
for a in self.args:
|
||||
if isinstance(a, str) and "~" in a:
|
||||
expanded_args.append(os.path.expanduser(a))
|
||||
else:
|
||||
expanded_args.append(a)
|
||||
|
||||
# Set up environment
|
||||
env_vars = os.environ.copy()
|
||||
if self.env:
|
||||
env_vars.update(self.env)
|
||||
|
||||
try:
|
||||
# Start the subprocess
|
||||
self.process = await asyncio.create_subprocess_exec(
|
||||
self.command, *expanded_args, stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, env=env_vars
|
||||
)
|
||||
|
||||
# Start the receive loop
|
||||
asyncio.create_task(self._receive_loop())
|
||||
|
||||
# Initialize the server
|
||||
return await self._initialize()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def _initialize(self) -> bool:
|
||||
"""Initialize the MCP server connection."""
|
||||
if not self.process:
|
||||
return False
|
||||
|
||||
# Send initialize request
|
||||
self.request_id += 1
|
||||
req_id = self.request_id
|
||||
initialize_req = {"jsonrpc": "2.0", "id": req_id, "method": "initialize", "params": {"protocolVersion": "2024-11-05", "clientInfo": {"name": "CustomMCPClient", "version": "1.0.0"}}}
|
||||
await self._send_message(initialize_req)
|
||||
|
||||
# Wait for response
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
timeout = 10
|
||||
while asyncio.get_event_loop().time() - start_time < timeout:
|
||||
if req_id in self.responses:
|
||||
resp = self.responses[req_id]
|
||||
del self.responses[req_id]
|
||||
|
||||
if "error" in resp:
|
||||
return False
|
||||
|
||||
# Send initialized notification
|
||||
notify = {"jsonrpc": "2.0", "method": "notifications/initialized"}
|
||||
await self._send_message(notify)
|
||||
return True
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
return False
|
||||
|
||||
async def list_tools(self) -> list[dict]:
|
||||
"""List available tools from the MCP server."""
|
||||
if not self.process:
|
||||
return []
|
||||
|
||||
self.request_id += 1
|
||||
req_id = self.request_id
|
||||
req = {"jsonrpc": "2.0", "id": req_id, "method": "tools/list", "params": {}}
|
||||
await self._send_message(req)
|
||||
|
||||
# Wait for response
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
timeout = 10
|
||||
while asyncio.get_event_loop().time() - start_time < timeout:
|
||||
if req_id in self.responses:
|
||||
resp = self.responses[req_id]
|
||||
del self.responses[req_id]
|
||||
|
||||
if "error" in resp:
|
||||
return []
|
||||
|
||||
if "result" in resp and "tools" in resp["result"]:
|
||||
self.tools = resp["result"]["tools"]
|
||||
return self.tools
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
return []
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: dict) -> dict:
|
||||
"""Call a tool on the MCP server."""
|
||||
if not self.process:
|
||||
return {"error": "Server not started"}
|
||||
|
||||
self.request_id += 1
|
||||
req_id = self.request_id
|
||||
req = {"jsonrpc": "2.0", "id": req_id, "method": "tools/call", "params": {"name": tool_name, "arguments": arguments}}
|
||||
await self._send_message(req)
|
||||
|
||||
# Wait for response
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
timeout = 30
|
||||
while asyncio.get_event_loop().time() - start_time < timeout:
|
||||
if req_id in self.responses:
|
||||
resp = self.responses[req_id]
|
||||
del self.responses[req_id]
|
||||
|
||||
if "error" in resp:
|
||||
return {"error": str(resp["error"])}
|
||||
|
||||
if "result" in resp:
|
||||
return resp["result"]
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
return {"error": f"Tool call timed out after {timeout}s"}
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the MCP server process."""
|
||||
if self._shutdown or not self.process:
|
||||
return
|
||||
|
||||
self._shutdown = True
|
||||
|
||||
try:
|
||||
# Send shutdown notification
|
||||
notify = {"jsonrpc": "2.0", "method": "shutdown"}
|
||||
await self._send_message(notify)
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Close stdin
|
||||
if self.process.stdin:
|
||||
self.process.stdin.close()
|
||||
|
||||
# Terminate the process
|
||||
self.process.terminate()
|
||||
try:
|
||||
await asyncio.wait_for(self.process.wait(), timeout=1.0)
|
||||
except TimeoutError:
|
||||
self.process.kill()
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
finally:
|
||||
self.process = None
|
||||
|
||||
|
||||
async def process_tool_call(tool_call: dict, servers: dict[str, MCPClient]) -> dict:
|
||||
"""Process a tool call from OpenAI."""
|
||||
func_name = tool_call["function"]["name"]
|
||||
try:
|
||||
func_args = json.loads(tool_call["function"].get("arguments", "{}"))
|
||||
except json.JSONDecodeError:
|
||||
return {"error": "Invalid arguments format"}
|
||||
|
||||
# Parse server_name and tool_name from function name
|
||||
parts = func_name.split("_", 1)
|
||||
if len(parts) != 2:
|
||||
return {"error": "Invalid function name format"}
|
||||
|
||||
server_name, tool_name = parts
|
||||
|
||||
if server_name not in servers:
|
||||
return {"error": f"Unknown server: {server_name}"}
|
||||
|
||||
# Call the tool
|
||||
return await servers[server_name].call_tool(tool_name, func_args)
|
||||
|
||||
|
||||
async def run_interaction(user_query: str, model_name: str, config: dict, stream: bool = False) -> dict | AsyncGenerator:
|
||||
"""
|
||||
Run an interaction with OpenAI using MCP server tools.
|
||||
|
||||
Args:
|
||||
user_query: The user's input query
|
||||
model_name: The model to use for processing
|
||||
config: Configuration dictionary
|
||||
stream: Whether to stream the response
|
||||
|
||||
Returns:
|
||||
Dictionary containing response or AsyncGenerator for streaming
|
||||
"""
|
||||
# Get OpenAI configuration
|
||||
api_key = config["models"][0]["apiKey"]
|
||||
base_url = config["models"][0].get("apiBase", "https://api.openai.com/v1")
|
||||
|
||||
# Start MCP servers
|
||||
servers = {}
|
||||
all_functions = []
|
||||
for server_name, server_config in config["mcpServers"].items():
|
||||
client = MCPClient(server_name=server_name, command=server_config.get("command"), args=server_config.get("args", []), env=server_config.get("env", {}))
|
||||
|
||||
if await client.start():
|
||||
tools = await client.list_tools()
|
||||
for tool in tools:
|
||||
all_functions.append({"name": f"{server_name}_{tool['name']}", "description": tool.get("description", ""), "parameters": tool.get("inputSchema", {})})
|
||||
servers[server_name] = client
|
||||
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||
messages = [{"role": "user", "content": user_query}]
|
||||
|
||||
if stream:
|
||||
|
||||
async def response_generator():
|
||||
try:
|
||||
while True:
|
||||
# Get OpenAI response
|
||||
response = await client.chat.completions.create(model=model_name, messages=messages, tools=all_functions, stream=True)
|
||||
|
||||
# Process streaming response
|
||||
full_response = ""
|
||||
tool_calls = []
|
||||
async for chunk in response:
|
||||
if chunk.choices[0].delta.content:
|
||||
content = chunk.choices[0].delta.content
|
||||
full_response += content
|
||||
yield {"assistant_text": content, "is_chunk": True}
|
||||
|
||||
if chunk.choices[0].delta.tool_calls:
|
||||
for tc in chunk.choices[0].delta.tool_calls:
|
||||
if len(tool_calls) <= tc.index:
|
||||
tool_calls.append({"id": tc.id, "type": "function", "function": {"name": tc.function.name, "arguments": tc.function.arguments}})
|
||||
else:
|
||||
tool_calls[tc.index]["function"]["arguments"] += tc.function.arguments
|
||||
|
||||
# Handle tool calls
|
||||
if tool_calls:
|
||||
assistant_message = {"role": "assistant", "content": full_response, "tool_calls": tool_calls}
|
||||
messages.append(assistant_message)
|
||||
|
||||
for tc in tool_calls:
|
||||
result = await process_tool_call(tc, servers)
|
||||
messages.append({"role": "tool", "content": json.dumps(result), "tool_call_id": tc["id"]})
|
||||
else:
|
||||
break
|
||||
|
||||
finally:
|
||||
# Clean up servers
|
||||
for server in servers.values():
|
||||
await server.stop()
|
||||
|
||||
else:
|
||||
try:
|
||||
while True:
|
||||
# Get OpenAI response
|
||||
response = await client.chat.completions.create(model=model_name, messages=messages, tools=all_functions)
|
||||
|
||||
message = response.choices[0].message
|
||||
messages.append(message)
|
||||
|
||||
# Handle tool calls
|
||||
if message.tool_calls:
|
||||
for tc in message.tool_calls:
|
||||
result = await process_tool_call({"id": tc.id, "function": {"name": tc.function.name, "arguments": tc.function.arguments}}, servers)
|
||||
messages.append({"role": "tool", "content": json.dumps(result), "tool_call_id": tc.id})
|
||||
else:
|
||||
return {"assistant_text": message.content or "", "tool_calls": []}
|
||||
|
||||
finally:
|
||||
# Clean up servers
|
||||
for server in servers.values():
|
||||
await server.stop()
|
||||
126
src/mcp_manager.py
Normal file
126
src/mcp_manager.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""Synchronous wrapper for managing MCP servers using our custom implementation."""
|
||||
|
||||
import asyncio
|
||||
import importlib.resources
|
||||
import json
|
||||
import threading
|
||||
|
||||
from custom_mcp_client import MCPClient, run_interaction
|
||||
|
||||
|
||||
class SyncMCPManager:
|
||||
"""Synchronous wrapper for managing MCP servers and interactions"""
|
||||
|
||||
def __init__(self, config_path: str = "config/mcp_config.json"):
|
||||
self.config_path = config_path
|
||||
self.config = None
|
||||
self.servers = {}
|
||||
self.initialized = False
|
||||
self._lock = threading.Lock()
|
||||
self._load_config()
|
||||
|
||||
def _load_config(self):
|
||||
"""Load MCP configuration from JSON file using importlib"""
|
||||
try:
|
||||
# First try to load as a package resource
|
||||
try:
|
||||
with importlib.resources.files("streamlit-chat-app").joinpath(self.config_path).open("r") as f:
|
||||
self.config = json.load(f)
|
||||
except (ImportError, ModuleNotFoundError, TypeError, FileNotFoundError):
|
||||
# Fall back to direct file access
|
||||
with open(self.config_path) as f:
|
||||
self.config = json.load(f)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading MCP config from {self.config_path}: {str(e)}")
|
||||
self.config = None
|
||||
|
||||
def initialize(self) -> bool:
|
||||
"""Initialize and start all MCP servers synchronously"""
|
||||
if not self.config or not self.config.get("mcpServers"):
|
||||
return False
|
||||
|
||||
if self.initialized:
|
||||
return True
|
||||
|
||||
with self._lock:
|
||||
if self.initialized: # Double-check after acquiring lock
|
||||
return True
|
||||
|
||||
# Run async initialization in a new event loop
|
||||
loop = asyncio.new_event_loop()
|
||||
success = loop.run_until_complete(self._async_initialize())
|
||||
loop.close()
|
||||
|
||||
self.initialized = success
|
||||
return success
|
||||
|
||||
async def _async_initialize(self) -> bool:
|
||||
"""Async implementation of server initialization"""
|
||||
success = True
|
||||
for server_name, server_config in self.config["mcpServers"].items():
|
||||
try:
|
||||
client = MCPClient(server_name=server_name, command=server_config.get("command"), args=server_config.get("args", []), env=server_config.get("env", {}))
|
||||
|
||||
if await client.start():
|
||||
tools = await client.list_tools()
|
||||
self.servers[server_name] = {"client": client, "tools": tools}
|
||||
else:
|
||||
success = False
|
||||
print(f"Failed to start MCP server: {server_name}")
|
||||
except Exception as e:
|
||||
print(f"Error initializing server {server_name}: {str(e)}")
|
||||
success = False
|
||||
|
||||
return success
|
||||
|
||||
def shutdown(self):
|
||||
"""Shut down all MCP servers synchronously"""
|
||||
if not self.initialized:
|
||||
return
|
||||
|
||||
with self._lock:
|
||||
if not self.initialized:
|
||||
return
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
loop.run_until_complete(self._async_shutdown())
|
||||
loop.close()
|
||||
|
||||
self.servers = {}
|
||||
self.initialized = False
|
||||
|
||||
async def _async_shutdown(self):
|
||||
"""Async implementation of server shutdown"""
|
||||
for server_info in self.servers.values():
|
||||
try:
|
||||
await server_info["client"].stop()
|
||||
except Exception as e:
|
||||
print(f"Error shutting down server: {str(e)}")
|
||||
|
||||
def process_query(self, query: str, model_name: str) -> dict:
|
||||
"""
|
||||
Process a query using MCP tools synchronously
|
||||
|
||||
Args:
|
||||
query: The user's input query
|
||||
model_name: The model to use for processing
|
||||
|
||||
Returns:
|
||||
Dictionary containing response or error
|
||||
"""
|
||||
if not self.initialized and not self.initialize():
|
||||
return {"error": "Failed to initialize MCP servers"}
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
result = loop.run_until_complete(self._async_process_query(query, model_name))
|
||||
return result
|
||||
except Exception as e:
|
||||
return {"error": f"Processing error: {str(e)}"}
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
async def _async_process_query(self, query: str, model_name: str) -> dict:
|
||||
"""Async implementation of query processing"""
|
||||
return await run_interaction(user_query=query, model_name=model_name, config=self.config, stream=False)
|
||||
@@ -1,39 +1,60 @@
|
||||
"""OpenAI client with custom MCP integration."""
|
||||
|
||||
import configparser
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from mcp_manager import SyncMCPManager
|
||||
|
||||
|
||||
class OpenAIClient:
|
||||
def __init__(self):
|
||||
self.config = configparser.ConfigParser()
|
||||
self.config.read('config/config.ini')
|
||||
|
||||
self.config.read("config/config.ini")
|
||||
|
||||
# Validate configuration
|
||||
if not self.config.has_section('openai'):
|
||||
if not self.config.has_section("openai"):
|
||||
raise Exception("Missing [openai] section in config.ini")
|
||||
if not self.config['openai'].get('api_key'):
|
||||
if not self.config["openai"].get("api_key"):
|
||||
raise Exception("Missing api_key in config.ini")
|
||||
|
||||
|
||||
# Configure OpenAI client
|
||||
self.client = OpenAI(
|
||||
api_key=self.config['openai']['api_key'],
|
||||
base_url=self.config['openai']['base_url'],
|
||||
default_headers={
|
||||
"HTTP-Referer": "https://streamlit-chat-app.com",
|
||||
"X-Title": "Streamlit Chat App"
|
||||
}
|
||||
api_key=self.config["openai"]["api_key"], base_url=self.config["openai"]["base_url"], default_headers={"HTTP-Referer": "https://streamlit-chat-app.com", "X-Title": "Streamlit Chat App"}
|
||||
)
|
||||
|
||||
|
||||
# Initialize MCP manager if configured
|
||||
self.mcp_manager = None
|
||||
if self.config.has_section("dolphin-mcp"):
|
||||
mcp_config_path = self.config["dolphin-mcp"].get("servers_json", "config/mcp_config.json")
|
||||
self.mcp_manager = SyncMCPManager(mcp_config_path)
|
||||
|
||||
def get_chat_response(self, messages):
|
||||
try:
|
||||
print(f"Sending request to {self.config['openai']['base_url']}") # Debug log
|
||||
print(f"Using model: {self.config['openai']['model']}") # Debug log
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.config['openai']['model'],
|
||||
messages=messages,
|
||||
stream=True
|
||||
)
|
||||
return response
|
||||
# Try using MCP if available
|
||||
if self.mcp_manager and self.mcp_manager.initialize():
|
||||
print("Using MCP with tools...")
|
||||
last_message = messages[-1]["content"]
|
||||
response = self.mcp_manager.process_query(last_message, model_name=self.config["openai"]["model"])
|
||||
|
||||
if "error" not in response:
|
||||
# Convert to OpenAI-compatible response format
|
||||
return self._wrap_mcp_response(response)
|
||||
|
||||
# Fall back to standard OpenAI
|
||||
print(f"Using standard OpenAI API with model: {self.config['openai']['model']}")
|
||||
return self.client.chat.completions.create(model=self.config["openai"]["model"], messages=messages, stream=True)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"API Error (Code: {getattr(e, 'code', 'N/A')}): {str(e)}"
|
||||
print(error_msg) # Debug log
|
||||
print(error_msg)
|
||||
raise Exception(error_msg)
|
||||
|
||||
def _wrap_mcp_response(self, response: dict):
|
||||
"""Convert MCP response to OpenAI-compatible format"""
|
||||
|
||||
# Create a generator to simulate streaming response
|
||||
def response_generator():
|
||||
yield {"choices": [{"delta": {"content": response.get("assistant_text", "")}}]}
|
||||
|
||||
return response_generator()
|
||||
|
||||
Reference in New Issue
Block a user