feat: update MCP client and manager to include API key and base URL in query processing
This commit is contained in:
@@ -166,7 +166,16 @@ class MCPClient:
|
||||
# Send initialize request
|
||||
self.request_id += 1
|
||||
req_id = self.request_id
|
||||
initialize_req = {"jsonrpc": "2.0", "id": req_id, "method": "initialize", "params": {"protocolVersion": "2024-11-05", "clientInfo": {"name": "CustomMCPClient", "version": "1.0.0"}}}
|
||||
initialize_req = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": req_id,
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"clientInfo": {"name": "CustomMCPClient", "version": "1.0.0"},
|
||||
"capabilities": {}, # Add empty capabilities object
|
||||
},
|
||||
}
|
||||
if not await self._send_message(initialize_req):
|
||||
self.logger.warning("Failed to send 'initialize' request.")
|
||||
# Continue anyway for non-compliant servers
|
||||
@@ -338,53 +347,45 @@ async def process_tool_call(tool_call: dict, servers: dict[str, MCPClient]) -> d
|
||||
return await servers[server_name].call_tool(tool_name, func_args)
|
||||
|
||||
|
||||
async def run_interaction(user_query: str, model_name: str, config: dict, stream: bool = False) -> dict | AsyncGenerator:
|
||||
async def run_interaction(
|
||||
user_query: str,
|
||||
model_name: str,
|
||||
api_key: str,
|
||||
base_url: str | None,
|
||||
mcp_config: dict,
|
||||
stream: bool = False,
|
||||
) -> dict | AsyncGenerator:
|
||||
"""
|
||||
Run an interaction with OpenAI using MCP server tools.
|
||||
|
||||
Args:
|
||||
user_query: The user's input query
|
||||
model_name: The model to use for processing
|
||||
config: Configuration dictionary
|
||||
stream: Whether to stream the response
|
||||
user_query: The user's input query.
|
||||
model_name: The model to use for processing.
|
||||
api_key: The OpenAI API key.
|
||||
base_url: The OpenAI API base URL (optional).
|
||||
mcp_config: The MCP configuration dictionary (for servers).
|
||||
stream: Whether to stream the response.
|
||||
|
||||
Returns:
|
||||
Dictionary containing response or AsyncGenerator for streaming
|
||||
Dictionary containing response or AsyncGenerator for streaming.
|
||||
"""
|
||||
# Get OpenAI configuration
|
||||
# TODO: Handle multiple models in config?
|
||||
if not config.get("models"):
|
||||
logger.error("No models defined in MCP configuration.")
|
||||
# This function needs to return something compatible, maybe raise error?
|
||||
# For now, returning error dict for non-streaming case
|
||||
if not stream:
|
||||
return {"error": "No models defined in MCP configuration."}
|
||||
else:
|
||||
# How to handle error in async generator? Yield an error dict?
|
||||
async def error_gen():
|
||||
yield {"error": "No models defined in MCP configuration."}
|
||||
|
||||
return error_gen()
|
||||
|
||||
api_key = config["models"][0].get("apiKey")
|
||||
base_url = config["models"][0].get("apiBase", "https://api.openai.com/v1")
|
||||
|
||||
# Validate passed arguments
|
||||
if not api_key:
|
||||
logger.error("apiKey missing for the first model in MCP configuration.")
|
||||
logger.error("API key is missing.")
|
||||
if not stream:
|
||||
return {"error": "apiKey missing for the first model in MCP configuration."}
|
||||
return {"error": "API key is missing."}
|
||||
else:
|
||||
|
||||
async def error_gen():
|
||||
yield {"error": "apiKey missing for the first model in MCP configuration."}
|
||||
yield {"error": "API key is missing."}
|
||||
|
||||
return error_gen()
|
||||
|
||||
# Start MCP servers
|
||||
# Start MCP servers using mcp_config
|
||||
servers = {}
|
||||
all_functions = []
|
||||
if config.get("mcpServers"):
|
||||
for server_name, server_config in config["mcpServers"].items():
|
||||
if mcp_config.get("mcpServers"): # Use mcp_config here
|
||||
for server_name, server_config in mcp_config["mcpServers"].items(): # Use mcp_config here
|
||||
client = MCPClient(server_name=server_name, command=server_config.get("command"), args=server_config.get("args", []), env=server_config.get("env", {}))
|
||||
|
||||
if await client.start():
|
||||
@@ -406,7 +407,8 @@ async def run_interaction(user_query: str, model_name: str, config: dict, stream
|
||||
else:
|
||||
logger.info("No mcpServers defined in configuration.")
|
||||
|
||||
openai_client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||
# Use passed api_key and base_url
|
||||
openai_client = AsyncOpenAI(api_key=api_key, base_url=base_url) # Use arguments
|
||||
messages = [{"role": "user", "content": user_query}]
|
||||
tool_defs = [{"type": "function", "function": f["function"]} for f in all_functions] if all_functions else None
|
||||
|
||||
|
||||
@@ -188,29 +188,47 @@ class SyncMCPManager:
|
||||
logger.debug(f"Shutdown completed for server: {server_name}")
|
||||
logger.debug("_async_shutdown finished.")
|
||||
|
||||
def process_query(self, query: str, model_name: str) -> dict:
|
||||
# Updated process_query signature
|
||||
def process_query(self, query: str, model_name: str, api_key: str, base_url: str | None) -> dict:
|
||||
"""
|
||||
Process a query using MCP tools synchronously
|
||||
|
||||
Args:
|
||||
query: The user's input query
|
||||
model_name: The model to use for processing
|
||||
query: The user's input query.
|
||||
model_name: The model to use for processing.
|
||||
api_key: The OpenAI API key.
|
||||
base_url: The OpenAI API base URL.
|
||||
|
||||
Returns:
|
||||
Dictionary containing response or error
|
||||
Dictionary containing response or error.
|
||||
"""
|
||||
if not self.initialized and not self.initialize():
|
||||
logger.error("process_query called but MCP manager failed to initialize.")
|
||||
return {"error": "Failed to initialize MCP servers"}
|
||||
|
||||
logger.debug(f"Processing query synchronously: '{query}'")
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
result = loop.run_until_complete(self._async_process_query(query, model_name))
|
||||
# Pass api_key and base_url to _async_process_query
|
||||
result = loop.run_until_complete(self._async_process_query(query, model_name, api_key, base_url))
|
||||
logger.debug(f"Synchronous query processing result: {result}")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error during synchronous query processing: {e}", exc_info=True)
|
||||
return {"error": f"Processing error: {str(e)}"}
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
async def _async_process_query(self, query: str, model_name: str) -> dict:
|
||||
# Updated _async_process_query signature
|
||||
async def _async_process_query(self, query: str, model_name: str, api_key: str, base_url: str | None) -> dict:
|
||||
"""Async implementation of query processing"""
|
||||
return await run_interaction(user_query=query, model_name=model_name, config=self.config, stream=False)
|
||||
# Pass api_key, base_url, and the MCP config separately to run_interaction
|
||||
return await run_interaction(
|
||||
user_query=query,
|
||||
model_name=model_name,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
mcp_config=self.config, # self.config only contains MCP server definitions now
|
||||
stream=False,
|
||||
)
|
||||
|
||||
@@ -1,14 +1,19 @@
|
||||
"""OpenAI client with custom MCP integration."""
|
||||
|
||||
import configparser
|
||||
import logging # Import logging
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from mcp_manager import SyncMCPManager
|
||||
|
||||
# Get a logger for this module
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIClient:
|
||||
def __init__(self):
|
||||
logger.debug("Initializing OpenAIClient...") # Add init log
|
||||
self.config = configparser.ConfigParser()
|
||||
self.config.read("config/config.ini")
|
||||
|
||||
@@ -33,21 +38,28 @@ class OpenAIClient:
|
||||
try:
|
||||
# Try using MCP if available
|
||||
if self.mcp_manager and self.mcp_manager.initialize():
|
||||
print("Using MCP with tools...")
|
||||
logger.info("Using MCP with tools...") # Use logger
|
||||
last_message = messages[-1]["content"]
|
||||
response = self.mcp_manager.process_query(last_message, model_name=self.config["openai"]["model"])
|
||||
# Pass API key and base URL from config.ini
|
||||
response = self.mcp_manager.process_query(
|
||||
query=last_message,
|
||||
model_name=self.config["openai"]["model"],
|
||||
api_key=self.config["openai"]["api_key"],
|
||||
base_url=self.config["openai"].get("base_url"), # Use .get for optional base_url
|
||||
)
|
||||
|
||||
if "error" not in response:
|
||||
logger.debug("MCP processing successful, wrapping response.")
|
||||
# Convert to OpenAI-compatible response format
|
||||
return self._wrap_mcp_response(response)
|
||||
|
||||
# Fall back to standard OpenAI
|
||||
print(f"Using standard OpenAI API with model: {self.config['openai']['model']}")
|
||||
logger.info(f"Falling back to standard OpenAI API with model: {self.config['openai']['model']}") # Use logger
|
||||
return self.client.chat.completions.create(model=self.config["openai"]["model"], messages=messages, stream=True)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"API Error (Code: {getattr(e, 'code', 'N/A')}): {str(e)}"
|
||||
print(error_msg)
|
||||
logger.error(error_msg, exc_info=True) # Use logger
|
||||
raise Exception(error_msg)
|
||||
|
||||
def _wrap_mcp_response(self, response: dict):
|
||||
|
||||
Reference in New Issue
Block a user