feat: update MCP client and manager to include API key and base URL in query processing
This commit is contained in:
@@ -166,7 +166,16 @@ class MCPClient:
|
|||||||
# Send initialize request
|
# Send initialize request
|
||||||
self.request_id += 1
|
self.request_id += 1
|
||||||
req_id = self.request_id
|
req_id = self.request_id
|
||||||
initialize_req = {"jsonrpc": "2.0", "id": req_id, "method": "initialize", "params": {"protocolVersion": "2024-11-05", "clientInfo": {"name": "CustomMCPClient", "version": "1.0.0"}}}
|
initialize_req = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": req_id,
|
||||||
|
"method": "initialize",
|
||||||
|
"params": {
|
||||||
|
"protocolVersion": "2024-11-05",
|
||||||
|
"clientInfo": {"name": "CustomMCPClient", "version": "1.0.0"},
|
||||||
|
"capabilities": {}, # Add empty capabilities object
|
||||||
|
},
|
||||||
|
}
|
||||||
if not await self._send_message(initialize_req):
|
if not await self._send_message(initialize_req):
|
||||||
self.logger.warning("Failed to send 'initialize' request.")
|
self.logger.warning("Failed to send 'initialize' request.")
|
||||||
# Continue anyway for non-compliant servers
|
# Continue anyway for non-compliant servers
|
||||||
@@ -338,53 +347,45 @@ async def process_tool_call(tool_call: dict, servers: dict[str, MCPClient]) -> d
|
|||||||
return await servers[server_name].call_tool(tool_name, func_args)
|
return await servers[server_name].call_tool(tool_name, func_args)
|
||||||
|
|
||||||
|
|
||||||
async def run_interaction(user_query: str, model_name: str, config: dict, stream: bool = False) -> dict | AsyncGenerator:
|
async def run_interaction(
|
||||||
|
user_query: str,
|
||||||
|
model_name: str,
|
||||||
|
api_key: str,
|
||||||
|
base_url: str | None,
|
||||||
|
mcp_config: dict,
|
||||||
|
stream: bool = False,
|
||||||
|
) -> dict | AsyncGenerator:
|
||||||
"""
|
"""
|
||||||
Run an interaction with OpenAI using MCP server tools.
|
Run an interaction with OpenAI using MCP server tools.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_query: The user's input query
|
user_query: The user's input query.
|
||||||
model_name: The model to use for processing
|
model_name: The model to use for processing.
|
||||||
config: Configuration dictionary
|
api_key: The OpenAI API key.
|
||||||
stream: Whether to stream the response
|
base_url: The OpenAI API base URL (optional).
|
||||||
|
mcp_config: The MCP configuration dictionary (for servers).
|
||||||
|
stream: Whether to stream the response.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary containing response or AsyncGenerator for streaming
|
Dictionary containing response or AsyncGenerator for streaming.
|
||||||
"""
|
"""
|
||||||
# Get OpenAI configuration
|
# Validate passed arguments
|
||||||
# TODO: Handle multiple models in config?
|
|
||||||
if not config.get("models"):
|
|
||||||
logger.error("No models defined in MCP configuration.")
|
|
||||||
# This function needs to return something compatible, maybe raise error?
|
|
||||||
# For now, returning error dict for non-streaming case
|
|
||||||
if not stream:
|
|
||||||
return {"error": "No models defined in MCP configuration."}
|
|
||||||
else:
|
|
||||||
# How to handle error in async generator? Yield an error dict?
|
|
||||||
async def error_gen():
|
|
||||||
yield {"error": "No models defined in MCP configuration."}
|
|
||||||
|
|
||||||
return error_gen()
|
|
||||||
|
|
||||||
api_key = config["models"][0].get("apiKey")
|
|
||||||
base_url = config["models"][0].get("apiBase", "https://api.openai.com/v1")
|
|
||||||
|
|
||||||
if not api_key:
|
if not api_key:
|
||||||
logger.error("apiKey missing for the first model in MCP configuration.")
|
logger.error("API key is missing.")
|
||||||
if not stream:
|
if not stream:
|
||||||
return {"error": "apiKey missing for the first model in MCP configuration."}
|
return {"error": "API key is missing."}
|
||||||
else:
|
else:
|
||||||
|
|
||||||
async def error_gen():
|
async def error_gen():
|
||||||
yield {"error": "apiKey missing for the first model in MCP configuration."}
|
yield {"error": "API key is missing."}
|
||||||
|
|
||||||
return error_gen()
|
return error_gen()
|
||||||
|
|
||||||
# Start MCP servers
|
# Start MCP servers using mcp_config
|
||||||
servers = {}
|
servers = {}
|
||||||
all_functions = []
|
all_functions = []
|
||||||
if config.get("mcpServers"):
|
if mcp_config.get("mcpServers"): # Use mcp_config here
|
||||||
for server_name, server_config in config["mcpServers"].items():
|
for server_name, server_config in mcp_config["mcpServers"].items(): # Use mcp_config here
|
||||||
client = MCPClient(server_name=server_name, command=server_config.get("command"), args=server_config.get("args", []), env=server_config.get("env", {}))
|
client = MCPClient(server_name=server_name, command=server_config.get("command"), args=server_config.get("args", []), env=server_config.get("env", {}))
|
||||||
|
|
||||||
if await client.start():
|
if await client.start():
|
||||||
@@ -406,7 +407,8 @@ async def run_interaction(user_query: str, model_name: str, config: dict, stream
|
|||||||
else:
|
else:
|
||||||
logger.info("No mcpServers defined in configuration.")
|
logger.info("No mcpServers defined in configuration.")
|
||||||
|
|
||||||
openai_client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
# Use passed api_key and base_url
|
||||||
|
openai_client = AsyncOpenAI(api_key=api_key, base_url=base_url) # Use arguments
|
||||||
messages = [{"role": "user", "content": user_query}]
|
messages = [{"role": "user", "content": user_query}]
|
||||||
tool_defs = [{"type": "function", "function": f["function"]} for f in all_functions] if all_functions else None
|
tool_defs = [{"type": "function", "function": f["function"]} for f in all_functions] if all_functions else None
|
||||||
|
|
||||||
|
|||||||
@@ -188,29 +188,47 @@ class SyncMCPManager:
|
|||||||
logger.debug(f"Shutdown completed for server: {server_name}")
|
logger.debug(f"Shutdown completed for server: {server_name}")
|
||||||
logger.debug("_async_shutdown finished.")
|
logger.debug("_async_shutdown finished.")
|
||||||
|
|
||||||
def process_query(self, query: str, model_name: str) -> dict:
|
# Updated process_query signature
|
||||||
|
def process_query(self, query: str, model_name: str, api_key: str, base_url: str | None) -> dict:
|
||||||
"""
|
"""
|
||||||
Process a query using MCP tools synchronously
|
Process a query using MCP tools synchronously
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: The user's input query
|
query: The user's input query.
|
||||||
model_name: The model to use for processing
|
model_name: The model to use for processing.
|
||||||
|
api_key: The OpenAI API key.
|
||||||
|
base_url: The OpenAI API base URL.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary containing response or error
|
Dictionary containing response or error.
|
||||||
"""
|
"""
|
||||||
if not self.initialized and not self.initialize():
|
if not self.initialized and not self.initialize():
|
||||||
|
logger.error("process_query called but MCP manager failed to initialize.")
|
||||||
return {"error": "Failed to initialize MCP servers"}
|
return {"error": "Failed to initialize MCP servers"}
|
||||||
|
|
||||||
|
logger.debug(f"Processing query synchronously: '{query}'")
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
try:
|
try:
|
||||||
result = loop.run_until_complete(self._async_process_query(query, model_name))
|
# Pass api_key and base_url to _async_process_query
|
||||||
|
result = loop.run_until_complete(self._async_process_query(query, model_name, api_key, base_url))
|
||||||
|
logger.debug(f"Synchronous query processing result: {result}")
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logger.error(f"Error during synchronous query processing: {e}", exc_info=True)
|
||||||
return {"error": f"Processing error: {str(e)}"}
|
return {"error": f"Processing error: {str(e)}"}
|
||||||
finally:
|
finally:
|
||||||
loop.close()
|
loop.close()
|
||||||
|
|
||||||
async def _async_process_query(self, query: str, model_name: str) -> dict:
|
# Updated _async_process_query signature
|
||||||
|
async def _async_process_query(self, query: str, model_name: str, api_key: str, base_url: str | None) -> dict:
|
||||||
"""Async implementation of query processing"""
|
"""Async implementation of query processing"""
|
||||||
return await run_interaction(user_query=query, model_name=model_name, config=self.config, stream=False)
|
# Pass api_key, base_url, and the MCP config separately to run_interaction
|
||||||
|
return await run_interaction(
|
||||||
|
user_query=query,
|
||||||
|
model_name=model_name,
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url,
|
||||||
|
mcp_config=self.config, # self.config only contains MCP server definitions now
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,14 +1,19 @@
|
|||||||
"""OpenAI client with custom MCP integration."""
|
"""OpenAI client with custom MCP integration."""
|
||||||
|
|
||||||
import configparser
|
import configparser
|
||||||
|
import logging # Import logging
|
||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
from mcp_manager import SyncMCPManager
|
from mcp_manager import SyncMCPManager
|
||||||
|
|
||||||
|
# Get a logger for this module
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class OpenAIClient:
|
class OpenAIClient:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
logger.debug("Initializing OpenAIClient...") # Add init log
|
||||||
self.config = configparser.ConfigParser()
|
self.config = configparser.ConfigParser()
|
||||||
self.config.read("config/config.ini")
|
self.config.read("config/config.ini")
|
||||||
|
|
||||||
@@ -33,21 +38,28 @@ class OpenAIClient:
|
|||||||
try:
|
try:
|
||||||
# Try using MCP if available
|
# Try using MCP if available
|
||||||
if self.mcp_manager and self.mcp_manager.initialize():
|
if self.mcp_manager and self.mcp_manager.initialize():
|
||||||
print("Using MCP with tools...")
|
logger.info("Using MCP with tools...") # Use logger
|
||||||
last_message = messages[-1]["content"]
|
last_message = messages[-1]["content"]
|
||||||
response = self.mcp_manager.process_query(last_message, model_name=self.config["openai"]["model"])
|
# Pass API key and base URL from config.ini
|
||||||
|
response = self.mcp_manager.process_query(
|
||||||
|
query=last_message,
|
||||||
|
model_name=self.config["openai"]["model"],
|
||||||
|
api_key=self.config["openai"]["api_key"],
|
||||||
|
base_url=self.config["openai"].get("base_url"), # Use .get for optional base_url
|
||||||
|
)
|
||||||
|
|
||||||
if "error" not in response:
|
if "error" not in response:
|
||||||
|
logger.debug("MCP processing successful, wrapping response.")
|
||||||
# Convert to OpenAI-compatible response format
|
# Convert to OpenAI-compatible response format
|
||||||
return self._wrap_mcp_response(response)
|
return self._wrap_mcp_response(response)
|
||||||
|
|
||||||
# Fall back to standard OpenAI
|
# Fall back to standard OpenAI
|
||||||
print(f"Using standard OpenAI API with model: {self.config['openai']['model']}")
|
logger.info(f"Falling back to standard OpenAI API with model: {self.config['openai']['model']}") # Use logger
|
||||||
return self.client.chat.completions.create(model=self.config["openai"]["model"], messages=messages, stream=True)
|
return self.client.chat.completions.create(model=self.config["openai"]["model"], messages=messages, stream=True)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"API Error (Code: {getattr(e, 'code', 'N/A')}): {str(e)}"
|
error_msg = f"API Error (Code: {getattr(e, 'code', 'N/A')}): {str(e)}"
|
||||||
print(error_msg)
|
logger.error(error_msg, exc_info=True) # Use logger
|
||||||
raise Exception(error_msg)
|
raise Exception(error_msg)
|
||||||
|
|
||||||
def _wrap_mcp_response(self, response: dict):
|
def _wrap_mcp_response(self, response: dict):
|
||||||
|
|||||||
Reference in New Issue
Block a user