Compare commits
3 Commits
a7d5a4cb33
...
a4683023ad
| Author | SHA1 | Date | |
|---|---|---|---|
|
a4683023ad
|
|||
|
b4986e0eb9
|
|||
|
80ba05338f
|
@@ -1,7 +1,31 @@
|
||||
[base]
|
||||
# provider can be [ openai|openrouter|anthropic|google]
|
||||
provider = openrouter
|
||||
|
||||
[openrouter]
|
||||
api_key = YOUR_API_KEY
|
||||
base_url = https://openrouter.ai/api/v1
|
||||
model = openai/gpt-4o-2024-11-20
|
||||
context_window = 128000
|
||||
|
||||
[anthropic]
|
||||
api_key = YOUR_API_KEY
|
||||
base_url = https://api.anthropic.com/v1/messages
|
||||
model = claude-3-7-sonnet-20250219
|
||||
context_window = 128000
|
||||
|
||||
[google]
|
||||
api_key = YOUR_API_KEY
|
||||
base_url = https://generativelanguage.googleapis.com/v1beta/generateContent
|
||||
model = gemini-2.0-flash
|
||||
context_window = 1000000
|
||||
|
||||
|
||||
[openai]
|
||||
api_key = YOUR_API_KEY
|
||||
base_url = CUSTOM_BASE_URL
|
||||
model = YOUR_MODEL_ID
|
||||
base_url = https://api.openai.com/v1
|
||||
model = openai/gpt-4o
|
||||
context_window = 128000
|
||||
|
||||
[mcp]
|
||||
servers_json = config/mcp_config.json
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"mcpServers": {
|
||||
"dolphin-demo-database-sqlite": {
|
||||
"command": "uvx",
|
||||
"args": [
|
||||
"mcp-server-sqlite",
|
||||
"--db-path",
|
||||
"~/.dolphin/dolphin.db"
|
||||
]
|
||||
}
|
||||
"mcpServers": {
|
||||
"mcp-server-sqlite": {
|
||||
"command": "uvx",
|
||||
"args": [
|
||||
"mcp-server-sqlite",
|
||||
"--db-path",
|
||||
"~/.mcpapp/mcpapp.db"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,7 +10,9 @@ authors = [
|
||||
dependencies = [
|
||||
"streamlit",
|
||||
"python-dotenv",
|
||||
"openai"
|
||||
"openai",
|
||||
"anthropic",
|
||||
"google-genai",
|
||||
]
|
||||
classifiers = [
|
||||
"Development Status :: 3 - Alpha",
|
||||
@@ -61,6 +63,7 @@ lint.select = [
|
||||
"T10", # flake8-debugger
|
||||
"A", # flake8-builtins
|
||||
"UP", # pyupgrade
|
||||
"TID", # flake8-tidy-imports
|
||||
]
|
||||
|
||||
lint.ignore = [
|
||||
@@ -81,7 +84,7 @@ skip-magic-trailing-comma = false
|
||||
combine-as-imports = true
|
||||
|
||||
[tool.ruff.lint.mccabe]
|
||||
max-complexity = 12
|
||||
max-complexity = 16
|
||||
|
||||
[tool.ruff.lint.flake8-tidy-imports]
|
||||
# Disallow all relative imports.
|
||||
|
||||
187
src/app.py
187
src/app.py
@@ -1,29 +1,112 @@
|
||||
import atexit
|
||||
import configparser
|
||||
import json # For handling potential error JSON in stream
|
||||
import logging
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from openai_client import OpenAIClient
|
||||
# Updated imports
|
||||
from llm_client import LLMClient
|
||||
from src.custom_mcp.manager import SyncMCPManager # Updated import path
|
||||
|
||||
# Configure logging for the app
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def init_session_state():
|
||||
"""Initializes session state variables including clients."""
|
||||
if "messages" not in st.session_state:
|
||||
st.session_state.messages = []
|
||||
logger.info("Initialized session state: messages")
|
||||
|
||||
if "client" not in st.session_state:
|
||||
st.session_state.client = OpenAIClient()
|
||||
# Register cleanup for MCP servers
|
||||
if hasattr(st.session_state.client, "mcp_manager"):
|
||||
atexit.register(st.session_state.client.mcp_manager.shutdown)
|
||||
logger.info("Attempting to initialize clients...")
|
||||
try:
|
||||
config = configparser.ConfigParser()
|
||||
# TODO: Improve config file path handling (e.g., environment variable, absolute path)
|
||||
config_files_read = config.read("config/config.ini")
|
||||
if not config_files_read:
|
||||
raise FileNotFoundError("config.ini not found or could not be read.")
|
||||
logger.info(f"Read configuration from: {config_files_read}")
|
||||
|
||||
# --- MCP Manager Setup ---
|
||||
mcp_config_path = "config/mcp_config.json" # Default
|
||||
if config.has_section("mcp") and config["mcp"].get("servers_json"):
|
||||
mcp_config_path = config["mcp"]["servers_json"]
|
||||
logger.info(f"Using MCP config path from config.ini: {mcp_config_path}")
|
||||
else:
|
||||
logger.info(f"Using default MCP config path: {mcp_config_path}")
|
||||
|
||||
mcp_manager = SyncMCPManager(mcp_config_path)
|
||||
if not mcp_manager.initialize():
|
||||
# Log warning but continue - LLMClient will operate without tools
|
||||
logger.warning("MCP Manager failed to initialize. Proceeding without MCP tools.")
|
||||
else:
|
||||
logger.info("MCP Manager initialized successfully.")
|
||||
# Register shutdown hook for MCP manager
|
||||
atexit.register(mcp_manager.shutdown)
|
||||
logger.info("Registered MCP Manager shutdown hook.")
|
||||
|
||||
# --- LLM Client Setup ---
|
||||
provider_name = None
|
||||
model_name = None
|
||||
api_key = None
|
||||
base_url = None
|
||||
|
||||
# 1. Determine provider from [base] section
|
||||
if config.has_section("base") and config["base"].get("provider"):
|
||||
provider_name = config["base"].get("provider")
|
||||
logger.info(f"Provider selected from [base] section: {provider_name}")
|
||||
else:
|
||||
# Fallback or error if [base] provider is missing? Let's error for now.
|
||||
raise ValueError("Missing 'provider' setting in [base] section of config.ini")
|
||||
|
||||
# 2. Read details from the specific provider's section
|
||||
if config.has_section(provider_name):
|
||||
provider_config = config[provider_name]
|
||||
model_name = provider_config.get("model")
|
||||
api_key = provider_config.get("api_key")
|
||||
base_url = provider_config.get("base_url") # Optional
|
||||
logger.info(f"Read configuration from [{provider_name}] section.")
|
||||
else:
|
||||
raise ValueError(f"Missing configuration section '[{provider_name}]' in config.ini for the selected provider.")
|
||||
|
||||
# Validate required config
|
||||
if not api_key:
|
||||
raise ValueError(f"Missing 'api_key' in [{provider_name}] section of config.ini")
|
||||
if not model_name:
|
||||
raise ValueError(f"Missing 'model' name in [{provider_name}] section of config.ini")
|
||||
|
||||
logger.info(f"Configuring LLMClient for provider: {provider_name}, model: {model_name}")
|
||||
st.session_state.client = LLMClient(
|
||||
provider_name=provider_name,
|
||||
api_key=api_key,
|
||||
mcp_manager=mcp_manager,
|
||||
base_url=base_url,
|
||||
)
|
||||
st.session_state.model_name = model_name
|
||||
logger.info("LLMClient initialized successfully.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize application clients: {e}", exc_info=True)
|
||||
st.error(f"Application Initialization Error: {e}. Please check configuration and logs.")
|
||||
# Stop the app if initialization fails critically
|
||||
st.stop()
|
||||
|
||||
|
||||
def display_chat_messages():
|
||||
"""Displays chat messages stored in session state."""
|
||||
for message in st.session_state.messages:
|
||||
with st.chat_message(message["role"]):
|
||||
# Simple markdown display for now
|
||||
st.markdown(message["content"])
|
||||
|
||||
|
||||
def handle_user_input():
|
||||
"""Handles user input, calls LLMClient, and displays the response."""
|
||||
if prompt := st.chat_input("Type your message..."):
|
||||
print(f"User input received: {prompt}") # Debug log
|
||||
logger.info(f"User input received: '{prompt[:50]}...'")
|
||||
st.session_state.messages.append({"role": "user", "content": prompt})
|
||||
with st.chat_message("user"):
|
||||
st.markdown(prompt)
|
||||
@@ -32,39 +115,85 @@ def handle_user_input():
|
||||
with st.chat_message("assistant"):
|
||||
response_placeholder = st.empty()
|
||||
full_response = ""
|
||||
error_occurred = False
|
||||
|
||||
print("Processing message...") # Debug log
|
||||
response = st.session_state.client.get_chat_response(st.session_state.messages)
|
||||
logger.info("Processing message via LLMClient...")
|
||||
# Use the new client and method, always requesting stream for UI
|
||||
response_stream = st.session_state.client.chat_completion(
|
||||
messages=st.session_state.messages,
|
||||
model=st.session_state.model_name, # Get model from session state
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Handle both MCP and standard OpenAI responses
|
||||
# Check if it's NOT a dict (assuming stream is not a dict)
|
||||
if not isinstance(response, dict):
|
||||
# Standard OpenAI streaming response
|
||||
for chunk in response:
|
||||
# Ensure chunk has choices and delta before accessing
|
||||
if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
|
||||
full_response += chunk.choices[0].delta.content
|
||||
response_placeholder.markdown(full_response + "▌")
|
||||
# Handle the response (stream generator or error dict)
|
||||
if hasattr(response_stream, "__iter__") and not isinstance(response_stream, dict):
|
||||
logger.debug("Processing response stream...")
|
||||
for chunk in response_stream:
|
||||
# Check for potential error JSON yielded by the stream
|
||||
try:
|
||||
# Attempt to parse chunk as JSON only if it looks like it
|
||||
if isinstance(chunk, str) and chunk.strip().startswith("{"):
|
||||
error_data = json.loads(chunk)
|
||||
if isinstance(error_data, dict) and "error" in error_data:
|
||||
full_response = f"Error: {error_data['error']}"
|
||||
logger.error(f"Error received in stream: {full_response}")
|
||||
st.error(full_response)
|
||||
error_occurred = True
|
||||
break # Stop processing stream on error
|
||||
# If not error JSON, treat as content chunk
|
||||
if not error_occurred and isinstance(chunk, str):
|
||||
full_response += chunk
|
||||
response_placeholder.markdown(full_response + "▌") # Add cursor effect
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# Not JSON or not error structure, treat as content chunk
|
||||
if not error_occurred and isinstance(chunk, str):
|
||||
full_response += chunk
|
||||
response_placeholder.markdown(full_response + "▌") # Add cursor effect
|
||||
|
||||
if not error_occurred:
|
||||
response_placeholder.markdown(full_response) # Final update without cursor
|
||||
logger.debug("Stream processing complete.")
|
||||
|
||||
elif isinstance(response_stream, dict) and "error" in response_stream:
|
||||
# Handle error dict returned directly (e.g., API error before streaming)
|
||||
full_response = f"Error: {response_stream['error']}"
|
||||
logger.error(f"Error returned directly from chat_completion: {full_response}")
|
||||
st.error(full_response)
|
||||
error_occurred = True
|
||||
else:
|
||||
# MCP non-streaming response
|
||||
full_response = response.get("assistant_text", "")
|
||||
response_placeholder.markdown(full_response)
|
||||
# Unexpected response type
|
||||
full_response = "[Unexpected response format from LLMClient]"
|
||||
logger.error(f"Unexpected response type: {type(response_stream)}")
|
||||
st.error(full_response)
|
||||
error_occurred = True
|
||||
|
||||
response_placeholder.markdown(full_response)
|
||||
st.session_state.messages.append({"role": "assistant", "content": full_response})
|
||||
print("Message processed successfully") # Debug log
|
||||
# Only add non-error, non-empty responses to history
|
||||
if not error_occurred and full_response:
|
||||
st.session_state.messages.append({"role": "assistant", "content": full_response})
|
||||
logger.info("Assistant response added to history.")
|
||||
elif error_occurred:
|
||||
logger.warning("Assistant response not added to history due to error.")
|
||||
else:
|
||||
logger.warning("Empty assistant response received, not added to history.")
|
||||
|
||||
except Exception as e:
|
||||
st.error(f"Error processing message: {str(e)}")
|
||||
print(f"Error details: {str(e)}") # Debug log
|
||||
logger.error(f"Error during chat handling: {str(e)}", exc_info=True)
|
||||
st.error(f"An unexpected error occurred: {str(e)}")
|
||||
|
||||
|
||||
def main():
|
||||
st.title("Streamlit Chat App")
|
||||
init_session_state()
|
||||
display_chat_messages()
|
||||
handle_user_input()
|
||||
"""Main function to run the Streamlit app."""
|
||||
st.title("MCP Chat App") # Updated title
|
||||
try:
|
||||
init_session_state()
|
||||
display_chat_messages()
|
||||
handle_user_input()
|
||||
except Exception as e:
|
||||
# Catch potential errors during rendering or handling
|
||||
logger.critical(f"Critical error in main app flow: {e}", exc_info=True)
|
||||
st.error(f"A critical application error occurred: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info("Starting Streamlit Chat App...")
|
||||
main()
|
||||
|
||||
1
src/custom_mcp/__init__.py
Normal file
1
src/custom_mcp/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# This file makes src/mcp a Python package
|
||||
281
src/custom_mcp/client.py
Normal file
281
src/custom_mcp/client.py
Normal file
@@ -0,0 +1,281 @@
|
||||
# src/mcp/client.py
|
||||
"""Client class for managing and interacting with a single MCP server process."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from custom_mcp import process, protocol
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Define reasonable timeouts
|
||||
LIST_TOOLS_TIMEOUT = 20.0 # Seconds (using the increased value from previous step)
|
||||
CALL_TOOL_TIMEOUT = 110.0 # Seconds
|
||||
|
||||
|
||||
class MCPClient:
|
||||
"""
|
||||
Manages the lifecycle and async communication with a single MCP server process.
|
||||
"""
|
||||
|
||||
def __init__(self, server_name: str, command: str, args: list[str], config_env: dict[str, str]):
|
||||
"""
|
||||
Initializes the client for a specific server configuration.
|
||||
|
||||
Args:
|
||||
server_name: Unique name for the server (for logging).
|
||||
command: The command executable.
|
||||
args: List of arguments for the command.
|
||||
config_env: Server-specific environment variables.
|
||||
"""
|
||||
self.server_name = server_name
|
||||
self.command = command
|
||||
self.args = args
|
||||
self.config_env = config_env
|
||||
self.process: asyncio.subprocess.Process | None = None
|
||||
self.reader: asyncio.StreamReader | None = None
|
||||
self.writer: asyncio.StreamWriter | None = None
|
||||
self._stderr_task: asyncio.Task | None = None
|
||||
self._request_counter = 0
|
||||
self._is_running = False
|
||||
self.logger = logging.getLogger(f"{__name__}.{self.server_name}") # Instance-specific logger
|
||||
|
||||
async def _log_stderr(self):
|
||||
"""Logs stderr output from the server process."""
|
||||
if not self.process or not self.process.stderr:
|
||||
self.logger.debug("Stderr logging skipped: process or stderr not available.")
|
||||
return
|
||||
stderr_reader = self.process.stderr
|
||||
try:
|
||||
while not stderr_reader.at_eof():
|
||||
line = await stderr_reader.readline()
|
||||
if line:
|
||||
self.logger.warning(f"[stderr] {line.decode().strip()}")
|
||||
except asyncio.CancelledError:
|
||||
self.logger.debug("Stderr logging task cancelled.")
|
||||
except Exception as e:
|
||||
# Log errors but don't crash the logger task if possible
|
||||
self.logger.error(f"Error reading stderr: {e}", exc_info=True)
|
||||
finally:
|
||||
self.logger.debug("Stderr logging task finished.")
|
||||
|
||||
async def start(self) -> bool:
|
||||
"""
|
||||
Starts the MCP server subprocess and sets up communication streams.
|
||||
|
||||
Returns:
|
||||
True if the process started successfully, False otherwise.
|
||||
"""
|
||||
if self._is_running:
|
||||
self.logger.warning("Start called but client is already running.")
|
||||
return True
|
||||
|
||||
self.logger.info("Starting MCP server process...")
|
||||
try:
|
||||
self.process = await process.start_mcp_process(self.command, self.args, self.config_env)
|
||||
self.reader = self.process.stdout
|
||||
self.writer = self.process.stdin
|
||||
|
||||
if self.reader is None or self.writer is None:
|
||||
self.logger.error("Failed to get stdout/stdin streams after process start.")
|
||||
await self.stop() # Attempt cleanup
|
||||
return False
|
||||
|
||||
# Start background task to monitor stderr
|
||||
self._stderr_task = asyncio.create_task(self._log_stderr())
|
||||
|
||||
# --- Start MCP Initialization Handshake ---
|
||||
self.logger.info("Starting MCP initialization handshake...")
|
||||
self._request_counter += 1
|
||||
init_req_id = self._request_counter
|
||||
initialize_req = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": init_req_id,
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05", # Use a recent version
|
||||
"clientInfo": {"name": "CustomMCPClient", "version": "1.0.0"}, # Identify the client
|
||||
"capabilities": {}, # Client capabilities (can be empty)
|
||||
},
|
||||
}
|
||||
|
||||
# Define a timeout for initialization
|
||||
INITIALIZE_TIMEOUT = 15.0 # Seconds
|
||||
|
||||
try:
|
||||
# Send initialize request
|
||||
await protocol.send_request(self.writer, initialize_req)
|
||||
self.logger.debug(f"Sent 'initialize' request (ID: {init_req_id}). Waiting for response...")
|
||||
|
||||
# Wait for initialize response
|
||||
init_response = await protocol.read_response(self.reader, INITIALIZE_TIMEOUT)
|
||||
|
||||
if init_response and init_response.get("id") == init_req_id:
|
||||
if "error" in init_response:
|
||||
self.logger.error(f"Server returned error during initialization: {init_response['error']}")
|
||||
await self.stop()
|
||||
return False
|
||||
elif "result" in init_response:
|
||||
self.logger.info(f"Received 'initialize' response: {init_response.get('result', '{}')}") # Log server capabilities if provided
|
||||
|
||||
# Send initialized notification (using standard method name)
|
||||
initialized_notify = {"jsonrpc": "2.0", "method": "notifications/initialized", "params": {}}
|
||||
await protocol.send_request(self.writer, initialized_notify)
|
||||
self.logger.info("'notifications/initialized' notification sent.")
|
||||
|
||||
self._is_running = True
|
||||
self.logger.info("MCP server process started and initialized successfully.")
|
||||
return True
|
||||
else:
|
||||
self.logger.error("Invalid 'initialize' response format (missing result/error).")
|
||||
await self.stop()
|
||||
return False
|
||||
elif init_response:
|
||||
self.logger.error(f"Received response with mismatched ID during initialization. Expected {init_req_id}, got {init_response.get('id')}")
|
||||
await self.stop()
|
||||
return False
|
||||
else: # Timeout case
|
||||
self.logger.error(f"'initialize' request timed out after {INITIALIZE_TIMEOUT} seconds.")
|
||||
await self.stop()
|
||||
return False
|
||||
|
||||
except ConnectionResetError:
|
||||
self.logger.error("Connection lost during initialization handshake. Stopping client.")
|
||||
await self.stop()
|
||||
return False
|
||||
except Exception as e:
|
||||
self.logger.error(f"Unexpected error during initialization handshake: {e}", exc_info=True)
|
||||
await self.stop()
|
||||
return False
|
||||
# --- End MCP Initialization Handshake ---
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to start MCP server process: {e}", exc_info=True)
|
||||
self.process = None # Ensure process is None on failure
|
||||
self.reader = None
|
||||
self.writer = None
|
||||
self._is_running = False
|
||||
return False
|
||||
|
||||
async def stop(self):
|
||||
"""Stops the MCP server subprocess gracefully."""
|
||||
if not self._is_running and not self.process:
|
||||
self.logger.debug("Stop called but client is not running.")
|
||||
return
|
||||
|
||||
self.logger.info("Stopping MCP server process...")
|
||||
self._is_running = False # Mark as stopping
|
||||
|
||||
# Cancel stderr logging task
|
||||
if self._stderr_task and not self._stderr_task.done():
|
||||
self._stderr_task.cancel()
|
||||
try:
|
||||
await self._stderr_task
|
||||
except asyncio.CancelledError:
|
||||
self.logger.debug("Stderr task successfully cancelled.")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error waiting for stderr task cancellation: {e}")
|
||||
self._stderr_task = None
|
||||
|
||||
# Stop the process using the utility function
|
||||
if self.process:
|
||||
await process.stop_mcp_process(self.process, self.server_name)
|
||||
|
||||
# Nullify references
|
||||
self.process = None
|
||||
self.reader = None
|
||||
self.writer = None
|
||||
self.logger.info("MCP server process stopped.")
|
||||
|
||||
async def list_tools(self) -> list[dict[str, Any]] | None:
|
||||
"""
|
||||
Sends a 'tools/list' request and waits for the response.
|
||||
|
||||
Returns:
|
||||
A list of tool dictionaries, or None on error/timeout.
|
||||
"""
|
||||
if not self._is_running or not self.writer or not self.reader:
|
||||
self.logger.error("Cannot list tools: client not running or streams unavailable.")
|
||||
return None
|
||||
|
||||
self._request_counter += 1
|
||||
req_id = self._request_counter
|
||||
request = {"jsonrpc": "2.0", "method": "tools/list", "id": req_id}
|
||||
|
||||
try:
|
||||
await protocol.send_request(self.writer, request)
|
||||
response = await protocol.read_response(self.reader, LIST_TOOLS_TIMEOUT)
|
||||
|
||||
if response and "result" in response and isinstance(response["result"], dict) and "tools" in response["result"]:
|
||||
tools = response["result"]["tools"]
|
||||
if isinstance(tools, list):
|
||||
self.logger.info(f"Successfully listed {len(tools)} tools.")
|
||||
return tools
|
||||
else:
|
||||
self.logger.error(f"Invalid 'tools' format in response ID {req_id}: {type(tools)}")
|
||||
return None
|
||||
elif response and "error" in response:
|
||||
self.logger.error(f"Error response for listTools ID {req_id}: {response['error']}")
|
||||
return None
|
||||
else:
|
||||
# Includes timeout case (read_response returns None)
|
||||
self.logger.error(f"No valid response or timeout for listTools ID {req_id}.")
|
||||
return None
|
||||
|
||||
except ConnectionResetError:
|
||||
self.logger.error("Connection lost during listTools request. Stopping client.")
|
||||
await self.stop()
|
||||
return None
|
||||
except Exception as e:
|
||||
self.logger.error(f"Unexpected error during listTools: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""
|
||||
Sends a 'tools/call' request and waits for the response.
|
||||
|
||||
Args:
|
||||
tool_name: The name of the tool to call.
|
||||
arguments: The arguments for the tool.
|
||||
|
||||
Returns:
|
||||
The result dictionary from the server, or None on error/timeout.
|
||||
"""
|
||||
if not self._is_running or not self.writer or not self.reader:
|
||||
self.logger.error(f"Cannot call tool '{tool_name}': client not running or streams unavailable.")
|
||||
return None
|
||||
|
||||
self._request_counter += 1
|
||||
req_id = self._request_counter
|
||||
request = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": "tools/call",
|
||||
"params": {"name": tool_name, "arguments": arguments},
|
||||
"id": req_id,
|
||||
}
|
||||
|
||||
try:
|
||||
await protocol.send_request(self.writer, request)
|
||||
response = await protocol.read_response(self.reader, CALL_TOOL_TIMEOUT)
|
||||
|
||||
if response and "result" in response:
|
||||
# Assuming result is the desired payload
|
||||
self.logger.info(f"Tool '{tool_name}' executed successfully.")
|
||||
return response["result"]
|
||||
elif response and "error" in response:
|
||||
self.logger.error(f"Error response for tool '{tool_name}' ID {req_id}: {response['error']}")
|
||||
# Return the error structure itself? Or just None? Returning error dict for now.
|
||||
return {"error": response["error"]}
|
||||
else:
|
||||
# Includes timeout case
|
||||
self.logger.error(f"No valid response or timeout for tool '{tool_name}' ID {req_id}.")
|
||||
return None
|
||||
|
||||
except ConnectionResetError:
|
||||
self.logger.error(f"Connection lost during callTool '{tool_name}'. Stopping client.")
|
||||
await self.stop()
|
||||
return None
|
||||
except Exception as e:
|
||||
self.logger.error(f"Unexpected error during callTool '{tool_name}': {e}", exc_info=True)
|
||||
return None
|
||||
366
src/custom_mcp/manager.py
Normal file
366
src/custom_mcp/manager.py
Normal file
@@ -0,0 +1,366 @@
|
||||
# src/mcp/manager.py
|
||||
"""Synchronous manager for multiple MCPClient instances."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
# Use relative imports within the mcp package
|
||||
from custom_mcp.client import MCPClient
|
||||
|
||||
# Configure basic logging
|
||||
# Consider moving this to the main app entry point if not already done
|
||||
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Define reasonable timeouts for sync calls (should be slightly longer than async timeouts)
|
||||
INITIALIZE_TIMEOUT = 60.0 # Seconds
|
||||
SHUTDOWN_TIMEOUT = 30.0 # Seconds
|
||||
LIST_ALL_TOOLS_TIMEOUT = 30.0 # Seconds
|
||||
EXECUTE_TOOL_TIMEOUT = 120.0 # Seconds
|
||||
|
||||
|
||||
class SyncMCPManager:
|
||||
"""
|
||||
Manages the lifecycle of multiple MCPClient instances and provides a
|
||||
synchronous interface to interact with them using a background event loop.
|
||||
"""
|
||||
|
||||
def __init__(self, config_path: str = "config/mcp_config.json"):
|
||||
"""
|
||||
Initializes the manager, loads config, but does not start servers yet.
|
||||
|
||||
Args:
|
||||
config_path: Path to the MCP server configuration JSON file.
|
||||
"""
|
||||
self.config_path = config_path
|
||||
self.config: dict[str, Any] | None = None
|
||||
# Stores server_name -> MCPClient instance
|
||||
self.servers: dict[str, MCPClient] = {}
|
||||
self.initialized = False
|
||||
self._lock = threading.Lock()
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
self._thread: threading.Thread | None = None
|
||||
logger.info(f"Initializing SyncMCPManager with config path: {config_path}")
|
||||
self._load_config()
|
||||
|
||||
def _load_config(self):
|
||||
"""Load MCP configuration from JSON file."""
|
||||
logger.debug(f"Attempting to load MCP config from: {self.config_path}")
|
||||
try:
|
||||
# Using direct file access
|
||||
with open(self.config_path) as f:
|
||||
self.config = json.load(f)
|
||||
logger.info("MCP configuration loaded successfully.")
|
||||
logger.debug(f"Config content: {self.config}")
|
||||
except FileNotFoundError:
|
||||
logger.error(f"MCP config file not found at {self.config_path}")
|
||||
self.config = None
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Error decoding JSON from MCP config file {self.config_path}: {e}")
|
||||
self.config = None
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading MCP config from {self.config_path}: {e}", exc_info=True)
|
||||
self.config = None
|
||||
|
||||
# --- Background Event Loop Management ---
|
||||
|
||||
def _run_event_loop(self):
|
||||
"""Target function for the background event loop thread."""
|
||||
try:
|
||||
self._loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self._loop)
|
||||
self._loop.run_forever()
|
||||
finally:
|
||||
if self._loop and not self._loop.is_closed():
|
||||
# Clean up remaining tasks before closing
|
||||
try:
|
||||
tasks = asyncio.all_tasks(self._loop)
|
||||
if tasks:
|
||||
logger.debug(f"Cancelling {len(tasks)} outstanding tasks before closing loop...")
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
# Allow cancellation to propagate
|
||||
self._loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
|
||||
logger.debug("Outstanding tasks cancelled.")
|
||||
self._loop.run_until_complete(self._loop.shutdown_asyncgens())
|
||||
except Exception as e:
|
||||
logger.error(f"Error during event loop cleanup: {e}")
|
||||
finally:
|
||||
self._loop.close()
|
||||
asyncio.set_event_loop(None)
|
||||
logger.info("Event loop thread finished.")
|
||||
|
||||
def _start_event_loop_thread(self):
|
||||
"""Starts the background event loop thread if not already running."""
|
||||
if self._thread is None or not self._thread.is_alive():
|
||||
self._thread = threading.Thread(target=self._run_event_loop, name="MCPEventLoop", daemon=True)
|
||||
self._thread.start()
|
||||
logger.info("Event loop thread started.")
|
||||
# Wait briefly for the loop to become available and running
|
||||
while self._loop is None or not self._loop.is_running():
|
||||
# Use time.sleep in sync context
|
||||
import time
|
||||
|
||||
time.sleep(0.01)
|
||||
logger.debug("Event loop is running.")
|
||||
|
||||
def _stop_event_loop_thread(self):
|
||||
"""Stops the background event loop thread."""
|
||||
if self._loop and self._loop.is_running():
|
||||
logger.info("Requesting event loop stop...")
|
||||
self._loop.call_soon_threadsafe(self._loop.stop)
|
||||
if self._thread and self._thread.is_alive():
|
||||
logger.info("Waiting for event loop thread to join...")
|
||||
self._thread.join(timeout=5)
|
||||
if self._thread.is_alive():
|
||||
logger.warning("Event loop thread did not stop gracefully.")
|
||||
self._loop = None
|
||||
self._thread = None
|
||||
logger.info("Event loop stopped.")
|
||||
|
||||
# --- Public Synchronous Interface ---
|
||||
|
||||
def initialize(self) -> bool:
|
||||
"""
|
||||
Initializes and starts all configured MCP servers synchronously.
|
||||
|
||||
Returns:
|
||||
True if all servers started successfully, False otherwise.
|
||||
"""
|
||||
logger.info("Manager initialization requested.")
|
||||
if not self.config or not self.config.get("mcpServers"):
|
||||
logger.warning("Initialization skipped: No valid configuration loaded.")
|
||||
return False
|
||||
|
||||
with self._lock:
|
||||
if self.initialized:
|
||||
logger.debug("Initialization skipped: Already initialized.")
|
||||
return True
|
||||
|
||||
self._start_event_loop_thread()
|
||||
if not self._loop:
|
||||
logger.error("Failed to start event loop for initialization.")
|
||||
return False
|
||||
|
||||
logger.info("Submitting asynchronous server initialization...")
|
||||
|
||||
# Prepare coroutine to start all clients
|
||||
|
||||
async def _async_init_all():
|
||||
tasks = []
|
||||
for server_name, server_config in self.config["mcpServers"].items():
|
||||
command = server_config.get("command")
|
||||
args = server_config.get("args", [])
|
||||
config_env = server_config.get("env", {})
|
||||
if not command:
|
||||
logger.error(f"Skipping server {server_name}: Missing 'command'.")
|
||||
continue
|
||||
|
||||
client = MCPClient(server_name, command, args, config_env)
|
||||
self.servers[server_name] = client
|
||||
tasks.append(client.start()) # Append the start coroutine
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Check results - True means success, False or Exception means failure
|
||||
all_success = True
|
||||
failed_servers = []
|
||||
for i, result in enumerate(results):
|
||||
server_name = list(self.config["mcpServers"].keys())[i] # Assumes order is maintained
|
||||
if isinstance(result, Exception) or result is False:
|
||||
all_success = False
|
||||
failed_servers.append(server_name)
|
||||
# Remove failed client from managed servers
|
||||
if server_name in self.servers:
|
||||
del self.servers[server_name]
|
||||
logger.error(f"Failed to start client for server '{server_name}'. Result/Error: {result}")
|
||||
|
||||
if not all_success:
|
||||
logger.error(f"Initialization failed for servers: {failed_servers}")
|
||||
return all_success
|
||||
|
||||
# Run the initialization coroutine in the background loop
|
||||
future = asyncio.run_coroutine_threadsafe(_async_init_all(), self._loop)
|
||||
try:
|
||||
success = future.result(timeout=INITIALIZE_TIMEOUT)
|
||||
if success:
|
||||
logger.info("Asynchronous initialization completed successfully.")
|
||||
self.initialized = True
|
||||
else:
|
||||
logger.error("Asynchronous initialization failed.")
|
||||
self.initialized = False
|
||||
# Attempt to clean up any partially started servers
|
||||
self.shutdown() # Call sync shutdown
|
||||
except TimeoutError:
|
||||
logger.error(f"Initialization timed out after {INITIALIZE_TIMEOUT}s.")
|
||||
self.initialized = False
|
||||
self.shutdown() # Clean up
|
||||
success = False
|
||||
except Exception as e:
|
||||
logger.error(f"Exception during initialization future result: {e}", exc_info=True)
|
||||
self.initialized = False
|
||||
self.shutdown() # Clean up
|
||||
success = False
|
||||
|
||||
return self.initialized
|
||||
|
||||
def shutdown(self):
|
||||
"""Shuts down all managed MCP servers synchronously."""
|
||||
logger.info("Manager shutdown requested.")
|
||||
with self._lock:
|
||||
# Check servers dict too, in case init was partial
|
||||
if not self.initialized and not self.servers:
|
||||
logger.debug("Shutdown skipped: Not initialized or no servers running.")
|
||||
# Ensure loop is stopped if it exists
|
||||
if self._thread and self._thread.is_alive():
|
||||
self._stop_event_loop_thread()
|
||||
return
|
||||
|
||||
if not self._loop or not self._loop.is_running():
|
||||
logger.warning("Shutdown requested but event loop not running. Attempting direct cleanup.")
|
||||
# Attempt direct cleanup if loop isn't running (shouldn't happen ideally)
|
||||
# This part is tricky as MCPClient.stop is async.
|
||||
# For simplicity, we might just log and rely on process termination on app exit.
|
||||
# Or, try a temporary loop just for shutdown? Let's stick to stopping the thread for now.
|
||||
self.servers = {}
|
||||
self.initialized = False
|
||||
if self._thread and self._thread.is_alive():
|
||||
self._stop_event_loop_thread()
|
||||
return
|
||||
|
||||
logger.info("Submitting asynchronous server shutdown...")
|
||||
|
||||
# Prepare coroutine to stop all clients
|
||||
|
||||
async def _async_shutdown_all():
|
||||
tasks = [client.stop() for client in self.servers.values()]
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Run the shutdown coroutine in the background loop
|
||||
future = asyncio.run_coroutine_threadsafe(_async_shutdown_all(), self._loop)
|
||||
try:
|
||||
future.result(timeout=SHUTDOWN_TIMEOUT)
|
||||
logger.info("Asynchronous shutdown completed.")
|
||||
except TimeoutError:
|
||||
logger.error(f"Shutdown timed out after {SHUTDOWN_TIMEOUT}s. Event loop will be stopped.")
|
||||
# Processes might still be running, OS will clean up on exit hopefully
|
||||
except Exception as e:
|
||||
logger.error(f"Exception during shutdown future result: {e}", exc_info=True)
|
||||
finally:
|
||||
# Always mark as uninitialized and clear servers dict
|
||||
self.servers = {}
|
||||
self.initialized = False
|
||||
# Stop the background thread
|
||||
self._stop_event_loop_thread()
|
||||
|
||||
logger.info("Manager shutdown complete.")
|
||||
|
||||
def list_all_tools(self) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Retrieves tools from all initialized MCP servers synchronously.
|
||||
|
||||
Returns:
|
||||
A list of tool definitions in the standard internal format,
|
||||
aggregated from all servers. Returns empty list on failure.
|
||||
"""
|
||||
if not self.initialized or not self.servers:
|
||||
logger.warning("Cannot list tools: Manager not initialized or no servers running.")
|
||||
return []
|
||||
|
||||
if not self._loop or not self._loop.is_running():
|
||||
logger.error("Cannot list tools: Event loop not running.")
|
||||
return []
|
||||
|
||||
logger.info(f"Requesting tools from {len(self.servers)} servers...")
|
||||
|
||||
# Prepare coroutine to list tools from all clients
|
||||
async def _async_list_all():
|
||||
tasks = []
|
||||
server_names_in_order = []
|
||||
for server_name, client in self.servers.items():
|
||||
tasks.append(client.list_tools())
|
||||
server_names_in_order.append(server_name)
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
all_tools = []
|
||||
for i, result in enumerate(results):
|
||||
server_name = server_names_in_order[i]
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"Error listing tools for server '{server_name}': {result}")
|
||||
elif result is None:
|
||||
# MCPClient.list_tools returns None on timeout/error
|
||||
logger.error(f"Failed to list tools for server '{server_name}' (timeout or error).")
|
||||
elif isinstance(result, list):
|
||||
# Add server_name to each tool definition
|
||||
for tool in result:
|
||||
tool["server_name"] = server_name
|
||||
all_tools.extend(result)
|
||||
logger.debug(f"Received {len(result)} tools from {server_name}")
|
||||
else:
|
||||
logger.error(f"Unexpected result type ({type(result)}) when listing tools for {server_name}.")
|
||||
return all_tools
|
||||
|
||||
# Run the coroutine in the background loop
|
||||
future = asyncio.run_coroutine_threadsafe(_async_list_all(), self._loop)
|
||||
try:
|
||||
aggregated_tools = future.result(timeout=LIST_ALL_TOOLS_TIMEOUT)
|
||||
logger.info(f"Aggregated {len(aggregated_tools)} tools from all servers.")
|
||||
return aggregated_tools
|
||||
except TimeoutError:
|
||||
logger.error(f"Listing all tools timed out after {LIST_ALL_TOOLS_TIMEOUT}s.")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Exception during listing all tools future result: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
def execute_tool(self, server_name: str, tool_name: str, arguments: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""
|
||||
Executes a specific tool on the designated MCP server synchronously.
|
||||
|
||||
Args:
|
||||
server_name: The name of the server hosting the tool.
|
||||
tool_name: The name of the tool to execute.
|
||||
arguments: A dictionary of arguments for the tool.
|
||||
|
||||
Returns:
|
||||
The result content from the tool execution (dict),
|
||||
an error dict ({"error": ...}), or None on timeout/comm failure.
|
||||
"""
|
||||
if not self.initialized:
|
||||
logger.warning(f"Cannot execute tool '{tool_name}' on {server_name}: Manager not initialized.")
|
||||
return None
|
||||
|
||||
client = self.servers.get(server_name)
|
||||
if not client:
|
||||
logger.error(f"Cannot execute tool: Server '{server_name}' not found.")
|
||||
return None
|
||||
|
||||
if not self._loop or not self._loop.is_running():
|
||||
logger.error(f"Cannot execute tool '{tool_name}': Event loop not running.")
|
||||
return None
|
||||
|
||||
logger.info(f"Executing tool '{tool_name}' on server '{server_name}' with args: {arguments}")
|
||||
|
||||
# Run the client's call_tool coroutine in the background loop
|
||||
future = asyncio.run_coroutine_threadsafe(client.call_tool(tool_name, arguments), self._loop)
|
||||
try:
|
||||
result = future.result(timeout=EXECUTE_TOOL_TIMEOUT)
|
||||
# MCPClient.call_tool returns the result dict or an error dict or None
|
||||
if result is None:
|
||||
logger.error(f"Tool execution '{tool_name}' on {server_name} failed (timeout or comm error).")
|
||||
elif isinstance(result, dict) and "error" in result:
|
||||
logger.error(f"Tool execution '{tool_name}' on {server_name} returned error: {result['error']}")
|
||||
else:
|
||||
logger.info(f"Tool '{tool_name}' execution successful.")
|
||||
return result # Return result dict, error dict, or None
|
||||
except TimeoutError:
|
||||
logger.error(f"Tool execution timed out after {EXECUTE_TOOL_TIMEOUT}s for '{tool_name}' on {server_name}.")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Exception during tool execution future result for '{tool_name}' on {server_name}: {e}", exc_info=True)
|
||||
return None
|
||||
128
src/custom_mcp/process.py
Normal file
128
src/custom_mcp/process.py
Normal file
@@ -0,0 +1,128 @@
|
||||
# src/mcp/process.py
|
||||
"""Async utilities for managing MCP server subprocesses."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def start_mcp_process(command: str, args: list[str], config_env: dict[str, str]) -> asyncio.subprocess.Process:
|
||||
"""
|
||||
Starts an MCP server subprocess using asyncio.create_subprocess_shell.
|
||||
|
||||
Handles argument expansion and environment merging.
|
||||
|
||||
Args:
|
||||
command: The main command executable.
|
||||
args: A list of arguments for the command.
|
||||
config_env: Server-specific environment variables from config.
|
||||
|
||||
Returns:
|
||||
The started asyncio.subprocess.Process object.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the command is not found.
|
||||
Exception: For other errors during subprocess creation.
|
||||
"""
|
||||
logger.debug(f"Preparing to start process for command: {command}")
|
||||
|
||||
# --- Add tilde expansion for arguments ---
|
||||
expanded_args = []
|
||||
try:
|
||||
for arg in args:
|
||||
if isinstance(arg, str) and "~" in arg:
|
||||
expanded_args.append(os.path.expanduser(arg))
|
||||
else:
|
||||
# Ensure all args are strings for list2cmdline
|
||||
expanded_args.append(str(arg))
|
||||
logger.debug(f"Expanded args: {expanded_args}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error expanding arguments for {command}: {e}", exc_info=True)
|
||||
raise ValueError(f"Failed to expand arguments: {e}") from e
|
||||
|
||||
# --- Merge os.environ with config_env ---
|
||||
merged_env = {**os.environ, **config_env}
|
||||
# logger.debug(f"Merged environment prepared (keys: {list(merged_env.keys())})") # Avoid logging values
|
||||
|
||||
# Combine command and expanded args into a single string for shell execution
|
||||
try:
|
||||
cmd_string = subprocess.list2cmdline([command] + expanded_args)
|
||||
logger.debug(f"Executing shell command: {cmd_string}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating command string: {e}", exc_info=True)
|
||||
raise ValueError(f"Failed to create command string: {e}") from e
|
||||
|
||||
# --- Start the subprocess using shell ---
|
||||
try:
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
cmd_string,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
env=merged_env,
|
||||
)
|
||||
logger.info(f"Subprocess started (PID: {process.pid}) for command: {command}")
|
||||
return process
|
||||
except FileNotFoundError:
|
||||
logger.error(f"Command not found: '{command}' when trying to execute '{cmd_string}'")
|
||||
raise # Re-raise specific error
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create subprocess for '{cmd_string}': {e}", exc_info=True)
|
||||
raise # Re-raise other errors
|
||||
|
||||
|
||||
async def stop_mcp_process(process: asyncio.subprocess.Process, server_name: str = "MCP Server"):
|
||||
"""
|
||||
Attempts to gracefully stop the MCP server subprocess.
|
||||
|
||||
Args:
|
||||
process: The asyncio.subprocess.Process object to stop.
|
||||
server_name: A name for logging purposes.
|
||||
"""
|
||||
if process is None or process.returncode is not None:
|
||||
logger.debug(f"Process {server_name} (PID: {process.pid if process else 'N/A'}) already stopped or not started.")
|
||||
return
|
||||
|
||||
pid = process.pid
|
||||
logger.info(f"Attempting to stop process {server_name} (PID: {pid})...")
|
||||
|
||||
# Close stdin first
|
||||
if process.stdin and not process.stdin.is_closing():
|
||||
try:
|
||||
process.stdin.close()
|
||||
await process.stdin.wait_closed()
|
||||
logger.debug(f"Stdin closed for {server_name} (PID: {pid})")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing stdin for {server_name} (PID: {pid}): {e}")
|
||||
|
||||
# Attempt graceful termination
|
||||
try:
|
||||
process.terminate()
|
||||
logger.debug(f"Sent terminate signal to {server_name} (PID: {pid})")
|
||||
await asyncio.wait_for(process.wait(), timeout=5.0)
|
||||
logger.info(f"Process {server_name} (PID: {pid}) terminated gracefully (return code: {process.returncode}).")
|
||||
except TimeoutError:
|
||||
logger.warning(f"Process {server_name} (PID: {pid}) did not terminate gracefully after 5s, killing.")
|
||||
try:
|
||||
process.kill()
|
||||
await process.wait() # Wait for kill to complete
|
||||
logger.info(f"Process {server_name} (PID: {pid}) killed (return code: {process.returncode}).")
|
||||
except ProcessLookupError:
|
||||
logger.warning(f"Process {server_name} (PID: {pid}) already exited before kill.")
|
||||
except Exception as e_kill:
|
||||
logger.error(f"Error killing process {server_name} (PID: {pid}): {e_kill}")
|
||||
except ProcessLookupError:
|
||||
logger.warning(f"Process {server_name} (PID: {pid}) already exited before termination.")
|
||||
except Exception as e_term:
|
||||
logger.error(f"Error during termination of {server_name} (PID: {pid}): {e_term}")
|
||||
# Attempt kill as fallback if terminate failed and process might still be running
|
||||
if process.returncode is None:
|
||||
try:
|
||||
process.kill()
|
||||
await process.wait()
|
||||
logger.info(f"Process {server_name} (PID: {pid}) killed after termination error (return code: {process.returncode}).")
|
||||
except Exception as e_kill_fallback:
|
||||
logger.error(f"Error killing process {server_name} (PID: {pid}) after termination error: {e_kill_fallback}")
|
||||
75
src/custom_mcp/protocol.py
Normal file
75
src/custom_mcp/protocol.py
Normal file
@@ -0,0 +1,75 @@
|
||||
# src/mcp/protocol.py
|
||||
"""Async utilities for MCP JSON-RPC communication over streams."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def send_request(writer: asyncio.StreamWriter, request_dict: dict[str, Any]):
|
||||
"""
|
||||
Sends a JSON-RPC request dictionary to the MCP server's stdin stream.
|
||||
|
||||
Args:
|
||||
writer: The asyncio StreamWriter connected to the process stdin.
|
||||
request_dict: The request dictionary to send.
|
||||
|
||||
Raises:
|
||||
ConnectionResetError: If the connection is lost during send.
|
||||
Exception: For other stream writing errors.
|
||||
"""
|
||||
try:
|
||||
request_json = json.dumps(request_dict) + "\n"
|
||||
writer.write(request_json.encode("utf-8"))
|
||||
await writer.drain()
|
||||
logger.debug(f"Sent request ID {request_dict.get('id')}: {request_json.strip()}")
|
||||
except ConnectionResetError:
|
||||
logger.error(f"Connection lost while sending request ID {request_dict.get('id')}")
|
||||
raise # Re-raise for the caller (MCPClient) to handle
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending request ID {request_dict.get('id')}: {e}", exc_info=True)
|
||||
raise # Re-raise for the caller
|
||||
|
||||
|
||||
async def read_response(reader: asyncio.StreamReader, timeout: float) -> dict[str, Any] | None:
|
||||
"""
|
||||
Reads and parses a JSON-RPC response line from the MCP server's stdout stream.
|
||||
|
||||
Args:
|
||||
reader: The asyncio StreamReader connected to the process stdout.
|
||||
timeout: Seconds to wait for a response line.
|
||||
|
||||
Returns:
|
||||
The parsed response dictionary, or None if timeout or error occurs.
|
||||
"""
|
||||
response_str = None
|
||||
try:
|
||||
response_json = await asyncio.wait_for(reader.readline(), timeout=timeout)
|
||||
if not response_json:
|
||||
logger.warning("Received empty response line (EOF?).")
|
||||
return None
|
||||
|
||||
response_str = response_json.decode("utf-8").strip()
|
||||
if not response_str:
|
||||
logger.warning("Received empty response string after strip.")
|
||||
return None
|
||||
|
||||
logger.debug(f"Received response line: {response_str}")
|
||||
response_dict = json.loads(response_str)
|
||||
return response_dict
|
||||
|
||||
except TimeoutError:
|
||||
logger.error(f"Timeout ({timeout}s) waiting for response.")
|
||||
return None
|
||||
except asyncio.IncompleteReadError:
|
||||
logger.warning("Connection closed while waiting for response.")
|
||||
return None
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Error decoding JSON response: {e}. Response: '{response_str}'")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading response: {e}", exc_info=True)
|
||||
return None
|
||||
@@ -1,5 +0,0 @@
|
||||
"""Custom MCP client implementation focused on OpenAI integration."""
|
||||
|
||||
from .client import MCPClient, run_interaction
|
||||
|
||||
__all__ = ["MCPClient", "run_interaction"]
|
||||
@@ -1,550 +0,0 @@
|
||||
"""Custom MCP client implementation with JSON-RPC and OpenAI integration."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
# Get a logger for this module
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPClient:
|
||||
"""Lightweight MCP client with JSON-RPC communication."""
|
||||
|
||||
def __init__(self, server_name: str, command: str, args: list[str] | None = None, env: dict[str, str] | None = None):
|
||||
self.server_name = server_name
|
||||
self.command = command
|
||||
self.args = args or []
|
||||
self.env = env or {}
|
||||
self.process = None
|
||||
self.tools = []
|
||||
self.request_id = 0
|
||||
self.responses = {}
|
||||
self._shutdown = False
|
||||
# Use a logger specific to this client instance
|
||||
self.logger = logging.getLogger(f"{__name__}.{self.server_name}")
|
||||
|
||||
async def _receive_loop(self):
|
||||
"""Listen for responses from the MCP server."""
|
||||
try:
|
||||
while self.process and self.process.stdout and not self.process.stdout.at_eof():
|
||||
line_bytes = await self.process.stdout.readline()
|
||||
if not line_bytes:
|
||||
self.logger.debug("STDOUT EOF reached.")
|
||||
break
|
||||
line_str = line_bytes.decode().strip()
|
||||
self.logger.debug(f"STDOUT Raw line: {line_str}")
|
||||
try:
|
||||
message = json.loads(line_str)
|
||||
if "jsonrpc" in message and "id" in message and ("result" in message or "error" in message):
|
||||
self.logger.debug(f"STDOUT Parsed response for ID {message['id']}")
|
||||
self.responses[message["id"]] = message
|
||||
elif "jsonrpc" in message and "method" in message:
|
||||
self.logger.debug(f"STDOUT Received notification: {message.get('method')}")
|
||||
else:
|
||||
self.logger.debug(f"STDOUT Parsed non-response/notification JSON: {message}")
|
||||
except json.JSONDecodeError:
|
||||
self.logger.warning("STDOUT Failed to parse line as JSON.")
|
||||
except Exception as e:
|
||||
self.logger.error(f"STDOUT Error processing line: {e}", exc_info=True)
|
||||
except asyncio.CancelledError:
|
||||
self.logger.debug("STDOUT Receive loop cancelled.")
|
||||
except Exception as e:
|
||||
self.logger.error(f"STDOUT Receive loop error: {e}", exc_info=True)
|
||||
finally:
|
||||
self.logger.debug("STDOUT Receive loop finished.")
|
||||
|
||||
async def _stderr_loop(self):
|
||||
"""Listen for stderr messages from the MCP server."""
|
||||
try:
|
||||
while self.process and self.process.stderr and not self.process.stderr.at_eof():
|
||||
line_bytes = await self.process.stderr.readline()
|
||||
if not line_bytes:
|
||||
self.logger.debug("STDERR EOF reached.")
|
||||
break
|
||||
line_str = line_bytes.decode().strip()
|
||||
self.logger.warning(f"STDERR: {line_str}") # Log stderr as warning
|
||||
except asyncio.CancelledError:
|
||||
self.logger.debug("STDERR Stderr loop cancelled.")
|
||||
except Exception as e:
|
||||
self.logger.error(f"STDERR Stderr loop error: {e}", exc_info=True)
|
||||
finally:
|
||||
self.logger.debug("STDERR Stderr loop finished.")
|
||||
|
||||
async def _send_message(self, message: dict) -> bool:
|
||||
"""Send a JSON-RPC message to the MCP server."""
|
||||
if not self.process or not self.process.stdin:
|
||||
self.logger.warning("STDIN Cannot send message, process or stdin not available.")
|
||||
return False
|
||||
try:
|
||||
data = json.dumps(message) + "\n"
|
||||
self.logger.debug(f"STDIN Sending: {data.strip()}")
|
||||
self.process.stdin.write(data.encode())
|
||||
await self.process.stdin.drain()
|
||||
return True
|
||||
except ConnectionResetError:
|
||||
self.logger.error("STDIN Connection reset while sending message.")
|
||||
self.process = None # Mark process as dead
|
||||
return False
|
||||
except Exception as e:
|
||||
self.logger.error(f"STDIN Error sending message: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
async def start(self) -> bool:
|
||||
"""Start the MCP server process."""
|
||||
self.logger.info("Attempting to start server...")
|
||||
# Expand ~ in paths and prepare args
|
||||
expanded_args = []
|
||||
try:
|
||||
for a in self.args:
|
||||
if isinstance(a, str) and "~" in a:
|
||||
expanded_args.append(os.path.expanduser(a))
|
||||
else:
|
||||
expanded_args.append(str(a)) # Ensure all args are strings
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error expanding arguments: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
# Set up environment
|
||||
env_vars = os.environ.copy()
|
||||
if self.env:
|
||||
env_vars.update(self.env)
|
||||
|
||||
self.logger.debug(f"Command: {self.command}")
|
||||
self.logger.debug(f"Expanded Args: {expanded_args}")
|
||||
# Avoid logging full env unless necessary for debugging sensitive info
|
||||
# self.logger.debug(f"Environment: {env_vars}")
|
||||
|
||||
try:
|
||||
# Start the subprocess
|
||||
self.logger.debug("Creating subprocess...")
|
||||
self.process = await asyncio.create_subprocess_exec(
|
||||
self.command,
|
||||
*expanded_args,
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE, # Capture stderr
|
||||
env=env_vars,
|
||||
)
|
||||
self.logger.info(f"Subprocess created with PID: {self.process.pid}")
|
||||
|
||||
# Start the receive loops
|
||||
asyncio.create_task(self._receive_loop())
|
||||
asyncio.create_task(self._stderr_loop()) # Start stderr loop
|
||||
|
||||
# Initialize the server
|
||||
self.logger.debug("Attempting initialization handshake...")
|
||||
init_success = await self._initialize()
|
||||
if init_success:
|
||||
self.logger.info("Initialization handshake successful (or skipped).")
|
||||
# Add delay after successful start
|
||||
await asyncio.sleep(0.5)
|
||||
self.logger.debug("Post-initialization delay complete.")
|
||||
return True
|
||||
else:
|
||||
self.logger.error("Initialization handshake failed.")
|
||||
await self.stop() # Ensure cleanup if init fails
|
||||
return False
|
||||
except FileNotFoundError:
|
||||
self.logger.error(f"Error starting subprocess: Command not found: '{self.command}'")
|
||||
return False
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error starting subprocess: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
async def _initialize(self) -> bool:
|
||||
"""Initialize the MCP server connection. Modified to not wait for response."""
|
||||
self.logger.debug("Sending 'initialize' request...")
|
||||
if not self.process:
|
||||
self.logger.warning("Cannot initialize, process not running.")
|
||||
return False
|
||||
|
||||
# Send initialize request
|
||||
self.request_id += 1
|
||||
req_id = self.request_id
|
||||
initialize_req = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": req_id,
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"clientInfo": {"name": "CustomMCPClient", "version": "1.0.0"},
|
||||
"capabilities": {}, # Add empty capabilities object
|
||||
},
|
||||
}
|
||||
if not await self._send_message(initialize_req):
|
||||
self.logger.warning("Failed to send 'initialize' request.")
|
||||
# Continue anyway for non-compliant servers
|
||||
|
||||
# Send initialized notification immediately
|
||||
self.logger.debug("Sending 'initialized' notification...")
|
||||
notify = {"jsonrpc": "2.0", "method": "notifications/initialized"}
|
||||
if await self._send_message(notify):
|
||||
self.logger.debug("'initialized' notification sent.")
|
||||
else:
|
||||
self.logger.warning("Failed to send 'initialized' notification.")
|
||||
# Still return True as the server might be running
|
||||
|
||||
self.logger.info("Skipping wait for 'initialize' response (assuming non-compliant server).")
|
||||
return True # Assume success without waiting for response
|
||||
|
||||
async def list_tools(self) -> list[dict]:
|
||||
"""List available tools from the MCP server."""
|
||||
if not self.process:
|
||||
self.logger.warning("Cannot list tools, process not running.")
|
||||
return []
|
||||
|
||||
self.logger.debug("Sending 'tools/list' request...")
|
||||
self.request_id += 1
|
||||
req_id = self.request_id
|
||||
req = {"jsonrpc": "2.0", "id": req_id, "method": "tools/list", "params": {}}
|
||||
if not await self._send_message(req):
|
||||
self.logger.error("Failed to send 'tools/list' request.")
|
||||
return []
|
||||
|
||||
# Wait for response
|
||||
self.logger.debug(f"Waiting for 'tools/list' response (ID: {req_id})...")
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
timeout = 10 # seconds
|
||||
while asyncio.get_event_loop().time() - start_time < timeout:
|
||||
if req_id in self.responses:
|
||||
resp = self.responses.pop(req_id)
|
||||
self.logger.debug(f"Received 'tools/list' response: {resp}")
|
||||
|
||||
if "error" in resp:
|
||||
self.logger.error(f"'tools/list' error response: {resp['error']}")
|
||||
return []
|
||||
|
||||
if "result" in resp and "tools" in resp["result"]:
|
||||
self.tools = resp["result"]["tools"]
|
||||
self.logger.info(f"Successfully listed tools: {len(self.tools)}")
|
||||
return self.tools
|
||||
else:
|
||||
self.logger.error("Invalid 'tools/list' response format.")
|
||||
return []
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
self.logger.error(f"'tools/list' request timed out after {timeout} seconds.")
|
||||
return []
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: dict) -> dict:
|
||||
"""Call a tool on the MCP server."""
|
||||
if not self.process:
|
||||
self.logger.warning(f"Cannot call tool '{tool_name}', process not running.")
|
||||
return {"error": "Server not started"}
|
||||
|
||||
self.logger.debug(f"Sending 'tools/call' request for tool '{tool_name}'...")
|
||||
self.request_id += 1
|
||||
req_id = self.request_id
|
||||
req = {"jsonrpc": "2.0", "id": req_id, "method": "tools/call", "params": {"name": tool_name, "arguments": arguments}}
|
||||
if not await self._send_message(req):
|
||||
self.logger.error("Failed to send 'tools/call' request.")
|
||||
return {"error": "Failed to send tool call request"}
|
||||
|
||||
# Wait for response
|
||||
self.logger.debug(f"Waiting for 'tools/call' response (ID: {req_id})...")
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
timeout = 30 # seconds
|
||||
while asyncio.get_event_loop().time() - start_time < timeout:
|
||||
if req_id in self.responses:
|
||||
resp = self.responses.pop(req_id)
|
||||
self.logger.debug(f"Received 'tools/call' response: {resp}")
|
||||
|
||||
if "error" in resp:
|
||||
self.logger.error(f"'tools/call' error response: {resp['error']}")
|
||||
return {"error": str(resp["error"])}
|
||||
|
||||
if "result" in resp:
|
||||
self.logger.info(f"Tool '{tool_name}' executed successfully.")
|
||||
return resp["result"]
|
||||
else:
|
||||
self.logger.error("Invalid 'tools/call' response format.")
|
||||
return {"error": "Invalid tool call response format"}
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
self.logger.error(f"Tool call '{tool_name}' timed out after {timeout} seconds.")
|
||||
return {"error": f"Tool call timed out after {timeout}s"}
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the MCP server process."""
|
||||
self.logger.info("Attempting to stop server...")
|
||||
if self._shutdown or not self.process:
|
||||
self.logger.debug("Server already stopped or not running.")
|
||||
return
|
||||
|
||||
self._shutdown = True
|
||||
proc = self.process # Keep a local reference
|
||||
self.process = None # Prevent further operations
|
||||
|
||||
try:
|
||||
# Send shutdown notification
|
||||
self.logger.debug("Sending 'shutdown' notification...")
|
||||
notify = {"jsonrpc": "2.0", "method": "shutdown"}
|
||||
await self._send_message(notify) # Use the method which now handles None process
|
||||
await asyncio.sleep(0.5) # Give server time to process
|
||||
|
||||
# Close stdin
|
||||
if proc and proc.stdin:
|
||||
try:
|
||||
if not proc.stdin.is_closing():
|
||||
proc.stdin.close()
|
||||
await proc.stdin.wait_closed()
|
||||
self.logger.debug("Stdin closed.")
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Error closing stdin: {e}", exc_info=True)
|
||||
|
||||
# Terminate the process
|
||||
if proc:
|
||||
self.logger.debug(f"Terminating process {proc.pid}...")
|
||||
proc.terminate()
|
||||
try:
|
||||
await asyncio.wait_for(proc.wait(), timeout=2.0)
|
||||
self.logger.info(f"Process {proc.pid} terminated gracefully.")
|
||||
except TimeoutError:
|
||||
self.logger.warning(f"Process {proc.pid} did not terminate gracefully, killing...")
|
||||
proc.kill()
|
||||
await proc.wait()
|
||||
self.logger.info(f"Process {proc.pid} killed.")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error waiting for process termination: {e}", exc_info=True)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error during shutdown sequence: {e}", exc_info=True)
|
||||
finally:
|
||||
self.logger.debug("Stop sequence finished.")
|
||||
# Ensure self.process is None even if errors occurred
|
||||
self.process = None
|
||||
|
||||
|
||||
async def process_tool_call(tool_call: dict, servers: dict[str, MCPClient]) -> dict:
|
||||
"""Process a tool call from OpenAI."""
|
||||
func_name = tool_call["function"]["name"]
|
||||
try:
|
||||
func_args = json.loads(tool_call["function"].get("arguments", "{}"))
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Invalid tool arguments format for {func_name}: {e}")
|
||||
return {"error": "Invalid arguments format"}
|
||||
|
||||
# Parse server_name and tool_name from function name
|
||||
parts = func_name.split("_", 1)
|
||||
if len(parts) != 2:
|
||||
logger.error(f"Invalid tool function name format: {func_name}")
|
||||
return {"error": "Invalid function name format"}
|
||||
|
||||
server_name, tool_name = parts
|
||||
|
||||
if server_name not in servers:
|
||||
logger.error(f"Tool call for unknown server: {server_name}")
|
||||
return {"error": f"Unknown server: {server_name}"}
|
||||
|
||||
# Call the tool
|
||||
return await servers[server_name].call_tool(tool_name, func_args)
|
||||
|
||||
|
||||
async def run_interaction(
|
||||
user_query: str,
|
||||
model_name: str,
|
||||
api_key: str,
|
||||
base_url: str | None,
|
||||
mcp_config: dict,
|
||||
stream: bool = False,
|
||||
) -> dict | AsyncGenerator:
|
||||
"""
|
||||
Run an interaction with OpenAI using MCP server tools.
|
||||
|
||||
Args:
|
||||
user_query: The user's input query.
|
||||
model_name: The model to use for processing.
|
||||
api_key: The OpenAI API key.
|
||||
base_url: The OpenAI API base URL (optional).
|
||||
mcp_config: The MCP configuration dictionary (for servers).
|
||||
stream: Whether to stream the response.
|
||||
|
||||
Returns:
|
||||
Dictionary containing response or AsyncGenerator for streaming.
|
||||
"""
|
||||
# Validate passed arguments
|
||||
if not api_key:
|
||||
logger.error("API key is missing.")
|
||||
if not stream:
|
||||
return {"error": "API key is missing."}
|
||||
else:
|
||||
|
||||
async def error_gen():
|
||||
yield {"error": "API key is missing."}
|
||||
|
||||
return error_gen()
|
||||
|
||||
# Start MCP servers using mcp_config
|
||||
servers = {}
|
||||
all_functions = []
|
||||
if mcp_config.get("mcpServers"): # Use mcp_config here
|
||||
for server_name, server_config in mcp_config["mcpServers"].items(): # Use mcp_config here
|
||||
client = MCPClient(server_name=server_name, command=server_config.get("command"), args=server_config.get("args", []), env=server_config.get("env", {}))
|
||||
|
||||
if await client.start():
|
||||
tools = await client.list_tools()
|
||||
for tool in tools:
|
||||
# Ensure parameters is a dict, default to empty if missing or not dict
|
||||
params = tool.get("inputSchema", {})
|
||||
if not isinstance(params, dict):
|
||||
logger.warning(f"Tool '{tool['name']}' for server '{server_name}' has non-dict inputSchema, defaulting to empty.")
|
||||
params = {}
|
||||
|
||||
all_functions.append({
|
||||
"type": "function", # Explicitly set type for clarity with newer OpenAI API
|
||||
"function": {"name": f"{server_name}_{tool['name']}", "description": tool.get("description", ""), "parameters": params},
|
||||
})
|
||||
servers[server_name] = client
|
||||
else:
|
||||
logger.warning(f"Failed to start MCP server '{server_name}', it will be unavailable.")
|
||||
else:
|
||||
logger.info("No mcpServers defined in configuration.")
|
||||
|
||||
# Use passed api_key and base_url
|
||||
openai_client = AsyncOpenAI(api_key=api_key, base_url=base_url) # Use arguments
|
||||
messages = [{"role": "user", "content": user_query}]
|
||||
tool_defs = [{"type": "function", "function": f["function"]} for f in all_functions] if all_functions else None
|
||||
|
||||
if stream:
|
||||
|
||||
async def response_generator():
|
||||
active_servers = list(servers.values()) # Keep track for cleanup
|
||||
try:
|
||||
while True:
|
||||
logger.debug(f"Calling OpenAI with messages: {messages}")
|
||||
logger.debug(f"Calling OpenAI with tools: {tool_defs}")
|
||||
# Get OpenAI response
|
||||
try:
|
||||
response = await openai_client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
tools=tool_defs,
|
||||
tool_choice="auto" if tool_defs else None, # Only set tool_choice if tools exist
|
||||
stream=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI API error: {e}", exc_info=True)
|
||||
yield {"error": f"OpenAI API error: {e}"}
|
||||
break
|
||||
|
||||
# Process streaming response
|
||||
full_response_content = ""
|
||||
tool_calls = []
|
||||
async for chunk in response:
|
||||
delta = chunk.choices[0].delta
|
||||
if delta.content:
|
||||
content = delta.content
|
||||
full_response_content += content
|
||||
yield {"assistant_text": content, "is_chunk": True}
|
||||
|
||||
if delta.tool_calls:
|
||||
for tc in delta.tool_calls:
|
||||
# Initialize tool call structure if it's the first chunk for this index
|
||||
if tc.index >= len(tool_calls):
|
||||
tool_calls.append({"id": "", "type": "function", "function": {"name": "", "arguments": ""}})
|
||||
|
||||
# Append parts as they arrive
|
||||
if tc.id:
|
||||
tool_calls[tc.index]["id"] = tc.id
|
||||
if tc.function and tc.function.name:
|
||||
tool_calls[tc.index]["function"]["name"] = tc.function.name
|
||||
if tc.function and tc.function.arguments:
|
||||
tool_calls[tc.index]["function"]["arguments"] += tc.function.arguments
|
||||
|
||||
# Add assistant message with content and potential tool calls
|
||||
assistant_message = {"role": "assistant", "content": full_response_content}
|
||||
if tool_calls:
|
||||
# Filter out incomplete tool calls just in case
|
||||
valid_tool_calls = [tc for tc in tool_calls if tc["id"] and tc["function"]["name"]]
|
||||
if valid_tool_calls:
|
||||
assistant_message["tool_calls"] = valid_tool_calls
|
||||
else:
|
||||
logger.warning("Received tool call chunks but couldn't assemble valid tool calls.")
|
||||
|
||||
messages.append(assistant_message)
|
||||
logger.debug(f"Assistant message added: {assistant_message}")
|
||||
|
||||
# Handle tool calls if any were successfully assembled
|
||||
if "tool_calls" in assistant_message:
|
||||
tool_results = []
|
||||
for tc in assistant_message["tool_calls"]:
|
||||
logger.info(f"Processing tool call: {tc['function']['name']}")
|
||||
result = await process_tool_call(tc, servers)
|
||||
tool_results.append({"role": "tool", "content": json.dumps(result), "tool_call_id": tc["id"]})
|
||||
|
||||
messages.extend(tool_results)
|
||||
logger.debug(f"Tool results added: {tool_results}")
|
||||
# Loop back to call OpenAI again with tool results
|
||||
else:
|
||||
# No tool calls, interaction finished for this turn
|
||||
yield {"assistant_text": full_response_content, "is_chunk": False, "final": True} # Signal final chunk
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during streaming interaction: {e}", exc_info=True)
|
||||
yield {"error": f"Interaction error: {e}"}
|
||||
finally:
|
||||
# Clean up servers
|
||||
logger.debug("Cleaning up MCP servers (stream)...")
|
||||
for server in active_servers:
|
||||
await server.stop()
|
||||
logger.debug("MCP server cleanup finished (stream).")
|
||||
|
||||
return response_generator()
|
||||
|
||||
else: # Non-streaming case
|
||||
active_servers = list(servers.values()) # Keep track for cleanup
|
||||
try:
|
||||
while True:
|
||||
logger.debug(f"Calling OpenAI with messages: {messages}")
|
||||
logger.debug(f"Calling OpenAI with tools: {tool_defs}")
|
||||
# Get OpenAI response
|
||||
try:
|
||||
response = await openai_client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
tools=tool_defs,
|
||||
tool_choice="auto" if tool_defs else None, # Only set tool_choice if tools exist
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI API error: {e}", exc_info=True)
|
||||
return {"error": f"OpenAI API error: {e}"}
|
||||
|
||||
message = response.choices[0].message
|
||||
messages.append(message)
|
||||
logger.debug(f"OpenAI response message: {message}")
|
||||
|
||||
# Handle tool calls
|
||||
if message.tool_calls:
|
||||
tool_results = []
|
||||
for tc in message.tool_calls:
|
||||
logger.info(f"Processing tool call: {tc.function.name}")
|
||||
# Reconstruct dict for process_tool_call
|
||||
tool_call_dict = {"id": tc.id, "type": "function", "function": {"name": tc.function.name, "arguments": tc.function.arguments}}
|
||||
result = await process_tool_call(tool_call_dict, servers)
|
||||
tool_results.append({"role": "tool", "content": json.dumps(result), "tool_call_id": tc.id})
|
||||
|
||||
messages.extend(tool_results)
|
||||
logger.debug(f"Tool results added: {tool_results}")
|
||||
# Loop back to call OpenAI again with tool results
|
||||
else:
|
||||
# No tool calls, interaction finished
|
||||
logger.info("Interaction finished, no tool calls.")
|
||||
return {"assistant_text": message.content or "", "tool_calls": []}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during non-streaming interaction: {e}", exc_info=True)
|
||||
return {"error": f"Interaction error: {e}"}
|
||||
finally:
|
||||
# Clean up servers
|
||||
logger.debug("Cleaning up MCP servers (non-stream)...")
|
||||
for server in active_servers:
|
||||
await server.stop()
|
||||
logger.debug("MCP server cleanup finished (non-stream).")
|
||||
219
src/llm_client.py
Normal file
219
src/llm_client.py
Normal file
@@ -0,0 +1,219 @@
|
||||
# src/llm_client.py
|
||||
"""
|
||||
Generic LLM client supporting multiple providers and MCP tool integration.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from src.custom_mcp.manager import SyncMCPManager # Updated import path
|
||||
from src.providers import BaseProvider, create_llm_provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMClient:
|
||||
"""
|
||||
Handles chat completion requests to various LLM providers through a unified
|
||||
interface, integrating with MCP tools via SyncMCPManager.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider_name: str,
|
||||
api_key: str,
|
||||
mcp_manager: SyncMCPManager,
|
||||
base_url: str | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the LLM client.
|
||||
|
||||
Args:
|
||||
provider_name: Name of the provider (e.g., 'openai', 'anthropic').
|
||||
api_key: API key for the provider.
|
||||
mcp_manager: An initialized instance of SyncMCPManager.
|
||||
base_url: Optional base URL for the provider API.
|
||||
"""
|
||||
logger.info(f"Initializing LLMClient for provider: {provider_name}")
|
||||
self.provider: BaseProvider = create_llm_provider(provider_name, api_key, base_url)
|
||||
self.mcp_manager = mcp_manager
|
||||
self.mcp_tools: list[dict[str, Any]] = []
|
||||
self._refresh_mcp_tools() # Initial tool load
|
||||
|
||||
def _refresh_mcp_tools(self):
|
||||
"""Retrieves the latest tools from the MCP manager."""
|
||||
logger.info("Refreshing MCP tools...")
|
||||
try:
|
||||
self.mcp_tools = self.mcp_manager.list_all_tools()
|
||||
logger.info(f"Refreshed {len(self.mcp_tools)} MCP tools.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error refreshing MCP tools: {e}", exc_info=True)
|
||||
# Keep existing tools if refresh fails
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
model: str,
|
||||
temperature: float = 0.4,
|
||||
max_tokens: int | None = None,
|
||||
stream: bool = True,
|
||||
) -> Generator[str, None, None] | dict[str, Any]:
|
||||
"""
|
||||
Send a chat completion request, handling potential tool calls.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries ({'role': 'user'/'assistant', 'content': ...}).
|
||||
model: Model identifier string.
|
||||
temperature: Sampling temperature.
|
||||
max_tokens: Maximum tokens to generate.
|
||||
stream: Whether to stream the response.
|
||||
|
||||
Returns:
|
||||
If stream=True: A generator yielding content chunks.
|
||||
If stream=False: A dictionary containing the final content or an error.
|
||||
e.g., {"content": "..."} or {"error": "..."}
|
||||
"""
|
||||
# Ensure tools are up-to-date (optional, could be done less frequently)
|
||||
# self._refresh_mcp_tools()
|
||||
|
||||
# Convert tools to the provider-specific format
|
||||
try:
|
||||
provider_tools = self.provider.convert_tools(self.mcp_tools)
|
||||
logger.debug(f"Converted {len(self.mcp_tools)} tools for provider {self.provider.__class__.__name__}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting tools for provider: {e}", exc_info=True)
|
||||
provider_tools = None # Proceed without tools if conversion fails
|
||||
|
||||
try:
|
||||
logger.info(f"Sending chat completion request to provider with model: {model}")
|
||||
response = self.provider.create_chat_completion(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stream=stream,
|
||||
tools=provider_tools,
|
||||
)
|
||||
logger.info("Received response from provider.")
|
||||
|
||||
if stream:
|
||||
# Streaming with tool calls requires more complex handling (like airflow_wingman's
|
||||
# process_tool_calls_and_follow_up). For now, we'll yield the initial stream
|
||||
# and handle tool calls *after* the stream completes if detected (less ideal UX).
|
||||
# A better approach involves checking for tool calls before streaming fully.
|
||||
# This simplified version just streams the first response.
|
||||
logger.info("Streaming response...")
|
||||
# NOTE: This simple version doesn't handle tool calls during streaming well.
|
||||
# It will stream the initial response which might *contain* the tool call request,
|
||||
# but won't execute it within the stream.
|
||||
return self._stream_generator(response)
|
||||
|
||||
else: # Non-streaming
|
||||
logger.info("Processing non-streaming response...")
|
||||
if self.provider.has_tool_calls(response):
|
||||
logger.info("Tool calls detected in response.")
|
||||
# Simplified non-streaming tool call handling (one round)
|
||||
try:
|
||||
tool_calls = self.provider.parse_tool_calls(response)
|
||||
logger.debug(f"Parsed tool calls: {tool_calls}")
|
||||
|
||||
tool_results = []
|
||||
original_message_with_calls = self.provider.get_original_message_with_calls(response) # Provider needs to implement this
|
||||
messages.append(original_message_with_calls) # Add assistant's turn with tool requests
|
||||
|
||||
for tool_call in tool_calls:
|
||||
server_name = tool_call.get("server_name") # Needs to be parsed by provider
|
||||
func_name = tool_call.get("function_name")
|
||||
func_args_str = tool_call.get("arguments")
|
||||
call_id = tool_call.get("id")
|
||||
|
||||
if not server_name or not func_name or func_args_str is None or call_id is None:
|
||||
logger.error(f"Skipping invalid tool call data: {tool_call}")
|
||||
# Add error result?
|
||||
result_content = {"error": "Invalid tool call structure from LLM"}
|
||||
else:
|
||||
try:
|
||||
# Arguments might be a JSON string, parse them
|
||||
arguments = json.loads(func_args_str)
|
||||
logger.info(f"Executing tool '{func_name}' on server '{server_name}' with args: {arguments}")
|
||||
# Execute synchronously using the manager
|
||||
execution_result = self.mcp_manager.execute_tool(server_name, func_name, arguments)
|
||||
logger.debug(f"Tool execution result: {execution_result}")
|
||||
|
||||
if execution_result is None:
|
||||
result_content = {"error": f"Tool execution failed or timed out for {func_name}"}
|
||||
elif isinstance(execution_result, dict) and "error" in execution_result:
|
||||
result_content = execution_result # Propagate error from tool/server
|
||||
else:
|
||||
# Assuming result is the content payload
|
||||
result_content = execution_result
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Failed to parse arguments for tool {func_name}: {func_args_str}")
|
||||
result_content = {"error": f"Invalid arguments format for tool {func_name}"}
|
||||
except Exception as exec_err:
|
||||
logger.error(f"Error executing tool {func_name}: {exec_err}", exc_info=True)
|
||||
result_content = {"error": f"Exception during tool execution: {str(exec_err)}"}
|
||||
|
||||
# Format result for the provider's follow-up message
|
||||
formatted_result = self.provider.format_tool_results(call_id, result_content)
|
||||
tool_results.append(formatted_result)
|
||||
messages.append(formatted_result) # Add tool result message
|
||||
|
||||
# Make follow-up call
|
||||
logger.info("Making follow-up request with tool results...")
|
||||
follow_up_response = self.provider.create_chat_completion(
|
||||
messages=messages, # Now includes assistant's turn and tool results
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stream=False, # Follow-up is non-streaming here
|
||||
tools=provider_tools, # Pass tools again? Some providers might need it.
|
||||
)
|
||||
final_content = self.provider.get_content(follow_up_response)
|
||||
logger.info("Received follow-up response content.")
|
||||
return {"content": final_content}
|
||||
|
||||
except Exception as tool_handling_err:
|
||||
logger.error(f"Error processing tool calls: {tool_handling_err}", exc_info=True)
|
||||
return {"error": f"Failed to handle tool calls: {str(tool_handling_err)}"}
|
||||
|
||||
else: # No tool calls
|
||||
logger.info("No tool calls detected.")
|
||||
content = self.provider.get_content(response)
|
||||
return {"content": content}
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"LLM API Error: {str(e)}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
if stream:
|
||||
# How to signal error in a stream? Yield a specific error message?
|
||||
# This simple generator won't handle it well. Returning an error dict for now.
|
||||
return {"error": error_msg} # Or raise?
|
||||
else:
|
||||
return {"error": error_msg}
|
||||
|
||||
def _stream_generator(self, response: Any) -> Generator[str, None, None]:
|
||||
"""Helper to yield content from the provider's streaming method."""
|
||||
try:
|
||||
# Use yield from for cleaner and potentially more efficient delegation
|
||||
yield from self.provider.get_streaming_content(response)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during streaming: {e}", exc_info=True)
|
||||
yield json.dumps({"error": f"Streaming error: {str(e)}"}) # Yield error as JSON chunk
|
||||
|
||||
|
||||
# Example of how a provider might need to implement get_original_message_with_calls
|
||||
# This would be in the specific provider class (e.g., openai_provider.py)
|
||||
# def get_original_message_with_calls(self, response: Any) -> Dict[str, Any]:
|
||||
# # For OpenAI, the tool calls are usually in the *first* response chunk's choice delta
|
||||
# # or in the non-streaming response's choice message
|
||||
# # Needs careful implementation based on provider's response structure
|
||||
# assistant_message = {
|
||||
# "role": "assistant",
|
||||
# "content": None, # Often null when tool calls are present
|
||||
# "tool_calls": [...] # Extracted tool calls in provider format
|
||||
# }
|
||||
# return assistant_message
|
||||
61
src/llm_models.py
Normal file
61
src/llm_models.py
Normal file
@@ -0,0 +1,61 @@
|
||||
MODELS = {
|
||||
"openai": {
|
||||
"name": "OpenAI",
|
||||
"endpoint": "https://api.openai.com/v1",
|
||||
"models": [
|
||||
{
|
||||
"id": "gpt-4o",
|
||||
"name": "GPT-4o",
|
||||
"default": True,
|
||||
"context_window": 128000,
|
||||
"description": "Input $5/M tokens, Output $15/M tokens",
|
||||
}
|
||||
],
|
||||
},
|
||||
"anthropic": {
|
||||
"name": "Anthropic",
|
||||
"endpoint": "https://api.anthropic.com/v1/messages",
|
||||
"models": [
|
||||
{
|
||||
"id": "claude-3-7-sonnet-20250219",
|
||||
"name": "Claude 3.7 Sonnet",
|
||||
"default": True,
|
||||
"context_window": 200000,
|
||||
"description": "Input $3/M tokens, Output $15/M tokens",
|
||||
},
|
||||
{
|
||||
"id": "claude-3-5-haiku-20241022",
|
||||
"name": "Claude 3.5 Haiku",
|
||||
"default": False,
|
||||
"context_window": 200000,
|
||||
"description": "Input $0.80/M tokens, Output $4/M tokens",
|
||||
},
|
||||
],
|
||||
},
|
||||
"google": {
|
||||
"name": "Google Gemini",
|
||||
"endpoint": "https://generativelanguage.googleapis.com/v1beta/generateContent",
|
||||
"models": [
|
||||
{
|
||||
"id": "gemini-2.0-flash",
|
||||
"name": "Gemini 2.0 Flash",
|
||||
"default": True,
|
||||
"context_window": 1000000,
|
||||
"description": "Input $0.1/M tokens, Output $0.4/M tokens",
|
||||
}
|
||||
],
|
||||
},
|
||||
"openrouter": {
|
||||
"name": "OpenRouter",
|
||||
"endpoint": "https://openrouter.ai/api/v1/chat/completions",
|
||||
"models": [
|
||||
{
|
||||
"id": "custom",
|
||||
"name": "Custom Model",
|
||||
"default": False,
|
||||
"context_window": 128000, # Default context window, will be updated based on model
|
||||
"description": "Enter any model name supported by OpenRouter (e.g., 'anthropic/claude-3-opus', 'meta-llama/llama-2-70b')",
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
@@ -1,234 +0,0 @@
|
||||
"""Synchronous wrapper for managing MCP servers using our custom implementation."""
|
||||
|
||||
import asyncio
|
||||
import importlib.resources
|
||||
import json
|
||||
import logging # Import logging
|
||||
import threading
|
||||
|
||||
from custom_mcp_client import MCPClient, run_interaction
|
||||
|
||||
# Configure basic logging for the application if not already configured
|
||||
# This basic config helps ensure logs are visible during development
|
||||
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
|
||||
# Get a logger for this module
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SyncMCPManager:
|
||||
"""Synchronous wrapper for managing MCP servers and interactions"""
|
||||
|
||||
def __init__(self, config_path: str = "config/mcp_config.json"):
|
||||
self.config_path = config_path
|
||||
self.config = None
|
||||
self.servers = {}
|
||||
self.initialized = False
|
||||
self._lock = threading.Lock()
|
||||
logger.info(f"Initializing SyncMCPManager with config path: {config_path}")
|
||||
self._load_config()
|
||||
|
||||
def _load_config(self):
|
||||
"""Load MCP configuration from JSON file using importlib"""
|
||||
logger.debug(f"Attempting to load MCP config from: {self.config_path}")
|
||||
try:
|
||||
# First try to load as a package resource
|
||||
try:
|
||||
# Try anchoring to the project name defined in pyproject.toml
|
||||
# This *might* work depending on editable installs or context.
|
||||
resource_path = importlib.resources.files("streamlit-chat-app").joinpath(self.config_path)
|
||||
with resource_path.open("r") as f:
|
||||
self.config = json.load(f)
|
||||
logger.debug("Loaded config via importlib.resources anchored to 'streamlit-chat-app'.")
|
||||
# REMOVED: raise FileNotFoundError
|
||||
|
||||
except (ImportError, ModuleNotFoundError, TypeError, FileNotFoundError, NotADirectoryError): # Added NotADirectoryError
|
||||
logger.debug("Failed to load via importlib.resources, falling back to direct file access.")
|
||||
# Fall back to direct file access relative to CWD
|
||||
with open(self.config_path) as f:
|
||||
self.config = json.load(f)
|
||||
logger.debug("Loaded config via direct file access.")
|
||||
|
||||
logger.info("MCP configuration loaded successfully.")
|
||||
logger.debug(f"Config content: {self.config}") # Log content only if loaded
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error(f"MCP config file not found at {self.config_path}")
|
||||
self.config = None
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Error decoding JSON from MCP config file {self.config_path}: {e}")
|
||||
self.config = None
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading MCP config from {self.config_path}: {e}", exc_info=True)
|
||||
self.config = None
|
||||
|
||||
def initialize(self) -> bool:
|
||||
"""Initialize and start all MCP servers synchronously"""
|
||||
logger.info("Initialize requested.")
|
||||
if not self.config:
|
||||
logger.warning("Initialization skipped: No configuration loaded.")
|
||||
return False
|
||||
if not self.config.get("mcpServers"):
|
||||
logger.warning("Initialization skipped: No 'mcpServers' defined in configuration.")
|
||||
return False
|
||||
|
||||
if self.initialized:
|
||||
logger.debug("Initialization skipped: Already initialized.")
|
||||
return True
|
||||
|
||||
with self._lock:
|
||||
if self.initialized: # Double-check after acquiring lock
|
||||
logger.debug("Initialization skipped inside lock: Already initialized.")
|
||||
return True
|
||||
|
||||
logger.info("Starting asynchronous initialization...")
|
||||
# Run async initialization in a new event loop
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop) # Ensure this loop is used by tasks
|
||||
success = loop.run_until_complete(self._async_initialize())
|
||||
loop.close()
|
||||
asyncio.set_event_loop(None) # Clean up
|
||||
|
||||
if success:
|
||||
logger.info("Asynchronous initialization completed successfully.")
|
||||
self.initialized = True
|
||||
else:
|
||||
logger.error("Asynchronous initialization failed.")
|
||||
self.initialized = False # Ensure state is False on failure
|
||||
|
||||
return self.initialized
|
||||
|
||||
async def _async_initialize(self) -> bool:
|
||||
"""Async implementation of server initialization"""
|
||||
logger.debug("Starting _async_initialize...")
|
||||
all_success = True
|
||||
if not self.config or not self.config.get("mcpServers"):
|
||||
logger.warning("_async_initialize: No config or mcpServers found.")
|
||||
return False
|
||||
|
||||
tasks = []
|
||||
server_names = list(self.config["mcpServers"].keys())
|
||||
|
||||
async def start_server(server_name, server_config):
|
||||
logger.info(f"Initializing server: {server_name}")
|
||||
try:
|
||||
client = MCPClient(server_name=server_name, command=server_config.get("command"), args=server_config.get("args", []), env=server_config.get("env", {}))
|
||||
|
||||
logger.debug(f"Attempting to start client for {server_name}...")
|
||||
if await client.start():
|
||||
logger.info(f"Client for {server_name} started successfully.")
|
||||
tools = await client.list_tools()
|
||||
logger.info(f"Tools listed for {server_name}: {len(tools)}")
|
||||
self.servers[server_name] = {"client": client, "tools": tools}
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Failed to start MCP server: {server_name}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing server {server_name}: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
# Start servers concurrently
|
||||
for server_name in server_names:
|
||||
server_config = self.config["mcpServers"][server_name]
|
||||
tasks.append(start_server(server_name, server_config))
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# Check if all servers started successfully
|
||||
all_success = all(results)
|
||||
|
||||
if all_success:
|
||||
logger.debug("_async_initialize completed: All servers started successfully.")
|
||||
else:
|
||||
failed_servers = [server_names[i] for i, res in enumerate(results) if not res]
|
||||
logger.error(f"_async_initialize completed with failures. Failed servers: {failed_servers}")
|
||||
# Optionally shutdown servers that did start if partial success is not desired
|
||||
# await self._async_shutdown() # Uncomment to enforce all-or-nothing startup
|
||||
|
||||
return all_success
|
||||
|
||||
def shutdown(self):
|
||||
"""Shut down all MCP servers synchronously"""
|
||||
logger.info("Shutdown requested.")
|
||||
if not self.initialized:
|
||||
logger.debug("Shutdown skipped: Not initialized.")
|
||||
return
|
||||
|
||||
with self._lock:
|
||||
if not self.initialized:
|
||||
logger.debug("Shutdown skipped inside lock: Not initialized.")
|
||||
return
|
||||
|
||||
logger.info("Starting asynchronous shutdown...")
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_until_complete(self._async_shutdown())
|
||||
loop.close()
|
||||
asyncio.set_event_loop(None)
|
||||
|
||||
self.servers = {}
|
||||
self.initialized = False
|
||||
logger.info("Shutdown complete.")
|
||||
|
||||
async def _async_shutdown(self):
|
||||
"""Async implementation of server shutdown"""
|
||||
logger.debug("Starting _async_shutdown...")
|
||||
tasks = []
|
||||
for server_name, server_info in self.servers.items():
|
||||
logger.debug(f"Initiating shutdown for server: {server_name}")
|
||||
tasks.append(server_info["client"].stop())
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
for i, result in enumerate(results):
|
||||
server_name = list(self.servers.keys())[i]
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"Error shutting down server {server_name}: {result}", exc_info=result)
|
||||
else:
|
||||
logger.debug(f"Shutdown completed for server: {server_name}")
|
||||
logger.debug("_async_shutdown finished.")
|
||||
|
||||
# Updated process_query signature
|
||||
def process_query(self, query: str, model_name: str, api_key: str, base_url: str | None) -> dict:
|
||||
"""
|
||||
Process a query using MCP tools synchronously
|
||||
|
||||
Args:
|
||||
query: The user's input query.
|
||||
model_name: The model to use for processing.
|
||||
api_key: The OpenAI API key.
|
||||
base_url: The OpenAI API base URL.
|
||||
|
||||
Returns:
|
||||
Dictionary containing response or error.
|
||||
"""
|
||||
if not self.initialized and not self.initialize():
|
||||
logger.error("process_query called but MCP manager failed to initialize.")
|
||||
return {"error": "Failed to initialize MCP servers"}
|
||||
|
||||
logger.debug(f"Processing query synchronously: '{query}'")
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
# Pass api_key and base_url to _async_process_query
|
||||
result = loop.run_until_complete(self._async_process_query(query, model_name, api_key, base_url))
|
||||
logger.debug(f"Synchronous query processing result: {result}")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error during synchronous query processing: {e}", exc_info=True)
|
||||
return {"error": f"Processing error: {str(e)}"}
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
# Updated _async_process_query signature
|
||||
async def _async_process_query(self, query: str, model_name: str, api_key: str, base_url: str | None) -> dict:
|
||||
"""Async implementation of query processing"""
|
||||
# Pass api_key, base_url, and the MCP config separately to run_interaction
|
||||
return await run_interaction(
|
||||
user_query=query,
|
||||
model_name=model_name,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
mcp_config=self.config, # self.config only contains MCP server definitions now
|
||||
stream=False,
|
||||
)
|
||||
69
src/providers/__init__.py
Normal file
69
src/providers/__init__.py
Normal file
@@ -0,0 +1,69 @@
|
||||
# src/providers/__init__.py
|
||||
import logging
|
||||
|
||||
from providers.anthropic_provider import AnthropicProvider
|
||||
from providers.base import BaseProvider
|
||||
from providers.openai_provider import OpenAIProvider
|
||||
|
||||
# from providers.google_provider import GoogleProvider
|
||||
# from providers.openrouter_provider import OpenRouterProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Map provider names (lowercase) to their corresponding class implementations
|
||||
PROVIDER_MAP: dict[str, type[BaseProvider]] = {
|
||||
"openai": OpenAIProvider,
|
||||
"anthropic": AnthropicProvider,
|
||||
# "google": GoogleProvider,
|
||||
# "openrouter": OpenRouterProvider, # OpenRouter can often use OpenAIProvider with custom base_url
|
||||
}
|
||||
|
||||
|
||||
def register_provider(name: str, provider_class: type[BaseProvider]):
|
||||
"""Registers a provider class."""
|
||||
if name.lower() in PROVIDER_MAP:
|
||||
logger.warning(f"Provider '{name}' is already registered. Overwriting.")
|
||||
PROVIDER_MAP[name.lower()] = provider_class
|
||||
logger.info(f"Registered provider: {name}")
|
||||
|
||||
|
||||
def create_llm_provider(provider_name: str, api_key: str, base_url: str | None = None) -> BaseProvider:
|
||||
"""
|
||||
Factory function to create an instance of a specific LLM provider.
|
||||
|
||||
Args:
|
||||
provider_name: The name of the provider (e.g., 'openai', 'anthropic').
|
||||
api_key: The API key for the provider.
|
||||
base_url: Optional base URL for the provider's API.
|
||||
|
||||
Returns:
|
||||
An instance of the requested BaseProvider subclass.
|
||||
|
||||
Raises:
|
||||
ValueError: If the requested provider_name is not registered.
|
||||
"""
|
||||
provider_class = PROVIDER_MAP.get(provider_name.lower())
|
||||
|
||||
if provider_class is None:
|
||||
available = ", ".join(PROVIDER_MAP.keys()) or "None"
|
||||
raise ValueError(f"Unsupported LLM provider: '{provider_name}'. Available providers: {available}")
|
||||
|
||||
logger.info(f"Creating LLM provider instance for: {provider_name}")
|
||||
try:
|
||||
return provider_class(api_key=api_key, base_url=base_url)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to instantiate provider '{provider_name}': {e}", exc_info=True)
|
||||
raise RuntimeError(f"Could not create provider '{provider_name}'.") from e
|
||||
|
||||
|
||||
def get_available_providers() -> list[str]:
|
||||
"""Returns a list of registered provider names."""
|
||||
return list(PROVIDER_MAP.keys())
|
||||
|
||||
|
||||
# Example of how specific providers would register themselves if structured as plugins,
|
||||
# but for now, we'll explicitly import and map them above.
|
||||
# def load_providers():
|
||||
# # Potentially load providers dynamically if designed as plugins
|
||||
# pass
|
||||
# load_providers()
|
||||
295
src/providers/anthropic_provider.py
Normal file
295
src/providers/anthropic_provider.py
Normal file
@@ -0,0 +1,295 @@
|
||||
# src/providers/anthropic_provider.py
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from anthropic import Anthropic, Stream
|
||||
from anthropic.types import Message, MessageStreamEvent, TextDelta
|
||||
|
||||
# Use relative imports for modules within the same package
|
||||
from providers.base import BaseProvider
|
||||
|
||||
# Use absolute imports as per Ruff warning and user instructions
|
||||
from src.llm_models import MODELS
|
||||
from src.tools.conversion import convert_to_anthropic_tools
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnthropicProvider(BaseProvider):
|
||||
"""Provider implementation for Anthropic Claude models."""
|
||||
|
||||
def __init__(self, api_key: str, base_url: str | None = None):
|
||||
# Anthropic client doesn't use base_url in the same way, but store it if needed
|
||||
# Use default Anthropic endpoint if base_url is not provided or relevant
|
||||
effective_base_url = base_url or MODELS.get("anthropic", {}).get("endpoint")
|
||||
super().__init__(api_key, effective_base_url) # Pass base_url to parent, though Anthropic client might ignore it
|
||||
logger.info("Initializing AnthropicProvider")
|
||||
try:
|
||||
self.client = Anthropic(api_key=self.api_key)
|
||||
# Note: Anthropic client doesn't take base_url during init
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize Anthropic client: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def _convert_messages(self, messages: list[dict[str, Any]]) -> tuple[str | None, list[dict[str, Any]]]:
|
||||
"""Converts standard message format to Anthropic's format, extracting system prompt."""
|
||||
anthropic_messages = []
|
||||
system_prompt = None
|
||||
for i, message in enumerate(messages):
|
||||
role = message.get("role")
|
||||
content = message.get("content")
|
||||
|
||||
if role == "system":
|
||||
if i == 0:
|
||||
system_prompt = content
|
||||
logger.debug("Extracted system prompt for Anthropic.")
|
||||
else:
|
||||
# Handle system message not at the start (append to previous user message or add as user)
|
||||
logger.warning("System message found not at the beginning. Treating as user message.")
|
||||
anthropic_messages.append({"role": "user", "content": f"[System Note]\n{content}"})
|
||||
continue
|
||||
|
||||
# Handle tool results specifically
|
||||
if role == "tool":
|
||||
# Find the preceding assistant message with the corresponding tool_use block
|
||||
# This requires careful handling in the follow-up logic
|
||||
tool_use_id = message.get("tool_call_id")
|
||||
tool_content = content
|
||||
# Format as a tool_result content block
|
||||
anthropic_messages.append({"role": "user", "content": [{"type": "tool_result", "tool_use_id": tool_use_id, "content": tool_content}]})
|
||||
continue
|
||||
|
||||
# Handle assistant message potentially containing tool_use blocks
|
||||
if role == "assistant":
|
||||
# Check if content is structured (e.g., from a previous tool call response)
|
||||
if isinstance(content, list): # Assuming tool calls might be represented as a list
|
||||
anthropic_messages.append({"role": "assistant", "content": content})
|
||||
else:
|
||||
anthropic_messages.append({"role": "assistant", "content": content}) # Regular text content
|
||||
continue
|
||||
|
||||
# Regular user messages
|
||||
if role == "user":
|
||||
anthropic_messages.append({"role": "user", "content": content})
|
||||
continue
|
||||
|
||||
logger.warning(f"Unsupported role '{role}' in message conversion for Anthropic.")
|
||||
|
||||
# Ensure conversation starts with a user message if no system prompt was used
|
||||
if not system_prompt and anthropic_messages and anthropic_messages[0]["role"] != "user":
|
||||
logger.warning("Anthropic conversation must start with a user message. Prepending empty user message.")
|
||||
anthropic_messages.insert(0, {"role": "user", "content": "[Start of conversation]"}) # Or handle differently
|
||||
|
||||
return system_prompt, anthropic_messages
|
||||
|
||||
def create_chat_completion(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
model: str,
|
||||
temperature: float = 0.4,
|
||||
max_tokens: int | None = None, # Anthropic requires max_tokens
|
||||
stream: bool = True,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
) -> Stream[MessageStreamEvent] | Message:
|
||||
"""Creates a chat completion using the Anthropic API."""
|
||||
logger.debug(f"Anthropic create_chat_completion called. Stream: {stream}, Tools: {bool(tools)}")
|
||||
|
||||
# Anthropic requires max_tokens
|
||||
if max_tokens is None:
|
||||
max_tokens = 4096 # Default value if not provided
|
||||
logger.warning(f"max_tokens not provided for Anthropic, defaulting to {max_tokens}")
|
||||
|
||||
system_prompt, anthropic_messages = self._convert_messages(messages)
|
||||
|
||||
try:
|
||||
completion_params = {
|
||||
"model": model,
|
||||
"messages": anthropic_messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"stream": stream,
|
||||
}
|
||||
if system_prompt:
|
||||
completion_params["system"] = system_prompt
|
||||
if tools:
|
||||
completion_params["tools"] = tools
|
||||
# Anthropic doesn't have an explicit 'tool_choice' like OpenAI's 'auto' in the main API call
|
||||
|
||||
# Remove None values (though Anthropic requires max_tokens)
|
||||
completion_params = {k: v for k, v in completion_params.items() if v is not None}
|
||||
|
||||
log_params = completion_params.copy()
|
||||
if "messages" in log_params:
|
||||
log_params["messages"] = [{k: (v[:100] + "..." if isinstance(v, str) and len(v) > 100 else v) for k, v in msg.items()} for msg in log_params["messages"][-2:]]
|
||||
tools_log = log_params.get("tools", "Not Present")
|
||||
logger.debug(f"Calling Anthropic API. Model: {log_params.get('model')}, Stream: {log_params.get('stream')}, System: {bool(log_params.get('system'))}, Tools: {tools_log}")
|
||||
logger.debug(f"Full API Params (messages summarized): {log_params}")
|
||||
|
||||
response = self.client.messages.create(**completion_params)
|
||||
logger.debug("Anthropic API call successful.")
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"Anthropic API error: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def get_streaming_content(self, response: Stream[MessageStreamEvent]) -> Generator[str, None, None]:
|
||||
"""Yields content chunks from an Anthropic streaming response."""
|
||||
logger.debug("Processing Anthropic stream...")
|
||||
full_delta = ""
|
||||
try:
|
||||
# Iterate through events in the stream
|
||||
for event in response:
|
||||
if event.type == "content_block_delta":
|
||||
# Check if the delta is for text content before accessing .text
|
||||
if isinstance(event.delta, TextDelta):
|
||||
delta_text = event.delta.text
|
||||
if delta_text:
|
||||
full_delta += delta_text
|
||||
yield delta_text
|
||||
# Ignore other delta types like InputJSONDelta for text streaming
|
||||
# Other event types like 'message_start', 'content_block_start', etc., can be logged or handled if needed
|
||||
elif event.type == "message_start":
|
||||
logger.debug(f"Anthropic stream started. Model: {event.message.model}")
|
||||
elif event.type == "message_stop":
|
||||
# The stop_reason might be available on the 'message' object associated with the stream,
|
||||
# not directly on the stop event itself. We log that the stop event occurred.
|
||||
# Accessing the actual reason might require inspecting the final message state if needed.
|
||||
logger.debug("Anthropic stream message_stop event received.")
|
||||
elif event.type == "content_block_start":
|
||||
if event.content_block.type == "tool_use":
|
||||
logger.debug(f"Anthropic stream detected tool use start: ID {event.content_block.id}, Name: {event.content_block.name}")
|
||||
elif event.type == "content_block_stop":
|
||||
logger.debug(f"Anthropic stream detected content block stop. Index: {event.index}")
|
||||
|
||||
logger.debug(f"Anthropic stream finished. Total delta length: {len(full_delta)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing Anthropic stream: {e}", exc_info=True)
|
||||
yield json.dumps({"error": f"Stream processing error: {str(e)}"})
|
||||
|
||||
def get_content(self, response: Message) -> str:
|
||||
"""Extracts content from a non-streaming Anthropic response."""
|
||||
try:
|
||||
# Combine text content from all text blocks
|
||||
text_content = "".join([block.text for block in response.content if block.type == "text"])
|
||||
logger.debug(f"Extracted content (length {len(text_content)}) from non-streaming Anthropic response.")
|
||||
return text_content
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting content from Anthropic response: {e}", exc_info=True)
|
||||
return f"[Error extracting content: {str(e)}]"
|
||||
|
||||
def has_tool_calls(self, response: Stream[MessageStreamEvent] | Message) -> bool:
|
||||
"""Checks if the Anthropic response contains tool calls."""
|
||||
try:
|
||||
if isinstance(response, Message): # Non-streaming
|
||||
# Check stop reason and content blocks
|
||||
has_tool_use_block = any(block.type == "tool_use" for block in response.content)
|
||||
has_calls = response.stop_reason == "tool_use" or has_tool_use_block
|
||||
logger.debug(f"Non-streaming Anthropic response check: stop_reason='{response.stop_reason}', has_tool_use_block={has_tool_use_block}. Result: {has_calls}")
|
||||
return has_calls
|
||||
elif isinstance(response, Stream):
|
||||
# Cannot reliably check an unconsumed stream without consuming it.
|
||||
# The LLMClient should handle this by checking after consumption or based on stop_reason if available post-stream.
|
||||
logger.warning("has_tool_calls check on an Anthropic stream is unreliable before consumption.")
|
||||
return False
|
||||
else:
|
||||
logger.warning(f"has_tool_calls received unexpected type for Anthropic: {type(response)}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking for Anthropic tool calls: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def parse_tool_calls(self, response: Message) -> list[dict[str, Any]]:
|
||||
"""Parses tool calls from a non-streaming Anthropic response."""
|
||||
parsed_calls = []
|
||||
try:
|
||||
if not isinstance(response, Message):
|
||||
logger.error(f"parse_tool_calls expects Anthropic Message, got {type(response)}")
|
||||
return []
|
||||
|
||||
if response.stop_reason != "tool_use":
|
||||
logger.debug("No tool use indicated by stop_reason.")
|
||||
# return [] # Might still have tool_use blocks even if stop_reason isn't tool_use? Check API docs. Let's check content anyway.
|
||||
|
||||
tool_use_blocks = [block for block in response.content if block.type == "tool_use"]
|
||||
if not tool_use_blocks:
|
||||
logger.debug("No 'tool_use' content blocks found in Anthropic response.")
|
||||
return []
|
||||
|
||||
logger.debug(f"Parsing {len(tool_use_blocks)} 'tool_use' blocks from Anthropic response.")
|
||||
for block in tool_use_blocks:
|
||||
# Adapt server/tool name splitting if needed (similar to OpenAI provider)
|
||||
# Assuming Anthropic tool names might also be prefixed like "server__tool"
|
||||
parts = block.name.split("__", 1)
|
||||
if len(parts) == 2:
|
||||
server_name, func_name = parts
|
||||
else:
|
||||
logger.warning(f"Could not determine server_name from Anthropic tool name '{block.name}'.")
|
||||
server_name = None
|
||||
func_name = block.name
|
||||
|
||||
parsed_calls.append({
|
||||
"id": block.id,
|
||||
"server_name": server_name,
|
||||
"function_name": func_name,
|
||||
"arguments": json.dumps(block.input), # Anthropic input is already a dict, dump to string like OpenAI provider expects? Or keep as dict? Let's keep as dict for now.
|
||||
# "arguments": block.input, # Keep as dict? Let's try this first.
|
||||
})
|
||||
|
||||
return parsed_calls
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing Anthropic tool calls: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
def format_tool_results(self, tool_call_id: str, result: Any) -> dict[str, Any]:
|
||||
"""Formats a tool result for an Anthropic follow-up request."""
|
||||
# Anthropic expects a 'tool_result' content block
|
||||
# The content of the result block should typically be a string.
|
||||
try:
|
||||
if isinstance(result, dict):
|
||||
content_str = json.dumps(result)
|
||||
else:
|
||||
content_str = str(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error JSON-encoding tool result for Anthropic {tool_call_id}: {e}")
|
||||
content_str = json.dumps({"error": "Failed to encode tool result", "original_type": str(type(result))})
|
||||
|
||||
logger.debug(f"Formatting Anthropic tool result for call ID {tool_call_id}")
|
||||
# This needs to be placed inside a "user" role message's content list
|
||||
return {
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tool_call_id,
|
||||
"content": content_str,
|
||||
# Optionally add is_error=True if result indicates an error
|
||||
}
|
||||
|
||||
def convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Converts internal tool format to Anthropic's format."""
|
||||
# Use the conversion function, assuming it's correctly placed and imported
|
||||
logger.debug(f"Converting {len(tools)} tools to Anthropic format.")
|
||||
try:
|
||||
# The conversion function needs to handle the server__tool prefixing
|
||||
anthropic_tools = convert_to_anthropic_tools(tools)
|
||||
logger.debug(f"Tool conversion result: {anthropic_tools}")
|
||||
return anthropic_tools
|
||||
except Exception as e:
|
||||
logger.error(f"Error during Anthropic tool conversion: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
# Helper needed by LLMClient's current tool handling logic (if adapting OpenAI's pattern)
|
||||
def get_original_message_with_calls(self, response: Message) -> dict[str, Any]:
|
||||
"""Extracts the assistant's message containing tool calls for Anthropic."""
|
||||
try:
|
||||
if isinstance(response, Message) and any(block.type == "tool_use" for block in response.content):
|
||||
# Anthropic's response structure is different. The 'message' itself is the assistant's turn.
|
||||
# We need to return a representation of this turn, including the tool_use blocks.
|
||||
# Convert Pydantic models within content to dicts
|
||||
content_list = [block.model_dump(exclude_unset=True) for block in response.content]
|
||||
return {"role": "assistant", "content": content_list}
|
||||
else:
|
||||
logger.warning("Could not extract original message with tool calls from Anthropic response.")
|
||||
return {"role": "assistant", "content": "[Could not extract tool calls message]"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting original Anthropic message with calls: {e}", exc_info=True)
|
||||
return {"role": "assistant", "content": f"[Error extracting tool calls message: {str(e)}]"}
|
||||
140
src/providers/base.py
Normal file
140
src/providers/base.py
Normal file
@@ -0,0 +1,140 @@
|
||||
# src/providers/base.py
|
||||
import abc
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
|
||||
class BaseProvider(abc.ABC):
|
||||
"""
|
||||
Abstract base class for LLM providers.
|
||||
|
||||
Defines the common interface for interacting with different LLM APIs,
|
||||
including handling chat completions and tool usage.
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: str, base_url: str | None = None):
|
||||
"""
|
||||
Initialize the provider.
|
||||
|
||||
Args:
|
||||
api_key: The API key for the provider.
|
||||
base_url: Optional base URL for the provider's API.
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
|
||||
@abc.abstractmethod
|
||||
def create_chat_completion(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
model: str,
|
||||
temperature: float = 0.4,
|
||||
max_tokens: int | None = None,
|
||||
stream: bool = True,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Send a chat completion request to the LLM provider.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries with 'role' and 'content'.
|
||||
model: Model identifier.
|
||||
temperature: Sampling temperature (0-1).
|
||||
max_tokens: Maximum tokens to generate.
|
||||
stream: Whether to stream the response.
|
||||
tools: Optional list of tools in the provider-specific format.
|
||||
|
||||
Returns:
|
||||
Provider-specific response object (e.g., API response, stream object).
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_streaming_content(self, response: Any) -> Generator[str, None, None]:
|
||||
"""
|
||||
Extracts and yields content chunks from a streaming response object.
|
||||
|
||||
Args:
|
||||
response: The streaming response object returned by create_chat_completion.
|
||||
|
||||
Yields:
|
||||
String chunks of the response content.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_content(self, response: Any) -> str:
|
||||
"""
|
||||
Extracts the complete content from a non-streaming response object.
|
||||
|
||||
Args:
|
||||
response: The non-streaming response object.
|
||||
|
||||
Returns:
|
||||
The complete response content as a string.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def has_tool_calls(self, response: Any) -> bool:
|
||||
"""
|
||||
Checks if the response object contains tool calls.
|
||||
|
||||
Args:
|
||||
response: The response object (streaming or non-streaming).
|
||||
|
||||
Returns:
|
||||
True if tool calls are present, False otherwise.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def parse_tool_calls(self, response: Any) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Parses tool calls from the response object.
|
||||
|
||||
Args:
|
||||
response: The response object containing tool calls.
|
||||
|
||||
Returns:
|
||||
A list of dictionaries, each representing a tool call with details
|
||||
like 'id', 'function_name', 'arguments'. The exact structure might
|
||||
vary slightly based on provider needs but should contain enough
|
||||
info for execution.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def format_tool_results(self, tool_call_id: str, result: Any) -> dict[str, Any]:
|
||||
"""
|
||||
Formats the result of a tool execution into the structure expected
|
||||
by the provider for follow-up requests.
|
||||
|
||||
Args:
|
||||
tool_call_id: The unique ID of the tool call (from parse_tool_calls).
|
||||
result: The data returned by the tool execution.
|
||||
|
||||
Returns:
|
||||
A dictionary representing the tool result in the provider's format.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Converts a list of tools from the standard internal format to the
|
||||
provider-specific format required for the API call.
|
||||
|
||||
Args:
|
||||
tools: List of tool definitions in the standard internal format.
|
||||
Each dict contains 'server_name', 'name', 'description', 'input_schema'.
|
||||
|
||||
Returns:
|
||||
List of tool definitions in the provider-specific format.
|
||||
"""
|
||||
pass
|
||||
|
||||
# Optional: Add a method for follow-up completions if the provider API
|
||||
# requires a specific structure different from just appending messages.
|
||||
# def create_follow_up_completion(...) -> Any:
|
||||
# pass
|
||||
239
src/providers/openai_provider.py
Normal file
239
src/providers/openai_provider.py
Normal file
@@ -0,0 +1,239 @@
|
||||
# src/providers/openai_provider.py
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from openai import OpenAI, Stream
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall
|
||||
|
||||
from providers.base import BaseProvider
|
||||
from src.llm_models import MODELS # Use absolute import
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIProvider(BaseProvider):
|
||||
"""Provider implementation for OpenAI and compatible APIs."""
|
||||
|
||||
def __init__(self, api_key: str, base_url: str | None = None):
|
||||
# Use default OpenAI endpoint if base_url is not provided
|
||||
effective_base_url = base_url or MODELS.get("openai", {}).get("endpoint")
|
||||
super().__init__(api_key, effective_base_url)
|
||||
logger.info(f"Initializing OpenAIProvider with base URL: {self.base_url}")
|
||||
try:
|
||||
# TODO: Add default headers like in original client?
|
||||
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize OpenAI client: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def create_chat_completion(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
model: str,
|
||||
temperature: float = 0.4,
|
||||
max_tokens: int | None = None,
|
||||
stream: bool = True,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
) -> Stream[ChatCompletionChunk] | ChatCompletion:
|
||||
"""Creates a chat completion using the OpenAI API."""
|
||||
logger.debug(f"OpenAI create_chat_completion called. Stream: {stream}, Tools: {bool(tools)}")
|
||||
try:
|
||||
completion_params = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"stream": stream,
|
||||
}
|
||||
if tools:
|
||||
completion_params["tools"] = tools
|
||||
completion_params["tool_choice"] = "auto" # Let OpenAI decide when to use tools
|
||||
|
||||
# Remove None values like max_tokens if not provided
|
||||
completion_params = {k: v for k, v in completion_params.items() if v is not None}
|
||||
|
||||
# --- Added Debug Logging ---
|
||||
log_params = completion_params.copy()
|
||||
# Avoid logging full messages if they are too long
|
||||
if "messages" in log_params:
|
||||
log_params["messages"] = [
|
||||
{k: (v[:100] + "..." if isinstance(v, str) and len(v) > 100 else v) for k, v in msg.items()}
|
||||
for msg in log_params["messages"][-2:] # Log last 2 messages summary
|
||||
]
|
||||
# Specifically log tools structure if present
|
||||
tools_log = log_params.get("tools", "Not Present")
|
||||
logger.debug(f"Calling OpenAI API. Model: {log_params.get('model')}, Stream: {log_params.get('stream')}, Tools: {tools_log}")
|
||||
logger.debug(f"Full API Params (messages summarized): {log_params}")
|
||||
# --- End Added Debug Logging ---
|
||||
|
||||
response = self.client.chat.completions.create(**completion_params)
|
||||
logger.debug("OpenAI API call successful.")
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI API error: {e}", exc_info=True)
|
||||
# Re-raise for the LLMClient to handle
|
||||
raise
|
||||
|
||||
def get_streaming_content(self, response: Stream[ChatCompletionChunk]) -> Generator[str, None, None]:
|
||||
"""Yields content chunks from an OpenAI streaming response."""
|
||||
logger.debug("Processing OpenAI stream...")
|
||||
full_delta = ""
|
||||
try:
|
||||
for chunk in response:
|
||||
delta = chunk.choices[0].delta.content
|
||||
if delta:
|
||||
full_delta += delta
|
||||
yield delta
|
||||
logger.debug(f"Stream finished. Total delta length: {len(full_delta)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing OpenAI stream: {e}", exc_info=True)
|
||||
# Yield an error message? Or let the generator stop?
|
||||
yield json.dumps({"error": f"Stream processing error: {str(e)}"})
|
||||
|
||||
def get_content(self, response: ChatCompletion) -> str:
|
||||
"""Extracts content from a non-streaming OpenAI response."""
|
||||
try:
|
||||
content = response.choices[0].message.content
|
||||
logger.debug(f"Extracted content (length {len(content) if content else 0}) from non-streaming response.")
|
||||
return content or "" # Return empty string if content is None
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting content from OpenAI response: {e}", exc_info=True)
|
||||
return f"[Error extracting content: {str(e)}]"
|
||||
|
||||
def has_tool_calls(self, response: Stream[ChatCompletionChunk] | ChatCompletion) -> bool:
|
||||
"""Checks if the OpenAI response contains tool calls."""
|
||||
try:
|
||||
if isinstance(response, ChatCompletion): # Non-streaming
|
||||
return bool(response.choices[0].message.tool_calls)
|
||||
elif hasattr(response, "_iterator"): # Check if it looks like our stream wrapper
|
||||
# This is tricky for streams. We'd need to peek at the first chunk(s)
|
||||
# or buffer the response. For simplicity, this check might be unreliable
|
||||
# for streams *before* they are consumed. LLMClient needs robust handling.
|
||||
logger.warning("has_tool_calls check on a stream is unreliable before consumption.")
|
||||
# A more robust check would involve consuming the start of the stream
|
||||
# or relying on the structure after consumption.
|
||||
return False # Assume no for unconsumed stream for now
|
||||
else:
|
||||
# If it's already consumed stream or unexpected type
|
||||
logger.warning(f"has_tool_calls received unexpected type: {type(response)}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking for tool calls: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def parse_tool_calls(self, response: ChatCompletion) -> list[dict[str, Any]]:
|
||||
"""Parses tool calls from a non-streaming OpenAI response."""
|
||||
# This implementation assumes a non-streaming response or a fully buffered stream
|
||||
parsed_calls = []
|
||||
try:
|
||||
if not isinstance(response, ChatCompletion):
|
||||
logger.error(f"parse_tool_calls expects ChatCompletion, got {type(response)}")
|
||||
# Attempt to handle buffered stream if possible? Complex.
|
||||
return []
|
||||
|
||||
tool_calls: list[ChatCompletionMessageToolCall] | None = response.choices[0].message.tool_calls
|
||||
if not tool_calls:
|
||||
return []
|
||||
|
||||
logger.debug(f"Parsing {len(tool_calls)} tool calls from OpenAI response.")
|
||||
for call in tool_calls:
|
||||
if call.type == "function":
|
||||
# Attempt to parse server_name from function name if prefixed
|
||||
# e.g., "server-name__actual-tool-name"
|
||||
parts = call.function.name.split("__", 1)
|
||||
if len(parts) == 2:
|
||||
server_name, func_name = parts
|
||||
else:
|
||||
# If no prefix, how do we know the server? Needs refinement.
|
||||
# Defaulting to None or a default server? Log warning.
|
||||
logger.warning(f"Could not determine server_name from tool name '{call.function.name}'. Assuming default or error needed.")
|
||||
server_name = None # Or raise error, or use a default?
|
||||
func_name = call.function.name
|
||||
|
||||
parsed_calls.append({
|
||||
"id": call.id,
|
||||
"server_name": server_name, # May be None if not prefixed
|
||||
"function_name": func_name,
|
||||
"arguments": call.function.arguments, # Arguments are already a string here
|
||||
})
|
||||
else:
|
||||
logger.warning(f"Unsupported tool call type: {call.type}")
|
||||
|
||||
return parsed_calls
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing OpenAI tool calls: {e}", exc_info=True)
|
||||
return [] # Return empty list on error
|
||||
|
||||
def format_tool_results(self, tool_call_id: str, result: Any) -> dict[str, Any]:
|
||||
"""Formats a tool result for an OpenAI follow-up request."""
|
||||
# Result might be a dict (including potential errors) or simple string/number
|
||||
# OpenAI expects the content to be a string, often JSON.
|
||||
try:
|
||||
if isinstance(result, dict):
|
||||
content = json.dumps(result)
|
||||
else:
|
||||
content = str(result) # Ensure it's a string
|
||||
except Exception as e:
|
||||
logger.error(f"Error JSON-encoding tool result for {tool_call_id}: {e}")
|
||||
content = json.dumps({"error": "Failed to encode tool result", "original_type": str(type(result))})
|
||||
|
||||
logger.debug(f"Formatting tool result for call ID {tool_call_id}")
|
||||
return {
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"content": content,
|
||||
}
|
||||
|
||||
def convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Converts internal tool format to OpenAI's format."""
|
||||
openai_tools = []
|
||||
logger.debug(f"Converting {len(tools)} tools to OpenAI format.")
|
||||
for tool in tools:
|
||||
server_name = tool.get("server_name")
|
||||
tool_name = tool.get("name")
|
||||
description = tool.get("description")
|
||||
input_schema = tool.get("inputSchema")
|
||||
|
||||
if not server_name or not tool_name or not description or not input_schema:
|
||||
logger.warning(f"Skipping invalid tool definition during conversion: {tool}")
|
||||
continue
|
||||
|
||||
# Prefix tool name with server name to avoid clashes and allow routing
|
||||
prefixed_tool_name = f"{server_name}__{tool_name}"
|
||||
|
||||
openai_tool_format = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": prefixed_tool_name,
|
||||
"description": description,
|
||||
"parameters": input_schema, # OpenAI uses JSON Schema directly
|
||||
},
|
||||
}
|
||||
openai_tools.append(openai_tool_format)
|
||||
logger.debug(f"Converted tool: {prefixed_tool_name}")
|
||||
|
||||
return openai_tools
|
||||
|
||||
# Helper needed by LLMClient's current tool handling logic
|
||||
def get_original_message_with_calls(self, response: ChatCompletion) -> dict[str, Any]:
|
||||
"""Extracts the assistant's message containing tool calls."""
|
||||
try:
|
||||
if isinstance(response, ChatCompletion) and response.choices[0].message.tool_calls:
|
||||
message = response.choices[0].message
|
||||
# Convert Pydantic model to dict for message history
|
||||
return message.model_dump(exclude_unset=True)
|
||||
else:
|
||||
logger.warning("Could not extract original message with tool calls from response.")
|
||||
# Return a placeholder or raise error?
|
||||
return {"role": "assistant", "content": "[Could not extract tool calls message]"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting original message with calls: {e}", exc_info=True)
|
||||
return {"role": "assistant", "content": f"[Error extracting tool calls message: {str(e)}]"}
|
||||
|
||||
|
||||
# Register this provider (if using the registration mechanism)
|
||||
# from . import register_provider
|
||||
# register_provider("openai", OpenAIProvider)
|
||||
6
src/tools/__init__.py
Normal file
6
src/tools/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# src/tools/__init__.py
|
||||
# This file makes the 'tools' directory a Python package.
|
||||
|
||||
# Optionally import key functions/classes for easier access
|
||||
# from .conversion import convert_to_openai_tools, convert_to_anthropic_tools
|
||||
# from .execution import execute_tool # Assuming execution.py will exist
|
||||
177
src/tools/conversion.py
Normal file
177
src/tools/conversion.py
Normal file
@@ -0,0 +1,177 @@
|
||||
"""
|
||||
Conversion utilities for MCP tools.
|
||||
|
||||
This module contains functions to convert between different tool formats
|
||||
for various LLM providers (OpenAI, Anthropic, etc.).
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def convert_to_openai_tools(mcp_tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Convert MCP tools to OpenAI tool definitions.
|
||||
|
||||
Args:
|
||||
mcp_tools: List of MCP tools (each with server_name, name, description, inputSchema).
|
||||
|
||||
Returns:
|
||||
List of OpenAI tool definitions.
|
||||
"""
|
||||
openai_tools = []
|
||||
logger.debug(f"Converting {len(mcp_tools)} MCP tools to OpenAI format.")
|
||||
|
||||
for tool in mcp_tools:
|
||||
server_name = tool.get("server_name")
|
||||
tool_name = tool.get("name")
|
||||
description = tool.get("description")
|
||||
input_schema = tool.get("inputSchema")
|
||||
|
||||
if not server_name or not tool_name or not description or not input_schema:
|
||||
logger.warning(f"Skipping invalid MCP tool definition during OpenAI conversion: {tool}")
|
||||
continue
|
||||
|
||||
# Prefix tool name with server name for routing
|
||||
prefixed_tool_name = f"{server_name}__{tool_name}"
|
||||
|
||||
# Initialize the OpenAI tool structure
|
||||
openai_tool = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": prefixed_tool_name,
|
||||
"description": description,
|
||||
"parameters": input_schema, # OpenAI uses JSON Schema directly
|
||||
},
|
||||
}
|
||||
# Basic validation/cleaning of schema if needed could go here
|
||||
if not isinstance(input_schema, dict) or input_schema.get("type") != "object":
|
||||
logger.warning(f"Input schema for tool '{prefixed_tool_name}' is not a valid JSON object schema. OpenAI might reject this.")
|
||||
# Ensure basic structure if missing
|
||||
if not isinstance(input_schema, dict):
|
||||
input_schema = {}
|
||||
if "type" not in input_schema:
|
||||
input_schema["type"] = "object"
|
||||
if "properties" not in input_schema:
|
||||
input_schema["properties"] = {}
|
||||
openai_tool["function"]["parameters"] = input_schema
|
||||
|
||||
openai_tools.append(openai_tool)
|
||||
logger.debug(f"Converted MCP tool to OpenAI: {prefixed_tool_name}")
|
||||
|
||||
return openai_tools
|
||||
|
||||
|
||||
def convert_to_anthropic_tools(mcp_tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Convert MCP tools to Anthropic tool definitions.
|
||||
|
||||
Args:
|
||||
mcp_tools: List of MCP tools (each with server_name, name, description, inputSchema).
|
||||
|
||||
Returns:
|
||||
List of Anthropic tool definitions.
|
||||
"""
|
||||
logger.debug(f"Converting {len(mcp_tools)} MCP tools to Anthropic format")
|
||||
anthropic_tools = []
|
||||
|
||||
for tool in mcp_tools:
|
||||
server_name = tool.get("server_name")
|
||||
tool_name = tool.get("name")
|
||||
description = tool.get("description")
|
||||
input_schema = tool.get("inputSchema")
|
||||
|
||||
if not server_name or not tool_name or not description or not input_schema:
|
||||
logger.warning(f"Skipping invalid MCP tool definition during Anthropic conversion: {tool}")
|
||||
continue
|
||||
|
||||
# Prefix tool name with server name for routing
|
||||
prefixed_tool_name = f"{server_name}__{tool_name}"
|
||||
|
||||
# Initialize the Anthropic tool structure
|
||||
# Anthropic's format is quite close to JSON Schema
|
||||
anthropic_tool = {"name": prefixed_tool_name, "description": description, "input_schema": input_schema}
|
||||
|
||||
# Basic validation/cleaning of schema if needed
|
||||
if not isinstance(input_schema, dict) or input_schema.get("type") != "object":
|
||||
logger.warning(f"Input schema for tool '{prefixed_tool_name}' is not a valid JSON object schema. Anthropic might reject this.")
|
||||
# Ensure basic structure if missing
|
||||
if not isinstance(input_schema, dict):
|
||||
input_schema = {}
|
||||
if "type" not in input_schema:
|
||||
input_schema["type"] = "object"
|
||||
if "properties" not in input_schema:
|
||||
input_schema["properties"] = {}
|
||||
anthropic_tool["input_schema"] = input_schema
|
||||
|
||||
anthropic_tools.append(anthropic_tool)
|
||||
logger.debug(f"Converted MCP tool to Anthropic: {prefixed_tool_name}")
|
||||
|
||||
return anthropic_tools
|
||||
|
||||
|
||||
def convert_to_google_tools(mcp_tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Convert MCP tools to Google Gemini format (dictionary structure).
|
||||
|
||||
Args:
|
||||
mcp_tools: List of MCP tools (each with server_name, name, description, inputSchema).
|
||||
|
||||
Returns:
|
||||
List containing one dictionary with 'function_declarations'.
|
||||
"""
|
||||
logger.debug(f"Converting {len(mcp_tools)} MCP tools to Google Gemini format")
|
||||
|
||||
function_declarations = []
|
||||
|
||||
for tool in mcp_tools:
|
||||
server_name = tool.get("server_name")
|
||||
tool_name = tool.get("name")
|
||||
description = tool.get("description")
|
||||
input_schema = tool.get("inputSchema")
|
||||
|
||||
if not server_name or not tool_name or not description or not input_schema:
|
||||
logger.warning(f"Skipping invalid MCP tool definition during Google conversion: {tool}")
|
||||
continue
|
||||
|
||||
# Prefix tool name with server name for routing
|
||||
prefixed_tool_name = f"{server_name}__{tool_name}"
|
||||
|
||||
# Basic validation/cleaning of schema
|
||||
if not isinstance(input_schema, dict) or input_schema.get("type") != "object":
|
||||
logger.warning(f"Input schema for tool '{prefixed_tool_name}' is not a valid JSON object schema. Google might reject this.")
|
||||
# Ensure basic structure if missing
|
||||
if not isinstance(input_schema, dict):
|
||||
input_schema = {}
|
||||
if "type" not in input_schema:
|
||||
input_schema["type"] = "object"
|
||||
if "properties" not in input_schema:
|
||||
input_schema["properties"] = {}
|
||||
# Google requires properties for object type, add dummy if empty
|
||||
if not input_schema["properties"]:
|
||||
logger.warning(f"Empty properties for tool '{prefixed_tool_name}', adding dummy property for Google.")
|
||||
input_schema["properties"] = {"_dummy_param": {"type": "STRING", "description": "Placeholder"}}
|
||||
|
||||
# Create function declaration for Google's format
|
||||
function_declaration = {
|
||||
"name": prefixed_tool_name,
|
||||
"description": description,
|
||||
"parameters": input_schema, # Google uses JSON Schema directly
|
||||
}
|
||||
|
||||
function_declarations.append(function_declaration)
|
||||
logger.debug(f"Converted MCP tool to Google FunctionDeclaration: {prefixed_tool_name}")
|
||||
|
||||
# Google API expects a list containing one Tool object dict
|
||||
google_tools_wrapper = [{"function_declarations": function_declarations}] if function_declarations else []
|
||||
|
||||
logger.debug(f"Final Google tools structure: {google_tools_wrapper}")
|
||||
return google_tools_wrapper
|
||||
|
||||
|
||||
# Note: The _handle_schema_construct helper from the reference code is not strictly
|
||||
# needed if we assume the inputSchema is already valid JSON Schema.
|
||||
# If complex schemas (anyOf, etc.) need specific handling beyond standard JSON Schema,
|
||||
# that logic could be added here or within the provider implementations.
|
||||
Reference in New Issue
Block a user