Compare commits
1 Commits
51e3058961
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
247835e595
|
@@ -8,18 +8,21 @@ api_key = YOUR_API_KEY
|
||||
base_url = https://openrouter.ai/api/v1
|
||||
model = openai/gpt-4o-2024-11-20
|
||||
context_window = 128000
|
||||
temperature = 0.6
|
||||
|
||||
[anthropic]
|
||||
api_key = YOUR_API_KEY
|
||||
base_url = https://api.anthropic.com/v1/messages
|
||||
model = claude-3-7-sonnet-20250219
|
||||
context_window = 128000
|
||||
temperature = 0.6
|
||||
|
||||
[google]
|
||||
api_key = YOUR_API_KEY
|
||||
base_url = https://generativelanguage.googleapis.com/v1beta/generateContent
|
||||
model = gemini-2.0-flash
|
||||
context_window = 1000000
|
||||
temperature = 0.6
|
||||
|
||||
|
||||
[openai]
|
||||
@@ -27,6 +30,7 @@ api_key = YOUR_API_KEY
|
||||
base_url = https://api.openai.com/v1
|
||||
model = openai/gpt-4o
|
||||
context_window = 128000
|
||||
temperature = 0.6
|
||||
|
||||
[mcp]
|
||||
servers_json = config/mcp_config.json
|
||||
|
||||
106
project_planning/updates.md
Normal file
106
project_planning/updates.md
Normal file
@@ -0,0 +1,106 @@
|
||||
What is the google-genai Module?
|
||||
|
||||
The google-genai module is part of the Google Gen AI Python SDK, a software development kit provided by Google to enable developers to integrate Google's generative AI models into Python applications. This SDK is distinct from the older, deprecated google-generativeai package. The google-genai package represents the newer, unified SDK designed to work with Google's latest generative AI offerings, such as the Gemini models.
|
||||
|
||||
Installation
|
||||
|
||||
To use the google-genai module, you need to install it via pip. The package name on PyPI is google-genai, and you can install it with the following command:
|
||||
|
||||
bash
|
||||
|
||||
```bash
|
||||
pip install google-genai
|
||||
```
|
||||
|
||||
This installs the necessary dependencies and makes the module available in your Python environment.
|
||||
|
||||
Correct Import Statement
|
||||
|
||||
The standard import statement for the google-genai SDK, as per the official documentation and examples, is:
|
||||
|
||||
python
|
||||
|
||||
```python
|
||||
from google import genai
|
||||
```
|
||||
|
||||
This differs from the older SDK's import style, which was:
|
||||
|
||||
python
|
||||
|
||||
```python
|
||||
import google.generativeai as genai
|
||||
```
|
||||
|
||||
When you install google-genai, it provides a module structure where genai is a submodule of the google package. Thus, from google import genai is the correct way to access its functionality.
|
||||
|
||||
Usage in the New SDK
|
||||
|
||||
Unlike the older google-generativeai SDK, which used a configure method (e.g., genai.configure(api_key='YOUR_API_KEY')) to set up the API key globally, the new google-genai SDK adopts a client-based approach. You create a Client instance with your API key and use it to interact with the models. Here’s a basic example:
|
||||
|
||||
python
|
||||
|
||||
```python
|
||||
from google import genai
|
||||
|
||||
# Initialize the client with your API key
|
||||
client = genai.Client(api_key='YOUR_API_KEY')
|
||||
|
||||
# Example: Generate content using a model
|
||||
response = client.models.generate_content(
|
||||
model='gemini-2.0-flash-001', # Specify the model name
|
||||
contents='Why is the sky blue?'
|
||||
)
|
||||
|
||||
print(response.text)
|
||||
```
|
||||
|
||||
Key points about this usage:
|
||||
|
||||
- No configure Method: The new SDK does not have a genai.configure method directly on the genai module. Instead, you pass the API key when creating a Client instance.
|
||||
|
||||
- Client-Based Interaction: All interactions with the generative models (e.g., generating content) are performed through the client object.
|
||||
|
||||
|
||||
Official Documentation
|
||||
|
||||
The official documentation for the google-genai SDK can be found on Google's API documentation site. Specifically:
|
||||
|
||||
- Google Gen AI Python SDK Documentation: Hosted at https://googleapis.github.io/python-genai/, this site provides detailed guides, API references, and code examples. It confirms the use of from google import genai and the client-based approach.
|
||||
|
||||
- GitHub Repository: The source code and additional examples are available at https://github.com/googleapis/python-genai. The repository documentation reinforces that from google import genai is the import style for the new SDK.
|
||||
|
||||
|
||||
Why Your Import is Correct
|
||||
|
||||
You mentioned that your import is correct, which I assume refers to from google import genai. This is indeed the proper import for the google-genai package, aligning with the new SDK's design. If you're encountering issues (e.g., an error like module 'google.genai' has no attribute 'configure'), it’s likely because the code is trying to use methods from the older SDK (like genai.configure) that don’t exist in the new one. To resolve this, you should update the code to use the client-based approach shown above.
|
||||
|
||||
Troubleshooting Common Issues
|
||||
|
||||
If you're seeing errors with from google import genai, here are some things to check:
|
||||
|
||||
1. Correct Package Installed:
|
||||
|
||||
- Ensure google-genai is installed (pip install google-genai).
|
||||
|
||||
- If google-generativeai is installed instead, uninstall it with pip uninstall google-generativeai to avoid conflicts, then install google-genai.
|
||||
|
||||
2. Code Compatibility:
|
||||
|
||||
- If your code uses genai.configure or assumes the older SDK’s structure, you’ll need to refactor it. Replace configuration calls with genai.Client(api_key='...') and adjust model interactions to use the client object.
|
||||
|
||||
3. Environment Verification:
|
||||
|
||||
- Run pip show google-genai to confirm the package is installed and check its version. This ensures you’re working with the intended SDK.
|
||||
|
||||
|
||||
Additional Resources
|
||||
|
||||
- PyPI Page: The google-genai package on PyPI (https://pypi.org/project/google-genai/) provides installation instructions and links to the GitHub repository.
|
||||
|
||||
- Examples: The GitHub repository includes sample code demonstrating how to use the SDK with from google import genai.
|
||||
|
||||
|
||||
Conclusion
|
||||
|
||||
Your import, from google import genai, aligns with the google-genai module from the new Google Gen AI Python SDK. The documentation and online resources confirm this as the correct approach for the current, unified SDK. If import google.generativeai was previously suggested or used, it pertains to the older, deprecated SDK, which explains why it might be considered incorrect in your context. To fully leverage google-genai, ensure your code uses the client-based API as outlined, and you should be able to interact with Google’s generative AI models effectively. If you’re still facing specific errors, feel free to share them, and I can assist further!
|
||||
@@ -87,7 +87,7 @@ skip-magic-trailing-comma = false
|
||||
combine-as-imports = true
|
||||
|
||||
[tool.ruff.lint.mccabe]
|
||||
max-complexity = 16
|
||||
max-complexity = 30
|
||||
|
||||
[tool.ruff.lint.flake8-tidy-imports]
|
||||
# Disallow all relative imports.
|
||||
|
||||
55
src/app.py
55
src/app.py
@@ -7,7 +7,6 @@ import streamlit as st
|
||||
from llm_client import LLMClient
|
||||
from src.custom_mcp.manager import SyncMCPManager
|
||||
|
||||
# Configure logging for the app
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -22,14 +21,12 @@ def init_session_state():
|
||||
logger.info("Attempting to initialize clients...")
|
||||
try:
|
||||
config = configparser.ConfigParser()
|
||||
# TODO: Improve config file path handling (e.g., environment variable, absolute path)
|
||||
config_files_read = config.read("config/config.ini")
|
||||
if not config_files_read:
|
||||
raise FileNotFoundError("config.ini not found or could not be read.")
|
||||
logger.info(f"Read configuration from: {config_files_read}")
|
||||
|
||||
# --- MCP Manager Setup ---
|
||||
mcp_config_path = "config/mcp_config.json" # Default
|
||||
mcp_config_path = "config/mcp_config.json"
|
||||
if config.has_section("mcp") and config["mcp"].get("servers_json"):
|
||||
mcp_config_path = config["mcp"]["servers_json"]
|
||||
logger.info(f"Using MCP config path from config.ini: {mcp_config_path}")
|
||||
@@ -38,39 +35,37 @@ def init_session_state():
|
||||
|
||||
mcp_manager = SyncMCPManager(mcp_config_path)
|
||||
if not mcp_manager.initialize():
|
||||
# Log warning but continue - LLMClient will operate without tools
|
||||
logger.warning("MCP Manager failed to initialize. Proceeding without MCP tools.")
|
||||
else:
|
||||
logger.info("MCP Manager initialized successfully.")
|
||||
# Register shutdown hook for MCP manager
|
||||
atexit.register(mcp_manager.shutdown)
|
||||
logger.info("Registered MCP Manager shutdown hook.")
|
||||
|
||||
# --- LLM Client Setup ---
|
||||
provider_name = None
|
||||
model_name = None
|
||||
api_key = None
|
||||
base_url = None
|
||||
|
||||
# 1. Determine provider from [base] section
|
||||
if config.has_section("base") and config["base"].get("provider"):
|
||||
provider_name = config["base"].get("provider")
|
||||
logger.info(f"Provider selected from [base] section: {provider_name}")
|
||||
else:
|
||||
# Fallback or error if [base] provider is missing? Let's error for now.
|
||||
raise ValueError("Missing 'provider' setting in [base] section of config.ini")
|
||||
|
||||
# 2. Read details from the specific provider's section
|
||||
if config.has_section(provider_name):
|
||||
provider_config = config[provider_name]
|
||||
model_name = provider_config.get("model")
|
||||
api_key = provider_config.get("api_key")
|
||||
base_url = provider_config.get("base_url") # Optional
|
||||
base_url = provider_config.get("base_url")
|
||||
provider_temperature = provider_config.getfloat("temperature", 0.6)
|
||||
if "temperature" not in provider_config:
|
||||
logger.warning(f"Temperature not found in [{provider_name}] section, defaulting to {provider_temperature}")
|
||||
else:
|
||||
logger.info(f"Loaded temperature for {provider_name}: {provider_temperature}")
|
||||
logger.info(f"Read configuration from [{provider_name}] section.")
|
||||
else:
|
||||
raise ValueError(f"Missing configuration section '[{provider_name}]' in config.ini for the selected provider.")
|
||||
|
||||
# Validate required config
|
||||
if not api_key:
|
||||
raise ValueError(f"Missing 'api_key' in [{provider_name}] section of config.ini")
|
||||
if not model_name:
|
||||
@@ -82,15 +77,15 @@ def init_session_state():
|
||||
api_key=api_key,
|
||||
mcp_manager=mcp_manager,
|
||||
base_url=base_url,
|
||||
temperature=provider_temperature,
|
||||
)
|
||||
st.session_state.model_name = model_name
|
||||
st.session_state.provider_name = provider_name # Store provider name
|
||||
st.session_state.provider_name = provider_name
|
||||
logger.info("LLMClient initialized successfully.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize application clients: {e}", exc_info=True)
|
||||
st.error(f"Application Initialization Error: {e}. Please check configuration and logs.")
|
||||
# Stop the app if initialization fails critically
|
||||
st.stop()
|
||||
|
||||
|
||||
@@ -98,9 +93,7 @@ def display_chat_messages():
|
||||
"""Displays chat messages stored in session state."""
|
||||
for message in st.session_state.messages:
|
||||
with st.chat_message(message["role"]):
|
||||
# Display content
|
||||
st.markdown(message["content"])
|
||||
# Display usage if available (for assistant messages)
|
||||
if message["role"] == "assistant" and "usage" in message:
|
||||
usage = message["usage"]
|
||||
prompt_tokens = usage.get("prompt_tokens", "N/A")
|
||||
@@ -121,19 +114,15 @@ def handle_user_input():
|
||||
response_placeholder = st.empty()
|
||||
full_response = ""
|
||||
error_occurred = False
|
||||
response_usage = None # Initialize usage info
|
||||
response_usage = None
|
||||
|
||||
logger.info("Processing message via LLMClient...")
|
||||
# Use the new client and method
|
||||
# NOTE: Setting stream=False to easily get usage info from the response dict.
|
||||
# A more complex solution is needed to get usage with streaming.
|
||||
response_data = st.session_state.client.chat_completion(
|
||||
messages=st.session_state.messages,
|
||||
model=st.session_state.model_name,
|
||||
stream=False, # Set to False for usage info
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# Handle the response (now expecting a dict)
|
||||
if isinstance(response_data, dict):
|
||||
if "error" in response_data:
|
||||
full_response = f"Error: {response_data['error']}"
|
||||
@@ -142,24 +131,19 @@ def handle_user_input():
|
||||
error_occurred = True
|
||||
else:
|
||||
full_response = response_data.get("content", "")
|
||||
response_usage = response_data.get("usage") # Get usage dict
|
||||
if not full_response and not error_occurred: # Check error_occurred flag too
|
||||
response_usage = response_data.get("usage")
|
||||
if not full_response and not error_occurred:
|
||||
logger.warning("Empty content received from LLMClient.")
|
||||
# Display nothing or a placeholder? Let's display nothing.
|
||||
# full_response = "[Empty Response]"
|
||||
# Display the full response at once (no streaming)
|
||||
response_placeholder.markdown(full_response)
|
||||
logger.debug("Non-streaming response processed.")
|
||||
|
||||
else:
|
||||
# Unexpected response type
|
||||
full_response = "[Unexpected response format from LLMClient]"
|
||||
logger.error(f"Unexpected response type: {type(response_data)}")
|
||||
st.error(full_response)
|
||||
error_occurred = True
|
||||
|
||||
# Add response to history, including usage if available
|
||||
if not error_occurred and full_response: # Only add if no error and content exists
|
||||
if not error_occurred and full_response:
|
||||
assistant_message = {"role": "assistant", "content": full_response}
|
||||
if response_usage:
|
||||
assistant_message["usage"] = response_usage
|
||||
@@ -181,35 +165,28 @@ def main():
|
||||
try:
|
||||
init_session_state()
|
||||
|
||||
# --- Display Enhanced Header ---
|
||||
provider_name = st.session_state.get("provider_name", "Unknown Provider")
|
||||
model_name = st.session_state.get("model_name", "Unknown Model")
|
||||
mcp_manager = st.session_state.client.mcp_manager # Get the manager
|
||||
mcp_manager = st.session_state.client.mcp_manager
|
||||
|
||||
server_count = 0
|
||||
tool_count = 0
|
||||
if mcp_manager and mcp_manager.initialized:
|
||||
server_count = len(mcp_manager.servers)
|
||||
try:
|
||||
# Get tool count (might be slightly slow if many tools/servers)
|
||||
tool_count = len(mcp_manager.list_all_tools())
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not retrieve tool count for header: {e}")
|
||||
tool_count = "N/A" # Display N/A if listing fails
|
||||
tool_count = "N/A"
|
||||
|
||||
# Display the new header format
|
||||
st.markdown(f"# Say Hi to **{provider_name.capitalize()}**!")
|
||||
st.write(f"MCP Servers: **{server_count}** | Tools: **{tool_count}**")
|
||||
st.write(f"Model: **{model_name}**")
|
||||
st.divider()
|
||||
# -----------------------------
|
||||
|
||||
# Removed the previous caption display
|
||||
|
||||
display_chat_messages()
|
||||
handle_user_input()
|
||||
except Exception as e:
|
||||
# Catch potential errors during rendering or handling
|
||||
logger.critical(f"Critical error in main app flow: {e}", exc_info=True)
|
||||
st.error(f"A critical application error occurred: {e}")
|
||||
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
# This file makes src/mcp a Python package
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# src/mcp/client.py
|
||||
"""Client class for managing and interacting with a single MCP server process."""
|
||||
|
||||
import asyncio
|
||||
@@ -9,9 +8,8 @@ from custom_mcp import process, protocol
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Define reasonable timeouts
|
||||
LIST_TOOLS_TIMEOUT = 20.0 # Seconds (using the increased value from previous step)
|
||||
CALL_TOOL_TIMEOUT = 110.0 # Seconds
|
||||
LIST_TOOLS_TIMEOUT = 20.0
|
||||
CALL_TOOL_TIMEOUT = 110.0
|
||||
|
||||
|
||||
class MCPClient:
|
||||
@@ -39,7 +37,7 @@ class MCPClient:
|
||||
self._stderr_task: asyncio.Task | None = None
|
||||
self._request_counter = 0
|
||||
self._is_running = False
|
||||
self.logger = logging.getLogger(f"{__name__}.{self.server_name}") # Instance-specific logger
|
||||
self.logger = logging.getLogger(f"{__name__}.{self.server_name}")
|
||||
|
||||
async def _log_stderr(self):
|
||||
"""Logs stderr output from the server process."""
|
||||
@@ -55,7 +53,6 @@ class MCPClient:
|
||||
except asyncio.CancelledError:
|
||||
self.logger.debug("Stderr logging task cancelled.")
|
||||
except Exception as e:
|
||||
# Log errors but don't crash the logger task if possible
|
||||
self.logger.error(f"Error reading stderr: {e}", exc_info=True)
|
||||
finally:
|
||||
self.logger.debug("Stderr logging task finished.")
|
||||
@@ -79,13 +76,11 @@ class MCPClient:
|
||||
|
||||
if self.reader is None or self.writer is None:
|
||||
self.logger.error("Failed to get stdout/stdin streams after process start.")
|
||||
await self.stop() # Attempt cleanup
|
||||
await self.stop()
|
||||
return False
|
||||
|
||||
# Start background task to monitor stderr
|
||||
self._stderr_task = asyncio.create_task(self._log_stderr())
|
||||
|
||||
# --- Start MCP Initialization Handshake ---
|
||||
self.logger.info("Starting MCP initialization handshake...")
|
||||
self._request_counter += 1
|
||||
init_req_id = self._request_counter
|
||||
@@ -94,21 +89,18 @@ class MCPClient:
|
||||
"id": init_req_id,
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05", # Use a recent version
|
||||
"clientInfo": {"name": "CustomMCPClient", "version": "1.0.0"}, # Identify the client
|
||||
"capabilities": {}, # Client capabilities (can be empty)
|
||||
"protocolVersion": "2024-11-05",
|
||||
"clientInfo": {"name": "CustomMCPClient", "version": "1.0.0"},
|
||||
"capabilities": {},
|
||||
},
|
||||
}
|
||||
|
||||
# Define a timeout for initialization
|
||||
INITIALIZE_TIMEOUT = 15.0 # Seconds
|
||||
INITIALIZE_TIMEOUT = 15.0
|
||||
|
||||
try:
|
||||
# Send initialize request
|
||||
await protocol.send_request(self.writer, initialize_req)
|
||||
self.logger.debug(f"Sent 'initialize' request (ID: {init_req_id}). Waiting for response...")
|
||||
|
||||
# Wait for initialize response
|
||||
init_response = await protocol.read_response(self.reader, INITIALIZE_TIMEOUT)
|
||||
|
||||
if init_response and init_response.get("id") == init_req_id:
|
||||
@@ -117,9 +109,8 @@ class MCPClient:
|
||||
await self.stop()
|
||||
return False
|
||||
elif "result" in init_response:
|
||||
self.logger.info(f"Received 'initialize' response: {init_response.get('result', '{}')}") # Log server capabilities if provided
|
||||
self.logger.info(f"Received 'initialize' response: {init_response.get('result', '{}')}")
|
||||
|
||||
# Send initialized notification (using standard method name)
|
||||
initialized_notify = {"jsonrpc": "2.0", "method": "notifications/initialized", "params": {}}
|
||||
await protocol.send_request(self.writer, initialized_notify)
|
||||
self.logger.info("'notifications/initialized' notification sent.")
|
||||
@@ -135,7 +126,7 @@ class MCPClient:
|
||||
self.logger.error(f"Received response with mismatched ID during initialization. Expected {init_req_id}, got {init_response.get('id')}")
|
||||
await self.stop()
|
||||
return False
|
||||
else: # Timeout case
|
||||
else:
|
||||
self.logger.error(f"'initialize' request timed out after {INITIALIZE_TIMEOUT} seconds.")
|
||||
await self.stop()
|
||||
return False
|
||||
@@ -148,26 +139,23 @@ class MCPClient:
|
||||
self.logger.error(f"Unexpected error during initialization handshake: {e}", exc_info=True)
|
||||
await self.stop()
|
||||
return False
|
||||
# --- End MCP Initialization Handshake ---
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to start MCP server process: {e}", exc_info=True)
|
||||
self.process = None # Ensure process is None on failure
|
||||
self.process = None
|
||||
self.reader = None
|
||||
self.writer = None
|
||||
self._is_running = False
|
||||
return False
|
||||
|
||||
async def stop(self):
|
||||
"""Stops the MCP server subprocess gracefully."""
|
||||
if not self._is_running and not self.process:
|
||||
self.logger.debug("Stop called but client is not running.")
|
||||
return
|
||||
|
||||
self.logger.info("Stopping MCP server process...")
|
||||
self._is_running = False # Mark as stopping
|
||||
self._is_running = False
|
||||
|
||||
# Cancel stderr logging task
|
||||
if self._stderr_task and not self._stderr_task.done():
|
||||
self._stderr_task.cancel()
|
||||
try:
|
||||
@@ -178,11 +166,9 @@ class MCPClient:
|
||||
self.logger.error(f"Error waiting for stderr task cancellation: {e}")
|
||||
self._stderr_task = None
|
||||
|
||||
# Stop the process using the utility function
|
||||
if self.process:
|
||||
await process.stop_mcp_process(self.process, self.server_name)
|
||||
|
||||
# Nullify references
|
||||
self.process = None
|
||||
self.reader = None
|
||||
self.writer = None
|
||||
@@ -219,7 +205,6 @@ class MCPClient:
|
||||
self.logger.error(f"Error response for listTools ID {req_id}: {response['error']}")
|
||||
return None
|
||||
else:
|
||||
# Includes timeout case (read_response returns None)
|
||||
self.logger.error(f"No valid response or timeout for listTools ID {req_id}.")
|
||||
return None
|
||||
|
||||
@@ -260,15 +245,12 @@ class MCPClient:
|
||||
response = await protocol.read_response(self.reader, CALL_TOOL_TIMEOUT)
|
||||
|
||||
if response and "result" in response:
|
||||
# Assuming result is the desired payload
|
||||
self.logger.info(f"Tool '{tool_name}' executed successfully.")
|
||||
return response["result"]
|
||||
elif response and "error" in response:
|
||||
self.logger.error(f"Error response for tool '{tool_name}' ID {req_id}: {response['error']}")
|
||||
# Return the error structure itself? Or just None? Returning error dict for now.
|
||||
return {"error": response["error"]}
|
||||
else:
|
||||
# Includes timeout case
|
||||
self.logger.error(f"No valid response or timeout for tool '{tool_name}' ID {req_id}.")
|
||||
return None
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# src/mcp/manager.py
|
||||
"""Synchronous manager for multiple MCPClient instances."""
|
||||
|
||||
import asyncio
|
||||
@@ -7,19 +6,15 @@ import logging
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
# Use relative imports within the mcp package
|
||||
from custom_mcp.client import MCPClient
|
||||
|
||||
# Configure basic logging
|
||||
# Consider moving this to the main app entry point if not already done
|
||||
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Define reasonable timeouts for sync calls (should be slightly longer than async timeouts)
|
||||
INITIALIZE_TIMEOUT = 60.0 # Seconds
|
||||
SHUTDOWN_TIMEOUT = 30.0 # Seconds
|
||||
LIST_ALL_TOOLS_TIMEOUT = 30.0 # Seconds
|
||||
EXECUTE_TOOL_TIMEOUT = 120.0 # Seconds
|
||||
INITIALIZE_TIMEOUT = 60.0
|
||||
SHUTDOWN_TIMEOUT = 30.0
|
||||
LIST_ALL_TOOLS_TIMEOUT = 30.0
|
||||
EXECUTE_TOOL_TIMEOUT = 120.0
|
||||
|
||||
|
||||
class SyncMCPManager:
|
||||
@@ -37,7 +32,6 @@ class SyncMCPManager:
|
||||
"""
|
||||
self.config_path = config_path
|
||||
self.config: dict[str, Any] | None = None
|
||||
# Stores server_name -> MCPClient instance
|
||||
self.servers: dict[str, MCPClient] = {}
|
||||
self.initialized = False
|
||||
self._lock = threading.Lock()
|
||||
@@ -50,7 +44,6 @@ class SyncMCPManager:
|
||||
"""Load MCP configuration from JSON file."""
|
||||
logger.debug(f"Attempting to load MCP config from: {self.config_path}")
|
||||
try:
|
||||
# Using direct file access
|
||||
with open(self.config_path) as f:
|
||||
self.config = json.load(f)
|
||||
logger.info("MCP configuration loaded successfully.")
|
||||
@@ -65,8 +58,6 @@ class SyncMCPManager:
|
||||
logger.error(f"Error loading MCP config from {self.config_path}: {e}", exc_info=True)
|
||||
self.config = None
|
||||
|
||||
# --- Background Event Loop Management ---
|
||||
|
||||
def _run_event_loop(self):
|
||||
"""Target function for the background event loop thread."""
|
||||
try:
|
||||
@@ -75,14 +66,12 @@ class SyncMCPManager:
|
||||
self._loop.run_forever()
|
||||
finally:
|
||||
if self._loop and not self._loop.is_closed():
|
||||
# Clean up remaining tasks before closing
|
||||
try:
|
||||
tasks = asyncio.all_tasks(self._loop)
|
||||
if tasks:
|
||||
logger.debug(f"Cancelling {len(tasks)} outstanding tasks before closing loop...")
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
# Allow cancellation to propagate
|
||||
self._loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
|
||||
logger.debug("Outstanding tasks cancelled.")
|
||||
self._loop.run_until_complete(self._loop.shutdown_asyncgens())
|
||||
@@ -99,9 +88,7 @@ class SyncMCPManager:
|
||||
self._thread = threading.Thread(target=self._run_event_loop, name="MCPEventLoop", daemon=True)
|
||||
self._thread.start()
|
||||
logger.info("Event loop thread started.")
|
||||
# Wait briefly for the loop to become available and running
|
||||
while self._loop is None or not self._loop.is_running():
|
||||
# Use time.sleep in sync context
|
||||
import time
|
||||
|
||||
time.sleep(0.01)
|
||||
@@ -121,8 +108,6 @@ class SyncMCPManager:
|
||||
self._thread = None
|
||||
logger.info("Event loop stopped.")
|
||||
|
||||
# --- Public Synchronous Interface ---
|
||||
|
||||
def initialize(self) -> bool:
|
||||
"""
|
||||
Initializes and starts all configured MCP servers synchronously.
|
||||
@@ -147,8 +132,6 @@ class SyncMCPManager:
|
||||
|
||||
logger.info("Submitting asynchronous server initialization...")
|
||||
|
||||
# Prepare coroutine to start all clients
|
||||
|
||||
async def _async_init_all():
|
||||
tasks = []
|
||||
for server_name, server_config in self.config["mcpServers"].items():
|
||||
@@ -161,19 +144,17 @@ class SyncMCPManager:
|
||||
|
||||
client = MCPClient(server_name, command, args, config_env)
|
||||
self.servers[server_name] = client
|
||||
tasks.append(client.start()) # Append the start coroutine
|
||||
tasks.append(client.start())
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Check results - True means success, False or Exception means failure
|
||||
all_success = True
|
||||
failed_servers = []
|
||||
for i, result in enumerate(results):
|
||||
server_name = list(self.config["mcpServers"].keys())[i] # Assumes order is maintained
|
||||
server_name = list(self.config["mcpServers"].keys())[i]
|
||||
if isinstance(result, Exception) or result is False:
|
||||
all_success = False
|
||||
failed_servers.append(server_name)
|
||||
# Remove failed client from managed servers
|
||||
if server_name in self.servers:
|
||||
del self.servers[server_name]
|
||||
logger.error(f"Failed to start client for server '{server_name}'. Result/Error: {result}")
|
||||
@@ -182,7 +163,6 @@ class SyncMCPManager:
|
||||
logger.error(f"Initialization failed for servers: {failed_servers}")
|
||||
return all_success
|
||||
|
||||
# Run the initialization coroutine in the background loop
|
||||
future = asyncio.run_coroutine_threadsafe(_async_init_all(), self._loop)
|
||||
try:
|
||||
success = future.result(timeout=INITIALIZE_TIMEOUT)
|
||||
@@ -192,17 +172,16 @@ class SyncMCPManager:
|
||||
else:
|
||||
logger.error("Asynchronous initialization failed.")
|
||||
self.initialized = False
|
||||
# Attempt to clean up any partially started servers
|
||||
self.shutdown() # Call sync shutdown
|
||||
self.shutdown()
|
||||
except TimeoutError:
|
||||
logger.error(f"Initialization timed out after {INITIALIZE_TIMEOUT}s.")
|
||||
self.initialized = False
|
||||
self.shutdown() # Clean up
|
||||
self.shutdown()
|
||||
success = False
|
||||
except Exception as e:
|
||||
logger.error(f"Exception during initialization future result: {e}", exc_info=True)
|
||||
self.initialized = False
|
||||
self.shutdown() # Clean up
|
||||
self.shutdown()
|
||||
success = False
|
||||
|
||||
return self.initialized
|
||||
@@ -211,20 +190,14 @@ class SyncMCPManager:
|
||||
"""Shuts down all managed MCP servers synchronously."""
|
||||
logger.info("Manager shutdown requested.")
|
||||
with self._lock:
|
||||
# Check servers dict too, in case init was partial
|
||||
if not self.initialized and not self.servers:
|
||||
logger.debug("Shutdown skipped: Not initialized or no servers running.")
|
||||
# Ensure loop is stopped if it exists
|
||||
if self._thread and self._thread.is_alive():
|
||||
self._stop_event_loop_thread()
|
||||
return
|
||||
|
||||
if not self._loop or not self._loop.is_running():
|
||||
logger.warning("Shutdown requested but event loop not running. Attempting direct cleanup.")
|
||||
# Attempt direct cleanup if loop isn't running (shouldn't happen ideally)
|
||||
# This part is tricky as MCPClient.stop is async.
|
||||
# For simplicity, we might just log and rely on process termination on app exit.
|
||||
# Or, try a temporary loop just for shutdown? Let's stick to stopping the thread for now.
|
||||
self.servers = {}
|
||||
self.initialized = False
|
||||
if self._thread and self._thread.is_alive():
|
||||
@@ -233,28 +206,22 @@ class SyncMCPManager:
|
||||
|
||||
logger.info("Submitting asynchronous server shutdown...")
|
||||
|
||||
# Prepare coroutine to stop all clients
|
||||
|
||||
async def _async_shutdown_all():
|
||||
tasks = [client.stop() for client in self.servers.values()]
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Run the shutdown coroutine in the background loop
|
||||
future = asyncio.run_coroutine_threadsafe(_async_shutdown_all(), self._loop)
|
||||
try:
|
||||
future.result(timeout=SHUTDOWN_TIMEOUT)
|
||||
logger.info("Asynchronous shutdown completed.")
|
||||
except TimeoutError:
|
||||
logger.error(f"Shutdown timed out after {SHUTDOWN_TIMEOUT}s. Event loop will be stopped.")
|
||||
# Processes might still be running, OS will clean up on exit hopefully
|
||||
except Exception as e:
|
||||
logger.error(f"Exception during shutdown future result: {e}", exc_info=True)
|
||||
finally:
|
||||
# Always mark as uninitialized and clear servers dict
|
||||
self.servers = {}
|
||||
self.initialized = False
|
||||
# Stop the background thread
|
||||
self._stop_event_loop_thread()
|
||||
|
||||
logger.info("Manager shutdown complete.")
|
||||
@@ -277,7 +244,6 @@ class SyncMCPManager:
|
||||
|
||||
logger.info(f"Requesting tools from {len(self.servers)} servers...")
|
||||
|
||||
# Prepare coroutine to list tools from all clients
|
||||
async def _async_list_all():
|
||||
tasks = []
|
||||
server_names_in_order = []
|
||||
@@ -293,10 +259,8 @@ class SyncMCPManager:
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"Error listing tools for server '{server_name}': {result}")
|
||||
elif result is None:
|
||||
# MCPClient.list_tools returns None on timeout/error
|
||||
logger.error(f"Failed to list tools for server '{server_name}' (timeout or error).")
|
||||
elif isinstance(result, list):
|
||||
# Add server_name to each tool definition
|
||||
for tool in result:
|
||||
tool["server_name"] = server_name
|
||||
all_tools.extend(result)
|
||||
@@ -305,7 +269,6 @@ class SyncMCPManager:
|
||||
logger.error(f"Unexpected result type ({type(result)}) when listing tools for {server_name}.")
|
||||
return all_tools
|
||||
|
||||
# Run the coroutine in the background loop
|
||||
future = asyncio.run_coroutine_threadsafe(_async_list_all(), self._loop)
|
||||
try:
|
||||
aggregated_tools = future.result(timeout=LIST_ALL_TOOLS_TIMEOUT)
|
||||
@@ -346,18 +309,16 @@ class SyncMCPManager:
|
||||
|
||||
logger.info(f"Executing tool '{tool_name}' on server '{server_name}' with args: {arguments}")
|
||||
|
||||
# Run the client's call_tool coroutine in the background loop
|
||||
future = asyncio.run_coroutine_threadsafe(client.call_tool(tool_name, arguments), self._loop)
|
||||
try:
|
||||
result = future.result(timeout=EXECUTE_TOOL_TIMEOUT)
|
||||
# MCPClient.call_tool returns the result dict or an error dict or None
|
||||
if result is None:
|
||||
logger.error(f"Tool execution '{tool_name}' on {server_name} failed (timeout or comm error).")
|
||||
elif isinstance(result, dict) and "error" in result:
|
||||
logger.error(f"Tool execution '{tool_name}' on {server_name} returned error: {result['error']}")
|
||||
else:
|
||||
logger.info(f"Tool '{tool_name}' execution successful.")
|
||||
return result # Return result dict, error dict, or None
|
||||
return result
|
||||
except TimeoutError:
|
||||
logger.error(f"Tool execution timed out after {EXECUTE_TOOL_TIMEOUT}s for '{tool_name}' on {server_name}.")
|
||||
return None
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# src/mcp/process.py
|
||||
"""Async utilities for managing MCP server subprocesses."""
|
||||
|
||||
import asyncio
|
||||
@@ -29,25 +28,20 @@ async def start_mcp_process(command: str, args: list[str], config_env: dict[str,
|
||||
"""
|
||||
logger.debug(f"Preparing to start process for command: {command}")
|
||||
|
||||
# --- Add tilde expansion for arguments ---
|
||||
expanded_args = []
|
||||
try:
|
||||
for arg in args:
|
||||
if isinstance(arg, str) and "~" in arg:
|
||||
expanded_args.append(os.path.expanduser(arg))
|
||||
else:
|
||||
# Ensure all args are strings for list2cmdline
|
||||
expanded_args.append(str(arg))
|
||||
logger.debug(f"Expanded args: {expanded_args}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error expanding arguments for {command}: {e}", exc_info=True)
|
||||
raise ValueError(f"Failed to expand arguments: {e}") from e
|
||||
|
||||
# --- Merge os.environ with config_env ---
|
||||
merged_env = {**os.environ, **config_env}
|
||||
# logger.debug(f"Merged environment prepared (keys: {list(merged_env.keys())})") # Avoid logging values
|
||||
|
||||
# Combine command and expanded args into a single string for shell execution
|
||||
try:
|
||||
cmd_string = subprocess.list2cmdline([command] + expanded_args)
|
||||
logger.debug(f"Executing shell command: {cmd_string}")
|
||||
@@ -55,7 +49,6 @@ async def start_mcp_process(command: str, args: list[str], config_env: dict[str,
|
||||
logger.error(f"Error creating command string: {e}", exc_info=True)
|
||||
raise ValueError(f"Failed to create command string: {e}") from e
|
||||
|
||||
# --- Start the subprocess using shell ---
|
||||
try:
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
cmd_string,
|
||||
@@ -68,10 +61,10 @@ async def start_mcp_process(command: str, args: list[str], config_env: dict[str,
|
||||
return process
|
||||
except FileNotFoundError:
|
||||
logger.error(f"Command not found: '{command}' when trying to execute '{cmd_string}'")
|
||||
raise # Re-raise specific error
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create subprocess for '{cmd_string}': {e}", exc_info=True)
|
||||
raise # Re-raise other errors
|
||||
raise
|
||||
|
||||
|
||||
async def stop_mcp_process(process: asyncio.subprocess.Process, server_name: str = "MCP Server"):
|
||||
@@ -89,7 +82,6 @@ async def stop_mcp_process(process: asyncio.subprocess.Process, server_name: str
|
||||
pid = process.pid
|
||||
logger.info(f"Attempting to stop process {server_name} (PID: {pid})...")
|
||||
|
||||
# Close stdin first
|
||||
if process.stdin and not process.stdin.is_closing():
|
||||
try:
|
||||
process.stdin.close()
|
||||
@@ -98,7 +90,6 @@ async def stop_mcp_process(process: asyncio.subprocess.Process, server_name: str
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing stdin for {server_name} (PID: {pid}): {e}")
|
||||
|
||||
# Attempt graceful termination
|
||||
try:
|
||||
process.terminate()
|
||||
logger.debug(f"Sent terminate signal to {server_name} (PID: {pid})")
|
||||
@@ -108,7 +99,7 @@ async def stop_mcp_process(process: asyncio.subprocess.Process, server_name: str
|
||||
logger.warning(f"Process {server_name} (PID: {pid}) did not terminate gracefully after 5s, killing.")
|
||||
try:
|
||||
process.kill()
|
||||
await process.wait() # Wait for kill to complete
|
||||
await process.wait()
|
||||
logger.info(f"Process {server_name} (PID: {pid}) killed (return code: {process.returncode}).")
|
||||
except ProcessLookupError:
|
||||
logger.warning(f"Process {server_name} (PID: {pid}) already exited before kill.")
|
||||
@@ -118,7 +109,6 @@ async def stop_mcp_process(process: asyncio.subprocess.Process, server_name: str
|
||||
logger.warning(f"Process {server_name} (PID: {pid}) already exited before termination.")
|
||||
except Exception as e_term:
|
||||
logger.error(f"Error during termination of {server_name} (PID: {pid}): {e_term}")
|
||||
# Attempt kill as fallback if terminate failed and process might still be running
|
||||
if process.returncode is None:
|
||||
try:
|
||||
process.kill()
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# src/mcp/protocol.py
|
||||
"""Async utilities for MCP JSON-RPC communication over streams."""
|
||||
|
||||
import asyncio
|
||||
@@ -28,10 +27,10 @@ async def send_request(writer: asyncio.StreamWriter, request_dict: dict[str, Any
|
||||
logger.debug(f"Sent request ID {request_dict.get('id')}: {request_json.strip()}")
|
||||
except ConnectionResetError:
|
||||
logger.error(f"Connection lost while sending request ID {request_dict.get('id')}")
|
||||
raise # Re-raise for the caller (MCPClient) to handle
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending request ID {request_dict.get('id')}: {e}", exc_info=True)
|
||||
raise # Re-raise for the caller
|
||||
raise
|
||||
|
||||
|
||||
async def read_response(reader: asyncio.StreamReader, timeout: float) -> dict[str, Any] | None:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# src/llm_client.py
|
||||
"""
|
||||
Generic LLM client supporting multiple providers and MCP tool integration.
|
||||
"""
|
||||
@@ -26,6 +25,7 @@ class LLMClient:
|
||||
api_key: str,
|
||||
mcp_manager: SyncMCPManager,
|
||||
base_url: str | None = None,
|
||||
temperature: float = 0.6, # Add temperature parameter with a fallback default
|
||||
):
|
||||
"""
|
||||
Initialize the LLM client.
|
||||
@@ -35,9 +35,15 @@ class LLMClient:
|
||||
api_key: API key for the provider.
|
||||
mcp_manager: An initialized instance of SyncMCPManager.
|
||||
base_url: Optional base URL for the provider API.
|
||||
temperature: Default temperature to configure the provider with.
|
||||
"""
|
||||
logger.info(f"Initializing LLMClient for provider: {provider_name}")
|
||||
self.provider: BaseProvider = create_llm_provider(provider_name, api_key, base_url)
|
||||
self.provider: BaseProvider = create_llm_provider(
|
||||
provider_name,
|
||||
api_key,
|
||||
base_url,
|
||||
temperature=temperature, # Pass temperature to provider factory
|
||||
)
|
||||
self.mcp_manager = mcp_manager
|
||||
self.mcp_tools: list[dict[str, Any]] = []
|
||||
self._refresh_mcp_tools() # Initial tool load
|
||||
@@ -56,7 +62,7 @@ class LLMClient:
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
model: str,
|
||||
temperature: float = 0.6,
|
||||
# temperature: float = 0.6, # REMOVE THIS LINE
|
||||
max_tokens: int | None = None,
|
||||
stream: bool = True,
|
||||
) -> Generator[str, None, None] | dict[str, Any]:
|
||||
@@ -66,7 +72,7 @@ class LLMClient:
|
||||
Args:
|
||||
messages: List of message dictionaries ({'role': 'user'/'assistant', 'content': ...}).
|
||||
model: Model identifier string.
|
||||
temperature: Sampling temperature.
|
||||
# temperature: REMOVED - Provider uses its configured temperature.
|
||||
max_tokens: Maximum tokens to generate.
|
||||
stream: Whether to stream the response.
|
||||
|
||||
@@ -92,7 +98,7 @@ class LLMClient:
|
||||
response = self.provider.create_chat_completion(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
# temperature=temperature, # REMOVE THIS LINE (provider uses its own)
|
||||
max_tokens=max_tokens,
|
||||
stream=stream,
|
||||
tools=provider_tools,
|
||||
@@ -169,7 +175,7 @@ class LLMClient:
|
||||
follow_up_response = self.provider.create_chat_completion(
|
||||
messages=messages, # Now includes assistant's turn and tool results
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
# temperature=temperature, # REMOVE THIS LINE
|
||||
max_tokens=max_tokens,
|
||||
stream=False, # Follow-up is non-streaming here
|
||||
tools=provider_tools, # Pass tools again? Some providers might need it.
|
||||
@@ -213,17 +219,3 @@ class LLMClient:
|
||||
except Exception as e:
|
||||
logger.error(f"Error during streaming: {e}", exc_info=True)
|
||||
yield json.dumps({"error": f"Streaming error: {str(e)}"}) # Yield error as JSON chunk
|
||||
|
||||
|
||||
# Example of how a provider might need to implement get_original_message_with_calls
|
||||
# This would be in the specific provider class (e.g., openai_provider.py)
|
||||
# def get_original_message_with_calls(self, response: Any) -> Dict[str, Any]:
|
||||
# # For OpenAI, the tool calls are usually in the *first* response chunk's choice delta
|
||||
# # or in the non-streaming response's choice message
|
||||
# # Needs careful implementation based on provider's response structure
|
||||
# assistant_message = {
|
||||
# "role": "assistant",
|
||||
# "content": None, # Often null when tool calls are present
|
||||
# "tool_calls": [...] # Extracted tool calls in provider format
|
||||
# }
|
||||
# return assistant_message
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
MODELS = {
|
||||
"openai": {
|
||||
"name": "OpenAI",
|
||||
"endpoint": "https://api.openai.com/v1",
|
||||
"models": [
|
||||
{
|
||||
"id": "gpt-4o",
|
||||
"name": "GPT-4o",
|
||||
"default": True,
|
||||
"context_window": 128000,
|
||||
"description": "Input $5/M tokens, Output $15/M tokens",
|
||||
}
|
||||
],
|
||||
},
|
||||
"anthropic": {
|
||||
"name": "Anthropic",
|
||||
"endpoint": "https://api.anthropic.com/v1/messages",
|
||||
"models": [
|
||||
{
|
||||
"id": "claude-3-7-sonnet-20250219",
|
||||
"name": "Claude 3.7 Sonnet",
|
||||
"default": True,
|
||||
"context_window": 200000,
|
||||
"description": "Input $3/M tokens, Output $15/M tokens",
|
||||
},
|
||||
{
|
||||
"id": "claude-3-5-haiku-20241022",
|
||||
"name": "Claude 3.5 Haiku",
|
||||
"default": False,
|
||||
"context_window": 200000,
|
||||
"description": "Input $0.80/M tokens, Output $4/M tokens",
|
||||
},
|
||||
],
|
||||
},
|
||||
"google": {
|
||||
"name": "Google Gemini",
|
||||
"endpoint": "https://generativelanguage.googleapis.com/v1beta/generateContent",
|
||||
"models": [
|
||||
{
|
||||
"id": "gemini-2.0-flash",
|
||||
"name": "Gemini 2.0 Flash",
|
||||
"default": True,
|
||||
"context_window": 1000000,
|
||||
"description": "Input $0.1/M tokens, Output $0.4/M tokens",
|
||||
}
|
||||
],
|
||||
},
|
||||
"openrouter": {
|
||||
"name": "OpenRouter",
|
||||
"endpoint": "https://openrouter.ai/api/v1/chat/completions",
|
||||
"models": [
|
||||
{
|
||||
"id": "custom",
|
||||
"name": "Custom Model",
|
||||
"default": False,
|
||||
"context_window": 128000, # Default context window, will be updated based on model
|
||||
"description": "Enter any model name supported by OpenRouter (e.g., 'anthropic/claude-3-opus', 'meta-llama/llama-2-70b')",
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
@@ -1,4 +1,3 @@
|
||||
# src/providers/__init__.py
|
||||
import logging
|
||||
|
||||
from providers.anthropic_provider import AnthropicProvider
|
||||
@@ -6,11 +5,8 @@ from providers.base import BaseProvider
|
||||
from providers.google_provider import GoogleProvider
|
||||
from providers.openai_provider import OpenAIProvider
|
||||
|
||||
# from providers.openrouter_provider import OpenRouterProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Map provider names (lowercase) to their corresponding class implementations
|
||||
PROVIDER_MAP: dict[str, type[BaseProvider]] = {
|
||||
"openai": OpenAIProvider,
|
||||
"anthropic": AnthropicProvider,
|
||||
@@ -27,7 +23,7 @@ def register_provider(name: str, provider_class: type[BaseProvider]):
|
||||
logger.info(f"Registered provider: {name}")
|
||||
|
||||
|
||||
def create_llm_provider(provider_name: str, api_key: str, base_url: str | None = None) -> BaseProvider:
|
||||
def create_llm_provider(provider_name: str, api_key: str, base_url: str | None = None, temperature: float = 0.6) -> BaseProvider:
|
||||
"""
|
||||
Factory function to create an instance of a specific LLM provider.
|
||||
|
||||
@@ -48,9 +44,9 @@ def create_llm_provider(provider_name: str, api_key: str, base_url: str | None =
|
||||
available = ", ".join(PROVIDER_MAP.keys()) or "None"
|
||||
raise ValueError(f"Unsupported LLM provider: '{provider_name}'. Available providers: {available}")
|
||||
|
||||
logger.info(f"Creating LLM provider instance for: {provider_name}")
|
||||
logger.info(f"Creating LLM provider instance for: {provider_name} with temperature: {temperature}")
|
||||
try:
|
||||
return provider_class(api_key=api_key, base_url=base_url)
|
||||
return provider_class(api_key=api_key, base_url=base_url, temperature=temperature)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to instantiate provider '{provider_name}': {e}", exc_info=True)
|
||||
raise RuntimeError(f"Could not create provider '{provider_name}'.") from e
|
||||
@@ -59,11 +55,3 @@ def create_llm_provider(provider_name: str, api_key: str, base_url: str | None =
|
||||
def get_available_providers() -> list[str]:
|
||||
"""Returns a list of registered provider names."""
|
||||
return list(PROVIDER_MAP.keys())
|
||||
|
||||
|
||||
# Example of how specific providers would register themselves if structured as plugins,
|
||||
# but for now, we'll explicitly import and map them above.
|
||||
# def load_providers():
|
||||
# # Potentially load providers dynamically if designed as plugins
|
||||
# pass
|
||||
# load_providers()
|
||||
|
||||
@@ -6,11 +6,21 @@ from providers.base import BaseProvider
|
||||
|
||||
|
||||
class AnthropicProvider(BaseProvider):
|
||||
def __init__(self, api_key: str, base_url: str | None = None):
|
||||
self.client = initialize_client(api_key, base_url)
|
||||
temperature: float
|
||||
|
||||
def create_chat_completion(self, messages, model, temperature=0.4, max_tokens=None, stream=True, tools=None):
|
||||
return create_chat_completion(self, messages, model, temperature, max_tokens, stream, tools)
|
||||
def __init__(self, api_key: str, base_url: str | None = None, temperature: float = 0.6):
|
||||
self.client = initialize_client(api_key, base_url)
|
||||
self.temperature = temperature
|
||||
|
||||
def create_chat_completion(
|
||||
self,
|
||||
messages,
|
||||
model,
|
||||
max_tokens=None,
|
||||
stream=True,
|
||||
tools=None,
|
||||
):
|
||||
return create_chat_completion(self, messages, model, self.temperature, max_tokens, stream, tools)
|
||||
|
||||
def get_streaming_content(self, response):
|
||||
return get_streaming_content(response)
|
||||
|
||||
@@ -84,17 +84,12 @@ def convert_to_anthropic_tools(mcp_tools: list[dict[str, Any]]) -> list[dict[str
|
||||
logger.warning(f"Skipping invalid MCP tool definition during Anthropic conversion: {tool}")
|
||||
continue
|
||||
|
||||
# Prefix tool name with server name for routing
|
||||
prefixed_tool_name = f"{server_name}__{tool_name}"
|
||||
|
||||
# Initialize the Anthropic tool structure
|
||||
# Anthropic's format is quite close to JSON Schema
|
||||
anthropic_tool = {"name": prefixed_tool_name, "description": description, "input_schema": input_schema}
|
||||
|
||||
# Basic validation/cleaning of schema if needed
|
||||
if not isinstance(input_schema, dict) or input_schema.get("type") != "object":
|
||||
logger.warning(f"Input schema for tool '{prefixed_tool_name}' is not a valid JSON object schema. Anthropic might reject this.")
|
||||
# Ensure basic structure if missing
|
||||
if not isinstance(input_schema, dict):
|
||||
input_schema = {}
|
||||
if "type" not in input_schema:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# src/providers/base.py
|
||||
import abc
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
@@ -147,8 +146,3 @@ class BaseProvider(abc.ABC):
|
||||
or None if usage information is not available.
|
||||
"""
|
||||
pass
|
||||
|
||||
# Optional: Add a method for follow-up completions if the provider API
|
||||
# requires a specific structure different from just appending messages.
|
||||
# def create_follow_up_completion(...) -> Any:
|
||||
# pass
|
||||
|
||||
@@ -1,14 +1,10 @@
|
||||
# src/providers/google_provider/__init__.py
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
# Import Generator type for isinstance check - Keep this import for type hints
|
||||
from google.genai.types import GenerateContentResponse
|
||||
|
||||
from providers.google_provider.client import initialize_client
|
||||
|
||||
# Correctly import the renamed function directly
|
||||
from providers.google_provider.completion import create_chat_completion
|
||||
from providers.google_provider.response import get_content, get_streaming_content, get_usage
|
||||
from providers.google_provider.tools import convert_to_google_tools, format_google_tool_results, has_google_tool_calls, parse_google_tool_calls
|
||||
@@ -20,45 +16,41 @@ logger = logging.getLogger(__name__)
|
||||
class GoogleProvider(BaseProvider):
|
||||
"""Provider implementation for Google Generative AI (Gemini)."""
|
||||
|
||||
# Type hint for the client (it's the configured 'genai' module itself)
|
||||
client_module: Any
|
||||
temperature: float
|
||||
|
||||
def __init__(self, api_key: str, base_url: str | None = None):
|
||||
def __init__(self, api_key: str, base_url: str | None = None, temperature: float = 0.6):
|
||||
"""
|
||||
Initializes the GoogleProvider.
|
||||
|
||||
Args:
|
||||
api_key: The Google API key.
|
||||
base_url: Base URL (typically not used by Google client config, but kept for interface consistency).
|
||||
temperature: The default temperature for completions.
|
||||
"""
|
||||
# initialize_client returns the client instance now
|
||||
self.client_module = initialize_client(api_key, base_url)
|
||||
self.api_key = api_key # Store if needed later
|
||||
self.base_url = base_url # Store if needed later
|
||||
logger.info("GoogleProvider initialized.")
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
self.temperature = temperature
|
||||
logger.info(f"GoogleProvider initialized with temperature: {self.temperature}")
|
||||
|
||||
def create_chat_completion(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
model: str,
|
||||
temperature: float = 0.6,
|
||||
max_tokens: int | None = None,
|
||||
stream: bool = True,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
) -> Any: # Return type is complex: iterator for stream, GenerateContentResponse otherwise, or error dict/iterator
|
||||
) -> Any:
|
||||
"""Creates a chat completion using the Google Gemini API."""
|
||||
# Pass self (provider instance) to the helper function
|
||||
raw_response = create_chat_completion(self, messages, model, temperature, max_tokens, stream, tools)
|
||||
print(f"Raw response type: {type(raw_response)}") # Debugging line to check the type of raw_response
|
||||
print(f"Raw response: {raw_response}") # Debugging line to check the content of raw_response
|
||||
raw_response = create_chat_completion(self, messages, model, self.temperature, max_tokens, stream, tools)
|
||||
print(f"Raw response type: {type(raw_response)}")
|
||||
print(f"Raw response: {raw_response}")
|
||||
|
||||
# The completion helper function handles returning the correct type or an error dict.
|
||||
# No need for generator handling here anymore.
|
||||
return raw_response
|
||||
|
||||
def get_streaming_content(self, response: Any) -> Generator[str, None, None]:
|
||||
"""Extracts content chunks from a Google streaming response."""
|
||||
# Response is expected to be an iterator from generate_content_stream
|
||||
return get_streaming_content(response)
|
||||
|
||||
def get_content(self, response: GenerateContentResponse | dict[str, Any]) -> str:
|
||||
@@ -67,33 +59,20 @@ class GoogleProvider(BaseProvider):
|
||||
|
||||
def has_tool_calls(self, response: GenerateContentResponse | dict[str, Any]) -> bool:
|
||||
"""Checks if the Google response contains tool calls (FunctionCalls)."""
|
||||
# Note: For streaming responses, this check is reliable only after the stream is fully consumed
|
||||
# or if the specific chunk containing the call is processed.
|
||||
return has_google_tool_calls(response)
|
||||
|
||||
def parse_tool_calls(self, response: GenerateContentResponse | dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""Parses tool calls (FunctionCalls) from a non-streaming Google response."""
|
||||
# Expects a non-streaming GenerateContentResponse or an error dict
|
||||
return parse_google_tool_calls(response)
|
||||
|
||||
# Note: Google's format_tool_results helper requires the original function_name.
|
||||
# Ensure the calling code (e.g., LLMClient) provides this when invoking this method.
|
||||
def format_tool_results(self, tool_call_id: str, function_name: str, result: Any) -> dict[str, Any]:
|
||||
"""Formats a tool result for a Google follow-up request (into standard message format)."""
|
||||
return format_google_tool_results(tool_call_id, function_name, result)
|
||||
|
||||
def convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Converts MCP tools list to Google's intermediate dictionary format."""
|
||||
# The `create_chat_completion` function handles the final conversion
|
||||
# from this intermediate format to Google's `Tool` objects internally.
|
||||
return convert_to_google_tools(tools)
|
||||
|
||||
def get_usage(self, response: GenerateContentResponse | dict[str, Any]) -> dict[str, int] | None:
|
||||
"""Extracts token usage information from a Google response."""
|
||||
# Expects a non-streaming GenerateContentResponse or an error dict
|
||||
return get_usage(response)
|
||||
|
||||
# `get_original_message_with_calls` (present in OpenAIProvider) is not implemented here
|
||||
# as Google's API structure integrates FunctionCall parts directly into the assistant's
|
||||
# message content, rather than having a separate `tool_calls` attribute on the message object.
|
||||
# The necessary information is handled during message conversion and tool call parsing.
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# src/providers/google_provider/client.py
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
@@ -16,12 +15,10 @@ def initialize_client(api_key: str, base_url: str | None = None) -> Any:
|
||||
raise ImportError("Google Generative AI SDK is required for GoogleProvider. Please install google-generativeai.")
|
||||
|
||||
try:
|
||||
# Instantiate the client directly using the API key
|
||||
client = genai.Client(api_key=api_key)
|
||||
logger.info("Google Generative AI client instantiated.")
|
||||
if base_url:
|
||||
logger.warning(f"base_url '{base_url}' provided but not typically used by Google client instantiation.")
|
||||
# Return the client instance
|
||||
return client
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to instantiate Google Generative AI client: {e}", exc_info=True)
|
||||
|
||||
@@ -1,19 +1,16 @@
|
||||
import logging
|
||||
import traceback
|
||||
from collections.abc import Iterable # Added Iterable
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
# Import specific types for better hinting
|
||||
from google.genai.types import ContentDict, GenerateContentResponse, GenerationConfigDict, Tool
|
||||
|
||||
# Removed convert_to_google_tools import as it's handled later
|
||||
from providers.google_provider.tools import convert_to_google_tool_objects
|
||||
from providers.google_provider.utils import convert_messages
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# --- Helper for Non-Streaming ---
|
||||
def _create_chat_completion_non_stream(
|
||||
provider,
|
||||
model: str,
|
||||
@@ -23,105 +20,85 @@ def _create_chat_completion_non_stream(
|
||||
"""Handles the non-streaming API call."""
|
||||
try:
|
||||
logger.debug("Calling client.models.generate_content...")
|
||||
# Use the client instance stored on the provider
|
||||
response = provider.client_module.models.generate_content(
|
||||
model=f"models/{model}",
|
||||
contents=google_messages,
|
||||
config=generation_config,
|
||||
)
|
||||
logger.debug("generate_content call successful, returning raw response object.")
|
||||
# Return the direct response object
|
||||
return response
|
||||
except ValueError as ve:
|
||||
error_msg = f"Google API request validation error: {ve}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
# Return error dict
|
||||
return {"error": error_msg, "traceback": traceback.format_exc()}
|
||||
except Exception as e:
|
||||
error_msg = f"Google API error during non-stream chat completion: {e}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
# Return error dict
|
||||
return {"error": error_msg, "traceback": traceback.format_exc()}
|
||||
|
||||
|
||||
# --- Helper for Streaming ---
|
||||
def _create_chat_completion_stream(
|
||||
provider,
|
||||
model: str,
|
||||
google_messages: list[ContentDict],
|
||||
generation_config: GenerationConfigDict,
|
||||
) -> Iterable[GenerateContentResponse | dict[str, Any]]: # Return Iterable of response chunks or error dict
|
||||
) -> Iterable[GenerateContentResponse | dict[str, Any]]:
|
||||
"""Handles the streaming API call and yields results."""
|
||||
try:
|
||||
logger.debug("Calling client.models.generate_content_stream...")
|
||||
# Use the client instance stored on the provider
|
||||
response_iterator = provider.client_module.models.generate_content_stream(
|
||||
model=f"models/{model}",
|
||||
contents=google_messages,
|
||||
config=generation_config,
|
||||
)
|
||||
logger.debug("generate_content_stream call successful, yielding from iterator.")
|
||||
# Yield from the SDK's iterator which produces GenerateContentResponse chunks
|
||||
yield from response_iterator
|
||||
except ValueError as ve:
|
||||
error_msg = f"Google API request validation error: {ve}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
# Yield error as a dict matching non-streaming error structure
|
||||
yield {"error": error_msg, "traceback": traceback.format_exc()}
|
||||
except Exception as e:
|
||||
error_msg = f"Google API error during stream chat completion: {e}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
# Yield error as a dict
|
||||
yield {"error": error_msg, "traceback": traceback.format_exc()}
|
||||
|
||||
|
||||
# --- Main Function ---
|
||||
# Renamed original function to avoid conflict if needed, though overwrite is fine
|
||||
def create_chat_completion(
|
||||
provider, # Provider instance is passed in
|
||||
provider,
|
||||
messages: list[dict[str, Any]],
|
||||
model: str,
|
||||
temperature: float = 0.6,
|
||||
max_tokens: int | None = None,
|
||||
stream: bool = True,
|
||||
tools: list[dict[str, Any]] | None = None, # Expects intermediate dict format
|
||||
) -> Any: # Return type depends on stream flag
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Creates a chat completion using the Google Gemini API.
|
||||
Delegates to streaming or non-streaming helpers. Contains NO yield itself.
|
||||
"""
|
||||
logger.debug(f"Google create_chat_completion_inner called. Model: {model}, Stream: {stream}, Tools: {bool(tools)}")
|
||||
|
||||
# Check if client exists on the provider instance
|
||||
if provider.client_module is None:
|
||||
error_msg = "Google Generative AI client not initialized on provider."
|
||||
logger.error(error_msg)
|
||||
# Return error dict directly for non-stream, create iterator for stream
|
||||
return iter([{"error": error_msg}]) if stream else {"error": error_msg}
|
||||
|
||||
try:
|
||||
# 1. Convert messages (Common logic)
|
||||
google_messages, system_prompt = convert_messages(messages)
|
||||
logger.debug(f"Converted {len(messages)} messages to {len(google_messages)} Google Content objects. System prompt present: {bool(system_prompt)}")
|
||||
|
||||
# 2. Prepare generation configuration (Common logic)
|
||||
# Use GenerationConfigDict for better type hinting if possible
|
||||
generation_config: GenerationConfigDict = {"temperature": temperature}
|
||||
if max_tokens is not None:
|
||||
generation_config["max_output_tokens"] = max_tokens
|
||||
logger.debug(f"Setting max_output_tokens: {max_tokens}")
|
||||
else:
|
||||
# Google requires max_output_tokens, set a default if None
|
||||
# Defaulting to a reasonable value, e.g., 8192, check model limits if needed
|
||||
default_max_tokens = 8192
|
||||
generation_config["max_output_tokens"] = default_max_tokens
|
||||
logger.warning(f"max_tokens not provided, defaulting to {default_max_tokens} for Google API.")
|
||||
|
||||
# 3. Convert tools if provided (Common logic)
|
||||
google_tool_objects: list[Tool] | None = None
|
||||
if tools:
|
||||
try:
|
||||
# Convert intermediate dict format to Google Tool objects
|
||||
google_tool_objects = convert_to_google_tool_objects(tools)
|
||||
if google_tool_objects:
|
||||
num_declarations = sum(len(t.function_declarations) for t in google_tool_objects if t.function_declarations)
|
||||
@@ -130,21 +107,17 @@ def create_chat_completion(
|
||||
logger.warning("Tool conversion resulted in no valid Google Tool objects.")
|
||||
except Exception as tool_conv_err:
|
||||
logger.error(f"Failed to convert tools for Google: {tool_conv_err}", exc_info=True)
|
||||
google_tool_objects = None # Continue without tools on conversion error
|
||||
google_tool_objects = None
|
||||
else:
|
||||
logger.debug("No tools provided for conversion.")
|
||||
|
||||
# 4. Add system prompt and tools to generation_config (Common logic)
|
||||
if system_prompt:
|
||||
# Ensure system_instruction is ContentDict or compatible type
|
||||
generation_config["system_instruction"] = system_prompt
|
||||
logger.debug("Added system_instruction to generation_config.")
|
||||
if google_tool_objects:
|
||||
# Assign the list of Tool objects directly
|
||||
generation_config["tools"] = google_tool_objects
|
||||
logger.debug(f"Added {len(google_tool_objects)} tool objects to generation_config.")
|
||||
|
||||
# 5. Log parameters before API call (Common logic)
|
||||
log_params = {
|
||||
"model": model,
|
||||
"stream": stream,
|
||||
@@ -156,19 +129,12 @@ def create_chat_completion(
|
||||
}
|
||||
logger.info(f"Calling Google API via helper with params: {log_params}")
|
||||
|
||||
# 6. Delegate to appropriate helper
|
||||
if stream:
|
||||
# Return the generator/iterator from the streaming helper
|
||||
# This helper uses 'yield from'
|
||||
return _create_chat_completion_stream(provider, model, google_messages, generation_config)
|
||||
else:
|
||||
# Return the direct result (GenerateContentResponse or error dict) from the non-streaming helper
|
||||
# This helper uses 'return'
|
||||
return _create_chat_completion_non_stream(provider, model, google_messages, generation_config)
|
||||
|
||||
except Exception as e:
|
||||
# Catch errors during common setup (message/tool conversion etc.)
|
||||
error_msg = f"Error during Google completion setup: {e}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
# Return error dict directly for non-stream, create iterator for stream
|
||||
return iter([{"error": error_msg, "traceback": traceback.format_exc()}]) if stream else {"error": error_msg, "traceback": traceback.format_exc()}
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# src/providers/google_provider/response.py
|
||||
"""
|
||||
Response handling utilities specific to the Google Generative AI provider.
|
||||
|
||||
@@ -32,50 +31,36 @@ def get_streaming_content(response: Any) -> Generator[str, None, None]:
|
||||
logger.debug("Processing Google stream...")
|
||||
full_delta = ""
|
||||
try:
|
||||
# Check if the response itself is an error indicator (e.g., from create_chat_completion error handling)
|
||||
if isinstance(response, dict) and "error" in response:
|
||||
yield json.dumps(response)
|
||||
logger.error(f"Stream processing stopped due to initial error: {response['error']}")
|
||||
return
|
||||
# Check if response is already an error iterator
|
||||
if hasattr(response, "__iter__") and not hasattr(response, "candidates"):
|
||||
# If it looks like an error iterator from create_chat_completion
|
||||
first_item = next(response, None)
|
||||
if first_item and isinstance(first_item, str):
|
||||
try:
|
||||
error_data = json.loads(first_item)
|
||||
if "error" in error_data:
|
||||
yield first_item # Yield the error JSON
|
||||
yield first_item
|
||||
yield from response
|
||||
logger.error(f"Stream processing stopped due to yielded error: {error_data['error']}")
|
||||
return
|
||||
except json.JSONDecodeError:
|
||||
# Not a JSON error, yield it as is and continue? Or stop?
|
||||
# Assuming it might be valid content if not JSON error.
|
||||
yield first_item
|
||||
elif first_item: # Put the first item back if it wasn't an error
|
||||
# This requires a way to chain iterators, simple yield doesn't work well here.
|
||||
# For simplicity, we assume error iterators yield JSON strings.
|
||||
# If the stream is valid, the loop below will handle it.
|
||||
# Re-assigning response might be complex. Let the main loop handle valid streams.
|
||||
pass # Let the main loop handle the original response iterator
|
||||
elif first_item:
|
||||
pass
|
||||
|
||||
# Process the stream chunk by chunk
|
||||
for chunk in response:
|
||||
# Check for errors embedded within the stream chunks (less common for Google?)
|
||||
if isinstance(chunk, dict) and "error" in chunk:
|
||||
yield json.dumps(chunk)
|
||||
logger.error(f"Error encountered during Google stream: {chunk['error']}")
|
||||
continue # Continue processing stream or stop? Continuing for now.
|
||||
continue
|
||||
|
||||
# Extract text content
|
||||
delta = ""
|
||||
try:
|
||||
if hasattr(chunk, "text"):
|
||||
delta = chunk.text
|
||||
elif hasattr(chunk, "candidates") and chunk.candidates:
|
||||
# Sometimes content might be nested under candidates even in stream?
|
||||
# Check the first candidate's first part for text.
|
||||
first_candidate = chunk.candidates[0]
|
||||
if hasattr(first_candidate, "content") and hasattr(first_candidate.content, "parts") and first_candidate.content.parts:
|
||||
first_part = first_candidate.content.parts[0]
|
||||
@@ -83,32 +68,27 @@ def get_streaming_content(response: Any) -> Generator[str, None, None]:
|
||||
delta = first_part.text
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not extract text from stream chunk: {chunk}. Error: {e}", exc_info=True)
|
||||
delta = "" # Ensure delta is a string
|
||||
delta = ""
|
||||
|
||||
if delta:
|
||||
full_delta += delta
|
||||
yield delta
|
||||
|
||||
# Detect function calls during stream (optional, for logging/early detection)
|
||||
try:
|
||||
if hasattr(chunk, "candidates") and chunk.candidates:
|
||||
for part in chunk.candidates[0].content.parts:
|
||||
if hasattr(part, "function_call") and part.function_call:
|
||||
logger.debug(f"Function call detected during stream: {part.function_call.name}")
|
||||
# Note: We don't yield the function call itself here, just the text.
|
||||
# Function calls are typically processed after the stream completes.
|
||||
break # Found a function call in this chunk
|
||||
break
|
||||
except Exception:
|
||||
# Ignore errors during optional function call detection in stream
|
||||
pass
|
||||
|
||||
logger.debug(f"Google stream finished. Total delta length: {len(full_delta)}")
|
||||
|
||||
except StopIteration:
|
||||
logger.debug("Google stream finished (StopIteration).") # Normal end of iteration
|
||||
logger.debug("Google stream finished (StopIteration).")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing Google stream: {e}", exc_info=True)
|
||||
# Yield a final error message
|
||||
yield json.dumps({"error": f"Stream processing error: {str(e)}"})
|
||||
|
||||
|
||||
@@ -124,27 +104,21 @@ def get_content(response: GenerateContentResponse | dict[str, Any]) -> str:
|
||||
The concatenated text content, or an error message string.
|
||||
"""
|
||||
try:
|
||||
# Check if it's an error dictionary passed from upstream (e.g., completion helper)
|
||||
if isinstance(response, dict) and "error" in response:
|
||||
logger.error(f"Cannot get content from error dict: {response['error']}")
|
||||
return f"[Error: {response['error']}]"
|
||||
|
||||
# Ensure it's a GenerateContentResponse object before accessing attributes
|
||||
if not isinstance(response, GenerateContentResponse):
|
||||
logger.error(f"Cannot get content: Expected GenerateContentResponse or error dict, got {type(response)}")
|
||||
return f"[Error: Unexpected response type {type(response)}]"
|
||||
|
||||
# --- Access GenerateContentResponse attributes ---
|
||||
# Prioritize response.text if available and not empty
|
||||
if hasattr(response, "text") and response.text:
|
||||
content = response.text
|
||||
logger.debug(f"Extracted content (length {len(content)}) from response.text.")
|
||||
return content
|
||||
|
||||
# Fallback: manually concatenate text from parts if .text is missing/empty
|
||||
if hasattr(response, "candidates") and response.candidates:
|
||||
first_candidate = response.candidates[0]
|
||||
# Check candidate content and parts carefully
|
||||
if hasattr(first_candidate, "content") and first_candidate.content and hasattr(first_candidate.content, "parts") and first_candidate.content.parts:
|
||||
text_parts = [part.text for part in first_candidate.content.parts if hasattr(part, "text")]
|
||||
if text_parts:
|
||||
@@ -153,14 +127,13 @@ def get_content(response: GenerateContentResponse | dict[str, Any]) -> str:
|
||||
return content
|
||||
else:
|
||||
logger.warning("Google response candidate parts contained no text.")
|
||||
return "" # Return empty if parts exist but have no text
|
||||
return ""
|
||||
else:
|
||||
logger.warning("Google response candidate has no valid content or parts.")
|
||||
return "" # Return empty string if no valid content/parts
|
||||
return ""
|
||||
else:
|
||||
# If neither .text nor valid candidates are found
|
||||
logger.warning(f"Could not extract content from Google response: No .text or valid candidates found. Response: {response}")
|
||||
return "" # Return empty string if no text found
|
||||
return ""
|
||||
|
||||
except AttributeError as ae:
|
||||
logger.error(f"Attribute error extracting content from Google response: {ae}. Response type: {type(response)}", exc_info=True)
|
||||
@@ -182,20 +155,16 @@ def get_usage(response: GenerateContentResponse | dict[str, Any]) -> dict[str, i
|
||||
usage information is unavailable or an error occurred.
|
||||
"""
|
||||
try:
|
||||
# Check if it's an error dictionary passed from upstream
|
||||
if isinstance(response, dict) and "error" in response:
|
||||
logger.warning(f"Cannot get usage from error dict: {response['error']}")
|
||||
return None
|
||||
|
||||
# Ensure it's a GenerateContentResponse object before accessing attributes
|
||||
if not isinstance(response, GenerateContentResponse):
|
||||
logger.warning(f"Cannot get usage: Expected GenerateContentResponse or error dict, got {type(response)}")
|
||||
return None
|
||||
|
||||
# Safely access usage metadata
|
||||
metadata = getattr(response, "usage_metadata", None)
|
||||
if metadata:
|
||||
# Google uses prompt_token_count and candidates_token_count
|
||||
prompt_tokens = getattr(metadata, "prompt_token_count", 0)
|
||||
completion_tokens = getattr(metadata, "candidates_token_count", 0)
|
||||
usage = {
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# src/providers/google_provider/tools.py
|
||||
"""
|
||||
Tool handling utilities specific to the Google Generative AI provider.
|
||||
|
||||
@@ -18,9 +17,6 @@ from google.genai.types import FunctionDeclaration, Schema, Tool, Type
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# --- Tool Conversion (from MCP format to Google format) ---
|
||||
|
||||
|
||||
def convert_to_google_tools(mcp_tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Convert MCP tools to Google Gemini format (dictionary structure).
|
||||
@@ -48,41 +44,34 @@ def convert_to_google_tools(mcp_tools: list[dict[str, Any]]) -> list[dict[str, A
|
||||
logger.warning(f"Skipping invalid MCP tool definition during Google conversion: {tool}")
|
||||
continue
|
||||
|
||||
# Prefix tool name with server name for routing
|
||||
prefixed_tool_name = f"{server_name}__{tool_name}"
|
||||
|
||||
# Basic validation/cleaning of schema for Google compatibility
|
||||
if not isinstance(input_schema, dict) or input_schema.get("type") != "object":
|
||||
logger.warning(f"Input schema for tool '{prefixed_tool_name}' is not a valid JSON object schema. Google might reject this. Attempting to normalize.")
|
||||
# Ensure basic structure if missing
|
||||
if not isinstance(input_schema, dict):
|
||||
input_schema = {} # Start fresh if not a dict
|
||||
input_schema = {}
|
||||
if "type" not in input_schema or input_schema["type"] != "object":
|
||||
# Wrap existing schema or create new if type is wrong/missing
|
||||
input_schema = {"type": "object", "properties": {"_original_schema": input_schema}} if input_schema else {"type": "object", "properties": {}}
|
||||
logger.warning(f"Wrapped original schema for {prefixed_tool_name} under '_original_schema' property.")
|
||||
|
||||
if "properties" not in input_schema:
|
||||
input_schema["properties"] = {}
|
||||
|
||||
# Google requires properties for object type, add dummy if empty
|
||||
if not input_schema["properties"]:
|
||||
logger.warning(f"Empty properties for tool '{prefixed_tool_name}', adding dummy property for Google.")
|
||||
input_schema["properties"] = {"_dummy_param": {"type": "STRING", "description": "Placeholder parameter as properties cannot be empty."}}
|
||||
if "required" in input_schema and not isinstance(input_schema.get("required"), list):
|
||||
input_schema["required"] = [] # Clear invalid required list
|
||||
input_schema["required"] = []
|
||||
|
||||
# Create function declaration dictionary for Google's format
|
||||
function_declaration = {
|
||||
"name": prefixed_tool_name,
|
||||
"description": description,
|
||||
"parameters": input_schema, # Google uses JSON Schema directly
|
||||
"parameters": input_schema,
|
||||
}
|
||||
|
||||
function_declarations.append(function_declaration)
|
||||
logger.debug(f"Prepared Google FunctionDeclaration dict for: {prefixed_tool_name}")
|
||||
|
||||
# Google API expects a list containing one dictionary with 'function_declarations' key
|
||||
google_tool_config = [{"function_declarations": function_declarations}] if function_declarations else []
|
||||
|
||||
logger.debug(f"Final Google tool config structure (pre-Tool object): {google_tool_config}")
|
||||
@@ -101,9 +90,8 @@ def _create_google_schema_recursive(schema_dict: dict[str, Any]) -> Schema | Non
|
||||
|
||||
if not isinstance(schema_dict, dict):
|
||||
logger.warning(f"Invalid schema part encountered: {schema_dict}. Returning None.")
|
||||
return None # Return None on invalid input
|
||||
return None
|
||||
|
||||
# Map JSON Schema types to Google's Type enum members
|
||||
type_mapping = {
|
||||
"string": Type.STRING,
|
||||
"number": Type.NUMBER,
|
||||
@@ -117,16 +105,14 @@ def _create_google_schema_recursive(schema_dict: dict[str, Any]) -> Schema | Non
|
||||
|
||||
if not google_type:
|
||||
logger.warning(f"Schema dictionary missing 'type' or type '{original_type}' is not recognized: {schema_dict}. Returning None.")
|
||||
return None # Return None if type is invalid/missing
|
||||
return None
|
||||
|
||||
# Prepare arguments for Schema constructor, filtering out None values
|
||||
schema_args = {
|
||||
"type": google_type, # Use the Type enum member
|
||||
"type": google_type,
|
||||
"format": schema_dict.get("format"),
|
||||
"description": schema_dict.get("description"),
|
||||
"nullable": schema_dict.get("nullable"), # Note: Google's Schema might not directly support nullable in constructor
|
||||
"nullable": schema_dict.get("nullable"),
|
||||
"enum": schema_dict.get("enum"),
|
||||
# Recursively create nested schemas, ensuring None is handled if recursion fails
|
||||
"items": _create_google_schema_recursive(schema_dict["items"]) if google_type == Type.ARRAY and "items" in schema_dict else None,
|
||||
"properties": {k: prop_schema for k, v in schema_dict.get("properties", {}).items() if (prop_schema := _create_google_schema_recursive(v)) is not None}
|
||||
if google_type == Type.OBJECT and schema_dict.get("properties")
|
||||
@@ -134,27 +120,20 @@ def _create_google_schema_recursive(schema_dict: dict[str, Any]) -> Schema | Non
|
||||
"required": schema_dict.get("required") if google_type == Type.OBJECT else None,
|
||||
}
|
||||
|
||||
# Remove keys with None values before passing to Schema constructor
|
||||
schema_args = {k: v for k, v in schema_args.items() if v is not None}
|
||||
|
||||
# Handle specific cases for ARRAY and OBJECT where items/properties might be needed
|
||||
if google_type == Type.ARRAY and "items" not in schema_args:
|
||||
logger.warning(f"Array schema missing 'items': {schema_dict}. Returning None.")
|
||||
return None # Array schema requires items
|
||||
return None
|
||||
if google_type == Type.OBJECT and "properties" not in schema_args:
|
||||
# Allow object schema without properties initially, might be handled later
|
||||
pass
|
||||
# logger.warning(f"Object schema missing 'properties': {schema_dict}. Creating empty properties.")
|
||||
# schema_args["properties"] = {} # Or return None if properties are strictly required
|
||||
|
||||
try:
|
||||
# Create the Schema object
|
||||
created_schema = Schema(**schema_args)
|
||||
# logger.debug(f"Successfully created Schema: {created_schema}")
|
||||
return created_schema
|
||||
except Exception as schema_creation_err:
|
||||
logger.error(f"Failed to create Schema object with args {schema_args}: {schema_creation_err}", exc_info=True)
|
||||
return None # Return None on creation error
|
||||
return None
|
||||
|
||||
|
||||
def convert_to_google_tool_objects(tool_configs: list[dict[str, Any]]) -> list[Tool] | None:
|
||||
@@ -177,7 +156,6 @@ def convert_to_google_tool_objects(tool_configs: list[dict[str, Any]]) -> list[T
|
||||
return None
|
||||
|
||||
all_func_declarations = []
|
||||
# Expecting structure like [{"function_declarations": [...]}]
|
||||
if isinstance(tool_configs, list) and len(tool_configs) > 0 and "function_declarations" in tool_configs[0]:
|
||||
func_declarations_list = tool_configs[0]["function_declarations"]
|
||||
if not isinstance(func_declarations_list, list):
|
||||
@@ -189,15 +167,13 @@ def convert_to_google_tool_objects(tool_configs: list[dict[str, Any]]) -> list[T
|
||||
try:
|
||||
params_schema_dict = func_dict.get("parameters", {})
|
||||
|
||||
# Ensure parameters is a dict and defaults to object type if missing
|
||||
if not isinstance(params_schema_dict, dict):
|
||||
logger.warning(f"Invalid 'parameters' format for tool {func_name}: {params_schema_dict}. Using empty object schema.")
|
||||
params_schema_dict = {"type": "object", "properties": {}}
|
||||
elif "type" not in params_schema_dict:
|
||||
params_schema_dict["type"] = "object" # Default to object if type is missing
|
||||
params_schema_dict["type"] = "object"
|
||||
elif params_schema_dict["type"] != "object":
|
||||
logger.warning(f"Tool {func_name} parameters schema is not type 'object' ({params_schema_dict.get('type')}). Google requires 'object'. Attempting to wrap properties.")
|
||||
# Attempt to salvage properties if the top level isn't object
|
||||
original_properties = params_schema_dict.get("properties", {})
|
||||
if not isinstance(original_properties, dict):
|
||||
original_properties = {}
|
||||
@@ -215,14 +191,11 @@ def convert_to_google_tool_objects(tool_configs: list[dict[str, Any]]) -> list[T
|
||||
else:
|
||||
logger.warning(f"'properties' for tool {func_name} is not a dictionary: {properties_dict}. Ignoring properties.")
|
||||
|
||||
# Handle empty properties - Google requires parameters to be OBJECT, and properties cannot be null/empty
|
||||
if not google_properties:
|
||||
logger.warning(f"Function '{func_name}' has no valid properties defined. Adding dummy property for Google compatibility.")
|
||||
google_properties = {"_dummy_param": Schema(type=Type.STRING, description="Placeholder parameter as properties cannot be empty.")}
|
||||
# Clear required list if properties are empty/dummy
|
||||
required_list = []
|
||||
else:
|
||||
# Validate required list against actual properties
|
||||
original_required = params_schema_dict.get("required", [])
|
||||
if isinstance(original_required, list):
|
||||
required_list = [req for req in original_required if req in google_properties]
|
||||
@@ -232,14 +205,12 @@ def convert_to_google_tool_objects(tool_configs: list[dict[str, Any]]) -> list[T
|
||||
logger.warning(f"'required' field for '{func_name}' is not a list: {original_required}. Ignoring required field.")
|
||||
required_list = []
|
||||
|
||||
# Create the top-level parameters schema, ensuring it's OBJECT type
|
||||
parameters_schema = Schema(
|
||||
type=Type.OBJECT,
|
||||
properties=google_properties,
|
||||
required=required_list if required_list else None, # Pass None if empty list
|
||||
required=required_list if required_list else None,
|
||||
)
|
||||
|
||||
# Create the FunctionDeclaration
|
||||
declaration = FunctionDeclaration(
|
||||
name=func_name,
|
||||
description=func_dict.get("description", ""),
|
||||
@@ -259,14 +230,10 @@ def convert_to_google_tool_objects(tool_configs: list[dict[str, Any]]) -> list[T
|
||||
logger.warning("No valid Google FunctionDeclarations were created from the provided configurations.")
|
||||
return None
|
||||
|
||||
# Google expects a list containing one Tool object
|
||||
logger.info(f"Successfully created {len(all_func_declarations)} Google FunctionDeclarations.")
|
||||
return [Tool(function_declarations=all_func_declarations)]
|
||||
|
||||
|
||||
# --- Tool Call Parsing and Handling (from Google response) ---
|
||||
|
||||
|
||||
def has_google_tool_calls(response: Any) -> bool:
|
||||
"""
|
||||
Checks if the Google response object contains tool calls (FunctionCalls).
|
||||
@@ -278,7 +245,6 @@ def has_google_tool_calls(response: Any) -> bool:
|
||||
True if FunctionCalls are present, False otherwise.
|
||||
"""
|
||||
try:
|
||||
# Check non-streaming response structure
|
||||
if hasattr(response, "candidates") and response.candidates:
|
||||
candidate = response.candidates[0]
|
||||
if hasattr(candidate, "content") and hasattr(candidate.content, "parts"):
|
||||
@@ -287,10 +253,6 @@ def has_google_tool_calls(response: Any) -> bool:
|
||||
logger.debug(f"Tool call (FunctionCall) detected in Google response part: {part.function_call.name}")
|
||||
return True
|
||||
|
||||
# Note: Detecting function calls reliably in a stream might require accumulating parts.
|
||||
# This function primarily works reliably for non-streaming responses.
|
||||
# For streaming, the check might happen during stream processing itself.
|
||||
|
||||
logger.debug("No tool calls (FunctionCall) detected in Google response.")
|
||||
return False
|
||||
except Exception as e:
|
||||
@@ -326,37 +288,31 @@ def parse_google_tool_calls(response: Any) -> list[dict[str, Any]]:
|
||||
for part in candidate.content.parts:
|
||||
if hasattr(part, "function_call") and part.function_call:
|
||||
func_call = part.function_call
|
||||
# Generate a simple unique ID for this call within this response
|
||||
call_id = f"call_{call_index}"
|
||||
call_index += 1
|
||||
|
||||
# Extract server_name and func_name from the prefixed name
|
||||
full_name = func_call.name
|
||||
parts = full_name.split("__", 1)
|
||||
if len(parts) == 2:
|
||||
server_name, func_name = parts
|
||||
else:
|
||||
# If the prefix isn't found, assume it's just the function name
|
||||
logger.warning(f"Could not determine server_name from Google tool name '{full_name}'. Using None for server_name.")
|
||||
server_name = None
|
||||
func_name = full_name
|
||||
|
||||
# Convert arguments dict to JSON string
|
||||
try:
|
||||
# func_call.args is already a dict-like object (Mapping)
|
||||
args_dict = dict(func_call.args) if func_call.args else {}
|
||||
args_str = json.dumps(args_dict)
|
||||
except Exception as json_err:
|
||||
logger.error(f"Failed to dump arguments dict to JSON string for {func_name}: {json_err}")
|
||||
# Provide error info in arguments if serialization fails
|
||||
args_str = json.dumps({"error": "Failed to serialize arguments", "original_args": str(func_call.args)})
|
||||
|
||||
parsed_calls.append({
|
||||
"id": call_id, # Internal ID for tracking this call
|
||||
"id": call_id,
|
||||
"server_name": server_name,
|
||||
"function_name": func_name, # The original function name
|
||||
"arguments": args_str, # Arguments as a JSON string
|
||||
"_google_tool_name": full_name, # Keep original name if needed later
|
||||
"function_name": func_name,
|
||||
"arguments": args_str,
|
||||
"_google_tool_name": full_name,
|
||||
})
|
||||
logger.debug(f"Parsed tool call: ID {call_id}, Server {server_name}, Func {func_name}, Args {args_str[:100]}...")
|
||||
|
||||
@@ -382,17 +338,14 @@ def format_google_tool_results(tool_call_id: str, function_name: str, result: An
|
||||
This will be converted later by `_convert_messages`.
|
||||
"""
|
||||
try:
|
||||
# Google expects the 'response' field in FunctionResponse to contain a dict.
|
||||
# The content should ideally be JSON serializable. We wrap the result.
|
||||
if isinstance(result, (str, int, float, bool, list)):
|
||||
content_dict = {"result": result}
|
||||
elif isinstance(result, dict):
|
||||
content_dict = result # Assume it's already a suitable dict
|
||||
content_dict = result
|
||||
else:
|
||||
logger.warning(f"Tool result for {function_name} is of non-standard type {type(result)}. Converting to string.")
|
||||
content_dict = {"result": str(result)}
|
||||
|
||||
# Ensure the content is JSON serializable for the 'content' field
|
||||
try:
|
||||
content_str = json.dumps(content_dict)
|
||||
except Exception as json_err:
|
||||
@@ -404,12 +357,9 @@ def format_google_tool_results(tool_call_id: str, function_name: str, result: An
|
||||
content_str = json.dumps({"error": "Failed to prepare tool result content", "details": str(e)})
|
||||
|
||||
logger.debug(f"Formatting Google tool result for call ID {tool_call_id} (Function: {function_name})")
|
||||
# Return in the standard message format, _convert_messages will handle Google's structure
|
||||
return {
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id, # Used by _convert_messages to find the original call
|
||||
"content": content_str, # The JSON string representing the result content
|
||||
"name": function_name, # Store original function name for _convert_messages
|
||||
# Note: Google's FunctionResponse Part needs 'name' and 'response' (dict).
|
||||
# This standard format will be converted by the provider's message conversion logic.
|
||||
"tool_call_id": tool_call_id,
|
||||
"content": content_str,
|
||||
"name": function_name,
|
||||
}
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# src/providers/google_provider/utils.py
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
@@ -12,7 +11,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
def get_context_window(model: str) -> int:
|
||||
"""Retrieves the context window size for a given Google model."""
|
||||
default_window = 1000000 # Default fallback for Gemini
|
||||
default_window = 1000000
|
||||
try:
|
||||
provider_models = MODELS.get("google", {}).get("models", [])
|
||||
for m in provider_models:
|
||||
@@ -44,12 +43,9 @@ def convert_messages(messages: list[dict[str, Any]]) -> tuple[list[Content], str
|
||||
system_prompt = content
|
||||
logger.debug("Extracted system prompt for Google.")
|
||||
else:
|
||||
# Google API expects system prompt only at the beginning.
|
||||
# If found later, log a warning and skip or merge if possible (though merging is complex).
|
||||
logger.warning("System message found not at the beginning. Skipping for Google API.")
|
||||
continue # Skip adding system messages to the main list
|
||||
continue
|
||||
|
||||
# Map roles: 'assistant' -> 'model', 'tool' -> 'function' (handled below)
|
||||
google_role = {"user": "user", "assistant": "model"}.get(role)
|
||||
|
||||
if not google_role and role != "tool":
|
||||
@@ -58,92 +54,73 @@ def convert_messages(messages: list[dict[str, Any]]) -> tuple[list[Content], str
|
||||
|
||||
parts: list[Part | str] = []
|
||||
if role == "tool":
|
||||
# Tool results are mapped to 'function' role in Google API
|
||||
if tool_call_id and content:
|
||||
try:
|
||||
# Attempt to parse the content as JSON, assuming it's the tool output
|
||||
response_content_dict = json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Could not decode tool result content for {tool_call_id}, sending as raw string.")
|
||||
response_content_dict = {"result": content} # Wrap raw string if not JSON
|
||||
response_content_dict = {"result": content}
|
||||
|
||||
# Find the original function name from the preceding assistant message
|
||||
func_name = "unknown_function" # Default if name can't be found
|
||||
func_name = "unknown_function"
|
||||
if i > 0 and messages[i - 1].get("role") == "assistant":
|
||||
prev_tool_calls = messages[i - 1].get("tool_calls")
|
||||
if prev_tool_calls:
|
||||
for tc in prev_tool_calls:
|
||||
# Match based on the ID provided in the tool message
|
||||
if tc.get("id") == tool_call_id:
|
||||
# Google uses 'server__func' format, extract original func name if possible
|
||||
full_name = tc.get("function_name", "unknown_function")
|
||||
func_name = full_name.split("__", 1)[-1] # Get the part after '__' or the full name
|
||||
func_name = full_name.split("__", 1)[-1]
|
||||
break
|
||||
|
||||
# Create a FunctionResponse part
|
||||
parts.append(Part.from_function_response(name=func_name, response={"content": response_content_dict}))
|
||||
google_role = "function" # Explicitly set role for tool results
|
||||
google_role = "function"
|
||||
else:
|
||||
logger.warning(f"Skipping tool message due to missing tool_call_id or content: {message}")
|
||||
continue # Skip if essential parts are missing
|
||||
continue
|
||||
|
||||
elif role == "assistant" and tool_calls:
|
||||
# Assistant message requesting tool calls
|
||||
for tool_call in tool_calls:
|
||||
args = tool_call.get("arguments", {})
|
||||
# Ensure arguments are a dict, not a string
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
args = json.loads(args)
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Failed to parse arguments string for tool call {tool_call.get('id')}: {args}")
|
||||
args = {"error": "failed to parse arguments"} # Provide error feedback
|
||||
args = {"error": "failed to parse arguments"}
|
||||
|
||||
# Google uses 'server__func' format, extract original func name if possible
|
||||
full_name = tool_call.get("function_name", "unknown_function")
|
||||
func_name = full_name.split("__", 1)[-1] # Get the part after '__' or the full name
|
||||
func_name = full_name.split("__", 1)[-1]
|
||||
|
||||
# Create a FunctionCall part
|
||||
parts.append(Part.from_function_call(name=func_name, args=args))
|
||||
|
||||
# Include any text content alongside the function calls
|
||||
if content and isinstance(content, str):
|
||||
parts.append(Part(text=content)) # Use direct instantiation
|
||||
parts.append(Part(text=content))
|
||||
|
||||
elif content:
|
||||
# Regular user or assistant message content
|
||||
if isinstance(content, str):
|
||||
parts.append(Part(text=content)) # Use direct instantiation
|
||||
# TODO: Handle potential image content if needed in the future
|
||||
parts.append(Part(text=content))
|
||||
else:
|
||||
logger.warning(f"Unsupported content type for role '{role}': {type(content)}. Converting to string.")
|
||||
parts.append(Part(text=str(content))) # Use direct instantiation
|
||||
parts.append(Part(text=str(content)))
|
||||
|
||||
# Add the constructed Content object if parts were generated
|
||||
if parts:
|
||||
google_messages.append(Content(role=google_role, parts=parts))
|
||||
else:
|
||||
# Log if a message resulted in no parts (e.g., empty content, skipped system message)
|
||||
logger.debug(f"No parts generated for message: {message}")
|
||||
|
||||
# Validate message alternation (user -> model -> user/function -> user -> ...)
|
||||
last_role = None
|
||||
valid_alternation = True
|
||||
for msg in google_messages:
|
||||
current_role = msg.role
|
||||
# Check for consecutive user/model roles
|
||||
if current_role == last_role and current_role in ["user", "model"]:
|
||||
valid_alternation = False
|
||||
logger.error(f"Invalid role sequence for Google: consecutive '{current_role}' roles.")
|
||||
break
|
||||
# Check if 'function' role is followed by 'user'
|
||||
if last_role == "function" and current_role != "user":
|
||||
valid_alternation = False
|
||||
logger.error(f"Invalid role sequence for Google: '{current_role}' follows 'function'. Expected 'user'.")
|
||||
break
|
||||
last_role = current_role
|
||||
|
||||
# Raise error if alternation is invalid, as Google API enforces this
|
||||
if not valid_alternation:
|
||||
raise ValueError("Invalid message sequence for Google API. Roles must alternate between 'user' and 'model', with 'function' responses followed by 'user'.")
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# src/providers/openai_provider/__init__.py
|
||||
from typing import Any
|
||||
|
||||
from openai import Stream
|
||||
@@ -20,25 +19,23 @@ from src.providers.base import BaseProvider
|
||||
class OpenAIProvider(BaseProvider):
|
||||
"""Provider implementation for OpenAI and compatible APIs."""
|
||||
|
||||
def __init__(self, api_key: str, base_url: str | None = None):
|
||||
# BaseProvider __init__ might not be needed if client init handles base_url logic
|
||||
# super().__init__(api_key, base_url) # Let's see if we need this
|
||||
temperature: float
|
||||
|
||||
def __init__(self, api_key: str, base_url: str | None = None, temperature: float = 0.6):
|
||||
self.client = initialize_client(api_key, base_url)
|
||||
# Store api_key and base_url if needed by BaseProvider or other methods
|
||||
self.api_key = api_key
|
||||
self.base_url = self.client.base_url # Get effective base_url from client
|
||||
self.base_url = self.client.base_url
|
||||
self.temperature = temperature
|
||||
|
||||
def create_chat_completion(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
model: str,
|
||||
temperature: float = 0.6,
|
||||
max_tokens: int | None = None,
|
||||
stream: bool = True,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
) -> Stream[ChatCompletionChunk] | ChatCompletion:
|
||||
# Pass self (provider instance) to the helper function
|
||||
return create_chat_completion(self, messages, model, temperature, max_tokens, stream, tools)
|
||||
return create_chat_completion(self, messages, model, self.temperature, max_tokens, stream, tools)
|
||||
|
||||
def get_streaming_content(self, response: Stream[ChatCompletionChunk]):
|
||||
return get_streaming_content(response)
|
||||
@@ -47,7 +44,6 @@ class OpenAIProvider(BaseProvider):
|
||||
return get_content(response)
|
||||
|
||||
def has_tool_calls(self, response: Stream[ChatCompletionChunk] | ChatCompletion) -> bool:
|
||||
# This method might need the full response after streaming, handled by LLMClient
|
||||
return has_tool_calls(response)
|
||||
|
||||
def parse_tool_calls(self, response: ChatCompletion) -> list[dict[str, Any]]:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# src/providers/openai_provider/client.py
|
||||
import logging
|
||||
|
||||
from openai import OpenAI
|
||||
@@ -10,12 +9,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
def initialize_client(api_key: str, base_url: str | None = None) -> OpenAI:
|
||||
"""Initializes and returns an OpenAI client instance."""
|
||||
# Use default OpenAI endpoint if base_url is not provided explicitly
|
||||
effective_base_url = base_url or MODELS.get("openai", {}).get("endpoint")
|
||||
logger.info(f"Initializing OpenAI client with base URL: {effective_base_url}")
|
||||
try:
|
||||
# TODO: Add default headers if needed, similar to the original openai_client.py?
|
||||
# default_headers={"HTTP-Referer": "...", "X-Title": "..."}
|
||||
client = OpenAI(api_key=api_key, base_url=effective_base_url)
|
||||
return client
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# src/providers/openai_provider/completion.py
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
@@ -11,7 +10,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_chat_completion(
|
||||
provider, # The OpenAIProvider instance
|
||||
provider,
|
||||
messages: list[dict[str, str]],
|
||||
model: str,
|
||||
temperature: float = 0.6,
|
||||
@@ -22,44 +21,30 @@ def create_chat_completion(
|
||||
"""Creates a chat completion using the OpenAI API, handling context window truncation."""
|
||||
logger.debug(f"OpenAI create_chat_completion called. Model: {model}, Stream: {stream}, Tools: {bool(tools)}")
|
||||
|
||||
# --- Truncation Step ---
|
||||
truncated_messages, initial_est_tokens, final_est_tokens = truncate_messages(messages, model)
|
||||
# -----------------------
|
||||
|
||||
try:
|
||||
completion_params = {
|
||||
"model": model,
|
||||
"messages": truncated_messages, # Use truncated messages
|
||||
"messages": truncated_messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"stream": stream,
|
||||
}
|
||||
if tools:
|
||||
completion_params["tools"] = tools
|
||||
completion_params["tool_choice"] = "auto" # Let OpenAI decide when to use tools
|
||||
completion_params["tool_choice"] = "auto"
|
||||
|
||||
# Remove None values like max_tokens if not provided
|
||||
completion_params = {k: v for k, v in completion_params.items() if v is not None}
|
||||
|
||||
# --- Added Debug Logging ---
|
||||
log_params = completion_params.copy()
|
||||
# Avoid logging full messages if they are too long
|
||||
if "messages" in log_params:
|
||||
log_params["messages"] = [
|
||||
{k: (v[:100] + "..." if isinstance(v, str) and len(v) > 100 else v) for k, v in msg.items()}
|
||||
for msg in log_params["messages"][-2:] # Log last 2 messages summary
|
||||
]
|
||||
# Specifically log tools structure if present
|
||||
tools_log = log_params.get("tools", "Not Present")
|
||||
logger.debug(f"Calling OpenAI API. Model: {log_params.get('model')}, Stream: {log_params.get('stream')}, Tools: {tools_log}")
|
||||
logger.debug(f"Full API Params (messages summarized): {log_params}")
|
||||
# --- End Added Debug Logging ---
|
||||
logger.debug(f"Full API Params: {log_params}")
|
||||
|
||||
response = provider.client.chat.completions.create(**completion_params)
|
||||
logger.debug("OpenAI API call successful.")
|
||||
|
||||
# --- Capture Actual Usage (for UI display later) ---
|
||||
# Log usage if available (primarily non-streaming)
|
||||
actual_usage = None
|
||||
if isinstance(response, ChatCompletion) and response.usage:
|
||||
actual_usage = {
|
||||
@@ -68,13 +53,9 @@ def create_chat_completion(
|
||||
"total_tokens": response.usage.total_tokens,
|
||||
}
|
||||
logger.info(f"Actual OpenAI API usage: {actual_usage}")
|
||||
# TODO: How to handle usage for streaming responses? Needs investigation.
|
||||
|
||||
# Return the raw response for now. LLMClient will process it.
|
||||
return response
|
||||
# ----------------------------------------------------
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI API error: {e}", exc_info=True)
|
||||
# Re-raise for the LLMClient to handle
|
||||
raise
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# src/providers/openai_provider/response.py
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
@@ -16,30 +15,24 @@ def get_streaming_content(response: Stream[ChatCompletionChunk]) -> Generator[st
|
||||
full_delta = ""
|
||||
try:
|
||||
for chunk in response:
|
||||
# Check if choices exist and are not empty
|
||||
if chunk.choices:
|
||||
delta = chunk.choices[0].delta.content
|
||||
if delta:
|
||||
full_delta += delta
|
||||
yield delta
|
||||
# Handle potential finish reasons or other stream elements if needed
|
||||
# else:
|
||||
# logger.debug(f"Stream chunk without choices: {chunk}") # Or handle finish reason etc.
|
||||
logger.debug(f"Stream finished. Total delta length: {len(full_delta)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing OpenAI stream: {e}", exc_info=True)
|
||||
# Yield an error message? Or let the generator stop?
|
||||
yield json.dumps({"error": f"Stream processing error: {str(e)}"})
|
||||
|
||||
|
||||
def get_content(response: ChatCompletion) -> str:
|
||||
"""Extracts content from a non-streaming OpenAI response."""
|
||||
try:
|
||||
# Check if choices exist and are not empty
|
||||
if response.choices:
|
||||
content = response.choices[0].message.content
|
||||
logger.debug(f"Extracted content (length {len(content) if content else 0}) from non-streaming response.")
|
||||
return content or "" # Return empty string if content is None
|
||||
return content or ""
|
||||
else:
|
||||
logger.warning("No choices found in OpenAI non-streaming response.")
|
||||
return "[No content received]"
|
||||
@@ -55,12 +48,10 @@ def get_usage(response: Any) -> dict[str, int] | None:
|
||||
usage = {
|
||||
"prompt_tokens": response.usage.prompt_tokens,
|
||||
"completion_tokens": response.usage.completion_tokens,
|
||||
# "total_tokens": response.usage.total_tokens, # Optional
|
||||
}
|
||||
logger.debug(f"Extracted usage from OpenAI response: {usage}")
|
||||
return usage
|
||||
else:
|
||||
# Don't log warning for streams, as usage isn't expected here
|
||||
if not isinstance(response, Stream):
|
||||
logger.warning(f"Could not extract usage from OpenAI response object of type {type(response)}")
|
||||
return None
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# src/providers/openai_provider/tools.py
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
@@ -13,20 +12,16 @@ logger = logging.getLogger(__name__)
|
||||
def has_tool_calls(response: Stream[ChatCompletionChunk] | ChatCompletion) -> bool:
|
||||
"""Checks if the OpenAI response contains tool calls."""
|
||||
try:
|
||||
if isinstance(response, ChatCompletion): # Non-streaming
|
||||
# Check if choices exist and are not empty
|
||||
if isinstance(response, ChatCompletion):
|
||||
if response.choices:
|
||||
return bool(response.choices[0].message.tool_calls)
|
||||
else:
|
||||
logger.warning("No choices found in OpenAI non-streaming response for tool check.")
|
||||
return False
|
||||
elif isinstance(response, Stream):
|
||||
# This check remains unreliable for unconsumed streams.
|
||||
# LLMClient needs robust handling after consumption.
|
||||
logger.warning("has_tool_calls check on a stream is unreliable before consumption.")
|
||||
return False # Assume no for unconsumed stream for now
|
||||
return False
|
||||
else:
|
||||
# If it's already consumed stream or unexpected type
|
||||
logger.warning(f"has_tool_calls received unexpected type: {type(response)}")
|
||||
return False
|
||||
except Exception as e:
|
||||
@@ -36,14 +31,12 @@ def has_tool_calls(response: Stream[ChatCompletionChunk] | ChatCompletion) -> bo
|
||||
|
||||
def parse_tool_calls(response: ChatCompletion) -> list[dict[str, Any]]:
|
||||
"""Parses tool calls from a non-streaming OpenAI response."""
|
||||
# This implementation assumes a non-streaming response or a fully buffered stream
|
||||
parsed_calls = []
|
||||
try:
|
||||
if not isinstance(response, ChatCompletion):
|
||||
logger.error(f"parse_tool_calls expects ChatCompletion, got {type(response)}")
|
||||
return []
|
||||
|
||||
# Check if choices exist and are not empty
|
||||
if not response.choices:
|
||||
logger.warning("No choices found in OpenAI non-streaming response for tool parsing.")
|
||||
return []
|
||||
@@ -55,38 +48,30 @@ def parse_tool_calls(response: ChatCompletion) -> list[dict[str, Any]]:
|
||||
logger.debug(f"Parsing {len(tool_calls)} tool calls from OpenAI response.")
|
||||
for call in tool_calls:
|
||||
if call.type == "function":
|
||||
# Attempt to parse server_name from function name if prefixed
|
||||
# e.g., "server-name__actual-tool-name"
|
||||
parts = call.function.name.split("__", 1)
|
||||
if len(parts) == 2:
|
||||
server_name, func_name = parts
|
||||
else:
|
||||
# If no prefix, how do we know the server? Needs refinement.
|
||||
# Defaulting to None or a default server? Log warning.
|
||||
logger.warning(f"Could not determine server_name from tool name '{call.function.name}'. Assuming default or error needed.")
|
||||
server_name = None # Or raise error, or use a default?
|
||||
server_name = None
|
||||
func_name = call.function.name
|
||||
|
||||
# Arguments might be a string needing JSON parsing, or already parsed dict
|
||||
arguments_obj = None
|
||||
try:
|
||||
if isinstance(call.function.arguments, str):
|
||||
arguments_obj = json.loads(call.function.arguments)
|
||||
else:
|
||||
# Assuming it might already be a dict if not a string (less common)
|
||||
arguments_obj = call.function.arguments
|
||||
except json.JSONDecodeError as json_err:
|
||||
logger.error(f"Failed to parse JSON arguments for tool {func_name} (ID: {call.id}): {json_err}")
|
||||
logger.error(f"Raw arguments string: {call.function.arguments}")
|
||||
# Decide how to handle: skip tool, pass raw string, pass error?
|
||||
# Passing raw string for now, but this might break consumers.
|
||||
arguments_obj = {"error": "Failed to parse arguments", "raw_arguments": call.function.arguments}
|
||||
|
||||
parsed_calls.append({
|
||||
"id": call.id,
|
||||
"server_name": server_name, # May be None if not prefixed
|
||||
"server_name": server_name,
|
||||
"function_name": func_name,
|
||||
"arguments": arguments_obj, # Pass parsed arguments (or error dict)
|
||||
"arguments": arguments_obj,
|
||||
})
|
||||
else:
|
||||
logger.warning(f"Unsupported tool call type: {call.type}")
|
||||
@@ -94,20 +79,18 @@ def parse_tool_calls(response: ChatCompletion) -> list[dict[str, Any]]:
|
||||
return parsed_calls
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing OpenAI tool calls: {e}", exc_info=True)
|
||||
return [] # Return empty list on error
|
||||
return []
|
||||
|
||||
|
||||
def format_tool_results(tool_call_id: str, result: Any) -> dict[str, Any]:
|
||||
"""Formats a tool result for an OpenAI follow-up request."""
|
||||
# Result might be a dict (including potential errors) or simple string/number
|
||||
# OpenAI expects the content to be a string, often JSON.
|
||||
try:
|
||||
if isinstance(result, dict):
|
||||
content = json.dumps(result)
|
||||
elif isinstance(result, str):
|
||||
content = result # Allow plain strings if result is already string
|
||||
content = result
|
||||
else:
|
||||
content = str(result) # Ensure it's a string otherwise
|
||||
content = str(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error JSON-encoding tool result for {tool_call_id}: {e}")
|
||||
content = json.dumps({"error": "Failed to encode tool result", "original_type": str(type(result))})
|
||||
@@ -122,9 +105,6 @@ def format_tool_results(tool_call_id: str, result: Any) -> dict[str, Any]:
|
||||
|
||||
def convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Converts internal tool format to OpenAI's format."""
|
||||
# This function seems identical to the one in src/tools/conversion.py
|
||||
# We can potentially remove it from here and import from the central location.
|
||||
# For now, keep it duplicated to maintain modularity until a decision is made.
|
||||
openai_tools = []
|
||||
logger.debug(f"Converting {len(tools)} tools to OpenAI format.")
|
||||
for tool in tools:
|
||||
@@ -137,7 +117,6 @@ def convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
logger.warning(f"Skipping invalid tool definition during conversion: {tool}")
|
||||
continue
|
||||
|
||||
# Prefix tool name with server name to avoid clashes and allow routing
|
||||
prefixed_tool_name = f"{server_name}__{tool_name}"
|
||||
|
||||
openai_tool_format = {
|
||||
@@ -145,7 +124,7 @@ def convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"function": {
|
||||
"name": prefixed_tool_name,
|
||||
"description": description,
|
||||
"parameters": input_schema, # OpenAI uses JSON Schema directly
|
||||
"parameters": input_schema,
|
||||
},
|
||||
}
|
||||
openai_tools.append(openai_tool_format)
|
||||
@@ -159,11 +138,9 @@ def get_original_message_with_calls(response: ChatCompletion) -> dict[str, Any]:
|
||||
try:
|
||||
if isinstance(response, ChatCompletion) and response.choices and response.choices[0].message.tool_calls:
|
||||
message = response.choices[0].message
|
||||
# Convert Pydantic model to dict for message history
|
||||
return message.model_dump(exclude_unset=True)
|
||||
else:
|
||||
logger.warning("Could not extract original message with tool calls from response.")
|
||||
# Return a placeholder or raise error?
|
||||
return {"role": "assistant", "content": "[Could not extract tool calls message]"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting original message with calls: {e}", exc_info=True)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# src/providers/openai_provider/utils.py
|
||||
import logging
|
||||
import math
|
||||
|
||||
@@ -9,15 +8,12 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
def get_context_window(model: str) -> int:
|
||||
"""Retrieves the context window size for a given model."""
|
||||
# Default to a safe fallback if model or provider info is missing
|
||||
default_window = 8000
|
||||
try:
|
||||
# Assuming MODELS structure: MODELS['openai']['models'] is a list of dicts
|
||||
provider_models = MODELS.get("openai", {}).get("models", [])
|
||||
for m in provider_models:
|
||||
if m.get("id") == model:
|
||||
return m.get("context_window", default_window)
|
||||
# Fallback if specific model ID not found in our list
|
||||
logger.warning(f"Context window for OpenAI model '{model}' not found in MODELS config. Using default: {default_window}")
|
||||
return default_window
|
||||
except Exception as e:
|
||||
@@ -36,8 +32,6 @@ def estimate_openai_token_count(messages: list[dict[str, str]]) -> int:
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
total_chars += len(content)
|
||||
# Rough approximation for function/tool call overhead if needed later
|
||||
# Using math.ceil to round up, ensuring we don't underestimate too much.
|
||||
estimated_tokens = math.ceil(total_chars / 4.0)
|
||||
logger.debug(f"Estimated OpenAI token count (char/4): {estimated_tokens} for {len(messages)} messages")
|
||||
return estimated_tokens
|
||||
@@ -54,49 +48,41 @@ def truncate_messages(messages: list[dict[str, str]], model: str) -> tuple[list[
|
||||
- The final estimated token count after truncation (if any).
|
||||
"""
|
||||
context_limit = get_context_window(model)
|
||||
# Add a buffer to be safer with approximation
|
||||
buffer = 200 # Reduce buffer slightly as we round up now
|
||||
buffer = 200
|
||||
effective_limit = context_limit - buffer
|
||||
|
||||
initial_estimated_count = estimate_openai_token_count(messages)
|
||||
final_estimated_count = initial_estimated_count
|
||||
|
||||
truncated_messages = list(messages) # Make a copy
|
||||
truncated_messages = list(messages)
|
||||
|
||||
# Identify if the first message is a system prompt
|
||||
has_system_prompt = False
|
||||
if truncated_messages and truncated_messages[0].get("role") == "system":
|
||||
has_system_prompt = True
|
||||
# If only system prompt exists, don't truncate further
|
||||
if len(truncated_messages) == 1 and final_estimated_count > effective_limit:
|
||||
logger.warning(f"System prompt alone ({final_estimated_count} tokens) exceeds effective limit ({effective_limit}). Cannot truncate further.")
|
||||
# Return original messages to avoid removing the only message
|
||||
return messages, initial_estimated_count, final_estimated_count
|
||||
|
||||
while final_estimated_count > effective_limit:
|
||||
if has_system_prompt and len(truncated_messages) <= 1:
|
||||
# Should not happen if check above works, but safety break
|
||||
logger.warning("Truncation stopped: Only system prompt remains.")
|
||||
break
|
||||
if not has_system_prompt and len(truncated_messages) <= 0:
|
||||
logger.warning("Truncation stopped: No messages left.")
|
||||
break # No messages left
|
||||
break
|
||||
|
||||
# Determine index to remove: 1 if system prompt exists and list is long enough, else 0
|
||||
remove_index = 1 if has_system_prompt and len(truncated_messages) > 1 else 0
|
||||
|
||||
if remove_index >= len(truncated_messages):
|
||||
logger.error(f"Truncation logic error: remove_index {remove_index} out of bounds for {len(truncated_messages)} messages.")
|
||||
break # Avoid index error
|
||||
break
|
||||
|
||||
removed_message = truncated_messages.pop(remove_index)
|
||||
logger.debug(f"Truncating message at index {remove_index} (Role: {removed_message.get('role')}) due to context limit.")
|
||||
|
||||
# Recalculate estimated count
|
||||
final_estimated_count = estimate_openai_token_count(truncated_messages)
|
||||
logger.debug(f"Recalculated estimated tokens: {final_estimated_count}")
|
||||
|
||||
# Safety break if list becomes unexpectedly empty
|
||||
if not truncated_messages:
|
||||
logger.warning("Truncation resulted in empty message list.")
|
||||
break
|
||||
|
||||
Reference in New Issue
Block a user