Compare commits

...

15 Commits

Author SHA1 Message Date
247835e595 Refactor Google and OpenAI provider response handling and tool utilities
- Improved error handling and logging in Google response processing.
- Simplified streaming content extraction and error detection in Google provider.
- Enhanced content extraction logic in OpenAI provider to handle edge cases.
- Streamlined tool conversion functions for both Google and OpenAI providers.
- Removed redundant comments and improved code readability across multiple files.
- Updated context window retrieval and message truncation logic for better performance.
- Ensured consistent handling of tool calls and arguments in OpenAI responses.
2025-03-28 04:20:39 +00:00
51e3058961 fix: update temperature parameter to 0.6 across multiple providers and add debugging output 2025-03-27 19:02:52 +00:00
ccf750fed4 fix: correct logging error message for Google Generative AI SDK 2025-03-27 15:22:19 +00:00
2fb6c5af3c refactor: remove OpenAIClient implementation to streamline codebase 2025-03-27 11:13:32 +00:00
6b390a35f8 feat: Implement GoogleProvider for Google Generative AI integration
- Added GoogleProvider class to handle chat completions with Google Gemini API.
- Implemented client initialization and response handling for streaming and non-streaming responses.
- Created utility functions for tool conversion, response parsing, and content extraction.
- Removed legacy tool conversion utilities from the tools module.
- Enhanced logging for better traceability of API interactions and error handling.
2025-03-27 11:11:56 +00:00
678f395649 feat: implement OpenAIProvider with client initialization, message handling, and utility functions 2025-03-26 19:59:01 +00:00
bae517a322 refactor: move convert_to_anthropic_tools function to tools.py for better organization 2025-03-26 19:06:21 +00:00
ab8d5fe074 feat: implement AnthropicProvider with client initialization, message handling, and utility functions 2025-03-26 19:02:26 +00:00
246d921743 feat: add GoogleProvider implementation and update conversion utilities for Google tools 2025-03-26 18:18:10 +00:00
15ecb9fc48 feat: enhance token usage tracking and context management for LLM providers 2025-03-26 17:27:41 +00:00
49aebc12d5 refactor: update application name and enhance header display in Streamlit app 2025-03-26 12:27:00 +00:00
bd56cc839d Refactor code structure for improved readability and maintainability 2025-03-26 12:14:58 +00:00
a4683023ad feat: add support for Anthropic provider, including configuration and conversion utilities 2025-03-26 11:57:52 +00:00
b4986e0eb9 refactor: remove custom MCP client implementation files 2025-03-26 11:00:43 +00:00
80ba05338f feat: Implement async utilities for MCP server management and JSON-RPC communication
- Added `process.py` for managing MCP server subprocesses with async capabilities.
- Introduced `protocol.py` for handling JSON-RPC communication over streams.
- Created `llm_client.py` to support chat completion requests to various LLM providers, integrating with MCP tools.
- Defined model configurations in `llm_models.py` for different LLM providers.
- Removed the synchronous `mcp_manager.py` in favor of a more modular approach.
- Established a provider framework in `providers` directory with a base class and specific implementations.
- Implemented `OpenAIProvider` for interacting with OpenAI's API, including streaming support and tool call handling.
2025-03-26 11:00:20 +00:00
40 changed files with 4664 additions and 906 deletions

9
.gitignore vendored
View File

@@ -5,6 +5,7 @@ __pycache__/
# Virtual environment # Virtual environment
env/ env/
.venv/
# Configuration # Configuration
config/config.ini config/config.ini
@@ -20,4 +21,10 @@ config/mcp_config.json
# resources # resources
resources/ resources/
# __pycache__/ # Ruff
.ruff_cache/
# Distribution / packaging
dist/
build/
*.egg-info/

View File

