feat: Implement async utilities for MCP server management and JSON-RPC communication
- Added `process.py` for managing MCP server subprocesses with async capabilities. - Introduced `protocol.py` for handling JSON-RPC communication over streams. - Created `llm_client.py` to support chat completion requests to various LLM providers, integrating with MCP tools. - Defined model configurations in `llm_models.py` for different LLM providers. - Removed the synchronous `mcp_manager.py` in favor of a more modular approach. - Established a provider framework in `providers` directory with a base class and specific implementations. - Implemented `OpenAIProvider` for interacting with OpenAI's API, including streaming support and tool call handling.
This commit is contained in:
@@ -1,11 +1,11 @@
|
|||||||
{
|
{
|
||||||
"mcpServers": {
|
"mcpServers": {
|
||||||
"dolphin-demo-database-sqlite": {
|
"mcp-server-sqlite": {
|
||||||
"command": "uvx",
|
"command": "uvx",
|
||||||
"args": [
|
"args": [
|
||||||
"mcp-server-sqlite",
|
"mcp-server-sqlite",
|
||||||
"--db-path",
|
"--db-path",
|
||||||
"~/.dolphin/dolphin.db"
|
"~/.mcpapp/mcpapp.db"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -61,6 +61,7 @@ lint.select = [
|
|||||||
"T10", # flake8-debugger
|
"T10", # flake8-debugger
|
||||||
"A", # flake8-builtins
|
"A", # flake8-builtins
|
||||||
"UP", # pyupgrade
|
"UP", # pyupgrade
|
||||||
|
"TID", # flake8-tidy-imports
|
||||||
]
|
]
|
||||||
|
|
||||||
lint.ignore = [
|
lint.ignore = [
|
||||||
@@ -81,7 +82,7 @@ skip-magic-trailing-comma = false
|
|||||||
combine-as-imports = true
|
combine-as-imports = true
|
||||||
|
|
||||||
[tool.ruff.lint.mccabe]
|
[tool.ruff.lint.mccabe]
|
||||||
max-complexity = 12
|
max-complexity = 16
|
||||||
|
|
||||||
[tool.ruff.lint.flake8-tidy-imports]
|
[tool.ruff.lint.flake8-tidy-imports]
|
||||||
# Disallow all relative imports.
|
# Disallow all relative imports.
|
||||||
|
|||||||
178
src/app.py
178
src/app.py
@@ -1,29 +1,111 @@
|
|||||||
import atexit
|
import atexit
|
||||||
|
import configparser
|
||||||
|
import json # For handling potential error JSON in stream
|
||||||
|
import logging
|
||||||
|
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
|
|
||||||
from openai_client import OpenAIClient
|
# Updated imports
|
||||||
|
from llm_client import LLMClient
|
||||||
|
from src.custom_mcp.manager import SyncMCPManager # Updated import path
|
||||||
|
|
||||||
|
# Configure logging for the app
|
||||||
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def init_session_state():
|
def init_session_state():
|
||||||
|
"""Initializes session state variables including clients."""
|
||||||
if "messages" not in st.session_state:
|
if "messages" not in st.session_state:
|
||||||
st.session_state.messages = []
|
st.session_state.messages = []
|
||||||
|
logger.info("Initialized session state: messages")
|
||||||
|
|
||||||
if "client" not in st.session_state:
|
if "client" not in st.session_state:
|
||||||
st.session_state.client = OpenAIClient()
|
logger.info("Attempting to initialize clients...")
|
||||||
# Register cleanup for MCP servers
|
try:
|
||||||
if hasattr(st.session_state.client, "mcp_manager"):
|
config = configparser.ConfigParser()
|
||||||
atexit.register(st.session_state.client.mcp_manager.shutdown)
|
# TODO: Improve config file path handling (e.g., environment variable, absolute path)
|
||||||
|
config_files_read = config.read("config/config.ini")
|
||||||
|
if not config_files_read:
|
||||||
|
raise FileNotFoundError("config.ini not found or could not be read.")
|
||||||
|
logger.info(f"Read configuration from: {config_files_read}")
|
||||||
|
|
||||||
|
# --- MCP Manager Setup ---
|
||||||
|
mcp_config_path = "config/mcp_config.json" # Default
|
||||||
|
if config.has_section("mcp") and config["mcp"].get("servers_json"):
|
||||||
|
mcp_config_path = config["mcp"]["servers_json"]
|
||||||
|
logger.info(f"Using MCP config path from config.ini: {mcp_config_path}")
|
||||||
|
else:
|
||||||
|
logger.info(f"Using default MCP config path: {mcp_config_path}")
|
||||||
|
|
||||||
|
mcp_manager = SyncMCPManager(mcp_config_path)
|
||||||
|
if not mcp_manager.initialize():
|
||||||
|
# Log warning but continue - LLMClient will operate without tools
|
||||||
|
logger.warning("MCP Manager failed to initialize. Proceeding without MCP tools.")
|
||||||
|
else:
|
||||||
|
logger.info("MCP Manager initialized successfully.")
|
||||||
|
# Register shutdown hook for MCP manager
|
||||||
|
atexit.register(mcp_manager.shutdown)
|
||||||
|
logger.info("Registered MCP Manager shutdown hook.")
|
||||||
|
|
||||||
|
# --- LLM Client Setup ---
|
||||||
|
# Determine provider details (e.g., from an [llm] section)
|
||||||
|
provider_name = "openai" # Default
|
||||||
|
model_name = None # Must be provided in config
|
||||||
|
api_key = None
|
||||||
|
base_url = None
|
||||||
|
|
||||||
|
# Prioritize [llm] section, fallback to [openai] for compatibility
|
||||||
|
if config.has_section("llm"):
|
||||||
|
logger.info("Reading configuration from [llm] section.")
|
||||||
|
provider_name = config["llm"].get("provider", provider_name)
|
||||||
|
model_name = config["llm"].get("model")
|
||||||
|
api_key = config["llm"].get("api_key")
|
||||||
|
base_url = config["llm"].get("base_url") # Optional
|
||||||
|
elif config.has_section("openai"):
|
||||||
|
logger.warning("Using legacy [openai] section for configuration.")
|
||||||
|
provider_name = "openai" # Force openai if using this section
|
||||||
|
model_name = config["openai"].get("model")
|
||||||
|
api_key = config["openai"].get("api_key")
|
||||||
|
base_url = config["openai"].get("base_url") # Optional
|
||||||
|
else:
|
||||||
|
raise ValueError("Missing [llm] or [openai] section in config.ini")
|
||||||
|
|
||||||
|
# Validate required config
|
||||||
|
if not api_key:
|
||||||
|
raise ValueError("Missing 'api_key' in config.ini ([llm] or [openai] section)")
|
||||||
|
if not model_name:
|
||||||
|
raise ValueError("Missing 'model' name in config.ini ([llm] or [openai] section)")
|
||||||
|
|
||||||
|
logger.info(f"Configuring LLMClient for provider: {provider_name}, model: {model_name}")
|
||||||
|
st.session_state.client = LLMClient(
|
||||||
|
provider_name=provider_name,
|
||||||
|
api_key=api_key,
|
||||||
|
mcp_manager=mcp_manager,
|
||||||
|
base_url=base_url, # Pass None if not provided
|
||||||
|
)
|
||||||
|
st.session_state.model_name = model_name # Store model name for chat requests
|
||||||
|
logger.info("LLMClient initialized successfully.")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize application clients: {e}", exc_info=True)
|
||||||
|
st.error(f"Application Initialization Error: {e}. Please check configuration and logs.")
|
||||||
|
# Stop the app if initialization fails critically
|
||||||
|
st.stop()
|
||||||
|
|
||||||
|
|
||||||
def display_chat_messages():
|
def display_chat_messages():
|
||||||
|
"""Displays chat messages stored in session state."""
|
||||||
for message in st.session_state.messages:
|
for message in st.session_state.messages:
|
||||||
with st.chat_message(message["role"]):
|
with st.chat_message(message["role"]):
|
||||||
|
# Simple markdown display for now
|
||||||
st.markdown(message["content"])
|
st.markdown(message["content"])
|
||||||
|
|
||||||
|
|
||||||
def handle_user_input():
|
def handle_user_input():
|
||||||
|
"""Handles user input, calls LLMClient, and displays the response."""
|
||||||
if prompt := st.chat_input("Type your message..."):
|
if prompt := st.chat_input("Type your message..."):
|
||||||
print(f"User input received: {prompt}") # Debug log
|
logger.info(f"User input received: '{prompt[:50]}...'")
|
||||||
st.session_state.messages.append({"role": "user", "content": prompt})
|
st.session_state.messages.append({"role": "user", "content": prompt})
|
||||||
with st.chat_message("user"):
|
with st.chat_message("user"):
|
||||||
st.markdown(prompt)
|
st.markdown(prompt)
|
||||||
@@ -32,39 +114,85 @@ def handle_user_input():
|
|||||||
with st.chat_message("assistant"):
|
with st.chat_message("assistant"):
|
||||||
response_placeholder = st.empty()
|
response_placeholder = st.empty()
|
||||||
full_response = ""
|
full_response = ""
|
||||||
|
error_occurred = False
|
||||||
|
|
||||||
print("Processing message...") # Debug log
|
logger.info("Processing message via LLMClient...")
|
||||||
response = st.session_state.client.get_chat_response(st.session_state.messages)
|
# Use the new client and method, always requesting stream for UI
|
||||||
|
response_stream = st.session_state.client.chat_completion(
|
||||||
|
messages=st.session_state.messages,
|
||||||
|
model=st.session_state.model_name, # Get model from session state
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
# Handle both MCP and standard OpenAI responses
|
# Handle the response (stream generator or error dict)
|
||||||
# Check if it's NOT a dict (assuming stream is not a dict)
|
if hasattr(response_stream, "__iter__") and not isinstance(response_stream, dict):
|
||||||
if not isinstance(response, dict):
|
logger.debug("Processing response stream...")
|
||||||
# Standard OpenAI streaming response
|
for chunk in response_stream:
|
||||||
for chunk in response:
|
# Check for potential error JSON yielded by the stream
|
||||||
# Ensure chunk has choices and delta before accessing
|
try:
|
||||||
if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
|
# Attempt to parse chunk as JSON only if it looks like it
|
||||||
full_response += chunk.choices[0].delta.content
|
if isinstance(chunk, str) and chunk.strip().startswith("{"):
|
||||||
response_placeholder.markdown(full_response + "▌")
|
error_data = json.loads(chunk)
|
||||||
|
if isinstance(error_data, dict) and "error" in error_data:
|
||||||
|
full_response = f"Error: {error_data['error']}"
|
||||||
|
logger.error(f"Error received in stream: {full_response}")
|
||||||
|
st.error(full_response)
|
||||||
|
error_occurred = True
|
||||||
|
break # Stop processing stream on error
|
||||||
|
# If not error JSON, treat as content chunk
|
||||||
|
if not error_occurred and isinstance(chunk, str):
|
||||||
|
full_response += chunk
|
||||||
|
response_placeholder.markdown(full_response + "▌") # Add cursor effect
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
# Not JSON or not error structure, treat as content chunk
|
||||||
|
if not error_occurred and isinstance(chunk, str):
|
||||||
|
full_response += chunk
|
||||||
|
response_placeholder.markdown(full_response + "▌") # Add cursor effect
|
||||||
|
|
||||||
|
if not error_occurred:
|
||||||
|
response_placeholder.markdown(full_response) # Final update without cursor
|
||||||
|
logger.debug("Stream processing complete.")
|
||||||
|
|
||||||
|
elif isinstance(response_stream, dict) and "error" in response_stream:
|
||||||
|
# Handle error dict returned directly (e.g., API error before streaming)
|
||||||
|
full_response = f"Error: {response_stream['error']}"
|
||||||
|
logger.error(f"Error returned directly from chat_completion: {full_response}")
|
||||||
|
st.error(full_response)
|
||||||
|
error_occurred = True
|
||||||
else:
|
else:
|
||||||
# MCP non-streaming response
|
# Unexpected response type
|
||||||
full_response = response.get("assistant_text", "")
|
full_response = "[Unexpected response format from LLMClient]"
|
||||||
response_placeholder.markdown(full_response)
|
logger.error(f"Unexpected response type: {type(response_stream)}")
|
||||||
|
st.error(full_response)
|
||||||
|
error_occurred = True
|
||||||
|
|
||||||
response_placeholder.markdown(full_response)
|
# Only add non-error, non-empty responses to history
|
||||||
|
if not error_occurred and full_response:
|
||||||
st.session_state.messages.append({"role": "assistant", "content": full_response})
|
st.session_state.messages.append({"role": "assistant", "content": full_response})
|
||||||
print("Message processed successfully") # Debug log
|
logger.info("Assistant response added to history.")
|
||||||
|
elif error_occurred:
|
||||||
|
logger.warning("Assistant response not added to history due to error.")
|
||||||
|
else:
|
||||||
|
logger.warning("Empty assistant response received, not added to history.")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
st.error(f"Error processing message: {str(e)}")
|
logger.error(f"Error during chat handling: {str(e)}", exc_info=True)
|
||||||
print(f"Error details: {str(e)}") # Debug log
|
st.error(f"An unexpected error occurred: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
st.title("Streamlit Chat App")
|
"""Main function to run the Streamlit app."""
|
||||||
|
st.title("MCP Chat App") # Updated title
|
||||||
|
try:
|
||||||
init_session_state()
|
init_session_state()
|
||||||
display_chat_messages()
|
display_chat_messages()
|
||||||
handle_user_input()
|
handle_user_input()
|
||||||
|
except Exception as e:
|
||||||
|
# Catch potential errors during rendering or handling
|
||||||
|
logger.critical(f"Critical error in main app flow: {e}", exc_info=True)
|
||||||
|
st.error(f"A critical application error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
logger.info("Starting Streamlit Chat App...")
|
||||||
main()
|
main()
|
||||||
|
|||||||
1
src/custom_mcp/__init__.py
Normal file
1
src/custom_mcp/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
# This file makes src/mcp a Python package
|
||||||
281
src/custom_mcp/client.py
Normal file
281
src/custom_mcp/client.py
Normal file
@@ -0,0 +1,281 @@
|
|||||||
|
# src/mcp/client.py
|
||||||
|
"""Client class for managing and interacting with a single MCP server process."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from custom_mcp import process, protocol
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Define reasonable timeouts
|
||||||
|
LIST_TOOLS_TIMEOUT = 20.0 # Seconds (using the increased value from previous step)
|
||||||
|
CALL_TOOL_TIMEOUT = 110.0 # Seconds
|
||||||
|
|
||||||
|
|
||||||
|
class MCPClient:
|
||||||
|
"""
|
||||||
|
Manages the lifecycle and async communication with a single MCP server process.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, server_name: str, command: str, args: list[str], config_env: dict[str, str]):
|
||||||
|
"""
|
||||||
|
Initializes the client for a specific server configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server_name: Unique name for the server (for logging).
|
||||||
|
command: The command executable.
|
||||||
|
args: List of arguments for the command.
|
||||||
|
config_env: Server-specific environment variables.
|
||||||
|
"""
|
||||||
|
self.server_name = server_name
|
||||||
|
self.command = command
|
||||||
|
self.args = args
|
||||||
|
self.config_env = config_env
|
||||||
|
self.process: asyncio.subprocess.Process | None = None
|
||||||
|
self.reader: asyncio.StreamReader | None = None
|
||||||
|
self.writer: asyncio.StreamWriter | None = None
|
||||||
|
self._stderr_task: asyncio.Task | None = None
|
||||||
|
self._request_counter = 0
|
||||||
|
self._is_running = False
|
||||||
|
self.logger = logging.getLogger(f"{__name__}.{self.server_name}") # Instance-specific logger
|
||||||
|
|
||||||
|
async def _log_stderr(self):
|
||||||
|
"""Logs stderr output from the server process."""
|
||||||
|
if not self.process or not self.process.stderr:
|
||||||
|
self.logger.debug("Stderr logging skipped: process or stderr not available.")
|
||||||
|
return
|
||||||
|
stderr_reader = self.process.stderr
|
||||||
|
try:
|
||||||
|
while not stderr_reader.at_eof():
|
||||||
|
line = await stderr_reader.readline()
|
||||||
|
if line:
|
||||||
|
self.logger.warning(f"[stderr] {line.decode().strip()}")
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
self.logger.debug("Stderr logging task cancelled.")
|
||||||
|
except Exception as e:
|
||||||
|
# Log errors but don't crash the logger task if possible
|
||||||
|
self.logger.error(f"Error reading stderr: {e}", exc_info=True)
|
||||||
|
finally:
|
||||||
|
self.logger.debug("Stderr logging task finished.")
|
||||||
|
|
||||||
|
async def start(self) -> bool:
|
||||||
|
"""
|
||||||
|
Starts the MCP server subprocess and sets up communication streams.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the process started successfully, False otherwise.
|
||||||
|
"""
|
||||||
|
if self._is_running:
|
||||||
|
self.logger.warning("Start called but client is already running.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
self.logger.info("Starting MCP server process...")
|
||||||
|
try:
|
||||||
|
self.process = await process.start_mcp_process(self.command, self.args, self.config_env)
|
||||||
|
self.reader = self.process.stdout
|
||||||
|
self.writer = self.process.stdin
|
||||||
|
|
||||||
|
if self.reader is None or self.writer is None:
|
||||||
|
self.logger.error("Failed to get stdout/stdin streams after process start.")
|
||||||
|
await self.stop() # Attempt cleanup
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Start background task to monitor stderr
|
||||||
|
self._stderr_task = asyncio.create_task(self._log_stderr())
|
||||||
|
|
||||||
|
# --- Start MCP Initialization Handshake ---
|
||||||
|
self.logger.info("Starting MCP initialization handshake...")
|
||||||
|
self._request_counter += 1
|
||||||
|
init_req_id = self._request_counter
|
||||||
|
initialize_req = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": init_req_id,
|
||||||
|
"method": "initialize",
|
||||||
|
"params": {
|
||||||
|
"protocolVersion": "2024-11-05", # Use a recent version
|
||||||
|
"clientInfo": {"name": "CustomMCPClient", "version": "1.0.0"}, # Identify the client
|
||||||
|
"capabilities": {}, # Client capabilities (can be empty)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Define a timeout for initialization
|
||||||
|
INITIALIZE_TIMEOUT = 15.0 # Seconds
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Send initialize request
|
||||||
|
await protocol.send_request(self.writer, initialize_req)
|
||||||
|
self.logger.debug(f"Sent 'initialize' request (ID: {init_req_id}). Waiting for response...")
|
||||||
|
|
||||||
|
# Wait for initialize response
|
||||||
|
init_response = await protocol.read_response(self.reader, INITIALIZE_TIMEOUT)
|
||||||
|
|
||||||
|
if init_response and init_response.get("id") == init_req_id:
|
||||||
|
if "error" in init_response:
|
||||||
|
self.logger.error(f"Server returned error during initialization: {init_response['error']}")
|
||||||
|
await self.stop()
|
||||||
|
return False
|
||||||
|
elif "result" in init_response:
|
||||||
|
self.logger.info(f"Received 'initialize' response: {init_response.get('result', '{}')}") # Log server capabilities if provided
|
||||||
|
|
||||||
|
# Send initialized notification (using standard method name)
|
||||||
|
initialized_notify = {"jsonrpc": "2.0", "method": "notifications/initialized", "params": {}}
|
||||||
|
await protocol.send_request(self.writer, initialized_notify)
|
||||||
|
self.logger.info("'notifications/initialized' notification sent.")
|
||||||
|
|
||||||
|
self._is_running = True
|
||||||
|
self.logger.info("MCP server process started and initialized successfully.")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
self.logger.error("Invalid 'initialize' response format (missing result/error).")
|
||||||
|
await self.stop()
|
||||||
|
return False
|
||||||
|
elif init_response:
|
||||||
|
self.logger.error(f"Received response with mismatched ID during initialization. Expected {init_req_id}, got {init_response.get('id')}")
|
||||||
|
await self.stop()
|
||||||
|
return False
|
||||||
|
else: # Timeout case
|
||||||
|
self.logger.error(f"'initialize' request timed out after {INITIALIZE_TIMEOUT} seconds.")
|
||||||
|
await self.stop()
|
||||||
|
return False
|
||||||
|
|
||||||
|
except ConnectionResetError:
|
||||||
|
self.logger.error("Connection lost during initialization handshake. Stopping client.")
|
||||||
|
await self.stop()
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Unexpected error during initialization handshake: {e}", exc_info=True)
|
||||||
|
await self.stop()
|
||||||
|
return False
|
||||||
|
# --- End MCP Initialization Handshake ---
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Failed to start MCP server process: {e}", exc_info=True)
|
||||||
|
self.process = None # Ensure process is None on failure
|
||||||
|
self.reader = None
|
||||||
|
self.writer = None
|
||||||
|
self._is_running = False
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
"""Stops the MCP server subprocess gracefully."""
|
||||||
|
if not self._is_running and not self.process:
|
||||||
|
self.logger.debug("Stop called but client is not running.")
|
||||||
|
return
|
||||||
|
|
||||||
|
self.logger.info("Stopping MCP server process...")
|
||||||
|
self._is_running = False # Mark as stopping
|
||||||
|
|
||||||
|
# Cancel stderr logging task
|
||||||
|
if self._stderr_task and not self._stderr_task.done():
|
||||||
|
self._stderr_task.cancel()
|
||||||
|
try:
|
||||||
|
await self._stderr_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
self.logger.debug("Stderr task successfully cancelled.")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error waiting for stderr task cancellation: {e}")
|
||||||
|
self._stderr_task = None
|
||||||
|
|
||||||
|
# Stop the process using the utility function
|
||||||
|
if self.process:
|
||||||
|
await process.stop_mcp_process(self.process, self.server_name)
|
||||||
|
|
||||||
|
# Nullify references
|
||||||
|
self.process = None
|
||||||
|
self.reader = None
|
||||||
|
self.writer = None
|
||||||
|
self.logger.info("MCP server process stopped.")
|
||||||
|
|
||||||
|
async def list_tools(self) -> list[dict[str, Any]] | None:
|
||||||
|
"""
|
||||||
|
Sends a 'tools/list' request and waits for the response.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of tool dictionaries, or None on error/timeout.
|
||||||
|
"""
|
||||||
|
if not self._is_running or not self.writer or not self.reader:
|
||||||
|
self.logger.error("Cannot list tools: client not running or streams unavailable.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
self._request_counter += 1
|
||||||
|
req_id = self._request_counter
|
||||||
|
request = {"jsonrpc": "2.0", "method": "tools/list", "id": req_id}
|
||||||
|
|
||||||
|
try:
|
||||||
|
await protocol.send_request(self.writer, request)
|
||||||
|
response = await protocol.read_response(self.reader, LIST_TOOLS_TIMEOUT)
|
||||||
|
|
||||||
|
if response and "result" in response and isinstance(response["result"], dict) and "tools" in response["result"]:
|
||||||
|
tools = response["result"]["tools"]
|
||||||
|
if isinstance(tools, list):
|
||||||
|
self.logger.info(f"Successfully listed {len(tools)} tools.")
|
||||||
|
return tools
|
||||||
|
else:
|
||||||
|
self.logger.error(f"Invalid 'tools' format in response ID {req_id}: {type(tools)}")
|
||||||
|
return None
|
||||||
|
elif response and "error" in response:
|
||||||
|
self.logger.error(f"Error response for listTools ID {req_id}: {response['error']}")
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
# Includes timeout case (read_response returns None)
|
||||||
|
self.logger.error(f"No valid response or timeout for listTools ID {req_id}.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
except ConnectionResetError:
|
||||||
|
self.logger.error("Connection lost during listTools request. Stopping client.")
|
||||||
|
await self.stop()
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Unexpected error during listTools: {e}", exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
|
"""
|
||||||
|
Sends a 'tools/call' request and waits for the response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: The name of the tool to call.
|
||||||
|
arguments: The arguments for the tool.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The result dictionary from the server, or None on error/timeout.
|
||||||
|
"""
|
||||||
|
if not self._is_running or not self.writer or not self.reader:
|
||||||
|
self.logger.error(f"Cannot call tool '{tool_name}': client not running or streams unavailable.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
self._request_counter += 1
|
||||||
|
req_id = self._request_counter
|
||||||
|
request = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"method": "tools/call",
|
||||||
|
"params": {"name": tool_name, "arguments": arguments},
|
||||||
|
"id": req_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
await protocol.send_request(self.writer, request)
|
||||||
|
response = await protocol.read_response(self.reader, CALL_TOOL_TIMEOUT)
|
||||||
|
|
||||||
|
if response and "result" in response:
|
||||||
|
# Assuming result is the desired payload
|
||||||
|
self.logger.info(f"Tool '{tool_name}' executed successfully.")
|
||||||
|
return response["result"]
|
||||||
|
elif response and "error" in response:
|
||||||
|
self.logger.error(f"Error response for tool '{tool_name}' ID {req_id}: {response['error']}")
|
||||||
|
# Return the error structure itself? Or just None? Returning error dict for now.
|
||||||
|
return {"error": response["error"]}
|
||||||
|
else:
|
||||||
|
# Includes timeout case
|
||||||
|
self.logger.error(f"No valid response or timeout for tool '{tool_name}' ID {req_id}.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
except ConnectionResetError:
|
||||||
|
self.logger.error(f"Connection lost during callTool '{tool_name}'. Stopping client.")
|
||||||
|
await self.stop()
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Unexpected error during callTool '{tool_name}': {e}", exc_info=True)
|
||||||
|
return None
|
||||||
366
src/custom_mcp/manager.py
Normal file
366
src/custom_mcp/manager.py
Normal file
@@ -0,0 +1,366 @@
|
|||||||
|
# src/mcp/manager.py
|
||||||
|
"""Synchronous manager for multiple MCPClient instances."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
# Use relative imports within the mcp package
|
||||||
|
from custom_mcp.client import MCPClient
|
||||||
|
|
||||||
|
# Configure basic logging
|
||||||
|
# Consider moving this to the main app entry point if not already done
|
||||||
|
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Define reasonable timeouts for sync calls (should be slightly longer than async timeouts)
|
||||||
|
INITIALIZE_TIMEOUT = 60.0 # Seconds
|
||||||
|
SHUTDOWN_TIMEOUT = 30.0 # Seconds
|
||||||
|
LIST_ALL_TOOLS_TIMEOUT = 30.0 # Seconds
|
||||||
|
EXECUTE_TOOL_TIMEOUT = 120.0 # Seconds
|
||||||
|
|
||||||
|
|
||||||
|
class SyncMCPManager:
|
||||||
|
"""
|
||||||
|
Manages the lifecycle of multiple MCPClient instances and provides a
|
||||||
|
synchronous interface to interact with them using a background event loop.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config_path: str = "config/mcp_config.json"):
|
||||||
|
"""
|
||||||
|
Initializes the manager, loads config, but does not start servers yet.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_path: Path to the MCP server configuration JSON file.
|
||||||
|
"""
|
||||||
|
self.config_path = config_path
|
||||||
|
self.config: dict[str, Any] | None = None
|
||||||
|
# Stores server_name -> MCPClient instance
|
||||||
|
self.servers: dict[str, MCPClient] = {}
|
||||||
|
self.initialized = False
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self._loop: asyncio.AbstractEventLoop | None = None
|
||||||
|
self._thread: threading.Thread | None = None
|
||||||
|
logger.info(f"Initializing SyncMCPManager with config path: {config_path}")
|
||||||
|
self._load_config()
|
||||||
|
|
||||||
|
def _load_config(self):
|
||||||
|
"""Load MCP configuration from JSON file."""
|
||||||
|
logger.debug(f"Attempting to load MCP config from: {self.config_path}")
|
||||||
|
try:
|
||||||
|
# Using direct file access
|
||||||
|
with open(self.config_path) as f:
|
||||||
|
self.config = json.load(f)
|
||||||
|
logger.info("MCP configuration loaded successfully.")
|
||||||
|
logger.debug(f"Config content: {self.config}")
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.error(f"MCP config file not found at {self.config_path}")
|
||||||
|
self.config = None
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"Error decoding JSON from MCP config file {self.config_path}: {e}")
|
||||||
|
self.config = None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading MCP config from {self.config_path}: {e}", exc_info=True)
|
||||||
|
self.config = None
|
||||||
|
|
||||||
|
# --- Background Event Loop Management ---
|
||||||
|
|
||||||
|
def _run_event_loop(self):
|
||||||
|
"""Target function for the background event loop thread."""
|
||||||
|
try:
|
||||||
|
self._loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(self._loop)
|
||||||
|
self._loop.run_forever()
|
||||||
|
finally:
|
||||||
|
if self._loop and not self._loop.is_closed():
|
||||||
|
# Clean up remaining tasks before closing
|
||||||
|
try:
|
||||||
|
tasks = asyncio.all_tasks(self._loop)
|
||||||
|
if tasks:
|
||||||
|
logger.debug(f"Cancelling {len(tasks)} outstanding tasks before closing loop...")
|
||||||
|
for task in tasks:
|
||||||
|
task.cancel()
|
||||||
|
# Allow cancellation to propagate
|
||||||
|
self._loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
|
||||||
|
logger.debug("Outstanding tasks cancelled.")
|
||||||
|
self._loop.run_until_complete(self._loop.shutdown_asyncgens())
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during event loop cleanup: {e}")
|
||||||
|
finally:
|
||||||
|
self._loop.close()
|
||||||
|
asyncio.set_event_loop(None)
|
||||||
|
logger.info("Event loop thread finished.")
|
||||||
|
|
||||||
|
def _start_event_loop_thread(self):
|
||||||
|
"""Starts the background event loop thread if not already running."""
|
||||||
|
if self._thread is None or not self._thread.is_alive():
|
||||||
|
self._thread = threading.Thread(target=self._run_event_loop, name="MCPEventLoop", daemon=True)
|
||||||
|
self._thread.start()
|
||||||
|
logger.info("Event loop thread started.")
|
||||||
|
# Wait briefly for the loop to become available and running
|
||||||
|
while self._loop is None or not self._loop.is_running():
|
||||||
|
# Use time.sleep in sync context
|
||||||
|
import time
|
||||||
|
|
||||||
|
time.sleep(0.01)
|
||||||
|
logger.debug("Event loop is running.")
|
||||||
|
|
||||||
|
def _stop_event_loop_thread(self):
|
||||||
|
"""Stops the background event loop thread."""
|
||||||
|
if self._loop and self._loop.is_running():
|
||||||
|
logger.info("Requesting event loop stop...")
|
||||||
|
self._loop.call_soon_threadsafe(self._loop.stop)
|
||||||
|
if self._thread and self._thread.is_alive():
|
||||||
|
logger.info("Waiting for event loop thread to join...")
|
||||||
|
self._thread.join(timeout=5)
|
||||||
|
if self._thread.is_alive():
|
||||||
|
logger.warning("Event loop thread did not stop gracefully.")
|
||||||
|
self._loop = None
|
||||||
|
self._thread = None
|
||||||
|
logger.info("Event loop stopped.")
|
||||||
|
|
||||||
|
# --- Public Synchronous Interface ---
|
||||||
|
|
||||||
|
def initialize(self) -> bool:
|
||||||
|
"""
|
||||||
|
Initializes and starts all configured MCP servers synchronously.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if all servers started successfully, False otherwise.
|
||||||
|
"""
|
||||||
|
logger.info("Manager initialization requested.")
|
||||||
|
if not self.config or not self.config.get("mcpServers"):
|
||||||
|
logger.warning("Initialization skipped: No valid configuration loaded.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
if self.initialized:
|
||||||
|
logger.debug("Initialization skipped: Already initialized.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
self._start_event_loop_thread()
|
||||||
|
if not self._loop:
|
||||||
|
logger.error("Failed to start event loop for initialization.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
logger.info("Submitting asynchronous server initialization...")
|
||||||
|
|
||||||
|
# Prepare coroutine to start all clients
|
||||||
|
|
||||||
|
async def _async_init_all():
|
||||||
|
tasks = []
|
||||||
|
for server_name, server_config in self.config["mcpServers"].items():
|
||||||
|
command = server_config.get("command")
|
||||||
|
args = server_config.get("args", [])
|
||||||
|
config_env = server_config.get("env", {})
|
||||||
|
if not command:
|
||||||
|
logger.error(f"Skipping server {server_name}: Missing 'command'.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
client = MCPClient(server_name, command, args, config_env)
|
||||||
|
self.servers[server_name] = client
|
||||||
|
tasks.append(client.start()) # Append the start coroutine
|
||||||
|
|
||||||
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
# Check results - True means success, False or Exception means failure
|
||||||
|
all_success = True
|
||||||
|
failed_servers = []
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
server_name = list(self.config["mcpServers"].keys())[i] # Assumes order is maintained
|
||||||
|
if isinstance(result, Exception) or result is False:
|
||||||
|
all_success = False
|
||||||
|
failed_servers.append(server_name)
|
||||||
|
# Remove failed client from managed servers
|
||||||
|
if server_name in self.servers:
|
||||||
|
del self.servers[server_name]
|
||||||
|
logger.error(f"Failed to start client for server '{server_name}'. Result/Error: {result}")
|
||||||
|
|
||||||
|
if not all_success:
|
||||||
|
logger.error(f"Initialization failed for servers: {failed_servers}")
|
||||||
|
return all_success
|
||||||
|
|
||||||
|
# Run the initialization coroutine in the background loop
|
||||||
|
future = asyncio.run_coroutine_threadsafe(_async_init_all(), self._loop)
|
||||||
|
try:
|
||||||
|
success = future.result(timeout=INITIALIZE_TIMEOUT)
|
||||||
|
if success:
|
||||||
|
logger.info("Asynchronous initialization completed successfully.")
|
||||||
|
self.initialized = True
|
||||||
|
else:
|
||||||
|
logger.error("Asynchronous initialization failed.")
|
||||||
|
self.initialized = False
|
||||||
|
# Attempt to clean up any partially started servers
|
||||||
|
self.shutdown() # Call sync shutdown
|
||||||
|
except TimeoutError:
|
||||||
|
logger.error(f"Initialization timed out after {INITIALIZE_TIMEOUT}s.")
|
||||||
|
self.initialized = False
|
||||||
|
self.shutdown() # Clean up
|
||||||
|
success = False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Exception during initialization future result: {e}", exc_info=True)
|
||||||
|
self.initialized = False
|
||||||
|
self.shutdown() # Clean up
|
||||||
|
success = False
|
||||||
|
|
||||||
|
return self.initialized
|
||||||
|
|
||||||
|
def shutdown(self):
|
||||||
|
"""Shuts down all managed MCP servers synchronously."""
|
||||||
|
logger.info("Manager shutdown requested.")
|
||||||
|
with self._lock:
|
||||||
|
# Check servers dict too, in case init was partial
|
||||||
|
if not self.initialized and not self.servers:
|
||||||
|
logger.debug("Shutdown skipped: Not initialized or no servers running.")
|
||||||
|
# Ensure loop is stopped if it exists
|
||||||
|
if self._thread and self._thread.is_alive():
|
||||||
|
self._stop_event_loop_thread()
|
||||||
|
return
|
||||||
|
|
||||||
|
if not self._loop or not self._loop.is_running():
|
||||||
|
logger.warning("Shutdown requested but event loop not running. Attempting direct cleanup.")
|
||||||
|
# Attempt direct cleanup if loop isn't running (shouldn't happen ideally)
|
||||||
|
# This part is tricky as MCPClient.stop is async.
|
||||||
|
# For simplicity, we might just log and rely on process termination on app exit.
|
||||||
|
# Or, try a temporary loop just for shutdown? Let's stick to stopping the thread for now.
|
||||||
|
self.servers = {}
|
||||||
|
self.initialized = False
|
||||||
|
if self._thread and self._thread.is_alive():
|
||||||
|
self._stop_event_loop_thread()
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("Submitting asynchronous server shutdown...")
|
||||||
|
|
||||||
|
# Prepare coroutine to stop all clients
|
||||||
|
|
||||||
|
async def _async_shutdown_all():
|
||||||
|
tasks = [client.stop() for client in self.servers.values()]
|
||||||
|
if tasks:
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
# Run the shutdown coroutine in the background loop
|
||||||
|
future = asyncio.run_coroutine_threadsafe(_async_shutdown_all(), self._loop)
|
||||||
|
try:
|
||||||
|
future.result(timeout=SHUTDOWN_TIMEOUT)
|
||||||
|
logger.info("Asynchronous shutdown completed.")
|
||||||
|
except TimeoutError:
|
||||||
|
logger.error(f"Shutdown timed out after {SHUTDOWN_TIMEOUT}s. Event loop will be stopped.")
|
||||||
|
# Processes might still be running, OS will clean up on exit hopefully
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Exception during shutdown future result: {e}", exc_info=True)
|
||||||
|
finally:
|
||||||
|
# Always mark as uninitialized and clear servers dict
|
||||||
|
self.servers = {}
|
||||||
|
self.initialized = False
|
||||||
|
# Stop the background thread
|
||||||
|
self._stop_event_loop_thread()
|
||||||
|
|
||||||
|
logger.info("Manager shutdown complete.")
|
||||||
|
|
||||||
|
def list_all_tools(self) -> list[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Retrieves tools from all initialized MCP servers synchronously.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of tool definitions in the standard internal format,
|
||||||
|
aggregated from all servers. Returns empty list on failure.
|
||||||
|
"""
|
||||||
|
if not self.initialized or not self.servers:
|
||||||
|
logger.warning("Cannot list tools: Manager not initialized or no servers running.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
if not self._loop or not self._loop.is_running():
|
||||||
|
logger.error("Cannot list tools: Event loop not running.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
logger.info(f"Requesting tools from {len(self.servers)} servers...")
|
||||||
|
|
||||||
|
# Prepare coroutine to list tools from all clients
|
||||||
|
async def _async_list_all():
|
||||||
|
tasks = []
|
||||||
|
server_names_in_order = []
|
||||||
|
for server_name, client in self.servers.items():
|
||||||
|
tasks.append(client.list_tools())
|
||||||
|
server_names_in_order.append(server_name)
|
||||||
|
|
||||||
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
all_tools = []
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
server_name = server_names_in_order[i]
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
logger.error(f"Error listing tools for server '{server_name}': {result}")
|
||||||
|
elif result is None:
|
||||||
|
# MCPClient.list_tools returns None on timeout/error
|
||||||
|
logger.error(f"Failed to list tools for server '{server_name}' (timeout or error).")
|
||||||
|
elif isinstance(result, list):
|
||||||
|
# Add server_name to each tool definition
|
||||||
|
for tool in result:
|
||||||
|
tool["server_name"] = server_name
|
||||||
|
all_tools.extend(result)
|
||||||
|
logger.debug(f"Received {len(result)} tools from {server_name}")
|
||||||
|
else:
|
||||||
|
logger.error(f"Unexpected result type ({type(result)}) when listing tools for {server_name}.")
|
||||||
|
return all_tools
|
||||||
|
|
||||||
|
# Run the coroutine in the background loop
|
||||||
|
future = asyncio.run_coroutine_threadsafe(_async_list_all(), self._loop)
|
||||||
|
try:
|
||||||
|
aggregated_tools = future.result(timeout=LIST_ALL_TOOLS_TIMEOUT)
|
||||||
|
logger.info(f"Aggregated {len(aggregated_tools)} tools from all servers.")
|
||||||
|
return aggregated_tools
|
||||||
|
except TimeoutError:
|
||||||
|
logger.error(f"Listing all tools timed out after {LIST_ALL_TOOLS_TIMEOUT}s.")
|
||||||
|
return []
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Exception during listing all tools future result: {e}", exc_info=True)
|
||||||
|
return []
|
||||||
|
|
||||||
|
def execute_tool(self, server_name: str, tool_name: str, arguments: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
|
"""
|
||||||
|
Executes a specific tool on the designated MCP server synchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server_name: The name of the server hosting the tool.
|
||||||
|
tool_name: The name of the tool to execute.
|
||||||
|
arguments: A dictionary of arguments for the tool.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The result content from the tool execution (dict),
|
||||||
|
an error dict ({"error": ...}), or None on timeout/comm failure.
|
||||||
|
"""
|
||||||
|
if not self.initialized:
|
||||||
|
logger.warning(f"Cannot execute tool '{tool_name}' on {server_name}: Manager not initialized.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
client = self.servers.get(server_name)
|
||||||
|
if not client:
|
||||||
|
logger.error(f"Cannot execute tool: Server '{server_name}' not found.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not self._loop or not self._loop.is_running():
|
||||||
|
logger.error(f"Cannot execute tool '{tool_name}': Event loop not running.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
logger.info(f"Executing tool '{tool_name}' on server '{server_name}' with args: {arguments}")
|
||||||
|
|
||||||
|
# Run the client's call_tool coroutine in the background loop
|
||||||
|
future = asyncio.run_coroutine_threadsafe(client.call_tool(tool_name, arguments), self._loop)
|
||||||
|
try:
|
||||||
|
result = future.result(timeout=EXECUTE_TOOL_TIMEOUT)
|
||||||
|
# MCPClient.call_tool returns the result dict or an error dict or None
|
||||||
|
if result is None:
|
||||||
|
logger.error(f"Tool execution '{tool_name}' on {server_name} failed (timeout or comm error).")
|
||||||
|
elif isinstance(result, dict) and "error" in result:
|
||||||
|
logger.error(f"Tool execution '{tool_name}' on {server_name} returned error: {result['error']}")
|
||||||
|
else:
|
||||||
|
logger.info(f"Tool '{tool_name}' execution successful.")
|
||||||
|
return result # Return result dict, error dict, or None
|
||||||
|
except TimeoutError:
|
||||||
|
logger.error(f"Tool execution timed out after {EXECUTE_TOOL_TIMEOUT}s for '{tool_name}' on {server_name}.")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Exception during tool execution future result for '{tool_name}' on {server_name}: {e}", exc_info=True)
|
||||||
|
return None
|
||||||
128
src/custom_mcp/process.py
Normal file
128
src/custom_mcp/process.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
# src/mcp/process.py
|
||||||
|
"""Async utilities for managing MCP server subprocesses."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def start_mcp_process(command: str, args: list[str], config_env: dict[str, str]) -> asyncio.subprocess.Process:
|
||||||
|
"""
|
||||||
|
Starts an MCP server subprocess using asyncio.create_subprocess_shell.
|
||||||
|
|
||||||
|
Handles argument expansion and environment merging.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
command: The main command executable.
|
||||||
|
args: A list of arguments for the command.
|
||||||
|
config_env: Server-specific environment variables from config.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The started asyncio.subprocess.Process object.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If the command is not found.
|
||||||
|
Exception: For other errors during subprocess creation.
|
||||||
|
"""
|
||||||
|
logger.debug(f"Preparing to start process for command: {command}")
|
||||||
|
|
||||||
|
# --- Add tilde expansion for arguments ---
|
||||||
|
expanded_args = []
|
||||||
|
try:
|
||||||
|
for arg in args:
|
||||||
|
if isinstance(arg, str) and "~" in arg:
|
||||||
|
expanded_args.append(os.path.expanduser(arg))
|
||||||
|
else:
|
||||||
|
# Ensure all args are strings for list2cmdline
|
||||||
|
expanded_args.append(str(arg))
|
||||||
|
logger.debug(f"Expanded args: {expanded_args}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error expanding arguments for {command}: {e}", exc_info=True)
|
||||||
|
raise ValueError(f"Failed to expand arguments: {e}") from e
|
||||||
|
|
||||||
|
# --- Merge os.environ with config_env ---
|
||||||
|
merged_env = {**os.environ, **config_env}
|
||||||
|
# logger.debug(f"Merged environment prepared (keys: {list(merged_env.keys())})") # Avoid logging values
|
||||||
|
|
||||||
|
# Combine command and expanded args into a single string for shell execution
|
||||||
|
try:
|
||||||
|
cmd_string = subprocess.list2cmdline([command] + expanded_args)
|
||||||
|
logger.debug(f"Executing shell command: {cmd_string}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating command string: {e}", exc_info=True)
|
||||||
|
raise ValueError(f"Failed to create command string: {e}") from e
|
||||||
|
|
||||||
|
# --- Start the subprocess using shell ---
|
||||||
|
try:
|
||||||
|
process = await asyncio.create_subprocess_shell(
|
||||||
|
cmd_string,
|
||||||
|
stdin=subprocess.PIPE,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
env=merged_env,
|
||||||
|
)
|
||||||
|
logger.info(f"Subprocess started (PID: {process.pid}) for command: {command}")
|
||||||
|
return process
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.error(f"Command not found: '{command}' when trying to execute '{cmd_string}'")
|
||||||
|
raise # Re-raise specific error
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create subprocess for '{cmd_string}': {e}", exc_info=True)
|
||||||
|
raise # Re-raise other errors
|
||||||
|
|
||||||
|
|
||||||
|
async def stop_mcp_process(process: asyncio.subprocess.Process, server_name: str = "MCP Server"):
|
||||||
|
"""
|
||||||
|
Attempts to gracefully stop the MCP server subprocess.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
process: The asyncio.subprocess.Process object to stop.
|
||||||
|
server_name: A name for logging purposes.
|
||||||
|
"""
|
||||||
|
if process is None or process.returncode is not None:
|
||||||
|
logger.debug(f"Process {server_name} (PID: {process.pid if process else 'N/A'}) already stopped or not started.")
|
||||||
|
return
|
||||||
|
|
||||||
|
pid = process.pid
|
||||||
|
logger.info(f"Attempting to stop process {server_name} (PID: {pid})...")
|
||||||
|
|
||||||
|
# Close stdin first
|
||||||
|
if process.stdin and not process.stdin.is_closing():
|
||||||
|
try:
|
||||||
|
process.stdin.close()
|
||||||
|
await process.stdin.wait_closed()
|
||||||
|
logger.debug(f"Stdin closed for {server_name} (PID: {pid})")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error closing stdin for {server_name} (PID: {pid}): {e}")
|
||||||
|
|
||||||
|
# Attempt graceful termination
|
||||||
|
try:
|
||||||
|
process.terminate()
|
||||||
|
logger.debug(f"Sent terminate signal to {server_name} (PID: {pid})")
|
||||||
|
await asyncio.wait_for(process.wait(), timeout=5.0)
|
||||||
|
logger.info(f"Process {server_name} (PID: {pid}) terminated gracefully (return code: {process.returncode}).")
|
||||||
|
except TimeoutError:
|
||||||
|
logger.warning(f"Process {server_name} (PID: {pid}) did not terminate gracefully after 5s, killing.")
|
||||||
|
try:
|
||||||
|
process.kill()
|
||||||
|
await process.wait() # Wait for kill to complete
|
||||||
|
logger.info(f"Process {server_name} (PID: {pid}) killed (return code: {process.returncode}).")
|
||||||
|
except ProcessLookupError:
|
||||||
|
logger.warning(f"Process {server_name} (PID: {pid}) already exited before kill.")
|
||||||
|
except Exception as e_kill:
|
||||||
|
logger.error(f"Error killing process {server_name} (PID: {pid}): {e_kill}")
|
||||||
|
except ProcessLookupError:
|
||||||
|
logger.warning(f"Process {server_name} (PID: {pid}) already exited before termination.")
|
||||||
|
except Exception as e_term:
|
||||||
|
logger.error(f"Error during termination of {server_name} (PID: {pid}): {e_term}")
|
||||||
|
# Attempt kill as fallback if terminate failed and process might still be running
|
||||||
|
if process.returncode is None:
|
||||||
|
try:
|
||||||
|
process.kill()
|
||||||
|
await process.wait()
|
||||||
|
logger.info(f"Process {server_name} (PID: {pid}) killed after termination error (return code: {process.returncode}).")
|
||||||
|
except Exception as e_kill_fallback:
|
||||||
|
logger.error(f"Error killing process {server_name} (PID: {pid}) after termination error: {e_kill_fallback}")
|
||||||
75
src/custom_mcp/protocol.py
Normal file
75
src/custom_mcp/protocol.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
# src/mcp/protocol.py
|
||||||
|
"""Async utilities for MCP JSON-RPC communication over streams."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def send_request(writer: asyncio.StreamWriter, request_dict: dict[str, Any]):
|
||||||
|
"""
|
||||||
|
Sends a JSON-RPC request dictionary to the MCP server's stdin stream.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
writer: The asyncio StreamWriter connected to the process stdin.
|
||||||
|
request_dict: The request dictionary to send.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ConnectionResetError: If the connection is lost during send.
|
||||||
|
Exception: For other stream writing errors.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
request_json = json.dumps(request_dict) + "\n"
|
||||||
|
writer.write(request_json.encode("utf-8"))
|
||||||
|
await writer.drain()
|
||||||
|
logger.debug(f"Sent request ID {request_dict.get('id')}: {request_json.strip()}")
|
||||||
|
except ConnectionResetError:
|
||||||
|
logger.error(f"Connection lost while sending request ID {request_dict.get('id')}")
|
||||||
|
raise # Re-raise for the caller (MCPClient) to handle
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error sending request ID {request_dict.get('id')}: {e}", exc_info=True)
|
||||||
|
raise # Re-raise for the caller
|
||||||
|
|
||||||
|
|
||||||
|
async def read_response(reader: asyncio.StreamReader, timeout: float) -> dict[str, Any] | None:
|
||||||
|
"""
|
||||||
|
Reads and parses a JSON-RPC response line from the MCP server's stdout stream.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reader: The asyncio StreamReader connected to the process stdout.
|
||||||
|
timeout: Seconds to wait for a response line.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The parsed response dictionary, or None if timeout or error occurs.
|
||||||
|
"""
|
||||||
|
response_str = None
|
||||||
|
try:
|
||||||
|
response_json = await asyncio.wait_for(reader.readline(), timeout=timeout)
|
||||||
|
if not response_json:
|
||||||
|
logger.warning("Received empty response line (EOF?).")
|
||||||
|
return None
|
||||||
|
|
||||||
|
response_str = response_json.decode("utf-8").strip()
|
||||||
|
if not response_str:
|
||||||
|
logger.warning("Received empty response string after strip.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
logger.debug(f"Received response line: {response_str}")
|
||||||
|
response_dict = json.loads(response_str)
|
||||||
|
return response_dict
|
||||||
|
|
||||||
|
except TimeoutError:
|
||||||
|
logger.error(f"Timeout ({timeout}s) waiting for response.")
|
||||||
|
return None
|
||||||
|
except asyncio.IncompleteReadError:
|
||||||
|
logger.warning("Connection closed while waiting for response.")
|
||||||
|
return None
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"Error decoding JSON response: {e}. Response: '{response_str}'")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error reading response: {e}", exc_info=True)
|
||||||
|
return None
|
||||||
219
src/llm_client.py
Normal file
219
src/llm_client.py
Normal file
@@ -0,0 +1,219 @@
|
|||||||
|
# src/llm_client.py
|
||||||
|
"""
|
||||||
|
Generic LLM client supporting multiple providers and MCP tool integration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from collections.abc import Generator
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from src.custom_mcp.manager import SyncMCPManager # Updated import path
|
||||||
|
from src.providers import BaseProvider, create_llm_provider
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LLMClient:
|
||||||
|
"""
|
||||||
|
Handles chat completion requests to various LLM providers through a unified
|
||||||
|
interface, integrating with MCP tools via SyncMCPManager.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
provider_name: str,
|
||||||
|
api_key: str,
|
||||||
|
mcp_manager: SyncMCPManager,
|
||||||
|
base_url: str | None = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the LLM client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_name: Name of the provider (e.g., 'openai', 'anthropic').
|
||||||
|
api_key: API key for the provider.
|
||||||
|
mcp_manager: An initialized instance of SyncMCPManager.
|
||||||
|
base_url: Optional base URL for the provider API.
|
||||||
|
"""
|
||||||
|
logger.info(f"Initializing LLMClient for provider: {provider_name}")
|
||||||
|
self.provider: BaseProvider = create_llm_provider(provider_name, api_key, base_url)
|
||||||
|
self.mcp_manager = mcp_manager
|
||||||
|
self.mcp_tools: list[dict[str, Any]] = []
|
||||||
|
self._refresh_mcp_tools() # Initial tool load
|
||||||
|
|
||||||
|
def _refresh_mcp_tools(self):
|
||||||
|
"""Retrieves the latest tools from the MCP manager."""
|
||||||
|
logger.info("Refreshing MCP tools...")
|
||||||
|
try:
|
||||||
|
self.mcp_tools = self.mcp_manager.list_all_tools()
|
||||||
|
logger.info(f"Refreshed {len(self.mcp_tools)} MCP tools.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error refreshing MCP tools: {e}", exc_info=True)
|
||||||
|
# Keep existing tools if refresh fails
|
||||||
|
|
||||||
|
def chat_completion(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
model: str,
|
||||||
|
temperature: float = 0.4,
|
||||||
|
max_tokens: int | None = None,
|
||||||
|
stream: bool = True,
|
||||||
|
) -> Generator[str, None, None] | dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Send a chat completion request, handling potential tool calls.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of message dictionaries ({'role': 'user'/'assistant', 'content': ...}).
|
||||||
|
model: Model identifier string.
|
||||||
|
temperature: Sampling temperature.
|
||||||
|
max_tokens: Maximum tokens to generate.
|
||||||
|
stream: Whether to stream the response.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
If stream=True: A generator yielding content chunks.
|
||||||
|
If stream=False: A dictionary containing the final content or an error.
|
||||||
|
e.g., {"content": "..."} or {"error": "..."}
|
||||||
|
"""
|
||||||
|
# Ensure tools are up-to-date (optional, could be done less frequently)
|
||||||
|
# self._refresh_mcp_tools()
|
||||||
|
|
||||||
|
# Convert tools to the provider-specific format
|
||||||
|
try:
|
||||||
|
provider_tools = self.provider.convert_tools(self.mcp_tools)
|
||||||
|
logger.debug(f"Converted {len(self.mcp_tools)} tools for provider {self.provider.__class__.__name__}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error converting tools for provider: {e}", exc_info=True)
|
||||||
|
provider_tools = None # Proceed without tools if conversion fails
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(f"Sending chat completion request to provider with model: {model}")
|
||||||
|
response = self.provider.create_chat_completion(
|
||||||
|
messages=messages,
|
||||||
|
model=model,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
stream=stream,
|
||||||
|
tools=provider_tools,
|
||||||
|
)
|
||||||
|
logger.info("Received response from provider.")
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
# Streaming with tool calls requires more complex handling (like airflow_wingman's
|
||||||
|
# process_tool_calls_and_follow_up). For now, we'll yield the initial stream
|
||||||
|
# and handle tool calls *after* the stream completes if detected (less ideal UX).
|
||||||
|
# A better approach involves checking for tool calls before streaming fully.
|
||||||
|
# This simplified version just streams the first response.
|
||||||
|
logger.info("Streaming response...")
|
||||||
|
# NOTE: This simple version doesn't handle tool calls during streaming well.
|
||||||
|
# It will stream the initial response which might *contain* the tool call request,
|
||||||
|
# but won't execute it within the stream.
|
||||||
|
return self._stream_generator(response)
|
||||||
|
|
||||||
|
else: # Non-streaming
|
||||||
|
logger.info("Processing non-streaming response...")
|
||||||
|
if self.provider.has_tool_calls(response):
|
||||||
|
logger.info("Tool calls detected in response.")
|
||||||
|
# Simplified non-streaming tool call handling (one round)
|
||||||
|
try:
|
||||||
|
tool_calls = self.provider.parse_tool_calls(response)
|
||||||
|
logger.debug(f"Parsed tool calls: {tool_calls}")
|
||||||
|
|
||||||
|
tool_results = []
|
||||||
|
original_message_with_calls = self.provider.get_original_message_with_calls(response) # Provider needs to implement this
|
||||||
|
messages.append(original_message_with_calls) # Add assistant's turn with tool requests
|
||||||
|
|
||||||
|
for tool_call in tool_calls:
|
||||||
|
server_name = tool_call.get("server_name") # Needs to be parsed by provider
|
||||||
|
func_name = tool_call.get("function_name")
|
||||||
|
func_args_str = tool_call.get("arguments")
|
||||||
|
call_id = tool_call.get("id")
|
||||||
|
|
||||||
|
if not server_name or not func_name or func_args_str is None or call_id is None:
|
||||||
|
logger.error(f"Skipping invalid tool call data: {tool_call}")
|
||||||
|
# Add error result?
|
||||||
|
result_content = {"error": "Invalid tool call structure from LLM"}
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
# Arguments might be a JSON string, parse them
|
||||||
|
arguments = json.loads(func_args_str)
|
||||||
|
logger.info(f"Executing tool '{func_name}' on server '{server_name}' with args: {arguments}")
|
||||||
|
# Execute synchronously using the manager
|
||||||
|
execution_result = self.mcp_manager.execute_tool(server_name, func_name, arguments)
|
||||||
|
logger.debug(f"Tool execution result: {execution_result}")
|
||||||
|
|
||||||
|
if execution_result is None:
|
||||||
|
result_content = {"error": f"Tool execution failed or timed out for {func_name}"}
|
||||||
|
elif isinstance(execution_result, dict) and "error" in execution_result:
|
||||||
|
result_content = execution_result # Propagate error from tool/server
|
||||||
|
else:
|
||||||
|
# Assuming result is the content payload
|
||||||
|
result_content = execution_result
|
||||||
|
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.error(f"Failed to parse arguments for tool {func_name}: {func_args_str}")
|
||||||
|
result_content = {"error": f"Invalid arguments format for tool {func_name}"}
|
||||||
|
except Exception as exec_err:
|
||||||
|
logger.error(f"Error executing tool {func_name}: {exec_err}", exc_info=True)
|
||||||
|
result_content = {"error": f"Exception during tool execution: {str(exec_err)}"}
|
||||||
|
|
||||||
|
# Format result for the provider's follow-up message
|
||||||
|
formatted_result = self.provider.format_tool_results(call_id, result_content)
|
||||||
|
tool_results.append(formatted_result)
|
||||||
|
messages.append(formatted_result) # Add tool result message
|
||||||
|
|
||||||
|
# Make follow-up call
|
||||||
|
logger.info("Making follow-up request with tool results...")
|
||||||
|
follow_up_response = self.provider.create_chat_completion(
|
||||||
|
messages=messages, # Now includes assistant's turn and tool results
|
||||||
|
model=model,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
stream=False, # Follow-up is non-streaming here
|
||||||
|
tools=provider_tools, # Pass tools again? Some providers might need it.
|
||||||
|
)
|
||||||
|
final_content = self.provider.get_content(follow_up_response)
|
||||||
|
logger.info("Received follow-up response content.")
|
||||||
|
return {"content": final_content}
|
||||||
|
|
||||||
|
except Exception as tool_handling_err:
|
||||||
|
logger.error(f"Error processing tool calls: {tool_handling_err}", exc_info=True)
|
||||||
|
return {"error": f"Failed to handle tool calls: {str(tool_handling_err)}"}
|
||||||
|
|
||||||
|
else: # No tool calls
|
||||||
|
logger.info("No tool calls detected.")
|
||||||
|
content = self.provider.get_content(response)
|
||||||
|
return {"content": content}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"LLM API Error: {str(e)}"
|
||||||
|
logger.error(error_msg, exc_info=True)
|
||||||
|
if stream:
|
||||||
|
# How to signal error in a stream? Yield a specific error message?
|
||||||
|
# This simple generator won't handle it well. Returning an error dict for now.
|
||||||
|
return {"error": error_msg} # Or raise?
|
||||||
|
else:
|
||||||
|
return {"error": error_msg}
|
||||||
|
|
||||||
|
def _stream_generator(self, response: Any) -> Generator[str, None, None]:
|
||||||
|
"""Helper to yield content from the provider's streaming method."""
|
||||||
|
try:
|
||||||
|
# Use yield from for cleaner and potentially more efficient delegation
|
||||||
|
yield from self.provider.get_streaming_content(response)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during streaming: {e}", exc_info=True)
|
||||||
|
yield json.dumps({"error": f"Streaming error: {str(e)}"}) # Yield error as JSON chunk
|
||||||
|
|
||||||
|
|
||||||
|
# Example of how a provider might need to implement get_original_message_with_calls
|
||||||
|
# This would be in the specific provider class (e.g., openai_provider.py)
|
||||||
|
# def get_original_message_with_calls(self, response: Any) -> Dict[str, Any]:
|
||||||
|
# # For OpenAI, the tool calls are usually in the *first* response chunk's choice delta
|
||||||
|
# # or in the non-streaming response's choice message
|
||||||
|
# # Needs careful implementation based on provider's response structure
|
||||||
|
# assistant_message = {
|
||||||
|
# "role": "assistant",
|
||||||
|
# "content": None, # Often null when tool calls are present
|
||||||
|
# "tool_calls": [...] # Extracted tool calls in provider format
|
||||||
|
# }
|
||||||
|
# return assistant_message
|
||||||
61
src/llm_models.py
Normal file
61
src/llm_models.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
MODELS = {
|
||||||
|
"openai": {
|
||||||
|
"name": "OpenAI",
|
||||||
|
"endpoint": "https://api.openai.com/v1",
|
||||||
|
"models": [
|
||||||
|
{
|
||||||
|
"id": "gpt-4o",
|
||||||
|
"name": "GPT-4o",
|
||||||
|
"default": True,
|
||||||
|
"context_window": 128000,
|
||||||
|
"description": "Input $5/M tokens, Output $15/M tokens",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"anthropic": {
|
||||||
|
"name": "Anthropic",
|
||||||
|
"endpoint": "https://api.anthropic.com/v1/messages",
|
||||||
|
"models": [
|
||||||
|
{
|
||||||
|
"id": "claude-3-7-sonnet-20250219",
|
||||||
|
"name": "Claude 3.7 Sonnet",
|
||||||
|
"default": True,
|
||||||
|
"context_window": 200000,
|
||||||
|
"description": "Input $3/M tokens, Output $15/M tokens",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "claude-3-5-haiku-20241022",
|
||||||
|
"name": "Claude 3.5 Haiku",
|
||||||
|
"default": False,
|
||||||
|
"context_window": 200000,
|
||||||
|
"description": "Input $0.80/M tokens, Output $4/M tokens",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"google": {
|
||||||
|
"name": "Google Gemini",
|
||||||
|
"endpoint": "https://generativelanguage.googleapis.com/v1beta/generateContent",
|
||||||
|
"models": [
|
||||||
|
{
|
||||||
|
"id": "gemini-2.0-flash",
|
||||||
|
"name": "Gemini 2.0 Flash",
|
||||||
|
"default": True,
|
||||||
|
"context_window": 1000000,
|
||||||
|
"description": "Input $0.1/M tokens, Output $0.4/M tokens",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"openrouter": {
|
||||||
|
"name": "OpenRouter",
|
||||||
|
"endpoint": "https://openrouter.ai/api/v1/chat/completions",
|
||||||
|
"models": [
|
||||||
|
{
|
||||||
|
"id": "custom",
|
||||||
|
"name": "Custom Model",
|
||||||
|
"default": False,
|
||||||
|
"context_window": 128000, # Default context window, will be updated based on model
|
||||||
|
"description": "Enter any model name supported by OpenRouter (e.g., 'anthropic/claude-3-opus', 'meta-llama/llama-2-70b')",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
@@ -1,234 +0,0 @@
|
|||||||
"""Synchronous wrapper for managing MCP servers using our custom implementation."""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import importlib.resources
|
|
||||||
import json
|
|
||||||
import logging # Import logging
|
|
||||||
import threading
|
|
||||||
|
|
||||||
from custom_mcp_client import MCPClient, run_interaction
|
|
||||||
|
|
||||||
# Configure basic logging for the application if not already configured
|
|
||||||
# This basic config helps ensure logs are visible during development
|
|
||||||
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
|
||||||
|
|
||||||
# Get a logger for this module
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class SyncMCPManager:
|
|
||||||
"""Synchronous wrapper for managing MCP servers and interactions"""
|
|
||||||
|
|
||||||
def __init__(self, config_path: str = "config/mcp_config.json"):
|
|
||||||
self.config_path = config_path
|
|
||||||
self.config = None
|
|
||||||
self.servers = {}
|
|
||||||
self.initialized = False
|
|
||||||
self._lock = threading.Lock()
|
|
||||||
logger.info(f"Initializing SyncMCPManager with config path: {config_path}")
|
|
||||||
self._load_config()
|
|
||||||
|
|
||||||
def _load_config(self):
|
|
||||||
"""Load MCP configuration from JSON file using importlib"""
|
|
||||||
logger.debug(f"Attempting to load MCP config from: {self.config_path}")
|
|
||||||
try:
|
|
||||||
# First try to load as a package resource
|
|
||||||
try:
|
|
||||||
# Try anchoring to the project name defined in pyproject.toml
|
|
||||||
# This *might* work depending on editable installs or context.
|
|
||||||
resource_path = importlib.resources.files("streamlit-chat-app").joinpath(self.config_path)
|
|
||||||
with resource_path.open("r") as f:
|
|
||||||
self.config = json.load(f)
|
|
||||||
logger.debug("Loaded config via importlib.resources anchored to 'streamlit-chat-app'.")
|
|
||||||
# REMOVED: raise FileNotFoundError
|
|
||||||
|
|
||||||
except (ImportError, ModuleNotFoundError, TypeError, FileNotFoundError, NotADirectoryError): # Added NotADirectoryError
|
|
||||||
logger.debug("Failed to load via importlib.resources, falling back to direct file access.")
|
|
||||||
# Fall back to direct file access relative to CWD
|
|
||||||
with open(self.config_path) as f:
|
|
||||||
self.config = json.load(f)
|
|
||||||
logger.debug("Loaded config via direct file access.")
|
|
||||||
|
|
||||||
logger.info("MCP configuration loaded successfully.")
|
|
||||||
logger.debug(f"Config content: {self.config}") # Log content only if loaded
|
|
||||||
|
|
||||||
except FileNotFoundError:
|
|
||||||
logger.error(f"MCP config file not found at {self.config_path}")
|
|
||||||
self.config = None
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
logger.error(f"Error decoding JSON from MCP config file {self.config_path}: {e}")
|
|
||||||
self.config = None
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error loading MCP config from {self.config_path}: {e}", exc_info=True)
|
|
||||||
self.config = None
|
|
||||||
|
|
||||||
def initialize(self) -> bool:
|
|
||||||
"""Initialize and start all MCP servers synchronously"""
|
|
||||||
logger.info("Initialize requested.")
|
|
||||||
if not self.config:
|
|
||||||
logger.warning("Initialization skipped: No configuration loaded.")
|
|
||||||
return False
|
|
||||||
if not self.config.get("mcpServers"):
|
|
||||||
logger.warning("Initialization skipped: No 'mcpServers' defined in configuration.")
|
|
||||||
return False
|
|
||||||
|
|
||||||
if self.initialized:
|
|
||||||
logger.debug("Initialization skipped: Already initialized.")
|
|
||||||
return True
|
|
||||||
|
|
||||||
with self._lock:
|
|
||||||
if self.initialized: # Double-check after acquiring lock
|
|
||||||
logger.debug("Initialization skipped inside lock: Already initialized.")
|
|
||||||
return True
|
|
||||||
|
|
||||||
logger.info("Starting asynchronous initialization...")
|
|
||||||
# Run async initialization in a new event loop
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop) # Ensure this loop is used by tasks
|
|
||||||
success = loop.run_until_complete(self._async_initialize())
|
|
||||||
loop.close()
|
|
||||||
asyncio.set_event_loop(None) # Clean up
|
|
||||||
|
|
||||||
if success:
|
|
||||||
logger.info("Asynchronous initialization completed successfully.")
|
|
||||||
self.initialized = True
|
|
||||||
else:
|
|
||||||
logger.error("Asynchronous initialization failed.")
|
|
||||||
self.initialized = False # Ensure state is False on failure
|
|
||||||
|
|
||||||
return self.initialized
|
|
||||||
|
|
||||||
async def _async_initialize(self) -> bool:
|
|
||||||
"""Async implementation of server initialization"""
|
|
||||||
logger.debug("Starting _async_initialize...")
|
|
||||||
all_success = True
|
|
||||||
if not self.config or not self.config.get("mcpServers"):
|
|
||||||
logger.warning("_async_initialize: No config or mcpServers found.")
|
|
||||||
return False
|
|
||||||
|
|
||||||
tasks = []
|
|
||||||
server_names = list(self.config["mcpServers"].keys())
|
|
||||||
|
|
||||||
async def start_server(server_name, server_config):
|
|
||||||
logger.info(f"Initializing server: {server_name}")
|
|
||||||
try:
|
|
||||||
client = MCPClient(server_name=server_name, command=server_config.get("command"), args=server_config.get("args", []), env=server_config.get("env", {}))
|
|
||||||
|
|
||||||
logger.debug(f"Attempting to start client for {server_name}...")
|
|
||||||
if await client.start():
|
|
||||||
logger.info(f"Client for {server_name} started successfully.")
|
|
||||||
tools = await client.list_tools()
|
|
||||||
logger.info(f"Tools listed for {server_name}: {len(tools)}")
|
|
||||||
self.servers[server_name] = {"client": client, "tools": tools}
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
logger.error(f"Failed to start MCP server: {server_name}")
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error initializing server {server_name}: {e}", exc_info=True)
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Start servers concurrently
|
|
||||||
for server_name in server_names:
|
|
||||||
server_config = self.config["mcpServers"][server_name]
|
|
||||||
tasks.append(start_server(server_name, server_config))
|
|
||||||
|
|
||||||
results = await asyncio.gather(*tasks)
|
|
||||||
|
|
||||||
# Check if all servers started successfully
|
|
||||||
all_success = all(results)
|
|
||||||
|
|
||||||
if all_success:
|
|
||||||
logger.debug("_async_initialize completed: All servers started successfully.")
|
|
||||||
else:
|
|
||||||
failed_servers = [server_names[i] for i, res in enumerate(results) if not res]
|
|
||||||
logger.error(f"_async_initialize completed with failures. Failed servers: {failed_servers}")
|
|
||||||
# Optionally shutdown servers that did start if partial success is not desired
|
|
||||||
# await self._async_shutdown() # Uncomment to enforce all-or-nothing startup
|
|
||||||
|
|
||||||
return all_success
|
|
||||||
|
|
||||||
def shutdown(self):
|
|
||||||
"""Shut down all MCP servers synchronously"""
|
|
||||||
logger.info("Shutdown requested.")
|
|
||||||
if not self.initialized:
|
|
||||||
logger.debug("Shutdown skipped: Not initialized.")
|
|
||||||
return
|
|
||||||
|
|
||||||
with self._lock:
|
|
||||||
if not self.initialized:
|
|
||||||
logger.debug("Shutdown skipped inside lock: Not initialized.")
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info("Starting asynchronous shutdown...")
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
loop.run_until_complete(self._async_shutdown())
|
|
||||||
loop.close()
|
|
||||||
asyncio.set_event_loop(None)
|
|
||||||
|
|
||||||
self.servers = {}
|
|
||||||
self.initialized = False
|
|
||||||
logger.info("Shutdown complete.")
|
|
||||||
|
|
||||||
async def _async_shutdown(self):
|
|
||||||
"""Async implementation of server shutdown"""
|
|
||||||
logger.debug("Starting _async_shutdown...")
|
|
||||||
tasks = []
|
|
||||||
for server_name, server_info in self.servers.items():
|
|
||||||
logger.debug(f"Initiating shutdown for server: {server_name}")
|
|
||||||
tasks.append(server_info["client"].stop())
|
|
||||||
|
|
||||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
||||||
for i, result in enumerate(results):
|
|
||||||
server_name = list(self.servers.keys())[i]
|
|
||||||
if isinstance(result, Exception):
|
|
||||||
logger.error(f"Error shutting down server {server_name}: {result}", exc_info=result)
|
|
||||||
else:
|
|
||||||
logger.debug(f"Shutdown completed for server: {server_name}")
|
|
||||||
logger.debug("_async_shutdown finished.")
|
|
||||||
|
|
||||||
# Updated process_query signature
|
|
||||||
def process_query(self, query: str, model_name: str, api_key: str, base_url: str | None) -> dict:
|
|
||||||
"""
|
|
||||||
Process a query using MCP tools synchronously
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: The user's input query.
|
|
||||||
model_name: The model to use for processing.
|
|
||||||
api_key: The OpenAI API key.
|
|
||||||
base_url: The OpenAI API base URL.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary containing response or error.
|
|
||||||
"""
|
|
||||||
if not self.initialized and not self.initialize():
|
|
||||||
logger.error("process_query called but MCP manager failed to initialize.")
|
|
||||||
return {"error": "Failed to initialize MCP servers"}
|
|
||||||
|
|
||||||
logger.debug(f"Processing query synchronously: '{query}'")
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
try:
|
|
||||||
# Pass api_key and base_url to _async_process_query
|
|
||||||
result = loop.run_until_complete(self._async_process_query(query, model_name, api_key, base_url))
|
|
||||||
logger.debug(f"Synchronous query processing result: {result}")
|
|
||||||
return result
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error during synchronous query processing: {e}", exc_info=True)
|
|
||||||
return {"error": f"Processing error: {str(e)}"}
|
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
# Updated _async_process_query signature
|
|
||||||
async def _async_process_query(self, query: str, model_name: str, api_key: str, base_url: str | None) -> dict:
|
|
||||||
"""Async implementation of query processing"""
|
|
||||||
# Pass api_key, base_url, and the MCP config separately to run_interaction
|
|
||||||
return await run_interaction(
|
|
||||||
user_query=query,
|
|
||||||
model_name=model_name,
|
|
||||||
api_key=api_key,
|
|
||||||
base_url=base_url,
|
|
||||||
mcp_config=self.config, # self.config only contains MCP server definitions now
|
|
||||||
stream=False,
|
|
||||||
)
|
|
||||||
71
src/providers/__init__.py
Normal file
71
src/providers/__init__.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
# src/providers/__init__.py
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from providers.base import BaseProvider
|
||||||
|
|
||||||
|
# Import specific provider implementations here as they are created
|
||||||
|
from providers.openai_provider import OpenAIProvider
|
||||||
|
|
||||||
|
# from .anthropic_provider import AnthropicProvider
|
||||||
|
# from .google_provider import GoogleProvider
|
||||||
|
# from .openrouter_provider import OpenRouterProvider
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Map provider names (lowercase) to their corresponding class implementations
|
||||||
|
PROVIDER_MAP: dict[str, type[BaseProvider]] = {
|
||||||
|
"openai": OpenAIProvider,
|
||||||
|
# "anthropic": AnthropicProvider,
|
||||||
|
# "google": GoogleProvider,
|
||||||
|
# "openrouter": OpenRouterProvider,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def register_provider(name: str, provider_class: type[BaseProvider]):
|
||||||
|
"""Registers a provider class."""
|
||||||
|
if name.lower() in PROVIDER_MAP:
|
||||||
|
logger.warning(f"Provider '{name}' is already registered. Overwriting.")
|
||||||
|
PROVIDER_MAP[name.lower()] = provider_class
|
||||||
|
logger.info(f"Registered provider: {name}")
|
||||||
|
|
||||||
|
|
||||||
|
def create_llm_provider(provider_name: str, api_key: str, base_url: str | None = None) -> BaseProvider:
|
||||||
|
"""
|
||||||
|
Factory function to create an instance of a specific LLM provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_name: The name of the provider (e.g., 'openai', 'anthropic').
|
||||||
|
api_key: The API key for the provider.
|
||||||
|
base_url: Optional base URL for the provider's API.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An instance of the requested BaseProvider subclass.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the requested provider_name is not registered.
|
||||||
|
"""
|
||||||
|
provider_class = PROVIDER_MAP.get(provider_name.lower())
|
||||||
|
|
||||||
|
if provider_class is None:
|
||||||
|
available = ", ".join(PROVIDER_MAP.keys()) or "None"
|
||||||
|
raise ValueError(f"Unsupported LLM provider: '{provider_name}'. Available providers: {available}")
|
||||||
|
|
||||||
|
logger.info(f"Creating LLM provider instance for: {provider_name}")
|
||||||
|
try:
|
||||||
|
return provider_class(api_key=api_key, base_url=base_url)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to instantiate provider '{provider_name}': {e}", exc_info=True)
|
||||||
|
raise RuntimeError(f"Could not create provider '{provider_name}'.") from e
|
||||||
|
|
||||||
|
|
||||||
|
def get_available_providers() -> list[str]:
|
||||||
|
"""Returns a list of registered provider names."""
|
||||||
|
return list(PROVIDER_MAP.keys())
|
||||||
|
|
||||||
|
|
||||||
|
# Example of how specific providers would register themselves if structured as plugins,
|
||||||
|
# but for now, we'll explicitly import and map them above.
|
||||||
|
# def load_providers():
|
||||||
|
# # Potentially load providers dynamically if designed as plugins
|
||||||
|
# pass
|
||||||
|
# load_providers()
|
||||||
140
src/providers/base.py
Normal file
140
src/providers/base.py
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
# src/providers/base.py
|
||||||
|
import abc
|
||||||
|
from collections.abc import Generator
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
class BaseProvider(abc.ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for LLM providers.
|
||||||
|
|
||||||
|
Defines the common interface for interacting with different LLM APIs,
|
||||||
|
including handling chat completions and tool usage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, api_key: str, base_url: str | None = None):
|
||||||
|
"""
|
||||||
|
Initialize the provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: The API key for the provider.
|
||||||
|
base_url: Optional base URL for the provider's API.
|
||||||
|
"""
|
||||||
|
self.api_key = api_key
|
||||||
|
self.base_url = base_url
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def create_chat_completion(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
model: str,
|
||||||
|
temperature: float = 0.4,
|
||||||
|
max_tokens: int | None = None,
|
||||||
|
stream: bool = True,
|
||||||
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Send a chat completion request to the LLM provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of message dictionaries with 'role' and 'content'.
|
||||||
|
model: Model identifier.
|
||||||
|
temperature: Sampling temperature (0-1).
|
||||||
|
max_tokens: Maximum tokens to generate.
|
||||||
|
stream: Whether to stream the response.
|
||||||
|
tools: Optional list of tools in the provider-specific format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Provider-specific response object (e.g., API response, stream object).
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_streaming_content(self, response: Any) -> Generator[str, None, None]:
|
||||||
|
"""
|
||||||
|
Extracts and yields content chunks from a streaming response object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: The streaming response object returned by create_chat_completion.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
String chunks of the response content.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_content(self, response: Any) -> str:
|
||||||
|
"""
|
||||||
|
Extracts the complete content from a non-streaming response object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: The non-streaming response object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The complete response content as a string.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def has_tool_calls(self, response: Any) -> bool:
|
||||||
|
"""
|
||||||
|
Checks if the response object contains tool calls.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: The response object (streaming or non-streaming).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if tool calls are present, False otherwise.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def parse_tool_calls(self, response: Any) -> list[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Parses tool calls from the response object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: The response object containing tool calls.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of dictionaries, each representing a tool call with details
|
||||||
|
like 'id', 'function_name', 'arguments'. The exact structure might
|
||||||
|
vary slightly based on provider needs but should contain enough
|
||||||
|
info for execution.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def format_tool_results(self, tool_call_id: str, result: Any) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Formats the result of a tool execution into the structure expected
|
||||||
|
by the provider for follow-up requests.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_call_id: The unique ID of the tool call (from parse_tool_calls).
|
||||||
|
result: The data returned by the tool execution.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary representing the tool result in the provider's format.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Converts a list of tools from the standard internal format to the
|
||||||
|
provider-specific format required for the API call.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tools: List of tool definitions in the standard internal format.
|
||||||
|
Each dict contains 'server_name', 'name', 'description', 'input_schema'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of tool definitions in the provider-specific format.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Optional: Add a method for follow-up completions if the provider API
|
||||||
|
# requires a specific structure different from just appending messages.
|
||||||
|
# def create_follow_up_completion(...) -> Any:
|
||||||
|
# pass
|
||||||
239
src/providers/openai_provider.py
Normal file
239
src/providers/openai_provider.py
Normal file
@@ -0,0 +1,239 @@
|
|||||||
|
# src/providers/openai_provider.py
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from collections.abc import Generator
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from openai import OpenAI, Stream
|
||||||
|
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||||
|
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall
|
||||||
|
|
||||||
|
from providers.base import BaseProvider
|
||||||
|
from src.llm_models import MODELS # Use absolute import
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIProvider(BaseProvider):
|
||||||
|
"""Provider implementation for OpenAI and compatible APIs."""
|
||||||
|
|
||||||
|
def __init__(self, api_key: str, base_url: str | None = None):
|
||||||
|
# Use default OpenAI endpoint if base_url is not provided
|
||||||
|
effective_base_url = base_url or MODELS.get("openai", {}).get("endpoint")
|
||||||
|
super().__init__(api_key, effective_base_url)
|
||||||
|
logger.info(f"Initializing OpenAIProvider with base URL: {self.base_url}")
|
||||||
|
try:
|
||||||
|
# TODO: Add default headers like in original client?
|
||||||
|
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize OpenAI client: {e}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def create_chat_completion(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
model: str,
|
||||||
|
temperature: float = 0.4,
|
||||||
|
max_tokens: int | None = None,
|
||||||
|
stream: bool = True,
|
||||||
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
) -> Stream[ChatCompletionChunk] | ChatCompletion:
|
||||||
|
"""Creates a chat completion using the OpenAI API."""
|
||||||
|
logger.debug(f"OpenAI create_chat_completion called. Stream: {stream}, Tools: {bool(tools)}")
|
||||||
|
try:
|
||||||
|
completion_params = {
|
||||||
|
"model": model,
|
||||||
|
"messages": messages,
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"stream": stream,
|
||||||
|
}
|
||||||
|
if tools:
|
||||||
|
completion_params["tools"] = tools
|
||||||
|
completion_params["tool_choice"] = "auto" # Let OpenAI decide when to use tools
|
||||||
|
|
||||||
|
# Remove None values like max_tokens if not provided
|
||||||
|
completion_params = {k: v for k, v in completion_params.items() if v is not None}
|
||||||
|
|
||||||
|
# --- Added Debug Logging ---
|
||||||
|
log_params = completion_params.copy()
|
||||||
|
# Avoid logging full messages if they are too long
|
||||||
|
if "messages" in log_params:
|
||||||
|
log_params["messages"] = [
|
||||||
|
{k: (v[:100] + "..." if isinstance(v, str) and len(v) > 100 else v) for k, v in msg.items()}
|
||||||
|
for msg in log_params["messages"][-2:] # Log last 2 messages summary
|
||||||
|
]
|
||||||
|
# Specifically log tools structure if present
|
||||||
|
tools_log = log_params.get("tools", "Not Present")
|
||||||
|
logger.debug(f"Calling OpenAI API. Model: {log_params.get('model')}, Stream: {log_params.get('stream')}, Tools: {tools_log}")
|
||||||
|
logger.debug(f"Full API Params (messages summarized): {log_params}")
|
||||||
|
# --- End Added Debug Logging ---
|
||||||
|
|
||||||
|
response = self.client.chat.completions.create(**completion_params)
|
||||||
|
logger.debug("OpenAI API call successful.")
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"OpenAI API error: {e}", exc_info=True)
|
||||||
|
# Re-raise for the LLMClient to handle
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_streaming_content(self, response: Stream[ChatCompletionChunk]) -> Generator[str, None, None]:
|
||||||
|
"""Yields content chunks from an OpenAI streaming response."""
|
||||||
|
logger.debug("Processing OpenAI stream...")
|
||||||
|
full_delta = ""
|
||||||
|
try:
|
||||||
|
for chunk in response:
|
||||||
|
delta = chunk.choices[0].delta.content
|
||||||
|
if delta:
|
||||||
|
full_delta += delta
|
||||||
|
yield delta
|
||||||
|
logger.debug(f"Stream finished. Total delta length: {len(full_delta)}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing OpenAI stream: {e}", exc_info=True)
|
||||||
|
# Yield an error message? Or let the generator stop?
|
||||||
|
yield json.dumps({"error": f"Stream processing error: {str(e)}"})
|
||||||
|
|
||||||
|
def get_content(self, response: ChatCompletion) -> str:
|
||||||
|
"""Extracts content from a non-streaming OpenAI response."""
|
||||||
|
try:
|
||||||
|
content = response.choices[0].message.content
|
||||||
|
logger.debug(f"Extracted content (length {len(content) if content else 0}) from non-streaming response.")
|
||||||
|
return content or "" # Return empty string if content is None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error extracting content from OpenAI response: {e}", exc_info=True)
|
||||||
|
return f"[Error extracting content: {str(e)}]"
|
||||||
|
|
||||||
|
def has_tool_calls(self, response: Stream[ChatCompletionChunk] | ChatCompletion) -> bool:
|
||||||
|
"""Checks if the OpenAI response contains tool calls."""
|
||||||
|
try:
|
||||||
|
if isinstance(response, ChatCompletion): # Non-streaming
|
||||||
|
return bool(response.choices[0].message.tool_calls)
|
||||||
|
elif hasattr(response, "_iterator"): # Check if it looks like our stream wrapper
|
||||||
|
# This is tricky for streams. We'd need to peek at the first chunk(s)
|
||||||
|
# or buffer the response. For simplicity, this check might be unreliable
|
||||||
|
# for streams *before* they are consumed. LLMClient needs robust handling.
|
||||||
|
logger.warning("has_tool_calls check on a stream is unreliable before consumption.")
|
||||||
|
# A more robust check would involve consuming the start of the stream
|
||||||
|
# or relying on the structure after consumption.
|
||||||
|
return False # Assume no for unconsumed stream for now
|
||||||
|
else:
|
||||||
|
# If it's already consumed stream or unexpected type
|
||||||
|
logger.warning(f"has_tool_calls received unexpected type: {type(response)}")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error checking for tool calls: {e}", exc_info=True)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def parse_tool_calls(self, response: ChatCompletion) -> list[dict[str, Any]]:
|
||||||
|
"""Parses tool calls from a non-streaming OpenAI response."""
|
||||||
|
# This implementation assumes a non-streaming response or a fully buffered stream
|
||||||
|
parsed_calls = []
|
||||||
|
try:
|
||||||
|
if not isinstance(response, ChatCompletion):
|
||||||
|
logger.error(f"parse_tool_calls expects ChatCompletion, got {type(response)}")
|
||||||
|
# Attempt to handle buffered stream if possible? Complex.
|
||||||
|
return []
|
||||||
|
|
||||||
|
tool_calls: list[ChatCompletionMessageToolCall] | None = response.choices[0].message.tool_calls
|
||||||
|
if not tool_calls:
|
||||||
|
return []
|
||||||
|
|
||||||
|
logger.debug(f"Parsing {len(tool_calls)} tool calls from OpenAI response.")
|
||||||
|
for call in tool_calls:
|
||||||
|
if call.type == "function":
|
||||||
|
# Attempt to parse server_name from function name if prefixed
|
||||||
|
# e.g., "server-name__actual-tool-name"
|
||||||
|
parts = call.function.name.split("__", 1)
|
||||||
|
if len(parts) == 2:
|
||||||
|
server_name, func_name = parts
|
||||||
|
else:
|
||||||
|
# If no prefix, how do we know the server? Needs refinement.
|
||||||
|
# Defaulting to None or a default server? Log warning.
|
||||||
|
logger.warning(f"Could not determine server_name from tool name '{call.function.name}'. Assuming default or error needed.")
|
||||||
|
server_name = None # Or raise error, or use a default?
|
||||||
|
func_name = call.function.name
|
||||||
|
|
||||||
|
parsed_calls.append({
|
||||||
|
"id": call.id,
|
||||||
|
"server_name": server_name, # May be None if not prefixed
|
||||||
|
"function_name": func_name,
|
||||||
|
"arguments": call.function.arguments, # Arguments are already a string here
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
logger.warning(f"Unsupported tool call type: {call.type}")
|
||||||
|
|
||||||
|
return parsed_calls
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error parsing OpenAI tool calls: {e}", exc_info=True)
|
||||||
|
return [] # Return empty list on error
|
||||||
|
|
||||||
|
def format_tool_results(self, tool_call_id: str, result: Any) -> dict[str, Any]:
|
||||||
|
"""Formats a tool result for an OpenAI follow-up request."""
|
||||||
|
# Result might be a dict (including potential errors) or simple string/number
|
||||||
|
# OpenAI expects the content to be a string, often JSON.
|
||||||
|
try:
|
||||||
|
if isinstance(result, dict):
|
||||||
|
content = json.dumps(result)
|
||||||
|
else:
|
||||||
|
content = str(result) # Ensure it's a string
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error JSON-encoding tool result for {tool_call_id}: {e}")
|
||||||
|
content = json.dumps({"error": "Failed to encode tool result", "original_type": str(type(result))})
|
||||||
|
|
||||||
|
logger.debug(f"Formatting tool result for call ID {tool_call_id}")
|
||||||
|
return {
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": tool_call_id,
|
||||||
|
"content": content,
|
||||||
|
}
|
||||||
|
|
||||||
|
def convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
|
"""Converts internal tool format to OpenAI's format."""
|
||||||
|
openai_tools = []
|
||||||
|
logger.debug(f"Converting {len(tools)} tools to OpenAI format.")
|
||||||
|
for tool in tools:
|
||||||
|
server_name = tool.get("server_name")
|
||||||
|
tool_name = tool.get("name")
|
||||||
|
description = tool.get("description")
|
||||||
|
input_schema = tool.get("inputSchema")
|
||||||
|
|
||||||
|
if not server_name or not tool_name or not description or not input_schema:
|
||||||
|
logger.warning(f"Skipping invalid tool definition during conversion: {tool}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Prefix tool name with server name to avoid clashes and allow routing
|
||||||
|
prefixed_tool_name = f"{server_name}__{tool_name}"
|
||||||
|
|
||||||
|
openai_tool_format = {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": prefixed_tool_name,
|
||||||
|
"description": description,
|
||||||
|
"parameters": input_schema, # OpenAI uses JSON Schema directly
|
||||||
|
},
|
||||||
|
}
|
||||||
|
openai_tools.append(openai_tool_format)
|
||||||
|
logger.debug(f"Converted tool: {prefixed_tool_name}")
|
||||||
|
|
||||||
|
return openai_tools
|
||||||
|
|
||||||
|
# Helper needed by LLMClient's current tool handling logic
|
||||||
|
def get_original_message_with_calls(self, response: ChatCompletion) -> dict[str, Any]:
|
||||||
|
"""Extracts the assistant's message containing tool calls."""
|
||||||
|
try:
|
||||||
|
if isinstance(response, ChatCompletion) and response.choices[0].message.tool_calls:
|
||||||
|
message = response.choices[0].message
|
||||||
|
# Convert Pydantic model to dict for message history
|
||||||
|
return message.model_dump(exclude_unset=True)
|
||||||
|
else:
|
||||||
|
logger.warning("Could not extract original message with tool calls from response.")
|
||||||
|
# Return a placeholder or raise error?
|
||||||
|
return {"role": "assistant", "content": "[Could not extract tool calls message]"}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error extracting original message with calls: {e}", exc_info=True)
|
||||||
|
return {"role": "assistant", "content": f"[Error extracting tool calls message: {str(e)}]"}
|
||||||
|
|
||||||
|
|
||||||
|
# Register this provider (if using the registration mechanism)
|
||||||
|
# from . import register_provider
|
||||||
|
# register_provider("openai", OpenAIProvider)
|
||||||
Reference in New Issue
Block a user