feat: Implement async utilities for MCP server management and JSON-RPC communication
- Added `process.py` for managing MCP server subprocesses with async capabilities. - Introduced `protocol.py` for handling JSON-RPC communication over streams. - Created `llm_client.py` to support chat completion requests to various LLM providers, integrating with MCP tools. - Defined model configurations in `llm_models.py` for different LLM providers. - Removed the synchronous `mcp_manager.py` in favor of a more modular approach. - Established a provider framework in `providers` directory with a base class and specific implementations. - Implemented `OpenAIProvider` for interacting with OpenAI's API, including streaming support and tool call handling.
This commit is contained in:
@@ -1,11 +1,11 @@
|
||||
{
|
||||
"mcpServers": {
|
||||
"dolphin-demo-database-sqlite": {
|
||||
"mcp-server-sqlite": {
|
||||
"command": "uvx",
|
||||
"args": [
|
||||
"mcp-server-sqlite",
|
||||
"--db-path",
|
||||
"~/.dolphin/dolphin.db"
|
||||
"~/.mcpapp/mcpapp.db"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -61,6 +61,7 @@ lint.select = [
|
||||
"T10", # flake8-debugger
|
||||
"A", # flake8-builtins
|
||||
"UP", # pyupgrade
|
||||
"TID", # flake8-tidy-imports
|
||||
]
|
||||
|
||||
lint.ignore = [
|
||||
@@ -81,7 +82,7 @@ skip-magic-trailing-comma = false
|
||||
combine-as-imports = true
|
||||
|
||||
[tool.ruff.lint.mccabe]
|
||||
max-complexity = 12
|
||||
max-complexity = 16
|
||||
|
||||
[tool.ruff.lint.flake8-tidy-imports]
|
||||
# Disallow all relative imports.
|
||||
|
||||
178
src/app.py
178
src/app.py
@@ -1,29 +1,111 @@
|
||||
import atexit
|
||||
import configparser
|
||||
import json # For handling potential error JSON in stream
|
||||
import logging
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from openai_client import OpenAIClient
|
||||
# Updated imports
|
||||
from llm_client import LLMClient
|
||||
from src.custom_mcp.manager import SyncMCPManager # Updated import path
|
||||
|
||||
# Configure logging for the app
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def init_session_state():
|
||||
"""Initializes session state variables including clients."""
|
||||
if "messages" not in st.session_state:
|
||||
st.session_state.messages = []
|
||||
logger.info("Initialized session state: messages")
|
||||
|
||||
if "client" not in st.session_state:
|
||||
st.session_state.client = OpenAIClient()
|
||||
# Register cleanup for MCP servers
|
||||
if hasattr(st.session_state.client, "mcp_manager"):
|
||||
atexit.register(st.session_state.client.mcp_manager.shutdown)
|
||||
logger.info("Attempting to initialize clients...")
|
||||
try:
|
||||
config = configparser.ConfigParser()
|
||||
# TODO: Improve config file path handling (e.g., environment variable, absolute path)
|
||||
config_files_read = config.read("config/config.ini")
|
||||
if not config_files_read:
|
||||
raise FileNotFoundError("config.ini not found or could not be read.")
|
||||
logger.info(f"Read configuration from: {config_files_read}")
|
||||
|
||||
# --- MCP Manager Setup ---
|
||||
mcp_config_path = "config/mcp_config.json" # Default
|
||||
if config.has_section("mcp") and config["mcp"].get("servers_json"):
|
||||
mcp_config_path = config["mcp"]["servers_json"]
|
||||
logger.info(f"Using MCP config path from config.ini: {mcp_config_path}")
|
||||
else:
|
||||
logger.info(f"Using default MCP config path: {mcp_config_path}")
|
||||
|
||||
mcp_manager = SyncMCPManager(mcp_config_path)
|
||||
if not mcp_manager.initialize():
|
||||
# Log warning but continue - LLMClient will operate without tools
|
||||
logger.warning("MCP Manager failed to initialize. Proceeding without MCP tools.")
|
||||
else:
|
||||
logger.info("MCP Manager initialized successfully.")
|
||||
# Register shutdown hook for MCP manager
|
||||
atexit.register(mcp_manager.shutdown)
|
||||
logger.info("Registered MCP Manager shutdown hook.")
|
||||
|
||||
# --- LLM Client Setup ---
|
||||
# Determine provider details (e.g., from an [llm] section)
|
||||
provider_name = "openai" # Default
|
||||
model_name = None # Must be provided in config
|
||||
api_key = None
|
||||
base_url = None
|
||||
|
||||
# Prioritize [llm] section, fallback to [openai] for compatibility
|
||||
if config.has_section("llm"):
|
||||
logger.info("Reading configuration from [llm] section.")
|
||||
provider_name = config["llm"].get("provider", provider_name)
|
||||
model_name = config["llm"].get("model")
|
||||
api_key = config["llm"].get("api_key")
|
||||
base_url = config["llm"].get("base_url") # Optional
|
||||
elif config.has_section("openai"):
|
||||
logger.warning("Using legacy [openai] section for configuration.")
|
||||
provider_name = "openai" # Force openai if using this section
|
||||
model_name = config["openai"].get("model")
|
||||
api_key = config["openai"].get("api_key")
|
||||
base_url = config["openai"].get("base_url") # Optional
|
||||
else:
|
||||
raise ValueError("Missing [llm] or [openai] section in config.ini")
|
||||
|
||||
# Validate required config
|
||||
if not api_key:
|
||||
raise ValueError("Missing 'api_key' in config.ini ([llm] or [openai] section)")
|
||||
if not model_name:
|
||||
raise ValueError("Missing 'model' name in config.ini ([llm] or [openai] section)")
|
||||
|
||||
logger.info(f"Configuring LLMClient for provider: {provider_name}, model: {model_name}")
|
||||
st.session_state.client = LLMClient(
|
||||
provider_name=provider_name,
|
||||
api_key=api_key,
|
||||
mcp_manager=mcp_manager,
|
||||
base_url=base_url, # Pass None if not provided
|
||||
)
|
||||
st.session_state.model_name = model_name # Store model name for chat requests
|
||||
logger.info("LLMClient initialized successfully.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize application clients: {e}", exc_info=True)
|
||||
st.error(f"Application Initialization Error: {e}. Please check configuration and logs.")
|
||||
# Stop the app if initialization fails critically
|
||||
st.stop()
|
||||
|
||||
|
||||
def display_chat_messages():
|
||||
"""Displays chat messages stored in session state."""
|
||||
for message in st.session_state.messages:
|
||||
with st.chat_message(message["role"]):
|
||||
# Simple markdown display for now
|
||||
st.markdown(message["content"])
|
||||
|
||||
|
||||
def handle_user_input():
|
||||
"""Handles user input, calls LLMClient, and displays the response."""
|
||||
if prompt := st.chat_input("Type your message..."):
|
||||
print(f"User input received: {prompt}") # Debug log
|
||||
logger.info(f"User input received: '{prompt[:50]}...'")
|
||||
st.session_state.messages.append({"role": "user", "content": prompt})
|
||||
with st.chat_message("user"):
|
||||
st.markdown(prompt)
|
||||
@@ -32,39 +114,85 @@ def handle_user_input():
|
||||
with st.chat_message("assistant"):
|
||||
response_placeholder = st.empty()
|
||||
full_response = ""
|
||||
error_occurred = False
|
||||
|
||||
print("Processing message...") # Debug log
|
||||
response = st.session_state.client.get_chat_response(st.session_state.messages)
|
||||
logger.info("Processing message via LLMClient...")
|
||||
# Use the new client and method, always requesting stream for UI
|
||||
response_stream = st.session_state.client.chat_completion(
|
||||
messages=st.session_state.messages,
|
||||
model=st.session_state.model_name, # Get model from session state
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Handle both MCP and standard OpenAI responses
|
||||
# Check if it's NOT a dict (assuming stream is not a dict)
|
||||
if not isinstance(response, dict):
|
||||
# Standard OpenAI streaming response
|
||||
for chunk in response:
|
||||
# Ensure chunk has choices and delta before accessing
|
||||
if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
|
||||
full_response += chunk.choices[0].delta.content
|
||||
response_placeholder.markdown(full_response + "▌")
|
||||
# Handle the response (stream generator or error dict)
|
||||
if hasattr(response_stream, "__iter__") and not isinstance(response_stream, dict):
|
||||
logger.debug("Processing response stream...")
|
||||
for chunk in response_stream:
|
||||
# Check for potential error JSON yielded by the stream
|
||||
try:
|
||||
# Attempt to parse chunk as JSON only if it looks like it
|
||||
if isinstance(chunk, str) and chunk.strip().startswith("{"):
|
||||
error_data = json.loads(chunk)
|
||||
if isinstance(error_data, dict) and "error" in error_data:
|
||||
full_response = f"Error: {error_data['error']}"
|
||||
logger.error(f"Error received in stream: {full_response}")
|
||||
st.error(full_response)
|
||||
error_occurred = True
|
||||
break # Stop processing stream on error
|
||||
# If not error JSON, treat as content chunk
|
||||
if not error_occurred and isinstance(chunk, str):
|
||||
full_response += chunk
|
||||
response_placeholder.markdown(full_response + "▌") # Add cursor effect
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# Not JSON or not error structure, treat as content chunk
|
||||
if not error_occurred and isinstance(chunk, str):
|
||||
full_response += chunk
|
||||
response_placeholder.markdown(full_response + "▌") # Add cursor effect
|
||||
|
||||
if not error_occurred:
|
||||
response_placeholder.markdown(full_response) # Final update without cursor
|
||||
logger.debug("Stream processing complete.")
|
||||
|
||||
elif isinstance(response_stream, dict) and "error" in response_stream:
|
||||
# Handle error dict returned directly (e.g., API error before streaming)
|
||||
full_response = f"Error: {response_stream['error']}"
|
||||
logger.error(f"Error returned directly from chat_completion: {full_response}")
|
||||
st.error(full_response)
|
||||
error_occurred = True
|
||||
else:
|
||||
# MCP non-streaming response
|
||||
full_response = response.get("assistant_text", "")
|
||||
response_placeholder.markdown(full_response)
|
||||
# Unexpected response type
|
||||
full_response = "[Unexpected response format from LLMClient]"
|
||||
logger.error(f"Unexpected response type: {type(response_stream)}")
|
||||
st.error(full_response)
|
||||
error_occurred = True
|
||||
|
||||
response_placeholder.markdown(full_response)
|
||||
# Only add non-error, non-empty responses to history
|
||||
if not error_occurred and full_response:
|
||||
st.session_state.messages.append({"role": "assistant", "content": full_response})
|
||||
print("Message processed successfully") # Debug log
|
||||
logger.info("Assistant response added to history.")
|
||||
elif error_occurred:
|
||||
logger.warning("Assistant response not added to history due to error.")
|
||||
else:
|
||||
logger.warning("Empty assistant response received, not added to history.")
|
||||
|
||||
except Exception as e:
|
||||
st.error(f"Error processing message: {str(e)}")
|
||||
print(f"Error details: {str(e)}") # Debug log
|
||||
logger.error(f"Error during chat handling: {str(e)}", exc_info=True)
|
||||
st.error(f"An unexpected error occurred: {str(e)}")
|
||||
|
||||
|
||||
def main():
|
||||
st.title("Streamlit Chat App")
|
||||
"""Main function to run the Streamlit app."""
|
||||
st.title("MCP Chat App") # Updated title
|
||||
try:
|
||||
init_session_state()
|
||||
display_chat_messages()
|
||||
handle_user_input()
|
||||
except Exception as e:
|
||||
# Catch potential errors during rendering or handling
|
||||
logger.critical(f"Critical error in main app flow: {e}", exc_info=True)
|
||||
st.error(f"A critical application error occurred: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info("Starting Streamlit Chat App...")
|
||||
main()
|
||||
|
||||
1
src/custom_mcp/__init__.py
Normal file
1
src/custom_mcp/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# This file makes src/mcp a Python package
|
||||
281
src/custom_mcp/client.py
Normal file
281
src/custom_mcp/client.py
Normal file
@@ -0,0 +1,281 @@
|
||||
# src/mcp/client.py
|
||||
"""Client class for managing and interacting with a single MCP server process."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from custom_mcp import process, protocol
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Define reasonable timeouts
|
||||
LIST_TOOLS_TIMEOUT = 20.0 # Seconds (using the increased value from previous step)
|
||||
CALL_TOOL_TIMEOUT = 110.0 # Seconds
|
||||
|
||||
|
||||
class MCPClient:
|
||||
"""
|
||||
Manages the lifecycle and async communication with a single MCP server process.
|
||||
"""
|
||||
|
||||
def __init__(self, server_name: str, command: str, args: list[str], config_env: dict[str, str]):
|
||||
"""
|
||||
Initializes the client for a specific server configuration.
|
||||
|
||||
Args:
|
||||
server_name: Unique name for the server (for logging).
|
||||
command: The command executable.
|
||||
args: List of arguments for the command.
|
||||
config_env: Server-specific environment variables.
|
||||
"""
|
||||
self.server_name = server_name
|
||||
self.command = command
|
||||
self.args = args
|
||||
self.config_env = config_env
|
||||
self.process: asyncio.subprocess.Process | None = None
|
||||
self.reader: asyncio.StreamReader | None = None
|
||||
self.writer: asyncio.StreamWriter | None = None
|
||||
self._stderr_task: asyncio.Task | None = None
|
||||
self._request_counter = 0
|
||||
self._is_running = False
|
||||
self.logger = logging.getLogger(f"{__name__}.{self.server_name}") # Instance-specific logger
|
||||
|
||||
async def _log_stderr(self):
|
||||
"""Logs stderr output from the server process."""
|
||||
if not self.process or not self.process.stderr:
|
||||
self.logger.debug("Stderr logging skipped: process or stderr not available.")
|
||||
return
|
||||
stderr_reader = self.process.stderr
|
||||
try:
|
||||
while not stderr_reader.at_eof():
|
||||
line = await stderr_reader.readline()
|
||||
if line:
|
||||
self.logger.warning(f"[stderr] {line.decode().strip()}")
|
||||
except asyncio.CancelledError:
|
||||
self.logger.debug("Stderr logging task cancelled.")
|
||||
except Exception as e:
|
||||
# Log errors but don't crash the logger task if possible
|
||||
self.logger.error(f"Error reading stderr: {e}", exc_info=True)
|
||||
finally:
|
||||
self.logger.debug("Stderr logging task finished.")
|
||||
|
||||
async def start(self) -> bool:
|
||||
"""
|
||||
Starts the MCP server subprocess and sets up communication streams.
|
||||
|
||||
Returns:
|
||||
True if the process started successfully, False otherwise.
|
||||
"""
|
||||
if self._is_running:
|
||||
self.logger.warning("Start called but client is already running.")
|
||||
return True
|
||||
|
||||
self.logger.info("Starting MCP server process...")
|
||||
try:
|
||||
self.process = await process.start_mcp_process(self.command, self.args, self.config_env)
|
||||
self.reader = self.process.stdout
|
||||
self.writer = self.process.stdin
|
||||
|
||||
if self.reader is None or self.writer is None:
|
||||
self.logger.error("Failed to get stdout/stdin streams after process start.")
|
||||
await self.stop() # Attempt cleanup
|
||||
return False
|
||||
|
||||
# Start background task to monitor stderr
|
||||
self._stderr_task = asyncio.create_task(self._log_stderr())
|
||||
|
||||
# --- Start MCP Initialization Handshake ---
|
||||
self.logger.info("Starting MCP initialization handshake...")
|
||||
self._request_counter += 1
|
||||
init_req_id = self._request_counter
|
||||
initialize_req = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": init_req_id,
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05", # Use a recent version
|
||||
"clientInfo": {"name": "CustomMCPClient", "version": "1.0.0"}, # Identify the client
|
||||
"capabilities": {}, # Client capabilities (can be empty)
|
||||
},
|
||||
}
|
||||
|
||||
# Define a timeout for initialization
|
||||
INITIALIZE_TIMEOUT = 15.0 # Seconds
|
||||
|
||||
try:
|
||||
# Send initialize request
|
||||
await protocol.send_request(self.writer, initialize_req)
|
||||
self.logger.debug(f"Sent 'initialize' request (ID: {init_req_id}). Waiting for response...")
|
||||
|
||||
# Wait for initialize response
|
||||
init_response = await protocol.read_response(self.reader, INITIALIZE_TIMEOUT)
|
||||
|
||||
if init_response and init_response.get("id") == init_req_id:
|
||||
if "error" in init_response:
|
||||
self.logger.error(f"Server returned error during initialization: {init_response['error']}")
|
||||
await self.stop()
|
||||
return False
|
||||
elif "result" in init_response:
|
||||
self.logger.info(f"Received 'initialize' response: {init_response.get('result', '{}')}") # Log server capabilities if provided
|
||||
|
||||
# Send initialized notification (using standard method name)
|
||||
initialized_notify = {"jsonrpc": "2.0", "method": "notifications/initialized", "params": {}}
|
||||
await protocol.send_request(self.writer, initialized_notify)
|
||||
self.logger.info("'notifications/initialized' notification sent.")
|
||||
|
||||
self._is_running = True
|
||||
self.logger.info("MCP server process started and initialized successfully.")
|
||||
return True
|
||||
else:
|
||||
self.logger.error("Invalid 'initialize' response format (missing result/error).")
|
||||
await self.stop()
|
||||
return False
|
||||
elif init_response:
|
||||
self.logger.error(f"Received response with mismatched ID during initialization. Expected {init_req_id}, got {init_response.get('id')}")
|
||||
await self.stop()
|
||||
return False
|
||||
else: # Timeout case
|
||||
self.logger.error(f"'initialize' request timed out after {INITIALIZE_TIMEOUT} seconds.")
|
||||
await self.stop()
|
||||
return False
|
||||
|
||||
except ConnectionResetError:
|
||||
self.logger.error("Connection lost during initialization handshake. Stopping client.")
|
||||
await self.stop()
|
||||
return False
|
||||
except Exception as e:
|
||||
self.logger.error(f"Unexpected error during initialization handshake: {e}", exc_info=True)
|
||||
await self.stop()
|
||||
return False
|
||||
# --- End MCP Initialization Handshake ---
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to start MCP server process: {e}", exc_info=True)
|
||||
self.process = None # Ensure process is None on failure
|
||||
self.reader = None
|
||||
self.writer = None
|
||||
self._is_running = False
|
||||
return False
|
||||
|
||||
async def stop(self):
|
||||
"""Stops the MCP server subprocess gracefully."""
|
||||
if not self._is_running and not self.process:
|
||||
self.logger.debug("Stop called but client is not running.")
|
||||
return
|
||||
|
||||
self.logger.info("Stopping MCP server process...")
|
||||
self._is_running = False # Mark as stopping
|
||||
|
||||
# Cancel stderr logging task
|
||||
if self._stderr_task and not self._stderr_task.done():
|
||||
self._stderr_task.cancel()
|
||||
try:
|
||||
await self._stderr_task
|
||||
except asyncio.CancelledError:
|
||||
self.logger.debug("Stderr task successfully cancelled.")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error waiting for stderr task cancellation: {e}")
|
||||
self._stderr_task = None
|
||||
|
||||
# Stop the process using the utility function
|
||||
if self.process:
|
||||
await process.stop_mcp_process(self.process, self.server_name)
|
||||
|
||||
# Nullify references
|
||||
self.process = None
|
||||
self.reader = None
|
||||
self.writer = None
|
||||
self.logger.info("MCP server process stopped.")
|
||||
|
||||
async def list_tools(self) -> list[dict[str, Any]] | None:
|
||||
"""
|
||||
Sends a 'tools/list' request and waits for the response.
|
||||
|
||||
Returns:
|
||||
A list of tool dictionaries, or None on error/timeout.
|
||||
"""
|
||||
if not self._is_running or not self.writer or not self.reader:
|
||||
self.logger.error("Cannot list tools: client not running or streams unavailable.")
|
||||
return None
|
||||
|
||||
self._request_counter += 1
|
||||
req_id = self._request_counter
|
||||
request = {"jsonrpc": "2.0", "method": "tools/list", "id": req_id}
|
||||
|
||||
try:
|
||||
await protocol.send_request(self.writer, request)
|
||||
response = await protocol.read_response(self.reader, LIST_TOOLS_TIMEOUT)
|
||||
|
||||
if response and "result" in response and isinstance(response["result"], dict) and "tools" in response["result"]:
|
||||
tools = response["result"]["tools"]
|
||||
if isinstance(tools, list):
|
||||
self.logger.info(f"Successfully listed {len(tools)} tools.")
|
||||
return tools
|
||||
else:
|
||||
self.logger.error(f"Invalid 'tools' format in response ID {req_id}: {type(tools)}")
|
||||
return None
|
||||
elif response and "error" in response:
|
||||
self.logger.error(f"Error response for listTools ID {req_id}: {response['error']}")
|
||||
return None
|
||||
else:
|
||||
# Includes timeout case (read_response returns None)
|
||||
self.logger.error(f"No valid response or timeout for listTools ID {req_id}.")
|
||||
return None
|
||||
|
||||
except ConnectionResetError:
|
||||
self.logger.error("Connection lost during listTools request. Stopping client.")
|
||||
await self.stop()
|
||||
return None
|
||||
except Exception as e:
|
||||
self.logger.error(f"Unexpected error during listTools: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""
|
||||
Sends a 'tools/call' request and waits for the response.
|
||||
|
||||
Args:
|
||||
tool_name: The name of the tool to call.
|
||||
arguments: The arguments for the tool.
|
||||
|
||||
Returns:
|
||||
The result dictionary from the server, or None on error/timeout.
|
||||
"""
|
||||
if not self._is_running or not self.writer or not self.reader:
|
||||
self.logger.error(f"Cannot call tool '{tool_name}': client not running or streams unavailable.")
|
||||
return None
|
||||
|
||||
self._request_counter += 1
|
||||
req_id = self._request_counter
|
||||
request = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": "tools/call",
|
||||
"params": {"name": tool_name, "arguments": arguments},
|
||||
"id": req_id,
|
||||
}
|
||||
|
||||
try:
|
||||
await protocol.send_request(self.writer, request)
|
||||
response = await protocol.read_response(self.reader, CALL_TOOL_TIMEOUT)
|
||||
|
||||
if response and "result" in response:
|
||||
# Assuming result is the desired payload
|
||||
self.logger.info(f"Tool '{tool_name}' executed successfully.")
|
||||
return response["result"]
|
||||
elif response and "error" in response:
|
||||
self.logger.error(f"Error response for tool '{tool_name}' ID {req_id}: {response['error']}")
|
||||
# Return the error structure itself? Or just None? Returning error dict for now.
|
||||
return {"error": response["error"]}
|
||||
else:
|
||||
# Includes timeout case
|
||||
self.logger.error(f"No valid response or timeout for tool '{tool_name}' ID {req_id}.")
|
||||
return None
|
||||
|
||||
except ConnectionResetError:
|
||||
self.logger.error(f"Connection lost during callTool '{tool_name}'. Stopping client.")
|
||||
await self.stop()
|
||||
return None
|
||||
except Exception as e:
|
||||
self.logger.error(f"Unexpected error during callTool '{tool_name}': {e}", exc_info=True)
|
||||
return None
|
||||
366
src/custom_mcp/manager.py
Normal file
366
src/custom_mcp/manager.py
Normal file
@@ -0,0 +1,366 @@
|
||||
# src/mcp/manager.py
|
||||
"""Synchronous manager for multiple MCPClient instances."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
# Use relative imports within the mcp package
|
||||
from custom_mcp.client import MCPClient
|
||||
|
||||
# Configure basic logging
|
||||
# Consider moving this to the main app entry point if not already done
|
||||
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Define reasonable timeouts for sync calls (should be slightly longer than async timeouts)
|
||||
INITIALIZE_TIMEOUT = 60.0 # Seconds
|
||||
SHUTDOWN_TIMEOUT = 30.0 # Seconds
|
||||
LIST_ALL_TOOLS_TIMEOUT = 30.0 # Seconds
|
||||
EXECUTE_TOOL_TIMEOUT = 120.0 # Seconds
|
||||
|
||||
|
||||
class SyncMCPManager:
|
||||
"""
|
||||
Manages the lifecycle of multiple MCPClient instances and provides a
|
||||
synchronous interface to interact with them using a background event loop.
|
||||
"""
|
||||
|
||||
def __init__(self, config_path: str = "config/mcp_config.json"):
|
||||
"""
|
||||
Initializes the manager, loads config, but does not start servers yet.
|
||||
|
||||
Args:
|
||||
config_path: Path to the MCP server configuration JSON file.
|
||||
"""
|
||||
self.config_path = config_path
|
||||
self.config: dict[str, Any] | None = None
|
||||
# Stores server_name -> MCPClient instance
|
||||
self.servers: dict[str, MCPClient] = {}
|
||||
self.initialized = False
|
||||
self._lock = threading.Lock()
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
self._thread: threading.Thread | None = None
|
||||
logger.info(f"Initializing SyncMCPManager with config path: {config_path}")
|
||||
self._load_config()
|
||||
|
||||
def _load_config(self):
|
||||
"""Load MCP configuration from JSON file."""
|
||||
logger.debug(f"Attempting to load MCP config from: {self.config_path}")
|
||||
try:
|
||||
# Using direct file access
|
||||
with open(self.config_path) as f:
|
||||
self.config = json.load(f)
|
||||
logger.info("MCP configuration loaded successfully.")
|
||||
logger.debug(f"Config content: {self.config}")
|
||||
except FileNotFoundError:
|
||||
logger.error(f"MCP config file not found at {self.config_path}")
|
||||
self.config = None
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Error decoding JSON from MCP config file {self.config_path}: {e}")
|
||||
self.config = None
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading MCP config from {self.config_path}: {e}", exc_info=True)
|
||||
self.config = None
|
||||
|
||||
# --- Background Event Loop Management ---
|
||||
|
||||
def _run_event_loop(self):
|
||||
"""Target function for the background event loop thread."""
|
||||
try:
|
||||
self._loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self._loop)
|
||||
self._loop.run_forever()
|
||||
finally:
|
||||
if self._loop and not self._loop.is_closed():
|
||||
# Clean up remaining tasks before closing
|
||||
try:
|
||||
tasks = asyncio.all_tasks(self._loop)
|
||||
if tasks:
|
||||
logger.debug(f"Cancelling {len(tasks)} outstanding tasks before closing loop...")
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
# Allow cancellation to propagate
|
||||
self._loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
|
||||
logger.debug("Outstanding tasks cancelled.")
|
||||
self._loop.run_until_complete(self._loop.shutdown_asyncgens())
|
||||
except Exception as e:
|
||||
logger.error(f"Error during event loop cleanup: {e}")
|
||||
finally:
|
||||
self._loop.close()
|
||||
asyncio.set_event_loop(None)
|
||||
logger.info("Event loop thread finished.")
|
||||
|
||||
def _start_event_loop_thread(self):
|
||||
"""Starts the background event loop thread if not already running."""
|
||||
if self._thread is None or not self._thread.is_alive():
|
||||
self._thread = threading.Thread(target=self._run_event_loop, name="MCPEventLoop", daemon=True)
|
||||
self._thread.start()
|
||||
logger.info("Event loop thread started.")
|
||||
# Wait briefly for the loop to become available and running
|
||||
while self._loop is None or not self._loop.is_running():
|
||||
# Use time.sleep in sync context
|
||||
import time
|
||||
|
||||
time.sleep(0.01)
|
||||
logger.debug("Event loop is running.")
|
||||
|
||||
def _stop_event_loop_thread(self):
|
||||
"""Stops the background event loop thread."""
|
||||
if self._loop and self._loop.is_running():
|
||||
logger.info("Requesting event loop stop...")
|
||||
self._loop.call_soon_threadsafe(self._loop.stop)
|
||||
if self._thread and self._thread.is_alive():
|
||||
logger.info("Waiting for event loop thread to join...")
|
||||
self._thread.join(timeout=5)
|
||||
if self._thread.is_alive():
|
||||
logger.warning("Event loop thread did not stop gracefully.")
|
||||
self._loop = None
|
||||
self._thread = None
|
||||
logger.info("Event loop stopped.")
|
||||
|
||||
# --- Public Synchronous Interface ---
|
||||
|
||||
def initialize(self) -> bool:
|
||||
"""
|
||||
Initializes and starts all configured MCP servers synchronously.
|
||||
|
||||
Returns:
|
||||
True if all servers started successfully, False otherwise.
|
||||
"""
|
||||
logger.info("Manager initialization requested.")
|
||||
if not self.config or not self.config.get("mcpServers"):
|
||||
logger.warning("Initialization skipped: No valid configuration loaded.")
|
||||
return False
|
||||
|
||||
with self._lock:
|
||||
if self.initialized:
|
||||
logger.debug("Initialization skipped: Already initialized.")
|
||||
return True
|
||||
|
||||
self._start_event_loop_thread()
|
||||
if not self._loop:
|
||||
logger.error("Failed to start event loop for initialization.")
|
||||
return False
|
||||
|
||||
logger.info("Submitting asynchronous server initialization...")
|
||||
|
||||
# Prepare coroutine to start all clients
|
||||
|
||||
async def _async_init_all():
|
||||
tasks = []
|
||||
for server_name, server_config in self.config["mcpServers"].items():
|
||||
command = server_config.get("command")
|
||||
args = server_config.get("args", [])
|
||||
config_env = server_config.get("env", {})
|
||||
if not command:
|
||||
logger.error(f"Skipping server {server_name}: Missing 'command'.")
|
||||
continue
|
||||
|
||||
client = MCPClient(server_name, command, args, config_env)
|
||||
self.servers[server_name] = client
|
||||
tasks.append(client.start()) # Append the start coroutine
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Check results - True means success, False or Exception means failure
|
||||
all_success = True
|
||||
failed_servers = []
|
||||
for i, result in enumerate(results):
|
||||
server_name = list(self.config["mcpServers"].keys())[i] # Assumes order is maintained
|
||||
if isinstance(result, Exception) or result is False:
|
||||
all_success = False
|
||||
failed_servers.append(server_name)
|
||||
# Remove failed client from managed servers
|
||||
if server_name in self.servers:
|
||||
del self.servers[server_name]
|
||||
logger.error(f"Failed to start client for server '{server_name}'. Result/Error: {result}")
|
||||
|
||||
if not all_success:
|
||||
logger.error(f"Initialization failed for servers: {failed_servers}")
|
||||
return all_success
|
||||
|
||||
# Run the initialization coroutine in the background loop
|
||||
future = asyncio.run_coroutine_threadsafe(_async_init_all(), self._loop)
|
||||
try:
|
||||
success = future.result(timeout=INITIALIZE_TIMEOUT)
|
||||
if success:
|
||||
logger.info("Asynchronous initialization completed successfully.")
|
||||
self.initialized = True
|
||||
else:
|
||||
logger.error("Asynchronous initialization failed.")
|
||||
self.initialized = False
|
||||
# Attempt to clean up any partially started servers
|
||||
self.shutdown() # Call sync shutdown
|
||||
except TimeoutError:
|
||||
logger.error(f"Initialization timed out after {INITIALIZE_TIMEOUT}s.")
|
||||
self.initialized = False
|
||||
self.shutdown() # Clean up
|
||||
success = False
|
||||
except Exception as e:
|
||||
logger.error(f"Exception during initialization future result: {e}", exc_info=True)
|
||||
self.initialized = False
|
||||
self.shutdown() # Clean up
|
||||
success = False
|
||||
|
||||
return self.initialized
|
||||
|
||||
def shutdown(self):
|
||||
"""Shuts down all managed MCP servers synchronously."""
|
||||
logger.info("Manager shutdown requested.")
|
||||
with self._lock:
|
||||
# Check servers dict too, in case init was partial
|
||||
if not self.initialized and not self.servers:
|
||||
logger.debug("Shutdown skipped: Not initialized or no servers running.")
|
||||
# Ensure loop is stopped if it exists
|
||||
if self._thread and self._thread.is_alive():
|
||||
self._stop_event_loop_thread()
|
||||
return
|
||||
|
||||
if not self._loop or not self._loop.is_running():
|
||||
logger.warning("Shutdown requested but event loop not running. Attempting direct cleanup.")
|
||||
# Attempt direct cleanup if loop isn't running (shouldn't happen ideally)
|
||||
# This part is tricky as MCPClient.stop is async.
|
||||
# For simplicity, we might just log and rely on process termination on app exit.
|
||||
# Or, try a temporary loop just for shutdown? Let's stick to stopping the thread for now.
|
||||
self.servers = {}
|
||||
self.initialized = False
|
||||
if self._thread and self._thread.is_alive():
|
||||
self._stop_event_loop_thread()
|
||||
return
|
||||
|
||||
logger.info("Submitting asynchronous server shutdown...")
|
||||
|
||||
# Prepare coroutine to stop all clients
|
||||
|
||||
async def _async_shutdown_all():
|
||||
tasks = [client.stop() for client in self.servers.values()]
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Run the shutdown coroutine in the background loop
|
||||
future = asyncio.run_coroutine_threadsafe(_async_shutdown_all(), self._loop)
|
||||
try:
|
||||
future.result(timeout=SHUTDOWN_TIMEOUT)
|
||||
logger.info("Asynchronous shutdown completed.")
|
||||
except TimeoutError:
|
||||
logger.error(f"Shutdown timed out after {SHUTDOWN_TIMEOUT}s. Event loop will be stopped.")
|
||||
# Processes might still be running, OS will clean up on exit hopefully
|
||||
except Exception as e:
|
||||
logger.error(f"Exception during shutdown future result: {e}", exc_info=True)
|
||||
finally:
|
||||
# Always mark as uninitialized and clear servers dict
|
||||
self.servers = {}
|
||||
self.initialized = False
|
||||
# Stop the background thread
|
||||
self._stop_event_loop_thread()
|
||||
|
||||
logger.info("Manager shutdown complete.")
|
||||
|
||||
def list_all_tools(self) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Retrieves tools from all initialized MCP servers synchronously.
|
||||
|
||||
Returns:
|
||||
A list of tool definitions in the standard internal format,
|
||||
aggregated from all servers. Returns empty list on failure.
|
||||
"""
|
||||
if not self.initialized or not self.servers:
|
||||
logger.warning("Cannot list tools: Manager not initialized or no servers running.")
|
||||
return []
|
||||
|
||||
if not self._loop or not self._loop.is_running():
|
||||
logger.error("Cannot list tools: Event loop not running.")
|
||||
return []
|
||||
|
||||
logger.info(f"Requesting tools from {len(self.servers)} servers...")
|
||||
|
||||
# Prepare coroutine to list tools from all clients
|
||||
async def _async_list_all():
|
||||
tasks = []
|
||||
server_names_in_order = []
|
||||
for server_name, client in self.servers.items():
|
||||
tasks.append(client.list_tools())
|
||||
server_names_in_order.append(server_name)
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
all_tools = []
|
||||
for i, result in enumerate(results):
|
||||
server_name = server_names_in_order[i]
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"Error listing tools for server '{server_name}': {result}")
|
||||
elif result is None:
|
||||
# MCPClient.list_tools returns None on timeout/error
|
||||
logger.error(f"Failed to list tools for server '{server_name}' (timeout or error).")
|
||||
elif isinstance(result, list):
|
||||
# Add server_name to each tool definition
|
||||
for tool in result:
|
||||
tool["server_name"] = server_name
|
||||
all_tools.extend(result)
|
||||
logger.debug(f"Received {len(result)} tools from {server_name}")
|
||||
else:
|
||||
logger.error(f"Unexpected result type ({type(result)}) when listing tools for {server_name}.")
|
||||
return all_tools
|
||||
|
||||
# Run the coroutine in the background loop
|
||||
future = asyncio.run_coroutine_threadsafe(_async_list_all(), self._loop)
|
||||
try:
|
||||
aggregated_tools = future.result(timeout=LIST_ALL_TOOLS_TIMEOUT)
|
||||
logger.info(f"Aggregated {len(aggregated_tools)} tools from all servers.")
|
||||
return aggregated_tools
|
||||
except TimeoutError:
|
||||
logger.error(f"Listing all tools timed out after {LIST_ALL_TOOLS_TIMEOUT}s.")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Exception during listing all tools future result: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
def execute_tool(self, server_name: str, tool_name: str, arguments: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""
|
||||
Executes a specific tool on the designated MCP server synchronously.
|
||||
|
||||
Args:
|
||||
server_name: The name of the server hosting the tool.
|
||||
tool_name: The name of the tool to execute.
|
||||
arguments: A dictionary of arguments for the tool.
|
||||
|
||||
Returns:
|
||||
The result content from the tool execution (dict),
|
||||
an error dict ({"error": ...}), or None on timeout/comm failure.
|
||||
"""
|
||||
if not self.initialized:
|
||||
logger.warning(f"Cannot execute tool '{tool_name}' on {server_name}: Manager not initialized.")
|
||||
return None
|
||||
|
||||
client = self.servers.get(server_name)
|
||||
if not client:
|
||||
logger.error(f"Cannot execute tool: Server '{server_name}' not found.")
|
||||
return None
|
||||
|
||||
if not self._loop or not self._loop.is_running():
|
||||
logger.error(f"Cannot execute tool '{tool_name}': Event loop not running.")
|
||||
return None
|
||||
|
||||
logger.info(f"Executing tool '{tool_name}' on server '{server_name}' with args: {arguments}")
|
||||
|
||||
# Run the client's call_tool coroutine in the background loop
|
||||
future = asyncio.run_coroutine_threadsafe(client.call_tool(tool_name, arguments), self._loop)
|
||||
try:
|
||||
result = future.result(timeout=EXECUTE_TOOL_TIMEOUT)
|
||||
# MCPClient.call_tool returns the result dict or an error dict or None
|
||||
if result is None:
|
||||
logger.error(f"Tool execution '{tool_name}' on {server_name} failed (timeout or comm error).")
|
||||
elif isinstance(result, dict) and "error" in result:
|
||||
logger.error(f"Tool execution '{tool_name}' on {server_name} returned error: {result['error']}")
|
||||
else:
|
||||
logger.info(f"Tool '{tool_name}' execution successful.")
|
||||
return result # Return result dict, error dict, or None
|
||||
except TimeoutError:
|
||||
logger.error(f"Tool execution timed out after {EXECUTE_TOOL_TIMEOUT}s for '{tool_name}' on {server_name}.")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Exception during tool execution future result for '{tool_name}' on {server_name}: {e}", exc_info=True)
|
||||
return None
|
||||
128
src/custom_mcp/process.py
Normal file
128
src/custom_mcp/process.py
Normal file
@@ -0,0 +1,128 @@
|
||||
# src/mcp/process.py
|
||||
"""Async utilities for managing MCP server subprocesses."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def start_mcp_process(command: str, args: list[str], config_env: dict[str, str]) -> asyncio.subprocess.Process:
|
||||
"""
|
||||
Starts an MCP server subprocess using asyncio.create_subprocess_shell.
|
||||
|
||||
Handles argument expansion and environment merging.
|
||||
|
||||
Args:
|
||||
command: The main command executable.
|
||||
args: A list of arguments for the command.
|
||||
config_env: Server-specific environment variables from config.
|
||||
|
||||
Returns:
|
||||
The started asyncio.subprocess.Process object.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the command is not found.
|
||||
Exception: For other errors during subprocess creation.
|
||||
"""
|
||||
logger.debug(f"Preparing to start process for command: {command}")
|
||||
|
||||
# --- Add tilde expansion for arguments ---
|
||||
expanded_args = []
|
||||
try:
|
||||
for arg in args:
|
||||
if isinstance(arg, str) and "~" in arg:
|
||||
expanded_args.append(os.path.expanduser(arg))
|
||||
else:
|
||||
# Ensure all args are strings for list2cmdline
|
||||
expanded_args.append(str(arg))
|
||||
logger.debug(f"Expanded args: {expanded_args}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error expanding arguments for {command}: {e}", exc_info=True)
|
||||
raise ValueError(f"Failed to expand arguments: {e}") from e
|
||||
|
||||
# --- Merge os.environ with config_env ---
|
||||
merged_env = {**os.environ, **config_env}
|
||||
# logger.debug(f"Merged environment prepared (keys: {list(merged_env.keys())})") # Avoid logging values
|
||||
|
||||
# Combine command and expanded args into a single string for shell execution
|
||||
try:
|
||||
cmd_string = subprocess.list2cmdline([command] + expanded_args)
|
||||
logger.debug(f"Executing shell command: {cmd_string}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating command string: {e}", exc_info=True)
|
||||
raise ValueError(f"Failed to create command string: {e}") from e
|
||||
|
||||
# --- Start the subprocess using shell ---
|
||||
try:
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
cmd_string,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
env=merged_env,
|
||||
)
|
||||
logger.info(f"Subprocess started (PID: {process.pid}) for command: {command}")
|
||||
return process
|
||||
except FileNotFoundError:
|
||||
logger.error(f"Command not found: '{command}' when trying to execute '{cmd_string}'")
|
||||
raise # Re-raise specific error
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create subprocess for '{cmd_string}': {e}", exc_info=True)
|
||||
raise # Re-raise other errors
|
||||
|
||||
|
||||
async def stop_mcp_process(process: asyncio.subprocess.Process, server_name: str = "MCP Server"):
|
||||
"""
|
||||
Attempts to gracefully stop the MCP server subprocess.
|
||||
|
||||
Args:
|
||||
process: The asyncio.subprocess.Process object to stop.
|
||||
server_name: A name for logging purposes.
|
||||
"""
|
||||
if process is None or process.returncode is not None:
|
||||
logger.debug(f"Process {server_name} (PID: {process.pid if process else 'N/A'}) already stopped or not started.")
|
||||
return
|
||||
|
||||
pid = process.pid
|
||||
logger.info(f"Attempting to stop process {server_name} (PID: {pid})...")
|
||||
|
||||
# Close stdin first
|
||||
if process.stdin and not process.stdin.is_closing():
|
||||
try:
|
||||
process.stdin.close()
|
||||
await process.stdin.wait_closed()
|
||||
logger.debug(f"Stdin closed for {server_name} (PID: {pid})")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing stdin for {server_name} (PID: {pid}): {e}")
|
||||
|
||||
# Attempt graceful termination
|
||||
try:
|
||||
process.terminate()
|
||||
logger.debug(f"Sent terminate signal to {server_name} (PID: {pid})")
|
||||
await asyncio.wait_for(process.wait(), timeout=5.0)
|
||||
logger.info(f"Process {server_name} (PID: {pid}) terminated gracefully (return code: {process.returncode}).")
|
||||
except TimeoutError:
|
||||
logger.warning(f"Process {server_name} (PID: {pid}) did not terminate gracefully after 5s, killing.")
|
||||
try:
|
||||
process.kill()
|
||||
await process.wait() # Wait for kill to complete
|
||||
logger.info(f"Process {server_name} (PID: {pid}) killed (return code: {process.returncode}).")
|
||||
except ProcessLookupError:
|
||||
logger.warning(f"Process {server_name} (PID: {pid}) already exited before kill.")
|
||||
except Exception as e_kill:
|
||||
logger.error(f"Error killing process {server_name} (PID: {pid}): {e_kill}")
|
||||
except ProcessLookupError:
|
||||
logger.warning(f"Process {server_name} (PID: {pid}) already exited before termination.")
|
||||
except Exception as e_term:
|
||||
logger.error(f"Error during termination of {server_name} (PID: {pid}): {e_term}")
|
||||
# Attempt kill as fallback if terminate failed and process might still be running
|
||||
if process.returncode is None:
|
||||
try:
|
||||
process.kill()
|
||||
await process.wait()
|
||||
logger.info(f"Process {server_name} (PID: {pid}) killed after termination error (return code: {process.returncode}).")
|
||||
except Exception as e_kill_fallback:
|
||||
logger.error(f"Error killing process {server_name} (PID: {pid}) after termination error: {e_kill_fallback}")
|
||||
75
src/custom_mcp/protocol.py
Normal file
75
src/custom_mcp/protocol.py
Normal file
@@ -0,0 +1,75 @@
|
||||
# src/mcp/protocol.py
|
||||
"""Async utilities for MCP JSON-RPC communication over streams."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def send_request(writer: asyncio.StreamWriter, request_dict: dict[str, Any]):
|
||||
"""
|
||||
Sends a JSON-RPC request dictionary to the MCP server's stdin stream.
|
||||
|
||||
Args:
|
||||
writer: The asyncio StreamWriter connected to the process stdin.
|
||||
request_dict: The request dictionary to send.
|
||||
|
||||
Raises:
|
||||
ConnectionResetError: If the connection is lost during send.
|
||||
Exception: For other stream writing errors.
|
||||
"""
|
||||
try:
|
||||
request_json = json.dumps(request_dict) + "\n"
|
||||
writer.write(request_json.encode("utf-8"))
|
||||
await writer.drain()
|
||||
logger.debug(f"Sent request ID {request_dict.get('id')}: {request_json.strip()}")
|
||||
except ConnectionResetError:
|
||||
logger.error(f"Connection lost while sending request ID {request_dict.get('id')}")
|
||||
raise # Re-raise for the caller (MCPClient) to handle
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending request ID {request_dict.get('id')}: {e}", exc_info=True)
|
||||
raise # Re-raise for the caller
|
||||
|
||||
|
||||
async def read_response(reader: asyncio.StreamReader, timeout: float) -> dict[str, Any] | None:
|
||||
"""
|
||||
Reads and parses a JSON-RPC response line from the MCP server's stdout stream.
|
||||
|
||||
Args:
|
||||
reader: The asyncio StreamReader connected to the process stdout.
|
||||
timeout: Seconds to wait for a response line.
|
||||
|
||||
Returns:
|
||||
The parsed response dictionary, or None if timeout or error occurs.
|
||||
"""
|
||||
response_str = None
|
||||
try:
|
||||
response_json = await asyncio.wait_for(reader.readline(), timeout=timeout)
|
||||
if not response_json:
|
||||
logger.warning("Received empty response line (EOF?).")
|
||||
return None
|
||||
|
||||
response_str = response_json.decode("utf-8").strip()
|
||||
if not response_str:
|
||||
logger.warning("Received empty response string after strip.")
|
||||
return None
|
||||
|
||||
logger.debug(f"Received response line: {response_str}")
|
||||
response_dict = json.loads(response_str)
|
||||
return response_dict
|
||||
|
||||
except TimeoutError:
|
||||
logger.error(f"Timeout ({timeout}s) waiting for response.")
|
||||
return None
|
||||
except asyncio.IncompleteReadError:
|
||||
logger.warning("Connection closed while waiting for response.")
|
||||
return None
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Error decoding JSON response: {e}. Response: '{response_str}'")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading response: {e}", exc_info=True)
|
||||
return None
|
||||
219
src/llm_client.py
Normal file
219
src/llm_client.py
Normal file
@@ -0,0 +1,219 @@
|
||||
# src/llm_client.py
|
||||
"""
|
||||
Generic LLM client supporting multiple providers and MCP tool integration.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from src.custom_mcp.manager import SyncMCPManager # Updated import path
|
||||
from src.providers import BaseProvider, create_llm_provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMClient:
|
||||
"""
|
||||
Handles chat completion requests to various LLM providers through a unified
|
||||
interface, integrating with MCP tools via SyncMCPManager.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider_name: str,
|
||||
api_key: str,
|
||||
mcp_manager: SyncMCPManager,
|
||||
base_url: str | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the LLM client.
|
||||
|
||||
Args:
|
||||
provider_name: Name of the provider (e.g., 'openai', 'anthropic').
|
||||
api_key: API key for the provider.
|
||||
mcp_manager: An initialized instance of SyncMCPManager.
|
||||
base_url: Optional base URL for the provider API.
|
||||
"""
|
||||
logger.info(f"Initializing LLMClient for provider: {provider_name}")
|
||||
self.provider: BaseProvider = create_llm_provider(provider_name, api_key, base_url)
|
||||
self.mcp_manager = mcp_manager
|
||||
self.mcp_tools: list[dict[str, Any]] = []
|
||||
self._refresh_mcp_tools() # Initial tool load
|
||||
|
||||
def _refresh_mcp_tools(self):
|
||||
"""Retrieves the latest tools from the MCP manager."""
|
||||
logger.info("Refreshing MCP tools...")
|
||||
try:
|
||||
self.mcp_tools = self.mcp_manager.list_all_tools()
|
||||
logger.info(f"Refreshed {len(self.mcp_tools)} MCP tools.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error refreshing MCP tools: {e}", exc_info=True)
|
||||
# Keep existing tools if refresh fails
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
model: str,
|
||||
temperature: float = 0.4,
|
||||
max_tokens: int | None = None,
|
||||
stream: bool = True,
|
||||
) -> Generator[str, None, None] | dict[str, Any]:
|
||||
"""
|
||||
Send a chat completion request, handling potential tool calls.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries ({'role': 'user'/'assistant', 'content': ...}).
|
||||
model: Model identifier string.
|
||||
temperature: Sampling temperature.
|
||||
max_tokens: Maximum tokens to generate.
|
||||
stream: Whether to stream the response.
|
||||
|
||||
Returns:
|
||||
If stream=True: A generator yielding content chunks.
|
||||
If stream=False: A dictionary containing the final content or an error.
|
||||
e.g., {"content": "..."} or {"error": "..."}
|
||||
"""
|
||||
# Ensure tools are up-to-date (optional, could be done less frequently)
|
||||
# self._refresh_mcp_tools()
|
||||
|
||||
# Convert tools to the provider-specific format
|
||||
try:
|
||||
provider_tools = self.provider.convert_tools(self.mcp_tools)
|
||||
logger.debug(f"Converted {len(self.mcp_tools)} tools for provider {self.provider.__class__.__name__}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting tools for provider: {e}", exc_info=True)
|
||||
provider_tools = None # Proceed without tools if conversion fails
|
||||
|
||||
try:
|
||||
logger.info(f"Sending chat completion request to provider with model: {model}")
|
||||
response = self.provider.create_chat_completion(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stream=stream,
|
||||
tools=provider_tools,
|
||||
)
|
||||
logger.info("Received response from provider.")
|
||||
|
||||
if stream:
|
||||
# Streaming with tool calls requires more complex handling (like airflow_wingman's
|
||||
# process_tool_calls_and_follow_up). For now, we'll yield the initial stream
|
||||
# and handle tool calls *after* the stream completes if detected (less ideal UX).
|
||||
# A better approach involves checking for tool calls before streaming fully.
|
||||
# This simplified version just streams the first response.
|
||||
logger.info("Streaming response...")
|
||||
# NOTE: This simple version doesn't handle tool calls during streaming well.
|
||||
# It will stream the initial response which might *contain* the tool call request,
|
||||
# but won't execute it within the stream.
|
||||
return self._stream_generator(response)
|
||||
|
||||
else: # Non-streaming
|
||||
logger.info("Processing non-streaming response...")
|
||||
if self.provider.has_tool_calls(response):
|
||||
logger.info("Tool calls detected in response.")
|
||||
# Simplified non-streaming tool call handling (one round)
|
||||
try:
|
||||
tool_calls = self.provider.parse_tool_calls(response)
|
||||
logger.debug(f"Parsed tool calls: {tool_calls}")
|
||||
|
||||
tool_results = []
|
||||
original_message_with_calls = self.provider.get_original_message_with_calls(response) # Provider needs to implement this
|
||||
messages.append(original_message_with_calls) # Add assistant's turn with tool requests
|
||||
|
||||
for tool_call in tool_calls:
|
||||
server_name = tool_call.get("server_name") # Needs to be parsed by provider
|
||||
func_name = tool_call.get("function_name")
|
||||
func_args_str = tool_call.get("arguments")
|
||||
call_id = tool_call.get("id")
|
||||
|
||||
if not server_name or not func_name or func_args_str is None or call_id is None:
|
||||
logger.error(f"Skipping invalid tool call data: {tool_call}")
|
||||
# Add error result?
|
||||
result_content = {"error": "Invalid tool call structure from LLM"}
|
||||
else:
|
||||
try:
|
||||
# Arguments might be a JSON string, parse them
|
||||
arguments = json.loads(func_args_str)
|
||||
logger.info(f"Executing tool '{func_name}' on server '{server_name}' with args: {arguments}")
|
||||
# Execute synchronously using the manager
|
||||
execution_result = self.mcp_manager.execute_tool(server_name, func_name, arguments)
|
||||
logger.debug(f"Tool execution result: {execution_result}")
|
||||
|
||||
if execution_result is None:
|
||||
result_content = {"error": f"Tool execution failed or timed out for {func_name}"}
|
||||
elif isinstance(execution_result, dict) and "error" in execution_result:
|
||||
result_content = execution_result # Propagate error from tool/server
|
||||
else:
|
||||
# Assuming result is the content payload
|
||||
result_content = execution_result
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Failed to parse arguments for tool {func_name}: {func_args_str}")
|
||||
result_content = {"error": f"Invalid arguments format for tool {func_name}"}
|
||||
except Exception as exec_err:
|
||||
logger.error(f"Error executing tool {func_name}: {exec_err}", exc_info=True)
|
||||
result_content = {"error": f"Exception during tool execution: {str(exec_err)}"}
|
||||
|
||||
# Format result for the provider's follow-up message
|
||||
formatted_result = self.provider.format_tool_results(call_id, result_content)
|
||||
tool_results.append(formatted_result)
|
||||
messages.append(formatted_result) # Add tool result message
|
||||
|
||||
# Make follow-up call
|
||||
logger.info("Making follow-up request with tool results...")
|
||||
follow_up_response = self.provider.create_chat_completion(
|
||||
messages=messages, # Now includes assistant's turn and tool results
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stream=False, # Follow-up is non-streaming here
|
||||
tools=provider_tools, # Pass tools again? Some providers might need it.
|
||||
)
|
||||
final_content = self.provider.get_content(follow_up_response)
|
||||
logger.info("Received follow-up response content.")
|
||||
return {"content": final_content}
|
||||
|
||||
except Exception as tool_handling_err:
|
||||
logger.error(f"Error processing tool calls: {tool_handling_err}", exc_info=True)
|
||||
return {"error": f"Failed to handle tool calls: {str(tool_handling_err)}"}
|
||||
|
||||
else: # No tool calls
|
||||
logger.info("No tool calls detected.")
|
||||
content = self.provider.get_content(response)
|
||||
return {"content": content}
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"LLM API Error: {str(e)}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
if stream:
|
||||
# How to signal error in a stream? Yield a specific error message?
|
||||
# This simple generator won't handle it well. Returning an error dict for now.
|
||||
return {"error": error_msg} # Or raise?
|
||||
else:
|
||||
return {"error": error_msg}
|
||||
|
||||
def _stream_generator(self, response: Any) -> Generator[str, None, None]:
|
||||
"""Helper to yield content from the provider's streaming method."""
|
||||
try:
|
||||
# Use yield from for cleaner and potentially more efficient delegation
|
||||
yield from self.provider.get_streaming_content(response)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during streaming: {e}", exc_info=True)
|
||||
yield json.dumps({"error": f"Streaming error: {str(e)}"}) # Yield error as JSON chunk
|
||||
|
||||
|
||||
# Example of how a provider might need to implement get_original_message_with_calls
|
||||
# This would be in the specific provider class (e.g., openai_provider.py)
|
||||
# def get_original_message_with_calls(self, response: Any) -> Dict[str, Any]:
|
||||
# # For OpenAI, the tool calls are usually in the *first* response chunk's choice delta
|
||||
# # or in the non-streaming response's choice message
|
||||
# # Needs careful implementation based on provider's response structure
|
||||
# assistant_message = {
|
||||
# "role": "assistant",
|
||||
# "content": None, # Often null when tool calls are present
|
||||
# "tool_calls": [...] # Extracted tool calls in provider format
|
||||
# }
|
||||
# return assistant_message
|
||||
61
src/llm_models.py
Normal file
61
src/llm_models.py
Normal file
@@ -0,0 +1,61 @@
|
||||
MODELS = {
|
||||
"openai": {
|
||||
"name": "OpenAI",
|
||||
"endpoint": "https://api.openai.com/v1",
|
||||
"models": [
|
||||
{
|
||||
"id": "gpt-4o",
|
||||
"name": "GPT-4o",
|
||||
"default": True,
|
||||
"context_window": 128000,
|
||||
"description": "Input $5/M tokens, Output $15/M tokens",
|
||||
}
|
||||
],
|
||||
},
|
||||
"anthropic": {
|
||||
"name": "Anthropic",
|
||||
"endpoint": "https://api.anthropic.com/v1/messages",
|
||||
"models": [
|
||||
{
|
||||
"id": "claude-3-7-sonnet-20250219",
|
||||
"name": "Claude 3.7 Sonnet",
|
||||
"default": True,
|
||||
"context_window": 200000,
|
||||
"description": "Input $3/M tokens, Output $15/M tokens",
|
||||
},
|
||||
{
|
||||
"id": "claude-3-5-haiku-20241022",
|
||||
"name": "Claude 3.5 Haiku",
|
||||
"default": False,
|
||||
"context_window": 200000,
|
||||
"description": "Input $0.80/M tokens, Output $4/M tokens",
|
||||
},
|
||||
],
|
||||
},
|
||||
"google": {
|
||||
"name": "Google Gemini",
|
||||
"endpoint": "https://generativelanguage.googleapis.com/v1beta/generateContent",
|
||||
"models": [
|
||||
{
|
||||
"id": "gemini-2.0-flash",
|
||||
"name": "Gemini 2.0 Flash",
|
||||
"default": True,
|
||||
"context_window": 1000000,
|
||||
"description": "Input $0.1/M tokens, Output $0.4/M tokens",
|
||||
}
|
||||
],
|
||||
},
|
||||
"openrouter": {
|
||||
"name": "OpenRouter",
|
||||
"endpoint": "https://openrouter.ai/api/v1/chat/completions",
|
||||
"models": [
|
||||
{
|
||||
"id": "custom",
|
||||
"name": "Custom Model",
|
||||
"default": False,
|
||||
"context_window": 128000, # Default context window, will be updated based on model
|
||||
"description": "Enter any model name supported by OpenRouter (e.g., 'anthropic/claude-3-opus', 'meta-llama/llama-2-70b')",
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
@@ -1,234 +0,0 @@
|
||||
"""Synchronous wrapper for managing MCP servers using our custom implementation."""
|
||||
|
||||
import asyncio
|
||||
import importlib.resources
|
||||
import json
|
||||
import logging # Import logging
|
||||
import threading
|
||||
|
||||
from custom_mcp_client import MCPClient, run_interaction
|
||||
|
||||
# Configure basic logging for the application if not already configured
|
||||
# This basic config helps ensure logs are visible during development
|
||||
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
|
||||
# Get a logger for this module
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SyncMCPManager:
|
||||
"""Synchronous wrapper for managing MCP servers and interactions"""
|
||||
|
||||
def __init__(self, config_path: str = "config/mcp_config.json"):
|
||||
self.config_path = config_path
|
||||
self.config = None
|
||||
self.servers = {}
|
||||
self.initialized = False
|
||||
self._lock = threading.Lock()
|
||||
logger.info(f"Initializing SyncMCPManager with config path: {config_path}")
|
||||
self._load_config()
|
||||
|
||||
def _load_config(self):
|
||||
"""Load MCP configuration from JSON file using importlib"""
|
||||
logger.debug(f"Attempting to load MCP config from: {self.config_path}")
|
||||
try:
|
||||
# First try to load as a package resource
|
||||
try:
|
||||
# Try anchoring to the project name defined in pyproject.toml
|
||||
# This *might* work depending on editable installs or context.
|
||||
resource_path = importlib.resources.files("streamlit-chat-app").joinpath(self.config_path)
|
||||
with resource_path.open("r") as f:
|
||||
self.config = json.load(f)
|
||||
logger.debug("Loaded config via importlib.resources anchored to 'streamlit-chat-app'.")
|
||||
# REMOVED: raise FileNotFoundError
|
||||
|
||||
except (ImportError, ModuleNotFoundError, TypeError, FileNotFoundError, NotADirectoryError): # Added NotADirectoryError
|
||||
logger.debug("Failed to load via importlib.resources, falling back to direct file access.")
|
||||
# Fall back to direct file access relative to CWD
|
||||
with open(self.config_path) as f:
|
||||
self.config = json.load(f)
|
||||
logger.debug("Loaded config via direct file access.")
|
||||
|
||||
logger.info("MCP configuration loaded successfully.")
|
||||
logger.debug(f"Config content: {self.config}") # Log content only if loaded
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error(f"MCP config file not found at {self.config_path}")
|
||||
self.config = None
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Error decoding JSON from MCP config file {self.config_path}: {e}")
|
||||
self.config = None
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading MCP config from {self.config_path}: {e}", exc_info=True)
|
||||
self.config = None
|
||||
|
||||
def initialize(self) -> bool:
|
||||
"""Initialize and start all MCP servers synchronously"""
|
||||
logger.info("Initialize requested.")
|
||||
if not self.config:
|
||||
logger.warning("Initialization skipped: No configuration loaded.")
|
||||
return False
|
||||
if not self.config.get("mcpServers"):
|
||||
logger.warning("Initialization skipped: No 'mcpServers' defined in configuration.")
|
||||
return False
|
||||
|
||||
if self.initialized:
|
||||
logger.debug("Initialization skipped: Already initialized.")
|
||||
return True
|
||||
|
||||
with self._lock:
|
||||
if self.initialized: # Double-check after acquiring lock
|
||||
logger.debug("Initialization skipped inside lock: Already initialized.")
|
||||
return True
|
||||
|
||||
logger.info("Starting asynchronous initialization...")
|
||||
# Run async initialization in a new event loop
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop) # Ensure this loop is used by tasks
|
||||
success = loop.run_until_complete(self._async_initialize())
|
||||
loop.close()
|
||||
asyncio.set_event_loop(None) # Clean up
|
||||
|
||||
if success:
|
||||
logger.info("Asynchronous initialization completed successfully.")
|
||||
self.initialized = True
|
||||
else:
|
||||
logger.error("Asynchronous initialization failed.")
|
||||
self.initialized = False # Ensure state is False on failure
|
||||
|
||||
return self.initialized
|
||||
|
||||
async def _async_initialize(self) -> bool:
|
||||
"""Async implementation of server initialization"""
|
||||
logger.debug("Starting _async_initialize...")
|
||||
all_success = True
|
||||
if not self.config or not self.config.get("mcpServers"):
|
||||
logger.warning("_async_initialize: No config or mcpServers found.")
|
||||
return False
|
||||
|
||||
tasks = []
|
||||
server_names = list(self.config["mcpServers"].keys())
|
||||
|
||||
async def start_server(server_name, server_config):
|
||||
logger.info(f"Initializing server: {server_name}")
|
||||
try:
|
||||
client = MCPClient(server_name=server_name, command=server_config.get("command"), args=server_config.get("args", []), env=server_config.get("env", {}))
|
||||
|
||||
logger.debug(f"Attempting to start client for {server_name}...")
|
||||
if await client.start():
|
||||
logger.info(f"Client for {server_name} started successfully.")
|
||||
tools = await client.list_tools()
|
||||
logger.info(f"Tools listed for {server_name}: {len(tools)}")
|
||||
self.servers[server_name] = {"client": client, "tools": tools}
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Failed to start MCP server: {server_name}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing server {server_name}: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
# Start servers concurrently
|
||||
for server_name in server_names:
|
||||
server_config = self.config["mcpServers"][server_name]
|
||||
tasks.append(start_server(server_name, server_config))
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# Check if all servers started successfully
|
||||
all_success = all(results)
|
||||
|
||||
if all_success:
|
||||
logger.debug("_async_initialize completed: All servers started successfully.")
|
||||
else:
|
||||
failed_servers = [server_names[i] for i, res in enumerate(results) if not res]
|
||||
logger.error(f"_async_initialize completed with failures. Failed servers: {failed_servers}")
|
||||
# Optionally shutdown servers that did start if partial success is not desired
|
||||
# await self._async_shutdown() # Uncomment to enforce all-or-nothing startup
|
||||
|
||||
return all_success
|
||||
|
||||
def shutdown(self):
|
||||
"""Shut down all MCP servers synchronously"""
|
||||
logger.info("Shutdown requested.")
|
||||
if not self.initialized:
|
||||
logger.debug("Shutdown skipped: Not initialized.")
|
||||
return
|
||||
|
||||
with self._lock:
|
||||
if not self.initialized:
|
||||
logger.debug("Shutdown skipped inside lock: Not initialized.")
|
||||
return
|
||||
|
||||
logger.info("Starting asynchronous shutdown...")
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_until_complete(self._async_shutdown())
|
||||
loop.close()
|
||||
asyncio.set_event_loop(None)
|
||||
|
||||
self.servers = {}
|
||||
self.initialized = False
|
||||
logger.info("Shutdown complete.")
|
||||
|
||||
async def _async_shutdown(self):
|
||||
"""Async implementation of server shutdown"""
|
||||
logger.debug("Starting _async_shutdown...")
|
||||
tasks = []
|
||||
for server_name, server_info in self.servers.items():
|
||||
logger.debug(f"Initiating shutdown for server: {server_name}")
|
||||
tasks.append(server_info["client"].stop())
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
for i, result in enumerate(results):
|
||||
server_name = list(self.servers.keys())[i]
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"Error shutting down server {server_name}: {result}", exc_info=result)
|
||||
else:
|
||||
logger.debug(f"Shutdown completed for server: {server_name}")
|
||||
logger.debug("_async_shutdown finished.")
|
||||
|
||||
# Updated process_query signature
|
||||
def process_query(self, query: str, model_name: str, api_key: str, base_url: str | None) -> dict:
|
||||
"""
|
||||
Process a query using MCP tools synchronously
|
||||
|
||||
Args:
|
||||
query: The user's input query.
|
||||
model_name: The model to use for processing.
|
||||
api_key: The OpenAI API key.
|
||||
base_url: The OpenAI API base URL.
|
||||
|
||||
Returns:
|
||||
Dictionary containing response or error.
|
||||
"""
|
||||
if not self.initialized and not self.initialize():
|
||||
logger.error("process_query called but MCP manager failed to initialize.")
|
||||
return {"error": "Failed to initialize MCP servers"}
|
||||
|
||||
logger.debug(f"Processing query synchronously: '{query}'")
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
# Pass api_key and base_url to _async_process_query
|
||||
result = loop.run_until_complete(self._async_process_query(query, model_name, api_key, base_url))
|
||||
logger.debug(f"Synchronous query processing result: {result}")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error during synchronous query processing: {e}", exc_info=True)
|
||||
return {"error": f"Processing error: {str(e)}"}
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
# Updated _async_process_query signature
|
||||
async def _async_process_query(self, query: str, model_name: str, api_key: str, base_url: str | None) -> dict:
|
||||
"""Async implementation of query processing"""
|
||||
# Pass api_key, base_url, and the MCP config separately to run_interaction
|
||||
return await run_interaction(
|
||||
user_query=query,
|
||||
model_name=model_name,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
mcp_config=self.config, # self.config only contains MCP server definitions now
|
||||
stream=False,
|
||||
)
|
||||
71
src/providers/__init__.py
Normal file
71
src/providers/__init__.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# src/providers/__init__.py
|
||||
import logging
|
||||
|
||||
from providers.base import BaseProvider
|
||||
|
||||
# Import specific provider implementations here as they are created
|
||||
from providers.openai_provider import OpenAIProvider
|
||||
|
||||
# from .anthropic_provider import AnthropicProvider
|
||||
# from .google_provider import GoogleProvider
|
||||
# from .openrouter_provider import OpenRouterProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Map provider names (lowercase) to their corresponding class implementations
|
||||
PROVIDER_MAP: dict[str, type[BaseProvider]] = {
|
||||
"openai": OpenAIProvider,
|
||||
# "anthropic": AnthropicProvider,
|
||||
# "google": GoogleProvider,
|
||||
# "openrouter": OpenRouterProvider,
|
||||
}
|
||||
|
||||
|
||||
def register_provider(name: str, provider_class: type[BaseProvider]):
|
||||
"""Registers a provider class."""
|
||||
if name.lower() in PROVIDER_MAP:
|
||||
logger.warning(f"Provider '{name}' is already registered. Overwriting.")
|
||||
PROVIDER_MAP[name.lower()] = provider_class
|
||||
logger.info(f"Registered provider: {name}")
|
||||
|
||||
|
||||
def create_llm_provider(provider_name: str, api_key: str, base_url: str | None = None) -> BaseProvider:
|
||||
"""
|
||||
Factory function to create an instance of a specific LLM provider.
|
||||
|
||||
Args:
|
||||
provider_name: The name of the provider (e.g., 'openai', 'anthropic').
|
||||
api_key: The API key for the provider.
|
||||
base_url: Optional base URL for the provider's API.
|
||||
|
||||
Returns:
|
||||
An instance of the requested BaseProvider subclass.
|
||||
|
||||
Raises:
|
||||
ValueError: If the requested provider_name is not registered.
|
||||
"""
|
||||
provider_class = PROVIDER_MAP.get(provider_name.lower())
|
||||
|
||||
if provider_class is None:
|
||||
available = ", ".join(PROVIDER_MAP.keys()) or "None"
|
||||
raise ValueError(f"Unsupported LLM provider: '{provider_name}'. Available providers: {available}")
|
||||
|
||||
logger.info(f"Creating LLM provider instance for: {provider_name}")
|
||||
try:
|
||||
return provider_class(api_key=api_key, base_url=base_url)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to instantiate provider '{provider_name}': {e}", exc_info=True)
|
||||
raise RuntimeError(f"Could not create provider '{provider_name}'.") from e
|
||||
|
||||
|
||||
def get_available_providers() -> list[str]:
|
||||
"""Returns a list of registered provider names."""
|
||||
return list(PROVIDER_MAP.keys())
|
||||
|
||||
|
||||
# Example of how specific providers would register themselves if structured as plugins,
|
||||
# but for now, we'll explicitly import and map them above.
|
||||
# def load_providers():
|
||||
# # Potentially load providers dynamically if designed as plugins
|
||||
# pass
|
||||
# load_providers()
|
||||
140
src/providers/base.py
Normal file
140
src/providers/base.py
Normal file
@@ -0,0 +1,140 @@
|
||||
# src/providers/base.py
|
||||
import abc
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
|
||||
class BaseProvider(abc.ABC):
|
||||
"""
|
||||
Abstract base class for LLM providers.
|
||||
|
||||
Defines the common interface for interacting with different LLM APIs,
|
||||
including handling chat completions and tool usage.
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: str, base_url: str | None = None):
|
||||
"""
|
||||
Initialize the provider.
|
||||
|
||||
Args:
|
||||
api_key: The API key for the provider.
|
||||
base_url: Optional base URL for the provider's API.
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
|
||||
@abc.abstractmethod
|
||||
def create_chat_completion(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
model: str,
|
||||
temperature: float = 0.4,
|
||||
max_tokens: int | None = None,
|
||||
stream: bool = True,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Send a chat completion request to the LLM provider.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries with 'role' and 'content'.
|
||||
model: Model identifier.
|
||||
temperature: Sampling temperature (0-1).
|
||||
max_tokens: Maximum tokens to generate.
|
||||
stream: Whether to stream the response.
|
||||
tools: Optional list of tools in the provider-specific format.
|
||||
|
||||
Returns:
|
||||
Provider-specific response object (e.g., API response, stream object).
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_streaming_content(self, response: Any) -> Generator[str, None, None]:
|
||||
"""
|
||||
Extracts and yields content chunks from a streaming response object.
|
||||
|
||||
Args:
|
||||
response: The streaming response object returned by create_chat_completion.
|
||||
|
||||
Yields:
|
||||
String chunks of the response content.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_content(self, response: Any) -> str:
|
||||
"""
|
||||
Extracts the complete content from a non-streaming response object.
|
||||
|
||||
Args:
|
||||
response: The non-streaming response object.
|
||||
|
||||
Returns:
|
||||
The complete response content as a string.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def has_tool_calls(self, response: Any) -> bool:
|
||||
"""
|
||||
Checks if the response object contains tool calls.
|
||||
|
||||
Args:
|
||||
response: The response object (streaming or non-streaming).
|
||||
|
||||
Returns:
|
||||
True if tool calls are present, False otherwise.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def parse_tool_calls(self, response: Any) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Parses tool calls from the response object.
|
||||
|
||||
Args:
|
||||
response: The response object containing tool calls.
|
||||
|
||||
Returns:
|
||||
A list of dictionaries, each representing a tool call with details
|
||||
like 'id', 'function_name', 'arguments'. The exact structure might
|
||||
vary slightly based on provider needs but should contain enough
|
||||
info for execution.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def format_tool_results(self, tool_call_id: str, result: Any) -> dict[str, Any]:
|
||||
"""
|
||||
Formats the result of a tool execution into the structure expected
|
||||
by the provider for follow-up requests.
|
||||
|
||||
Args:
|
||||
tool_call_id: The unique ID of the tool call (from parse_tool_calls).
|
||||
result: The data returned by the tool execution.
|
||||
|
||||
Returns:
|
||||
A dictionary representing the tool result in the provider's format.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Converts a list of tools from the standard internal format to the
|
||||
provider-specific format required for the API call.
|
||||
|
||||
Args:
|
||||
tools: List of tool definitions in the standard internal format.
|
||||
Each dict contains 'server_name', 'name', 'description', 'input_schema'.
|
||||
|
||||
Returns:
|
||||
List of tool definitions in the provider-specific format.
|
||||
"""
|
||||
pass
|
||||
|
||||
# Optional: Add a method for follow-up completions if the provider API
|
||||
# requires a specific structure different from just appending messages.
|
||||
# def create_follow_up_completion(...) -> Any:
|
||||
# pass
|
||||
239
src/providers/openai_provider.py
Normal file
239
src/providers/openai_provider.py
Normal file
@@ -0,0 +1,239 @@
|
||||
# src/providers/openai_provider.py
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from openai import OpenAI, Stream
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall
|
||||
|
||||
from providers.base import BaseProvider
|
||||
from src.llm_models import MODELS # Use absolute import
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIProvider(BaseProvider):
|
||||
"""Provider implementation for OpenAI and compatible APIs."""
|
||||
|
||||
def __init__(self, api_key: str, base_url: str | None = None):
|
||||
# Use default OpenAI endpoint if base_url is not provided
|
||||
effective_base_url = base_url or MODELS.get("openai", {}).get("endpoint")
|
||||
super().__init__(api_key, effective_base_url)
|
||||
logger.info(f"Initializing OpenAIProvider with base URL: {self.base_url}")
|
||||
try:
|
||||
# TODO: Add default headers like in original client?
|
||||
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize OpenAI client: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def create_chat_completion(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
model: str,
|
||||
temperature: float = 0.4,
|
||||
max_tokens: int | None = None,
|
||||
stream: bool = True,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
) -> Stream[ChatCompletionChunk] | ChatCompletion:
|
||||
"""Creates a chat completion using the OpenAI API."""
|
||||
logger.debug(f"OpenAI create_chat_completion called. Stream: {stream}, Tools: {bool(tools)}")
|
||||
try:
|
||||
completion_params = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"stream": stream,
|
||||
}
|
||||
if tools:
|
||||
completion_params["tools"] = tools
|
||||
completion_params["tool_choice"] = "auto" # Let OpenAI decide when to use tools
|
||||
|
||||
# Remove None values like max_tokens if not provided
|
||||
completion_params = {k: v for k, v in completion_params.items() if v is not None}
|
||||
|
||||
# --- Added Debug Logging ---
|
||||
log_params = completion_params.copy()
|
||||
# Avoid logging full messages if they are too long
|
||||
if "messages" in log_params:
|
||||
log_params["messages"] = [
|
||||
{k: (v[:100] + "..." if isinstance(v, str) and len(v) > 100 else v) for k, v in msg.items()}
|
||||
for msg in log_params["messages"][-2:] # Log last 2 messages summary
|
||||
]
|
||||
# Specifically log tools structure if present
|
||||
tools_log = log_params.get("tools", "Not Present")
|
||||
logger.debug(f"Calling OpenAI API. Model: {log_params.get('model')}, Stream: {log_params.get('stream')}, Tools: {tools_log}")
|
||||
logger.debug(f"Full API Params (messages summarized): {log_params}")
|
||||
# --- End Added Debug Logging ---
|
||||
|
||||
response = self.client.chat.completions.create(**completion_params)
|
||||
logger.debug("OpenAI API call successful.")
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI API error: {e}", exc_info=True)
|
||||
# Re-raise for the LLMClient to handle
|
||||
raise
|
||||
|
||||
def get_streaming_content(self, response: Stream[ChatCompletionChunk]) -> Generator[str, None, None]:
|
||||
"""Yields content chunks from an OpenAI streaming response."""
|
||||
logger.debug("Processing OpenAI stream...")
|
||||
full_delta = ""
|
||||
try:
|
||||
for chunk in response:
|
||||
delta = chunk.choices[0].delta.content
|
||||
if delta:
|
||||
full_delta += delta
|
||||
yield delta
|
||||
logger.debug(f"Stream finished. Total delta length: {len(full_delta)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing OpenAI stream: {e}", exc_info=True)
|
||||
# Yield an error message? Or let the generator stop?
|
||||
yield json.dumps({"error": f"Stream processing error: {str(e)}"})
|
||||
|
||||
def get_content(self, response: ChatCompletion) -> str:
|
||||
"""Extracts content from a non-streaming OpenAI response."""
|
||||
try:
|
||||
content = response.choices[0].message.content
|
||||
logger.debug(f"Extracted content (length {len(content) if content else 0}) from non-streaming response.")
|
||||
return content or "" # Return empty string if content is None
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting content from OpenAI response: {e}", exc_info=True)
|
||||
return f"[Error extracting content: {str(e)}]"
|
||||
|
||||
def has_tool_calls(self, response: Stream[ChatCompletionChunk] | ChatCompletion) -> bool:
|
||||
"""Checks if the OpenAI response contains tool calls."""
|
||||
try:
|
||||
if isinstance(response, ChatCompletion): # Non-streaming
|
||||
return bool(response.choices[0].message.tool_calls)
|
||||
elif hasattr(response, "_iterator"): # Check if it looks like our stream wrapper
|
||||
# This is tricky for streams. We'd need to peek at the first chunk(s)
|
||||
# or buffer the response. For simplicity, this check might be unreliable
|
||||
# for streams *before* they are consumed. LLMClient needs robust handling.
|
||||
logger.warning("has_tool_calls check on a stream is unreliable before consumption.")
|
||||
# A more robust check would involve consuming the start of the stream
|
||||
# or relying on the structure after consumption.
|
||||
return False # Assume no for unconsumed stream for now
|
||||
else:
|
||||
# If it's already consumed stream or unexpected type
|
||||
logger.warning(f"has_tool_calls received unexpected type: {type(response)}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking for tool calls: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def parse_tool_calls(self, response: ChatCompletion) -> list[dict[str, Any]]:
|
||||
"""Parses tool calls from a non-streaming OpenAI response."""
|
||||
# This implementation assumes a non-streaming response or a fully buffered stream
|
||||
parsed_calls = []
|
||||
try:
|
||||
if not isinstance(response, ChatCompletion):
|
||||
logger.error(f"parse_tool_calls expects ChatCompletion, got {type(response)}")
|
||||
# Attempt to handle buffered stream if possible? Complex.
|
||||
return []
|
||||
|
||||
tool_calls: list[ChatCompletionMessageToolCall] | None = response.choices[0].message.tool_calls
|
||||
if not tool_calls:
|
||||
return []
|
||||
|
||||
logger.debug(f"Parsing {len(tool_calls)} tool calls from OpenAI response.")
|
||||
for call in tool_calls:
|
||||
if call.type == "function":
|
||||
# Attempt to parse server_name from function name if prefixed
|
||||
# e.g., "server-name__actual-tool-name"
|
||||
parts = call.function.name.split("__", 1)
|
||||
if len(parts) == 2:
|
||||
server_name, func_name = parts
|
||||
else:
|
||||
# If no prefix, how do we know the server? Needs refinement.
|
||||
# Defaulting to None or a default server? Log warning.
|
||||
logger.warning(f"Could not determine server_name from tool name '{call.function.name}'. Assuming default or error needed.")
|
||||
server_name = None # Or raise error, or use a default?
|
||||
func_name = call.function.name
|
||||
|
||||
parsed_calls.append({
|
||||
"id": call.id,
|
||||
"server_name": server_name, # May be None if not prefixed
|
||||
"function_name": func_name,
|
||||
"arguments": call.function.arguments, # Arguments are already a string here
|
||||
})
|
||||
else:
|
||||
logger.warning(f"Unsupported tool call type: {call.type}")
|
||||
|
||||
return parsed_calls
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing OpenAI tool calls: {e}", exc_info=True)
|
||||
return [] # Return empty list on error
|
||||
|
||||
def format_tool_results(self, tool_call_id: str, result: Any) -> dict[str, Any]:
|
||||
"""Formats a tool result for an OpenAI follow-up request."""
|
||||
# Result might be a dict (including potential errors) or simple string/number
|
||||
# OpenAI expects the content to be a string, often JSON.
|
||||
try:
|
||||
if isinstance(result, dict):
|
||||
content = json.dumps(result)
|
||||
else:
|
||||
content = str(result) # Ensure it's a string
|
||||
except Exception as e:
|
||||
logger.error(f"Error JSON-encoding tool result for {tool_call_id}: {e}")
|
||||
content = json.dumps({"error": "Failed to encode tool result", "original_type": str(type(result))})
|
||||
|
||||
logger.debug(f"Formatting tool result for call ID {tool_call_id}")
|
||||
return {
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"content": content,
|
||||
}
|
||||
|
||||
def convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Converts internal tool format to OpenAI's format."""
|
||||
openai_tools = []
|
||||
logger.debug(f"Converting {len(tools)} tools to OpenAI format.")
|
||||
for tool in tools:
|
||||
server_name = tool.get("server_name")
|
||||
tool_name = tool.get("name")
|
||||
description = tool.get("description")
|
||||
input_schema = tool.get("inputSchema")
|
||||
|
||||
if not server_name or not tool_name or not description or not input_schema:
|
||||
logger.warning(f"Skipping invalid tool definition during conversion: {tool}")
|
||||
continue
|
||||
|
||||
# Prefix tool name with server name to avoid clashes and allow routing
|
||||
prefixed_tool_name = f"{server_name}__{tool_name}"
|
||||
|
||||
openai_tool_format = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": prefixed_tool_name,
|
||||
"description": description,
|
||||
"parameters": input_schema, # OpenAI uses JSON Schema directly
|
||||
},
|
||||
}
|
||||
openai_tools.append(openai_tool_format)
|
||||
logger.debug(f"Converted tool: {prefixed_tool_name}")
|
||||
|
||||
return openai_tools
|
||||
|
||||
# Helper needed by LLMClient's current tool handling logic
|
||||
def get_original_message_with_calls(self, response: ChatCompletion) -> dict[str, Any]:
|
||||
"""Extracts the assistant's message containing tool calls."""
|
||||
try:
|
||||
if isinstance(response, ChatCompletion) and response.choices[0].message.tool_calls:
|
||||
message = response.choices[0].message
|
||||
# Convert Pydantic model to dict for message history
|
||||
return message.model_dump(exclude_unset=True)
|
||||
else:
|
||||
logger.warning("Could not extract original message with tool calls from response.")
|
||||
# Return a placeholder or raise error?
|
||||
return {"role": "assistant", "content": "[Could not extract tool calls message]"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting original message with calls: {e}", exc_info=True)
|
||||
return {"role": "assistant", "content": f"[Error extracting tool calls message: {str(e)}]"}
|
||||
|
||||
|
||||
# Register this provider (if using the registration mechanism)
|
||||
# from . import register_provider
|
||||
# register_provider("openai", OpenAIProvider)
|
||||
Reference in New Issue
Block a user