diff --git a/pyproject.toml b/pyproject.toml index e19aee3..04b1a34 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,8 +10,7 @@ authors = [ dependencies = [ "streamlit", "python-dotenv", - "openai", - "dolphin-mcp" + "openai" ] classifiers = [ "Development Status :: 3 - Alpha", @@ -81,3 +80,10 @@ combine-as-imports = true [tool.ruff.lint.mccabe] max-complexity = 12 + +[tool.ruff.lint.flake8-tidy-imports] +# Disallow all relative imports. +ban-relative-imports = "all" + +[tool.streamlit-chat-app.config] +mcp_config = "config/mcp_config.json" diff --git a/src/app.py b/src/app.py index 2c44948..2bcad71 100644 --- a/src/app.py +++ b/src/app.py @@ -1,46 +1,68 @@ +import atexit + import streamlit as st + from openai_client import OpenAIClient + def init_session_state(): if "messages" not in st.session_state: st.session_state.messages = [] + if "client" not in st.session_state: + st.session_state.client = OpenAIClient() + # Register cleanup for MCP servers + if hasattr(st.session_state.client, "mcp_manager"): + atexit.register(st.session_state.client.mcp_manager.shutdown) + def display_chat_messages(): for message in st.session_state.messages: with st.chat_message(message["role"]): st.markdown(message["content"]) + def handle_user_input(): if prompt := st.chat_input("Type your message..."): print(f"User input received: {prompt}") # Debug log st.session_state.messages.append({"role": "user", "content": prompt}) with st.chat_message("user"): st.markdown(prompt) - + try: with st.chat_message("assistant"): response_placeholder = st.empty() full_response = "" - - client = OpenAIClient() - print("Calling OpenAI API...") # Debug log - for chunk in client.get_chat_response(st.session_state.messages): - if chunk.choices[0].delta.content: - full_response += chunk.choices[0].delta.content - response_placeholder.markdown(full_response + "▌") - + + print("Processing message...") # Debug log + response = st.session_state.client.get_chat_response(st.session_state.messages) + + # Handle both MCP and standard OpenAI responses + if hasattr(response, "__iter__"): + # Standard OpenAI streaming response + for chunk in response: + if chunk.choices[0].delta.content: + full_response += chunk.choices[0].delta.content + response_placeholder.markdown(full_response + "▌") + else: + # MCP non-streaming response + full_response = response.get("assistant_text", "") + response_placeholder.markdown(full_response) + response_placeholder.markdown(full_response) st.session_state.messages.append({"role": "assistant", "content": full_response}) - print("API call completed successfully") # Debug log + print("Message processed successfully") # Debug log + except Exception as e: st.error(f"Error processing message: {str(e)}") print(f"Error details: {str(e)}") # Debug log + def main(): st.title("Streamlit Chat App") init_session_state() display_chat_messages() handle_user_input() + if __name__ == "__main__": main() diff --git a/src/custom_mcp_client/__init__.py b/src/custom_mcp_client/__init__.py new file mode 100644 index 0000000..4913c64 --- /dev/null +++ b/src/custom_mcp_client/__init__.py @@ -0,0 +1,5 @@ +"""Custom MCP client implementation focused on OpenAI integration.""" + +from .client import MCPClient, run_interaction + +__all__ = ["MCPClient", "run_interaction"] diff --git a/src/custom_mcp_client/client.py b/src/custom_mcp_client/client.py new file mode 100644 index 0000000..7a0b372 --- /dev/null +++ b/src/custom_mcp_client/client.py @@ -0,0 +1,315 @@ +"""Custom MCP client implementation with JSON-RPC and OpenAI integration.""" + +import asyncio +import json +import os +from collections.abc import AsyncGenerator + +from openai import AsyncOpenAI + + +class MCPClient: + """Lightweight MCP client with JSON-RPC communication.""" + + def __init__(self, server_name: str, command: str, args: list[str] | None = None, env: dict[str, str] | None = None): + self.server_name = server_name + self.command = command + self.args = args or [] + self.env = env or {} + self.process = None + self.tools = [] + self.request_id = 0 + self.responses = {} + self._shutdown = False + + async def _receive_loop(self): + """Listen for responses from the MCP server.""" + try: + while not self.process.stdout.at_eof(): + line = await self.process.stdout.readline() + if not line: + break + try: + message = json.loads(line.decode().strip()) + if "jsonrpc" in message and "id" in message and ("result" in message or "error" in message): + self.responses[message["id"]] = message + except Exception: + pass + except Exception: + pass + + async def _send_message(self, message: dict) -> bool: + """Send a JSON-RPC message to the MCP server.""" + if not self.process: + return False + try: + data = json.dumps(message) + "\n" + self.process.stdin.write(data.encode()) + await self.process.stdin.drain() + return True + except Exception: + return False + + async def start(self) -> bool: + """Start the MCP server process.""" + # Expand ~ in paths + expanded_args = [] + for a in self.args: + if isinstance(a, str) and "~" in a: + expanded_args.append(os.path.expanduser(a)) + else: + expanded_args.append(a) + + # Set up environment + env_vars = os.environ.copy() + if self.env: + env_vars.update(self.env) + + try: + # Start the subprocess + self.process = await asyncio.create_subprocess_exec( + self.command, *expanded_args, stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, env=env_vars + ) + + # Start the receive loop + asyncio.create_task(self._receive_loop()) + + # Initialize the server + return await self._initialize() + except Exception: + return False + + async def _initialize(self) -> bool: + """Initialize the MCP server connection.""" + if not self.process: + return False + + # Send initialize request + self.request_id += 1 + req_id = self.request_id + initialize_req = {"jsonrpc": "2.0", "id": req_id, "method": "initialize", "params": {"protocolVersion": "2024-11-05", "clientInfo": {"name": "CustomMCPClient", "version": "1.0.0"}}} + await self._send_message(initialize_req) + + # Wait for response + start_time = asyncio.get_event_loop().time() + timeout = 10 + while asyncio.get_event_loop().time() - start_time < timeout: + if req_id in self.responses: + resp = self.responses[req_id] + del self.responses[req_id] + + if "error" in resp: + return False + + # Send initialized notification + notify = {"jsonrpc": "2.0", "method": "notifications/initialized"} + await self._send_message(notify) + return True + + await asyncio.sleep(0.05) + + return False + + async def list_tools(self) -> list[dict]: + """List available tools from the MCP server.""" + if not self.process: + return [] + + self.request_id += 1 + req_id = self.request_id + req = {"jsonrpc": "2.0", "id": req_id, "method": "tools/list", "params": {}} + await self._send_message(req) + + # Wait for response + start_time = asyncio.get_event_loop().time() + timeout = 10 + while asyncio.get_event_loop().time() - start_time < timeout: + if req_id in self.responses: + resp = self.responses[req_id] + del self.responses[req_id] + + if "error" in resp: + return [] + + if "result" in resp and "tools" in resp["result"]: + self.tools = resp["result"]["tools"] + return self.tools + + await asyncio.sleep(0.05) + + return [] + + async def call_tool(self, tool_name: str, arguments: dict) -> dict: + """Call a tool on the MCP server.""" + if not self.process: + return {"error": "Server not started"} + + self.request_id += 1 + req_id = self.request_id + req = {"jsonrpc": "2.0", "id": req_id, "method": "tools/call", "params": {"name": tool_name, "arguments": arguments}} + await self._send_message(req) + + # Wait for response + start_time = asyncio.get_event_loop().time() + timeout = 30 + while asyncio.get_event_loop().time() - start_time < timeout: + if req_id in self.responses: + resp = self.responses[req_id] + del self.responses[req_id] + + if "error" in resp: + return {"error": str(resp["error"])} + + if "result" in resp: + return resp["result"] + + await asyncio.sleep(0.05) + + return {"error": f"Tool call timed out after {timeout}s"} + + async def stop(self): + """Stop the MCP server process.""" + if self._shutdown or not self.process: + return + + self._shutdown = True + + try: + # Send shutdown notification + notify = {"jsonrpc": "2.0", "method": "shutdown"} + await self._send_message(notify) + await asyncio.sleep(0.5) + + # Close stdin + if self.process.stdin: + self.process.stdin.close() + + # Terminate the process + self.process.terminate() + try: + await asyncio.wait_for(self.process.wait(), timeout=1.0) + except TimeoutError: + self.process.kill() + + except Exception: + pass + + finally: + self.process = None + + +async def process_tool_call(tool_call: dict, servers: dict[str, MCPClient]) -> dict: + """Process a tool call from OpenAI.""" + func_name = tool_call["function"]["name"] + try: + func_args = json.loads(tool_call["function"].get("arguments", "{}")) + except json.JSONDecodeError: + return {"error": "Invalid arguments format"} + + # Parse server_name and tool_name from function name + parts = func_name.split("_", 1) + if len(parts) != 2: + return {"error": "Invalid function name format"} + + server_name, tool_name = parts + + if server_name not in servers: + return {"error": f"Unknown server: {server_name}"} + + # Call the tool + return await servers[server_name].call_tool(tool_name, func_args) + + +async def run_interaction(user_query: str, model_name: str, config: dict, stream: bool = False) -> dict | AsyncGenerator: + """ + Run an interaction with OpenAI using MCP server tools. + + Args: + user_query: The user's input query + model_name: The model to use for processing + config: Configuration dictionary + stream: Whether to stream the response + + Returns: + Dictionary containing response or AsyncGenerator for streaming + """ + # Get OpenAI configuration + api_key = config["models"][0]["apiKey"] + base_url = config["models"][0].get("apiBase", "https://api.openai.com/v1") + + # Start MCP servers + servers = {} + all_functions = [] + for server_name, server_config in config["mcpServers"].items(): + client = MCPClient(server_name=server_name, command=server_config.get("command"), args=server_config.get("args", []), env=server_config.get("env", {})) + + if await client.start(): + tools = await client.list_tools() + for tool in tools: + all_functions.append({"name": f"{server_name}_{tool['name']}", "description": tool.get("description", ""), "parameters": tool.get("inputSchema", {})}) + servers[server_name] = client + + client = AsyncOpenAI(api_key=api_key, base_url=base_url) + messages = [{"role": "user", "content": user_query}] + + if stream: + + async def response_generator(): + try: + while True: + # Get OpenAI response + response = await client.chat.completions.create(model=model_name, messages=messages, tools=all_functions, stream=True) + + # Process streaming response + full_response = "" + tool_calls = [] + async for chunk in response: + if chunk.choices[0].delta.content: + content = chunk.choices[0].delta.content + full_response += content + yield {"assistant_text": content, "is_chunk": True} + + if chunk.choices[0].delta.tool_calls: + for tc in chunk.choices[0].delta.tool_calls: + if len(tool_calls) <= tc.index: + tool_calls.append({"id": tc.id, "type": "function", "function": {"name": tc.function.name, "arguments": tc.function.arguments}}) + else: + tool_calls[tc.index]["function"]["arguments"] += tc.function.arguments + + # Handle tool calls + if tool_calls: + assistant_message = {"role": "assistant", "content": full_response, "tool_calls": tool_calls} + messages.append(assistant_message) + + for tc in tool_calls: + result = await process_tool_call(tc, servers) + messages.append({"role": "tool", "content": json.dumps(result), "tool_call_id": tc["id"]}) + else: + break + + finally: + # Clean up servers + for server in servers.values(): + await server.stop() + + else: + try: + while True: + # Get OpenAI response + response = await client.chat.completions.create(model=model_name, messages=messages, tools=all_functions) + + message = response.choices[0].message + messages.append(message) + + # Handle tool calls + if message.tool_calls: + for tc in message.tool_calls: + result = await process_tool_call({"id": tc.id, "function": {"name": tc.function.name, "arguments": tc.function.arguments}}, servers) + messages.append({"role": "tool", "content": json.dumps(result), "tool_call_id": tc.id}) + else: + return {"assistant_text": message.content or "", "tool_calls": []} + + finally: + # Clean up servers + for server in servers.values(): + await server.stop() diff --git a/src/mcp_manager.py b/src/mcp_manager.py new file mode 100644 index 0000000..f6a6af7 --- /dev/null +++ b/src/mcp_manager.py @@ -0,0 +1,126 @@ +"""Synchronous wrapper for managing MCP servers using our custom implementation.""" + +import asyncio +import importlib.resources +import json +import threading + +from custom_mcp_client import MCPClient, run_interaction + + +class SyncMCPManager: + """Synchronous wrapper for managing MCP servers and interactions""" + + def __init__(self, config_path: str = "config/mcp_config.json"): + self.config_path = config_path + self.config = None + self.servers = {} + self.initialized = False + self._lock = threading.Lock() + self._load_config() + + def _load_config(self): + """Load MCP configuration from JSON file using importlib""" + try: + # First try to load as a package resource + try: + with importlib.resources.files("streamlit-chat-app").joinpath(self.config_path).open("r") as f: + self.config = json.load(f) + except (ImportError, ModuleNotFoundError, TypeError, FileNotFoundError): + # Fall back to direct file access + with open(self.config_path) as f: + self.config = json.load(f) + + except Exception as e: + print(f"Error loading MCP config from {self.config_path}: {str(e)}") + self.config = None + + def initialize(self) -> bool: + """Initialize and start all MCP servers synchronously""" + if not self.config or not self.config.get("mcpServers"): + return False + + if self.initialized: + return True + + with self._lock: + if self.initialized: # Double-check after acquiring lock + return True + + # Run async initialization in a new event loop + loop = asyncio.new_event_loop() + success = loop.run_until_complete(self._async_initialize()) + loop.close() + + self.initialized = success + return success + + async def _async_initialize(self) -> bool: + """Async implementation of server initialization""" + success = True + for server_name, server_config in self.config["mcpServers"].items(): + try: + client = MCPClient(server_name=server_name, command=server_config.get("command"), args=server_config.get("args", []), env=server_config.get("env", {})) + + if await client.start(): + tools = await client.list_tools() + self.servers[server_name] = {"client": client, "tools": tools} + else: + success = False + print(f"Failed to start MCP server: {server_name}") + except Exception as e: + print(f"Error initializing server {server_name}: {str(e)}") + success = False + + return success + + def shutdown(self): + """Shut down all MCP servers synchronously""" + if not self.initialized: + return + + with self._lock: + if not self.initialized: + return + + loop = asyncio.new_event_loop() + loop.run_until_complete(self._async_shutdown()) + loop.close() + + self.servers = {} + self.initialized = False + + async def _async_shutdown(self): + """Async implementation of server shutdown""" + for server_info in self.servers.values(): + try: + await server_info["client"].stop() + except Exception as e: + print(f"Error shutting down server: {str(e)}") + + def process_query(self, query: str, model_name: str) -> dict: + """ + Process a query using MCP tools synchronously + + Args: + query: The user's input query + model_name: The model to use for processing + + Returns: + Dictionary containing response or error + """ + if not self.initialized and not self.initialize(): + return {"error": "Failed to initialize MCP servers"} + + loop = asyncio.new_event_loop() + try: + result = loop.run_until_complete(self._async_process_query(query, model_name)) + return result + except Exception as e: + return {"error": f"Processing error: {str(e)}"} + finally: + loop.close() + + async def _async_process_query(self, query: str, model_name: str) -> dict: + """Async implementation of query processing""" + return await run_interaction(user_query=query, model_name=model_name, config=self.config, stream=False) diff --git a/src/openai_client.py b/src/openai_client.py index 3817043..79d5e63 100644 --- a/src/openai_client.py +++ b/src/openai_client.py @@ -1,39 +1,60 @@ +"""OpenAI client with custom MCP integration.""" + import configparser + from openai import OpenAI +from mcp_manager import SyncMCPManager + + class OpenAIClient: def __init__(self): self.config = configparser.ConfigParser() - self.config.read('config/config.ini') - + self.config.read("config/config.ini") + # Validate configuration - if not self.config.has_section('openai'): + if not self.config.has_section("openai"): raise Exception("Missing [openai] section in config.ini") - if not self.config['openai'].get('api_key'): + if not self.config["openai"].get("api_key"): raise Exception("Missing api_key in config.ini") - + # Configure OpenAI client self.client = OpenAI( - api_key=self.config['openai']['api_key'], - base_url=self.config['openai']['base_url'], - default_headers={ - "HTTP-Referer": "https://streamlit-chat-app.com", - "X-Title": "Streamlit Chat App" - } + api_key=self.config["openai"]["api_key"], base_url=self.config["openai"]["base_url"], default_headers={"HTTP-Referer": "https://streamlit-chat-app.com", "X-Title": "Streamlit Chat App"} ) - + + # Initialize MCP manager if configured + self.mcp_manager = None + if self.config.has_section("dolphin-mcp"): + mcp_config_path = self.config["dolphin-mcp"].get("servers_json", "config/mcp_config.json") + self.mcp_manager = SyncMCPManager(mcp_config_path) + def get_chat_response(self, messages): try: - print(f"Sending request to {self.config['openai']['base_url']}") # Debug log - print(f"Using model: {self.config['openai']['model']}") # Debug log - - response = self.client.chat.completions.create( - model=self.config['openai']['model'], - messages=messages, - stream=True - ) - return response + # Try using MCP if available + if self.mcp_manager and self.mcp_manager.initialize(): + print("Using MCP with tools...") + last_message = messages[-1]["content"] + response = self.mcp_manager.process_query(last_message, model_name=self.config["openai"]["model"]) + + if "error" not in response: + # Convert to OpenAI-compatible response format + return self._wrap_mcp_response(response) + + # Fall back to standard OpenAI + print(f"Using standard OpenAI API with model: {self.config['openai']['model']}") + return self.client.chat.completions.create(model=self.config["openai"]["model"], messages=messages, stream=True) + except Exception as e: error_msg = f"API Error (Code: {getattr(e, 'code', 'N/A')}): {str(e)}" - print(error_msg) # Debug log + print(error_msg) raise Exception(error_msg) + + def _wrap_mcp_response(self, response: dict): + """Convert MCP response to OpenAI-compatible format""" + + # Create a generator to simulate streaming response + def response_generator(): + yield {"choices": [{"delta": {"content": response.get("assistant_text", "")}}]} + + return response_generator()