Compare commits

..

3 Commits

Author SHA1 Message Date
a4683023ad feat: add support for Anthropic provider, including configuration and conversion utilities 2025-03-26 11:57:52 +00:00
b4986e0eb9 refactor: remove custom MCP client implementation files 2025-03-26 11:00:43 +00:00
80ba05338f feat: Implement async utilities for MCP server management and JSON-RPC communication
- Added `process.py` for managing MCP server subprocesses with async capabilities.
- Introduced `protocol.py` for handling JSON-RPC communication over streams.
- Created `llm_client.py` to support chat completion requests to various LLM providers, integrating with MCP tools.
- Defined model configurations in `llm_models.py` for different LLM providers.
- Removed the synchronous `mcp_manager.py` in favor of a more modular approach.
- Established a provider framework in `providers` directory with a base class and specific implementations.
- Implemented `OpenAIProvider` for interacting with OpenAI's API, including streaming support and tool call handling.
2025-03-26 11:00:20 +00:00
20 changed files with 2255 additions and 831 deletions

View File

@@ -1,7 +1,31 @@
[base]
# provider can be [ openai|openrouter|anthropic|google]
provider = openrouter
[openrouter]
api_key = YOUR_API_KEY
base_url = https://openrouter.ai/api/v1
model = openai/gpt-4o-2024-11-20
context_window = 128000
[anthropic]
api_key = YOUR_API_KEY
base_url = https://api.anthropic.com/v1/messages
model = claude-3-7-sonnet-20250219
context_window = 128000
[google]
api_key = YOUR_API_KEY
base_url = https://generativelanguage.googleapis.com/v1beta/generateContent
model = gemini-2.0-flash
context_window = 1000000
[openai]
api_key = YOUR_API_KEY
base_url = CUSTOM_BASE_URL
model = YOUR_MODEL_ID
base_url = https://api.openai.com/v1
model = openai/gpt-4o
context_window = 128000
[mcp]
servers_json = config/mcp_config.json

View File