@@ -67,7 +67,7 @@ servers_json = config/mcp_config.json
Start the application: Start the application:
```bash ```bash
streamlit run src/app.py uv run mcpapp
``` ```
The app will be available at `http://localhost:8501` The app will be available at `http://localhost:8501`
@@ -82,9 +82,6 @@ Key components:
## Development ## Development
### Running Tests
```bash
pytest
``` ```
### Code Formatting ### Code Formatting
@@ -94,7 +91,7 @@ ruff check . --fix
### Building ### Building
```bash ```bash
python -m build uv build
``` ```
## License ## License

View File

@@ -1,7 +1,36 @@
[base]
# provider can be [ openai|openrouter|anthropic|google]
provider = openrouter
streamlit_headless = true
[openrouter]
api_key = YOUR_API_KEY
base_url = https://openrouter.ai/api/v1
model = openai/gpt-4o-2024-11-20
context_window = 128000
temperature = 0.6
[anthropic]
api_key = YOUR_API_KEY
base_url = https://api.anthropic.com/v1/messages
model = claude-3-7-sonnet-20250219
context_window = 128000
temperature = 0.6
[google]
api_key = YOUR_API_KEY
base_url = https://generativelanguage.googleapis.com/v1beta/generateContent
model = gemini-2.0-flash
context_window = 1000000
temperature = 0.6
[openai] [openai]
api_key = YOUR_API_KEY api_key = YOUR_API_KEY
base_url = CUSTOM_BASE_URL base_url = https://api.openai.com/v1
model = YOUR_MODEL_ID model = openai/gpt-4o
context_window = 128000
temperature = 0.6
[mcp] [mcp]
servers_json = config/mcp_config.json servers_json = config/mcp_config.json

View File

@@ -1,12 +1,12 @@
{ {
"mcpServers": { "mcpServers": {
"dolphin-demo-database-sqlite": { "mcp-server-sqlite": {
"command": "uvx", "command": "uvx",
"args": [ "args": [
"mcp-server-sqlite", "mcp-server-sqlite",
"--db-path", "--db-path",
"~/.dolphin/dolphin.db" "~/.mcpapp/mcpapp.db"
] ]
}
} }
}
} }

106
project_planning/updates.md Normal file
View File

@@ -0,0 +1,106 @@
What is the google-genai Module?
The google-genai module is part of the Google Gen AI Python SDK, a software development kit provided by Google to enable developers to integrate Google's generative AI models into Python applications. This SDK is distinct from the older, deprecated google-generativeai package. The google-genai package represents the newer, unified SDK designed to work with Google's latest generative AI offerings, such as the Gemini models.
Installation
To use the google-genai module, you need to install it via pip. The package name on PyPI is google-genai, and you can install it with the following command:
bash
```bash
pip install google-genai
```
This installs the necessary dependencies and makes the module available in your Python environment.
Correct Import Statement
The standard import statement for the google-genai SDK, as per the official documentation and examples, is:
python
```python
from google import genai
```
This differs from the older SDK's import style, which was:
python
```python
import google.generativeai as genai
```
When you install google-genai, it provides a module structure where genai is a submodule of the google package. Thus, from google import genai is the correct way to access its functionality.
Usage in the New SDK
Unlike the older google-generativeai SDK, which used a configure method (e.g., genai.configure(api_key='YOUR_API_KEY')) to set up the API key globally, the new google-genai SDK adopts a client-based approach. You create a Client instance with your API key and use it to interact with the models. Heres a basic example:
python
```python
from google import genai
# Initialize the client with your API key
client = genai.Client(api_key='YOUR_API_KEY')
# Example: Generate content using a model
response = client.models.generate_content(
model='gemini-2.0-flash-001', # Specify the model name
contents='Why is the sky blue?'
)
print(response.text)
```
Key points about this usage:
- No configure Method: The new SDK does not have a genai.configure method directly on the genai module. Instead, you pass the API key when creating a Client instance.
- Client-Based Interaction: All interactions with the generative models (e.g., generating content) are performed through the client object.
Official Documentation
The official documentation for the google-genai SDK can be found on Google's API documentation site. Specifically:
- Google Gen AI Python SDK Documentation: Hosted at https://googleapis.github.io/python-genai/, this site provides detailed guides, API references, and code examples. It confirms the use of from google import genai and the client-based approach.
- GitHub Repository: The source code and additional examples are available at https://github.com/googleapis/python-genai. The repository documentation reinforces that from google import genai is the import style for the new SDK.
Why Your Import is Correct
You mentioned that your import is correct, which I assume refers to from google import genai. This is indeed the proper import for the google-genai package, aligning with the new SDK's design. If you're encountering issues (e.g., an error like module 'google.genai' has no attribute 'configure'), its likely because the code is trying to use methods from the older SDK (like genai.configure) that dont exist in the new one. To resolve this, you should update the code to use the client-based approach shown above.
Troubleshooting Common Issues
If you're seeing errors with from google import genai, here are some things to check:
1. Correct Package Installed:
- Ensure google-genai is installed (pip install google-genai).
- If google-generativeai is installed instead, uninstall it with pip uninstall google-generativeai to avoid conflicts, then install google-genai.
2. Code Compatibility:
- If your code uses genai.configure or assumes the older SDKs structure, youll need to refactor it. Replace configuration calls with genai.Client(api_key='...') and adjust model interactions to use the client object.
3. Environment Verification:
- Run pip show google-genai to confirm the package is installed and check its version. This ensures youre working with the intended SDK.
Additional Resources
- PyPI Page: The google-genai package on PyPI (https://pypi.org/project/google-genai/) provides installation instructions and links to the GitHub repository.
- Examples: The GitHub repository includes sample code demonstrating how to use the SDK with from google import genai.
Conclusion
Your import, from google import genai, aligns with the google-genai module from the new Google Gen AI Python SDK. The documentation and online resources confirm this as the correct approach for the current, unified SDK. If import google.generativeai was previously suggested or used, it pertains to the older, deprecated SDK, which explains why it might be considered incorrect in your context. To fully leverage google-genai, ensure your code uses the client-based API as outlined, and you should be able to interact with Googles generative AI models effectively. If youre still facing specific errors, feel free to share them, and I can assist further!

View File

@@ -1,5 +1,5 @@
[project] [project]
name = "streamlit-chat-app" name = "macpapp"
version = "0.1.0" version = "0.1.0"
description = "Streamlit chat app with MCP" description = "Streamlit chat app with MCP"
readme = "README.md" readme = "README.md"
@@ -10,7 +10,9 @@ authors = [
dependencies = [ dependencies = [
"streamlit", "streamlit",
"python-dotenv", "python-dotenv",
"openai" "openai",
"anthropic",
"google-genai",
] ]
classifiers = [ classifiers = [
"Development Status :: 3 - Alpha", "Development Status :: 3 - Alpha",
@@ -25,6 +27,9 @@ license-files = ["LICEN[CS]E*"]
GitHub = "https://git.bhakat.dev/abhishekbhakat/mcpapp" GitHub = "https://git.bhakat.dev/abhishekbhakat/mcpapp"
Issues = "https://git.bhakat.dev/abhishekbhakat/mcpapp/issues" Issues = "https://git.bhakat.dev/abhishekbhakat/mcpapp/issues"
[project.scripts]
mcpapp = "run_app:run_streamlit_app"
[project.optional-dependencies] [project.optional-dependencies]
dev = [ dev = [
"build>=1.2.2", "build>=1.2.2",
@@ -61,6 +66,7 @@ lint.select = [
"T10", # flake8-debugger "T10", # flake8-debugger
"A", # flake8-builtins "A", # flake8-builtins
"UP", # pyupgrade "UP", # pyupgrade
"TID", # flake8-tidy-imports
] ]
lint.ignore = [ lint.ignore = [
@@ -81,7 +87,7 @@ skip-magic-trailing-comma = false
combine-as-imports = true combine-as-imports = true
[tool.ruff.lint.mccabe] [tool.ruff.lint.mccabe]
max-complexity = 12 max-complexity = 30
[tool.ruff.lint.flake8-tidy-imports] [tool.ruff.lint.flake8-tidy-imports]
# Disallow all relative imports. # Disallow all relative imports.

57
run_app.py Normal file
View File

@@ -0,0 +1,57 @@
import configparser
import os
import subprocess
import sys
def run_streamlit_app():
"""
Reads the configuration file and launches the Streamlit app,
optionally in headless mode.
"""
config_path = "config/config.ini"
headless = False
try:
if os.path.exists(config_path):
config = configparser.ConfigParser()
config.read(config_path)
if config.has_section("base"):
headless = config.getboolean("base", "streamlit_headless", fallback=False)
if headless:
print(f"INFO: Headless mode enabled via {config_path}.")
else:
print(f"INFO: Headless mode disabled via {config_path}.")
else:
print(f"WARNING: [base] section not found in {config_path}. Defaulting to non-headless.")
else:
print(f"WARNING: Configuration file not found at {config_path}. Defaulting to non-headless.")
except Exception as e:
print(f"ERROR: Could not read headless config from {config_path}: {e}. Defaulting to non-headless.")
headless = False # Ensure default on error
# Construct the command
command = [sys.executable, "-m", "streamlit", "run", "src/app.py"]
if headless:
command.extend(["--server.headless", "true"])
print(f"Running command: {' '.join(command)}")
try:
# Run Streamlit using subprocess.run which waits for completion
# Use check=True to raise an error if Streamlit fails
# Capture output might be useful for debugging but can be complex with interactive apps
process = subprocess.Popen(command)
process.wait() # Wait for the Streamlit process to exit
print(f"Streamlit process finished with exit code: {process.returncode}")
except FileNotFoundError:
print("ERROR: 'streamlit' command not found. Make sure Streamlit is installed and in your PATH.")
sys.exit(1)
except Exception as e:
print(f"ERROR: Failed to run Streamlit: {e}")
sys.exit(1)
if __name__ == "__main__":
run_streamlit_app()

View File

@@ -1,29 +1,110 @@
import atexit import atexit
import configparser
import logging
import streamlit as st import streamlit as st
from openai_client import OpenAIClient from llm_client import LLMClient
from src.custom_mcp.manager import SyncMCPManager
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
def init_session_state(): def init_session_state():
"""Initializes session state variables including clients."""
if "messages" not in st.session_state: if "messages" not in st.session_state:
st.session_state.messages = [] st.session_state.messages = []
logger.info("Initialized session state: messages")
if "client" not in st.session_state: if "client" not in st.session_state:
st.session_state.client = OpenAIClient() logger.info("Attempting to initialize clients...")
# Register cleanup for MCP servers try:
if hasattr(st.session_state.client, "mcp_manager"): config = configparser.ConfigParser()
atexit.register(st.session_state.client.mcp_manager.shutdown) config_files_read = config.read("config/config.ini")
if not config_files_read:
raise FileNotFoundError("config.ini not found or could not be read.")
logger.info(f"Read configuration from: {config_files_read}")
mcp_config_path = "config/mcp_config.json"
if config.has_section("mcp") and config["mcp"].get("servers_json"):
mcp_config_path = config["mcp"]["servers_json"]
logger.info(f"Using MCP config path from config.ini: {mcp_config_path}")
else:
logger.info(f"Using default MCP config path: {mcp_config_path}")
mcp_manager = SyncMCPManager(mcp_config_path)
if not mcp_manager.initialize():
logger.warning("MCP Manager failed to initialize. Proceeding without MCP tools.")
else:
logger.info("MCP Manager initialized successfully.")
atexit.register(mcp_manager.shutdown)
logger.info("Registered MCP Manager shutdown hook.")
provider_name = None
model_name = None
api_key = None
base_url = None
if config.has_section("base") and config["base"].get("provider"):
provider_name = config["base"].get("provider")
logger.info(f"Provider selected from [base] section: {provider_name}")
else:
raise ValueError("Missing 'provider' setting in [base] section of config.ini")
if config.has_section(provider_name):
provider_config = config[provider_name]
model_name = provider_config.get("model")
api_key = provider_config.get("api_key")
base_url = provider_config.get("base_url")
provider_temperature = provider_config.getfloat("temperature", 0.6)
if "temperature" not in provider_config:
logger.warning(f"Temperature not found in [{provider_name}] section, defaulting to {provider_temperature}")
else:
logger.info(f"Loaded temperature for {provider_name}: {provider_temperature}")
logger.info(f"Read configuration from [{provider_name}] section.")
else:
raise ValueError(f"Missing configuration section '[{provider_name}]' in config.ini for the selected provider.")
if not api_key:
raise ValueError(f"Missing 'api_key' in [{provider_name}] section of config.ini")
if not model_name:
raise ValueError(f"Missing 'model' name in [{provider_name}] section of config.ini")
logger.info(f"Configuring LLMClient for provider: {provider_name}, model: {model_name}")
st.session_state.client = LLMClient(
provider_name=provider_name,
api_key=api_key,
mcp_manager=mcp_manager,
base_url=base_url,
temperature=provider_temperature,
)
st.session_state.model_name = model_name
st.session_state.provider_name = provider_name
logger.info("LLMClient initialized successfully.")
except Exception as e:
logger.error(f"Failed to initialize application clients: {e}", exc_info=True)
st.error(f"Application Initialization Error: {e}. Please check configuration and logs.")
st.stop()
def display_chat_messages(): def display_chat_messages():
"""Displays chat messages stored in session state."""
for message in st.session_state.messages: for message in st.session_state.messages:
with st.chat_message(message["role"]): with st.chat_message(message["role"]):
st.markdown(message["content"]) st.markdown(message["content"])
if message["role"] == "assistant" and "usage" in message:
usage = message["usage"]
prompt_tokens = usage.get("prompt_tokens", "N/A")
completion_tokens = usage.get("completion_tokens", "N/A")
st.caption(f"Tokens: Prompt {prompt_tokens}, Completion {completion_tokens}")
def handle_user_input(): def handle_user_input():
"""Handles user input, calls LLMClient, and displays the response."""
if prompt := st.chat_input("Type your message..."): if prompt := st.chat_input("Type your message..."):
print(f"User input received: {prompt}") # Debug log logger.info(f"User input received: '{prompt[:50]}...'")
st.session_state.messages.append({"role": "user", "content": prompt}) st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"): with st.chat_message("user"):
st.markdown(prompt) st.markdown(prompt)
@@ -32,39 +113,84 @@ def handle_user_input():
with st.chat_message("assistant"): with st.chat_message("assistant"):
response_placeholder = st.empty() response_placeholder = st.empty()
full_response = "" full_response = ""
error_occurred = False
response_usage = None
print("Processing message...") # Debug log logger.info("Processing message via LLMClient...")
response = st.session_state.client.get_chat_response(st.session_state.messages) response_data = st.session_state.client.chat_completion(
messages=st.session_state.messages,
model=st.session_state.model_name,
stream=False,
)
if isinstance(response_data, dict):
if "error" in response_data:
full_response = f"Error: {response_data['error']}"
logger.error(f"Error returned from chat_completion: {full_response}")
st.error(full_response)
error_occurred = True
else:
full_response = response_data.get("content", "")
response_usage = response_data.get("usage")
if not full_response and not error_occurred:
logger.warning("Empty content received from LLMClient.")
response_placeholder.markdown(full_response)
logger.debug("Non-streaming response processed.")
# Handle both MCP and standard OpenAI responses
# Check if it's NOT a dict (assuming stream is not a dict)
if not isinstance(response, dict):
# Standard OpenAI streaming response
for chunk in response:
# Ensure chunk has choices and delta before accessing
if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
full_response += chunk.choices[0].delta.content
response_placeholder.markdown(full_response + "")
else: else:
# MCP non-streaming response full_response = "[Unexpected response format from LLMClient]"
full_response = response.get("assistant_text", "") logger.error(f"Unexpected response type: {type(response_data)}")
response_placeholder.markdown(full_response) st.error(full_response)
error_occurred = True
response_placeholder.markdown(full_response) if not error_occurred and full_response:
st.session_state.messages.append({"role": "assistant", "content": full_response}) assistant_message = {"role": "assistant", "content": full_response}
print("Message processed successfully") # Debug log if response_usage:
assistant_message["usage"] = response_usage
logger.info(f"Assistant response usage: {response_usage}")
st.session_state.messages.append(assistant_message)
logger.info("Assistant response added to history.")
elif error_occurred:
logger.warning("Assistant response not added to history due to error.")
else:
logger.warning("Empty assistant response received, not added to history.")
except Exception as e: except Exception as e:
st.error(f"Error processing message: {str(e)}") logger.error(f"Error during chat handling: {str(e)}", exc_info=True)
print(f"Error details: {str(e)}") # Debug log st.error(f"An unexpected error occurred: {str(e)}")
def main(): def main():
st.title("Streamlit Chat App") """Main function to run the Streamlit app."""
init_session_state() try:
display_chat_messages() init_session_state()
handle_user_input()
provider_name = st.session_state.get("provider_name", "Unknown Provider")
model_name = st.session_state.get("model_name", "Unknown Model")
mcp_manager = st.session_state.client.mcp_manager
server_count = 0
tool_count = 0
if mcp_manager and mcp_manager.initialized:
server_count = len(mcp_manager.servers)
try:
tool_count = len(mcp_manager.list_all_tools())
except Exception as e:
logger.warning(f"Could not retrieve tool count for header: {e}")
tool_count = "N/A"
st.markdown(f"# Say Hi to **{provider_name.capitalize()}**!")
st.write(f"MCP Servers: **{server_count}** | Tools: **{tool_count}**")
st.write(f"Model: **{model_name}**")
st.divider()
display_chat_messages()
handle_user_input()
except Exception as e:
logger.critical(f"Critical error in main app flow: {e}", exc_info=True)
st.error(f"A critical application error occurred: {e}")
if __name__ == "__main__": if __name__ == "__main__":
logger.info("Starting Streamlit Chat App...")
main() main()

View File

263
src/custom_mcp/client.py Normal file
View File

@@ -0,0 +1,263 @@
"""Client class for managing and interacting with a single MCP server process."""
import asyncio
import logging
from typing import Any
from custom_mcp import process, protocol
logger = logging.getLogger(__name__)
LIST_TOOLS_TIMEOUT = 20.0
CALL_TOOL_TIMEOUT = 110.0
class MCPClient:
"""
Manages the lifecycle and async communication with a single MCP server process.
"""
def __init__(self, server_name: str, command: str, args: list[str], config_env: dict[str, str]):
"""
Initializes the client for a specific server configuration.
Args:
server_name: Unique name for the server (for logging).
command: The command executable.
args: List of arguments for the command.
config_env: Server-specific environment variables.
"""
self.server_name = server_name
self.command = command
self.args = args
self.config_env = config_env
self.process: asyncio.subprocess.Process | None = None
self.reader: asyncio.StreamReader | None = None
self.writer: asyncio.StreamWriter | None = None
self._stderr_task: asyncio.Task | None = None
self._request_counter = 0
self._is_running = False
self.logger = logging.getLogger(f"{__name__}.{self.server_name}")
async def _log_stderr(self):
"""Logs stderr output from the server process."""
if not self.process or not self.process.stderr:
self.logger.debug("Stderr logging skipped: process or stderr not available.")
return
stderr_reader = self.process.stderr
try:
while not stderr_reader.at_eof():
line = await stderr_reader.readline()
if line:
self.logger.warning(f"[stderr] {line.decode().strip()}")
except asyncio.CancelledError:
self.logger.debug("Stderr logging task cancelled.")
except Exception as e:
self.logger.error(f"Error reading stderr: {e}", exc_info=True)
finally:
self.logger.debug("Stderr logging task finished.")
async def start(self) -> bool:
"""
Starts the MCP server subprocess and sets up communication streams.
Returns:
True if the process started successfully, False otherwise.
"""
if self._is_running:
self.logger.warning("Start called but client is already running.")
return True
self.logger.info("Starting MCP server process...")
try:
self.process = await process.start_mcp_process(self.command, self.args, self.config_env)
self.reader = self.process.stdout
self.writer = self.process.stdin
if self.reader is None or self.writer is None:
self.logger.error("Failed to get stdout/stdin streams after process start.")
await self.stop()
return False
self._stderr_task = asyncio.create_task(self._log_stderr())
self.logger.info("Starting MCP initialization handshake...")
self._request_counter += 1
init_req_id = self._request_counter
initialize_req = {
"jsonrpc": "2.0",
"id": init_req_id,
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"clientInfo": {"name": "CustomMCPClient", "version": "1.0.0"},
"capabilities": {},
},
}
INITIALIZE_TIMEOUT = 15.0
try:
await protocol.send_request(self.writer, initialize_req)
self.logger.debug(f"Sent 'initialize' request (ID: {init_req_id}). Waiting for response...")
init_response = await protocol.read_response(self.reader, INITIALIZE_TIMEOUT)
if init_response and init_response.get("id") == init_req_id:
if "error" in init_response:
self.logger.error(f"Server returned error during initialization: {init_response['error']}")
await self.stop()
return False
elif "result" in init_response:
self.logger.info(f"Received 'initialize' response: {init_response.get('result', '{}')}")
initialized_notify = {"jsonrpc": "2.0", "method": "notifications/initialized", "params": {}}
await protocol.send_request(self.writer, initialized_notify)
self.logger.info("'notifications/initialized' notification sent.")
self._is_running = True
self.logger.info("MCP server process started and initialized successfully.")
return True
else:
self.logger.error("Invalid 'initialize' response format (missing result/error).")
await self.stop()
return False
elif init_response:
self.logger.error(f"Received response with mismatched ID during initialization. Expected {init_req_id}, got {init_response.get('id')}")
await self.stop()
return False
else:
self.logger.error(f"'initialize' request timed out after {INITIALIZE_TIMEOUT} seconds.")
await self.stop()
return False
except ConnectionResetError:
self.logger.error("Connection lost during initialization handshake. Stopping client.")
await self.stop()
return False
except Exception as e:
self.logger.error(f"Unexpected error during initialization handshake: {e}", exc_info=True)
await self.stop()
return False
except Exception as e:
self.logger.error(f"Failed to start MCP server process: {e}", exc_info=True)
self.process = None
self.reader = None
self.writer = None
self._is_running = False
return False
async def stop(self):
if not self._is_running and not self.process:
self.logger.debug("Stop called but client is not running.")
return
self.logger.info("Stopping MCP server process...")
self._is_running = False
if self._stderr_task and not self._stderr_task.done():
self._stderr_task.cancel()
try:
await self._stderr_task
except asyncio.CancelledError:
self.logger.debug("Stderr task successfully cancelled.")
except Exception as e:
self.logger.error(f"Error waiting for stderr task cancellation: {e}")
self._stderr_task = None
if self.process:
await process.stop_mcp_process(self.process, self.server_name)
self.process = None
self.reader = None
self.writer = None
self.logger.info("MCP server process stopped.")
async def list_tools(self) -> list[dict[str, Any]] | None:
"""
Sends a 'tools/list' request and waits for the response.
Returns:
A list of tool dictionaries, or None on error/timeout.
"""
if not self._is_running or not self.writer or not self.reader:
self.logger.error("Cannot list tools: client not running or streams unavailable.")
return None
self._request_counter += 1
req_id = self._request_counter
request = {"jsonrpc": "2.0", "method": "tools/list", "id": req_id}
try:
await protocol.send_request(self.writer, request)
response = await protocol.read_response(self.reader, LIST_TOOLS_TIMEOUT)
if response and "result" in response and isinstance(response["result"], dict) and "tools" in response["result"]:
tools = response["result"]["tools"]
if isinstance(tools, list):
self.logger.info(f"Successfully listed {len(tools)} tools.")
return tools
else:
self.logger.error(f"Invalid 'tools' format in response ID {req_id}: {type(tools)}")
return None
elif response and "error" in response:
self.logger.error(f"Error response for listTools ID {req_id}: {response['error']}")
return None
else:
self.logger.error(f"No valid response or timeout for listTools ID {req_id}.")
return None
except ConnectionResetError:
self.logger.error("Connection lost during listTools request. Stopping client.")
await self.stop()
return None
except Exception as e:
self.logger.error(f"Unexpected error during listTools: {e}", exc_info=True)
return None
async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> dict[str, Any] | None:
"""
Sends a 'tools/call' request and waits for the response.
Args:
tool_name: The name of the tool to call.
arguments: The arguments for the tool.
Returns:
The result dictionary from the server, or None on error/timeout.
"""
if not self._is_running or not self.writer or not self.reader:
self.logger.error(f"Cannot call tool '{tool_name}': client not running or streams unavailable.")
return None
self._request_counter += 1
req_id = self._request_counter
request = {
"jsonrpc": "2.0",
"method": "tools/call",
"params": {"name": tool_name, "arguments": arguments},
"id": req_id,
}
try:
await protocol.send_request(self.writer, request)
response = await protocol.read_response(self.reader, CALL_TOOL_TIMEOUT)
if response and "result" in response:
self.logger.info(f"Tool '{tool_name}' executed successfully.")
return response["result"]
elif response and "error" in response:
self.logger.error(f"Error response for tool '{tool_name}' ID {req_id}: {response['error']}")
return {"error": response["error"]}
else:
self.logger.error(f"No valid response or timeout for tool '{tool_name}' ID {req_id}.")
return None
except ConnectionResetError:
self.logger.error(f"Connection lost during callTool '{tool_name}'. Stopping client.")
await self.stop()
return None
except Exception as e:
self.logger.error(f"Unexpected error during callTool '{tool_name}': {e}", exc_info=True)
return None

327
src/custom_mcp/manager.py Normal file
View File

@@ -0,0 +1,327 @@
"""Synchronous manager for multiple MCPClient instances."""
import asyncio
import json
import logging
import threading
from typing import Any
from custom_mcp.client import MCPClient
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
INITIALIZE_TIMEOUT = 60.0
SHUTDOWN_TIMEOUT = 30.0
LIST_ALL_TOOLS_TIMEOUT = 30.0
EXECUTE_TOOL_TIMEOUT = 120.0
class SyncMCPManager:
"""
Manages the lifecycle of multiple MCPClient instances and provides a
synchronous interface to interact with them using a background event loop.
"""
def __init__(self, config_path: str = "config/mcp_config.json"):
"""
Initializes the manager, loads config, but does not start servers yet.
Args:
config_path: Path to the MCP server configuration JSON file.
"""
self.config_path = config_path
self.config: dict[str, Any] | None = None
self.servers: dict[str, MCPClient] = {}
self.initialized = False
self._lock = threading.Lock()
self._loop: asyncio.AbstractEventLoop | None = None
self._thread: threading.Thread | None = None
logger.info(f"Initializing SyncMCPManager with config path: {config_path}")
self._load_config()
def _load_config(self):
"""Load MCP configuration from JSON file."""
logger.debug(f"Attempting to load MCP config from: {self.config_path}")
try:
with open(self.config_path) as f:
self.config = json.load(f)
logger.info("MCP configuration loaded successfully.")
logger.debug(f"Config content: {self.config}")
except FileNotFoundError:
logger.error(f"MCP config file not found at {self.config_path}")
self.config = None
except json.JSONDecodeError as e:
logger.error(f"Error decoding JSON from MCP config file {self.config_path}: {e}")
self.config = None
except Exception as e:
logger.error(f"Error loading MCP config from {self.config_path}: {e}", exc_info=True)
self.config = None
def _run_event_loop(self):
"""Target function for the background event loop thread."""
try:
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
self._loop.run_forever()
finally:
if self._loop and not self._loop.is_closed():
try:
tasks = asyncio.all_tasks(self._loop)
if tasks:
logger.debug(f"Cancelling {len(tasks)} outstanding tasks before closing loop...")
for task in tasks:
task.cancel()
self._loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
logger.debug("Outstanding tasks cancelled.")
self._loop.run_until_complete(self._loop.shutdown_asyncgens())
except Exception as e:
logger.error(f"Error during event loop cleanup: {e}")
finally:
self._loop.close()
asyncio.set_event_loop(None)
logger.info("Event loop thread finished.")
def _start_event_loop_thread(self):
"""Starts the background event loop thread if not already running."""
if self._thread is None or not self._thread.is_alive():
self._thread = threading.Thread(target=self._run_event_loop, name="MCPEventLoop", daemon=True)
self._thread.start()
logger.info("Event loop thread started.")
while self._loop is None or not self._loop.is_running():
import time
time.sleep(0.01)
logger.debug("Event loop is running.")
def _stop_event_loop_thread(self):
"""Stops the background event loop thread."""
if self._loop and self._loop.is_running():
logger.info("Requesting event loop stop...")
self._loop.call_soon_threadsafe(self._loop.stop)
if self._thread and self._thread.is_alive():
logger.info("Waiting for event loop thread to join...")
self._thread.join(timeout=5)
if self._thread.is_alive():
logger.warning("Event loop thread did not stop gracefully.")
self._loop = None
self._thread = None
logger.info("Event loop stopped.")
def initialize(self) -> bool:
"""
Initializes and starts all configured MCP servers synchronously.
Returns:
True if all servers started successfully, False otherwise.
"""
logger.info("Manager initialization requested.")
if not self.config or not self.config.get("mcpServers"):
logger.warning("Initialization skipped: No valid configuration loaded.")
return False
with self._lock:
if self.initialized:
logger.debug("Initialization skipped: Already initialized.")
return True
self._start_event_loop_thread()
if not self._loop:
logger.error("Failed to start event loop for initialization.")
return False
logger.info("Submitting asynchronous server initialization...")
async def _async_init_all():
tasks = []
for server_name, server_config in self.config["mcpServers"].items():
command = server_config.get("command")
args = server_config.get("args", [])
config_env = server_config.get("env", {})
if not command:
logger.error(f"Skipping server {server_name}: Missing 'command'.")
continue
client = MCPClient(server_name, command, args, config_env)
self.servers[server_name] = client
tasks.append(client.start())
results = await asyncio.gather(*tasks, return_exceptions=True)
all_success = True
failed_servers = []
for i, result in enumerate(results):
server_name = list(self.config["mcpServers"].keys())[i]
if isinstance(result, Exception) or result is False:
all_success = False
failed_servers.append(server_name)
if server_name in self.servers:
del self.servers[server_name]
logger.error(f"Failed to start client for server '{server_name}'. Result/Error: {result}")
if not all_success:
logger.error(f"Initialization failed for servers: {failed_servers}")
return all_success
future = asyncio.run_coroutine_threadsafe(_async_init_all(), self._loop)
try:
success = future.result(timeout=INITIALIZE_TIMEOUT)
if success:
logger.info("Asynchronous initialization completed successfully.")
self.initialized = True
else:
logger.error("Asynchronous initialization failed.")
self.initialized = False
self.shutdown()
except TimeoutError:
logger.error(f"Initialization timed out after {INITIALIZE_TIMEOUT}s.")
self.initialized = False
self.shutdown()
success = False
except Exception as e:
logger.error(f"Exception during initialization future result: {e}", exc_info=True)
self.initialized = False
self.shutdown()
success = False
return self.initialized
def shutdown(self):
"""Shuts down all managed MCP servers synchronously."""
logger.info("Manager shutdown requested.")
with self._lock:
if not self.initialized and not self.servers:
logger.debug("Shutdown skipped: Not initialized or no servers running.")
if self._thread and self._thread.is_alive():
self._stop_event_loop_thread()
return
if not self._loop or not self._loop.is_running():
logger.warning("Shutdown requested but event loop not running. Attempting direct cleanup.")
self.servers = {}
self.initialized = False
if self._thread and self._thread.is_alive():
self._stop_event_loop_thread()
return
logger.info("Submitting asynchronous server shutdown...")
async def _async_shutdown_all():
tasks = [client.stop() for client in self.servers.values()]
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
future = asyncio.run_coroutine_threadsafe(_async_shutdown_all(), self._loop)
try:
future.result(timeout=SHUTDOWN_TIMEOUT)
logger.info("Asynchronous shutdown completed.")
except TimeoutError:
logger.error(f"Shutdown timed out after {SHUTDOWN_TIMEOUT}s. Event loop will be stopped.")
except Exception as e:
logger.error(f"Exception during shutdown future result: {e}", exc_info=True)
finally:
self.servers = {}
self.initialized = False
self._stop_event_loop_thread()
logger.info("Manager shutdown complete.")
def list_all_tools(self) -> list[dict[str, Any]]:
"""
Retrieves tools from all initialized MCP servers synchronously.
Returns:
A list of tool definitions in the standard internal format,
aggregated from all servers. Returns empty list on failure.
"""
if not self.initialized or not self.servers:
logger.warning("Cannot list tools: Manager not initialized or no servers running.")
return []
if not self._loop or not self._loop.is_running():
logger.error("Cannot list tools: Event loop not running.")
return []
logger.info(f"Requesting tools from {len(self.servers)} servers...")
async def _async_list_all():
tasks = []
server_names_in_order = []
for server_name, client in self.servers.items():
tasks.append(client.list_tools())
server_names_in_order.append(server_name)
results = await asyncio.gather(*tasks, return_exceptions=True)
all_tools = []
for i, result in enumerate(results):
server_name = server_names_in_order[i]
if isinstance(result, Exception):
logger.error(f"Error listing tools for server '{server_name}': {result}")
elif result is None:
logger.error(f"Failed to list tools for server '{server_name}' (timeout or error).")
elif isinstance(result, list):
for tool in result:
tool["server_name"] = server_name
all_tools.extend(result)
logger.debug(f"Received {len(result)} tools from {server_name}")
else:
logger.error(f"Unexpected result type ({type(result)}) when listing tools for {server_name}.")
return all_tools
future = asyncio.run_coroutine_threadsafe(_async_list_all(), self._loop)
try:
aggregated_tools = future.result(timeout=LIST_ALL_TOOLS_TIMEOUT)
logger.info(f"Aggregated {len(aggregated_tools)} tools from all servers.")
return aggregated_tools
except TimeoutError:
logger.error(f"Listing all tools timed out after {LIST_ALL_TOOLS_TIMEOUT}s.")
return []
except Exception as e:
logger.error(f"Exception during listing all tools future result: {e}", exc_info=True)
return []
def execute_tool(self, server_name: str, tool_name: str, arguments: dict[str, Any]) -> dict[str, Any] | None:
"""
Executes a specific tool on the designated MCP server synchronously.
Args:
server_name: The name of the server hosting the tool.
tool_name: The name of the tool to execute.
arguments: A dictionary of arguments for the tool.
Returns:
The result content from the tool execution (dict),
an error dict ({"error": ...}), or None on timeout/comm failure.
"""
if not self.initialized:
logger.warning(f"Cannot execute tool '{tool_name}' on {server_name}: Manager not initialized.")
return None
client = self.servers.get(server_name)
if not client:
logger.error(f"Cannot execute tool: Server '{server_name}' not found.")
return None
if not self._loop or not self._loop.is_running():
logger.error(f"Cannot execute tool '{tool_name}': Event loop not running.")
return None
logger.info(f"Executing tool '{tool_name}' on server '{server_name}' with args: {arguments}")
future = asyncio.run_coroutine_threadsafe(client.call_tool(tool_name, arguments), self._loop)
try:
result = future.result(timeout=EXECUTE_TOOL_TIMEOUT)
if result is None:
logger.error(f"Tool execution '{tool_name}' on {server_name} failed (timeout or comm error).")
elif isinstance(result, dict) and "error" in result:
logger.error(f"Tool execution '{tool_name}' on {server_name} returned error: {result['error']}")
else:
logger.info(f"Tool '{tool_name}' execution successful.")
return result
except TimeoutError:
logger.error(f"Tool execution timed out after {EXECUTE_TOOL_TIMEOUT}s for '{tool_name}' on {server_name}.")
return None
except Exception as e:
logger.error(f"Exception during tool execution future result for '{tool_name}' on {server_name}: {e}", exc_info=True)
return None

118
src/custom_mcp/process.py Normal file
View File

@@ -0,0 +1,118 @@
"""Async utilities for managing MCP server subprocesses."""
import asyncio
import logging
import os
import subprocess
logger = logging.getLogger(__name__)
async def start_mcp_process(command: str, args: list[str], config_env: dict[str, str]) -> asyncio.subprocess.Process:
"""
Starts an MCP server subprocess using asyncio.create_subprocess_shell.
Handles argument expansion and environment merging.
Args:
command: The main command executable.
args: A list of arguments for the command.
config_env: Server-specific environment variables from config.
Returns:
The started asyncio.subprocess.Process object.
Raises:
FileNotFoundError: If the command is not found.
Exception: For other errors during subprocess creation.
"""
logger.debug(f"Preparing to start process for command: {command}")
expanded_args = []
try:
for arg in args:
if isinstance(arg, str) and "~" in arg:
expanded_args.append(os.path.expanduser(arg))
else:
expanded_args.append(str(arg))
logger.debug(f"Expanded args: {expanded_args}")
except Exception as e:
logger.error(f"Error expanding arguments for {command}: {e}", exc_info=True)
raise ValueError(f"Failed to expand arguments: {e}") from e
merged_env = {**os.environ, **config_env}
try:
cmd_string = subprocess.list2cmdline([command] + expanded_args)
logger.debug(f"Executing shell command: {cmd_string}")
except Exception as e:
logger.error(f"Error creating command string: {e}", exc_info=True)
raise ValueError(f"Failed to create command string: {e}") from e
try:
process = await asyncio.create_subprocess_shell(
cmd_string,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=merged_env,
)
logger.info(f"Subprocess started (PID: {process.pid}) for command: {command}")
return process
except FileNotFoundError:
logger.error(f"Command not found: '{command}' when trying to execute '{cmd_string}'")
raise
except Exception as e:
logger.error(f"Failed to create subprocess for '{cmd_string}': {e}", exc_info=True)
raise
async def stop_mcp_process(process: asyncio.subprocess.Process, server_name: str = "MCP Server"):
"""
Attempts to gracefully stop the MCP server subprocess.
Args:
process: The asyncio.subprocess.Process object to stop.
server_name: A name for logging purposes.
"""
if process is None or process.returncode is not None:
logger.debug(f"Process {server_name} (PID: {process.pid if process else 'N/A'}) already stopped or not started.")
return
pid = process.pid
logger.info(f"Attempting to stop process {server_name} (PID: {pid})...")
if process.stdin and not process.stdin.is_closing():
try:
process.stdin.close()
await process.stdin.wait_closed()
logger.debug(f"Stdin closed for {server_name} (PID: {pid})")
except Exception as e:
logger.warning(f"Error closing stdin for {server_name} (PID: {pid}): {e}")
try:
process.terminate()
logger.debug(f"Sent terminate signal to {server_name} (PID: {pid})")
await asyncio.wait_for(process.wait(), timeout=5.0)
logger.info(f"Process {server_name} (PID: {pid}) terminated gracefully (return code: {process.returncode}).")
except TimeoutError:
logger.warning(f"Process {server_name} (PID: {pid}) did not terminate gracefully after 5s, killing.")
try:
process.kill()
await process.wait()
logger.info(f"Process {server_name} (PID: {pid}) killed (return code: {process.returncode}).")
except ProcessLookupError:
logger.warning(f"Process {server_name} (PID: {pid}) already exited before kill.")
except Exception as e_kill:
logger.error(f"Error killing process {server_name} (PID: {pid}): {e_kill}")
except ProcessLookupError:
logger.warning(f"Process {server_name} (PID: {pid}) already exited before termination.")
except Exception as e_term:
logger.error(f"Error during termination of {server_name} (PID: {pid}): {e_term}")
if process.returncode is None:
try:
process.kill()
await process.wait()
logger.info(f"Process {server_name} (PID: {pid}) killed after termination error (return code: {process.returncode}).")
except Exception as e_kill_fallback:
logger.error(f"Error killing process {server_name} (PID: {pid}) after termination error: {e_kill_fallback}")

View File

@@ -0,0 +1,74 @@
"""Async utilities for MCP JSON-RPC communication over streams."""
import asyncio
import json
import logging
from typing import Any
logger = logging.getLogger(__name__)
async def send_request(writer: asyncio.StreamWriter, request_dict: dict[str, Any]):
"""
Sends a JSON-RPC request dictionary to the MCP server's stdin stream.
Args:
writer: The asyncio StreamWriter connected to the process stdin.
request_dict: The request dictionary to send.
Raises:
ConnectionResetError: If the connection is lost during send.
Exception: For other stream writing errors.
"""
try:
request_json = json.dumps(request_dict) + "\n"
writer.write(request_json.encode("utf-8"))
await writer.drain()
logger.debug(f"Sent request ID {request_dict.get('id')}: {request_json.strip()}")
except ConnectionResetError:
logger.error(f"Connection lost while sending request ID {request_dict.get('id')}")
raise
except Exception as e:
logger.error(f"Error sending request ID {request_dict.get('id')}: {e}", exc_info=True)
raise
async def read_response(reader: asyncio.StreamReader, timeout: float) -> dict[str, Any] | None:
"""
Reads and parses a JSON-RPC response line from the MCP server's stdout stream.
Args:
reader: The asyncio StreamReader connected to the process stdout.
timeout: Seconds to wait for a response line.
Returns:
The parsed response dictionary, or None if timeout or error occurs.
"""
response_str = None
try:
response_json = await asyncio.wait_for(reader.readline(), timeout=timeout)
if not response_json:
logger.warning("Received empty response line (EOF?).")
return None
response_str = response_json.decode("utf-8").strip()
if not response_str:
logger.warning("Received empty response string after strip.")
return None
logger.debug(f"Received response line: {response_str}")
response_dict = json.loads(response_str)
return response_dict
except TimeoutError:
logger.error(f"Timeout ({timeout}s) waiting for response.")
return None
except asyncio.IncompleteReadError:
logger.warning("Connection closed while waiting for response.")
return None
except json.JSONDecodeError as e:
logger.error(f"Error decoding JSON response: {e}. Response: '{response_str}'")
return None
except Exception as e:
logger.error(f"Error reading response: {e}", exc_info=True)
return None

View File

@@ -1,5 +0,0 @@
"""Custom MCP client implementation focused on OpenAI integration."""
from .client import MCPClient, run_interaction
__all__ = ["MCPClient", "run_interaction"]

View File

@@ -1,550 +0,0 @@
"""Custom MCP client implementation with JSON-RPC and OpenAI integration."""
import asyncio
import json
import logging
import os
from collections.abc import AsyncGenerator
from openai import AsyncOpenAI
# Get a logger for this module
logger = logging.getLogger(__name__)
class MCPClient:
"""Lightweight MCP client with JSON-RPC communication."""
def __init__(self, server_name: str, command: str, args: list[str] | None = None, env: dict[str, str] | None = None):
self.server_name = server_name
self.command = command
self.args = args or []
self.env = env or {}
self.process = None
self.tools = []
self.request_id = 0
self.responses = {}
self._shutdown = False
# Use a logger specific to this client instance
self.logger = logging.getLogger(f"{__name__}.{self.server_name}")
async def _receive_loop(self):
"""Listen for responses from the MCP server."""
try:
while self.process and self.process.stdout and not self.process.stdout.at_eof():
line_bytes = await self.process.stdout.readline()
if not line_bytes:
self.logger.debug("STDOUT EOF reached.")
break
line_str = line_bytes.decode().strip()
self.logger.debug(f"STDOUT Raw line: {line_str}")
try:
message = json.loads(line_str)
if "jsonrpc" in message and "id" in message and ("result" in message or "error" in message):
self.logger.debug(f"STDOUT Parsed response for ID {message['id']}")
self.responses[message["id"]] = message
elif "jsonrpc" in message and "method" in message:
self.logger.debug(f"STDOUT Received notification: {message.get('method')}")
else:
self.logger.debug(f"STDOUT Parsed non-response/notification JSON: {message}")
except json.JSONDecodeError:
self.logger.warning("STDOUT Failed to parse line as JSON.")
except Exception as e:
self.logger.error(f"STDOUT Error processing line: {e}", exc_info=True)
except asyncio.CancelledError:
self.logger.debug("STDOUT Receive loop cancelled.")
except Exception as e:
self.logger.error(f"STDOUT Receive loop error: {e}", exc_info=True)
finally:
self.logger.debug("STDOUT Receive loop finished.")
async def _stderr_loop(self):
"""Listen for stderr messages from the MCP server."""
try:
while self.process and self.process.stderr and not self.process.stderr.at_eof():
line_bytes = await self.process.stderr.readline()
if not line_bytes:
self.logger.debug("STDERR EOF reached.")
break
line_str = line_bytes.decode().strip()
self.logger.warning(f"STDERR: {line_str}") # Log stderr as warning
except asyncio.CancelledError:
self.logger.debug("STDERR Stderr loop cancelled.")
except Exception as e:
self.logger.error(f"STDERR Stderr loop error: {e}", exc_info=True)
finally:
self.logger.debug("STDERR Stderr loop finished.")
async def _send_message(self, message: dict) -> bool:
"""Send a JSON-RPC message to the MCP server."""
if not self.process or not self.process.stdin:
self.logger.warning("STDIN Cannot send message, process or stdin not available.")
return False
try:
data = json.dumps(message) + "\n"
self.logger.debug(f"STDIN Sending: {data.strip()}")
self.process.stdin.write(data.encode())
await self.process.stdin.drain()
return True
except ConnectionResetError:
self.logger.error("STDIN Connection reset while sending message.")
self.process = None # Mark process as dead
return False
except Exception as e:
self.logger.error(f"STDIN Error sending message: {e}", exc_info=True)
return False
async def start(self) -> bool:
"""Start the MCP server process."""
self.logger.info("Attempting to start server...")
# Expand ~ in paths and prepare args
expanded_args = []
try:
for a in self.args:
if isinstance(a, str) and "~" in a:
expanded_args.append(os.path.expanduser(a))
else:
expanded_args.append(str(a)) # Ensure all args are strings
except Exception as e:
self.logger.error(f"Error expanding arguments: {e}", exc_info=True)
return False
# Set up environment
env_vars = os.environ.copy()
if self.env:
env_vars.update(self.env)
self.logger.debug(f"Command: {self.command}")
self.logger.debug(f"Expanded Args: {expanded_args}")
# Avoid logging full env unless necessary for debugging sensitive info
# self.logger.debug(f"Environment: {env_vars}")
try:
# Start the subprocess
self.logger.debug("Creating subprocess...")
self.process = await asyncio.create_subprocess_exec(
self.command,
*expanded_args,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE, # Capture stderr
env=env_vars,
)
self.logger.info(f"Subprocess created with PID: {self.process.pid}")
# Start the receive loops
asyncio.create_task(self._receive_loop())
asyncio.create_task(self._stderr_loop()) # Start stderr loop
# Initialize the server
self.logger.debug("Attempting initialization handshake...")
init_success = await self._initialize()
if init_success:
self.logger.info("Initialization handshake successful (or skipped).")
# Add delay after successful start
await asyncio.sleep(0.5)
self.logger.debug("Post-initialization delay complete.")
return True
else:
self.logger.error("Initialization handshake failed.")
await self.stop() # Ensure cleanup if init fails
return False
except FileNotFoundError:
self.logger.error(f"Error starting subprocess: Command not found: '{self.command}'")
return False
except Exception as e:
self.logger.error(f"Error starting subprocess: {e}", exc_info=True)
return False
async def _initialize(self) -> bool:
"""Initialize the MCP server connection. Modified to not wait for response."""
self.logger.debug("Sending 'initialize' request...")
if not self.process:
self.logger.warning("Cannot initialize, process not running.")
return False
# Send initialize request
self.request_id += 1
req_id = self.request_id
initialize_req = {
"jsonrpc": "2.0",
"id": req_id,
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"clientInfo": {"name": "CustomMCPClient", "version": "1.0.0"},
"capabilities": {}, # Add empty capabilities object
},
}
if not await self._send_message(initialize_req):
self.logger.warning("Failed to send 'initialize' request.")
# Continue anyway for non-compliant servers
# Send initialized notification immediately
self.logger.debug("Sending 'initialized' notification...")
notify = {"jsonrpc": "2.0", "method": "notifications/initialized"}
if await self._send_message(notify):
self.logger.debug("'initialized' notification sent.")
else:
self.logger.warning("Failed to send 'initialized' notification.")
# Still return True as the server might be running
self.logger.info("Skipping wait for 'initialize' response (assuming non-compliant server).")
return True # Assume success without waiting for response
async def list_tools(self) -> list[dict]:
"""List available tools from the MCP server."""
if not self.process:
self.logger.warning("Cannot list tools, process not running.")
return []
self.logger.debug("Sending 'tools/list' request...")
self.request_id += 1
req_id = self.request_id
req = {"jsonrpc": "2.0", "id": req_id, "method": "tools/list", "params": {}}
if not await self._send_message(req):
self.logger.error("Failed to send 'tools/list' request.")
return []
# Wait for response
self.logger.debug(f"Waiting for 'tools/list' response (ID: {req_id})...")
start_time = asyncio.get_event_loop().time()
timeout = 10 # seconds
while asyncio.get_event_loop().time() - start_time < timeout:
if req_id in self.responses:
resp = self.responses.pop(req_id)
self.logger.debug(f"Received 'tools/list' response: {resp}")
if "error" in resp:
self.logger.error(f"'tools/list' error response: {resp['error']}")
return []
if "result" in resp and "tools" in resp["result"]:
self.tools = resp["result"]["tools"]
self.logger.info(f"Successfully listed tools: {len(self.tools)}")
return self.tools
else:
self.logger.error("Invalid 'tools/list' response format.")
return []
await asyncio.sleep(0.05)
self.logger.error(f"'tools/list' request timed out after {timeout} seconds.")
return []
async def call_tool(self, tool_name: str, arguments: dict) -> dict:
"""Call a tool on the MCP server."""
if not self.process:
self.logger.warning(f"Cannot call tool '{tool_name}', process not running.")
return {"error": "Server not started"}
self.logger.debug(f"Sending 'tools/call' request for tool '{tool_name}'...")
self.request_id += 1
req_id = self.request_id
req = {"jsonrpc": "2.0", "id": req_id, "method": "tools/call", "params": {"name": tool_name, "arguments": arguments}}
if not await self._send_message(req):
self.logger.error("Failed to send 'tools/call' request.")
return {"error": "Failed to send tool call request"}
# Wait for response
self.logger.debug(f"Waiting for 'tools/call' response (ID: {req_id})...")
start_time = asyncio.get_event_loop().time()
timeout = 30 # seconds
while asyncio.get_event_loop().time() - start_time < timeout:
if req_id in self.responses:
resp = self.responses.pop(req_id)
self.logger.debug(f"Received 'tools/call' response: {resp}")
if "error" in resp:
self.logger.error(f"'tools/call' error response: {resp['error']}")
return {"error": str(resp["error"])}
if "result" in resp:
self.logger.info(f"Tool '{tool_name}' executed successfully.")
return resp["result"]
else:
self.logger.error("Invalid 'tools/call' response format.")
return {"error": "Invalid tool call response format"}
await asyncio.sleep(0.05)
self.logger.error(f"Tool call '{tool_name}' timed out after {timeout} seconds.")
return {"error": f"Tool call timed out after {timeout}s"}
async def stop(self):
"""Stop the MCP server process."""
self.logger.info("Attempting to stop server...")
if self._shutdown or not self.process:
self.logger.debug("Server already stopped or not running.")
return
self._shutdown = True
proc = self.process # Keep a local reference
self.process = None # Prevent further operations
try:
# Send shutdown notification
self.logger.debug("Sending 'shutdown' notification...")
notify = {"jsonrpc": "2.0", "method": "shutdown"}
await self._send_message(notify) # Use the method which now handles None process
await asyncio.sleep(0.5) # Give server time to process
# Close stdin
if proc and proc.stdin:
try:
if not proc.stdin.is_closing():
proc.stdin.close()
await proc.stdin.wait_closed()
self.logger.debug("Stdin closed.")
except Exception as e:
self.logger.warning(f"Error closing stdin: {e}", exc_info=True)
# Terminate the process
if proc:
self.logger.debug(f"Terminating process {proc.pid}...")
proc.terminate()
try:
await asyncio.wait_for(proc.wait(), timeout=2.0)
self.logger.info(f"Process {proc.pid} terminated gracefully.")
except TimeoutError:
self.logger.warning(f"Process {proc.pid} did not terminate gracefully, killing...")
proc.kill()
await proc.wait()
self.logger.info(f"Process {proc.pid} killed.")
except Exception as e:
self.logger.error(f"Error waiting for process termination: {e}", exc_info=True)
except Exception as e:
self.logger.error(f"Error during shutdown sequence: {e}", exc_info=True)
finally:
self.logger.debug("Stop sequence finished.")
# Ensure self.process is None even if errors occurred
self.process = None
async def process_tool_call(tool_call: dict, servers: dict[str, MCPClient]) -> dict:
"""Process a tool call from OpenAI."""
func_name = tool_call["function"]["name"]
try:
func_args = json.loads(tool_call["function"].get("arguments", "{}"))
except json.JSONDecodeError as e:
logger.error(f"Invalid tool arguments format for {func_name}: {e}")
return {"error": "Invalid arguments format"}
# Parse server_name and tool_name from function name
parts = func_name.split("_", 1)
if len(parts) != 2:
logger.error(f"Invalid tool function name format: {func_name}")
return {"error": "Invalid function name format"}
server_name, tool_name = parts
if server_name not in servers:
logger.error(f"Tool call for unknown server: {server_name}")
return {"error": f"Unknown server: {server_name}"}
# Call the tool
return await servers[server_name].call_tool(tool_name, func_args)
async def run_interaction(
user_query: str,
model_name: str,
api_key: str,
base_url: str | None,
mcp_config: dict,
stream: bool = False,
) -> dict | AsyncGenerator:
"""
Run an interaction with OpenAI using MCP server tools.
Args:
user_query: The user's input query.
model_name: The model to use for processing.
api_key: The OpenAI API key.
base_url: The OpenAI API base URL (optional).
mcp_config: The MCP configuration dictionary (for servers).
stream: Whether to stream the response.
Returns:
Dictionary containing response or AsyncGenerator for streaming.
"""
# Validate passed arguments
if not api_key:
logger.error("API key is missing.")
if not stream:
return {"error": "API key is missing."}
else:
async def error_gen():
yield {"error": "API key is missing."}
return error_gen()
# Start MCP servers using mcp_config
servers = {}
all_functions = []
if mcp_config.get("mcpServers"): # Use mcp_config here
for server_name, server_config in mcp_config["mcpServers"].items(): # Use mcp_config here
client = MCPClient(server_name=server_name, command=server_config.get("command"), args=server_config.get("args", []), env=server_config.get("env", {}))
if await client.start():
tools = await client.list_tools()
for tool in tools:
# Ensure parameters is a dict, default to empty if missing or not dict
params = tool.get("inputSchema", {})
if not isinstance(params, dict):
logger.warning(f"Tool '{tool['name']}' for server '{server_name}' has non-dict inputSchema, defaulting to empty.")
params = {}
all_functions.append({
"type": "function", # Explicitly set type for clarity with newer OpenAI API
"function": {"name": f"{server_name}_{tool['name']}", "description": tool.get("description", ""), "parameters": params},
})
servers[server_name] = client
else:
logger.warning(f"Failed to start MCP server '{server_name}', it will be unavailable.")
else:
logger.info("No mcpServers defined in configuration.")
# Use passed api_key and base_url
openai_client = AsyncOpenAI(api_key=api_key, base_url=base_url) # Use arguments
messages = [{"role": "user", "content": user_query}]
tool_defs = [{"type": "function", "function": f["function"]} for f in all_functions] if all_functions else None
if stream:
async def response_generator():
active_servers = list(servers.values()) # Keep track for cleanup
try:
while True:
logger.debug(f"Calling OpenAI with messages: {messages}")
logger.debug(f"Calling OpenAI with tools: {tool_defs}")
# Get OpenAI response
try:
response = await openai_client.chat.completions.create(
model=model_name,
messages=messages,
tools=tool_defs,
tool_choice="auto" if tool_defs else None, # Only set tool_choice if tools exist
stream=True,
)
except Exception as e:
logger.error(f"OpenAI API error: {e}", exc_info=True)
yield {"error": f"OpenAI API error: {e}"}
break
# Process streaming response
full_response_content = ""
tool_calls = []
async for chunk in response:
delta = chunk.choices[0].delta
if delta.content:
content = delta.content
full_response_content += content
yield {"assistant_text": content, "is_chunk": True}
if delta.tool_calls:
for tc in delta.tool_calls:
# Initialize tool call structure if it's the first chunk for this index
if tc.index >= len(tool_calls):
tool_calls.append({"id": "", "type": "function", "function": {"name": "", "arguments": ""}})
# Append parts as they arrive
if tc.id:
tool_calls[tc.index]["id"] = tc.id
if tc.function and tc.function.name:
tool_calls[tc.index]["function"]["name"] = tc.function.name
if tc.function and tc.function.arguments:
tool_calls[tc.index]["function"]["arguments"] += tc.function.arguments
# Add assistant message with content and potential tool calls
assistant_message = {"role": "assistant", "content": full_response_content}
if tool_calls:
# Filter out incomplete tool calls just in case
valid_tool_calls = [tc for tc in tool_calls if tc["id"] and tc["function"]["name"]]
if valid_tool_calls:
assistant_message["tool_calls"] = valid_tool_calls
else:
logger.warning("Received tool call chunks but couldn't assemble valid tool calls.")
messages.append(assistant_message)
logger.debug(f"Assistant message added: {assistant_message}")
# Handle tool calls if any were successfully assembled
if "tool_calls" in assistant_message:
tool_results = []
for tc in assistant_message["tool_calls"]:
logger.info(f"Processing tool call: {tc['function']['name']}")
result = await process_tool_call(tc, servers)
tool_results.append({"role": "tool", "content": json.dumps(result), "tool_call_id": tc["id"]})
messages.extend(tool_results)
logger.debug(f"Tool results added: {tool_results}")
# Loop back to call OpenAI again with tool results
else:
# No tool calls, interaction finished for this turn
yield {"assistant_text": full_response_content, "is_chunk": False, "final": True} # Signal final chunk
break
except Exception as e:
logger.error(f"Error during streaming interaction: {e}", exc_info=True)
yield {"error": f"Interaction error: {e}"}
finally:
# Clean up servers
logger.debug("Cleaning up MCP servers (stream)...")
for server in active_servers:
await server.stop()
logger.debug("MCP server cleanup finished (stream).")
return response_generator()
else: # Non-streaming case
active_servers = list(servers.values()) # Keep track for cleanup
try:
while True:
logger.debug(f"Calling OpenAI with messages: {messages}")
logger.debug(f"Calling OpenAI with tools: {tool_defs}")
# Get OpenAI response
try:
response = await openai_client.chat.completions.create(
model=model_name,
messages=messages,
tools=tool_defs,
tool_choice="auto" if tool_defs else None, # Only set tool_choice if tools exist
)
except Exception as e:
logger.error(f"OpenAI API error: {e}", exc_info=True)
return {"error": f"OpenAI API error: {e}"}
message = response.choices[0].message
messages.append(message)
logger.debug(f"OpenAI response message: {message}")
# Handle tool calls
if message.tool_calls:
tool_results = []
for tc in message.tool_calls:
logger.info(f"Processing tool call: {tc.function.name}")
# Reconstruct dict for process_tool_call
tool_call_dict = {"id": tc.id, "type": "function", "function": {"name": tc.function.name, "arguments": tc.function.arguments}}
result = await process_tool_call(tool_call_dict, servers)
tool_results.append({"role": "tool", "content": json.dumps(result), "tool_call_id": tc.id})
messages.extend(tool_results)
logger.debug(f"Tool results added: {tool_results}")
# Loop back to call OpenAI again with tool results
else:
# No tool calls, interaction finished
logger.info("Interaction finished, no tool calls.")
return {"assistant_text": message.content or "", "tool_calls": []}
except Exception as e:
logger.error(f"Error during non-streaming interaction: {e}", exc_info=True)
return {"error": f"Interaction error: {e}"}
finally:
# Clean up servers
logger.debug("Cleaning up MCP servers (non-stream)...")
for server in active_servers:
await server.stop()
logger.debug("MCP server cleanup finished (non-stream).")

221
src/llm_client.py Normal file
View File

@@ -0,0 +1,221 @@
"""
Generic LLM client supporting multiple providers and MCP tool integration.
"""
import json
import logging
from collections.abc import Generator
from typing import Any
from src.custom_mcp.manager import SyncMCPManager # Updated import path
from src.providers import BaseProvider, create_llm_provider
logger = logging.getLogger(__name__)
class LLMClient:
"""
Handles chat completion requests to various LLM providers through a unified
interface, integrating with MCP tools via SyncMCPManager.
"""
def __init__(
self,
provider_name: str,
api_key: str,
mcp_manager: SyncMCPManager,
base_url: str | None = None,
temperature: float = 0.6, # Add temperature parameter with a fallback default
):
"""
Initialize the LLM client.
Args:
provider_name: Name of the provider (e.g., 'openai', 'anthropic').
api_key: API key for the provider.
mcp_manager: An initialized instance of SyncMCPManager.
base_url: Optional base URL for the provider API.
temperature: Default temperature to configure the provider with.
"""
logger.info(f"Initializing LLMClient for provider: {provider_name}")
self.provider: BaseProvider = create_llm_provider(
provider_name,
api_key,
base_url,
temperature=temperature, # Pass temperature to provider factory
)
self.mcp_manager = mcp_manager
self.mcp_tools: list[dict[str, Any]] = []
self._refresh_mcp_tools() # Initial tool load
def _refresh_mcp_tools(self):
"""Retrieves the latest tools from the MCP manager."""
logger.info("Refreshing MCP tools...")
try:
self.mcp_tools = self.mcp_manager.list_all_tools()
logger.info(f"Refreshed {len(self.mcp_tools)} MCP tools.")
except Exception as e:
logger.error(f"Error refreshing MCP tools: {e}", exc_info=True)
# Keep existing tools if refresh fails
def chat_completion(
self,
messages: list[dict[str, str]],
model: str,
# temperature: float = 0.6, # REMOVE THIS LINE
max_tokens: int | None = None,
stream: bool = True,
) -> Generator[str, None, None] | dict[str, Any]:
"""
Send a chat completion request, handling potential tool calls.
Args:
messages: List of message dictionaries ({'role': 'user'/'assistant', 'content': ...}).
model: Model identifier string.
# temperature: REMOVED - Provider uses its configured temperature.
max_tokens: Maximum tokens to generate.
stream: Whether to stream the response.
Returns:
If stream=True: A generator yielding content chunks.
If stream=False: A dictionary containing the final content, usage, or an error.
e.g., {"content": "...", "usage": {"prompt_tokens": ..., "completion_tokens": ...}}
or {"error": "..."}
"""
# Ensure tools are up-to-date (optional, could be done less frequently)
# self._refresh_mcp_tools()
# Convert tools to the provider-specific format
try:
provider_tools = self.provider.convert_tools(self.mcp_tools)
logger.debug(f"Converted {len(self.mcp_tools)} tools for provider {self.provider.__class__.__name__}")
except Exception as e:
logger.error(f"Error converting tools for provider: {e}", exc_info=True)
provider_tools = None # Proceed without tools if conversion fails
try:
logger.info(f"Sending chat completion request to provider with model: {model}")
response = self.provider.create_chat_completion(
messages=messages,
model=model,
# temperature=temperature, # REMOVE THIS LINE (provider uses its own)
max_tokens=max_tokens,
stream=stream,
tools=provider_tools,
)
print(f"Response: {response}") # Debugging line to check the response
logger.info("Received response from provider.")
if stream:
# Streaming with tool calls requires more complex handling (like airflow_wingman's
# process_tool_calls_and_follow_up). For now, we'll yield the initial stream
# and handle tool calls *after* the stream completes if detected (less ideal UX).
# A better approach involves checking for tool calls before streaming fully.
# This simplified version just streams the first response.
logger.info("Streaming response...")
# NOTE: This simple version doesn't handle tool calls during streaming well.
# It will stream the initial response which might *contain* the tool call request,
# but won't execute it within the stream.
return self._stream_generator(response)
else: # Non-streaming
logger.info("Processing non-streaming response...")
if self.provider.has_tool_calls(response):
logger.info("Tool calls detected in response.")
# Simplified non-streaming tool call handling (one round)
try:
tool_calls = self.provider.parse_tool_calls(response)
logger.debug(f"Parsed tool calls: {tool_calls}")
tool_results = []
original_message_with_calls = self.provider.get_original_message_with_calls(response) # Provider needs to implement this
messages.append(original_message_with_calls) # Add assistant's turn with tool requests
for tool_call in tool_calls:
server_name = tool_call.get("server_name") # Needs to be parsed by provider
func_name = tool_call.get("function_name")
func_args_str = tool_call.get("arguments")
call_id = tool_call.get("id")
if not server_name or not func_name or func_args_str is None or call_id is None:
logger.error(f"Skipping invalid tool call data: {tool_call}")
# Add error result?
result_content = {"error": "Invalid tool call structure from LLM"}
else:
try:
# Arguments might be a JSON string, parse them
arguments = json.loads(func_args_str)
logger.info(f"Executing tool '{func_name}' on server '{server_name}' with args: {arguments}")
# Execute synchronously using the manager
execution_result = self.mcp_manager.execute_tool(server_name, func_name, arguments)
logger.debug(f"Tool execution result: {execution_result}")
if execution_result is None:
result_content = {"error": f"Tool execution failed or timed out for {func_name}"}
elif isinstance(execution_result, dict) and "error" in execution_result:
result_content = execution_result # Propagate error from tool/server
else:
# Assuming result is the content payload
result_content = execution_result
except json.JSONDecodeError:
logger.error(f"Failed to parse arguments for tool {func_name}: {func_args_str}")
result_content = {"error": f"Invalid arguments format for tool {func_name}"}
except Exception as exec_err:
logger.error(f"Error executing tool {func_name}: {exec_err}", exc_info=True)
result_content = {"error": f"Exception during tool execution: {str(exec_err)}"}
# Format result for the provider's follow-up message
formatted_result = self.provider.format_tool_results(call_id, result_content)
tool_results.append(formatted_result)
messages.append(formatted_result) # Add tool result message
# Make follow-up call
logger.info("Making follow-up request with tool results...")
follow_up_response = self.provider.create_chat_completion(
messages=messages, # Now includes assistant's turn and tool results
model=model,
# temperature=temperature, # REMOVE THIS LINE
max_tokens=max_tokens,
stream=False, # Follow-up is non-streaming here
tools=provider_tools, # Pass tools again? Some providers might need it.
)
final_content = self.provider.get_content(follow_up_response)
final_usage = self.provider.get_usage(follow_up_response) # Get usage from follow-up
logger.info("Received follow-up response content.")
result_dict = {"content": final_content}
if final_usage:
result_dict["usage"] = final_usage
return result_dict
except Exception as tool_handling_err:
logger.error(f"Error processing tool calls: {tool_handling_err}", exc_info=True)
return {"error": f"Failed to handle tool calls: {str(tool_handling_err)}"}
else: # No tool calls
logger.info("No tool calls detected.")
content = self.provider.get_content(response)
usage = self.provider.get_usage(response) # Get usage from initial response
result_dict = {"content": content}
if usage:
result_dict["usage"] = usage
return result_dict
except Exception as e:
error_msg = f"LLM API Error: {str(e)}"
logger.error(error_msg, exc_info=True)
if stream:
# How to signal error in a stream? Yield a specific error message?
# This simple generator won't handle it well. Returning an error dict for now.
return {"error": error_msg} # Or raise?
else:
return {"error": error_msg}
def _stream_generator(self, response: Any) -> Generator[str, None, None]:
"""Helper to yield content from the provider's streaming method."""
try:
# Use yield from for cleaner and potentially more efficient delegation
yield from self.provider.get_streaming_content(response)
except Exception as e:
logger.error(f"Error during streaming: {e}", exc_info=True)
yield json.dumps({"error": f"Streaming error: {str(e)}"}) # Yield error as JSON chunk

View File

@@ -1,234 +0,0 @@
"""Synchronous wrapper for managing MCP servers using our custom implementation."""
import asyncio
import importlib.resources
import json
import logging # Import logging
import threading
from custom_mcp_client import MCPClient, run_interaction
# Configure basic logging for the application if not already configured
# This basic config helps ensure logs are visible during development
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
# Get a logger for this module
logger = logging.getLogger(__name__)
class SyncMCPManager:
"""Synchronous wrapper for managing MCP servers and interactions"""
def __init__(self, config_path: str = "config/mcp_config.json"):
self.config_path = config_path
self.config = None
self.servers = {}
self.initialized = False
self._lock = threading.Lock()
logger.info(f"Initializing SyncMCPManager with config path: {config_path}")
self._load_config()
def _load_config(self):
"""Load MCP configuration from JSON file using importlib"""
logger.debug(f"Attempting to load MCP config from: {self.config_path}")
try:
# First try to load as a package resource
try:
# Try anchoring to the project name defined in pyproject.toml
# This *might* work depending on editable installs or context.
resource_path = importlib.resources.files("streamlit-chat-app").joinpath(self.config_path)
with resource_path.open("r") as f:
self.config = json.load(f)
logger.debug("Loaded config via importlib.resources anchored to 'streamlit-chat-app'.")
# REMOVED: raise FileNotFoundError
except (ImportError, ModuleNotFoundError, TypeError, FileNotFoundError, NotADirectoryError): # Added NotADirectoryError
logger.debug("Failed to load via importlib.resources, falling back to direct file access.")
# Fall back to direct file access relative to CWD
with open(self.config_path) as f:
self.config = json.load(f)
logger.debug("Loaded config via direct file access.")
logger.info("MCP configuration loaded successfully.")
logger.debug(f"Config content: {self.config}") # Log content only if loaded
except FileNotFoundError:
logger.error(f"MCP config file not found at {self.config_path}")
self.config = None
except json.JSONDecodeError as e:
logger.error(f"Error decoding JSON from MCP config file {self.config_path}: {e}")
self.config = None
except Exception as e:
logger.error(f"Error loading MCP config from {self.config_path}: {e}", exc_info=True)
self.config = None
def initialize(self) -> bool:
"""Initialize and start all MCP servers synchronously"""
logger.info("Initialize requested.")
if not self.config:
logger.warning("Initialization skipped: No configuration loaded.")
return False
if not self.config.get("mcpServers"):
logger.warning("Initialization skipped: No 'mcpServers' defined in configuration.")
return False
if self.initialized:
logger.debug("Initialization skipped: Already initialized.")
return True
with self._lock:
if self.initialized: # Double-check after acquiring lock
logger.debug("Initialization skipped inside lock: Already initialized.")
return True
logger.info("Starting asynchronous initialization...")
# Run async initialization in a new event loop
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) # Ensure this loop is used by tasks
success = loop.run_until_complete(self._async_initialize())
loop.close()
asyncio.set_event_loop(None) # Clean up
if success:
logger.info("Asynchronous initialization completed successfully.")
self.initialized = True
else:
logger.error("Asynchronous initialization failed.")
self.initialized = False # Ensure state is False on failure
return self.initialized
async def _async_initialize(self) -> bool:
"""Async implementation of server initialization"""
logger.debug("Starting _async_initialize...")
all_success = True
if not self.config or not self.config.get("mcpServers"):
logger.warning("_async_initialize: No config or mcpServers found.")
return False
tasks = []
server_names = list(self.config["mcpServers"].keys())
async def start_server(server_name, server_config):
logger.info(f"Initializing server: {server_name}")
try:
client = MCPClient(server_name=server_name, command=server_config.get("command"), args=server_config.get("args", []), env=server_config.get("env", {}))
logger.debug(f"Attempting to start client for {server_name}...")
if await client.start():
logger.info(f"Client for {server_name} started successfully.")
tools = await client.list_tools()
logger.info(f"Tools listed for {server_name}: {len(tools)}")
self.servers[server_name] = {"client": client, "tools": tools}
return True
else:
logger.error(f"Failed to start MCP server: {server_name}")
return False
except Exception as e:
logger.error(f"Error initializing server {server_name}: {e}", exc_info=True)
return False
# Start servers concurrently
for server_name in server_names:
server_config = self.config["mcpServers"][server_name]
tasks.append(start_server(server_name, server_config))
results = await asyncio.gather(*tasks)
# Check if all servers started successfully
all_success = all(results)
if all_success:
logger.debug("_async_initialize completed: All servers started successfully.")
else:
failed_servers = [server_names[i] for i, res in enumerate(results) if not res]
logger.error(f"_async_initialize completed with failures. Failed servers: {failed_servers}")
# Optionally shutdown servers that did start if partial success is not desired
# await self._async_shutdown() # Uncomment to enforce all-or-nothing startup
return all_success
def shutdown(self):
"""Shut down all MCP servers synchronously"""
logger.info("Shutdown requested.")
if not self.initialized:
logger.debug("Shutdown skipped: Not initialized.")
return
with self._lock:
if not self.initialized:
logger.debug("Shutdown skipped inside lock: Not initialized.")
return
logger.info("Starting asynchronous shutdown...")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(self._async_shutdown())
loop.close()
asyncio.set_event_loop(None)
self.servers = {}
self.initialized = False
logger.info("Shutdown complete.")
async def _async_shutdown(self):
"""Async implementation of server shutdown"""
logger.debug("Starting _async_shutdown...")
tasks = []
for server_name, server_info in self.servers.items():
logger.debug(f"Initiating shutdown for server: {server_name}")
tasks.append(server_info["client"].stop())
results = await asyncio.gather(*tasks, return_exceptions=True)
for i, result in enumerate(results):
server_name = list(self.servers.keys())[i]
if isinstance(result, Exception):
logger.error(f"Error shutting down server {server_name}: {result}", exc_info=result)
else:
logger.debug(f"Shutdown completed for server: {server_name}")
logger.debug("_async_shutdown finished.")
# Updated process_query signature
def process_query(self, query: str, model_name: str, api_key: str, base_url: str | None) -> dict:
"""
Process a query using MCP tools synchronously
Args:
query: The user's input query.
model_name: The model to use for processing.
api_key: The OpenAI API key.
base_url: The OpenAI API base URL.
Returns:
Dictionary containing response or error.
"""
if not self.initialized and not self.initialize():
logger.error("process_query called but MCP manager failed to initialize.")
return {"error": "Failed to initialize MCP servers"}
logger.debug(f"Processing query synchronously: '{query}'")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
# Pass api_key and base_url to _async_process_query
result = loop.run_until_complete(self._async_process_query(query, model_name, api_key, base_url))
logger.debug(f"Synchronous query processing result: {result}")
return result
except Exception as e:
logger.error(f"Error during synchronous query processing: {e}", exc_info=True)
return {"error": f"Processing error: {str(e)}"}
finally:
loop.close()
# Updated _async_process_query signature
async def _async_process_query(self, query: str, model_name: str, api_key: str, base_url: str | None) -> dict:
"""Async implementation of query processing"""
# Pass api_key, base_url, and the MCP config separately to run_interaction
return await run_interaction(
user_query=query,
model_name=model_name,
api_key=api_key,
base_url=base_url,
mcp_config=self.config, # self.config only contains MCP server definitions now
stream=False,
)

View File

@@ -1,68 +0,0 @@
"""OpenAI client with custom MCP integration."""
import configparser
import logging # Import logging
from openai import OpenAI
from mcp_manager import SyncMCPManager
# Get a logger for this module
logger = logging.getLogger(__name__)
class OpenAIClient:
def __init__(self):
logger.debug("Initializing OpenAIClient...") # Add init log
self.config = configparser.ConfigParser()
self.config.read("config/config.ini")
# Validate configuration
if not self.config.has_section("openai"):
raise Exception("Missing [openai] section in config.ini")
if not self.config["openai"].get("api_key"):
raise Exception("Missing api_key in config.ini")
# Configure OpenAI client
self.client = OpenAI(
api_key=self.config["openai"]["api_key"], base_url=self.config["openai"]["base_url"], default_headers={"HTTP-Referer": "https://streamlit-chat-app.com", "X-Title": "Streamlit Chat App"}
)
# Initialize MCP manager if configured
self.mcp_manager = None
if self.config.has_section("mcp"):
mcp_config_path = self.config["mcp"].get("servers_json", "config/mcp_config.json")
self.mcp_manager = SyncMCPManager(mcp_config_path)
def get_chat_response(self, messages):
try:
# Try using MCP if available
if self.mcp_manager and self.mcp_manager.initialize():
logger.info("Using MCP with tools...") # Use logger
last_message = messages[-1]["content"]
# Pass API key and base URL from config.ini
response = self.mcp_manager.process_query(
query=last_message,
model_name=self.config["openai"]["model"],
api_key=self.config["openai"]["api_key"],
base_url=self.config["openai"].get("base_url"), # Use .get for optional base_url
)
if "error" not in response:
logger.debug("MCP processing successful, wrapping response.")
# Convert to OpenAI-compatible response format
return self._wrap_mcp_response(response)
# Fall back to standard OpenAI
logger.info(f"Falling back to standard OpenAI API with model: {self.config['openai']['model']}") # Use logger
return self.client.chat.completions.create(model=self.config["openai"]["model"], messages=messages, stream=True)
except Exception as e:
error_msg = f"API Error (Code: {getattr(e, 'code', 'N/A')}): {str(e)}"
logger.error(error_msg, exc_info=True) # Use logger
raise Exception(error_msg)
def _wrap_mcp_response(self, response: dict):
"""Return the MCP response dictionary directly (for non-streaming)."""
# No conversion needed if app.py handles dicts separately
return response

57
src/providers/__init__.py Normal file
View File

@@ -0,0 +1,57 @@
import logging
from providers.anthropic_provider import AnthropicProvider
from providers.base import BaseProvider
from providers.google_provider import GoogleProvider
from providers.openai_provider import OpenAIProvider
logger = logging.getLogger(__name__)
PROVIDER_MAP: dict[str, type[BaseProvider]] = {
"openai": OpenAIProvider,
"anthropic": AnthropicProvider,
"google": GoogleProvider,
# "openrouter": OpenRouterProvider, # OpenRouter can often use OpenAIProvider with custom base_url
}
def register_provider(name: str, provider_class: type[BaseProvider]):
"""Registers a provider class."""
if name.lower() in PROVIDER_MAP:
logger.warning(f"Provider '{name}' is already registered. Overwriting.")
PROVIDER_MAP[name.lower()] = provider_class
logger.info(f"Registered provider: {name}")
def create_llm_provider(provider_name: str, api_key: str, base_url: str | None = None, temperature: float = 0.6) -> BaseProvider:
"""
Factory function to create an instance of a specific LLM provider.
Args:
provider_name: The name of the provider (e.g., 'openai', 'anthropic').
api_key: The API key for the provider.
base_url: Optional base URL for the provider's API.
Returns:
An instance of the requested BaseProvider subclass.
Raises:
ValueError: If the requested provider_name is not registered.
"""
provider_class = PROVIDER_MAP.get(provider_name.lower())
if provider_class is None:
available = ", ".join(PROVIDER_MAP.keys()) or "None"
raise ValueError(f"Unsupported LLM provider: '{provider_name}'. Available providers: {available}")
logger.info(f"Creating LLM provider instance for: {provider_name} with temperature: {temperature}")
try:
return provider_class(api_key=api_key, base_url=base_url, temperature=temperature)
except Exception as e:
logger.error(f"Failed to instantiate provider '{provider_name}': {e}", exc_info=True)
raise RuntimeError(f"Could not create provider '{provider_name}'.") from e
def get_available_providers() -> list[str]:
"""Returns a list of registered provider names."""
return list(PROVIDER_MAP.keys())

View File

@@ -0,0 +1,44 @@
from providers.anthropic_provider.client import initialize_client
from providers.anthropic_provider.completion import create_chat_completion
from providers.anthropic_provider.response import get_content, get_streaming_content, get_usage
from providers.anthropic_provider.tools import convert_tools, format_tool_results, has_tool_calls, parse_tool_calls
from providers.base import BaseProvider
class AnthropicProvider(BaseProvider):
temperature: float
def __init__(self, api_key: str, base_url: str | None = None, temperature: float = 0.6):
self.client = initialize_client(api_key, base_url)
self.temperature = temperature
def create_chat_completion(
self,
messages,
model,
max_tokens=None,
stream=True,
tools=None,
):
return create_chat_completion(self, messages, model, self.temperature, max_tokens, stream, tools)
def get_streaming_content(self, response):
return get_streaming_content(response)
def get_content(self, response):
return get_content(response)
def has_tool_calls(self, response):
return has_tool_calls(response)
def parse_tool_calls(self, response):
return parse_tool_calls(response)
def format_tool_results(self, tool_call_id, result):
return format_tool_results(tool_call_id, result)
def convert_tools(self, tools):
return convert_tools(tools)
def get_usage(self, response):
return get_usage(response)

View File

@@ -0,0 +1,17 @@
import logging
from anthropic import Anthropic
logger = logging.getLogger(__name__)
def initialize_client(api_key: str, base_url: str | None = None) -> Anthropic:
logger.info("Initializing Anthropic client")
try:
client = Anthropic(api_key=api_key)
if base_url:
logger.warning(f"base_url '{base_url}' provided but not used by Anthropic client")
return client
except Exception as e:
logger.error(f"Failed to initialize Anthropic client: {e}", exc_info=True)
raise

View File

@@ -0,0 +1,38 @@
import logging
from typing import Any
from anthropic import Stream
from anthropic.types import Message
from providers.anthropic_provider.messages import convert_messages, truncate_messages
logger = logging.getLogger(__name__)
def create_chat_completion(
provider, messages: list[dict[str, Any]], model: str, temperature: float = 0.6, max_tokens: int | None = None, stream: bool = True, tools: list[dict[str, Any]] | None = None
) -> Stream | Message:
logger.debug(f"Creating Anthropic chat completion. Model: {model}, Stream: {stream}, Tools: {bool(tools)}")
temp_system_prompt, temp_anthropic_messages = convert_messages(messages)
truncated_messages, final_system_prompt, _, _ = truncate_messages(provider, temp_anthropic_messages, temp_system_prompt, model)
if max_tokens is None:
max_tokens = 4096
logger.warning(f"max_tokens not provided, defaulting to {max_tokens}")
completion_params = {
"model": model,
"messages": truncated_messages,
"temperature": temperature,
"max_tokens": max_tokens,
"stream": stream,
}
if final_system_prompt:
completion_params["system"] = final_system_prompt
if tools:
completion_params["tools"] = tools
try:
response = provider.client.messages.create(**completion_params)
logger.debug("Anthropic API call successful.")
return response
except Exception as e:
logger.error(f"Anthropic API error: {e}", exc_info=True)
raise

View File

@@ -0,0 +1,61 @@
import logging
from typing import Any
from providers.anthropic_provider.utils import count_anthropic_tokens, get_context_window
logger = logging.getLogger(__name__)
def convert_messages(messages: list[dict[str, Any]]) -> tuple[str | None, list[dict[str, Any]]]:
anthropic_messages = []
system_prompt = None
for i, message in enumerate(messages):
role = message.get("role")
content = message.get("content")
if role == "system":
if i == 0:
system_prompt = content
else:
logger.warning("System message not at beginning. Treating as user message.")
anthropic_messages.append({"role": "user", "content": f"[System Note]\n{content}"})
continue
if role == "tool":
tool_use_id = message.get("tool_call_id")
tool_content = content
anthropic_messages.append({"role": "user", "content": [{"type": "tool_result", "tool_use_id": tool_use_id, "content": tool_content}]})
continue
if role == "assistant":
if isinstance(content, list):
anthropic_messages.append({"role": "assistant", "content": content})
else:
anthropic_messages.append({"role": "assistant", "content": content})
continue
if role == "user":
anthropic_messages.append({"role": "user", "content": content})
continue
logger.warning(f"Unsupported role '{role}' in message conversion.")
if not system_prompt and anthropic_messages and anthropic_messages[0]["role"] != "user":
logger.warning("Conversation must start with user message. Prepending placeholder.")
anthropic_messages.insert(0, {"role": "user", "content": "[Start of conversation]"})
return system_prompt, anthropic_messages
def truncate_messages(provider, messages: list[dict[str, Any]], system_prompt: str | None, model: str) -> tuple[list[dict[str, Any]], str | None, int, int]:
context_limit = get_context_window(model)
buffer = 200
effective_limit = context_limit - buffer
initial_token_count = count_anthropic_tokens(provider.client, messages, system_prompt)
final_token_count = initial_token_count
truncated_messages = list(messages)
while final_token_count > effective_limit and len(truncated_messages) > 0:
removed_message = truncated_messages.pop(0)
logger.debug(f"Truncating message (Role: {removed_message.get('role')})")
final_token_count = count_anthropic_tokens(provider.client, truncated_messages, system_prompt)
if initial_token_count != final_token_count:
logger.info(f"Truncated messages. Initial tokens: {initial_token_count}, Final: {final_token_count}")
else:
logger.debug(f"No truncation needed. Tokens: {final_token_count}")
if not system_prompt and truncated_messages and truncated_messages[0].get("role") != "user":
logger.warning("First message after truncation is not 'user'. Prepending placeholder.")
truncated_messages.insert(0, {"role": "user", "content": "[Context truncated]"})
return truncated_messages, system_prompt, initial_token_count, final_token_count

View File

@@ -0,0 +1,62 @@
import json
import logging
from collections.abc import Generator
from typing import Any
from anthropic import Stream
from anthropic.types import Message, MessageStreamEvent, TextDelta
logger = logging.getLogger(__name__)
def get_streaming_content(response: Stream[MessageStreamEvent]) -> Generator[str, None, None]:
logger.debug("Processing Anthropic stream...")
full_delta = ""
try:
for event in response:
if event.type == "content_block_delta":
if isinstance(event.delta, TextDelta):
delta_text = event.delta.text
if delta_text:
full_delta += delta_text
yield delta_text
elif event.type == "message_start":
logger.debug(f"Stream started. Model: {event.message.model}")
elif event.type == "message_stop":
logger.debug("Stream message_stop event received.")
elif event.type == "content_block_start":
if event.content_block.type == "tool_use":
logger.debug(f"Tool use start: ID {event.content_block.id}, Name: {event.content_block.name}")
elif event.type == "content_block_stop":
logger.debug(f"Content block stop. Index: {event.index}")
logger.debug(f"Stream finished. Total delta length: {len(full_delta)}")
except Exception as e:
logger.error(f"Error processing stream: {e}", exc_info=True)
yield json.dumps({"error": f"Stream processing error: {str(e)}"})
def get_content(response: Message) -> str:
try:
text_content = "".join([block.text for block in response.content if block.type == "text"])
logger.debug(f"Extracted content (length {len(text_content)})")
return text_content
except Exception as e:
logger.error(f"Error extracting content: {e}", exc_info=True)
return f"[Error extracting content: {str(e)}]"
def get_usage(response: Any) -> dict[str, int] | None:
try:
if isinstance(response, Message) and response.usage:
usage = {
"prompt_tokens": response.usage.input_tokens,
"completion_tokens": response.usage.output_tokens,
}
logger.debug(f"Extracted usage: {usage}")
return usage
else:
logger.warning(f"Could not extract usage from {type(response)}")
return None
except Exception as e:
logger.error(f"Error extracting usage: {e}", exc_info=True)
return None

View File

@@ -0,0 +1,115 @@
import json
import logging
from typing import Any
from anthropic.types import Message
logger = logging.getLogger(__name__)
def has_tool_calls(response: Any) -> bool:
try:
if isinstance(response, Message):
has_tool_use_block = any(block.type == "tool_use" for block in response.content)
has_calls = response.stop_reason == "tool_use" or has_tool_use_block
logger.debug(f"Tool calls check: stop_reason='{response.stop_reason}', has_tool_use_block={has_tool_use_block}. Result: {has_calls}")
return has_calls
else:
logger.warning(f"has_tool_calls received unexpected type: {type(response)}")
return False
except Exception as e:
logger.error(f"Error checking for tool calls: {e}", exc_info=True)
return False
def parse_tool_calls(response: Message) -> list[dict[str, Any]]:
parsed_calls = []
try:
if not isinstance(response, Message):
logger.error(f"parse_tool_calls expects Message, got {type(response)}")
return []
tool_use_blocks = [block for block in response.content if block.type == "tool_use"]
if not tool_use_blocks:
logger.debug("No 'tool_use' content blocks found.")
return []
logger.debug(f"Parsing {len(tool_use_blocks)} 'tool_use' blocks.")
for block in tool_use_blocks:
parts = block.name.split("__", 1)
if len(parts) == 2:
server_name, func_name = parts
else:
logger.warning(f"Could not determine server_name from tool name '{block.name}'.")
server_name = None
func_name = block.name
parsed_calls.append({"id": block.id, "server_name": server_name, "function_name": func_name, "arguments": block.input})
return parsed_calls
except Exception as e:
logger.error(f"Error parsing tool calls: {e}", exc_info=True)
return []
def format_tool_results(tool_call_id: str, result: Any) -> dict[str, Any]:
try:
if isinstance(result, dict):
content_str = json.dumps(result)
else:
content_str = str(result)
except Exception as e:
logger.error(f"Error encoding tool result for {tool_call_id}: {e}")
content_str = json.dumps({"error": "Failed to encode tool result", "original_type": str(type(result))})
logger.debug(f"Formatting tool result for call ID {tool_call_id}")
return {"type": "tool_result", "tool_use_id": tool_call_id, "content": content_str}
def convert_to_anthropic_tools(mcp_tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""
Convert MCP tools to Anthropic tool definitions.
Args:
mcp_tools: List of MCP tools (each with server_name, name, description, inputSchema).
Returns:
List of Anthropic tool definitions.
"""
logger.debug(f"Converting {len(mcp_tools)} MCP tools to Anthropic format")
anthropic_tools = []
for tool in mcp_tools:
server_name = tool.get("server_name")
tool_name = tool.get("name")
description = tool.get("description")
input_schema = tool.get("inputSchema")
if not server_name or not tool_name or not description or not input_schema:
logger.warning(f"Skipping invalid MCP tool definition during Anthropic conversion: {tool}")
continue
prefixed_tool_name = f"{server_name}__{tool_name}"
anthropic_tool = {"name": prefixed_tool_name, "description": description, "input_schema": input_schema}
if not isinstance(input_schema, dict) or input_schema.get("type") != "object":
logger.warning(f"Input schema for tool '{prefixed_tool_name}' is not a valid JSON object schema. Anthropic might reject this.")
if not isinstance(input_schema, dict):
input_schema = {}
if "type" not in input_schema:
input_schema["type"] = "object"
if "properties" not in input_schema:
input_schema["properties"] = {}
anthropic_tool["input_schema"] = input_schema
anthropic_tools.append(anthropic_tool)
logger.debug(f"Converted MCP tool to Anthropic: {prefixed_tool_name}")
return anthropic_tools
def convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
logger.debug(f"Converting {len(tools)} tools to Anthropic format.")
try:
anthropic_tools = convert_to_anthropic_tools(tools)
logger.debug(f"Tool conversion result: {anthropic_tools}")
return anthropic_tools
except Exception as e:
logger.error(f"Error during tool conversion: {e}", exc_info=True)
return []

View File

@@ -0,0 +1,50 @@
import json
import logging
import math
from typing import Any
from anthropic import Anthropic
from src.llm_models import MODELS
logger = logging.getLogger(__name__)
def get_context_window(model: str) -> int:
default_window = 100000
try:
provider_models = MODELS.get("anthropic", {}).get("models", [])
for m in provider_models:
if m.get("id") == model:
return m.get("context_window", default_window)
logger.warning(f"Context window for Anthropic model '{model}' not found. Using default: {default_window}")
return default_window
except Exception as e:
logger.error(f"Error retrieving context window for model {model}: {e}. Using default: {default_window}", exc_info=True)
return default_window
def count_anthropic_tokens(client: Anthropic, messages: list[dict[str, Any]], system_prompt: str | None) -> int:
text_to_count = ""
if system_prompt:
text_to_count += f"System: {system_prompt}\n\n"
for message in messages:
role = message.get("role")
content = message.get("content")
if isinstance(content, str):
text_to_count += f"{role}: {content}\n"
elif isinstance(content, list):
try:
content_str = json.dumps(content)
text_to_count += f"{role}: {content_str}\n"
except Exception:
text_to_count += f"{role}: [Unserializable Content]\n"
try:
count = client.count_tokens(text=text_to_count)
logger.debug(f"Counted Anthropic tokens: {count}")
return count
except Exception as e:
logger.error(f"Error counting Anthropic tokens: {e}", exc_info=True)
estimated_tokens = math.ceil(len(text_to_count) / 4.0)
logger.warning(f"Falling back to approximation: {estimated_tokens}")
return estimated_tokens

148
src/providers/base.py Normal file
View File

@@ -0,0 +1,148 @@
import abc
from collections.abc import Generator
from typing import Any
class BaseProvider(abc.ABC):
"""
Abstract base class for LLM providers.
Defines the common interface for interacting with different LLM APIs,
including handling chat completions and tool usage.
"""
def __init__(self, api_key: str, base_url: str | None = None):
"""
Initialize the provider.
Args:
api_key: The API key for the provider.
base_url: Optional base URL for the provider's API.
"""
self.api_key = api_key
self.base_url = base_url
@abc.abstractmethod
def create_chat_completion(
self,
messages: list[dict[str, str]],
model: str,
temperature: float = 0.6,
max_tokens: int | None = None,
stream: bool = True,
tools: list[dict[str, Any]] | None = None,
) -> Any:
"""
Send a chat completion request to the LLM provider.
Args:
messages: List of message dictionaries with 'role' and 'content'.
model: Model identifier.
temperature: Sampling temperature (0-2).
max_tokens: Maximum tokens to generate.
stream: Whether to stream the response.
tools: Optional list of tools in the provider-specific format.
Returns:
Provider-specific response object (e.g., API response, stream object).
"""
pass
@abc.abstractmethod
def get_streaming_content(self, response: Any) -> Generator[str, None, None]:
"""
Extracts and yields content chunks from a streaming response object.
Args:
response: The streaming response object returned by create_chat_completion.
Yields:
String chunks of the response content.
"""
pass
@abc.abstractmethod
def get_content(self, response: Any) -> str:
"""
Extracts the complete content from a non-streaming response object.
Args:
response: The non-streaming response object.
Returns:
The complete response content as a string.
"""
pass
@abc.abstractmethod
def has_tool_calls(self, response: Any) -> bool:
"""
Checks if the response object contains tool calls.
Args:
response: The response object (streaming or non-streaming).
Returns:
True if tool calls are present, False otherwise.
"""
pass
@abc.abstractmethod
def parse_tool_calls(self, response: Any) -> list[dict[str, Any]]:
"""
Parses tool calls from the response object.
Args:
response: The response object containing tool calls.
Returns:
A list of dictionaries, each representing a tool call with details
like 'id', 'function_name', 'arguments'. The exact structure might
vary slightly based on provider needs but should contain enough
info for execution.
"""
pass
@abc.abstractmethod
def format_tool_results(self, tool_call_id: str, result: Any) -> dict[str, Any]:
"""
Formats the result of a tool execution into the structure expected
by the provider for follow-up requests.
Args:
tool_call_id: The unique ID of the tool call (from parse_tool_calls).
result: The data returned by the tool execution.
Returns:
A dictionary representing the tool result in the provider's format.
"""
pass
@abc.abstractmethod
def convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""
Converts a list of tools from the standard internal format to the
provider-specific format required for the API call.
Args:
tools: List of tool definitions in the standard internal format.
Each dict contains 'server_name', 'name', 'description', 'input_schema'.
Returns:
List of tool definitions in the provider-specific format.
"""
pass
@abc.abstractmethod
def get_usage(self, response: Any) -> dict[str, int] | None:
"""
Extracts token usage information from a non-streaming response object.
Args:
response: The non-streaming response object.
Returns:
A dictionary containing 'prompt_tokens' and 'completion_tokens',
or None if usage information is not available.
"""
pass

View File

@@ -0,0 +1,78 @@
import logging
from collections.abc import Generator
from typing import Any
from google.genai.types import GenerateContentResponse
from providers.google_provider.client import initialize_client
from providers.google_provider.completion import create_chat_completion
from providers.google_provider.response import get_content, get_streaming_content, get_usage
from providers.google_provider.tools import convert_to_google_tools, format_google_tool_results, has_google_tool_calls, parse_google_tool_calls
from src.providers.base import BaseProvider
logger = logging.getLogger(__name__)
class GoogleProvider(BaseProvider):
"""Provider implementation for Google Generative AI (Gemini)."""
client_module: Any
temperature: float
def __init__(self, api_key: str, base_url: str | None = None, temperature: float = 0.6):
"""
Initializes the GoogleProvider.
Args:
api_key: The Google API key.
base_url: Base URL (typically not used by Google client config, but kept for interface consistency).
temperature: The default temperature for completions.
"""
self.client_module = initialize_client(api_key, base_url)
self.api_key = api_key
self.base_url = base_url
self.temperature = temperature
logger.info(f"GoogleProvider initialized with temperature: {self.temperature}")
def create_chat_completion(
self,
messages: list[dict[str, Any]],
model: str,
max_tokens: int | None = None,
stream: bool = True,
tools: list[dict[str, Any]] | None = None,
) -> Any:
"""Creates a chat completion using the Google Gemini API."""
raw_response = create_chat_completion(self, messages, model, self.temperature, max_tokens, stream, tools)
print(f"Raw response type: {type(raw_response)}")
print(f"Raw response: {raw_response}")
return raw_response
def get_streaming_content(self, response: Any) -> Generator[str, None, None]:
"""Extracts content chunks from a Google streaming response."""
return get_streaming_content(response)
def get_content(self, response: GenerateContentResponse | dict[str, Any]) -> str:
"""Extracts the full text content from a non-streaming Google response."""
return get_content(response)
def has_tool_calls(self, response: GenerateContentResponse | dict[str, Any]) -> bool:
"""Checks if the Google response contains tool calls (FunctionCalls)."""
return has_google_tool_calls(response)
def parse_tool_calls(self, response: GenerateContentResponse | dict[str, Any]) -> list[dict[str, Any]]:
"""Parses tool calls (FunctionCalls) from a non-streaming Google response."""
return parse_google_tool_calls(response)
def format_tool_results(self, tool_call_id: str, function_name: str, result: Any) -> dict[str, Any]:
"""Formats a tool result for a Google follow-up request (into standard message format)."""
return format_google_tool_results(tool_call_id, function_name, result)
def convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Converts MCP tools list to Google's intermediate dictionary format."""
return convert_to_google_tools(tools)
def get_usage(self, response: GenerateContentResponse | dict[str, Any]) -> dict[str, int] | None:
"""Extracts token usage information from a Google response."""
return get_usage(response)

View File

@@ -0,0 +1,25 @@
import logging
from typing import Any
from google import genai
logger = logging.getLogger(__name__)
def initialize_client(api_key: str, base_url: str | None = None) -> Any:
"""Initializes and returns the Google Generative AI client module."""
logger.info("Initializing Google Generative AI client")
if genai is None:
logger.error("Google Generative AI SDK (google-genai) is not installed.")
raise ImportError("Google Generative AI SDK is required for GoogleProvider. Please install google-generativeai.")
try:
client = genai.Client(api_key=api_key)
logger.info("Google Generative AI client instantiated.")
if base_url:
logger.warning(f"base_url '{base_url}' provided but not typically used by Google client instantiation.")
return client
except Exception as e:
logger.error(f"Failed to instantiate Google Generative AI client: {e}", exc_info=True)
raise

View File

@@ -0,0 +1,140 @@
import logging
import traceback
from collections.abc import Iterable
from typing import Any
from google.genai.types import ContentDict, GenerateContentResponse, GenerationConfigDict, Tool
from providers.google_provider.tools import convert_to_google_tool_objects
from providers.google_provider.utils import convert_messages
logger = logging.getLogger(__name__)
def _create_chat_completion_non_stream(
provider,
model: str,
google_messages: list[ContentDict],
generation_config: GenerationConfigDict,
) -> GenerateContentResponse | dict[str, Any]:
"""Handles the non-streaming API call."""
try:
logger.debug("Calling client.models.generate_content...")
response = provider.client_module.models.generate_content(
model=f"models/{model}",
contents=google_messages,
config=generation_config,
)
logger.debug("generate_content call successful, returning raw response object.")
return response
except ValueError as ve:
error_msg = f"Google API request validation error: {ve}"
logger.error(error_msg, exc_info=True)
return {"error": error_msg, "traceback": traceback.format_exc()}
except Exception as e:
error_msg = f"Google API error during non-stream chat completion: {e}"
logger.error(error_msg, exc_info=True)
return {"error": error_msg, "traceback": traceback.format_exc()}
def _create_chat_completion_stream(
provider,
model: str,
google_messages: list[ContentDict],
generation_config: GenerationConfigDict,
) -> Iterable[GenerateContentResponse | dict[str, Any]]:
"""Handles the streaming API call and yields results."""
try:
logger.debug("Calling client.models.generate_content_stream...")
response_iterator = provider.client_module.models.generate_content_stream(
model=f"models/{model}",
contents=google_messages,
config=generation_config,
)
logger.debug("generate_content_stream call successful, yielding from iterator.")
yield from response_iterator
except ValueError as ve:
error_msg = f"Google API request validation error: {ve}"
logger.error(error_msg, exc_info=True)
yield {"error": error_msg, "traceback": traceback.format_exc()}
except Exception as e:
error_msg = f"Google API error during stream chat completion: {e}"
logger.error(error_msg, exc_info=True)
yield {"error": error_msg, "traceback": traceback.format_exc()}
def create_chat_completion(
provider,
messages: list[dict[str, Any]],
model: str,
temperature: float = 0.6,
max_tokens: int | None = None,
stream: bool = True,
tools: list[dict[str, Any]] | None = None,
) -> Any:
"""
Creates a chat completion using the Google Gemini API.
Delegates to streaming or non-streaming helpers. Contains NO yield itself.
"""
logger.debug(f"Google create_chat_completion_inner called. Model: {model}, Stream: {stream}, Tools: {bool(tools)}")
if provider.client_module is None:
error_msg = "Google Generative AI client not initialized on provider."
logger.error(error_msg)
return iter([{"error": error_msg}]) if stream else {"error": error_msg}
try:
google_messages, system_prompt = convert_messages(messages)
logger.debug(f"Converted {len(messages)} messages to {len(google_messages)} Google Content objects. System prompt present: {bool(system_prompt)}")
generation_config: GenerationConfigDict = {"temperature": temperature}
if max_tokens is not None:
generation_config["max_output_tokens"] = max_tokens
logger.debug(f"Setting max_output_tokens: {max_tokens}")
else:
default_max_tokens = 8192
generation_config["max_output_tokens"] = default_max_tokens
logger.warning(f"max_tokens not provided, defaulting to {default_max_tokens} for Google API.")
google_tool_objects: list[Tool] | None = None
if tools:
try:
google_tool_objects = convert_to_google_tool_objects(tools)
if google_tool_objects:
num_declarations = sum(len(t.function_declarations) for t in google_tool_objects if t.function_declarations)
logger.debug(f"Successfully converted intermediate tool config to {len(google_tool_objects)} Google Tool objects with {num_declarations} declarations.")
else:
logger.warning("Tool conversion resulted in no valid Google Tool objects.")
except Exception as tool_conv_err:
logger.error(f"Failed to convert tools for Google: {tool_conv_err}", exc_info=True)
google_tool_objects = None
else:
logger.debug("No tools provided for conversion.")
if system_prompt:
generation_config["system_instruction"] = system_prompt
logger.debug("Added system_instruction to generation_config.")
if google_tool_objects:
generation_config["tools"] = google_tool_objects
logger.debug(f"Added {len(google_tool_objects)} tool objects to generation_config.")
log_params = {
"model": model,
"stream": stream,
"temperature": temperature,
"max_output_tokens": generation_config.get("max_output_tokens"),
"system_prompt_present": bool(system_prompt),
"num_tools": len(generation_config.get("tools", [])) if "tools" in generation_config else 0,
"num_messages": len(google_messages),
}
logger.info(f"Calling Google API via helper with params: {log_params}")
if stream:
return _create_chat_completion_stream(provider, model, google_messages, generation_config)
else:
return _create_chat_completion_non_stream(provider, model, google_messages, generation_config)
except Exception as e:
error_msg = f"Error during Google completion setup: {e}"
logger.error(error_msg, exc_info=True)
return iter([{"error": error_msg, "traceback": traceback.format_exc()}]) if stream else {"error": error_msg, "traceback": traceback.format_exc()}

View File

@@ -0,0 +1,185 @@
"""
Response handling utilities specific to the Google Generative AI provider.
Includes functions for:
- Extracting content from streaming responses.
- Extracting content from non-streaming responses.
- Extracting token usage information.
"""
import json
import logging
from collections.abc import Generator
from typing import Any
from google.genai.types import GenerateContentResponse
logger = logging.getLogger(__name__)
def get_streaming_content(response: Any) -> Generator[str, None, None]:
"""
Yields content chunks (text) from a Google streaming response iterator.
Args:
response: The streaming response iterator returned by `generate_content(stream=True)`.
Yields:
String chunks of the generated text content.
May yield JSON strings containing error information if errors occur during streaming.
"""
logger.debug("Processing Google stream...")
full_delta = ""
try:
if isinstance(response, dict) and "error" in response:
yield json.dumps(response)
logger.error(f"Stream processing stopped due to initial error: {response['error']}")
return
if hasattr(response, "__iter__") and not hasattr(response, "candidates"):
first_item = next(response, None)
if first_item and isinstance(first_item, str):
try:
error_data = json.loads(first_item)
if "error" in error_data:
yield first_item
yield from response
logger.error(f"Stream processing stopped due to yielded error: {error_data['error']}")
return
except json.JSONDecodeError:
yield first_item
elif first_item:
pass
for chunk in response:
if isinstance(chunk, dict) and "error" in chunk:
yield json.dumps(chunk)
logger.error(f"Error encountered during Google stream: {chunk['error']}")
continue
delta = ""
try:
if hasattr(chunk, "text"):
delta = chunk.text
elif hasattr(chunk, "candidates") and chunk.candidates:
first_candidate = chunk.candidates[0]
if hasattr(first_candidate, "content") and hasattr(first_candidate.content, "parts") and first_candidate.content.parts:
first_part = first_candidate.content.parts[0]
if hasattr(first_part, "text"):
delta = first_part.text
except Exception as e:
logger.warning(f"Could not extract text from stream chunk: {chunk}. Error: {e}", exc_info=True)
delta = ""
if delta:
full_delta += delta
yield delta
try:
if hasattr(chunk, "candidates") and chunk.candidates:
for part in chunk.candidates[0].content.parts:
if hasattr(part, "function_call") and part.function_call:
logger.debug(f"Function call detected during stream: {part.function_call.name}")
break
except Exception:
pass
logger.debug(f"Google stream finished. Total delta length: {len(full_delta)}")
except StopIteration:
logger.debug("Google stream finished (StopIteration).")
except Exception as e:
logger.error(f"Error processing Google stream: {e}", exc_info=True)
yield json.dumps({"error": f"Stream processing error: {str(e)}"})
def get_content(response: GenerateContentResponse | dict[str, Any]) -> str:
"""
Extracts the full text content from a non-streaming Google response.
Args:
response: The non-streaming response object (`GenerateContentResponse`) or
an error dictionary.
Returns:
The concatenated text content, or an error message string.
"""
try:
if isinstance(response, dict) and "error" in response:
logger.error(f"Cannot get content from error dict: {response['error']}")
return f"[Error: {response['error']}]"
if not isinstance(response, GenerateContentResponse):
logger.error(f"Cannot get content: Expected GenerateContentResponse or error dict, got {type(response)}")
return f"[Error: Unexpected response type {type(response)}]"
if hasattr(response, "text") and response.text:
content = response.text
logger.debug(f"Extracted content (length {len(content)}) from response.text.")
return content
if hasattr(response, "candidates") and response.candidates:
first_candidate = response.candidates[0]
if hasattr(first_candidate, "content") and first_candidate.content and hasattr(first_candidate.content, "parts") and first_candidate.content.parts:
text_parts = [part.text for part in first_candidate.content.parts if hasattr(part, "text")]
if text_parts:
content = "".join(text_parts)
logger.debug(f"Extracted content (length {len(content)}) from response candidate parts.")
return content
else:
logger.warning("Google response candidate parts contained no text.")
return ""
else:
logger.warning("Google response candidate has no valid content or parts.")
return ""
else:
logger.warning(f"Could not extract content from Google response: No .text or valid candidates found. Response: {response}")
return ""
except AttributeError as ae:
logger.error(f"Attribute error extracting content from Google response: {ae}. Response type: {type(response)}", exc_info=True)
return f"[Error extracting content: Attribute missing - {str(ae)}]"
except Exception as e:
logger.error(f"Unexpected error extracting content from Google response: {e}", exc_info=True)
return f"[Error extracting content: {str(e)}]"
def get_usage(response: GenerateContentResponse | dict[str, Any]) -> dict[str, int] | None:
"""
Extracts token usage information from a Google response object.
Args:
response: The response object (`GenerateContentResponse`) or an error dictionary.
Returns:
A dictionary containing 'prompt_tokens' and 'completion_tokens', or None if
usage information is unavailable or an error occurred.
"""
try:
if isinstance(response, dict) and "error" in response:
logger.warning(f"Cannot get usage from error dict: {response['error']}")
return None
if not isinstance(response, GenerateContentResponse):
logger.warning(f"Cannot get usage: Expected GenerateContentResponse or error dict, got {type(response)}")
return None
metadata = getattr(response, "usage_metadata", None)
if metadata:
prompt_tokens = getattr(metadata, "prompt_token_count", 0)
completion_tokens = getattr(metadata, "candidates_token_count", 0)
usage = {
"prompt_tokens": int(prompt_tokens),
"completion_tokens": int(completion_tokens),
}
logger.debug(f"Extracted usage from Google response metadata: {usage}")
return usage
else:
logger.warning(f"Could not extract usage from Google response object: No 'usage_metadata' attribute found. Response: {response}")
return None
except AttributeError as ae:
logger.error(f"Attribute error extracting usage from Google response: {ae}. Response type: {type(response)}", exc_info=True)
return None
except Exception as e:
logger.error(f"Unexpected error extracting usage from Google response: {e}", exc_info=True)
return None

View File

@@ -0,0 +1,365 @@
"""
Tool handling utilities specific to the Google Generative AI provider.
Includes functions for:
- Converting MCP tool definitions to Google's format.
- Creating Google Tool/FunctionDeclaration objects.
- Parsing tool calls (FunctionCalls) from Google responses.
- Formatting tool results for subsequent API calls.
"""
import json
import logging
from typing import Any
from google.genai.types import FunctionDeclaration, Schema, Tool, Type
logger = logging.getLogger(__name__)
def convert_to_google_tools(mcp_tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""
Convert MCP tools to Google Gemini format (dictionary structure).
This format is an intermediate step before creating Tool objects.
Args:
mcp_tools: List of MCP tools (each with server_name, name, description, inputSchema).
Returns:
List containing one dictionary with 'function_declarations'.
Returns an empty list if no valid tools are provided or converted.
"""
logger.debug(f"Converting {len(mcp_tools)} MCP tools to Google Gemini format")
function_declarations = []
for tool in mcp_tools:
server_name = tool.get("server_name")
tool_name = tool.get("name")
description = tool.get("description")
input_schema = tool.get("inputSchema")
if not server_name or not tool_name or not description or not input_schema:
logger.warning(f"Skipping invalid MCP tool definition during Google conversion: {tool}")
continue
prefixed_tool_name = f"{server_name}__{tool_name}"
if not isinstance(input_schema, dict) or input_schema.get("type") != "object":
logger.warning(f"Input schema for tool '{prefixed_tool_name}' is not a valid JSON object schema. Google might reject this. Attempting to normalize.")
if not isinstance(input_schema, dict):
input_schema = {}
if "type" not in input_schema or input_schema["type"] != "object":
input_schema = {"type": "object", "properties": {"_original_schema": input_schema}} if input_schema else {"type": "object", "properties": {}}
logger.warning(f"Wrapped original schema for {prefixed_tool_name} under '_original_schema' property.")
if "properties" not in input_schema:
input_schema["properties"] = {}
if not input_schema["properties"]:
logger.warning(f"Empty properties for tool '{prefixed_tool_name}', adding dummy property for Google.")
input_schema["properties"] = {"_dummy_param": {"type": "STRING", "description": "Placeholder parameter as properties cannot be empty."}}
if "required" in input_schema and not isinstance(input_schema.get("required"), list):
input_schema["required"] = []
function_declaration = {
"name": prefixed_tool_name,
"description": description,
"parameters": input_schema,
}
function_declarations.append(function_declaration)
logger.debug(f"Prepared Google FunctionDeclaration dict for: {prefixed_tool_name}")
google_tool_config = [{"function_declarations": function_declarations}] if function_declarations else []
logger.debug(f"Final Google tool config structure (pre-Tool object): {google_tool_config}")
return google_tool_config
def _create_google_schema_recursive(schema_dict: dict[str, Any]) -> Schema | None:
"""
Recursively creates Google Schema objects from a JSON schema dictionary.
Handles type mapping and nested structures. Returns None on failure.
"""
if Schema is None or Type is None:
logger.error("Cannot create Schema object: google.genai types (Schema or Type) not available.")
return None
if not isinstance(schema_dict, dict):
logger.warning(f"Invalid schema part encountered: {schema_dict}. Returning None.")
return None
type_mapping = {
"string": Type.STRING,
"number": Type.NUMBER,
"integer": Type.INTEGER,
"boolean": Type.BOOLEAN,
"array": Type.ARRAY,
"object": Type.OBJECT,
}
original_type = schema_dict.get("type")
google_type = type_mapping.get(str(original_type).lower()) if original_type else None
if not google_type:
logger.warning(f"Schema dictionary missing 'type' or type '{original_type}' is not recognized: {schema_dict}. Returning None.")
return None
schema_args = {
"type": google_type,
"format": schema_dict.get("format"),
"description": schema_dict.get("description"),
"nullable": schema_dict.get("nullable"),
"enum": schema_dict.get("enum"),
"items": _create_google_schema_recursive(schema_dict["items"]) if google_type == Type.ARRAY and "items" in schema_dict else None,
"properties": {k: prop_schema for k, v in schema_dict.get("properties", {}).items() if (prop_schema := _create_google_schema_recursive(v)) is not None}
if google_type == Type.OBJECT and schema_dict.get("properties")
else None,
"required": schema_dict.get("required") if google_type == Type.OBJECT else None,
}
schema_args = {k: v for k, v in schema_args.items() if v is not None}
if google_type == Type.ARRAY and "items" not in schema_args:
logger.warning(f"Array schema missing 'items': {schema_dict}. Returning None.")
return None
if google_type == Type.OBJECT and "properties" not in schema_args:
pass
try:
created_schema = Schema(**schema_args)
return created_schema
except Exception as schema_creation_err:
logger.error(f"Failed to create Schema object with args {schema_args}: {schema_creation_err}", exc_info=True)
return None
def convert_to_google_tool_objects(tool_configs: list[dict[str, Any]]) -> list[Tool] | None:
"""
Convert the dictionary-based tool configurations into Google's Tool objects.
Args:
tool_configs: A list containing a dictionary with 'function_declarations',
as produced by `convert_to_google_tools`.
Returns:
A list containing a single Google `Tool` object, or None if conversion fails
or no valid declarations are found.
"""
if Tool is None or FunctionDeclaration is None:
logger.error("Cannot create Tool objects: google.genai types not available.")
return None
if not tool_configs:
logger.debug("No tool configurations provided to convert to Tool objects.")
return None
all_func_declarations = []
if isinstance(tool_configs, list) and len(tool_configs) > 0 and "function_declarations" in tool_configs[0]:
func_declarations_list = tool_configs[0]["function_declarations"]
if not isinstance(func_declarations_list, list):
logger.error(f"Expected 'function_declarations' to be a list, got {type(func_declarations_list)}")
return None
for func_dict in func_declarations_list:
func_name = func_dict.get("name", "Unknown")
try:
params_schema_dict = func_dict.get("parameters", {})
if not isinstance(params_schema_dict, dict):
logger.warning(f"Invalid 'parameters' format for tool {func_name}: {params_schema_dict}. Using empty object schema.")
params_schema_dict = {"type": "object", "properties": {}}
elif "type" not in params_schema_dict:
params_schema_dict["type"] = "object"
elif params_schema_dict["type"] != "object":
logger.warning(f"Tool {func_name} parameters schema is not type 'object' ({params_schema_dict.get('type')}). Google requires 'object'. Attempting to wrap properties.")
original_properties = params_schema_dict.get("properties", {})
if not isinstance(original_properties, dict):
original_properties = {}
params_schema_dict = {"type": "object", "properties": original_properties}
properties_dict = params_schema_dict.get("properties", {})
google_properties = {}
if isinstance(properties_dict, dict):
for prop_name, prop_schema_dict in properties_dict.items():
prop_schema = _create_google_schema_recursive(prop_schema_dict)
if prop_schema:
google_properties[prop_name] = prop_schema
else:
logger.warning(f"Failed to create schema for property '{prop_name}' in tool '{func_name}'. Skipping property.")
else:
logger.warning(f"'properties' for tool {func_name} is not a dictionary: {properties_dict}. Ignoring properties.")
if not google_properties:
logger.warning(f"Function '{func_name}' has no valid properties defined. Adding dummy property for Google compatibility.")
google_properties = {"_dummy_param": Schema(type=Type.STRING, description="Placeholder parameter as properties cannot be empty.")}
required_list = []
else:
original_required = params_schema_dict.get("required", [])
if isinstance(original_required, list):
required_list = [req for req in original_required if req in google_properties]
if len(required_list) != len(original_required):
logger.warning(f"Some required properties for '{func_name}' were invalid or missing from properties: {set(original_required) - set(required_list)}")
else:
logger.warning(f"'required' field for '{func_name}' is not a list: {original_required}. Ignoring required field.")
required_list = []
parameters_schema = Schema(
type=Type.OBJECT,
properties=google_properties,
required=required_list if required_list else None,
)
declaration = FunctionDeclaration(
name=func_name,
description=func_dict.get("description", ""),
parameters=parameters_schema,
)
all_func_declarations.append(declaration)
logger.debug(f"Successfully created FunctionDeclaration for: {func_name}")
except Exception as decl_err:
logger.error(f"Failed to create FunctionDeclaration object for tool '{func_name}': {decl_err}", exc_info=True)
else:
logger.error(f"Invalid tool_configs structure provided: {tool_configs}")
return None
if not all_func_declarations:
logger.warning("No valid Google FunctionDeclarations were created from the provided configurations.")
return None
logger.info(f"Successfully created {len(all_func_declarations)} Google FunctionDeclarations.")
return [Tool(function_declarations=all_func_declarations)]
def has_google_tool_calls(response: Any) -> bool:
"""
Checks if the Google response object contains tool calls (FunctionCalls).
Args:
response: The response object from the Google generate_content API call.
Returns:
True if FunctionCalls are present, False otherwise.
"""
try:
if hasattr(response, "candidates") and response.candidates:
candidate = response.candidates[0]
if hasattr(candidate, "content") and hasattr(candidate.content, "parts"):
for part in candidate.content.parts:
if hasattr(part, "function_call") and part.function_call:
logger.debug(f"Tool call (FunctionCall) detected in Google response part: {part.function_call.name}")
return True
logger.debug("No tool calls (FunctionCall) detected in Google response.")
return False
except Exception as e:
logger.error(f"Error checking for Google tool calls: {e}", exc_info=True)
return False
def parse_google_tool_calls(response: Any) -> list[dict[str, Any]]:
"""
Parses tool calls (FunctionCalls) from a non-streaming Google response object.
Args:
response: The non-streaming response object from the Google generate_content API call.
Returns:
A list of dictionaries, each representing a tool call in the standard MCP format
(id, server_name, function_name, arguments as JSON string).
Returns an empty list if no calls are found or an error occurs.
"""
parsed_calls = []
try:
if not (hasattr(response, "candidates") and response.candidates):
logger.warning("Cannot parse tool calls: Response has no candidates.")
return []
candidate = response.candidates[0]
if not (hasattr(candidate, "content") and hasattr(candidate.content, "parts")):
logger.warning("Cannot parse tool calls: Response candidate has no content or parts.")
return []
logger.debug("Parsing tool calls (FunctionCall) from Google response.")
call_index = 0
for part in candidate.content.parts:
if hasattr(part, "function_call") and part.function_call:
func_call = part.function_call
call_id = f"call_{call_index}"
call_index += 1
full_name = func_call.name
parts = full_name.split("__", 1)
if len(parts) == 2:
server_name, func_name = parts
else:
logger.warning(f"Could not determine server_name from Google tool name '{full_name}'. Using None for server_name.")
server_name = None
func_name = full_name
try:
args_dict = dict(func_call.args) if func_call.args else {}
args_str = json.dumps(args_dict)
except Exception as json_err:
logger.error(f"Failed to dump arguments dict to JSON string for {func_name}: {json_err}")
args_str = json.dumps({"error": "Failed to serialize arguments", "original_args": str(func_call.args)})
parsed_calls.append({
"id": call_id,
"server_name": server_name,
"function_name": func_name,
"arguments": args_str,
"_google_tool_name": full_name,
})
logger.debug(f"Parsed tool call: ID {call_id}, Server {server_name}, Func {func_name}, Args {args_str[:100]}...")
return parsed_calls
except Exception as e:
logger.error(f"Error parsing Google tool calls: {e}", exc_info=True)
return []
def format_google_tool_results(tool_call_id: str, function_name: str, result: Any) -> dict[str, Any]:
"""
Formats a tool result for a Google follow-up request (FunctionResponse).
Args:
tool_call_id: The unique ID assigned during parsing (e.g., "call_0").
Note: Google's API itself doesn't use this ID directly in the
FunctionResponse part, but we need it for mapping in the message list.
function_name: The original function name (without server prefix) that was called.
result: The data returned by the tool execution. Should be JSON-serializable.
Returns:
A dictionary representing the tool result message in the standard MCP format.
This will be converted later by `_convert_messages`.
"""
try:
if isinstance(result, (str, int, float, bool, list)):
content_dict = {"result": result}
elif isinstance(result, dict):
content_dict = result
else:
logger.warning(f"Tool result for {function_name} is of non-standard type {type(result)}. Converting to string.")
content_dict = {"result": str(result)}
try:
content_str = json.dumps(content_dict)
except Exception as json_err:
logger.error(f"Error JSON-encoding tool result content for Google {function_name} ({tool_call_id}): {json_err}")
content_str = json.dumps({"error": "Failed to encode tool result content", "original_type": str(type(result))})
except Exception as e:
logger.error(f"Error preparing tool result content for Google {function_name} ({tool_call_id}): {e}")
content_str = json.dumps({"error": "Failed to prepare tool result content", "details": str(e)})
logger.debug(f"Formatting Google tool result for call ID {tool_call_id} (Function: {function_name})")
return {
"role": "tool",
"tool_call_id": tool_call_id,
"content": content_str,
"name": function_name,
}

View File

@@ -0,0 +1,127 @@
import json
import logging
from typing import Any
from google.genai.types import Content, Part
from src.llm_models import MODELS
logger = logging.getLogger(__name__)
def get_context_window(model: str) -> int:
"""Retrieves the context window size for a given Google model."""
default_window = 1000000
try:
provider_models = MODELS.get("google", {}).get("models", [])
for m in provider_models:
if m.get("id") == model:
return m.get("context_window", default_window)
logger.warning(f"Context window for Google model '{model}' not found in MODELS config. Using default: {default_window}")
return default_window
except Exception as e:
logger.error(f"Error retrieving context window for model {model}: {e}. Using default: {default_window}", exc_info=True)
return default_window
def convert_messages(messages: list[dict[str, Any]]) -> tuple[list[Content], str | None]:
"""
Converts standard message format to Google's format, extracting system prompt.
Handles mapping roles and structuring tool calls/results.
"""
google_messages: list[Content] = []
system_prompt: str | None = None
for i, message in enumerate(messages):
role = message.get("role")
content = message.get("content")
tool_calls = message.get("tool_calls")
tool_call_id = message.get("tool_call_id")
if role == "system":
if i == 0:
system_prompt = content
logger.debug("Extracted system prompt for Google.")
else:
logger.warning("System message found not at the beginning. Skipping for Google API.")
continue
google_role = {"user": "user", "assistant": "model"}.get(role)
if not google_role and role != "tool":
logger.warning(f"Unsupported role '{role}' for Google provider, skipping message.")
continue
parts: list[Part | str] = []
if role == "tool":
if tool_call_id and content:
try:
response_content_dict = json.loads(content)
except json.JSONDecodeError:
logger.warning(f"Could not decode tool result content for {tool_call_id}, sending as raw string.")
response_content_dict = {"result": content}
func_name = "unknown_function"
if i > 0 and messages[i - 1].get("role") == "assistant":
prev_tool_calls = messages[i - 1].get("tool_calls")
if prev_tool_calls:
for tc in prev_tool_calls:
if tc.get("id") == tool_call_id:
full_name = tc.get("function_name", "unknown_function")
func_name = full_name.split("__", 1)[-1]
break
parts.append(Part.from_function_response(name=func_name, response={"content": response_content_dict}))
google_role = "function"
else:
logger.warning(f"Skipping tool message due to missing tool_call_id or content: {message}")
continue
elif role == "assistant" and tool_calls:
for tool_call in tool_calls:
args = tool_call.get("arguments", {})
if isinstance(args, str):
try:
args = json.loads(args)
except json.JSONDecodeError:
logger.error(f"Failed to parse arguments string for tool call {tool_call.get('id')}: {args}")
args = {"error": "failed to parse arguments"}
full_name = tool_call.get("function_name", "unknown_function")
func_name = full_name.split("__", 1)[-1]
parts.append(Part.from_function_call(name=func_name, args=args))
if content and isinstance(content, str):
parts.append(Part(text=content))
elif content:
if isinstance(content, str):
parts.append(Part(text=content))
else:
logger.warning(f"Unsupported content type for role '{role}': {type(content)}. Converting to string.")
parts.append(Part(text=str(content)))
if parts:
google_messages.append(Content(role=google_role, parts=parts))
else:
logger.debug(f"No parts generated for message: {message}")
last_role = None
valid_alternation = True
for msg in google_messages:
current_role = msg.role
if current_role == last_role and current_role in ["user", "model"]:
valid_alternation = False
logger.error(f"Invalid role sequence for Google: consecutive '{current_role}' roles.")
break
if last_role == "function" and current_role != "user":
valid_alternation = False
logger.error(f"Invalid role sequence for Google: '{current_role}' follows 'function'. Expected 'user'.")
break
last_role = current_role
if not valid_alternation:
raise ValueError("Invalid message sequence for Google API. Roles must alternate between 'user' and 'model', with 'function' responses followed by 'user'.")
return google_messages, system_prompt

View File

@@ -0,0 +1,62 @@
from typing import Any
from openai import Stream
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from providers.openai_provider.client import initialize_client
from providers.openai_provider.completion import create_chat_completion
from providers.openai_provider.response import get_content, get_streaming_content, get_usage
from providers.openai_provider.tools import (
convert_tools,
format_tool_results,
get_original_message_with_calls,
has_tool_calls,
parse_tool_calls,
)
from src.providers.base import BaseProvider
class OpenAIProvider(BaseProvider):
"""Provider implementation for OpenAI and compatible APIs."""
temperature: float
def __init__(self, api_key: str, base_url: str | None = None, temperature: float = 0.6):
self.client = initialize_client(api_key, base_url)
self.api_key = api_key
self.base_url = self.client.base_url
self.temperature = temperature
def create_chat_completion(
self,
messages: list[dict[str, str]],
model: str,
max_tokens: int | None = None,
stream: bool = True,
tools: list[dict[str, Any]] | None = None,
) -> Stream[ChatCompletionChunk] | ChatCompletion:
return create_chat_completion(self, messages, model, self.temperature, max_tokens, stream, tools)
def get_streaming_content(self, response: Stream[ChatCompletionChunk]):
return get_streaming_content(response)
def get_content(self, response: ChatCompletion) -> str:
return get_content(response)
def has_tool_calls(self, response: Stream[ChatCompletionChunk] | ChatCompletion) -> bool:
return has_tool_calls(response)
def parse_tool_calls(self, response: ChatCompletion) -> list[dict[str, Any]]:
return parse_tool_calls(response)
def format_tool_results(self, tool_call_id: str, result: Any) -> dict[str, Any]:
return format_tool_results(tool_call_id, result)
def convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
return convert_tools(tools)
def get_original_message_with_calls(self, response: ChatCompletion) -> dict[str, Any]:
return get_original_message_with_calls(response)
def get_usage(self, response: Any) -> dict[str, int] | None:
return get_usage(response)

View File

@@ -0,0 +1,19 @@
import logging
from openai import OpenAI
from src.llm_models import MODELS
logger = logging.getLogger(__name__)
def initialize_client(api_key: str, base_url: str | None = None) -> OpenAI:
"""Initializes and returns an OpenAI client instance."""
effective_base_url = base_url or MODELS.get("openai", {}).get("endpoint")
logger.info(f"Initializing OpenAI client with base URL: {effective_base_url}")
try:
client = OpenAI(api_key=api_key, base_url=effective_base_url)
return client
except Exception as e:
logger.error(f"Failed to initialize OpenAI client: {e}", exc_info=True)
raise

View File

@@ -0,0 +1,61 @@
import logging
from typing import Any
from openai import Stream
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from providers.openai_provider.utils import truncate_messages
logger = logging.getLogger(__name__)
def create_chat_completion(
provider,
messages: list[dict[str, str]],
model: str,
temperature: float = 0.6,
max_tokens: int | None = None,
stream: bool = True,
tools: list[dict[str, Any]] | None = None,
) -> Stream[ChatCompletionChunk] | ChatCompletion:
"""Creates a chat completion using the OpenAI API, handling context window truncation."""
logger.debug(f"OpenAI create_chat_completion called. Model: {model}, Stream: {stream}, Tools: {bool(tools)}")
truncated_messages, initial_est_tokens, final_est_tokens = truncate_messages(messages, model)
try:
completion_params = {
"model": model,
"messages": truncated_messages,
"temperature": temperature,
"max_tokens": max_tokens,
"stream": stream,
}
if tools:
completion_params["tools"] = tools
completion_params["tool_choice"] = "auto"
completion_params = {k: v for k, v in completion_params.items() if v is not None}
log_params = completion_params.copy()
tools_log = log_params.get("tools", "Not Present")
logger.debug(f"Calling OpenAI API. Model: {log_params.get('model')}, Stream: {log_params.get('stream')}, Tools: {tools_log}")
logger.debug(f"Full API Params: {log_params}")
response = provider.client.chat.completions.create(**completion_params)
logger.debug("OpenAI API call successful.")
actual_usage = None
if isinstance(response, ChatCompletion) and response.usage:
actual_usage = {
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens,
}
logger.info(f"Actual OpenAI API usage: {actual_usage}")
return response
except Exception as e:
logger.error(f"OpenAI API error: {e}", exc_info=True)
raise

View File

@@ -0,0 +1,60 @@
import json
import logging
from collections.abc import Generator
from typing import Any
from openai import Stream
from openai.types.chat import ChatCompletion, ChatCompletionChunk
logger = logging.getLogger(__name__)
def get_streaming_content(response: Stream[ChatCompletionChunk]) -> Generator[str, None, None]:
"""Yields content chunks from an OpenAI streaming response."""
logger.debug("Processing OpenAI stream...")
full_delta = ""
try:
for chunk in response:
if chunk.choices:
delta = chunk.choices[0].delta.content
if delta:
full_delta += delta
yield delta
logger.debug(f"Stream finished. Total delta length: {len(full_delta)}")
except Exception as e:
logger.error(f"Error processing OpenAI stream: {e}", exc_info=True)
yield json.dumps({"error": f"Stream processing error: {str(e)}"})
def get_content(response: ChatCompletion) -> str:
"""Extracts content from a non-streaming OpenAI response."""
try:
if response.choices:
content = response.choices[0].message.content
logger.debug(f"Extracted content (length {len(content) if content else 0}) from non-streaming response.")
return content or ""
else:
logger.warning("No choices found in OpenAI non-streaming response.")
return "[No content received]"
except Exception as e:
logger.error(f"Error extracting content from OpenAI response: {e}", exc_info=True)
return f"[Error extracting content: {str(e)}]"
def get_usage(response: Any) -> dict[str, int] | None:
"""Extracts token usage from a non-streaming OpenAI response."""
try:
if isinstance(response, ChatCompletion) and response.usage:
usage = {
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
}
logger.debug(f"Extracted usage from OpenAI response: {usage}")
return usage
else:
if not isinstance(response, Stream):
logger.warning(f"Could not extract usage from OpenAI response object of type {type(response)}")
return None
except Exception as e:
logger.error(f"Error extracting usage from OpenAI response: {e}", exc_info=True)
return None

View File

@@ -0,0 +1,147 @@
import json
import logging
from typing import Any
from openai import Stream
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall
logger = logging.getLogger(__name__)
def has_tool_calls(response: Stream[ChatCompletionChunk] | ChatCompletion) -> bool:
"""Checks if the OpenAI response contains tool calls."""
try:
if isinstance(response, ChatCompletion):
if response.choices:
return bool(response.choices[0].message.tool_calls)
else:
logger.warning("No choices found in OpenAI non-streaming response for tool check.")
return False
elif isinstance(response, Stream):
logger.warning("has_tool_calls check on a stream is unreliable before consumption.")
return False
else:
logger.warning(f"has_tool_calls received unexpected type: {type(response)}")
return False
except Exception as e:
logger.error(f"Error checking for tool calls: {e}", exc_info=True)
return False
def parse_tool_calls(response: ChatCompletion) -> list[dict[str, Any]]:
"""Parses tool calls from a non-streaming OpenAI response."""
parsed_calls = []
try:
if not isinstance(response, ChatCompletion):
logger.error(f"parse_tool_calls expects ChatCompletion, got {type(response)}")
return []
if not response.choices:
logger.warning("No choices found in OpenAI non-streaming response for tool parsing.")
return []
tool_calls: list[ChatCompletionMessageToolCall] | None = response.choices[0].message.tool_calls
if not tool_calls:
return []
logger.debug(f"Parsing {len(tool_calls)} tool calls from OpenAI response.")
for call in tool_calls:
if call.type == "function":
parts = call.function.name.split("__", 1)
if len(parts) == 2:
server_name, func_name = parts
else:
logger.warning(f"Could not determine server_name from tool name '{call.function.name}'. Assuming default or error needed.")
server_name = None
func_name = call.function.name
arguments_obj = None
try:
if isinstance(call.function.arguments, str):
arguments_obj = json.loads(call.function.arguments)
else:
arguments_obj = call.function.arguments
except json.JSONDecodeError as json_err:
logger.error(f"Failed to parse JSON arguments for tool {func_name} (ID: {call.id}): {json_err}")
logger.error(f"Raw arguments string: {call.function.arguments}")
arguments_obj = {"error": "Failed to parse arguments", "raw_arguments": call.function.arguments}
parsed_calls.append({
"id": call.id,
"server_name": server_name,
"function_name": func_name,
"arguments": arguments_obj,
})
else:
logger.warning(f"Unsupported tool call type: {call.type}")
return parsed_calls
except Exception as e:
logger.error(f"Error parsing OpenAI tool calls: {e}", exc_info=True)
return []
def format_tool_results(tool_call_id: str, result: Any) -> dict[str, Any]:
"""Formats a tool result for an OpenAI follow-up request."""
try:
if isinstance(result, dict):
content = json.dumps(result)
elif isinstance(result, str):
content = result
else:
content = str(result)
except Exception as e:
logger.error(f"Error JSON-encoding tool result for {tool_call_id}: {e}")
content = json.dumps({"error": "Failed to encode tool result", "original_type": str(type(result))})
logger.debug(f"Formatting tool result for call ID {tool_call_id}")
return {
"role": "tool",
"tool_call_id": tool_call_id,
"content": content,
}
def convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Converts internal tool format to OpenAI's format."""
openai_tools = []
logger.debug(f"Converting {len(tools)} tools to OpenAI format.")
for tool in tools:
server_name = tool.get("server_name")
tool_name = tool.get("name")
description = tool.get("description")
input_schema = tool.get("inputSchema")
if not server_name or not tool_name or not description or not input_schema:
logger.warning(f"Skipping invalid tool definition during conversion: {tool}")
continue
prefixed_tool_name = f"{server_name}__{tool_name}"
openai_tool_format = {
"type": "function",
"function": {
"name": prefixed_tool_name,
"description": description,
"parameters": input_schema,
},
}
openai_tools.append(openai_tool_format)
logger.debug(f"Converted tool: {prefixed_tool_name}")
return openai_tools
def get_original_message_with_calls(response: ChatCompletion) -> dict[str, Any]:
"""Extracts the assistant's message containing tool calls."""
try:
if isinstance(response, ChatCompletion) and response.choices and response.choices[0].message.tool_calls:
message = response.choices[0].message
return message.model_dump(exclude_unset=True)
else:
logger.warning("Could not extract original message with tool calls from response.")
return {"role": "assistant", "content": "[Could not extract tool calls message]"}
except Exception as e:
logger.error(f"Error extracting original message with calls: {e}", exc_info=True)
return {"role": "assistant", "content": f"[Error extracting tool calls message: {str(e)}]"}

View File

@@ -0,0 +1,100 @@
import logging
import math
from src.llm_models import MODELS
logger = logging.getLogger(__name__)
def get_context_window(model: str) -> int:
"""Retrieves the context window size for a given model."""
default_window = 8000
try:
provider_models = MODELS.get("openai", {}).get("models", [])
for m in provider_models:
if m.get("id") == model:
return m.get("context_window", default_window)
logger.warning(f"Context window for OpenAI model '{model}' not found in MODELS config. Using default: {default_window}")
return default_window
except Exception as e:
logger.error(f"Error retrieving context window for model {model}: {e}. Using default: {default_window}", exc_info=True)
return default_window
def estimate_openai_token_count(messages: list[dict[str, str]]) -> int:
"""
Estimates the token count for OpenAI messages using char count / 4 approximation.
Note: This is less accurate than using tiktoken.
"""
total_chars = 0
for message in messages:
total_chars += len(message.get("role", ""))
content = message.get("content")
if isinstance(content, str):
total_chars += len(content)
estimated_tokens = math.ceil(total_chars / 4.0)
logger.debug(f"Estimated OpenAI token count (char/4): {estimated_tokens} for {len(messages)} messages")
return estimated_tokens
def truncate_messages(messages: list[dict[str, str]], model: str) -> tuple[list[dict[str, str]], int, int]:
"""
Truncates messages from the beginning if estimated token count exceeds the limit.
Preserves the first message if it's a system prompt.
Returns:
- The potentially truncated list of messages.
- The initial estimated token count.
- The final estimated token count after truncation (if any).
"""
context_limit = get_context_window(model)
buffer = 200
effective_limit = context_limit - buffer
initial_estimated_count = estimate_openai_token_count(messages)
final_estimated_count = initial_estimated_count
truncated_messages = list(messages)
has_system_prompt = False
if truncated_messages and truncated_messages[0].get("role") == "system":
has_system_prompt = True
if len(truncated_messages) == 1 and final_estimated_count > effective_limit:
logger.warning(f"System prompt alone ({final_estimated_count} tokens) exceeds effective limit ({effective_limit}). Cannot truncate further.")
return messages, initial_estimated_count, final_estimated_count
while final_estimated_count > effective_limit:
if has_system_prompt and len(truncated_messages) <= 1:
logger.warning("Truncation stopped: Only system prompt remains.")
break
if not has_system_prompt and len(truncated_messages) <= 0:
logger.warning("Truncation stopped: No messages left.")
break
remove_index = 1 if has_system_prompt and len(truncated_messages) > 1 else 0
if remove_index >= len(truncated_messages):
logger.error(f"Truncation logic error: remove_index {remove_index} out of bounds for {len(truncated_messages)} messages.")
break
removed_message = truncated_messages.pop(remove_index)
logger.debug(f"Truncating message at index {remove_index} (Role: {removed_message.get('role')}) due to context limit.")
final_estimated_count = estimate_openai_token_count(truncated_messages)
logger.debug(f"Recalculated estimated tokens: {final_estimated_count}")
if not truncated_messages:
logger.warning("Truncation resulted in empty message list.")
break
if initial_estimated_count != final_estimated_count:
logger.info(
f"Truncated messages for model {model}. "
f"Initial estimated tokens: {initial_estimated_count}, "
f"Final estimated tokens: {final_estimated_count}, "
f"Limit: {context_limit} (Effective: {effective_limit})"
)
else:
logger.debug(f"No truncation needed for model {model}. Estimated tokens: {final_estimated_count}, Limit: {context_limit} (Effective: {effective_limit})")
return truncated_messages, initial_estimated_count, final_estimated_count

1323
uv.lock generated Normal file

File diff suppressed because it is too large Load Diff