@@ -1,12 +1,12 @@
{
"mcpServers": {
"dolphin-demo-database-sqlite": {
"command": "uvx",
"args": [
"mcp-server-sqlite",
"--db-path",
"~/.dolphin/dolphin.db"
]
}
"mcpServers": {
"mcp-server-sqlite": {
"command": "uvx",
"args": [
"mcp-server-sqlite",
"--db-path",
"~/.mcpapp/mcpapp.db"
]
}
}
}

View File

@@ -10,7 +10,9 @@ authors = [
dependencies = [
"streamlit",
"python-dotenv",
"openai"
"openai",
"anthropic",
"google-genai",
]
classifiers = [
"Development Status :: 3 - Alpha",
@@ -61,6 +63,7 @@ lint.select = [
"T10", # flake8-debugger
"A", # flake8-builtins
"UP", # pyupgrade
"TID", # flake8-tidy-imports
]
lint.ignore = [
@@ -81,7 +84,7 @@ skip-magic-trailing-comma = false
combine-as-imports = true
[tool.ruff.lint.mccabe]
max-complexity = 12
max-complexity = 16
[tool.ruff.lint.flake8-tidy-imports]
# Disallow all relative imports.

View File

@@ -1,29 +1,112 @@
import atexit
import configparser
import json # For handling potential error JSON in stream
import logging
import streamlit as st
from openai_client import OpenAIClient
# Updated imports
from llm_client import LLMClient
from src.custom_mcp.manager import SyncMCPManager # Updated import path
# Configure logging for the app
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
def init_session_state():
"""Initializes session state variables including clients."""
if "messages" not in st.session_state:
st.session_state.messages = []
logger.info("Initialized session state: messages")
if "client" not in st.session_state:
st.session_state.client = OpenAIClient()
# Register cleanup for MCP servers
if hasattr(st.session_state.client, "mcp_manager"):
atexit.register(st.session_state.client.mcp_manager.shutdown)
logger.info("Attempting to initialize clients...")
try:
config = configparser.ConfigParser()
# TODO: Improve config file path handling (e.g., environment variable, absolute path)
config_files_read = config.read("config/config.ini")
if not config_files_read:
raise FileNotFoundError("config.ini not found or could not be read.")
logger.info(f"Read configuration from: {config_files_read}")
# --- MCP Manager Setup ---
mcp_config_path = "config/mcp_config.json" # Default
if config.has_section("mcp") and config["mcp"].get("servers_json"):
mcp_config_path = config["mcp"]["servers_json"]
logger.info(f"Using MCP config path from config.ini: {mcp_config_path}")
else:
logger.info(f"Using default MCP config path: {mcp_config_path}")
mcp_manager = SyncMCPManager(mcp_config_path)
if not mcp_manager.initialize():
# Log warning but continue - LLMClient will operate without tools
logger.warning("MCP Manager failed to initialize. Proceeding without MCP tools.")
else:
logger.info("MCP Manager initialized successfully.")
# Register shutdown hook for MCP manager
atexit.register(mcp_manager.shutdown)
logger.info("Registered MCP Manager shutdown hook.")
# --- LLM Client Setup ---
provider_name = None
model_name = None
api_key = None
base_url = None
# 1. Determine provider from [base] section
if config.has_section("base") and config["base"].get("provider"):
provider_name = config["base"].get("provider")
logger.info(f"Provider selected from [base] section: {provider_name}")
else:
# Fallback or error if [base] provider is missing? Let's error for now.
raise ValueError("Missing 'provider' setting in [base] section of config.ini")
# 2. Read details from the specific provider's section
if config.has_section(provider_name):
provider_config = config[provider_name]
model_name = provider_config.get("model")
api_key = provider_config.get("api_key")
base_url = provider_config.get("base_url") # Optional
logger.info(f"Read configuration from [{provider_name}] section.")
else:
raise ValueError(f"Missing configuration section '[{provider_name}]' in config.ini for the selected provider.")
# Validate required config
if not api_key:
raise ValueError(f"Missing 'api_key' in [{provider_name}] section of config.ini")
if not model_name:
raise ValueError(f"Missing 'model' name in [{provider_name}] section of config.ini")
logger.info(f"Configuring LLMClient for provider: {provider_name}, model: {model_name}")
st.session_state.client = LLMClient(
provider_name=provider_name,
api_key=api_key,
mcp_manager=mcp_manager,
base_url=base_url,
)
st.session_state.model_name = model_name
logger.info("LLMClient initialized successfully.")
except Exception as e:
logger.error(f"Failed to initialize application clients: {e}", exc_info=True)
st.error(f"Application Initialization Error: {e}. Please check configuration and logs.")
# Stop the app if initialization fails critically
st.stop()
def display_chat_messages():
"""Displays chat messages stored in session state."""
for message in st.session_state.messages:
with st.chat_message(message["role"]):
# Simple markdown display for now
st.markdown(message["content"])
def handle_user_input():
"""Handles user input, calls LLMClient, and displays the response."""
if prompt := st.chat_input("Type your message..."):
print(f"User input received: {prompt}") # Debug log
logger.info(f"User input received: '{prompt[:50]}...'")
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
@@ -32,39 +115,85 @@ def handle_user_input():
with st.chat_message("assistant"):
response_placeholder = st.empty()
full_response = ""
error_occurred = False
print("Processing message...") # Debug log
response = st.session_state.client.get_chat_response(st.session_state.messages)
logger.info("Processing message via LLMClient...")
# Use the new client and method, always requesting stream for UI
response_stream = st.session_state.client.chat_completion(
messages=st.session_state.messages,
model=st.session_state.model_name, # Get model from session state
stream=True,
)
# Handle both MCP and standard OpenAI responses
# Check if it's NOT a dict (assuming stream is not a dict)
if not isinstance(response, dict):
# Standard OpenAI streaming response
for chunk in response:
# Ensure chunk has choices and delta before accessing
if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
full_response += chunk.choices[0].delta.content
response_placeholder.markdown(full_response + "")
# Handle the response (stream generator or error dict)
if hasattr(response_stream, "__iter__") and not isinstance(response_stream, dict):
logger.debug("Processing response stream...")
for chunk in response_stream:
# Check for potential error JSON yielded by the stream
try:
# Attempt to parse chunk as JSON only if it looks like it
if isinstance(chunk, str) and chunk.strip().startswith("{"):
error_data = json.loads(chunk)
if isinstance(error_data, dict) and "error" in error_data:
full_response = f"Error: {error_data['error']}"
logger.error(f"Error received in stream: {full_response}")
st.error(full_response)
error_occurred = True
break # Stop processing stream on error
# If not error JSON, treat as content chunk
if not error_occurred and isinstance(chunk, str):
full_response += chunk
response_placeholder.markdown(full_response + "") # Add cursor effect
except (json.JSONDecodeError, TypeError):
# Not JSON or not error structure, treat as content chunk
if not error_occurred and isinstance(chunk, str):
full_response += chunk
response_placeholder.markdown(full_response + "") # Add cursor effect
if not error_occurred:
response_placeholder.markdown(full_response) # Final update without cursor
logger.debug("Stream processing complete.")
elif isinstance(response_stream, dict) and "error" in response_stream:
# Handle error dict returned directly (e.g., API error before streaming)
full_response = f"Error: {response_stream['error']}"
logger.error(f"Error returned directly from chat_completion: {full_response}")
st.error(full_response)
error_occurred = True
else:
# MCP non-streaming response
full_response = response.get("assistant_text", "")
response_placeholder.markdown(full_response)
# Unexpected response type
full_response = "[Unexpected response format from LLMClient]"
logger.error(f"Unexpected response type: {type(response_stream)}")
st.error(full_response)
error_occurred = True
response_placeholder.markdown(full_response)
st.session_state.messages.append({"role": "assistant", "content": full_response})
print("Message processed successfully") # Debug log
# Only add non-error, non-empty responses to history
if not error_occurred and full_response:
st.session_state.messages.append({"role": "assistant", "content": full_response})
logger.info("Assistant response added to history.")
elif error_occurred:
logger.warning("Assistant response not added to history due to error.")
else:
logger.warning("Empty assistant response received, not added to history.")
except Exception as e:
st.error(f"Error processing message: {str(e)}")
print(f"Error details: {str(e)}") # Debug log
logger.error(f"Error during chat handling: {str(e)}", exc_info=True)
st.error(f"An unexpected error occurred: {str(e)}")
def main():
st.title("Streamlit Chat App")
init_session_state()
display_chat_messages()
handle_user_input()
"""Main function to run the Streamlit app."""
st.title("MCP Chat App") # Updated title
try:
init_session_state()
display_chat_messages()
handle_user_input()
except Exception as e:
# Catch potential errors during rendering or handling
logger.critical(f"Critical error in main app flow: {e}", exc_info=True)
st.error(f"A critical application error occurred: {e}")
if __name__ == "__main__":
logger.info("Starting Streamlit Chat App...")
main()

View File

@@ -0,0 +1 @@
# This file makes src/mcp a Python package

281
src/custom_mcp/client.py Normal file
View File

@@ -0,0 +1,281 @@
# src/mcp/client.py
"""Client class for managing and interacting with a single MCP server process."""
import asyncio
import logging
from typing import Any
from custom_mcp import process, protocol
logger = logging.getLogger(__name__)
# Define reasonable timeouts
LIST_TOOLS_TIMEOUT = 20.0 # Seconds (using the increased value from previous step)
CALL_TOOL_TIMEOUT = 110.0 # Seconds
class MCPClient:
"""
Manages the lifecycle and async communication with a single MCP server process.
"""
def __init__(self, server_name: str, command: str, args: list[str], config_env: dict[str, str]):
"""
Initializes the client for a specific server configuration.
Args:
server_name: Unique name for the server (for logging).
command: The command executable.
args: List of arguments for the command.
config_env: Server-specific environment variables.
"""
self.server_name = server_name
self.command = command
self.args = args
self.config_env = config_env
self.process: asyncio.subprocess.Process | None = None
self.reader: asyncio.StreamReader | None = None
self.writer: asyncio.StreamWriter | None = None
self._stderr_task: asyncio.Task | None = None
self._request_counter = 0
self._is_running = False
self.logger = logging.getLogger(f"{__name__}.{self.server_name}") # Instance-specific logger
async def _log_stderr(self):
"""Logs stderr output from the server process."""
if not self.process or not self.process.stderr:
self.logger.debug("Stderr logging skipped: process or stderr not available.")
return
stderr_reader = self.process.stderr
try:
while not stderr_reader.at_eof():
line = await stderr_reader.readline()
if line:
self.logger.warning(f"[stderr] {line.decode().strip()}")
except asyncio.CancelledError:
self.logger.debug("Stderr logging task cancelled.")
except Exception as e:
# Log errors but don't crash the logger task if possible
self.logger.error(f"Error reading stderr: {e}", exc_info=True)
finally:
self.logger.debug("Stderr logging task finished.")
async def start(self) -> bool:
"""
Starts the MCP server subprocess and sets up communication streams.
Returns:
True if the process started successfully, False otherwise.
"""
if self._is_running:
self.logger.warning("Start called but client is already running.")
return True
self.logger.info("Starting MCP server process...")
try:
self.process = await process.start_mcp_process(self.command, self.args, self.config_env)
self.reader = self.process.stdout
self.writer = self.process.stdin
if self.reader is None or self.writer is None:
self.logger.error("Failed to get stdout/stdin streams after process start.")
await self.stop() # Attempt cleanup
return False
# Start background task to monitor stderr
self._stderr_task = asyncio.create_task(self._log_stderr())
# --- Start MCP Initialization Handshake ---
self.logger.info("Starting MCP initialization handshake...")
self._request_counter += 1
init_req_id = self._request_counter
initialize_req = {
"jsonrpc": "2.0",
"id": init_req_id,
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05", # Use a recent version
"clientInfo": {"name": "CustomMCPClient", "version": "1.0.0"}, # Identify the client
"capabilities": {}, # Client capabilities (can be empty)
},
}
# Define a timeout for initialization
INITIALIZE_TIMEOUT = 15.0 # Seconds
try:
# Send initialize request
await protocol.send_request(self.writer, initialize_req)
self.logger.debug(f"Sent 'initialize' request (ID: {init_req_id}). Waiting for response...")
# Wait for initialize response
init_response = await protocol.read_response(self.reader, INITIALIZE_TIMEOUT)
if init_response and init_response.get("id") == init_req_id:
if "error" in init_response:
self.logger.error(f"Server returned error during initialization: {init_response['error']}")
await self.stop()
return False
elif "result" in init_response:
self.logger.info(f"Received 'initialize' response: {init_response.get('result', '{}')}") # Log server capabilities if provided
# Send initialized notification (using standard method name)
initialized_notify = {"jsonrpc": "2.0", "method": "notifications/initialized", "params": {}}
await protocol.send_request(self.writer, initialized_notify)
self.logger.info("'notifications/initialized' notification sent.")
self._is_running = True
self.logger.info("MCP server process started and initialized successfully.")
return True
else:
self.logger.error("Invalid 'initialize' response format (missing result/error).")
await self.stop()
return False
elif init_response:
self.logger.error(f"Received response with mismatched ID during initialization. Expected {init_req_id}, got {init_response.get('id')}")
await self.stop()
return False
else: # Timeout case
self.logger.error(f"'initialize' request timed out after {INITIALIZE_TIMEOUT} seconds.")
await self.stop()
return False
except ConnectionResetError:
self.logger.error("Connection lost during initialization handshake. Stopping client.")
await self.stop()
return False
except Exception as e:
self.logger.error(f"Unexpected error during initialization handshake: {e}", exc_info=True)
await self.stop()
return False
# --- End MCP Initialization Handshake ---
except Exception as e:
self.logger.error(f"Failed to start MCP server process: {e}", exc_info=True)
self.process = None # Ensure process is None on failure
self.reader = None
self.writer = None
self._is_running = False
return False
async def stop(self):
"""Stops the MCP server subprocess gracefully."""
if not self._is_running and not self.process:
self.logger.debug("Stop called but client is not running.")
return
self.logger.info("Stopping MCP server process...")
self._is_running = False # Mark as stopping
# Cancel stderr logging task
if self._stderr_task and not self._stderr_task.done():
self._stderr_task.cancel()
try:
await self._stderr_task
except asyncio.CancelledError:
self.logger.debug("Stderr task successfully cancelled.")
except Exception as e:
self.logger.error(f"Error waiting for stderr task cancellation: {e}")
self._stderr_task = None
# Stop the process using the utility function
if self.process:
await process.stop_mcp_process(self.process, self.server_name)
# Nullify references
self.process = None
self.reader = None
self.writer = None
self.logger.info("MCP server process stopped.")
async def list_tools(self) -> list[dict[str, Any]] | None:
"""
Sends a 'tools/list' request and waits for the response.
Returns:
A list of tool dictionaries, or None on error/timeout.
"""
if not self._is_running or not self.writer or not self.reader:
self.logger.error("Cannot list tools: client not running or streams unavailable.")
return None
self._request_counter += 1
req_id = self._request_counter
request = {"jsonrpc": "2.0", "method": "tools/list", "id": req_id}
try:
await protocol.send_request(self.writer, request)
response = await protocol.read_response(self.reader, LIST_TOOLS_TIMEOUT)
if response and "result" in response and isinstance(response["result"], dict) and "tools" in response["result"]:
tools = response["result"]["tools"]
if isinstance(tools, list):
self.logger.info(f"Successfully listed {len(tools)} tools.")
return tools
else:
self.logger.error(f"Invalid 'tools' format in response ID {req_id}: {type(tools)}")
return None
elif response and "error" in response:
self.logger.error(f"Error response for listTools ID {req_id}: {response['error']}")
return None
else:
# Includes timeout case (read_response returns None)
self.logger.error(f"No valid response or timeout for listTools ID {req_id}.")
return None
except ConnectionResetError:
self.logger.error("Connection lost during listTools request. Stopping client.")
await self.stop()
return None
except Exception as e:
self.logger.error(f"Unexpected error during listTools: {e}", exc_info=True)
return None
async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> dict[str, Any] | None:
"""
Sends a 'tools/call' request and waits for the response.
Args:
tool_name: The name of the tool to call.
arguments: The arguments for the tool.
Returns:
The result dictionary from the server, or None on error/timeout.
"""
if not self._is_running or not self.writer or not self.reader:
self.logger.error(f"Cannot call tool '{tool_name}': client not running or streams unavailable.")
return None
self._request_counter += 1
req_id = self._request_counter
request = {
"jsonrpc": "2.0",
"method": "tools/call",
"params": {"name": tool_name, "arguments": arguments},
"id": req_id,
}
try:
await protocol.send_request(self.writer, request)
response = await protocol.read_response(self.reader, CALL_TOOL_TIMEOUT)
if response and "result" in response:
# Assuming result is the desired payload
self.logger.info(f"Tool '{tool_name}' executed successfully.")
return response["result"]
elif response and "error" in response:
self.logger.error(f"Error response for tool '{tool_name}' ID {req_id}: {response['error']}")
# Return the error structure itself? Or just None? Returning error dict for now.
return {"error": response["error"]}
else:
# Includes timeout case
self.logger.error(f"No valid response or timeout for tool '{tool_name}' ID {req_id}.")
return None
except ConnectionResetError:
self.logger.error(f"Connection lost during callTool '{tool_name}'. Stopping client.")
await self.stop()
return None
except Exception as e:
self.logger.error(f"Unexpected error during callTool '{tool_name}': {e}", exc_info=True)
return None

366
src/custom_mcp/manager.py Normal file
View File

@@ -0,0 +1,366 @@
# src/mcp/manager.py
"""Synchronous manager for multiple MCPClient instances."""
import asyncio
import json
import logging
import threading
from typing import Any
# Use relative imports within the mcp package
from custom_mcp.client import MCPClient
# Configure basic logging
# Consider moving this to the main app entry point if not already done
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# Define reasonable timeouts for sync calls (should be slightly longer than async timeouts)
INITIALIZE_TIMEOUT = 60.0 # Seconds
SHUTDOWN_TIMEOUT = 30.0 # Seconds
LIST_ALL_TOOLS_TIMEOUT = 30.0 # Seconds
EXECUTE_TOOL_TIMEOUT = 120.0 # Seconds
class SyncMCPManager:
"""
Manages the lifecycle of multiple MCPClient instances and provides a
synchronous interface to interact with them using a background event loop.
"""
def __init__(self, config_path: str = "config/mcp_config.json"):
"""
Initializes the manager, loads config, but does not start servers yet.
Args:
config_path: Path to the MCP server configuration JSON file.
"""
self.config_path = config_path
self.config: dict[str, Any] | None = None
# Stores server_name -> MCPClient instance
self.servers: dict[str, MCPClient] = {}
self.initialized = False
self._lock = threading.Lock()
self._loop: asyncio.AbstractEventLoop | None = None
self._thread: threading.Thread | None = None
logger.info(f"Initializing SyncMCPManager with config path: {config_path}")
self._load_config()
def _load_config(self):
"""Load MCP configuration from JSON file."""
logger.debug(f"Attempting to load MCP config from: {self.config_path}")
try:
# Using direct file access
with open(self.config_path) as f:
self.config = json.load(f)
logger.info("MCP configuration loaded successfully.")
logger.debug(f"Config content: {self.config}")
except FileNotFoundError:
logger.error(f"MCP config file not found at {self.config_path}")
self.config = None
except json.JSONDecodeError as e:
logger.error(f"Error decoding JSON from MCP config file {self.config_path}: {e}")
self.config = None
except Exception as e:
logger.error(f"Error loading MCP config from {self.config_path}: {e}", exc_info=True)
self.config = None
# --- Background Event Loop Management ---
def _run_event_loop(self):
"""Target function for the background event loop thread."""
try:
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
self._loop.run_forever()
finally:
if self._loop and not self._loop.is_closed():
# Clean up remaining tasks before closing
try:
tasks = asyncio.all_tasks(self._loop)
if tasks:
logger.debug(f"Cancelling {len(tasks)} outstanding tasks before closing loop...")
for task in tasks:
task.cancel()
# Allow cancellation to propagate
self._loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
logger.debug("Outstanding tasks cancelled.")
self._loop.run_until_complete(self._loop.shutdown_asyncgens())
except Exception as e:
logger.error(f"Error during event loop cleanup: {e}")
finally:
self._loop.close()
asyncio.set_event_loop(None)
logger.info("Event loop thread finished.")
def _start_event_loop_thread(self):
"""Starts the background event loop thread if not already running."""
if self._thread is None or not self._thread.is_alive():
self._thread = threading.Thread(target=self._run_event_loop, name="MCPEventLoop", daemon=True)
self._thread.start()
logger.info("Event loop thread started.")
# Wait briefly for the loop to become available and running
while self._loop is None or not self._loop.is_running():
# Use time.sleep in sync context
import time
time.sleep(0.01)
logger.debug("Event loop is running.")
def _stop_event_loop_thread(self):
"""Stops the background event loop thread."""
if self._loop and self._loop.is_running():
logger.info("Requesting event loop stop...")
self._loop.call_soon_threadsafe(self._loop.stop)
if self._thread and self._thread.is_alive():
logger.info("Waiting for event loop thread to join...")
self._thread.join(timeout=5)
if self._thread.is_alive():
logger.warning("Event loop thread did not stop gracefully.")
self._loop = None
self._thread = None
logger.info("Event loop stopped.")
# --- Public Synchronous Interface ---
def initialize(self) -> bool:
"""
Initializes and starts all configured MCP servers synchronously.
Returns:
True if all servers started successfully, False otherwise.
"""
logger.info("Manager initialization requested.")
if not self.config or not self.config.get("mcpServers"):
logger.warning("Initialization skipped: No valid configuration loaded.")
return False
with self._lock:
if self.initialized:
logger.debug("Initialization skipped: Already initialized.")
return True
self._start_event_loop_thread()
if not self._loop:
logger.error("Failed to start event loop for initialization.")
return False
logger.info("Submitting asynchronous server initialization...")
# Prepare coroutine to start all clients
async def _async_init_all():
tasks = []
for server_name, server_config in self.config["mcpServers"].items():
command = server_config.get("command")
args = server_config.get("args", [])
config_env = server_config.get("env", {})
if not command:
logger.error(f"Skipping server {server_name}: Missing 'command'.")
continue
client = MCPClient(server_name, command, args, config_env)
self.servers[server_name] = client
tasks.append(client.start()) # Append the start coroutine
results = await asyncio.gather(*tasks, return_exceptions=True)
# Check results - True means success, False or Exception means failure
all_success = True
failed_servers = []
for i, result in enumerate(results):
server_name = list(self.config["mcpServers"].keys())[i] # Assumes order is maintained
if isinstance(result, Exception) or result is False:
all_success = False
failed_servers.append(server_name)
# Remove failed client from managed servers
if server_name in self.servers:
del self.servers[server_name]
logger.error(f"Failed to start client for server '{server_name}'. Result/Error: {result}")
if not all_success:
logger.error(f"Initialization failed for servers: {failed_servers}")
return all_success
# Run the initialization coroutine in the background loop
future = asyncio.run_coroutine_threadsafe(_async_init_all(), self._loop)
try:
success = future.result(timeout=INITIALIZE_TIMEOUT)
if success:
logger.info("Asynchronous initialization completed successfully.")
self.initialized = True
else:
logger.error("Asynchronous initialization failed.")
self.initialized = False
# Attempt to clean up any partially started servers
self.shutdown() # Call sync shutdown
except TimeoutError:
logger.error(f"Initialization timed out after {INITIALIZE_TIMEOUT}s.")
self.initialized = False
self.shutdown() # Clean up
success = False
except Exception as e:
logger.error(f"Exception during initialization future result: {e}", exc_info=True)
self.initialized = False
self.shutdown() # Clean up
success = False
return self.initialized
def shutdown(self):
"""Shuts down all managed MCP servers synchronously."""
logger.info("Manager shutdown requested.")
with self._lock:
# Check servers dict too, in case init was partial
if not self.initialized and not self.servers:
logger.debug("Shutdown skipped: Not initialized or no servers running.")
# Ensure loop is stopped if it exists
if self._thread and self._thread.is_alive():
self._stop_event_loop_thread()
return
if not self._loop or not self._loop.is_running():
logger.warning("Shutdown requested but event loop not running. Attempting direct cleanup.")
# Attempt direct cleanup if loop isn't running (shouldn't happen ideally)
# This part is tricky as MCPClient.stop is async.
# For simplicity, we might just log and rely on process termination on app exit.
# Or, try a temporary loop just for shutdown? Let's stick to stopping the thread for now.
self.servers = {}
self.initialized = False
if self._thread and self._thread.is_alive():
self._stop_event_loop_thread()
return
logger.info("Submitting asynchronous server shutdown...")
# Prepare coroutine to stop all clients
async def _async_shutdown_all():
tasks = [client.stop() for client in self.servers.values()]
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
# Run the shutdown coroutine in the background loop
future = asyncio.run_coroutine_threadsafe(_async_shutdown_all(), self._loop)
try:
future.result(timeout=SHUTDOWN_TIMEOUT)
logger.info("Asynchronous shutdown completed.")
except TimeoutError:
logger.error(f"Shutdown timed out after {SHUTDOWN_TIMEOUT}s. Event loop will be stopped.")
# Processes might still be running, OS will clean up on exit hopefully
except Exception as e:
logger.error(f"Exception during shutdown future result: {e}", exc_info=True)
finally:
# Always mark as uninitialized and clear servers dict
self.servers = {}
self.initialized = False
# Stop the background thread
self._stop_event_loop_thread()
logger.info("Manager shutdown complete.")
def list_all_tools(self) -> list[dict[str, Any]]:
"""
Retrieves tools from all initialized MCP servers synchronously.
Returns:
A list of tool definitions in the standard internal format,
aggregated from all servers. Returns empty list on failure.
"""
if not self.initialized or not self.servers:
logger.warning("Cannot list tools: Manager not initialized or no servers running.")
return []
if not self._loop or not self._loop.is_running():
logger.error("Cannot list tools: Event loop not running.")
return []
logger.info(f"Requesting tools from {len(self.servers)} servers...")
# Prepare coroutine to list tools from all clients
async def _async_list_all():
tasks = []
server_names_in_order = []
for server_name, client in self.servers.items():
tasks.append(client.list_tools())
server_names_in_order.append(server_name)
results = await asyncio.gather(*tasks, return_exceptions=True)
all_tools = []
for i, result in enumerate(results):
server_name = server_names_in_order[i]
if isinstance(result, Exception):
logger.error(f"Error listing tools for server '{server_name}': {result}")
elif result is None:
# MCPClient.list_tools returns None on timeout/error
logger.error(f"Failed to list tools for server '{server_name}' (timeout or error).")
elif isinstance(result, list):
# Add server_name to each tool definition
for tool in result:
tool["server_name"] = server_name
all_tools.extend(result)
logger.debug(f"Received {len(result)} tools from {server_name}")
else:
logger.error(f"Unexpected result type ({type(result)}) when listing tools for {server_name}.")
return all_tools
# Run the coroutine in the background loop
future = asyncio.run_coroutine_threadsafe(_async_list_all(), self._loop)
try:
aggregated_tools = future.result(timeout=LIST_ALL_TOOLS_TIMEOUT)
logger.info(f"Aggregated {len(aggregated_tools)} tools from all servers.")
return aggregated_tools
except TimeoutError:
logger.error(f"Listing all tools timed out after {LIST_ALL_TOOLS_TIMEOUT}s.")
return []
except Exception as e:
logger.error(f"Exception during listing all tools future result: {e}", exc_info=True)
return []
def execute_tool(self, server_name: str, tool_name: str, arguments: dict[str, Any]) -> dict[str, Any] | None:
"""
Executes a specific tool on the designated MCP server synchronously.
Args:
server_name: The name of the server hosting the tool.
tool_name: The name of the tool to execute.
arguments: A dictionary of arguments for the tool.
Returns:
The result content from the tool execution (dict),
an error dict ({"error": ...}), or None on timeout/comm failure.
"""
if not self.initialized:
logger.warning(f"Cannot execute tool '{tool_name}' on {server_name}: Manager not initialized.")
return None
client = self.servers.get(server_name)
if not client:
logger.error(f"Cannot execute tool: Server '{server_name}' not found.")
return None
if not self._loop or not self._loop.is_running():
logger.error(f"Cannot execute tool '{tool_name}': Event loop not running.")
return None
logger.info(f"Executing tool '{tool_name}' on server '{server_name}' with args: {arguments}")
# Run the client's call_tool coroutine in the background loop
future = asyncio.run_coroutine_threadsafe(client.call_tool(tool_name, arguments), self._loop)
try:
result = future.result(timeout=EXECUTE_TOOL_TIMEOUT)
# MCPClient.call_tool returns the result dict or an error dict or None
if result is None:
logger.error(f"Tool execution '{tool_name}' on {server_name} failed (timeout or comm error).")
elif isinstance(result, dict) and "error" in result:
logger.error(f"Tool execution '{tool_name}' on {server_name} returned error: {result['error']}")
else:
logger.info(f"Tool '{tool_name}' execution successful.")
return result # Return result dict, error dict, or None
except TimeoutError:
logger.error(f"Tool execution timed out after {EXECUTE_TOOL_TIMEOUT}s for '{tool_name}' on {server_name}.")
return None
except Exception as e:
logger.error(f"Exception during tool execution future result for '{tool_name}' on {server_name}: {e}", exc_info=True)
return None

128
src/custom_mcp/process.py Normal file
View File

@@ -0,0 +1,128 @@
# src/mcp/process.py
"""Async utilities for managing MCP server subprocesses."""
import asyncio
import logging
import os
import subprocess
logger = logging.getLogger(__name__)
async def start_mcp_process(command: str, args: list[str], config_env: dict[str, str]) -> asyncio.subprocess.Process:
"""
Starts an MCP server subprocess using asyncio.create_subprocess_shell.
Handles argument expansion and environment merging.
Args:
command: The main command executable.
args: A list of arguments for the command.
config_env: Server-specific environment variables from config.
Returns:
The started asyncio.subprocess.Process object.
Raises:
FileNotFoundError: If the command is not found.
Exception: For other errors during subprocess creation.
"""
logger.debug(f"Preparing to start process for command: {command}")
# --- Add tilde expansion for arguments ---
expanded_args = []
try:
for arg in args:
if isinstance(arg, str) and "~" in arg:
expanded_args.append(os.path.expanduser(arg))
else:
# Ensure all args are strings for list2cmdline
expanded_args.append(str(arg))
logger.debug(f"Expanded args: {expanded_args}")
except Exception as e:
logger.error(f"Error expanding arguments for {command}: {e}", exc_info=True)
raise ValueError(f"Failed to expand arguments: {e}") from e
# --- Merge os.environ with config_env ---
merged_env = {**os.environ, **config_env}
# logger.debug(f"Merged environment prepared (keys: {list(merged_env.keys())})") # Avoid logging values
# Combine command and expanded args into a single string for shell execution
try:
cmd_string = subprocess.list2cmdline([command] + expanded_args)
logger.debug(f"Executing shell command: {cmd_string}")
except Exception as e:
logger.error(f"Error creating command string: {e}", exc_info=True)
raise ValueError(f"Failed to create command string: {e}") from e
# --- Start the subprocess using shell ---
try:
process = await asyncio.create_subprocess_shell(
cmd_string,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=merged_env,
)
logger.info(f"Subprocess started (PID: {process.pid}) for command: {command}")
return process
except FileNotFoundError:
logger.error(f"Command not found: '{command}' when trying to execute '{cmd_string}'")
raise # Re-raise specific error
except Exception as e:
logger.error(f"Failed to create subprocess for '{cmd_string}': {e}", exc_info=True)
raise # Re-raise other errors
async def stop_mcp_process(process: asyncio.subprocess.Process, server_name: str = "MCP Server"):
"""
Attempts to gracefully stop the MCP server subprocess.
Args:
process: The asyncio.subprocess.Process object to stop.
server_name: A name for logging purposes.
"""
if process is None or process.returncode is not None:
logger.debug(f"Process {server_name} (PID: {process.pid if process else 'N/A'}) already stopped or not started.")
return
pid = process.pid
logger.info(f"Attempting to stop process {server_name} (PID: {pid})...")
# Close stdin first
if process.stdin and not process.stdin.is_closing():
try:
process.stdin.close()
await process.stdin.wait_closed()
logger.debug(f"Stdin closed for {server_name} (PID: {pid})")
except Exception as e:
logger.warning(f"Error closing stdin for {server_name} (PID: {pid}): {e}")
# Attempt graceful termination
try:
process.terminate()
logger.debug(f"Sent terminate signal to {server_name} (PID: {pid})")
await asyncio.wait_for(process.wait(), timeout=5.0)
logger.info(f"Process {server_name} (PID: {pid}) terminated gracefully (return code: {process.returncode}).")
except TimeoutError:
logger.warning(f"Process {server_name} (PID: {pid}) did not terminate gracefully after 5s, killing.")
try:
process.kill()
await process.wait() # Wait for kill to complete
logger.info(f"Process {server_name} (PID: {pid}) killed (return code: {process.returncode}).")
except ProcessLookupError:
logger.warning(f"Process {server_name} (PID: {pid}) already exited before kill.")
except Exception as e_kill:
logger.error(f"Error killing process {server_name} (PID: {pid}): {e_kill}")
except ProcessLookupError:
logger.warning(f"Process {server_name} (PID: {pid}) already exited before termination.")
except Exception as e_term:
logger.error(f"Error during termination of {server_name} (PID: {pid}): {e_term}")
# Attempt kill as fallback if terminate failed and process might still be running
if process.returncode is None:
try:
process.kill()
await process.wait()
logger.info(f"Process {server_name} (PID: {pid}) killed after termination error (return code: {process.returncode}).")
except Exception as e_kill_fallback:
logger.error(f"Error killing process {server_name} (PID: {pid}) after termination error: {e_kill_fallback}")

View File

@@ -0,0 +1,75 @@
# src/mcp/protocol.py
"""Async utilities for MCP JSON-RPC communication over streams."""
import asyncio
import json
import logging
from typing import Any
logger = logging.getLogger(__name__)
async def send_request(writer: asyncio.StreamWriter, request_dict: dict[str, Any]):
"""
Sends a JSON-RPC request dictionary to the MCP server's stdin stream.
Args:
writer: The asyncio StreamWriter connected to the process stdin.
request_dict: The request dictionary to send.
Raises:
ConnectionResetError: If the connection is lost during send.
Exception: For other stream writing errors.
"""
try:
request_json = json.dumps(request_dict) + "\n"
writer.write(request_json.encode("utf-8"))
await writer.drain()
logger.debug(f"Sent request ID {request_dict.get('id')}: {request_json.strip()}")
except ConnectionResetError:
logger.error(f"Connection lost while sending request ID {request_dict.get('id')}")
raise # Re-raise for the caller (MCPClient) to handle
except Exception as e:
logger.error(f"Error sending request ID {request_dict.get('id')}: {e}", exc_info=True)
raise # Re-raise for the caller
async def read_response(reader: asyncio.StreamReader, timeout: float) -> dict[str, Any] | None:
"""
Reads and parses a JSON-RPC response line from the MCP server's stdout stream.
Args:
reader: The asyncio StreamReader connected to the process stdout.
timeout: Seconds to wait for a response line.
Returns:
The parsed response dictionary, or None if timeout or error occurs.
"""
response_str = None
try:
response_json = await asyncio.wait_for(reader.readline(), timeout=timeout)
if not response_json:
logger.warning("Received empty response line (EOF?).")
return None
response_str = response_json.decode("utf-8").strip()
if not response_str:
logger.warning("Received empty response string after strip.")
return None
logger.debug(f"Received response line: {response_str}")
response_dict = json.loads(response_str)
return response_dict
except TimeoutError:
logger.error(f"Timeout ({timeout}s) waiting for response.")
return None
except asyncio.IncompleteReadError:
logger.warning("Connection closed while waiting for response.")
return None
except json.JSONDecodeError as e:
logger.error(f"Error decoding JSON response: {e}. Response: '{response_str}'")
return None
except Exception as e:
logger.error(f"Error reading response: {e}", exc_info=True)
return None

View File

@@ -1,5 +0,0 @@
"""Custom MCP client implementation focused on OpenAI integration."""
from .client import MCPClient, run_interaction
__all__ = ["MCPClient", "run_interaction"]

View File

@@ -1,550 +0,0 @@
"""Custom MCP client implementation with JSON-RPC and OpenAI integration."""
import asyncio
import json
import logging
import os
from collections.abc import AsyncGenerator
from openai import AsyncOpenAI
# Get a logger for this module
logger = logging.getLogger(__name__)
class MCPClient:
"""Lightweight MCP client with JSON-RPC communication."""
def __init__(self, server_name: str, command: str, args: list[str] | None = None, env: dict[str, str] | None = None):
self.server_name = server_name
self.command = command
self.args = args or []
self.env = env or {}
self.process = None
self.tools = []
self.request_id = 0
self.responses = {}
self._shutdown = False
# Use a logger specific to this client instance
self.logger = logging.getLogger(f"{__name__}.{self.server_name}")
async def _receive_loop(self):
"""Listen for responses from the MCP server."""
try:
while self.process and self.process.stdout and not self.process.stdout.at_eof():
line_bytes = await self.process.stdout.readline()
if not line_bytes:
self.logger.debug("STDOUT EOF reached.")
break
line_str = line_bytes.decode().strip()
self.logger.debug(f"STDOUT Raw line: {line_str}")
try:
message = json.loads(line_str)
if "jsonrpc" in message and "id" in message and ("result" in message or "error" in message):
self.logger.debug(f"STDOUT Parsed response for ID {message['id']}")
self.responses[message["id"]] = message
elif "jsonrpc" in message and "method" in message:
self.logger.debug(f"STDOUT Received notification: {message.get('method')}")
else:
self.logger.debug(f"STDOUT Parsed non-response/notification JSON: {message}")
except json.JSONDecodeError:
self.logger.warning("STDOUT Failed to parse line as JSON.")
except Exception as e:
self.logger.error(f"STDOUT Error processing line: {e}", exc_info=True)
except asyncio.CancelledError:
self.logger.debug("STDOUT Receive loop cancelled.")
except Exception as e:
self.logger.error(f"STDOUT Receive loop error: {e}", exc_info=True)
finally:
self.logger.debug("STDOUT Receive loop finished.")
async def _stderr_loop(self):
"""Listen for stderr messages from the MCP server."""
try:
while self.process and self.process.stderr and not self.process.stderr.at_eof():
line_bytes = await self.process.stderr.readline()
if not line_bytes:
self.logger.debug("STDERR EOF reached.")
break
line_str = line_bytes.decode().strip()
self.logger.warning(f"STDERR: {line_str}") # Log stderr as warning
except asyncio.CancelledError:
self.logger.debug("STDERR Stderr loop cancelled.")
except Exception as e:
self.logger.error(f"STDERR Stderr loop error: {e}", exc_info=True)
finally:
self.logger.debug("STDERR Stderr loop finished.")
async def _send_message(self, message: dict) -> bool:
"""Send a JSON-RPC message to the MCP server."""
if not self.process or not self.process.stdin:
self.logger.warning("STDIN Cannot send message, process or stdin not available.")
return False
try:
data = json.dumps(message) + "\n"
self.logger.debug(f"STDIN Sending: {data.strip()}")
self.process.stdin.write(data.encode())
await self.process.stdin.drain()
return True
except ConnectionResetError:
self.logger.error("STDIN Connection reset while sending message.")
self.process = None # Mark process as dead
return False
except Exception as e:
self.logger.error(f"STDIN Error sending message: {e}", exc_info=True)
return False
async def start(self) -> bool:
"""Start the MCP server process."""
self.logger.info("Attempting to start server...")
# Expand ~ in paths and prepare args
expanded_args = []
try:
for a in self.args:
if isinstance(a, str) and "~" in a:
expanded_args.append(os.path.expanduser(a))
else:
expanded_args.append(str(a)) # Ensure all args are strings
except Exception as e:
self.logger.error(f"Error expanding arguments: {e}", exc_info=True)
return False
# Set up environment
env_vars = os.environ.copy()
if self.env:
env_vars.update(self.env)
self.logger.debug(f"Command: {self.command}")
self.logger.debug(f"Expanded Args: {expanded_args}")
# Avoid logging full env unless necessary for debugging sensitive info
# self.logger.debug(f"Environment: {env_vars}")
try:
# Start the subprocess
self.logger.debug("Creating subprocess...")
self.process = await asyncio.create_subprocess_exec(
self.command,
*expanded_args,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE, # Capture stderr
env=env_vars,
)
self.logger.info(f"Subprocess created with PID: {self.process.pid}")
# Start the receive loops
asyncio.create_task(self._receive_loop())
asyncio.create_task(self._stderr_loop()) # Start stderr loop
# Initialize the server
self.logger.debug("Attempting initialization handshake...")
init_success = await self._initialize()
if init_success:
self.logger.info("Initialization handshake successful (or skipped).")
# Add delay after successful start
await asyncio.sleep(0.5)
self.logger.debug("Post-initialization delay complete.")
return True
else:
self.logger.error("Initialization handshake failed.")
await self.stop() # Ensure cleanup if init fails
return False
except FileNotFoundError:
self.logger.error(f"Error starting subprocess: Command not found: '{self.command}'")
return False
except Exception as e:
self.logger.error(f"Error starting subprocess: {e}", exc_info=True)
return False
async def _initialize(self) -> bool:
"""Initialize the MCP server connection. Modified to not wait for response."""
self.logger.debug("Sending 'initialize' request...")
if not self.process:
self.logger.warning("Cannot initialize, process not running.")
return False
# Send initialize request
self.request_id += 1
req_id = self.request_id
initialize_req = {
"jsonrpc": "2.0",
"id": req_id,
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"clientInfo": {"name": "CustomMCPClient", "version": "1.0.0"},
"capabilities": {}, # Add empty capabilities object
},
}
if not await self._send_message(initialize_req):
self.logger.warning("Failed to send 'initialize' request.")
# Continue anyway for non-compliant servers
# Send initialized notification immediately
self.logger.debug("Sending 'initialized' notification...")
notify = {"jsonrpc": "2.0", "method": "notifications/initialized"}
if await self._send_message(notify):
self.logger.debug("'initialized' notification sent.")
else:
self.logger.warning("Failed to send 'initialized' notification.")
# Still return True as the server might be running
self.logger.info("Skipping wait for 'initialize' response (assuming non-compliant server).")
return True # Assume success without waiting for response
async def list_tools(self) -> list[dict]:
"""List available tools from the MCP server."""
if not self.process:
self.logger.warning("Cannot list tools, process not running.")
return []
self.logger.debug("Sending 'tools/list' request...")
self.request_id += 1
req_id = self.request_id
req = {"jsonrpc": "2.0", "id": req_id, "method": "tools/list", "params": {}}
if not await self._send_message(req):
self.logger.error("Failed to send 'tools/list' request.")
return []
# Wait for response
self.logger.debug(f"Waiting for 'tools/list' response (ID: {req_id})...")
start_time = asyncio.get_event_loop().time()
timeout = 10 # seconds
while asyncio.get_event_loop().time() - start_time < timeout:
if req_id in self.responses:
resp = self.responses.pop(req_id)
self.logger.debug(f"Received 'tools/list' response: {resp}")
if "error" in resp:
self.logger.error(f"'tools/list' error response: {resp['error']}")
return []
if "result" in resp and "tools" in resp["result"]:
self.tools = resp["result"]["tools"]
self.logger.info(f"Successfully listed tools: {len(self.tools)}")
return self.tools
else:
self.logger.error("Invalid 'tools/list' response format.")
return []
await asyncio.sleep(0.05)
self.logger.error(f"'tools/list' request timed out after {timeout} seconds.")
return []
async def call_tool(self, tool_name: str, arguments: dict) -> dict:
"""Call a tool on the MCP server."""
if not self.process:
self.logger.warning(f"Cannot call tool '{tool_name}', process not running.")
return {"error": "Server not started"}
self.logger.debug(f"Sending 'tools/call' request for tool '{tool_name}'...")
self.request_id += 1
req_id = self.request_id
req = {"jsonrpc": "2.0", "id": req_id, "method": "tools/call", "params": {"name": tool_name, "arguments": arguments}}
if not await self._send_message(req):
self.logger.error("Failed to send 'tools/call' request.")
return {"error": "Failed to send tool call request"}
# Wait for response
self.logger.debug(f"Waiting for 'tools/call' response (ID: {req_id})...")
start_time = asyncio.get_event_loop().time()
timeout = 30 # seconds
while asyncio.get_event_loop().time() - start_time < timeout:
if req_id in self.responses:
resp = self.responses.pop(req_id)
self.logger.debug(f"Received 'tools/call' response: {resp}")
if "error" in resp:
self.logger.error(f"'tools/call' error response: {resp['error']}")
return {"error": str(resp["error"])}
if "result" in resp:
self.logger.info(f"Tool '{tool_name}' executed successfully.")
return resp["result"]
else:
self.logger.error("Invalid 'tools/call' response format.")
return {"error": "Invalid tool call response format"}
await asyncio.sleep(0.05)
self.logger.error(f"Tool call '{tool_name}' timed out after {timeout} seconds.")
return {"error": f"Tool call timed out after {timeout}s"}
async def stop(self):
"""Stop the MCP server process."""
self.logger.info("Attempting to stop server...")
if self._shutdown or not self.process:
self.logger.debug("Server already stopped or not running.")
return
self._shutdown = True
proc = self.process # Keep a local reference
self.process = None # Prevent further operations
try:
# Send shutdown notification
self.logger.debug("Sending 'shutdown' notification...")
notify = {"jsonrpc": "2.0", "method": "shutdown"}
await self._send_message(notify) # Use the method which now handles None process
await asyncio.sleep(0.5) # Give server time to process
# Close stdin
if proc and proc.stdin:
try:
if not proc.stdin.is_closing():
proc.stdin.close()
await proc.stdin.wait_closed()
self.logger.debug("Stdin closed.")
except Exception as e:
self.logger.warning(f"Error closing stdin: {e}", exc_info=True)
# Terminate the process
if proc:
self.logger.debug(f"Terminating process {proc.pid}...")
proc.terminate()
try:
await asyncio.wait_for(proc.wait(), timeout=2.0)
self.logger.info(f"Process {proc.pid} terminated gracefully.")
except TimeoutError:
self.logger.warning(f"Process {proc.pid} did not terminate gracefully, killing...")
proc.kill()
await proc.wait()
self.logger.info(f"Process {proc.pid} killed.")
except Exception as e:
self.logger.error(f"Error waiting for process termination: {e}", exc_info=True)
except Exception as e:
self.logger.error(f"Error during shutdown sequence: {e}", exc_info=True)
finally:
self.logger.debug("Stop sequence finished.")
# Ensure self.process is None even if errors occurred
self.process = None
async def process_tool_call(tool_call: dict, servers: dict[str, MCPClient]) -> dict:
"""Process a tool call from OpenAI."""
func_name = tool_call["function"]["name"]
try:
func_args = json.loads(tool_call["function"].get("arguments", "{}"))
except json.JSONDecodeError as e:
logger.error(f"Invalid tool arguments format for {func_name}: {e}")
return {"error": "Invalid arguments format"}
# Parse server_name and tool_name from function name
parts = func_name.split("_", 1)
if len(parts) != 2:
logger.error(f"Invalid tool function name format: {func_name}")
return {"error": "Invalid function name format"}
server_name, tool_name = parts
if server_name not in servers:
logger.error(f"Tool call for unknown server: {server_name}")
return {"error": f"Unknown server: {server_name}"}
# Call the tool
return await servers[server_name].call_tool(tool_name, func_args)
async def run_interaction(
user_query: str,
model_name: str,
api_key: str,
base_url: str | None,
mcp_config: dict,
stream: bool = False,
) -> dict | AsyncGenerator:
"""
Run an interaction with OpenAI using MCP server tools.
Args:
user_query: The user's input query.
model_name: The model to use for processing.
api_key: The OpenAI API key.
base_url: The OpenAI API base URL (optional).
mcp_config: The MCP configuration dictionary (for servers).
stream: Whether to stream the response.
Returns:
Dictionary containing response or AsyncGenerator for streaming.
"""
# Validate passed arguments
if not api_key:
logger.error("API key is missing.")
if not stream:
return {"error": "API key is missing."}
else:
async def error_gen():
yield {"error": "API key is missing."}
return error_gen()
# Start MCP servers using mcp_config
servers = {}
all_functions = []
if mcp_config.get("mcpServers"): # Use mcp_config here
for server_name, server_config in mcp_config["mcpServers"].items(): # Use mcp_config here
client = MCPClient(server_name=server_name, command=server_config.get("command"), args=server_config.get("args", []), env=server_config.get("env", {}))
if await client.start():
tools = await client.list_tools()
for tool in tools:
# Ensure parameters is a dict, default to empty if missing or not dict
params = tool.get("inputSchema", {})
if not isinstance(params, dict):
logger.warning(f"Tool '{tool['name']}' for server '{server_name}' has non-dict inputSchema, defaulting to empty.")
params = {}
all_functions.append({
"type": "function", # Explicitly set type for clarity with newer OpenAI API
"function": {"name": f"{server_name}_{tool['name']}", "description": tool.get("description", ""), "parameters": params},
})
servers[server_name] = client
else:
logger.warning(f"Failed to start MCP server '{server_name}', it will be unavailable.")
else:
logger.info("No mcpServers defined in configuration.")
# Use passed api_key and base_url
openai_client = AsyncOpenAI(api_key=api_key, base_url=base_url) # Use arguments
messages = [{"role": "user", "content": user_query}]
tool_defs = [{"type": "function", "function": f["function"]} for f in all_functions] if all_functions else None
if stream:
async def response_generator():
active_servers = list(servers.values()) # Keep track for cleanup
try:
while True:
logger.debug(f"Calling OpenAI with messages: {messages}")
logger.debug(f"Calling OpenAI with tools: {tool_defs}")
# Get OpenAI response
try:
response = await openai_client.chat.completions.create(
model=model_name,
messages=messages,
tools=tool_defs,
tool_choice="auto" if tool_defs else None, # Only set tool_choice if tools exist
stream=True,
)
except Exception as e:
logger.error(f"OpenAI API error: {e}", exc_info=True)
yield {"error": f"OpenAI API error: {e}"}
break
# Process streaming response
full_response_content = ""
tool_calls = []
async for chunk in response:
delta = chunk.choices[0].delta
if delta.content:
content = delta.content
full_response_content += content
yield {"assistant_text": content, "is_chunk": True}
if delta.tool_calls:
for tc in delta.tool_calls:
# Initialize tool call structure if it's the first chunk for this index
if tc.index >= len(tool_calls):
tool_calls.append({"id": "", "type": "function", "function": {"name": "", "arguments": ""}})
# Append parts as they arrive
if tc.id:
tool_calls[tc.index]["id"] = tc.id
if tc.function and tc.function.name:
tool_calls[tc.index]["function"]["name"] = tc.function.name
if tc.function and tc.function.arguments:
tool_calls[tc.index]["function"]["arguments"] += tc.function.arguments
# Add assistant message with content and potential tool calls
assistant_message = {"role": "assistant", "content": full_response_content}
if tool_calls:
# Filter out incomplete tool calls just in case
valid_tool_calls = [tc for tc in tool_calls if tc["id"] and tc["function"]["name"]]
if valid_tool_calls:
assistant_message["tool_calls"] = valid_tool_calls
else:
logger.warning("Received tool call chunks but couldn't assemble valid tool calls.")
messages.append(assistant_message)
logger.debug(f"Assistant message added: {assistant_message}")
# Handle tool calls if any were successfully assembled
if "tool_calls" in assistant_message:
tool_results = []
for tc in assistant_message["tool_calls"]:
logger.info(f"Processing tool call: {tc['function']['name']}")
result = await process_tool_call(tc, servers)
tool_results.append({"role": "tool", "content": json.dumps(result), "tool_call_id": tc["id"]})
messages.extend(tool_results)
logger.debug(f"Tool results added: {tool_results}")
# Loop back to call OpenAI again with tool results
else:
# No tool calls, interaction finished for this turn
yield {"assistant_text": full_response_content, "is_chunk": False, "final": True} # Signal final chunk
break
except Exception as e:
logger.error(f"Error during streaming interaction: {e}", exc_info=True)
yield {"error": f"Interaction error: {e}"}
finally:
# Clean up servers
logger.debug("Cleaning up MCP servers (stream)...")
for server in active_servers:
await server.stop()
logger.debug("MCP server cleanup finished (stream).")
return response_generator()
else: # Non-streaming case
active_servers = list(servers.values()) # Keep track for cleanup
try:
while True:
logger.debug(f"Calling OpenAI with messages: {messages}")
logger.debug(f"Calling OpenAI with tools: {tool_defs}")
# Get OpenAI response
try:
response = await openai_client.chat.completions.create(
model=model_name,
messages=messages,
tools=tool_defs,
tool_choice="auto" if tool_defs else None, # Only set tool_choice if tools exist
)
except Exception as e:
logger.error(f"OpenAI API error: {e}", exc_info=True)
return {"error": f"OpenAI API error: {e}"}
message = response.choices[0].message
messages.append(message)
logger.debug(f"OpenAI response message: {message}")
# Handle tool calls
if message.tool_calls:
tool_results = []
for tc in message.tool_calls:
logger.info(f"Processing tool call: {tc.function.name}")
# Reconstruct dict for process_tool_call
tool_call_dict = {"id": tc.id, "type": "function", "function": {"name": tc.function.name, "arguments": tc.function.arguments}}
result = await process_tool_call(tool_call_dict, servers)
tool_results.append({"role": "tool", "content": json.dumps(result), "tool_call_id": tc.id})
messages.extend(tool_results)
logger.debug(f"Tool results added: {tool_results}")
# Loop back to call OpenAI again with tool results
else:
# No tool calls, interaction finished
logger.info("Interaction finished, no tool calls.")
return {"assistant_text": message.content or "", "tool_calls": []}
except Exception as e:
logger.error(f"Error during non-streaming interaction: {e}", exc_info=True)
return {"error": f"Interaction error: {e}"}
finally:
# Clean up servers
logger.debug("Cleaning up MCP servers (non-stream)...")
for server in active_servers:
await server.stop()
logger.debug("MCP server cleanup finished (non-stream).")

219
src/llm_client.py Normal file
View File

@@ -0,0 +1,219 @@
# src/llm_client.py
"""
Generic LLM client supporting multiple providers and MCP tool integration.
"""
import json
import logging
from collections.abc import Generator
from typing import Any
from src.custom_mcp.manager import SyncMCPManager # Updated import path
from src.providers import BaseProvider, create_llm_provider
logger = logging.getLogger(__name__)
class LLMClient:
"""
Handles chat completion requests to various LLM providers through a unified
interface, integrating with MCP tools via SyncMCPManager.
"""
def __init__(
self,
provider_name: str,
api_key: str,
mcp_manager: SyncMCPManager,
base_url: str | None = None,
):
"""
Initialize the LLM client.
Args:
provider_name: Name of the provider (e.g., 'openai', 'anthropic').
api_key: API key for the provider.
mcp_manager: An initialized instance of SyncMCPManager.
base_url: Optional base URL for the provider API.
"""
logger.info(f"Initializing LLMClient for provider: {provider_name}")
self.provider: BaseProvider = create_llm_provider(provider_name, api_key, base_url)
self.mcp_manager = mcp_manager
self.mcp_tools: list[dict[str, Any]] = []
self._refresh_mcp_tools() # Initial tool load
def _refresh_mcp_tools(self):
"""Retrieves the latest tools from the MCP manager."""
logger.info("Refreshing MCP tools...")
try:
self.mcp_tools = self.mcp_manager.list_all_tools()
logger.info(f"Refreshed {len(self.mcp_tools)} MCP tools.")
except Exception as e:
logger.error(f"Error refreshing MCP tools: {e}", exc_info=True)
# Keep existing tools if refresh fails
def chat_completion(
self,
messages: list[dict[str, str]],
model: str,
temperature: float = 0.4,
max_tokens: int | None = None,
stream: bool = True,
) -> Generator[str, None, None] | dict[str, Any]:
"""
Send a chat completion request, handling potential tool calls.
Args:
messages: List of message dictionaries ({'role': 'user'/'assistant', 'content': ...}).
model: Model identifier string.
temperature: Sampling temperature.
max_tokens: Maximum tokens to generate.
stream: Whether to stream the response.
Returns:
If stream=True: A generator yielding content chunks.
If stream=False: A dictionary containing the final content or an error.
e.g., {"content": "..."} or {"error": "..."}
"""
# Ensure tools are up-to-date (optional, could be done less frequently)
# self._refresh_mcp_tools()
# Convert tools to the provider-specific format
try:
provider_tools = self.provider.convert_tools(self.mcp_tools)
logger.debug(f"Converted {len(self.mcp_tools)} tools for provider {self.provider.__class__.__name__}")
except Exception as e:
logger.error(f"Error converting tools for provider: {e}", exc_info=True)
provider_tools = None # Proceed without tools if conversion fails
try:
logger.info(f"Sending chat completion request to provider with model: {model}")
response = self.provider.create_chat_completion(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
stream=stream,
tools=provider_tools,
)
logger.info("Received response from provider.")
if stream:
# Streaming with tool calls requires more complex handling (like airflow_wingman's
# process_tool_calls_and_follow_up). For now, we'll yield the initial stream
# and handle tool calls *after* the stream completes if detected (less ideal UX).
# A better approach involves checking for tool calls before streaming fully.
# This simplified version just streams the first response.
logger.info("Streaming response...")
# NOTE: This simple version doesn't handle tool calls during streaming well.
# It will stream the initial response which might *contain* the tool call request,
# but won't execute it within the stream.
return self._stream_generator(response)
else: # Non-streaming
logger.info("Processing non-streaming response...")
if self.provider.has_tool_calls(response):
logger.info("Tool calls detected in response.")
# Simplified non-streaming tool call handling (one round)
try:
tool_calls = self.provider.parse_tool_calls(response)
logger.debug(f"Parsed tool calls: {tool_calls}")
tool_results = []
original_message_with_calls = self.provider.get_original_message_with_calls(response) # Provider needs to implement this
messages.append(original_message_with_calls) # Add assistant's turn with tool requests
for tool_call in tool_calls:
server_name = tool_call.get("server_name") # Needs to be parsed by provider
func_name = tool_call.get("function_name")
func_args_str = tool_call.get("arguments")
call_id = tool_call.get("id")
if not server_name or not func_name or func_args_str is None or call_id is None:
logger.error(f"Skipping invalid tool call data: {tool_call}")
# Add error result?
result_content = {"error": "Invalid tool call structure from LLM"}
else:
try:
# Arguments might be a JSON string, parse them
arguments = json.loads(func_args_str)
logger.info(f"Executing tool '{func_name}' on server '{server_name}' with args: {arguments}")
# Execute synchronously using the manager
execution_result = self.mcp_manager.execute_tool(server_name, func_name, arguments)
logger.debug(f"Tool execution result: {execution_result}")
if execution_result is None:
result_content = {"error": f"Tool execution failed or timed out for {func_name}"}
elif isinstance(execution_result, dict) and "error" in execution_result:
result_content = execution_result # Propagate error from tool/server
else:
# Assuming result is the content payload
result_content = execution_result
except json.JSONDecodeError:
logger.error(f"Failed to parse arguments for tool {func_name}: {func_args_str}")
result_content = {"error": f"Invalid arguments format for tool {func_name}"}
except Exception as exec_err:
logger.error(f"Error executing tool {func_name}: {exec_err}", exc_info=True)
result_content = {"error": f"Exception during tool execution: {str(exec_err)}"}
# Format result for the provider's follow-up message
formatted_result = self.provider.format_tool_results(call_id, result_content)
tool_results.append(formatted_result)
messages.append(formatted_result) # Add tool result message
# Make follow-up call
logger.info("Making follow-up request with tool results...")
follow_up_response = self.provider.create_chat_completion(
messages=messages, # Now includes assistant's turn and tool results
model=model,
temperature=temperature,
max_tokens=max_tokens,
stream=False, # Follow-up is non-streaming here
tools=provider_tools, # Pass tools again? Some providers might need it.
)
final_content = self.provider.get_content(follow_up_response)
logger.info("Received follow-up response content.")
return {"content": final_content}
except Exception as tool_handling_err:
logger.error(f"Error processing tool calls: {tool_handling_err}", exc_info=True)
return {"error": f"Failed to handle tool calls: {str(tool_handling_err)}"}
else: # No tool calls
logger.info("No tool calls detected.")
content = self.provider.get_content(response)
return {"content": content}
except Exception as e:
error_msg = f"LLM API Error: {str(e)}"
logger.error(error_msg, exc_info=True)
if stream:
# How to signal error in a stream? Yield a specific error message?
# This simple generator won't handle it well. Returning an error dict for now.
return {"error": error_msg} # Or raise?
else:
return {"error": error_msg}
def _stream_generator(self, response: Any) -> Generator[str, None, None]:
"""Helper to yield content from the provider's streaming method."""
try:
# Use yield from for cleaner and potentially more efficient delegation
yield from self.provider.get_streaming_content(response)
except Exception as e:
logger.error(f"Error during streaming: {e}", exc_info=True)
yield json.dumps({"error": f"Streaming error: {str(e)}"}) # Yield error as JSON chunk
# Example of how a provider might need to implement get_original_message_with_calls
# This would be in the specific provider class (e.g., openai_provider.py)
# def get_original_message_with_calls(self, response: Any) -> Dict[str, Any]:
# # For OpenAI, the tool calls are usually in the *first* response chunk's choice delta
# # or in the non-streaming response's choice message
# # Needs careful implementation based on provider's response structure
# assistant_message = {
# "role": "assistant",
# "content": None, # Often null when tool calls are present
# "tool_calls": [...] # Extracted tool calls in provider format
# }
# return assistant_message

61
src/llm_models.py Normal file
View File

@@ -0,0 +1,61 @@
MODELS = {
"openai": {
"name": "OpenAI",
"endpoint": "https://api.openai.com/v1",
"models": [
{
"id": "gpt-4o",
"name": "GPT-4o",
"default": True,
"context_window": 128000,
"description": "Input $5/M tokens, Output $15/M tokens",
}
],
},
"anthropic": {
"name": "Anthropic",
"endpoint": "https://api.anthropic.com/v1/messages",
"models": [
{
"id": "claude-3-7-sonnet-20250219",
"name": "Claude 3.7 Sonnet",
"default": True,
"context_window": 200000,
"description": "Input $3/M tokens, Output $15/M tokens",
},
{
"id": "claude-3-5-haiku-20241022",
"name": "Claude 3.5 Haiku",
"default": False,
"context_window": 200000,
"description": "Input $0.80/M tokens, Output $4/M tokens",
},
],
},
"google": {
"name": "Google Gemini",
"endpoint": "https://generativelanguage.googleapis.com/v1beta/generateContent",
"models": [
{
"id": "gemini-2.0-flash",
"name": "Gemini 2.0 Flash",
"default": True,
"context_window": 1000000,
"description": "Input $0.1/M tokens, Output $0.4/M tokens",
}
],
},
"openrouter": {
"name": "OpenRouter",
"endpoint": "https://openrouter.ai/api/v1/chat/completions",
"models": [
{
"id": "custom",
"name": "Custom Model",
"default": False,
"context_window": 128000, # Default context window, will be updated based on model
"description": "Enter any model name supported by OpenRouter (e.g., 'anthropic/claude-3-opus', 'meta-llama/llama-2-70b')",
},
],
},
}

View File

@@ -1,234 +0,0 @@
"""Synchronous wrapper for managing MCP servers using our custom implementation."""
import asyncio
import importlib.resources
import json
import logging # Import logging
import threading
from custom_mcp_client import MCPClient, run_interaction
# Configure basic logging for the application if not already configured
# This basic config helps ensure logs are visible during development
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
# Get a logger for this module
logger = logging.getLogger(__name__)
class SyncMCPManager:
"""Synchronous wrapper for managing MCP servers and interactions"""
def __init__(self, config_path: str = "config/mcp_config.json"):
self.config_path = config_path
self.config = None
self.servers = {}
self.initialized = False
self._lock = threading.Lock()
logger.info(f"Initializing SyncMCPManager with config path: {config_path}")
self._load_config()
def _load_config(self):
"""Load MCP configuration from JSON file using importlib"""
logger.debug(f"Attempting to load MCP config from: {self.config_path}")
try:
# First try to load as a package resource
try:
# Try anchoring to the project name defined in pyproject.toml
# This *might* work depending on editable installs or context.
resource_path = importlib.resources.files("streamlit-chat-app").joinpath(self.config_path)
with resource_path.open("r") as f:
self.config = json.load(f)
logger.debug("Loaded config via importlib.resources anchored to 'streamlit-chat-app'.")
# REMOVED: raise FileNotFoundError
except (ImportError, ModuleNotFoundError, TypeError, FileNotFoundError, NotADirectoryError): # Added NotADirectoryError
logger.debug("Failed to load via importlib.resources, falling back to direct file access.")
# Fall back to direct file access relative to CWD
with open(self.config_path) as f:
self.config = json.load(f)
logger.debug("Loaded config via direct file access.")
logger.info("MCP configuration loaded successfully.")
logger.debug(f"Config content: {self.config}") # Log content only if loaded
except FileNotFoundError:
logger.error(f"MCP config file not found at {self.config_path}")
self.config = None
except json.JSONDecodeError as e:
logger.error(f"Error decoding JSON from MCP config file {self.config_path}: {e}")
self.config = None
except Exception as e:
logger.error(f"Error loading MCP config from {self.config_path}: {e}", exc_info=True)
self.config = None
def initialize(self) -> bool:
"""Initialize and start all MCP servers synchronously"""
logger.info("Initialize requested.")
if not self.config:
logger.warning("Initialization skipped: No configuration loaded.")
return False
if not self.config.get("mcpServers"):
logger.warning("Initialization skipped: No 'mcpServers' defined in configuration.")
return False
if self.initialized:
logger.debug("Initialization skipped: Already initialized.")
return True
with self._lock:
if self.initialized: # Double-check after acquiring lock
logger.debug("Initialization skipped inside lock: Already initialized.")
return True
logger.info("Starting asynchronous initialization...")
# Run async initialization in a new event loop
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) # Ensure this loop is used by tasks
success = loop.run_until_complete(self._async_initialize())
loop.close()
asyncio.set_event_loop(None) # Clean up
if success:
logger.info("Asynchronous initialization completed successfully.")
self.initialized = True
else:
logger.error("Asynchronous initialization failed.")
self.initialized = False # Ensure state is False on failure
return self.initialized
async def _async_initialize(self) -> bool:
"""Async implementation of server initialization"""
logger.debug("Starting _async_initialize...")
all_success = True
if not self.config or not self.config.get("mcpServers"):
logger.warning("_async_initialize: No config or mcpServers found.")
return False
tasks = []
server_names = list(self.config["mcpServers"].keys())
async def start_server(server_name, server_config):
logger.info(f"Initializing server: {server_name}")
try:
client = MCPClient(server_name=server_name, command=server_config.get("command"), args=server_config.get("args", []), env=server_config.get("env", {}))
logger.debug(f"Attempting to start client for {server_name}...")
if await client.start():
logger.info(f"Client for {server_name} started successfully.")
tools = await client.list_tools()
logger.info(f"Tools listed for {server_name}: {len(tools)}")
self.servers[server_name] = {"client": client, "tools": tools}
return True
else:
logger.error(f"Failed to start MCP server: {server_name}")
return False
except Exception as e:
logger.error(f"Error initializing server {server_name}: {e}", exc_info=True)
return False
# Start servers concurrently
for server_name in server_names:
server_config = self.config["mcpServers"][server_name]
tasks.append(start_server(server_name, server_config))
results = await asyncio.gather(*tasks)
# Check if all servers started successfully
all_success = all(results)
if all_success:
logger.debug("_async_initialize completed: All servers started successfully.")
else:
failed_servers = [server_names[i] for i, res in enumerate(results) if not res]
logger.error(f"_async_initialize completed with failures. Failed servers: {failed_servers}")
# Optionally shutdown servers that did start if partial success is not desired
# await self._async_shutdown() # Uncomment to enforce all-or-nothing startup
return all_success
def shutdown(self):
"""Shut down all MCP servers synchronously"""
logger.info("Shutdown requested.")
if not self.initialized:
logger.debug("Shutdown skipped: Not initialized.")
return
with self._lock:
if not self.initialized:
logger.debug("Shutdown skipped inside lock: Not initialized.")
return
logger.info("Starting asynchronous shutdown...")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(self._async_shutdown())
loop.close()
asyncio.set_event_loop(None)
self.servers = {}
self.initialized = False
logger.info("Shutdown complete.")
async def _async_shutdown(self):
"""Async implementation of server shutdown"""
logger.debug("Starting _async_shutdown...")
tasks = []
for server_name, server_info in self.servers.items():
logger.debug(f"Initiating shutdown for server: {server_name}")
tasks.append(server_info["client"].stop())
results = await asyncio.gather(*tasks, return_exceptions=True)
for i, result in enumerate(results):
server_name = list(self.servers.keys())[i]
if isinstance(result, Exception):
logger.error(f"Error shutting down server {server_name}: {result}", exc_info=result)
else:
logger.debug(f"Shutdown completed for server: {server_name}")
logger.debug("_async_shutdown finished.")
# Updated process_query signature
def process_query(self, query: str, model_name: str, api_key: str, base_url: str | None) -> dict:
"""
Process a query using MCP tools synchronously
Args:
query: The user's input query.
model_name: The model to use for processing.
api_key: The OpenAI API key.
base_url: The OpenAI API base URL.
Returns:
Dictionary containing response or error.
"""
if not self.initialized and not self.initialize():
logger.error("process_query called but MCP manager failed to initialize.")
return {"error": "Failed to initialize MCP servers"}
logger.debug(f"Processing query synchronously: '{query}'")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
# Pass api_key and base_url to _async_process_query
result = loop.run_until_complete(self._async_process_query(query, model_name, api_key, base_url))
logger.debug(f"Synchronous query processing result: {result}")
return result
except Exception as e:
logger.error(f"Error during synchronous query processing: {e}", exc_info=True)
return {"error": f"Processing error: {str(e)}"}
finally:
loop.close()
# Updated _async_process_query signature
async def _async_process_query(self, query: str, model_name: str, api_key: str, base_url: str | None) -> dict:
"""Async implementation of query processing"""
# Pass api_key, base_url, and the MCP config separately to run_interaction
return await run_interaction(
user_query=query,
model_name=model_name,
api_key=api_key,
base_url=base_url,
mcp_config=self.config, # self.config only contains MCP server definitions now
stream=False,
)

69
src/providers/__init__.py Normal file
View File

@@ -0,0 +1,69 @@
# src/providers/__init__.py
import logging
from providers.anthropic_provider import AnthropicProvider
from providers.base import BaseProvider
from providers.openai_provider import OpenAIProvider
# from providers.google_provider import GoogleProvider
# from providers.openrouter_provider import OpenRouterProvider
logger = logging.getLogger(__name__)
# Map provider names (lowercase) to their corresponding class implementations
PROVIDER_MAP: dict[str, type[BaseProvider]] = {
"openai": OpenAIProvider,
"anthropic": AnthropicProvider,
# "google": GoogleProvider,
# "openrouter": OpenRouterProvider, # OpenRouter can often use OpenAIProvider with custom base_url
}
def register_provider(name: str, provider_class: type[BaseProvider]):
"""Registers a provider class."""
if name.lower() in PROVIDER_MAP:
logger.warning(f"Provider '{name}' is already registered. Overwriting.")
PROVIDER_MAP[name.lower()] = provider_class
logger.info(f"Registered provider: {name}")
def create_llm_provider(provider_name: str, api_key: str, base_url: str | None = None) -> BaseProvider:
"""
Factory function to create an instance of a specific LLM provider.
Args:
provider_name: The name of the provider (e.g., 'openai', 'anthropic').
api_key: The API key for the provider.
base_url: Optional base URL for the provider's API.
Returns:
An instance of the requested BaseProvider subclass.
Raises:
ValueError: If the requested provider_name is not registered.
"""
provider_class = PROVIDER_MAP.get(provider_name.lower())
if provider_class is None:
available = ", ".join(PROVIDER_MAP.keys()) or "None"
raise ValueError(f"Unsupported LLM provider: '{provider_name}'. Available providers: {available}")
logger.info(f"Creating LLM provider instance for: {provider_name}")
try:
return provider_class(api_key=api_key, base_url=base_url)
except Exception as e:
logger.error(f"Failed to instantiate provider '{provider_name}': {e}", exc_info=True)
raise RuntimeError(f"Could not create provider '{provider_name}'.") from e
def get_available_providers() -> list[str]:
"""Returns a list of registered provider names."""
return list(PROVIDER_MAP.keys())
# Example of how specific providers would register themselves if structured as plugins,
# but for now, we'll explicitly import and map them above.
# def load_providers():
# # Potentially load providers dynamically if designed as plugins
# pass
# load_providers()

View File

@@ -0,0 +1,295 @@
# src/providers/anthropic_provider.py
import json
import logging
from collections.abc import Generator
from typing import Any
from anthropic import Anthropic, Stream
from anthropic.types import Message, MessageStreamEvent, TextDelta
# Use relative imports for modules within the same package
from providers.base import BaseProvider
# Use absolute imports as per Ruff warning and user instructions
from src.llm_models import MODELS
from src.tools.conversion import convert_to_anthropic_tools
logger = logging.getLogger(__name__)
class AnthropicProvider(BaseProvider):
"""Provider implementation for Anthropic Claude models."""
def __init__(self, api_key: str, base_url: str | None = None):
# Anthropic client doesn't use base_url in the same way, but store it if needed
# Use default Anthropic endpoint if base_url is not provided or relevant
effective_base_url = base_url or MODELS.get("anthropic", {}).get("endpoint")
super().__init__(api_key, effective_base_url) # Pass base_url to parent, though Anthropic client might ignore it
logger.info("Initializing AnthropicProvider")
try:
self.client = Anthropic(api_key=self.api_key)
# Note: Anthropic client doesn't take base_url during init
except Exception as e:
logger.error(f"Failed to initialize Anthropic client: {e}", exc_info=True)
raise
def _convert_messages(self, messages: list[dict[str, Any]]) -> tuple[str | None, list[dict[str, Any]]]:
"""Converts standard message format to Anthropic's format, extracting system prompt."""
anthropic_messages = []
system_prompt = None
for i, message in enumerate(messages):
role = message.get("role")
content = message.get("content")
if role == "system":
if i == 0:
system_prompt = content
logger.debug("Extracted system prompt for Anthropic.")
else:
# Handle system message not at the start (append to previous user message or add as user)
logger.warning("System message found not at the beginning. Treating as user message.")
anthropic_messages.append({"role": "user", "content": f"[System Note]\n{content}"})
continue
# Handle tool results specifically
if role == "tool":
# Find the preceding assistant message with the corresponding tool_use block
# This requires careful handling in the follow-up logic
tool_use_id = message.get("tool_call_id")
tool_content = content
# Format as a tool_result content block
anthropic_messages.append({"role": "user", "content": [{"type": "tool_result", "tool_use_id": tool_use_id, "content": tool_content}]})
continue
# Handle assistant message potentially containing tool_use blocks
if role == "assistant":
# Check if content is structured (e.g., from a previous tool call response)
if isinstance(content, list): # Assuming tool calls might be represented as a list
anthropic_messages.append({"role": "assistant", "content": content})
else:
anthropic_messages.append({"role": "assistant", "content": content}) # Regular text content
continue
# Regular user messages
if role == "user":
anthropic_messages.append({"role": "user", "content": content})
continue
logger.warning(f"Unsupported role '{role}' in message conversion for Anthropic.")
# Ensure conversation starts with a user message if no system prompt was used
if not system_prompt and anthropic_messages and anthropic_messages[0]["role"] != "user":
logger.warning("Anthropic conversation must start with a user message. Prepending empty user message.")
anthropic_messages.insert(0, {"role": "user", "content": "[Start of conversation]"}) # Or handle differently
return system_prompt, anthropic_messages
def create_chat_completion(
self,
messages: list[dict[str, str]],
model: str,
temperature: float = 0.4,
max_tokens: int | None = None, # Anthropic requires max_tokens
stream: bool = True,
tools: list[dict[str, Any]] | None = None,
) -> Stream[MessageStreamEvent] | Message:
"""Creates a chat completion using the Anthropic API."""
logger.debug(f"Anthropic create_chat_completion called. Stream: {stream}, Tools: {bool(tools)}")
# Anthropic requires max_tokens
if max_tokens is None:
max_tokens = 4096 # Default value if not provided
logger.warning(f"max_tokens not provided for Anthropic, defaulting to {max_tokens}")
system_prompt, anthropic_messages = self._convert_messages(messages)
try:
completion_params = {
"model": model,
"messages": anthropic_messages,
"temperature": temperature,
"max_tokens": max_tokens,
"stream": stream,
}
if system_prompt:
completion_params["system"] = system_prompt
if tools:
completion_params["tools"] = tools
# Anthropic doesn't have an explicit 'tool_choice' like OpenAI's 'auto' in the main API call
# Remove None values (though Anthropic requires max_tokens)
completion_params = {k: v for k, v in completion_params.items() if v is not None}
log_params = completion_params.copy()
if "messages" in log_params:
log_params["messages"] = [{k: (v[:100] + "..." if isinstance(v, str) and len(v) > 100 else v) for k, v in msg.items()} for msg in log_params["messages"][-2:]]
tools_log = log_params.get("tools", "Not Present")
logger.debug(f"Calling Anthropic API. Model: {log_params.get('model')}, Stream: {log_params.get('stream')}, System: {bool(log_params.get('system'))}, Tools: {tools_log}")
logger.debug(f"Full API Params (messages summarized): {log_params}")
response = self.client.messages.create(**completion_params)
logger.debug("Anthropic API call successful.")
return response
except Exception as e:
logger.error(f"Anthropic API error: {e}", exc_info=True)
raise
def get_streaming_content(self, response: Stream[MessageStreamEvent]) -> Generator[str, None, None]:
"""Yields content chunks from an Anthropic streaming response."""
logger.debug("Processing Anthropic stream...")
full_delta = ""
try:
# Iterate through events in the stream
for event in response:
if event.type == "content_block_delta":
# Check if the delta is for text content before accessing .text
if isinstance(event.delta, TextDelta):
delta_text = event.delta.text
if delta_text:
full_delta += delta_text
yield delta_text
# Ignore other delta types like InputJSONDelta for text streaming
# Other event types like 'message_start', 'content_block_start', etc., can be logged or handled if needed
elif event.type == "message_start":
logger.debug(f"Anthropic stream started. Model: {event.message.model}")
elif event.type == "message_stop":
# The stop_reason might be available on the 'message' object associated with the stream,
# not directly on the stop event itself. We log that the stop event occurred.
# Accessing the actual reason might require inspecting the final message state if needed.
logger.debug("Anthropic stream message_stop event received.")
elif event.type == "content_block_start":
if event.content_block.type == "tool_use":
logger.debug(f"Anthropic stream detected tool use start: ID {event.content_block.id}, Name: {event.content_block.name}")
elif event.type == "content_block_stop":
logger.debug(f"Anthropic stream detected content block stop. Index: {event.index}")
logger.debug(f"Anthropic stream finished. Total delta length: {len(full_delta)}")
except Exception as e:
logger.error(f"Error processing Anthropic stream: {e}", exc_info=True)
yield json.dumps({"error": f"Stream processing error: {str(e)}"})
def get_content(self, response: Message) -> str:
"""Extracts content from a non-streaming Anthropic response."""
try:
# Combine text content from all text blocks
text_content = "".join([block.text for block in response.content if block.type == "text"])
logger.debug(f"Extracted content (length {len(text_content)}) from non-streaming Anthropic response.")
return text_content
except Exception as e:
logger.error(f"Error extracting content from Anthropic response: {e}", exc_info=True)
return f"[Error extracting content: {str(e)}]"
def has_tool_calls(self, response: Stream[MessageStreamEvent] | Message) -> bool:
"""Checks if the Anthropic response contains tool calls."""
try:
if isinstance(response, Message): # Non-streaming
# Check stop reason and content blocks
has_tool_use_block = any(block.type == "tool_use" for block in response.content)
has_calls = response.stop_reason == "tool_use" or has_tool_use_block
logger.debug(f"Non-streaming Anthropic response check: stop_reason='{response.stop_reason}', has_tool_use_block={has_tool_use_block}. Result: {has_calls}")
return has_calls
elif isinstance(response, Stream):
# Cannot reliably check an unconsumed stream without consuming it.
# The LLMClient should handle this by checking after consumption or based on stop_reason if available post-stream.
logger.warning("has_tool_calls check on an Anthropic stream is unreliable before consumption.")
return False
else:
logger.warning(f"has_tool_calls received unexpected type for Anthropic: {type(response)}")
return False
except Exception as e:
logger.error(f"Error checking for Anthropic tool calls: {e}", exc_info=True)
return False
def parse_tool_calls(self, response: Message) -> list[dict[str, Any]]:
"""Parses tool calls from a non-streaming Anthropic response."""
parsed_calls = []
try:
if not isinstance(response, Message):
logger.error(f"parse_tool_calls expects Anthropic Message, got {type(response)}")
return []
if response.stop_reason != "tool_use":
logger.debug("No tool use indicated by stop_reason.")
# return [] # Might still have tool_use blocks even if stop_reason isn't tool_use? Check API docs. Let's check content anyway.
tool_use_blocks = [block for block in response.content if block.type == "tool_use"]
if not tool_use_blocks:
logger.debug("No 'tool_use' content blocks found in Anthropic response.")
return []
logger.debug(f"Parsing {len(tool_use_blocks)} 'tool_use' blocks from Anthropic response.")
for block in tool_use_blocks:
# Adapt server/tool name splitting if needed (similar to OpenAI provider)
# Assuming Anthropic tool names might also be prefixed like "server__tool"
parts = block.name.split("__", 1)
if len(parts) == 2:
server_name, func_name = parts
else:
logger.warning(f"Could not determine server_name from Anthropic tool name '{block.name}'.")
server_name = None
func_name = block.name
parsed_calls.append({
"id": block.id,
"server_name": server_name,
"function_name": func_name,
"arguments": json.dumps(block.input), # Anthropic input is already a dict, dump to string like OpenAI provider expects? Or keep as dict? Let's keep as dict for now.
# "arguments": block.input, # Keep as dict? Let's try this first.
})
return parsed_calls
except Exception as e:
logger.error(f"Error parsing Anthropic tool calls: {e}", exc_info=True)
return []
def format_tool_results(self, tool_call_id: str, result: Any) -> dict[str, Any]:
"""Formats a tool result for an Anthropic follow-up request."""
# Anthropic expects a 'tool_result' content block
# The content of the result block should typically be a string.
try:
if isinstance(result, dict):
content_str = json.dumps(result)
else:
content_str = str(result)
except Exception as e:
logger.error(f"Error JSON-encoding tool result for Anthropic {tool_call_id}: {e}")
content_str = json.dumps({"error": "Failed to encode tool result", "original_type": str(type(result))})
logger.debug(f"Formatting Anthropic tool result for call ID {tool_call_id}")
# This needs to be placed inside a "user" role message's content list
return {
"type": "tool_result",
"tool_use_id": tool_call_id,
"content": content_str,
# Optionally add is_error=True if result indicates an error
}
def convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Converts internal tool format to Anthropic's format."""
# Use the conversion function, assuming it's correctly placed and imported
logger.debug(f"Converting {len(tools)} tools to Anthropic format.")
try:
# The conversion function needs to handle the server__tool prefixing
anthropic_tools = convert_to_anthropic_tools(tools)
logger.debug(f"Tool conversion result: {anthropic_tools}")
return anthropic_tools
except Exception as e:
logger.error(f"Error during Anthropic tool conversion: {e}", exc_info=True)
return []
# Helper needed by LLMClient's current tool handling logic (if adapting OpenAI's pattern)
def get_original_message_with_calls(self, response: Message) -> dict[str, Any]:
"""Extracts the assistant's message containing tool calls for Anthropic."""
try:
if isinstance(response, Message) and any(block.type == "tool_use" for block in response.content):
# Anthropic's response structure is different. The 'message' itself is the assistant's turn.
# We need to return a representation of this turn, including the tool_use blocks.
# Convert Pydantic models within content to dicts
content_list = [block.model_dump(exclude_unset=True) for block in response.content]
return {"role": "assistant", "content": content_list}
else:
logger.warning("Could not extract original message with tool calls from Anthropic response.")
return {"role": "assistant", "content": "[Could not extract tool calls message]"}
except Exception as e:
logger.error(f"Error extracting original Anthropic message with calls: {e}", exc_info=True)
return {"role": "assistant", "content": f"[Error extracting tool calls message: {str(e)}]"}

140
src/providers/base.py Normal file
View File

@@ -0,0 +1,140 @@
# src/providers/base.py
import abc
from collections.abc import Generator
from typing import Any
class BaseProvider(abc.ABC):
"""
Abstract base class for LLM providers.
Defines the common interface for interacting with different LLM APIs,
including handling chat completions and tool usage.
"""
def __init__(self, api_key: str, base_url: str | None = None):
"""
Initialize the provider.
Args:
api_key: The API key for the provider.
base_url: Optional base URL for the provider's API.
"""
self.api_key = api_key
self.base_url = base_url
@abc.abstractmethod
def create_chat_completion(
self,
messages: list[dict[str, str]],
model: str,
temperature: float = 0.4,
max_tokens: int | None = None,
stream: bool = True,
tools: list[dict[str, Any]] | None = None,
) -> Any:
"""
Send a chat completion request to the LLM provider.
Args:
messages: List of message dictionaries with 'role' and 'content'.
model: Model identifier.
temperature: Sampling temperature (0-1).
max_tokens: Maximum tokens to generate.
stream: Whether to stream the response.
tools: Optional list of tools in the provider-specific format.
Returns:
Provider-specific response object (e.g., API response, stream object).
"""
pass
@abc.abstractmethod
def get_streaming_content(self, response: Any) -> Generator[str, None, None]:
"""
Extracts and yields content chunks from a streaming response object.
Args:
response: The streaming response object returned by create_chat_completion.
Yields:
String chunks of the response content.
"""
pass
@abc.abstractmethod
def get_content(self, response: Any) -> str:
"""
Extracts the complete content from a non-streaming response object.
Args:
response: The non-streaming response object.
Returns:
The complete response content as a string.
"""
pass
@abc.abstractmethod
def has_tool_calls(self, response: Any) -> bool:
"""
Checks if the response object contains tool calls.
Args:
response: The response object (streaming or non-streaming).
Returns:
True if tool calls are present, False otherwise.
"""
pass
@abc.abstractmethod
def parse_tool_calls(self, response: Any) -> list[dict[str, Any]]:
"""
Parses tool calls from the response object.
Args:
response: The response object containing tool calls.
Returns:
A list of dictionaries, each representing a tool call with details
like 'id', 'function_name', 'arguments'. The exact structure might
vary slightly based on provider needs but should contain enough
info for execution.
"""
pass
@abc.abstractmethod
def format_tool_results(self, tool_call_id: str, result: Any) -> dict[str, Any]:
"""
Formats the result of a tool execution into the structure expected
by the provider for follow-up requests.
Args:
tool_call_id: The unique ID of the tool call (from parse_tool_calls).
result: The data returned by the tool execution.
Returns:
A dictionary representing the tool result in the provider's format.
"""
pass
@abc.abstractmethod
def convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""
Converts a list of tools from the standard internal format to the
provider-specific format required for the API call.
Args:
tools: List of tool definitions in the standard internal format.
Each dict contains 'server_name', 'name', 'description', 'input_schema'.
Returns:
List of tool definitions in the provider-specific format.
"""
pass
# Optional: Add a method for follow-up completions if the provider API
# requires a specific structure different from just appending messages.
# def create_follow_up_completion(...) -> Any:
# pass

View File

@@ -0,0 +1,239 @@
# src/providers/openai_provider.py
import json
import logging
from collections.abc import Generator
from typing import Any
from openai import OpenAI, Stream
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall
from providers.base import BaseProvider
from src.llm_models import MODELS # Use absolute import
logger = logging.getLogger(__name__)
class OpenAIProvider(BaseProvider):
"""Provider implementation for OpenAI and compatible APIs."""
def __init__(self, api_key: str, base_url: str | None = None):
# Use default OpenAI endpoint if base_url is not provided
effective_base_url = base_url or MODELS.get("openai", {}).get("endpoint")
super().__init__(api_key, effective_base_url)
logger.info(f"Initializing OpenAIProvider with base URL: {self.base_url}")
try:
# TODO: Add default headers like in original client?
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
except Exception as e:
logger.error(f"Failed to initialize OpenAI client: {e}", exc_info=True)
raise
def create_chat_completion(
self,
messages: list[dict[str, str]],
model: str,
temperature: float = 0.4,
max_tokens: int | None = None,
stream: bool = True,
tools: list[dict[str, Any]] | None = None,
) -> Stream[ChatCompletionChunk] | ChatCompletion:
"""Creates a chat completion using the OpenAI API."""
logger.debug(f"OpenAI create_chat_completion called. Stream: {stream}, Tools: {bool(tools)}")
try:
completion_params = {
"model": model,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
"stream": stream,
}
if tools:
completion_params["tools"] = tools
completion_params["tool_choice"] = "auto" # Let OpenAI decide when to use tools
# Remove None values like max_tokens if not provided
completion_params = {k: v for k, v in completion_params.items() if v is not None}
# --- Added Debug Logging ---
log_params = completion_params.copy()
# Avoid logging full messages if they are too long
if "messages" in log_params:
log_params["messages"] = [
{k: (v[:100] + "..." if isinstance(v, str) and len(v) > 100 else v) for k, v in msg.items()}
for msg in log_params["messages"][-2:] # Log last 2 messages summary
]
# Specifically log tools structure if present
tools_log = log_params.get("tools", "Not Present")
logger.debug(f"Calling OpenAI API. Model: {log_params.get('model')}, Stream: {log_params.get('stream')}, Tools: {tools_log}")
logger.debug(f"Full API Params (messages summarized): {log_params}")
# --- End Added Debug Logging ---
response = self.client.chat.completions.create(**completion_params)
logger.debug("OpenAI API call successful.")
return response
except Exception as e:
logger.error(f"OpenAI API error: {e}", exc_info=True)
# Re-raise for the LLMClient to handle
raise
def get_streaming_content(self, response: Stream[ChatCompletionChunk]) -> Generator[str, None, None]:
"""Yields content chunks from an OpenAI streaming response."""
logger.debug("Processing OpenAI stream...")
full_delta = ""
try:
for chunk in response:
delta = chunk.choices[0].delta.content
if delta:
full_delta += delta
yield delta
logger.debug(f"Stream finished. Total delta length: {len(full_delta)}")
except Exception as e:
logger.error(f"Error processing OpenAI stream: {e}", exc_info=True)
# Yield an error message? Or let the generator stop?
yield json.dumps({"error": f"Stream processing error: {str(e)}"})
def get_content(self, response: ChatCompletion) -> str:
"""Extracts content from a non-streaming OpenAI response."""
try:
content = response.choices[0].message.content
logger.debug(f"Extracted content (length {len(content) if content else 0}) from non-streaming response.")
return content or "" # Return empty string if content is None
except Exception as e:
logger.error(f"Error extracting content from OpenAI response: {e}", exc_info=True)
return f"[Error extracting content: {str(e)}]"
def has_tool_calls(self, response: Stream[ChatCompletionChunk] | ChatCompletion) -> bool:
"""Checks if the OpenAI response contains tool calls."""
try:
if isinstance(response, ChatCompletion): # Non-streaming
return bool(response.choices[0].message.tool_calls)
elif hasattr(response, "_iterator"): # Check if it looks like our stream wrapper
# This is tricky for streams. We'd need to peek at the first chunk(s)
# or buffer the response. For simplicity, this check might be unreliable
# for streams *before* they are consumed. LLMClient needs robust handling.
logger.warning("has_tool_calls check on a stream is unreliable before consumption.")
# A more robust check would involve consuming the start of the stream
# or relying on the structure after consumption.
return False # Assume no for unconsumed stream for now
else:
# If it's already consumed stream or unexpected type
logger.warning(f"has_tool_calls received unexpected type: {type(response)}")
return False
except Exception as e:
logger.error(f"Error checking for tool calls: {e}", exc_info=True)
return False
def parse_tool_calls(self, response: ChatCompletion) -> list[dict[str, Any]]:
"""Parses tool calls from a non-streaming OpenAI response."""
# This implementation assumes a non-streaming response or a fully buffered stream
parsed_calls = []
try:
if not isinstance(response, ChatCompletion):
logger.error(f"parse_tool_calls expects ChatCompletion, got {type(response)}")
# Attempt to handle buffered stream if possible? Complex.
return []
tool_calls: list[ChatCompletionMessageToolCall] | None = response.choices[0].message.tool_calls
if not tool_calls:
return []
logger.debug(f"Parsing {len(tool_calls)} tool calls from OpenAI response.")
for call in tool_calls:
if call.type == "function":
# Attempt to parse server_name from function name if prefixed
# e.g., "server-name__actual-tool-name"
parts = call.function.name.split("__", 1)
if len(parts) == 2:
server_name, func_name = parts
else:
# If no prefix, how do we know the server? Needs refinement.
# Defaulting to None or a default server? Log warning.
logger.warning(f"Could not determine server_name from tool name '{call.function.name}'. Assuming default or error needed.")
server_name = None # Or raise error, or use a default?
func_name = call.function.name
parsed_calls.append({
"id": call.id,
"server_name": server_name, # May be None if not prefixed
"function_name": func_name,
"arguments": call.function.arguments, # Arguments are already a string here
})
else:
logger.warning(f"Unsupported tool call type: {call.type}")
return parsed_calls
except Exception as e:
logger.error(f"Error parsing OpenAI tool calls: {e}", exc_info=True)
return [] # Return empty list on error
def format_tool_results(self, tool_call_id: str, result: Any) -> dict[str, Any]:
"""Formats a tool result for an OpenAI follow-up request."""
# Result might be a dict (including potential errors) or simple string/number
# OpenAI expects the content to be a string, often JSON.
try:
if isinstance(result, dict):
content = json.dumps(result)
else:
content = str(result) # Ensure it's a string
except Exception as e:
logger.error(f"Error JSON-encoding tool result for {tool_call_id}: {e}")
content = json.dumps({"error": "Failed to encode tool result", "original_type": str(type(result))})
logger.debug(f"Formatting tool result for call ID {tool_call_id}")
return {
"role": "tool",
"tool_call_id": tool_call_id,
"content": content,
}
def convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Converts internal tool format to OpenAI's format."""
openai_tools = []
logger.debug(f"Converting {len(tools)} tools to OpenAI format.")
for tool in tools:
server_name = tool.get("server_name")
tool_name = tool.get("name")
description = tool.get("description")
input_schema = tool.get("inputSchema")
if not server_name or not tool_name or not description or not input_schema:
logger.warning(f"Skipping invalid tool definition during conversion: {tool}")
continue
# Prefix tool name with server name to avoid clashes and allow routing
prefixed_tool_name = f"{server_name}__{tool_name}"
openai_tool_format = {
"type": "function",
"function": {
"name": prefixed_tool_name,
"description": description,
"parameters": input_schema, # OpenAI uses JSON Schema directly
},
}
openai_tools.append(openai_tool_format)
logger.debug(f"Converted tool: {prefixed_tool_name}")
return openai_tools
# Helper needed by LLMClient's current tool handling logic
def get_original_message_with_calls(self, response: ChatCompletion) -> dict[str, Any]:
"""Extracts the assistant's message containing tool calls."""
try:
if isinstance(response, ChatCompletion) and response.choices[0].message.tool_calls:
message = response.choices[0].message
# Convert Pydantic model to dict for message history
return message.model_dump(exclude_unset=True)
else:
logger.warning("Could not extract original message with tool calls from response.")
# Return a placeholder or raise error?
return {"role": "assistant", "content": "[Could not extract tool calls message]"}
except Exception as e:
logger.error(f"Error extracting original message with calls: {e}", exc_info=True)
return {"role": "assistant", "content": f"[Error extracting tool calls message: {str(e)}]"}
# Register this provider (if using the registration mechanism)
# from . import register_provider
# register_provider("openai", OpenAIProvider)

6
src/tools/__init__.py Normal file
View File

@@ -0,0 +1,6 @@
# src/tools/__init__.py
# This file makes the 'tools' directory a Python package.
# Optionally import key functions/classes for easier access
# from .conversion import convert_to_openai_tools, convert_to_anthropic_tools
# from .execution import execute_tool # Assuming execution.py will exist

177
src/tools/conversion.py Normal file
View File

@@ -0,0 +1,177 @@
"""
Conversion utilities for MCP tools.
This module contains functions to convert between different tool formats
for various LLM providers (OpenAI, Anthropic, etc.).
"""
import logging
from typing import Any
logger = logging.getLogger(__name__)
def convert_to_openai_tools(mcp_tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""
Convert MCP tools to OpenAI tool definitions.
Args:
mcp_tools: List of MCP tools (each with server_name, name, description, inputSchema).
Returns:
List of OpenAI tool definitions.
"""
openai_tools = []
logger.debug(f"Converting {len(mcp_tools)} MCP tools to OpenAI format.")
for tool in mcp_tools:
server_name = tool.get("server_name")
tool_name = tool.get("name")
description = tool.get("description")
input_schema = tool.get("inputSchema")
if not server_name or not tool_name or not description or not input_schema:
logger.warning(f"Skipping invalid MCP tool definition during OpenAI conversion: {tool}")
continue
# Prefix tool name with server name for routing
prefixed_tool_name = f"{server_name}__{tool_name}"
# Initialize the OpenAI tool structure
openai_tool = {
"type": "function",
"function": {
"name": prefixed_tool_name,
"description": description,
"parameters": input_schema, # OpenAI uses JSON Schema directly
},
}
# Basic validation/cleaning of schema if needed could go here
if not isinstance(input_schema, dict) or input_schema.get("type") != "object":
logger.warning(f"Input schema for tool '{prefixed_tool_name}' is not a valid JSON object schema. OpenAI might reject this.")
# Ensure basic structure if missing
if not isinstance(input_schema, dict):
input_schema = {}
if "type" not in input_schema:
input_schema["type"] = "object"
if "properties" not in input_schema:
input_schema["properties"] = {}
openai_tool["function"]["parameters"] = input_schema
openai_tools.append(openai_tool)
logger.debug(f"Converted MCP tool to OpenAI: {prefixed_tool_name}")
return openai_tools
def convert_to_anthropic_tools(mcp_tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""
Convert MCP tools to Anthropic tool definitions.
Args:
mcp_tools: List of MCP tools (each with server_name, name, description, inputSchema).
Returns:
List of Anthropic tool definitions.
"""
logger.debug(f"Converting {len(mcp_tools)} MCP tools to Anthropic format")
anthropic_tools = []
for tool in mcp_tools:
server_name = tool.get("server_name")
tool_name = tool.get("name")
description = tool.get("description")
input_schema = tool.get("inputSchema")
if not server_name or not tool_name or not description or not input_schema:
logger.warning(f"Skipping invalid MCP tool definition during Anthropic conversion: {tool}")
continue
# Prefix tool name with server name for routing
prefixed_tool_name = f"{server_name}__{tool_name}"
# Initialize the Anthropic tool structure
# Anthropic's format is quite close to JSON Schema
anthropic_tool = {"name": prefixed_tool_name, "description": description, "input_schema": input_schema}
# Basic validation/cleaning of schema if needed
if not isinstance(input_schema, dict) or input_schema.get("type") != "object":
logger.warning(f"Input schema for tool '{prefixed_tool_name}' is not a valid JSON object schema. Anthropic might reject this.")
# Ensure basic structure if missing
if not isinstance(input_schema, dict):
input_schema = {}
if "type" not in input_schema:
input_schema["type"] = "object"
if "properties" not in input_schema:
input_schema["properties"] = {}
anthropic_tool["input_schema"] = input_schema
anthropic_tools.append(anthropic_tool)
logger.debug(f"Converted MCP tool to Anthropic: {prefixed_tool_name}")
return anthropic_tools
def convert_to_google_tools(mcp_tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""
Convert MCP tools to Google Gemini format (dictionary structure).
Args:
mcp_tools: List of MCP tools (each with server_name, name, description, inputSchema).
Returns:
List containing one dictionary with 'function_declarations'.
"""
logger.debug(f"Converting {len(mcp_tools)} MCP tools to Google Gemini format")
function_declarations = []
for tool in mcp_tools:
server_name = tool.get("server_name")
tool_name = tool.get("name")
description = tool.get("description")
input_schema = tool.get("inputSchema")
if not server_name or not tool_name or not description or not input_schema:
logger.warning(f"Skipping invalid MCP tool definition during Google conversion: {tool}")
continue
# Prefix tool name with server name for routing
prefixed_tool_name = f"{server_name}__{tool_name}"
# Basic validation/cleaning of schema
if not isinstance(input_schema, dict) or input_schema.get("type") != "object":
logger.warning(f"Input schema for tool '{prefixed_tool_name}' is not a valid JSON object schema. Google might reject this.")
# Ensure basic structure if missing
if not isinstance(input_schema, dict):
input_schema = {}
if "type" not in input_schema:
input_schema["type"] = "object"
if "properties" not in input_schema:
input_schema["properties"] = {}
# Google requires properties for object type, add dummy if empty
if not input_schema["properties"]:
logger.warning(f"Empty properties for tool '{prefixed_tool_name}', adding dummy property for Google.")
input_schema["properties"] = {"_dummy_param": {"type": "STRING", "description": "Placeholder"}}
# Create function declaration for Google's format
function_declaration = {
"name": prefixed_tool_name,
"description": description,
"parameters": input_schema, # Google uses JSON Schema directly
}
function_declarations.append(function_declaration)
logger.debug(f"Converted MCP tool to Google FunctionDeclaration: {prefixed_tool_name}")
# Google API expects a list containing one Tool object dict
google_tools_wrapper = [{"function_declarations": function_declarations}] if function_declarations else []
logger.debug(f"Final Google tools structure: {google_tools_wrapper}")
return google_tools_wrapper
# Note: The _handle_schema_construct helper from the reference code is not strictly
# needed if we assume the inputSchema is already valid JSON Schema.
# If complex schemas (anyOf, etc.) need specific handling beyond standard JSON Schema,
# that logic could be added here or within the provider implementations.