Compare commits

...

3 Commits

Author SHA1 Message Date
247835e595 Refactor Google and OpenAI provider response handling and tool utilities
- Improved error handling and logging in Google response processing.
- Simplified streaming content extraction and error detection in Google provider.
- Enhanced content extraction logic in OpenAI provider to handle edge cases.
- Streamlined tool conversion functions for both Google and OpenAI providers.
- Removed redundant comments and improved code readability across multiple files.
- Updated context window retrieval and message truncation logic for better performance.
- Ensured consistent handling of tool calls and arguments in OpenAI responses.
2025-03-28 04:20:39 +00:00
51e3058961 fix: update temperature parameter to 0.6 across multiple providers and add debugging output 2025-03-27 19:02:52 +00:00
ccf750fed4 fix: correct logging error message for Google Generative AI SDK 2025-03-27 15:22:19 +00:00
28 changed files with 458 additions and 645 deletions

View File

@@ -8,18 +8,21 @@ api_key = YOUR_API_KEY
base_url = https://openrouter.ai/api/v1 base_url = https://openrouter.ai/api/v1
model = openai/gpt-4o-2024-11-20 model = openai/gpt-4o-2024-11-20
context_window = 128000 context_window = 128000
temperature = 0.6
[anthropic] [anthropic]
api_key = YOUR_API_KEY api_key = YOUR_API_KEY
base_url = https://api.anthropic.com/v1/messages base_url = https://api.anthropic.com/v1/messages
model = claude-3-7-sonnet-20250219 model = claude-3-7-sonnet-20250219
context_window = 128000 context_window = 128000
temperature = 0.6
[google] [google]
api_key = YOUR_API_KEY api_key = YOUR_API_KEY
base_url = https://generativelanguage.googleapis.com/v1beta/generateContent base_url = https://generativelanguage.googleapis.com/v1beta/generateContent
model = gemini-2.0-flash model = gemini-2.0-flash
context_window = 1000000 context_window = 1000000
temperature = 0.6
[openai] [openai]
@@ -27,6 +30,7 @@ api_key = YOUR_API_KEY
base_url = https://api.openai.com/v1 base_url = https://api.openai.com/v1
model = openai/gpt-4o model = openai/gpt-4o
context_window = 128000 context_window = 128000
temperature = 0.6
[mcp] [mcp]
servers_json = config/mcp_config.json servers_json = config/mcp_config.json

106
project_planning/updates.md Normal file
View File

@@ -0,0 +1,106 @@
What is the google-genai Module?
The google-genai module is part of the Google Gen AI Python SDK, a software development kit provided by Google to enable developers to integrate Google's generative AI models into Python applications. This SDK is distinct from the older, deprecated google-generativeai package. The google-genai package represents the newer, unified SDK designed to work with Google's latest generative AI offerings, such as the Gemini models.
Installation
To use the google-genai module, you need to install it via pip. The package name on PyPI is google-genai, and you can install it with the following command:
bash
```bash
pip install google-genai
```
This installs the necessary dependencies and makes the module available in your Python environment.
Correct Import Statement
The standard import statement for the google-genai SDK, as per the official documentation and examples, is:
python
```python
from google import genai
```
This differs from the older SDK's import style, which was:
python
```python
import google.generativeai as genai
```
When you install google-genai, it provides a module structure where genai is a submodule of the google package. Thus, from google import genai is the correct way to access its functionality.
Usage in the New SDK
Unlike the older google-generativeai SDK, which used a configure method (e.g., genai.configure(api_key='YOUR_API_KEY')) to set up the API key globally, the new google-genai SDK adopts a client-based approach. You create a Client instance with your API key and use it to interact with the models. Heres a basic example:
python
```python
from google import genai
# Initialize the client with your API key
client = genai.Client(api_key='YOUR_API_KEY')
# Example: Generate content using a model
response = client.models.generate_content(
model='gemini-2.0-flash-001', # Specify the model name
contents='Why is the sky blue?'
)
print(response.text)
```
Key points about this usage:
- No configure Method: The new SDK does not have a genai.configure method directly on the genai module. Instead, you pass the API key when creating a Client instance.
- Client-Based Interaction: All interactions with the generative models (e.g., generating content) are performed through the client object.
Official Documentation
The official documentation for the google-genai SDK can be found on Google's API documentation site. Specifically:
- Google Gen AI Python SDK Documentation: Hosted at https://googleapis.github.io/python-genai/, this site provides detailed guides, API references, and code examples. It confirms the use of from google import genai and the client-based approach.
- GitHub Repository: The source code and additional examples are available at https://github.com/googleapis/python-genai. The repository documentation reinforces that from google import genai is the import style for the new SDK.
Why Your Import is Correct
You mentioned that your import is correct, which I assume refers to from google import genai. This is indeed the proper import for the google-genai package, aligning with the new SDK's design. If you're encountering issues (e.g., an error like module 'google.genai' has no attribute 'configure'), its likely because the code is trying to use methods from the older SDK (like genai.configure) that dont exist in the new one. To resolve this, you should update the code to use the client-based approach shown above.
Troubleshooting Common Issues
If you're seeing errors with from google import genai, here are some things to check:
1. Correct Package Installed:
- Ensure google-genai is installed (pip install google-genai).
- If google-generativeai is installed instead, uninstall it with pip uninstall google-generativeai to avoid conflicts, then install google-genai.
2. Code Compatibility:
- If your code uses genai.configure or assumes the older SDKs structure, youll need to refactor it. Replace configuration calls with genai.Client(api_key='...') and adjust model interactions to use the client object.
3. Environment Verification:
- Run pip show google-genai to confirm the package is installed and check its version. This ensures youre working with the intended SDK.
Additional Resources
- PyPI Page: The google-genai package on PyPI (https://pypi.org/project/google-genai/) provides installation instructions and links to the GitHub repository.
- Examples: The GitHub repository includes sample code demonstrating how to use the SDK with from google import genai.
Conclusion
Your import, from google import genai, aligns with the google-genai module from the new Google Gen AI Python SDK. The documentation and online resources confirm this as the correct approach for the current, unified SDK. If import google.generativeai was previously suggested or used, it pertains to the older, deprecated SDK, which explains why it might be considered incorrect in your context. To fully leverage google-genai, ensure your code uses the client-based API as outlined, and you should be able to interact with Googles generative AI models effectively. If youre still facing specific errors, feel free to share them, and I can assist further!

View File

@@ -87,7 +87,7 @@ skip-magic-trailing-comma = false
combine-as-imports = true combine-as-imports = true
[tool.ruff.lint.mccabe] [tool.ruff.lint.mccabe]
max-complexity = 16 max-complexity = 30
[tool.ruff.lint.flake8-tidy-imports] [tool.ruff.lint.flake8-tidy-imports]
# Disallow all relative imports. # Disallow all relative imports.

View File

@@ -7,7 +7,6 @@ import streamlit as st
from llm_client import LLMClient from llm_client import LLMClient
from src.custom_mcp.manager import SyncMCPManager from src.custom_mcp.manager import SyncMCPManager
# Configure logging for the app
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -22,14 +21,12 @@ def init_session_state():
logger.info("Attempting to initialize clients...") logger.info("Attempting to initialize clients...")
try: try:
config = configparser.ConfigParser() config = configparser.ConfigParser()
# TODO: Improve config file path handling (e.g., environment variable, absolute path)
config_files_read = config.read("config/config.ini") config_files_read = config.read("config/config.ini")
if not config_files_read: if not config_files_read:
raise FileNotFoundError("config.ini not found or could not be read.") raise FileNotFoundError("config.ini not found or could not be read.")
logger.info(f"Read configuration from: {config_files_read}") logger.info(f"Read configuration from: {config_files_read}")
# --- MCP Manager Setup --- mcp_config_path = "config/mcp_config.json"
mcp_config_path = "config/mcp_config.json" # Default
if config.has_section("mcp") and config["mcp"].get("servers_json"): if config.has_section("mcp") and config["mcp"].get("servers_json"):
mcp_config_path = config["mcp"]["servers_json"] mcp_config_path = config["mcp"]["servers_json"]
logger.info(f"Using MCP config path from config.ini: {mcp_config_path}") logger.info(f"Using MCP config path from config.ini: {mcp_config_path}")
@@ -38,39 +35,37 @@ def init_session_state():
mcp_manager = SyncMCPManager(mcp_config_path) mcp_manager = SyncMCPManager(mcp_config_path)
if not mcp_manager.initialize(): if not mcp_manager.initialize():
# Log warning but continue - LLMClient will operate without tools
logger.warning("MCP Manager failed to initialize. Proceeding without MCP tools.") logger.warning("MCP Manager failed to initialize. Proceeding without MCP tools.")
else: else:
logger.info("MCP Manager initialized successfully.") logger.info("MCP Manager initialized successfully.")
# Register shutdown hook for MCP manager
atexit.register(mcp_manager.shutdown) atexit.register(mcp_manager.shutdown)
logger.info("Registered MCP Manager shutdown hook.") logger.info("Registered MCP Manager shutdown hook.")
# --- LLM Client Setup ---
provider_name = None provider_name = None
model_name = None model_name = None
api_key = None api_key = None
base_url = None base_url = None
# 1. Determine provider from [base] section
if config.has_section("base") and config["base"].get("provider"): if config.has_section("base") and config["base"].get("provider"):
provider_name = config["base"].get("provider") provider_name = config["base"].get("provider")
logger.info(f"Provider selected from [base] section: {provider_name}") logger.info(f"Provider selected from [base] section: {provider_name}")
else: else:
# Fallback or error if [base] provider is missing? Let's error for now.
raise ValueError("Missing 'provider' setting in [base] section of config.ini") raise ValueError("Missing 'provider' setting in [base] section of config.ini")
# 2. Read details from the specific provider's section
if config.has_section(provider_name): if config.has_section(provider_name):
provider_config = config[provider_name] provider_config = config[provider_name]
model_name = provider_config.get("model") model_name = provider_config.get("model")
api_key = provider_config.get("api_key") api_key = provider_config.get("api_key")
base_url = provider_config.get("base_url") # Optional base_url = provider_config.get("base_url")
provider_temperature = provider_config.getfloat("temperature", 0.6)
if "temperature" not in provider_config:
logger.warning(f"Temperature not found in [{provider_name}] section, defaulting to {provider_temperature}")
else:
logger.info(f"Loaded temperature for {provider_name}: {provider_temperature}")
logger.info(f"Read configuration from [{provider_name}] section.") logger.info(f"Read configuration from [{provider_name}] section.")
else: else:
raise ValueError(f"Missing configuration section '[{provider_name}]' in config.ini for the selected provider.") raise ValueError(f"Missing configuration section '[{provider_name}]' in config.ini for the selected provider.")
# Validate required config
if not api_key: if not api_key:
raise ValueError(f"Missing 'api_key' in [{provider_name}] section of config.ini") raise ValueError(f"Missing 'api_key' in [{provider_name}] section of config.ini")
if not model_name: if not model_name:
@@ -82,15 +77,15 @@ def init_session_state():
api_key=api_key, api_key=api_key,
mcp_manager=mcp_manager, mcp_manager=mcp_manager,
base_url=base_url, base_url=base_url,
temperature=provider_temperature,
) )
st.session_state.model_name = model_name st.session_state.model_name = model_name
st.session_state.provider_name = provider_name # Store provider name st.session_state.provider_name = provider_name
logger.info("LLMClient initialized successfully.") logger.info("LLMClient initialized successfully.")
except Exception as e: except Exception as e:
logger.error(f"Failed to initialize application clients: {e}", exc_info=True) logger.error(f"Failed to initialize application clients: {e}", exc_info=True)
st.error(f"Application Initialization Error: {e}. Please check configuration and logs.") st.error(f"Application Initialization Error: {e}. Please check configuration and logs.")
# Stop the app if initialization fails critically
st.stop() st.stop()
@@ -98,9 +93,7 @@ def display_chat_messages():
"""Displays chat messages stored in session state.""" """Displays chat messages stored in session state."""
for message in st.session_state.messages: for message in st.session_state.messages:
with st.chat_message(message["role"]): with st.chat_message(message["role"]):
# Display content
st.markdown(message["content"]) st.markdown(message["content"])
# Display usage if available (for assistant messages)
if message["role"] == "assistant" and "usage" in message: if message["role"] == "assistant" and "usage" in message:
usage = message["usage"] usage = message["usage"]
prompt_tokens = usage.get("prompt_tokens", "N/A") prompt_tokens = usage.get("prompt_tokens", "N/A")
@@ -121,19 +114,15 @@ def handle_user_input():
response_placeholder = st.empty() response_placeholder = st.empty()
full_response = "" full_response = ""
error_occurred = False error_occurred = False
response_usage = None # Initialize usage info response_usage = None
logger.info("Processing message via LLMClient...") logger.info("Processing message via LLMClient...")
# Use the new client and method
# NOTE: Setting stream=False to easily get usage info from the response dict.
# A more complex solution is needed to get usage with streaming.
response_data = st.session_state.client.chat_completion( response_data = st.session_state.client.chat_completion(
messages=st.session_state.messages, messages=st.session_state.messages,
model=st.session_state.model_name, model=st.session_state.model_name,
stream=False, # Set to False for usage info stream=False,
) )
# Handle the response (now expecting a dict)
if isinstance(response_data, dict): if isinstance(response_data, dict):
if "error" in response_data: if "error" in response_data:
full_response = f"Error: {response_data['error']}" full_response = f"Error: {response_data['error']}"
@@ -142,24 +131,19 @@ def handle_user_input():
error_occurred = True error_occurred = True
else: else:
full_response = response_data.get("content", "") full_response = response_data.get("content", "")
response_usage = response_data.get("usage") # Get usage dict response_usage = response_data.get("usage")
if not full_response and not error_occurred: # Check error_occurred flag too if not full_response and not error_occurred:
logger.warning("Empty content received from LLMClient.") logger.warning("Empty content received from LLMClient.")
# Display nothing or a placeholder? Let's display nothing.
# full_response = "[Empty Response]"
# Display the full response at once (no streaming)
response_placeholder.markdown(full_response) response_placeholder.markdown(full_response)
logger.debug("Non-streaming response processed.") logger.debug("Non-streaming response processed.")
else: else:
# Unexpected response type
full_response = "[Unexpected response format from LLMClient]" full_response = "[Unexpected response format from LLMClient]"
logger.error(f"Unexpected response type: {type(response_data)}") logger.error(f"Unexpected response type: {type(response_data)}")
st.error(full_response) st.error(full_response)
error_occurred = True error_occurred = True
# Add response to history, including usage if available if not error_occurred and full_response:
if not error_occurred and full_response: # Only add if no error and content exists
assistant_message = {"role": "assistant", "content": full_response} assistant_message = {"role": "assistant", "content": full_response}
if response_usage: if response_usage:
assistant_message["usage"] = response_usage assistant_message["usage"] = response_usage
@@ -181,35 +165,28 @@ def main():
try: try:
init_session_state() init_session_state()
# --- Display Enhanced Header ---
provider_name = st.session_state.get("provider_name", "Unknown Provider") provider_name = st.session_state.get("provider_name", "Unknown Provider")
model_name = st.session_state.get("model_name", "Unknown Model") model_name = st.session_state.get("model_name", "Unknown Model")
mcp_manager = st.session_state.client.mcp_manager # Get the manager mcp_manager = st.session_state.client.mcp_manager
server_count = 0 server_count = 0
tool_count = 0 tool_count = 0
if mcp_manager and mcp_manager.initialized: if mcp_manager and mcp_manager.initialized:
server_count = len(mcp_manager.servers) server_count = len(mcp_manager.servers)
try: try:
# Get tool count (might be slightly slow if many tools/servers)
tool_count = len(mcp_manager.list_all_tools()) tool_count = len(mcp_manager.list_all_tools())
except Exception as e: except Exception as e:
logger.warning(f"Could not retrieve tool count for header: {e}") logger.warning(f"Could not retrieve tool count for header: {e}")
tool_count = "N/A" # Display N/A if listing fails tool_count = "N/A"
# Display the new header format
st.markdown(f"# Say Hi to **{provider_name.capitalize()}**!") st.markdown(f"# Say Hi to **{provider_name.capitalize()}**!")
st.write(f"MCP Servers: **{server_count}** | Tools: **{tool_count}**") st.write(f"MCP Servers: **{server_count}** | Tools: **{tool_count}**")
st.write(f"Model: **{model_name}**") st.write(f"Model: **{model_name}**")
st.divider() st.divider()
# -----------------------------
# Removed the previous caption display
display_chat_messages() display_chat_messages()
handle_user_input() handle_user_input()
except Exception as e: except Exception as e:
# Catch potential errors during rendering or handling
logger.critical(f"Critical error in main app flow: {e}", exc_info=True) logger.critical(f"Critical error in main app flow: {e}", exc_info=True)
st.error(f"A critical application error occurred: {e}") st.error(f"A critical application error occurred: {e}")

View File

@@ -1 +0,0 @@
# This file makes src/mcp a Python package

View File

@@ -1,4 +1,3 @@
# src/mcp/client.py
"""Client class for managing and interacting with a single MCP server process.""" """Client class for managing and interacting with a single MCP server process."""
import asyncio import asyncio
@@ -9,9 +8,8 @@ from custom_mcp import process, protocol
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Define reasonable timeouts LIST_TOOLS_TIMEOUT = 20.0
LIST_TOOLS_TIMEOUT = 20.0 # Seconds (using the increased value from previous step) CALL_TOOL_TIMEOUT = 110.0
CALL_TOOL_TIMEOUT = 110.0 # Seconds
class MCPClient: class MCPClient:
@@ -39,7 +37,7 @@ class MCPClient:
self._stderr_task: asyncio.Task | None = None self._stderr_task: asyncio.Task | None = None
self._request_counter = 0 self._request_counter = 0
self._is_running = False self._is_running = False
self.logger = logging.getLogger(f"{__name__}.{self.server_name}") # Instance-specific logger self.logger = logging.getLogger(f"{__name__}.{self.server_name}")
async def _log_stderr(self): async def _log_stderr(self):
"""Logs stderr output from the server process.""" """Logs stderr output from the server process."""
@@ -55,7 +53,6 @@ class MCPClient:
except asyncio.CancelledError: except asyncio.CancelledError:
self.logger.debug("Stderr logging task cancelled.") self.logger.debug("Stderr logging task cancelled.")
except Exception as e: except Exception as e:
# Log errors but don't crash the logger task if possible
self.logger.error(f"Error reading stderr: {e}", exc_info=True) self.logger.error(f"Error reading stderr: {e}", exc_info=True)
finally: finally:
self.logger.debug("Stderr logging task finished.") self.logger.debug("Stderr logging task finished.")
@@ -79,13 +76,11 @@ class MCPClient:
if self.reader is None or self.writer is None: if self.reader is None or self.writer is None:
self.logger.error("Failed to get stdout/stdin streams after process start.") self.logger.error("Failed to get stdout/stdin streams after process start.")
await self.stop() # Attempt cleanup await self.stop()
return False return False
# Start background task to monitor stderr
self._stderr_task = asyncio.create_task(self._log_stderr()) self._stderr_task = asyncio.create_task(self._log_stderr())
# --- Start MCP Initialization Handshake ---
self.logger.info("Starting MCP initialization handshake...") self.logger.info("Starting MCP initialization handshake...")
self._request_counter += 1 self._request_counter += 1
init_req_id = self._request_counter init_req_id = self._request_counter
@@ -94,21 +89,18 @@ class MCPClient:
"id": init_req_id, "id": init_req_id,
"method": "initialize", "method": "initialize",
"params": { "params": {
"protocolVersion": "2024-11-05", # Use a recent version "protocolVersion": "2024-11-05",
"clientInfo": {"name": "CustomMCPClient", "version": "1.0.0"}, # Identify the client "clientInfo": {"name": "CustomMCPClient", "version": "1.0.0"},
"capabilities": {}, # Client capabilities (can be empty) "capabilities": {},
}, },
} }
# Define a timeout for initialization INITIALIZE_TIMEOUT = 15.0
INITIALIZE_TIMEOUT = 15.0 # Seconds
try: try:
# Send initialize request
await protocol.send_request(self.writer, initialize_req) await protocol.send_request(self.writer, initialize_req)
self.logger.debug(f"Sent 'initialize' request (ID: {init_req_id}). Waiting for response...") self.logger.debug(f"Sent 'initialize' request (ID: {init_req_id}). Waiting for response...")
# Wait for initialize response
init_response = await protocol.read_response(self.reader, INITIALIZE_TIMEOUT) init_response = await protocol.read_response(self.reader, INITIALIZE_TIMEOUT)
if init_response and init_response.get("id") == init_req_id: if init_response and init_response.get("id") == init_req_id:
@@ -117,9 +109,8 @@ class MCPClient:
await self.stop() await self.stop()
return False return False
elif "result" in init_response: elif "result" in init_response:
self.logger.info(f"Received 'initialize' response: {init_response.get('result', '{}')}") # Log server capabilities if provided self.logger.info(f"Received 'initialize' response: {init_response.get('result', '{}')}")
# Send initialized notification (using standard method name)
initialized_notify = {"jsonrpc": "2.0", "method": "notifications/initialized", "params": {}} initialized_notify = {"jsonrpc": "2.0", "method": "notifications/initialized", "params": {}}
await protocol.send_request(self.writer, initialized_notify) await protocol.send_request(self.writer, initialized_notify)
self.logger.info("'notifications/initialized' notification sent.") self.logger.info("'notifications/initialized' notification sent.")
@@ -135,7 +126,7 @@ class MCPClient:
self.logger.error(f"Received response with mismatched ID during initialization. Expected {init_req_id}, got {init_response.get('id')}") self.logger.error(f"Received response with mismatched ID during initialization. Expected {init_req_id}, got {init_response.get('id')}")
await self.stop() await self.stop()
return False return False
else: # Timeout case else:
self.logger.error(f"'initialize' request timed out after {INITIALIZE_TIMEOUT} seconds.") self.logger.error(f"'initialize' request timed out after {INITIALIZE_TIMEOUT} seconds.")
await self.stop() await self.stop()
return False return False
@@ -148,26 +139,23 @@ class MCPClient:
self.logger.error(f"Unexpected error during initialization handshake: {e}", exc_info=True) self.logger.error(f"Unexpected error during initialization handshake: {e}", exc_info=True)
await self.stop() await self.stop()
return False return False
# --- End MCP Initialization Handshake ---
except Exception as e: except Exception as e:
self.logger.error(f"Failed to start MCP server process: {e}", exc_info=True) self.logger.error(f"Failed to start MCP server process: {e}", exc_info=True)
self.process = None # Ensure process is None on failure self.process = None
self.reader = None self.reader = None
self.writer = None self.writer = None
self._is_running = False self._is_running = False
return False return False
async def stop(self): async def stop(self):
"""Stops the MCP server subprocess gracefully."""
if not self._is_running and not self.process: if not self._is_running and not self.process:
self.logger.debug("Stop called but client is not running.") self.logger.debug("Stop called but client is not running.")
return return
self.logger.info("Stopping MCP server process...") self.logger.info("Stopping MCP server process...")
self._is_running = False # Mark as stopping self._is_running = False
# Cancel stderr logging task
if self._stderr_task and not self._stderr_task.done(): if self._stderr_task and not self._stderr_task.done():
self._stderr_task.cancel() self._stderr_task.cancel()
try: try:
@@ -178,11 +166,9 @@ class MCPClient:
self.logger.error(f"Error waiting for stderr task cancellation: {e}") self.logger.error(f"Error waiting for stderr task cancellation: {e}")
self._stderr_task = None self._stderr_task = None
# Stop the process using the utility function
if self.process: if self.process:
await process.stop_mcp_process(self.process, self.server_name) await process.stop_mcp_process(self.process, self.server_name)
# Nullify references
self.process = None self.process = None
self.reader = None self.reader = None
self.writer = None self.writer = None
@@ -219,7 +205,6 @@ class MCPClient:
self.logger.error(f"Error response for listTools ID {req_id}: {response['error']}") self.logger.error(f"Error response for listTools ID {req_id}: {response['error']}")
return None return None
else: else:
# Includes timeout case (read_response returns None)
self.logger.error(f"No valid response or timeout for listTools ID {req_id}.") self.logger.error(f"No valid response or timeout for listTools ID {req_id}.")
return None return None
@@ -260,15 +245,12 @@ class MCPClient:
response = await protocol.read_response(self.reader, CALL_TOOL_TIMEOUT) response = await protocol.read_response(self.reader, CALL_TOOL_TIMEOUT)
if response and "result" in response: if response and "result" in response:
# Assuming result is the desired payload
self.logger.info(f"Tool '{tool_name}' executed successfully.") self.logger.info(f"Tool '{tool_name}' executed successfully.")
return response["result"] return response["result"]
elif response and "error" in response: elif response and "error" in response:
self.logger.error(f"Error response for tool '{tool_name}' ID {req_id}: {response['error']}") self.logger.error(f"Error response for tool '{tool_name}' ID {req_id}: {response['error']}")
# Return the error structure itself? Or just None? Returning error dict for now.
return {"error": response["error"]} return {"error": response["error"]}
else: else:
# Includes timeout case
self.logger.error(f"No valid response or timeout for tool '{tool_name}' ID {req_id}.") self.logger.error(f"No valid response or timeout for tool '{tool_name}' ID {req_id}.")
return None return None

View File

@@ -1,4 +1,3 @@
# src/mcp/manager.py
"""Synchronous manager for multiple MCPClient instances.""" """Synchronous manager for multiple MCPClient instances."""
import asyncio import asyncio
@@ -7,19 +6,15 @@ import logging
import threading import threading
from typing import Any from typing import Any
# Use relative imports within the mcp package
from custom_mcp.client import MCPClient from custom_mcp.client import MCPClient
# Configure basic logging
# Consider moving this to the main app entry point if not already done
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Define reasonable timeouts for sync calls (should be slightly longer than async timeouts) INITIALIZE_TIMEOUT = 60.0
INITIALIZE_TIMEOUT = 60.0 # Seconds SHUTDOWN_TIMEOUT = 30.0
SHUTDOWN_TIMEOUT = 30.0 # Seconds LIST_ALL_TOOLS_TIMEOUT = 30.0
LIST_ALL_TOOLS_TIMEOUT = 30.0 # Seconds EXECUTE_TOOL_TIMEOUT = 120.0
EXECUTE_TOOL_TIMEOUT = 120.0 # Seconds
class SyncMCPManager: class SyncMCPManager:
@@ -37,7 +32,6 @@ class SyncMCPManager:
""" """
self.config_path = config_path self.config_path = config_path
self.config: dict[str, Any] | None = None self.config: dict[str, Any] | None = None
# Stores server_name -> MCPClient instance
self.servers: dict[str, MCPClient] = {} self.servers: dict[str, MCPClient] = {}
self.initialized = False self.initialized = False
self._lock = threading.Lock() self._lock = threading.Lock()
@@ -50,7 +44,6 @@ class SyncMCPManager:
"""Load MCP configuration from JSON file.""" """Load MCP configuration from JSON file."""
logger.debug(f"Attempting to load MCP config from: {self.config_path}") logger.debug(f"Attempting to load MCP config from: {self.config_path}")
try: try:
# Using direct file access
with open(self.config_path) as f: with open(self.config_path) as f:
self.config = json.load(f) self.config = json.load(f)
logger.info("MCP configuration loaded successfully.") logger.info("MCP configuration loaded successfully.")
@@ -65,8 +58,6 @@ class SyncMCPManager:
logger.error(f"Error loading MCP config from {self.config_path}: {e}", exc_info=True) logger.error(f"Error loading MCP config from {self.config_path}: {e}", exc_info=True)
self.config = None self.config = None
# --- Background Event Loop Management ---
def _run_event_loop(self): def _run_event_loop(self):
"""Target function for the background event loop thread.""" """Target function for the background event loop thread."""
try: try:
@@ -75,14 +66,12 @@ class SyncMCPManager:
self._loop.run_forever() self._loop.run_forever()
finally: finally:
if self._loop and not self._loop.is_closed(): if self._loop and not self._loop.is_closed():
# Clean up remaining tasks before closing
try: try:
tasks = asyncio.all_tasks(self._loop) tasks = asyncio.all_tasks(self._loop)
if tasks: if tasks:
logger.debug(f"Cancelling {len(tasks)} outstanding tasks before closing loop...") logger.debug(f"Cancelling {len(tasks)} outstanding tasks before closing loop...")
for task in tasks: for task in tasks:
task.cancel() task.cancel()
# Allow cancellation to propagate
self._loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True)) self._loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
logger.debug("Outstanding tasks cancelled.") logger.debug("Outstanding tasks cancelled.")
self._loop.run_until_complete(self._loop.shutdown_asyncgens()) self._loop.run_until_complete(self._loop.shutdown_asyncgens())
@@ -99,9 +88,7 @@ class SyncMCPManager:
self._thread = threading.Thread(target=self._run_event_loop, name="MCPEventLoop", daemon=True) self._thread = threading.Thread(target=self._run_event_loop, name="MCPEventLoop", daemon=True)
self._thread.start() self._thread.start()
logger.info("Event loop thread started.") logger.info("Event loop thread started.")
# Wait briefly for the loop to become available and running
while self._loop is None or not self._loop.is_running(): while self._loop is None or not self._loop.is_running():
# Use time.sleep in sync context
import time import time
time.sleep(0.01) time.sleep(0.01)
@@ -121,8 +108,6 @@ class SyncMCPManager:
self._thread = None self._thread = None
logger.info("Event loop stopped.") logger.info("Event loop stopped.")
# --- Public Synchronous Interface ---
def initialize(self) -> bool: def initialize(self) -> bool:
""" """
Initializes and starts all configured MCP servers synchronously. Initializes and starts all configured MCP servers synchronously.
@@ -147,8 +132,6 @@ class SyncMCPManager:
logger.info("Submitting asynchronous server initialization...") logger.info("Submitting asynchronous server initialization...")
# Prepare coroutine to start all clients
async def _async_init_all(): async def _async_init_all():
tasks = [] tasks = []
for server_name, server_config in self.config["mcpServers"].items(): for server_name, server_config in self.config["mcpServers"].items():
@@ -161,19 +144,17 @@ class SyncMCPManager:
client = MCPClient(server_name, command, args, config_env) client = MCPClient(server_name, command, args, config_env)
self.servers[server_name] = client self.servers[server_name] = client
tasks.append(client.start()) # Append the start coroutine tasks.append(client.start())
results = await asyncio.gather(*tasks, return_exceptions=True) results = await asyncio.gather(*tasks, return_exceptions=True)
# Check results - True means success, False or Exception means failure
all_success = True all_success = True
failed_servers = [] failed_servers = []
for i, result in enumerate(results): for i, result in enumerate(results):
server_name = list(self.config["mcpServers"].keys())[i] # Assumes order is maintained server_name = list(self.config["mcpServers"].keys())[i]
if isinstance(result, Exception) or result is False: if isinstance(result, Exception) or result is False:
all_success = False all_success = False
failed_servers.append(server_name) failed_servers.append(server_name)
# Remove failed client from managed servers
if server_name in self.servers: if server_name in self.servers:
del self.servers[server_name] del self.servers[server_name]
logger.error(f"Failed to start client for server '{server_name}'. Result/Error: {result}") logger.error(f"Failed to start client for server '{server_name}'. Result/Error: {result}")
@@ -182,7 +163,6 @@ class SyncMCPManager:
logger.error(f"Initialization failed for servers: {failed_servers}") logger.error(f"Initialization failed for servers: {failed_servers}")
return all_success return all_success
# Run the initialization coroutine in the background loop
future = asyncio.run_coroutine_threadsafe(_async_init_all(), self._loop) future = asyncio.run_coroutine_threadsafe(_async_init_all(), self._loop)
try: try:
success = future.result(timeout=INITIALIZE_TIMEOUT) success = future.result(timeout=INITIALIZE_TIMEOUT)
@@ -192,17 +172,16 @@ class SyncMCPManager:
else: else:
logger.error("Asynchronous initialization failed.") logger.error("Asynchronous initialization failed.")
self.initialized = False self.initialized = False
# Attempt to clean up any partially started servers self.shutdown()
self.shutdown() # Call sync shutdown
except TimeoutError: except TimeoutError:
logger.error(f"Initialization timed out after {INITIALIZE_TIMEOUT}s.") logger.error(f"Initialization timed out after {INITIALIZE_TIMEOUT}s.")
self.initialized = False self.initialized = False
self.shutdown() # Clean up self.shutdown()
success = False success = False
except Exception as e: except Exception as e:
logger.error(f"Exception during initialization future result: {e}", exc_info=True) logger.error(f"Exception during initialization future result: {e}", exc_info=True)
self.initialized = False self.initialized = False
self.shutdown() # Clean up self.shutdown()
success = False success = False
return self.initialized return self.initialized
@@ -211,20 +190,14 @@ class SyncMCPManager:
"""Shuts down all managed MCP servers synchronously.""" """Shuts down all managed MCP servers synchronously."""
logger.info("Manager shutdown requested.") logger.info("Manager shutdown requested.")
with self._lock: with self._lock:
# Check servers dict too, in case init was partial
if not self.initialized and not self.servers: if not self.initialized and not self.servers:
logger.debug("Shutdown skipped: Not initialized or no servers running.") logger.debug("Shutdown skipped: Not initialized or no servers running.")
# Ensure loop is stopped if it exists
if self._thread and self._thread.is_alive(): if self._thread and self._thread.is_alive():
self._stop_event_loop_thread() self._stop_event_loop_thread()
return return
if not self._loop or not self._loop.is_running(): if not self._loop or not self._loop.is_running():
logger.warning("Shutdown requested but event loop not running. Attempting direct cleanup.") logger.warning("Shutdown requested but event loop not running. Attempting direct cleanup.")
# Attempt direct cleanup if loop isn't running (shouldn't happen ideally)
# This part is tricky as MCPClient.stop is async.
# For simplicity, we might just log and rely on process termination on app exit.
# Or, try a temporary loop just for shutdown? Let's stick to stopping the thread for now.
self.servers = {} self.servers = {}
self.initialized = False self.initialized = False
if self._thread and self._thread.is_alive(): if self._thread and self._thread.is_alive():
@@ -233,28 +206,22 @@ class SyncMCPManager:
logger.info("Submitting asynchronous server shutdown...") logger.info("Submitting asynchronous server shutdown...")
# Prepare coroutine to stop all clients
async def _async_shutdown_all(): async def _async_shutdown_all():
tasks = [client.stop() for client in self.servers.values()] tasks = [client.stop() for client in self.servers.values()]
if tasks: if tasks:
await asyncio.gather(*tasks, return_exceptions=True) await asyncio.gather(*tasks, return_exceptions=True)
# Run the shutdown coroutine in the background loop
future = asyncio.run_coroutine_threadsafe(_async_shutdown_all(), self._loop) future = asyncio.run_coroutine_threadsafe(_async_shutdown_all(), self._loop)
try: try:
future.result(timeout=SHUTDOWN_TIMEOUT) future.result(timeout=SHUTDOWN_TIMEOUT)
logger.info("Asynchronous shutdown completed.") logger.info("Asynchronous shutdown completed.")
except TimeoutError: except TimeoutError:
logger.error(f"Shutdown timed out after {SHUTDOWN_TIMEOUT}s. Event loop will be stopped.") logger.error(f"Shutdown timed out after {SHUTDOWN_TIMEOUT}s. Event loop will be stopped.")
# Processes might still be running, OS will clean up on exit hopefully
except Exception as e: except Exception as e:
logger.error(f"Exception during shutdown future result: {e}", exc_info=True) logger.error(f"Exception during shutdown future result: {e}", exc_info=True)
finally: finally:
# Always mark as uninitialized and clear servers dict
self.servers = {} self.servers = {}
self.initialized = False self.initialized = False
# Stop the background thread
self._stop_event_loop_thread() self._stop_event_loop_thread()
logger.info("Manager shutdown complete.") logger.info("Manager shutdown complete.")
@@ -277,7 +244,6 @@ class SyncMCPManager:
logger.info(f"Requesting tools from {len(self.servers)} servers...") logger.info(f"Requesting tools from {len(self.servers)} servers...")
# Prepare coroutine to list tools from all clients
async def _async_list_all(): async def _async_list_all():
tasks = [] tasks = []
server_names_in_order = [] server_names_in_order = []
@@ -293,10 +259,8 @@ class SyncMCPManager:
if isinstance(result, Exception): if isinstance(result, Exception):
logger.error(f"Error listing tools for server '{server_name}': {result}") logger.error(f"Error listing tools for server '{server_name}': {result}")
elif result is None: elif result is None:
# MCPClient.list_tools returns None on timeout/error
logger.error(f"Failed to list tools for server '{server_name}' (timeout or error).") logger.error(f"Failed to list tools for server '{server_name}' (timeout or error).")
elif isinstance(result, list): elif isinstance(result, list):
# Add server_name to each tool definition
for tool in result: for tool in result:
tool["server_name"] = server_name tool["server_name"] = server_name
all_tools.extend(result) all_tools.extend(result)
@@ -305,7 +269,6 @@ class SyncMCPManager:
logger.error(f"Unexpected result type ({type(result)}) when listing tools for {server_name}.") logger.error(f"Unexpected result type ({type(result)}) when listing tools for {server_name}.")
return all_tools return all_tools
# Run the coroutine in the background loop
future = asyncio.run_coroutine_threadsafe(_async_list_all(), self._loop) future = asyncio.run_coroutine_threadsafe(_async_list_all(), self._loop)
try: try:
aggregated_tools = future.result(timeout=LIST_ALL_TOOLS_TIMEOUT) aggregated_tools = future.result(timeout=LIST_ALL_TOOLS_TIMEOUT)
@@ -346,18 +309,16 @@ class SyncMCPManager:
logger.info(f"Executing tool '{tool_name}' on server '{server_name}' with args: {arguments}") logger.info(f"Executing tool '{tool_name}' on server '{server_name}' with args: {arguments}")
# Run the client's call_tool coroutine in the background loop
future = asyncio.run_coroutine_threadsafe(client.call_tool(tool_name, arguments), self._loop) future = asyncio.run_coroutine_threadsafe(client.call_tool(tool_name, arguments), self._loop)
try: try:
result = future.result(timeout=EXECUTE_TOOL_TIMEOUT) result = future.result(timeout=EXECUTE_TOOL_TIMEOUT)
# MCPClient.call_tool returns the result dict or an error dict or None
if result is None: if result is None:
logger.error(f"Tool execution '{tool_name}' on {server_name} failed (timeout or comm error).") logger.error(f"Tool execution '{tool_name}' on {server_name} failed (timeout or comm error).")
elif isinstance(result, dict) and "error" in result: elif isinstance(result, dict) and "error" in result:
logger.error(f"Tool execution '{tool_name}' on {server_name} returned error: {result['error']}") logger.error(f"Tool execution '{tool_name}' on {server_name} returned error: {result['error']}")
else: else:
logger.info(f"Tool '{tool_name}' execution successful.") logger.info(f"Tool '{tool_name}' execution successful.")
return result # Return result dict, error dict, or None return result
except TimeoutError: except TimeoutError:
logger.error(f"Tool execution timed out after {EXECUTE_TOOL_TIMEOUT}s for '{tool_name}' on {server_name}.") logger.error(f"Tool execution timed out after {EXECUTE_TOOL_TIMEOUT}s for '{tool_name}' on {server_name}.")
return None return None

View File

@@ -1,4 +1,3 @@
# src/mcp/process.py
"""Async utilities for managing MCP server subprocesses.""" """Async utilities for managing MCP server subprocesses."""
import asyncio import asyncio
@@ -29,25 +28,20 @@ async def start_mcp_process(command: str, args: list[str], config_env: dict[str,
""" """
logger.debug(f"Preparing to start process for command: {command}") logger.debug(f"Preparing to start process for command: {command}")
# --- Add tilde expansion for arguments ---
expanded_args = [] expanded_args = []
try: try:
for arg in args: for arg in args:
if isinstance(arg, str) and "~" in arg: if isinstance(arg, str) and "~" in arg:
expanded_args.append(os.path.expanduser(arg)) expanded_args.append(os.path.expanduser(arg))
else: else:
# Ensure all args are strings for list2cmdline
expanded_args.append(str(arg)) expanded_args.append(str(arg))
logger.debug(f"Expanded args: {expanded_args}") logger.debug(f"Expanded args: {expanded_args}")
except Exception as e: except Exception as e:
logger.error(f"Error expanding arguments for {command}: {e}", exc_info=True) logger.error(f"Error expanding arguments for {command}: {e}", exc_info=True)
raise ValueError(f"Failed to expand arguments: {e}") from e raise ValueError(f"Failed to expand arguments: {e}") from e
# --- Merge os.environ with config_env ---
merged_env = {**os.environ, **config_env} merged_env = {**os.environ, **config_env}
# logger.debug(f"Merged environment prepared (keys: {list(merged_env.keys())})") # Avoid logging values
# Combine command and expanded args into a single string for shell execution
try: try:
cmd_string = subprocess.list2cmdline([command] + expanded_args) cmd_string = subprocess.list2cmdline([command] + expanded_args)
logger.debug(f"Executing shell command: {cmd_string}") logger.debug(f"Executing shell command: {cmd_string}")
@@ -55,7 +49,6 @@ async def start_mcp_process(command: str, args: list[str], config_env: dict[str,
logger.error(f"Error creating command string: {e}", exc_info=True) logger.error(f"Error creating command string: {e}", exc_info=True)
raise ValueError(f"Failed to create command string: {e}") from e raise ValueError(f"Failed to create command string: {e}") from e
# --- Start the subprocess using shell ---
try: try:
process = await asyncio.create_subprocess_shell( process = await asyncio.create_subprocess_shell(
cmd_string, cmd_string,
@@ -68,10 +61,10 @@ async def start_mcp_process(command: str, args: list[str], config_env: dict[str,
return process return process
except FileNotFoundError: except FileNotFoundError:
logger.error(f"Command not found: '{command}' when trying to execute '{cmd_string}'") logger.error(f"Command not found: '{command}' when trying to execute '{cmd_string}'")
raise # Re-raise specific error raise
except Exception as e: except Exception as e:
logger.error(f"Failed to create subprocess for '{cmd_string}': {e}", exc_info=True) logger.error(f"Failed to create subprocess for '{cmd_string}': {e}", exc_info=True)
raise # Re-raise other errors raise
async def stop_mcp_process(process: asyncio.subprocess.Process, server_name: str = "MCP Server"): async def stop_mcp_process(process: asyncio.subprocess.Process, server_name: str = "MCP Server"):
@@ -89,7 +82,6 @@ async def stop_mcp_process(process: asyncio.subprocess.Process, server_name: str
pid = process.pid pid = process.pid
logger.info(f"Attempting to stop process {server_name} (PID: {pid})...") logger.info(f"Attempting to stop process {server_name} (PID: {pid})...")
# Close stdin first
if process.stdin and not process.stdin.is_closing(): if process.stdin and not process.stdin.is_closing():
try: try:
process.stdin.close() process.stdin.close()
@@ -98,7 +90,6 @@ async def stop_mcp_process(process: asyncio.subprocess.Process, server_name: str
except Exception as e: except Exception as e:
logger.warning(f"Error closing stdin for {server_name} (PID: {pid}): {e}") logger.warning(f"Error closing stdin for {server_name} (PID: {pid}): {e}")
# Attempt graceful termination
try: try:
process.terminate() process.terminate()
logger.debug(f"Sent terminate signal to {server_name} (PID: {pid})") logger.debug(f"Sent terminate signal to {server_name} (PID: {pid})")
@@ -108,7 +99,7 @@ async def stop_mcp_process(process: asyncio.subprocess.Process, server_name: str
logger.warning(f"Process {server_name} (PID: {pid}) did not terminate gracefully after 5s, killing.") logger.warning(f"Process {server_name} (PID: {pid}) did not terminate gracefully after 5s, killing.")
try: try:
process.kill() process.kill()
await process.wait() # Wait for kill to complete await process.wait()
logger.info(f"Process {server_name} (PID: {pid}) killed (return code: {process.returncode}).") logger.info(f"Process {server_name} (PID: {pid}) killed (return code: {process.returncode}).")
except ProcessLookupError: except ProcessLookupError:
logger.warning(f"Process {server_name} (PID: {pid}) already exited before kill.") logger.warning(f"Process {server_name} (PID: {pid}) already exited before kill.")
@@ -118,7 +109,6 @@ async def stop_mcp_process(process: asyncio.subprocess.Process, server_name: str
logger.warning(f"Process {server_name} (PID: {pid}) already exited before termination.") logger.warning(f"Process {server_name} (PID: {pid}) already exited before termination.")
except Exception as e_term: except Exception as e_term:
logger.error(f"Error during termination of {server_name} (PID: {pid}): {e_term}") logger.error(f"Error during termination of {server_name} (PID: {pid}): {e_term}")
# Attempt kill as fallback if terminate failed and process might still be running
if process.returncode is None: if process.returncode is None:
try: try:
process.kill() process.kill()

View File

@@ -1,4 +1,3 @@
# src/mcp/protocol.py
"""Async utilities for MCP JSON-RPC communication over streams.""" """Async utilities for MCP JSON-RPC communication over streams."""
import asyncio import asyncio
@@ -28,10 +27,10 @@ async def send_request(writer: asyncio.StreamWriter, request_dict: dict[str, Any
logger.debug(f"Sent request ID {request_dict.get('id')}: {request_json.strip()}") logger.debug(f"Sent request ID {request_dict.get('id')}: {request_json.strip()}")
except ConnectionResetError: except ConnectionResetError:
logger.error(f"Connection lost while sending request ID {request_dict.get('id')}") logger.error(f"Connection lost while sending request ID {request_dict.get('id')}")
raise # Re-raise for the caller (MCPClient) to handle raise
except Exception as e: except Exception as e:
logger.error(f"Error sending request ID {request_dict.get('id')}: {e}", exc_info=True) logger.error(f"Error sending request ID {request_dict.get('id')}: {e}", exc_info=True)
raise # Re-raise for the caller raise
async def read_response(reader: asyncio.StreamReader, timeout: float) -> dict[str, Any] | None: async def read_response(reader: asyncio.StreamReader, timeout: float) -> dict[str, Any] | None:

View File

@@ -1,4 +1,3 @@
# src/llm_client.py
""" """
Generic LLM client supporting multiple providers and MCP tool integration. Generic LLM client supporting multiple providers and MCP tool integration.
""" """
@@ -26,6 +25,7 @@ class LLMClient:
api_key: str, api_key: str,
mcp_manager: SyncMCPManager, mcp_manager: SyncMCPManager,
base_url: str | None = None, base_url: str | None = None,
temperature: float = 0.6, # Add temperature parameter with a fallback default
): ):
""" """
Initialize the LLM client. Initialize the LLM client.
@@ -35,9 +35,15 @@ class LLMClient:
api_key: API key for the provider. api_key: API key for the provider.
mcp_manager: An initialized instance of SyncMCPManager. mcp_manager: An initialized instance of SyncMCPManager.
base_url: Optional base URL for the provider API. base_url: Optional base URL for the provider API.
temperature: Default temperature to configure the provider with.
""" """
logger.info(f"Initializing LLMClient for provider: {provider_name}") logger.info(f"Initializing LLMClient for provider: {provider_name}")
self.provider: BaseProvider = create_llm_provider(provider_name, api_key, base_url) self.provider: BaseProvider = create_llm_provider(
provider_name,
api_key,
base_url,
temperature=temperature, # Pass temperature to provider factory
)
self.mcp_manager = mcp_manager self.mcp_manager = mcp_manager
self.mcp_tools: list[dict[str, Any]] = [] self.mcp_tools: list[dict[str, Any]] = []
self._refresh_mcp_tools() # Initial tool load self._refresh_mcp_tools() # Initial tool load
@@ -56,7 +62,7 @@ class LLMClient:
self, self,
messages: list[dict[str, str]], messages: list[dict[str, str]],
model: str, model: str,
temperature: float = 0.4, # temperature: float = 0.6, # REMOVE THIS LINE
max_tokens: int | None = None, max_tokens: int | None = None,
stream: bool = True, stream: bool = True,
) -> Generator[str, None, None] | dict[str, Any]: ) -> Generator[str, None, None] | dict[str, Any]:
@@ -66,7 +72,7 @@ class LLMClient:
Args: Args:
messages: List of message dictionaries ({'role': 'user'/'assistant', 'content': ...}). messages: List of message dictionaries ({'role': 'user'/'assistant', 'content': ...}).
model: Model identifier string. model: Model identifier string.
temperature: Sampling temperature. # temperature: REMOVED - Provider uses its configured temperature.
max_tokens: Maximum tokens to generate. max_tokens: Maximum tokens to generate.
stream: Whether to stream the response. stream: Whether to stream the response.
@@ -92,11 +98,12 @@ class LLMClient:
response = self.provider.create_chat_completion( response = self.provider.create_chat_completion(
messages=messages, messages=messages,
model=model, model=model,
temperature=temperature, # temperature=temperature, # REMOVE THIS LINE (provider uses its own)
max_tokens=max_tokens, max_tokens=max_tokens,
stream=stream, stream=stream,
tools=provider_tools, tools=provider_tools,
) )
print(f"Response: {response}") # Debugging line to check the response
logger.info("Received response from provider.") logger.info("Received response from provider.")
if stream: if stream:
@@ -168,7 +175,7 @@ class LLMClient:
follow_up_response = self.provider.create_chat_completion( follow_up_response = self.provider.create_chat_completion(
messages=messages, # Now includes assistant's turn and tool results messages=messages, # Now includes assistant's turn and tool results
model=model, model=model,
temperature=temperature, # temperature=temperature, # REMOVE THIS LINE
max_tokens=max_tokens, max_tokens=max_tokens,
stream=False, # Follow-up is non-streaming here stream=False, # Follow-up is non-streaming here
tools=provider_tools, # Pass tools again? Some providers might need it. tools=provider_tools, # Pass tools again? Some providers might need it.
@@ -212,17 +219,3 @@ class LLMClient:
except Exception as e: except Exception as e:
logger.error(f"Error during streaming: {e}", exc_info=True) logger.error(f"Error during streaming: {e}", exc_info=True)
yield json.dumps({"error": f"Streaming error: {str(e)}"}) # Yield error as JSON chunk yield json.dumps({"error": f"Streaming error: {str(e)}"}) # Yield error as JSON chunk
# Example of how a provider might need to implement get_original_message_with_calls
# This would be in the specific provider class (e.g., openai_provider.py)
# def get_original_message_with_calls(self, response: Any) -> Dict[str, Any]:
# # For OpenAI, the tool calls are usually in the *first* response chunk's choice delta
# # or in the non-streaming response's choice message
# # Needs careful implementation based on provider's response structure
# assistant_message = {
# "role": "assistant",
# "content": None, # Often null when tool calls are present
# "tool_calls": [...] # Extracted tool calls in provider format
# }
# return assistant_message

View File

@@ -1,61 +0,0 @@
MODELS = {
"openai": {
"name": "OpenAI",
"endpoint": "https://api.openai.com/v1",
"models": [
{
"id": "gpt-4o",
"name": "GPT-4o",
"default": True,
"context_window": 128000,
"description": "Input $5/M tokens, Output $15/M tokens",
}
],
},
"anthropic": {
"name": "Anthropic",
"endpoint": "https://api.anthropic.com/v1/messages",
"models": [
{
"id": "claude-3-7-sonnet-20250219",
"name": "Claude 3.7 Sonnet",
"default": True,
"context_window": 200000,
"description": "Input $3/M tokens, Output $15/M tokens",
},
{
"id": "claude-3-5-haiku-20241022",
"name": "Claude 3.5 Haiku",
"default": False,
"context_window": 200000,
"description": "Input $0.80/M tokens, Output $4/M tokens",
},
],
},
"google": {
"name": "Google Gemini",
"endpoint": "https://generativelanguage.googleapis.com/v1beta/generateContent",
"models": [
{
"id": "gemini-2.0-flash",
"name": "Gemini 2.0 Flash",
"default": True,
"context_window": 1000000,
"description": "Input $0.1/M tokens, Output $0.4/M tokens",
}
],
},
"openrouter": {
"name": "OpenRouter",
"endpoint": "https://openrouter.ai/api/v1/chat/completions",
"models": [
{
"id": "custom",
"name": "Custom Model",
"default": False,
"context_window": 128000, # Default context window, will be updated based on model
"description": "Enter any model name supported by OpenRouter (e.g., 'anthropic/claude-3-opus', 'meta-llama/llama-2-70b')",
},
],
},
}

View File

@@ -1,4 +1,3 @@
# src/providers/__init__.py
import logging import logging
from providers.anthropic_provider import AnthropicProvider from providers.anthropic_provider import AnthropicProvider
@@ -6,11 +5,8 @@ from providers.base import BaseProvider
from providers.google_provider import GoogleProvider from providers.google_provider import GoogleProvider
from providers.openai_provider import OpenAIProvider from providers.openai_provider import OpenAIProvider
# from providers.openrouter_provider import OpenRouterProvider
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Map provider names (lowercase) to their corresponding class implementations
PROVIDER_MAP: dict[str, type[BaseProvider]] = { PROVIDER_MAP: dict[str, type[BaseProvider]] = {
"openai": OpenAIProvider, "openai": OpenAIProvider,
"anthropic": AnthropicProvider, "anthropic": AnthropicProvider,
@@ -27,7 +23,7 @@ def register_provider(name: str, provider_class: type[BaseProvider]):
logger.info(f"Registered provider: {name}") logger.info(f"Registered provider: {name}")
def create_llm_provider(provider_name: str, api_key: str, base_url: str | None = None) -> BaseProvider: def create_llm_provider(provider_name: str, api_key: str, base_url: str | None = None, temperature: float = 0.6) -> BaseProvider:
""" """
Factory function to create an instance of a specific LLM provider. Factory function to create an instance of a specific LLM provider.
@@ -48,9 +44,9 @@ def create_llm_provider(provider_name: str, api_key: str, base_url: str | None =
available = ", ".join(PROVIDER_MAP.keys()) or "None" available = ", ".join(PROVIDER_MAP.keys()) or "None"
raise ValueError(f"Unsupported LLM provider: '{provider_name}'. Available providers: {available}") raise ValueError(f"Unsupported LLM provider: '{provider_name}'. Available providers: {available}")
logger.info(f"Creating LLM provider instance for: {provider_name}") logger.info(f"Creating LLM provider instance for: {provider_name} with temperature: {temperature}")
try: try:
return provider_class(api_key=api_key, base_url=base_url) return provider_class(api_key=api_key, base_url=base_url, temperature=temperature)
except Exception as e: except Exception as e:
logger.error(f"Failed to instantiate provider '{provider_name}': {e}", exc_info=True) logger.error(f"Failed to instantiate provider '{provider_name}': {e}", exc_info=True)
raise RuntimeError(f"Could not create provider '{provider_name}'.") from e raise RuntimeError(f"Could not create provider '{provider_name}'.") from e
@@ -59,11 +55,3 @@ def create_llm_provider(provider_name: str, api_key: str, base_url: str | None =
def get_available_providers() -> list[str]: def get_available_providers() -> list[str]:
"""Returns a list of registered provider names.""" """Returns a list of registered provider names."""
return list(PROVIDER_MAP.keys()) return list(PROVIDER_MAP.keys())
# Example of how specific providers would register themselves if structured as plugins,
# but for now, we'll explicitly import and map them above.
# def load_providers():
# # Potentially load providers dynamically if designed as plugins
# pass
# load_providers()

View File

@@ -6,11 +6,21 @@ from providers.base import BaseProvider
class AnthropicProvider(BaseProvider): class AnthropicProvider(BaseProvider):
def __init__(self, api_key: str, base_url: str | None = None): temperature: float
self.client = initialize_client(api_key, base_url)
def create_chat_completion(self, messages, model, temperature=0.4, max_tokens=None, stream=True, tools=None): def __init__(self, api_key: str, base_url: str | None = None, temperature: float = 0.6):
return create_chat_completion(self, messages, model, temperature, max_tokens, stream, tools) self.client = initialize_client(api_key, base_url)
self.temperature = temperature
def create_chat_completion(
self,
messages,
model,
max_tokens=None,
stream=True,
tools=None,
):
return create_chat_completion(self, messages, model, self.temperature, max_tokens, stream, tools)
def get_streaming_content(self, response): def get_streaming_content(self, response):
return get_streaming_content(response) return get_streaming_content(response)

View File

@@ -10,7 +10,7 @@ logger = logging.getLogger(__name__)
def create_chat_completion( def create_chat_completion(
provider, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, stream: bool = True, tools: list[dict[str, Any]] | None = None provider, messages: list[dict[str, Any]], model: str, temperature: float = 0.6, max_tokens: int | None = None, stream: bool = True, tools: list[dict[str, Any]] | None = None
) -> Stream | Message: ) -> Stream | Message:
logger.debug(f"Creating Anthropic chat completion. Model: {model}, Stream: {stream}, Tools: {bool(tools)}") logger.debug(f"Creating Anthropic chat completion. Model: {model}, Stream: {stream}, Tools: {bool(tools)}")
temp_system_prompt, temp_anthropic_messages = convert_messages(messages) temp_system_prompt, temp_anthropic_messages = convert_messages(messages)

View File

@@ -84,17 +84,12 @@ def convert_to_anthropic_tools(mcp_tools: list[dict[str, Any]]) -> list[dict[str
logger.warning(f"Skipping invalid MCP tool definition during Anthropic conversion: {tool}") logger.warning(f"Skipping invalid MCP tool definition during Anthropic conversion: {tool}")
continue continue
# Prefix tool name with server name for routing
prefixed_tool_name = f"{server_name}__{tool_name}" prefixed_tool_name = f"{server_name}__{tool_name}"
# Initialize the Anthropic tool structure
# Anthropic's format is quite close to JSON Schema
anthropic_tool = {"name": prefixed_tool_name, "description": description, "input_schema": input_schema} anthropic_tool = {"name": prefixed_tool_name, "description": description, "input_schema": input_schema}
# Basic validation/cleaning of schema if needed
if not isinstance(input_schema, dict) or input_schema.get("type") != "object": if not isinstance(input_schema, dict) or input_schema.get("type") != "object":
logger.warning(f"Input schema for tool '{prefixed_tool_name}' is not a valid JSON object schema. Anthropic might reject this.") logger.warning(f"Input schema for tool '{prefixed_tool_name}' is not a valid JSON object schema. Anthropic might reject this.")
# Ensure basic structure if missing
if not isinstance(input_schema, dict): if not isinstance(input_schema, dict):
input_schema = {} input_schema = {}
if "type" not in input_schema: if "type" not in input_schema:

View File

@@ -1,4 +1,3 @@
# src/providers/base.py
import abc import abc
from collections.abc import Generator from collections.abc import Generator
from typing import Any from typing import Any
@@ -28,7 +27,7 @@ class BaseProvider(abc.ABC):
self, self,
messages: list[dict[str, str]], messages: list[dict[str, str]],
model: str, model: str,
temperature: float = 0.4, temperature: float = 0.6,
max_tokens: int | None = None, max_tokens: int | None = None,
stream: bool = True, stream: bool = True,
tools: list[dict[str, Any]] | None = None, tools: list[dict[str, Any]] | None = None,
@@ -39,7 +38,7 @@ class BaseProvider(abc.ABC):
Args: Args:
messages: List of message dictionaries with 'role' and 'content'. messages: List of message dictionaries with 'role' and 'content'.
model: Model identifier. model: Model identifier.
temperature: Sampling temperature (0-1). temperature: Sampling temperature (0-2).
max_tokens: Maximum tokens to generate. max_tokens: Maximum tokens to generate.
stream: Whether to stream the response. stream: Whether to stream the response.
tools: Optional list of tools in the provider-specific format. tools: Optional list of tools in the provider-specific format.
@@ -147,8 +146,3 @@ class BaseProvider(abc.ABC):
or None if usage information is not available. or None if usage information is not available.
""" """
pass pass
# Optional: Add a method for follow-up completions if the provider API
# requires a specific structure different from just appending messages.
# def create_follow_up_completion(...) -> Any:
# pass

View File

@@ -1,4 +1,3 @@
# src/providers/google_provider/__init__.py
import logging import logging
from collections.abc import Generator from collections.abc import Generator
from typing import Any from typing import Any
@@ -17,39 +16,41 @@ logger = logging.getLogger(__name__)
class GoogleProvider(BaseProvider): class GoogleProvider(BaseProvider):
"""Provider implementation for Google Generative AI (Gemini).""" """Provider implementation for Google Generative AI (Gemini)."""
# Type hint for the client (it's the configured 'genai' module itself)
client_module: Any client_module: Any
temperature: float
def __init__(self, api_key: str, base_url: str | None = None): def __init__(self, api_key: str, base_url: str | None = None, temperature: float = 0.6):
""" """
Initializes the GoogleProvider. Initializes the GoogleProvider.
Args: Args:
api_key: The Google API key. api_key: The Google API key.
base_url: Base URL (typically not used by Google client config, but kept for interface consistency). base_url: Base URL (typically not used by Google client config, but kept for interface consistency).
temperature: The default temperature for completions.
""" """
# initialize_client returns the configured genai module
self.client_module = initialize_client(api_key, base_url) self.client_module = initialize_client(api_key, base_url)
self.api_key = api_key # Store if needed later self.api_key = api_key
self.base_url = base_url # Store if needed later self.base_url = base_url
logger.info("GoogleProvider initialized.") self.temperature = temperature
logger.info(f"GoogleProvider initialized with temperature: {self.temperature}")
def create_chat_completion( def create_chat_completion(
self, self,
messages: list[dict[str, Any]], messages: list[dict[str, Any]],
model: str, model: str,
temperature: float = 0.4,
max_tokens: int | None = None, max_tokens: int | None = None,
stream: bool = True, stream: bool = True,
tools: list[dict[str, Any]] | None = None, tools: list[dict[str, Any]] | None = None,
) -> Any: # Return type is complex: iterator for stream, GenerateContentResponse otherwise, or error dict/iterator ) -> Any:
"""Creates a chat completion using the Google Gemini API.""" """Creates a chat completion using the Google Gemini API."""
# Pass self (provider instance) to the helper function raw_response = create_chat_completion(self, messages, model, self.temperature, max_tokens, stream, tools)
return create_chat_completion(self, messages, model, temperature, max_tokens, stream, tools) print(f"Raw response type: {type(raw_response)}")
print(f"Raw response: {raw_response}")
return raw_response
def get_streaming_content(self, response: Any) -> Generator[str, None, None]: def get_streaming_content(self, response: Any) -> Generator[str, None, None]:
"""Extracts content chunks from a Google streaming response.""" """Extracts content chunks from a Google streaming response."""
# Response is expected to be an iterator from generate_content(stream=True)
return get_streaming_content(response) return get_streaming_content(response)
def get_content(self, response: GenerateContentResponse | dict[str, Any]) -> str: def get_content(self, response: GenerateContentResponse | dict[str, Any]) -> str:
@@ -58,33 +59,20 @@ class GoogleProvider(BaseProvider):
def has_tool_calls(self, response: GenerateContentResponse | dict[str, Any]) -> bool: def has_tool_calls(self, response: GenerateContentResponse | dict[str, Any]) -> bool:
"""Checks if the Google response contains tool calls (FunctionCalls).""" """Checks if the Google response contains tool calls (FunctionCalls)."""
# Note: For streaming responses, this check is reliable only after the stream is fully consumed
# or if the specific chunk containing the call is processed.
return has_google_tool_calls(response) return has_google_tool_calls(response)
def parse_tool_calls(self, response: GenerateContentResponse | dict[str, Any]) -> list[dict[str, Any]]: def parse_tool_calls(self, response: GenerateContentResponse | dict[str, Any]) -> list[dict[str, Any]]:
"""Parses tool calls (FunctionCalls) from a non-streaming Google response.""" """Parses tool calls (FunctionCalls) from a non-streaming Google response."""
# Expects a non-streaming GenerateContentResponse or an error dict
return parse_google_tool_calls(response) return parse_google_tool_calls(response)
# Note: Google's format_tool_results helper requires the original function_name.
# Ensure the calling code (e.g., LLMClient) provides this when invoking this method.
def format_tool_results(self, tool_call_id: str, function_name: str, result: Any) -> dict[str, Any]: def format_tool_results(self, tool_call_id: str, function_name: str, result: Any) -> dict[str, Any]:
"""Formats a tool result for a Google follow-up request (into standard message format).""" """Formats a tool result for a Google follow-up request (into standard message format)."""
return format_google_tool_results(tool_call_id, function_name, result) return format_google_tool_results(tool_call_id, function_name, result)
def convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]: def convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Converts MCP tools list to Google's intermediate dictionary format.""" """Converts MCP tools list to Google's intermediate dictionary format."""
# The `create_chat_completion` function handles the final conversion
# from this intermediate format to Google's `Tool` objects internally.
return convert_to_google_tools(tools) return convert_to_google_tools(tools)
def get_usage(self, response: GenerateContentResponse | dict[str, Any]) -> dict[str, int] | None: def get_usage(self, response: GenerateContentResponse | dict[str, Any]) -> dict[str, int] | None:
"""Extracts token usage information from a Google response.""" """Extracts token usage information from a Google response."""
# Expects a non-streaming GenerateContentResponse or an error dict
return get_usage(response) return get_usage(response)
# `get_original_message_with_calls` (present in OpenAIProvider) is not implemented here
# as Google's API structure integrates FunctionCall parts directly into the assistant's
# message content, rather than having a separate `tool_calls` attribute on the message object.
# The necessary information is handled during message conversion and tool call parsing.

View File

@@ -1,4 +1,3 @@
# src/providers/google_provider/client.py
import logging import logging
from typing import Any from typing import Any
@@ -12,16 +11,15 @@ def initialize_client(api_key: str, base_url: str | None = None) -> Any:
logger.info("Initializing Google Generative AI client") logger.info("Initializing Google Generative AI client")
if genai is None: if genai is None:
logger.error("Google Generative AI SDK (google-generativeai) is not installed.") logger.error("Google Generative AI SDK (google-genai) is not installed.")
raise ImportError("Google Generative AI SDK is required for GoogleProvider. Please install google-generativeai.") raise ImportError("Google Generative AI SDK is required for GoogleProvider. Please install google-generativeai.")
try: try:
# Configure the client client = genai.Client(api_key=api_key)
genai.configure(api_key=api_key) logger.info("Google Generative AI client instantiated.")
if base_url: if base_url:
logger.warning(f"base_url '{base_url}' provided but not typically used by Google client configuration.") logger.warning(f"base_url '{base_url}' provided but not typically used by Google client instantiation.")
# Return the configured module itself, as it's used directly return client
return genai
except Exception as e: except Exception as e:
logger.error(f"Failed to configure Google Generative AI client: {e}", exc_info=True) logger.error(f"Failed to instantiate Google Generative AI client: {e}", exc_info=True)
raise raise

View File

@@ -1,140 +1,140 @@
# src/providers/google_provider/completion.py
import json
import logging import logging
import traceback import traceback
from collections.abc import Iterable
from typing import Any from typing import Any
from google.genai.types import Tool from google.genai.types import ContentDict, GenerateContentResponse, GenerationConfigDict, Tool
from providers.google_provider.tools import convert_to_google_tool_objects, convert_to_google_tools from providers.google_provider.tools import convert_to_google_tool_objects
from providers.google_provider.utils import convert_messages from providers.google_provider.utils import convert_messages
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _create_chat_completion_non_stream(
provider,
model: str,
google_messages: list[ContentDict],
generation_config: GenerationConfigDict,
) -> GenerateContentResponse | dict[str, Any]:
"""Handles the non-streaming API call."""
try:
logger.debug("Calling client.models.generate_content...")
response = provider.client_module.models.generate_content(
model=f"models/{model}",
contents=google_messages,
config=generation_config,
)
logger.debug("generate_content call successful, returning raw response object.")
return response
except ValueError as ve:
error_msg = f"Google API request validation error: {ve}"
logger.error(error_msg, exc_info=True)
return {"error": error_msg, "traceback": traceback.format_exc()}
except Exception as e:
error_msg = f"Google API error during non-stream chat completion: {e}"
logger.error(error_msg, exc_info=True)
return {"error": error_msg, "traceback": traceback.format_exc()}
def _create_chat_completion_stream(
provider,
model: str,
google_messages: list[ContentDict],
generation_config: GenerationConfigDict,
) -> Iterable[GenerateContentResponse | dict[str, Any]]:
"""Handles the streaming API call and yields results."""
try:
logger.debug("Calling client.models.generate_content_stream...")
response_iterator = provider.client_module.models.generate_content_stream(
model=f"models/{model}",
contents=google_messages,
config=generation_config,
)
logger.debug("generate_content_stream call successful, yielding from iterator.")
yield from response_iterator
except ValueError as ve:
error_msg = f"Google API request validation error: {ve}"
logger.error(error_msg, exc_info=True)
yield {"error": error_msg, "traceback": traceback.format_exc()}
except Exception as e:
error_msg = f"Google API error during stream chat completion: {e}"
logger.error(error_msg, exc_info=True)
yield {"error": error_msg, "traceback": traceback.format_exc()}
def create_chat_completion( def create_chat_completion(
provider, provider,
messages: list[dict[str, Any]], messages: list[dict[str, Any]],
model: str, model: str,
temperature: float = 0.4, temperature: float = 0.6,
max_tokens: int | None = None, max_tokens: int | None = None,
stream: bool = True, stream: bool = True,
tools: list[dict[str, Any]] | None = None, tools: list[dict[str, Any]] | None = None,
) -> Any: ) -> Any:
""" """
Creates a chat completion using the Google Gemini API. Creates a chat completion using the Google Gemini API.
Delegates to streaming or non-streaming helpers. Contains NO yield itself.
Args:
provider: The instance of the GoogleProvider.
messages: A list of message dictionaries in the standard format.
model: The model ID to use (e.g., "gemini-1.5-flash").
temperature: The sampling temperature.
max_tokens: The maximum number of tokens to generate.
stream: Whether to stream the response.
tools: A list of tool definitions in the MCP format.
Returns:
The response object from the Google API (could be a stream iterator or
a GenerateContentResponse object), or an error dictionary/iterator.
""" """
logger.debug(f"Google create_chat_completion called. Model: {model}, Stream: {stream}, Tools: {bool(tools)}") logger.debug(f"Google create_chat_completion_inner called. Model: {model}, Stream: {stream}, Tools: {bool(tools)}")
if provider.client_module is None: if provider.client_module is None:
error_msg = "Google Generative AI SDK not configured or installed." error_msg = "Google Generative AI client not initialized on provider."
logger.error(error_msg) logger.error(error_msg)
# Return an error structure compatible with both streaming and non-streaming expectations return iter([{"error": error_msg}]) if stream else {"error": error_msg}
if stream:
return iter([json.dumps({"error": error_msg})])
else:
return {"error": error_msg}
try: try:
# 1. Convert messages to Google's format
google_messages, system_prompt = convert_messages(messages) google_messages, system_prompt = convert_messages(messages)
logger.debug(f"Converted {len(messages)} messages to {len(google_messages)} Google Content objects. System prompt present: {bool(system_prompt)}") logger.debug(f"Converted {len(messages)} messages to {len(google_messages)} Google Content objects. System prompt present: {bool(system_prompt)}")
# 2. Prepare generation configuration generation_config: GenerationConfigDict = {"temperature": temperature}
generation_config: dict[str, Any] = {"temperature": temperature}
if max_tokens is not None: if max_tokens is not None:
# Google uses 'max_output_tokens'
generation_config["max_output_tokens"] = max_tokens generation_config["max_output_tokens"] = max_tokens
logger.debug(f"Setting max_output_tokens: {max_tokens}") logger.debug(f"Setting max_output_tokens: {max_tokens}")
else: else:
logger.debug("No max_tokens specified.") default_max_tokens = 8192
generation_config["max_output_tokens"] = default_max_tokens
logger.warning(f"max_tokens not provided, defaulting to {default_max_tokens} for Google API.")
# 3. Convert tools if provided
google_tool_objects: list[Tool] | None = None google_tool_objects: list[Tool] | None = None
if tools: if tools:
try: try:
# Step 3a: Convert MCP tools to intermediate Google dict format google_tool_objects = convert_to_google_tool_objects(tools)
tool_dict_list = convert_to_google_tools(tools)
# Step 3b: Convert intermediate dict format to Google Tool objects
google_tool_objects = convert_to_google_tool_objects(tool_dict_list)
if google_tool_objects: if google_tool_objects:
logger.debug(f"Successfully converted {len(tools)} MCP tools to {len(google_tool_objects)} Google Tool objects.") num_declarations = sum(len(t.function_declarations) for t in google_tool_objects if t.function_declarations)
logger.debug(f"Successfully converted intermediate tool config to {len(google_tool_objects)} Google Tool objects with {num_declarations} declarations.")
else: else:
logger.warning("Tool conversion resulted in no valid Google Tool objects.") logger.warning("Tool conversion resulted in no valid Google Tool objects.")
except Exception as tool_conv_err: except Exception as tool_conv_err:
logger.error(f"Failed to convert tools for Google: {tool_conv_err}", exc_info=True) logger.error(f"Failed to convert tools for Google: {tool_conv_err}", exc_info=True)
# Decide whether to proceed without tools or raise an error
# Proceeding without tools for now, but logging the error.
google_tool_objects = None google_tool_objects = None
else: else:
logger.debug("No tools provided for conversion.") logger.debug("No tools provided for conversion.")
# 4. Initialize the Google Generative Model if system_prompt:
# Ensure client_module is callable and has GenerativeModel generation_config["system_instruction"] = system_prompt
if not hasattr(provider.client_module, "GenerativeModel"): logger.debug("Added system_instruction to generation_config.")
raise AttributeError("Configured Google client module does not have 'GenerativeModel'") if google_tool_objects:
generation_config["tools"] = google_tool_objects
logger.debug(f"Added {len(google_tool_objects)} tool objects to generation_config.")
gemini_model = provider.client_module.GenerativeModel(
model_name=model,
system_instruction=system_prompt, # Pass extracted system prompt
tools=google_tool_objects, # Pass converted Tool objects (or None)
# Add safety_settings if needed: safety_settings=...
)
logger.debug(f"Initialized Google GenerativeModel for '{model}'.")
# 5. Log parameters before API call
log_params = { log_params = {
"model": model, "model": model,
"stream": stream, "stream": stream,
"temperature": temperature, "temperature": temperature,
"max_output_tokens": generation_config.get("max_output_tokens"), "max_output_tokens": generation_config.get("max_output_tokens"),
"system_prompt_present": bool(system_prompt), "system_prompt_present": bool(system_prompt),
"num_tools": len(google_tool_objects) if google_tool_objects else 0, "num_tools": len(generation_config.get("tools", [])) if "tools" in generation_config else 0,
"num_messages": len(google_messages), "num_messages": len(google_messages),
} }
logger.info(f"Calling Google generate_content API with params: {log_params}") logger.info(f"Calling Google API via helper with params: {log_params}")
# Avoid logging full message content unless necessary for debugging specific issues
# logger.debug(f"Google messages being sent: {google_messages}")
# 6. Call the Google API
response = gemini_model.generate_content(
contents=google_messages,
generation_config=generation_config,
stream=stream,
# tool_config={"function_calling_config": "AUTO"} # AUTO is default
)
logger.debug("Google API call successful, returning response object.")
return response
except ValueError as ve: # Catch specific errors like invalid message sequence
error_msg = f"Google API request validation error: {ve}"
logger.error(error_msg, exc_info=True)
if stream: if stream:
# Yield a JSON error message in an iterator return _create_chat_completion_stream(provider, model, google_messages, generation_config)
yield json.dumps({"error": error_msg, "traceback": traceback.format_exc()})
else: else:
# Return an error dictionary return _create_chat_completion_non_stream(provider, model, google_messages, generation_config)
return {"error": error_msg, "traceback": traceback.format_exc()}
except Exception as e: except Exception as e:
# Catch any other exceptions during setup or API call error_msg = f"Error during Google completion setup: {e}"
error_msg = f"Google API error during chat completion: {e}"
logger.error(error_msg, exc_info=True) logger.error(error_msg, exc_info=True)
if stream: return iter([{"error": error_msg, "traceback": traceback.format_exc()}]) if stream else {"error": error_msg, "traceback": traceback.format_exc()}
# Yield a JSON error message in an iterator
yield json.dumps({"error": error_msg, "traceback": traceback.format_exc()})
else:
# Return an error dictionary
return {"error": error_msg, "traceback": traceback.format_exc()}

View File

@@ -1,4 +1,3 @@
# src/providers/google_provider/response.py
""" """
Response handling utilities specific to the Google Generative AI provider. Response handling utilities specific to the Google Generative AI provider.
@@ -32,50 +31,36 @@ def get_streaming_content(response: Any) -> Generator[str, None, None]:
logger.debug("Processing Google stream...") logger.debug("Processing Google stream...")
full_delta = "" full_delta = ""
try: try:
# Check if the response itself is an error indicator (e.g., from create_chat_completion error handling)
if isinstance(response, dict) and "error" in response: if isinstance(response, dict) and "error" in response:
yield json.dumps(response) yield json.dumps(response)
logger.error(f"Stream processing stopped due to initial error: {response['error']}") logger.error(f"Stream processing stopped due to initial error: {response['error']}")
return return
# Check if response is already an error iterator
if hasattr(response, "__iter__") and not hasattr(response, "candidates"): if hasattr(response, "__iter__") and not hasattr(response, "candidates"):
# If it looks like an error iterator from create_chat_completion
first_item = next(response, None) first_item = next(response, None)
if first_item and isinstance(first_item, str): if first_item and isinstance(first_item, str):
try: try:
error_data = json.loads(first_item) error_data = json.loads(first_item)
if "error" in error_data: if "error" in error_data:
yield first_item # Yield the error JSON yield first_item
yield from response yield from response
logger.error(f"Stream processing stopped due to yielded error: {error_data['error']}") logger.error(f"Stream processing stopped due to yielded error: {error_data['error']}")
return return
except json.JSONDecodeError: except json.JSONDecodeError:
# Not a JSON error, yield it as is and continue? Or stop?
# Assuming it might be valid content if not JSON error.
yield first_item yield first_item
elif first_item: # Put the first item back if it wasn't an error elif first_item:
# This requires a way to chain iterators, simple yield doesn't work well here. pass
# For simplicity, we assume error iterators yield JSON strings.
# If the stream is valid, the loop below will handle it.
# Re-assigning response might be complex. Let the main loop handle valid streams.
pass # Let the main loop handle the original response iterator
# Process the stream chunk by chunk
for chunk in response: for chunk in response:
# Check for errors embedded within the stream chunks (less common for Google?)
if isinstance(chunk, dict) and "error" in chunk: if isinstance(chunk, dict) and "error" in chunk:
yield json.dumps(chunk) yield json.dumps(chunk)
logger.error(f"Error encountered during Google stream: {chunk['error']}") logger.error(f"Error encountered during Google stream: {chunk['error']}")
continue # Continue processing stream or stop? Continuing for now. continue
# Extract text content
delta = "" delta = ""
try: try:
if hasattr(chunk, "text"): if hasattr(chunk, "text"):
delta = chunk.text delta = chunk.text
elif hasattr(chunk, "candidates") and chunk.candidates: elif hasattr(chunk, "candidates") and chunk.candidates:
# Sometimes content might be nested under candidates even in stream?
# Check the first candidate's first part for text.
first_candidate = chunk.candidates[0] first_candidate = chunk.candidates[0]
if hasattr(first_candidate, "content") and hasattr(first_candidate.content, "parts") and first_candidate.content.parts: if hasattr(first_candidate, "content") and hasattr(first_candidate.content, "parts") and first_candidate.content.parts:
first_part = first_candidate.content.parts[0] first_part = first_candidate.content.parts[0]
@@ -83,32 +68,27 @@ def get_streaming_content(response: Any) -> Generator[str, None, None]:
delta = first_part.text delta = first_part.text
except Exception as e: except Exception as e:
logger.warning(f"Could not extract text from stream chunk: {chunk}. Error: {e}", exc_info=True) logger.warning(f"Could not extract text from stream chunk: {chunk}. Error: {e}", exc_info=True)
delta = "" # Ensure delta is a string delta = ""
if delta: if delta:
full_delta += delta full_delta += delta
yield delta yield delta
# Detect function calls during stream (optional, for logging/early detection)
try: try:
if hasattr(chunk, "candidates") and chunk.candidates: if hasattr(chunk, "candidates") and chunk.candidates:
for part in chunk.candidates[0].content.parts: for part in chunk.candidates[0].content.parts:
if hasattr(part, "function_call") and part.function_call: if hasattr(part, "function_call") and part.function_call:
logger.debug(f"Function call detected during stream: {part.function_call.name}") logger.debug(f"Function call detected during stream: {part.function_call.name}")
# Note: We don't yield the function call itself here, just the text. break
# Function calls are typically processed after the stream completes.
break # Found a function call in this chunk
except Exception: except Exception:
# Ignore errors during optional function call detection in stream
pass pass
logger.debug(f"Google stream finished. Total delta length: {len(full_delta)}") logger.debug(f"Google stream finished. Total delta length: {len(full_delta)}")
except StopIteration: except StopIteration:
logger.debug("Google stream finished (StopIteration).") # Normal end of iteration logger.debug("Google stream finished (StopIteration).")
except Exception as e: except Exception as e:
logger.error(f"Error processing Google stream: {e}", exc_info=True) logger.error(f"Error processing Google stream: {e}", exc_info=True)
# Yield a final error message
yield json.dumps({"error": f"Stream processing error: {str(e)}"}) yield json.dumps({"error": f"Stream processing error: {str(e)}"})
@@ -124,37 +104,39 @@ def get_content(response: GenerateContentResponse | dict[str, Any]) -> str:
The concatenated text content, or an error message string. The concatenated text content, or an error message string.
""" """
try: try:
# Handle error dictionary case
if isinstance(response, dict) and "error" in response: if isinstance(response, dict) and "error" in response:
logger.error(f"Cannot get content from error response: {response['error']}") logger.error(f"Cannot get content from error dict: {response['error']}")
return f"[Error: {response['error']}]" return f"[Error: {response['error']}]"
# Handle successful GenerateContentResponse object if not isinstance(response, GenerateContentResponse):
if hasattr(response, "text"): logger.error(f"Cannot get content: Expected GenerateContentResponse or error dict, got {type(response)}")
# The `.text` attribute usually provides the concatenated text content directly return f"[Error: Unexpected response type {type(response)}]"
if hasattr(response, "text") and response.text:
content = response.text content = response.text
logger.debug(f"Extracted content (length {len(content)}) from response.text.") logger.debug(f"Extracted content (length {len(content)}) from response.text.")
return content return content
elif hasattr(response, "candidates") and response.candidates:
# Fallback: manually concatenate text from parts if .text is missing if hasattr(response, "candidates") and response.candidates:
first_candidate = response.candidates[0] first_candidate = response.candidates[0]
if hasattr(first_candidate, "content") and hasattr(first_candidate.content, "parts"): if hasattr(first_candidate, "content") and first_candidate.content and hasattr(first_candidate.content, "parts") and first_candidate.content.parts:
text_parts = [] text_parts = [part.text for part in first_candidate.content.parts if hasattr(part, "text")]
for part in first_candidate.content.parts: if text_parts:
if hasattr(part, "text"):
text_parts.append(part.text)
# We are only interested in text content here, ignore function calls etc.
content = "".join(text_parts) content = "".join(text_parts)
logger.debug(f"Extracted content (length {len(content)}) from response candidates' parts.") logger.debug(f"Extracted content (length {len(content)}) from response candidate parts.")
return content return content
else: else:
logger.warning("Google response candidate has no content or parts.") logger.warning("Google response candidate parts contained no text.")
return "" # Return empty string if no text found return ""
else: else:
logger.warning(f"Could not extract content from Google response: No 'text' or valid 'candidates'. Response type: {type(response)}") logger.warning("Google response candidate has no valid content or parts.")
return "" # Return empty string if no text found return ""
else:
logger.warning(f"Could not extract content from Google response: No .text or valid candidates found. Response: {response}")
return ""
except AttributeError as ae: except AttributeError as ae:
logger.error(f"Attribute error extracting content from Google response: {ae}. Response object: {response}", exc_info=True) logger.error(f"Attribute error extracting content from Google response: {ae}. Response type: {type(response)}", exc_info=True)
return f"[Error extracting content: Attribute missing - {str(ae)}]" return f"[Error extracting content: Attribute missing - {str(ae)}]"
except Exception as e: except Exception as e:
logger.error(f"Unexpected error extracting content from Google response: {e}", exc_info=True) logger.error(f"Unexpected error extracting content from Google response: {e}", exc_info=True)
@@ -173,32 +155,30 @@ def get_usage(response: GenerateContentResponse | dict[str, Any]) -> dict[str, i
usage information is unavailable or an error occurred. usage information is unavailable or an error occurred.
""" """
try: try:
# Handle error dictionary case
if isinstance(response, dict) and "error" in response: if isinstance(response, dict) and "error" in response:
logger.warning("Cannot get usage from error response.") logger.warning(f"Cannot get usage from error dict: {response['error']}")
return None return None
# Check for usage metadata in the response object if not isinstance(response, GenerateContentResponse):
if hasattr(response, "usage_metadata"): logger.warning(f"Cannot get usage: Expected GenerateContentResponse or error dict, got {type(response)}")
metadata = response.usage_metadata return None
# Google uses prompt_token_count and candidates_token_count
metadata = getattr(response, "usage_metadata", None)
if metadata:
prompt_tokens = getattr(metadata, "prompt_token_count", 0)
completion_tokens = getattr(metadata, "candidates_token_count", 0)
usage = { usage = {
"prompt_tokens": getattr(metadata, "prompt_token_count", 0), "prompt_tokens": int(prompt_tokens),
"completion_tokens": getattr(metadata, "candidates_token_count", 0), "completion_tokens": int(completion_tokens),
# Google also provides total_token_count, could be added if needed
# "total_tokens": getattr(metadata, "total_token_count", 0),
} }
# Ensure values are integers
usage = {k: int(v) for k, v in usage.items()}
logger.debug(f"Extracted usage from Google response metadata: {usage}") logger.debug(f"Extracted usage from Google response metadata: {usage}")
return usage return usage
else: else:
# Log a warning only if it's not clearly an error dict already handled logger.warning(f"Could not extract usage from Google response object: No 'usage_metadata' attribute found. Response: {response}")
if not (isinstance(response, dict) and "error" in response):
logger.warning(f"Could not extract usage from Google response object of type {type(response)}. No 'usage_metadata' attribute found.")
return None return None
except AttributeError as ae: except AttributeError as ae:
logger.error(f"Attribute error extracting usage from Google response: {ae}. Response object: {response}", exc_info=True) logger.error(f"Attribute error extracting usage from Google response: {ae}. Response type: {type(response)}", exc_info=True)
return None return None
except Exception as e: except Exception as e:
logger.error(f"Unexpected error extracting usage from Google response: {e}", exc_info=True) logger.error(f"Unexpected error extracting usage from Google response: {e}", exc_info=True)

View File

@@ -1,4 +1,3 @@
# src/providers/google_provider/tools.py
""" """
Tool handling utilities specific to the Google Generative AI provider. Tool handling utilities specific to the Google Generative AI provider.
@@ -13,14 +12,11 @@ import json
import logging import logging
from typing import Any from typing import Any
from google.genai.types import FunctionDeclaration, Schema, Tool from google.genai.types import FunctionDeclaration, Schema, Tool, Type
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# --- Tool Conversion (from MCP format to Google format) ---
def convert_to_google_tools(mcp_tools: list[dict[str, Any]]) -> list[dict[str, Any]]: def convert_to_google_tools(mcp_tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
""" """
Convert MCP tools to Google Gemini format (dictionary structure). Convert MCP tools to Google Gemini format (dictionary structure).
@@ -48,41 +44,34 @@ def convert_to_google_tools(mcp_tools: list[dict[str, Any]]) -> list[dict[str, A
logger.warning(f"Skipping invalid MCP tool definition during Google conversion: {tool}") logger.warning(f"Skipping invalid MCP tool definition during Google conversion: {tool}")
continue continue
# Prefix tool name with server name for routing
prefixed_tool_name = f"{server_name}__{tool_name}" prefixed_tool_name = f"{server_name}__{tool_name}"
# Basic validation/cleaning of schema for Google compatibility
if not isinstance(input_schema, dict) or input_schema.get("type") != "object": if not isinstance(input_schema, dict) or input_schema.get("type") != "object":
logger.warning(f"Input schema for tool '{prefixed_tool_name}' is not a valid JSON object schema. Google might reject this. Attempting to normalize.") logger.warning(f"Input schema for tool '{prefixed_tool_name}' is not a valid JSON object schema. Google might reject this. Attempting to normalize.")
# Ensure basic structure if missing
if not isinstance(input_schema, dict): if not isinstance(input_schema, dict):
input_schema = {} # Start fresh if not a dict input_schema = {}
if "type" not in input_schema or input_schema["type"] != "object": if "type" not in input_schema or input_schema["type"] != "object":
# Wrap existing schema or create new if type is wrong/missing
input_schema = {"type": "object", "properties": {"_original_schema": input_schema}} if input_schema else {"type": "object", "properties": {}} input_schema = {"type": "object", "properties": {"_original_schema": input_schema}} if input_schema else {"type": "object", "properties": {}}
logger.warning(f"Wrapped original schema for {prefixed_tool_name} under '_original_schema' property.") logger.warning(f"Wrapped original schema for {prefixed_tool_name} under '_original_schema' property.")
if "properties" not in input_schema: if "properties" not in input_schema:
input_schema["properties"] = {} input_schema["properties"] = {}
# Google requires properties for object type, add dummy if empty
if not input_schema["properties"]: if not input_schema["properties"]:
logger.warning(f"Empty properties for tool '{prefixed_tool_name}', adding dummy property for Google.") logger.warning(f"Empty properties for tool '{prefixed_tool_name}', adding dummy property for Google.")
input_schema["properties"] = {"_dummy_param": {"type": "STRING", "description": "Placeholder parameter as properties cannot be empty."}} input_schema["properties"] = {"_dummy_param": {"type": "STRING", "description": "Placeholder parameter as properties cannot be empty."}}
if "required" in input_schema and not isinstance(input_schema.get("required"), list): if "required" in input_schema and not isinstance(input_schema.get("required"), list):
input_schema["required"] = [] # Clear invalid required list input_schema["required"] = []
# Create function declaration dictionary for Google's format
function_declaration = { function_declaration = {
"name": prefixed_tool_name, "name": prefixed_tool_name,
"description": description, "description": description,
"parameters": input_schema, # Google uses JSON Schema directly "parameters": input_schema,
} }
function_declarations.append(function_declaration) function_declarations.append(function_declaration)
logger.debug(f"Prepared Google FunctionDeclaration dict for: {prefixed_tool_name}") logger.debug(f"Prepared Google FunctionDeclaration dict for: {prefixed_tool_name}")
# Google API expects a list containing one dictionary with 'function_declarations' key
google_tool_config = [{"function_declarations": function_declarations}] if function_declarations else [] google_tool_config = [{"function_declarations": function_declarations}] if function_declarations else []
logger.debug(f"Final Google tool config structure (pre-Tool object): {google_tool_config}") logger.debug(f"Final Google tool config structure (pre-Tool object): {google_tool_config}")
@@ -95,50 +84,56 @@ def _create_google_schema_recursive(schema_dict: dict[str, Any]) -> Schema | Non
Handles type mapping and nested structures. Returns None on failure. Handles type mapping and nested structures. Returns None on failure.
""" """
if Schema is None: if Schema is None or Type is None:
logger.error("Cannot create Schema object: google.genai types not available.") logger.error("Cannot create Schema object: google.genai types (Schema or Type) not available.")
return None return None
if not isinstance(schema_dict, dict): if not isinstance(schema_dict, dict):
logger.warning(f"Invalid schema part encountered: {schema_dict}. Returning empty schema.") logger.warning(f"Invalid schema part encountered: {schema_dict}. Returning None.")
return Schema() # Return empty schema to avoid breaking the parent return None
# Map JSON Schema types to Google's Type enum strings
type_mapping = { type_mapping = {
"string": "STRING", "string": Type.STRING,
"number": "NUMBER", "number": Type.NUMBER,
"integer": "INTEGER", "integer": Type.INTEGER,
"boolean": "BOOLEAN", "boolean": Type.BOOLEAN,
"array": "ARRAY", "array": Type.ARRAY,
"object": "OBJECT", "object": Type.OBJECT,
# Add other mappings if necessary
} }
original_type = schema_dict.get("type") original_type = schema_dict.get("type")
google_type = type_mapping.get(str(original_type).lower()) if original_type else None google_type = type_mapping.get(str(original_type).lower()) if original_type else None
# Prepare arguments for Schema constructor, filtering out None values if not google_type:
logger.warning(f"Schema dictionary missing 'type' or type '{original_type}' is not recognized: {schema_dict}. Returning None.")
return None
schema_args = { schema_args = {
"type": google_type, "type": google_type,
"format": schema_dict.get("format"), "format": schema_dict.get("format"),
"description": schema_dict.get("description"), "description": schema_dict.get("description"),
"nullable": schema_dict.get("nullable"), "nullable": schema_dict.get("nullable"),
"enum": schema_dict.get("enum"), "enum": schema_dict.get("enum"),
"items": _create_google_schema_recursive(schema_dict["items"]) if "items" in schema_dict and google_type == "ARRAY" else None, "items": _create_google_schema_recursive(schema_dict["items"]) if google_type == Type.ARRAY and "items" in schema_dict else None,
"properties": {k: _create_google_schema_recursive(v) for k, v in schema_dict.get("properties", {}).items()} if schema_dict.get("properties") and google_type == "OBJECT" else None, "properties": {k: prop_schema for k, v in schema_dict.get("properties", {}).items() if (prop_schema := _create_google_schema_recursive(v)) is not None}
"required": schema_dict.get("required") if google_type == "OBJECT" else None, if google_type == Type.OBJECT and schema_dict.get("properties")
else None,
"required": schema_dict.get("required") if google_type == Type.OBJECT else None,
} }
# Remove keys with None values
schema_args = {k: v for k, v in schema_args.items() if v is not None} schema_args = {k: v for k, v in schema_args.items() if v is not None}
if not schema_args.get("type"): if google_type == Type.ARRAY and "items" not in schema_args:
logger.warning(f"Schema dictionary missing 'type' or type '{original_type}' is not recognized: {schema_dict}. Creating empty Schema.") logger.warning(f"Array schema missing 'items': {schema_dict}. Returning None.")
return Schema() # Return empty schema return None
if google_type == Type.OBJECT and "properties" not in schema_args:
pass
try: try:
return Schema(**schema_args) created_schema = Schema(**schema_args)
return created_schema
except Exception as schema_creation_err: except Exception as schema_creation_err:
logger.error(f"Failed to create Schema object with args {schema_args}: {schema_creation_err}", exc_info=True) logger.error(f"Failed to create Schema object with args {schema_args}: {schema_creation_err}", exc_info=True)
return Schema() # Return empty schema on error return None
def convert_to_google_tool_objects(tool_configs: list[dict[str, Any]]) -> list[Tool] | None: def convert_to_google_tool_objects(tool_configs: list[dict[str, Any]]) -> list[Tool] | None:
@@ -161,7 +156,6 @@ def convert_to_google_tool_objects(tool_configs: list[dict[str, Any]]) -> list[T
return None return None
all_func_declarations = [] all_func_declarations = []
# Expecting structure like [{"function_declarations": [...]}]
if isinstance(tool_configs, list) and len(tool_configs) > 0 and "function_declarations" in tool_configs[0]: if isinstance(tool_configs, list) and len(tool_configs) > 0 and "function_declarations" in tool_configs[0]:
func_declarations_list = tool_configs[0]["function_declarations"] func_declarations_list = tool_configs[0]["function_declarations"]
if not isinstance(func_declarations_list, list): if not isinstance(func_declarations_list, list):
@@ -169,31 +163,64 @@ def convert_to_google_tool_objects(tool_configs: list[dict[str, Any]]) -> list[T
return None return None
for func_dict in func_declarations_list: for func_dict in func_declarations_list:
func_name = func_dict.get("name", "Unknown")
try: try:
params_schema_dict = func_dict.get("parameters", {"type": "object", "properties": {}}) params_schema_dict = func_dict.get("parameters", {})
# Ensure parameters is a valid schema dict for the recursive creator
if not isinstance(params_schema_dict, dict): if not isinstance(params_schema_dict, dict):
logger.warning(f"Invalid 'parameters' format for tool {func_dict.get('name')}: {params_schema_dict}. Using empty object schema.") logger.warning(f"Invalid 'parameters' format for tool {func_name}: {params_schema_dict}. Using empty object schema.")
params_schema_dict = {"type": "object", "properties": {}} params_schema_dict = {"type": "object", "properties": {}}
elif params_schema_dict.get("type") != "object": elif "type" not in params_schema_dict:
logger.warning(f"Tool {func_dict.get('name')} parameters schema is not type 'object'. Forcing object type.") params_schema_dict["type"] = "object"
params_schema_dict = {"type": "object", "properties": params_schema_dict.get("properties", {})} # Attempt to salvage properties elif params_schema_dict["type"] != "object":
logger.warning(f"Tool {func_name} parameters schema is not type 'object' ({params_schema_dict.get('type')}). Google requires 'object'. Attempting to wrap properties.")
original_properties = params_schema_dict.get("properties", {})
if not isinstance(original_properties, dict):
original_properties = {}
params_schema_dict = {"type": "object", "properties": original_properties}
parameters_schema = _create_google_schema_recursive(params_schema_dict) properties_dict = params_schema_dict.get("properties", {})
google_properties = {}
if isinstance(properties_dict, dict):
for prop_name, prop_schema_dict in properties_dict.items():
prop_schema = _create_google_schema_recursive(prop_schema_dict)
if prop_schema:
google_properties[prop_name] = prop_schema
else:
logger.warning(f"Failed to create schema for property '{prop_name}' in tool '{func_name}'. Skipping property.")
else:
logger.warning(f"'properties' for tool {func_name} is not a dictionary: {properties_dict}. Ignoring properties.")
if not google_properties:
logger.warning(f"Function '{func_name}' has no valid properties defined. Adding dummy property for Google compatibility.")
google_properties = {"_dummy_param": Schema(type=Type.STRING, description="Placeholder parameter as properties cannot be empty.")}
required_list = []
else:
original_required = params_schema_dict.get("required", [])
if isinstance(original_required, list):
required_list = [req for req in original_required if req in google_properties]
if len(required_list) != len(original_required):
logger.warning(f"Some required properties for '{func_name}' were invalid or missing from properties: {set(original_required) - set(required_list)}")
else:
logger.warning(f"'required' field for '{func_name}' is not a list: {original_required}. Ignoring required field.")
required_list = []
parameters_schema = Schema(
type=Type.OBJECT,
properties=google_properties,
required=required_list if required_list else None,
)
# Only proceed if schema creation was somewhat successful
if parameters_schema is not None:
declaration = FunctionDeclaration( declaration = FunctionDeclaration(
name=func_dict["name"], name=func_name,
description=func_dict.get("description", ""), description=func_dict.get("description", ""),
parameters=parameters_schema, parameters=parameters_schema,
) )
all_func_declarations.append(declaration) all_func_declarations.append(declaration)
else: logger.debug(f"Successfully created FunctionDeclaration for: {func_name}")
logger.error(f"Failed to create parameters Schema for FunctionDeclaration '{func_dict.get('name', 'Unknown')}'")
except Exception as decl_err: except Exception as decl_err:
logger.error(f"Failed to create FunctionDeclaration object for tool '{func_dict.get('name', 'Unknown')}': {decl_err}", exc_info=True) logger.error(f"Failed to create FunctionDeclaration object for tool '{func_name}': {decl_err}", exc_info=True)
else: else:
logger.error(f"Invalid tool_configs structure provided: {tool_configs}") logger.error(f"Invalid tool_configs structure provided: {tool_configs}")
@@ -203,14 +230,10 @@ def convert_to_google_tool_objects(tool_configs: list[dict[str, Any]]) -> list[T
logger.warning("No valid Google FunctionDeclarations were created from the provided configurations.") logger.warning("No valid Google FunctionDeclarations were created from the provided configurations.")
return None return None
# Google expects a list containing one Tool object
logger.info(f"Successfully created {len(all_func_declarations)} Google FunctionDeclarations.") logger.info(f"Successfully created {len(all_func_declarations)} Google FunctionDeclarations.")
return [Tool(function_declarations=all_func_declarations)] return [Tool(function_declarations=all_func_declarations)]
# --- Tool Call Parsing and Handling (from Google response) ---
def has_google_tool_calls(response: Any) -> bool: def has_google_tool_calls(response: Any) -> bool:
""" """
Checks if the Google response object contains tool calls (FunctionCalls). Checks if the Google response object contains tool calls (FunctionCalls).
@@ -222,7 +245,6 @@ def has_google_tool_calls(response: Any) -> bool:
True if FunctionCalls are present, False otherwise. True if FunctionCalls are present, False otherwise.
""" """
try: try:
# Check non-streaming response structure
if hasattr(response, "candidates") and response.candidates: if hasattr(response, "candidates") and response.candidates:
candidate = response.candidates[0] candidate = response.candidates[0]
if hasattr(candidate, "content") and hasattr(candidate.content, "parts"): if hasattr(candidate, "content") and hasattr(candidate.content, "parts"):
@@ -231,10 +253,6 @@ def has_google_tool_calls(response: Any) -> bool:
logger.debug(f"Tool call (FunctionCall) detected in Google response part: {part.function_call.name}") logger.debug(f"Tool call (FunctionCall) detected in Google response part: {part.function_call.name}")
return True return True
# Note: Detecting function calls reliably in a stream might require accumulating parts.
# This function primarily works reliably for non-streaming responses.
# For streaming, the check might happen during stream processing itself.
logger.debug("No tool calls (FunctionCall) detected in Google response.") logger.debug("No tool calls (FunctionCall) detected in Google response.")
return False return False
except Exception as e: except Exception as e:
@@ -270,37 +288,31 @@ def parse_google_tool_calls(response: Any) -> list[dict[str, Any]]:
for part in candidate.content.parts: for part in candidate.content.parts:
if hasattr(part, "function_call") and part.function_call: if hasattr(part, "function_call") and part.function_call:
func_call = part.function_call func_call = part.function_call
# Generate a simple unique ID for this call within this response
call_id = f"call_{call_index}" call_id = f"call_{call_index}"
call_index += 1 call_index += 1
# Extract server_name and func_name from the prefixed name
full_name = func_call.name full_name = func_call.name
parts = full_name.split("__", 1) parts = full_name.split("__", 1)
if len(parts) == 2: if len(parts) == 2:
server_name, func_name = parts server_name, func_name = parts
else: else:
# If the prefix isn't found, assume it's just the function name
logger.warning(f"Could not determine server_name from Google tool name '{full_name}'. Using None for server_name.") logger.warning(f"Could not determine server_name from Google tool name '{full_name}'. Using None for server_name.")
server_name = None server_name = None
func_name = full_name func_name = full_name
# Convert arguments dict to JSON string
try: try:
# func_call.args is already a dict-like object (Mapping)
args_dict = dict(func_call.args) if func_call.args else {} args_dict = dict(func_call.args) if func_call.args else {}
args_str = json.dumps(args_dict) args_str = json.dumps(args_dict)
except Exception as json_err: except Exception as json_err:
logger.error(f"Failed to dump arguments dict to JSON string for {func_name}: {json_err}") logger.error(f"Failed to dump arguments dict to JSON string for {func_name}: {json_err}")
# Provide error info in arguments if serialization fails
args_str = json.dumps({"error": "Failed to serialize arguments", "original_args": str(func_call.args)}) args_str = json.dumps({"error": "Failed to serialize arguments", "original_args": str(func_call.args)})
parsed_calls.append({ parsed_calls.append({
"id": call_id, # Internal ID for tracking this call "id": call_id,
"server_name": server_name, "server_name": server_name,
"function_name": func_name, # The original function name "function_name": func_name,
"arguments": args_str, # Arguments as a JSON string "arguments": args_str,
"_google_tool_name": full_name, # Keep original name if needed later "_google_tool_name": full_name,
}) })
logger.debug(f"Parsed tool call: ID {call_id}, Server {server_name}, Func {func_name}, Args {args_str[:100]}...") logger.debug(f"Parsed tool call: ID {call_id}, Server {server_name}, Func {func_name}, Args {args_str[:100]}...")
@@ -326,17 +338,14 @@ def format_google_tool_results(tool_call_id: str, function_name: str, result: An
This will be converted later by `_convert_messages`. This will be converted later by `_convert_messages`.
""" """
try: try:
# Google expects the 'response' field in FunctionResponse to contain a dict.
# The content should ideally be JSON serializable. We wrap the result.
if isinstance(result, (str, int, float, bool, list)): if isinstance(result, (str, int, float, bool, list)):
content_dict = {"result": result} content_dict = {"result": result}
elif isinstance(result, dict): elif isinstance(result, dict):
content_dict = result # Assume it's already a suitable dict content_dict = result
else: else:
logger.warning(f"Tool result for {function_name} is of non-standard type {type(result)}. Converting to string.") logger.warning(f"Tool result for {function_name} is of non-standard type {type(result)}. Converting to string.")
content_dict = {"result": str(result)} content_dict = {"result": str(result)}
# Ensure the content is JSON serializable for the 'content' field
try: try:
content_str = json.dumps(content_dict) content_str = json.dumps(content_dict)
except Exception as json_err: except Exception as json_err:
@@ -348,12 +357,9 @@ def format_google_tool_results(tool_call_id: str, function_name: str, result: An
content_str = json.dumps({"error": "Failed to prepare tool result content", "details": str(e)}) content_str = json.dumps({"error": "Failed to prepare tool result content", "details": str(e)})
logger.debug(f"Formatting Google tool result for call ID {tool_call_id} (Function: {function_name})") logger.debug(f"Formatting Google tool result for call ID {tool_call_id} (Function: {function_name})")
# Return in the standard message format, _convert_messages will handle Google's structure
return { return {
"role": "tool", "role": "tool",
"tool_call_id": tool_call_id, # Used by _convert_messages to find the original call "tool_call_id": tool_call_id,
"content": content_str, # The JSON string representing the result content "content": content_str,
"name": function_name, # Store original function name for _convert_messages "name": function_name,
# Note: Google's FunctionResponse Part needs 'name' and 'response' (dict).
# This standard format will be converted by the provider's message conversion logic.
} }

View File

@@ -1,4 +1,3 @@
# src/providers/google_provider/utils.py
import json import json
import logging import logging
from typing import Any from typing import Any
@@ -12,7 +11,7 @@ logger = logging.getLogger(__name__)
def get_context_window(model: str) -> int: def get_context_window(model: str) -> int:
"""Retrieves the context window size for a given Google model.""" """Retrieves the context window size for a given Google model."""
default_window = 1000000 # Default fallback for Gemini default_window = 1000000
try: try:
provider_models = MODELS.get("google", {}).get("models", []) provider_models = MODELS.get("google", {}).get("models", [])
for m in provider_models: for m in provider_models:
@@ -44,12 +43,9 @@ def convert_messages(messages: list[dict[str, Any]]) -> tuple[list[Content], str
system_prompt = content system_prompt = content
logger.debug("Extracted system prompt for Google.") logger.debug("Extracted system prompt for Google.")
else: else:
# Google API expects system prompt only at the beginning.
# If found later, log a warning and skip or merge if possible (though merging is complex).
logger.warning("System message found not at the beginning. Skipping for Google API.") logger.warning("System message found not at the beginning. Skipping for Google API.")
continue # Skip adding system messages to the main list continue
# Map roles: 'assistant' -> 'model', 'tool' -> 'function' (handled below)
google_role = {"user": "user", "assistant": "model"}.get(role) google_role = {"user": "user", "assistant": "model"}.get(role)
if not google_role and role != "tool": if not google_role and role != "tool":
@@ -58,92 +54,73 @@ def convert_messages(messages: list[dict[str, Any]]) -> tuple[list[Content], str
parts: list[Part | str] = [] parts: list[Part | str] = []
if role == "tool": if role == "tool":
# Tool results are mapped to 'function' role in Google API
if tool_call_id and content: if tool_call_id and content:
try: try:
# Attempt to parse the content as JSON, assuming it's the tool output
response_content_dict = json.loads(content) response_content_dict = json.loads(content)
except json.JSONDecodeError: except json.JSONDecodeError:
logger.warning(f"Could not decode tool result content for {tool_call_id}, sending as raw string.") logger.warning(f"Could not decode tool result content for {tool_call_id}, sending as raw string.")
response_content_dict = {"result": content} # Wrap raw string if not JSON response_content_dict = {"result": content}
# Find the original function name from the preceding assistant message func_name = "unknown_function"
func_name = "unknown_function" # Default if name can't be found
if i > 0 and messages[i - 1].get("role") == "assistant": if i > 0 and messages[i - 1].get("role") == "assistant":
prev_tool_calls = messages[i - 1].get("tool_calls") prev_tool_calls = messages[i - 1].get("tool_calls")
if prev_tool_calls: if prev_tool_calls:
for tc in prev_tool_calls: for tc in prev_tool_calls:
# Match based on the ID provided in the tool message
if tc.get("id") == tool_call_id: if tc.get("id") == tool_call_id:
# Google uses 'server__func' format, extract original func name if possible
full_name = tc.get("function_name", "unknown_function") full_name = tc.get("function_name", "unknown_function")
func_name = full_name.split("__", 1)[-1] # Get the part after '__' or the full name func_name = full_name.split("__", 1)[-1]
break break
# Create a FunctionResponse part
parts.append(Part.from_function_response(name=func_name, response={"content": response_content_dict})) parts.append(Part.from_function_response(name=func_name, response={"content": response_content_dict}))
google_role = "function" # Explicitly set role for tool results google_role = "function"
else: else:
logger.warning(f"Skipping tool message due to missing tool_call_id or content: {message}") logger.warning(f"Skipping tool message due to missing tool_call_id or content: {message}")
continue # Skip if essential parts are missing continue
elif role == "assistant" and tool_calls: elif role == "assistant" and tool_calls:
# Assistant message requesting tool calls
for tool_call in tool_calls: for tool_call in tool_calls:
args = tool_call.get("arguments", {}) args = tool_call.get("arguments", {})
# Ensure arguments are a dict, not a string
if isinstance(args, str): if isinstance(args, str):
try: try:
args = json.loads(args) args = json.loads(args)
except json.JSONDecodeError: except json.JSONDecodeError:
logger.error(f"Failed to parse arguments string for tool call {tool_call.get('id')}: {args}") logger.error(f"Failed to parse arguments string for tool call {tool_call.get('id')}: {args}")
args = {"error": "failed to parse arguments"} # Provide error feedback args = {"error": "failed to parse arguments"}
# Google uses 'server__func' format, extract original func name if possible
full_name = tool_call.get("function_name", "unknown_function") full_name = tool_call.get("function_name", "unknown_function")
func_name = full_name.split("__", 1)[-1] # Get the part after '__' or the full name func_name = full_name.split("__", 1)[-1]
# Create a FunctionCall part
parts.append(Part.from_function_call(name=func_name, args=args)) parts.append(Part.from_function_call(name=func_name, args=args))
# Include any text content alongside the function calls
if content and isinstance(content, str): if content and isinstance(content, str):
parts.append(Part.from_text(content)) parts.append(Part(text=content))
elif content: elif content:
# Regular user or assistant message content
if isinstance(content, str): if isinstance(content, str):
parts.append(Part.from_text(content)) parts.append(Part(text=content))
# TODO: Handle potential image content if needed in the future
else: else:
logger.warning(f"Unsupported content type for role '{role}': {type(content)}. Converting to string.") logger.warning(f"Unsupported content type for role '{role}': {type(content)}. Converting to string.")
parts.append(Part.from_text(str(content))) parts.append(Part(text=str(content)))
# Add the constructed Content object if parts were generated
if parts: if parts:
google_messages.append(Content(role=google_role, parts=parts)) google_messages.append(Content(role=google_role, parts=parts))
else: else:
# Log if a message resulted in no parts (e.g., empty content, skipped system message)
logger.debug(f"No parts generated for message: {message}") logger.debug(f"No parts generated for message: {message}")
# Validate message alternation (user -> model -> user/function -> user -> ...)
last_role = None last_role = None
valid_alternation = True valid_alternation = True
for msg in google_messages: for msg in google_messages:
current_role = msg.role current_role = msg.role
# Check for consecutive user/model roles
if current_role == last_role and current_role in ["user", "model"]: if current_role == last_role and current_role in ["user", "model"]:
valid_alternation = False valid_alternation = False
logger.error(f"Invalid role sequence for Google: consecutive '{current_role}' roles.") logger.error(f"Invalid role sequence for Google: consecutive '{current_role}' roles.")
break break
# Check if 'function' role is followed by 'user'
if last_role == "function" and current_role != "user": if last_role == "function" and current_role != "user":
valid_alternation = False valid_alternation = False
logger.error(f"Invalid role sequence for Google: '{current_role}' follows 'function'. Expected 'user'.") logger.error(f"Invalid role sequence for Google: '{current_role}' follows 'function'. Expected 'user'.")
break break
last_role = current_role last_role = current_role
# Raise error if alternation is invalid, as Google API enforces this
if not valid_alternation: if not valid_alternation:
raise ValueError("Invalid message sequence for Google API. Roles must alternate between 'user' and 'model', with 'function' responses followed by 'user'.") raise ValueError("Invalid message sequence for Google API. Roles must alternate between 'user' and 'model', with 'function' responses followed by 'user'.")

View File

@@ -1,4 +1,3 @@
# src/providers/openai_provider/__init__.py
from typing import Any from typing import Any
from openai import Stream from openai import Stream
@@ -20,25 +19,23 @@ from src.providers.base import BaseProvider
class OpenAIProvider(BaseProvider): class OpenAIProvider(BaseProvider):
"""Provider implementation for OpenAI and compatible APIs.""" """Provider implementation for OpenAI and compatible APIs."""
def __init__(self, api_key: str, base_url: str | None = None): temperature: float
# BaseProvider __init__ might not be needed if client init handles base_url logic
# super().__init__(api_key, base_url) # Let's see if we need this def __init__(self, api_key: str, base_url: str | None = None, temperature: float = 0.6):
self.client = initialize_client(api_key, base_url) self.client = initialize_client(api_key, base_url)
# Store api_key and base_url if needed by BaseProvider or other methods
self.api_key = api_key self.api_key = api_key
self.base_url = self.client.base_url # Get effective base_url from client self.base_url = self.client.base_url
self.temperature = temperature
def create_chat_completion( def create_chat_completion(
self, self,
messages: list[dict[str, str]], messages: list[dict[str, str]],
model: str, model: str,
temperature: float = 0.4,
max_tokens: int | None = None, max_tokens: int | None = None,
stream: bool = True, stream: bool = True,
tools: list[dict[str, Any]] | None = None, tools: list[dict[str, Any]] | None = None,
) -> Stream[ChatCompletionChunk] | ChatCompletion: ) -> Stream[ChatCompletionChunk] | ChatCompletion:
# Pass self (provider instance) to the helper function return create_chat_completion(self, messages, model, self.temperature, max_tokens, stream, tools)
return create_chat_completion(self, messages, model, temperature, max_tokens, stream, tools)
def get_streaming_content(self, response: Stream[ChatCompletionChunk]): def get_streaming_content(self, response: Stream[ChatCompletionChunk]):
return get_streaming_content(response) return get_streaming_content(response)
@@ -47,7 +44,6 @@ class OpenAIProvider(BaseProvider):
return get_content(response) return get_content(response)
def has_tool_calls(self, response: Stream[ChatCompletionChunk] | ChatCompletion) -> bool: def has_tool_calls(self, response: Stream[ChatCompletionChunk] | ChatCompletion) -> bool:
# This method might need the full response after streaming, handled by LLMClient
return has_tool_calls(response) return has_tool_calls(response)
def parse_tool_calls(self, response: ChatCompletion) -> list[dict[str, Any]]: def parse_tool_calls(self, response: ChatCompletion) -> list[dict[str, Any]]:

View File

@@ -1,4 +1,3 @@
# src/providers/openai_provider/client.py
import logging import logging
from openai import OpenAI from openai import OpenAI
@@ -10,12 +9,9 @@ logger = logging.getLogger(__name__)
def initialize_client(api_key: str, base_url: str | None = None) -> OpenAI: def initialize_client(api_key: str, base_url: str | None = None) -> OpenAI:
"""Initializes and returns an OpenAI client instance.""" """Initializes and returns an OpenAI client instance."""
# Use default OpenAI endpoint if base_url is not provided explicitly
effective_base_url = base_url or MODELS.get("openai", {}).get("endpoint") effective_base_url = base_url or MODELS.get("openai", {}).get("endpoint")
logger.info(f"Initializing OpenAI client with base URL: {effective_base_url}") logger.info(f"Initializing OpenAI client with base URL: {effective_base_url}")
try: try:
# TODO: Add default headers if needed, similar to the original openai_client.py?
# default_headers={"HTTP-Referer": "...", "X-Title": "..."}
client = OpenAI(api_key=api_key, base_url=effective_base_url) client = OpenAI(api_key=api_key, base_url=effective_base_url)
return client return client
except Exception as e: except Exception as e:

View File

@@ -1,4 +1,3 @@
# src/providers/openai_provider/completion.py
import logging import logging
from typing import Any from typing import Any
@@ -11,10 +10,10 @@ logger = logging.getLogger(__name__)
def create_chat_completion( def create_chat_completion(
provider, # The OpenAIProvider instance provider,
messages: list[dict[str, str]], messages: list[dict[str, str]],
model: str, model: str,
temperature: float = 0.4, temperature: float = 0.6,
max_tokens: int | None = None, max_tokens: int | None = None,
stream: bool = True, stream: bool = True,
tools: list[dict[str, Any]] | None = None, tools: list[dict[str, Any]] | None = None,
@@ -22,44 +21,30 @@ def create_chat_completion(
"""Creates a chat completion using the OpenAI API, handling context window truncation.""" """Creates a chat completion using the OpenAI API, handling context window truncation."""
logger.debug(f"OpenAI create_chat_completion called. Model: {model}, Stream: {stream}, Tools: {bool(tools)}") logger.debug(f"OpenAI create_chat_completion called. Model: {model}, Stream: {stream}, Tools: {bool(tools)}")
# --- Truncation Step ---
truncated_messages, initial_est_tokens, final_est_tokens = truncate_messages(messages, model) truncated_messages, initial_est_tokens, final_est_tokens = truncate_messages(messages, model)
# -----------------------
try: try:
completion_params = { completion_params = {
"model": model, "model": model,
"messages": truncated_messages, # Use truncated messages "messages": truncated_messages,
"temperature": temperature, "temperature": temperature,
"max_tokens": max_tokens, "max_tokens": max_tokens,
"stream": stream, "stream": stream,
} }
if tools: if tools:
completion_params["tools"] = tools completion_params["tools"] = tools
completion_params["tool_choice"] = "auto" # Let OpenAI decide when to use tools completion_params["tool_choice"] = "auto"
# Remove None values like max_tokens if not provided
completion_params = {k: v for k, v in completion_params.items() if v is not None} completion_params = {k: v for k, v in completion_params.items() if v is not None}
# --- Added Debug Logging ---
log_params = completion_params.copy() log_params = completion_params.copy()
# Avoid logging full messages if they are too long
if "messages" in log_params:
log_params["messages"] = [
{k: (v[:100] + "..." if isinstance(v, str) and len(v) > 100 else v) for k, v in msg.items()}
for msg in log_params["messages"][-2:] # Log last 2 messages summary
]
# Specifically log tools structure if present
tools_log = log_params.get("tools", "Not Present") tools_log = log_params.get("tools", "Not Present")
logger.debug(f"Calling OpenAI API. Model: {log_params.get('model')}, Stream: {log_params.get('stream')}, Tools: {tools_log}") logger.debug(f"Calling OpenAI API. Model: {log_params.get('model')}, Stream: {log_params.get('stream')}, Tools: {tools_log}")
logger.debug(f"Full API Params (messages summarized): {log_params}") logger.debug(f"Full API Params: {log_params}")
# --- End Added Debug Logging ---
response = provider.client.chat.completions.create(**completion_params) response = provider.client.chat.completions.create(**completion_params)
logger.debug("OpenAI API call successful.") logger.debug("OpenAI API call successful.")
# --- Capture Actual Usage (for UI display later) ---
# Log usage if available (primarily non-streaming)
actual_usage = None actual_usage = None
if isinstance(response, ChatCompletion) and response.usage: if isinstance(response, ChatCompletion) and response.usage:
actual_usage = { actual_usage = {
@@ -68,13 +53,9 @@ def create_chat_completion(
"total_tokens": response.usage.total_tokens, "total_tokens": response.usage.total_tokens,
} }
logger.info(f"Actual OpenAI API usage: {actual_usage}") logger.info(f"Actual OpenAI API usage: {actual_usage}")
# TODO: How to handle usage for streaming responses? Needs investigation.
# Return the raw response for now. LLMClient will process it.
return response return response
# ----------------------------------------------------
except Exception as e: except Exception as e:
logger.error(f"OpenAI API error: {e}", exc_info=True) logger.error(f"OpenAI API error: {e}", exc_info=True)
# Re-raise for the LLMClient to handle
raise raise

View File

@@ -1,4 +1,3 @@
# src/providers/openai_provider/response.py
import json import json
import logging import logging
from collections.abc import Generator from collections.abc import Generator
@@ -16,30 +15,24 @@ def get_streaming_content(response: Stream[ChatCompletionChunk]) -> Generator[st
full_delta = "" full_delta = ""
try: try:
for chunk in response: for chunk in response:
# Check if choices exist and are not empty
if chunk.choices: if chunk.choices:
delta = chunk.choices[0].delta.content delta = chunk.choices[0].delta.content
if delta: if delta:
full_delta += delta full_delta += delta
yield delta yield delta
# Handle potential finish reasons or other stream elements if needed
# else:
# logger.debug(f"Stream chunk without choices: {chunk}") # Or handle finish reason etc.
logger.debug(f"Stream finished. Total delta length: {len(full_delta)}") logger.debug(f"Stream finished. Total delta length: {len(full_delta)}")
except Exception as e: except Exception as e:
logger.error(f"Error processing OpenAI stream: {e}", exc_info=True) logger.error(f"Error processing OpenAI stream: {e}", exc_info=True)
# Yield an error message? Or let the generator stop?
yield json.dumps({"error": f"Stream processing error: {str(e)}"}) yield json.dumps({"error": f"Stream processing error: {str(e)}"})
def get_content(response: ChatCompletion) -> str: def get_content(response: ChatCompletion) -> str:
"""Extracts content from a non-streaming OpenAI response.""" """Extracts content from a non-streaming OpenAI response."""
try: try:
# Check if choices exist and are not empty
if response.choices: if response.choices:
content = response.choices[0].message.content content = response.choices[0].message.content
logger.debug(f"Extracted content (length {len(content) if content else 0}) from non-streaming response.") logger.debug(f"Extracted content (length {len(content) if content else 0}) from non-streaming response.")
return content or "" # Return empty string if content is None return content or ""
else: else:
logger.warning("No choices found in OpenAI non-streaming response.") logger.warning("No choices found in OpenAI non-streaming response.")
return "[No content received]" return "[No content received]"
@@ -55,12 +48,10 @@ def get_usage(response: Any) -> dict[str, int] | None:
usage = { usage = {
"prompt_tokens": response.usage.prompt_tokens, "prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens, "completion_tokens": response.usage.completion_tokens,
# "total_tokens": response.usage.total_tokens, # Optional
} }
logger.debug(f"Extracted usage from OpenAI response: {usage}") logger.debug(f"Extracted usage from OpenAI response: {usage}")
return usage return usage
else: else:
# Don't log warning for streams, as usage isn't expected here
if not isinstance(response, Stream): if not isinstance(response, Stream):
logger.warning(f"Could not extract usage from OpenAI response object of type {type(response)}") logger.warning(f"Could not extract usage from OpenAI response object of type {type(response)}")
return None return None

View File

@@ -1,4 +1,3 @@
# src/providers/openai_provider/tools.py
import json import json
import logging import logging
from typing import Any from typing import Any
@@ -13,20 +12,16 @@ logger = logging.getLogger(__name__)
def has_tool_calls(response: Stream[ChatCompletionChunk] | ChatCompletion) -> bool: def has_tool_calls(response: Stream[ChatCompletionChunk] | ChatCompletion) -> bool:
"""Checks if the OpenAI response contains tool calls.""" """Checks if the OpenAI response contains tool calls."""
try: try:
if isinstance(response, ChatCompletion): # Non-streaming if isinstance(response, ChatCompletion):
# Check if choices exist and are not empty
if response.choices: if response.choices:
return bool(response.choices[0].message.tool_calls) return bool(response.choices[0].message.tool_calls)
else: else:
logger.warning("No choices found in OpenAI non-streaming response for tool check.") logger.warning("No choices found in OpenAI non-streaming response for tool check.")
return False return False
elif isinstance(response, Stream): elif isinstance(response, Stream):
# This check remains unreliable for unconsumed streams.
# LLMClient needs robust handling after consumption.
logger.warning("has_tool_calls check on a stream is unreliable before consumption.") logger.warning("has_tool_calls check on a stream is unreliable before consumption.")
return False # Assume no for unconsumed stream for now return False
else: else:
# If it's already consumed stream or unexpected type
logger.warning(f"has_tool_calls received unexpected type: {type(response)}") logger.warning(f"has_tool_calls received unexpected type: {type(response)}")
return False return False
except Exception as e: except Exception as e:
@@ -36,14 +31,12 @@ def has_tool_calls(response: Stream[ChatCompletionChunk] | ChatCompletion) -> bo
def parse_tool_calls(response: ChatCompletion) -> list[dict[str, Any]]: def parse_tool_calls(response: ChatCompletion) -> list[dict[str, Any]]:
"""Parses tool calls from a non-streaming OpenAI response.""" """Parses tool calls from a non-streaming OpenAI response."""
# This implementation assumes a non-streaming response or a fully buffered stream
parsed_calls = [] parsed_calls = []
try: try:
if not isinstance(response, ChatCompletion): if not isinstance(response, ChatCompletion):
logger.error(f"parse_tool_calls expects ChatCompletion, got {type(response)}") logger.error(f"parse_tool_calls expects ChatCompletion, got {type(response)}")
return [] return []
# Check if choices exist and are not empty
if not response.choices: if not response.choices:
logger.warning("No choices found in OpenAI non-streaming response for tool parsing.") logger.warning("No choices found in OpenAI non-streaming response for tool parsing.")
return [] return []
@@ -55,38 +48,30 @@ def parse_tool_calls(response: ChatCompletion) -> list[dict[str, Any]]:
logger.debug(f"Parsing {len(tool_calls)} tool calls from OpenAI response.") logger.debug(f"Parsing {len(tool_calls)} tool calls from OpenAI response.")
for call in tool_calls: for call in tool_calls:
if call.type == "function": if call.type == "function":
# Attempt to parse server_name from function name if prefixed
# e.g., "server-name__actual-tool-name"
parts = call.function.name.split("__", 1) parts = call.function.name.split("__", 1)
if len(parts) == 2: if len(parts) == 2:
server_name, func_name = parts server_name, func_name = parts
else: else:
# If no prefix, how do we know the server? Needs refinement.
# Defaulting to None or a default server? Log warning.
logger.warning(f"Could not determine server_name from tool name '{call.function.name}'. Assuming default or error needed.") logger.warning(f"Could not determine server_name from tool name '{call.function.name}'. Assuming default or error needed.")
server_name = None # Or raise error, or use a default? server_name = None
func_name = call.function.name func_name = call.function.name
# Arguments might be a string needing JSON parsing, or already parsed dict
arguments_obj = None arguments_obj = None
try: try:
if isinstance(call.function.arguments, str): if isinstance(call.function.arguments, str):
arguments_obj = json.loads(call.function.arguments) arguments_obj = json.loads(call.function.arguments)
else: else:
# Assuming it might already be a dict if not a string (less common)
arguments_obj = call.function.arguments arguments_obj = call.function.arguments
except json.JSONDecodeError as json_err: except json.JSONDecodeError as json_err:
logger.error(f"Failed to parse JSON arguments for tool {func_name} (ID: {call.id}): {json_err}") logger.error(f"Failed to parse JSON arguments for tool {func_name} (ID: {call.id}): {json_err}")
logger.error(f"Raw arguments string: {call.function.arguments}") logger.error(f"Raw arguments string: {call.function.arguments}")
# Decide how to handle: skip tool, pass raw string, pass error?
# Passing raw string for now, but this might break consumers.
arguments_obj = {"error": "Failed to parse arguments", "raw_arguments": call.function.arguments} arguments_obj = {"error": "Failed to parse arguments", "raw_arguments": call.function.arguments}
parsed_calls.append({ parsed_calls.append({
"id": call.id, "id": call.id,
"server_name": server_name, # May be None if not prefixed "server_name": server_name,
"function_name": func_name, "function_name": func_name,
"arguments": arguments_obj, # Pass parsed arguments (or error dict) "arguments": arguments_obj,
}) })
else: else:
logger.warning(f"Unsupported tool call type: {call.type}") logger.warning(f"Unsupported tool call type: {call.type}")
@@ -94,20 +79,18 @@ def parse_tool_calls(response: ChatCompletion) -> list[dict[str, Any]]:
return parsed_calls return parsed_calls
except Exception as e: except Exception as e:
logger.error(f"Error parsing OpenAI tool calls: {e}", exc_info=True) logger.error(f"Error parsing OpenAI tool calls: {e}", exc_info=True)
return [] # Return empty list on error return []
def format_tool_results(tool_call_id: str, result: Any) -> dict[str, Any]: def format_tool_results(tool_call_id: str, result: Any) -> dict[str, Any]:
"""Formats a tool result for an OpenAI follow-up request.""" """Formats a tool result for an OpenAI follow-up request."""
# Result might be a dict (including potential errors) or simple string/number
# OpenAI expects the content to be a string, often JSON.
try: try:
if isinstance(result, dict): if isinstance(result, dict):
content = json.dumps(result) content = json.dumps(result)
elif isinstance(result, str): elif isinstance(result, str):
content = result # Allow plain strings if result is already string content = result
else: else:
content = str(result) # Ensure it's a string otherwise content = str(result)
except Exception as e: except Exception as e:
logger.error(f"Error JSON-encoding tool result for {tool_call_id}: {e}") logger.error(f"Error JSON-encoding tool result for {tool_call_id}: {e}")
content = json.dumps({"error": "Failed to encode tool result", "original_type": str(type(result))}) content = json.dumps({"error": "Failed to encode tool result", "original_type": str(type(result))})
@@ -122,9 +105,6 @@ def format_tool_results(tool_call_id: str, result: Any) -> dict[str, Any]:
def convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]: def convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Converts internal tool format to OpenAI's format.""" """Converts internal tool format to OpenAI's format."""
# This function seems identical to the one in src/tools/conversion.py
# We can potentially remove it from here and import from the central location.
# For now, keep it duplicated to maintain modularity until a decision is made.
openai_tools = [] openai_tools = []
logger.debug(f"Converting {len(tools)} tools to OpenAI format.") logger.debug(f"Converting {len(tools)} tools to OpenAI format.")
for tool in tools: for tool in tools:
@@ -137,7 +117,6 @@ def convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
logger.warning(f"Skipping invalid tool definition during conversion: {tool}") logger.warning(f"Skipping invalid tool definition during conversion: {tool}")
continue continue
# Prefix tool name with server name to avoid clashes and allow routing
prefixed_tool_name = f"{server_name}__{tool_name}" prefixed_tool_name = f"{server_name}__{tool_name}"
openai_tool_format = { openai_tool_format = {
@@ -145,7 +124,7 @@ def convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
"function": { "function": {
"name": prefixed_tool_name, "name": prefixed_tool_name,
"description": description, "description": description,
"parameters": input_schema, # OpenAI uses JSON Schema directly "parameters": input_schema,
}, },
} }
openai_tools.append(openai_tool_format) openai_tools.append(openai_tool_format)
@@ -159,11 +138,9 @@ def get_original_message_with_calls(response: ChatCompletion) -> dict[str, Any]:
try: try:
if isinstance(response, ChatCompletion) and response.choices and response.choices[0].message.tool_calls: if isinstance(response, ChatCompletion) and response.choices and response.choices[0].message.tool_calls:
message = response.choices[0].message message = response.choices[0].message
# Convert Pydantic model to dict for message history
return message.model_dump(exclude_unset=True) return message.model_dump(exclude_unset=True)
else: else:
logger.warning("Could not extract original message with tool calls from response.") logger.warning("Could not extract original message with tool calls from response.")
# Return a placeholder or raise error?
return {"role": "assistant", "content": "[Could not extract tool calls message]"} return {"role": "assistant", "content": "[Could not extract tool calls message]"}
except Exception as e: except Exception as e:
logger.error(f"Error extracting original message with calls: {e}", exc_info=True) logger.error(f"Error extracting original message with calls: {e}", exc_info=True)

View File

@@ -1,4 +1,3 @@
# src/providers/openai_provider/utils.py
import logging import logging
import math import math
@@ -9,15 +8,12 @@ logger = logging.getLogger(__name__)
def get_context_window(model: str) -> int: def get_context_window(model: str) -> int:
"""Retrieves the context window size for a given model.""" """Retrieves the context window size for a given model."""
# Default to a safe fallback if model or provider info is missing
default_window = 8000 default_window = 8000
try: try:
# Assuming MODELS structure: MODELS['openai']['models'] is a list of dicts
provider_models = MODELS.get("openai", {}).get("models", []) provider_models = MODELS.get("openai", {}).get("models", [])
for m in provider_models: for m in provider_models:
if m.get("id") == model: if m.get("id") == model:
return m.get("context_window", default_window) return m.get("context_window", default_window)
# Fallback if specific model ID not found in our list
logger.warning(f"Context window for OpenAI model '{model}' not found in MODELS config. Using default: {default_window}") logger.warning(f"Context window for OpenAI model '{model}' not found in MODELS config. Using default: {default_window}")
return default_window return default_window
except Exception as e: except Exception as e:
@@ -36,8 +32,6 @@ def estimate_openai_token_count(messages: list[dict[str, str]]) -> int:
content = message.get("content") content = message.get("content")
if isinstance(content, str): if isinstance(content, str):
total_chars += len(content) total_chars += len(content)
# Rough approximation for function/tool call overhead if needed later
# Using math.ceil to round up, ensuring we don't underestimate too much.
estimated_tokens = math.ceil(total_chars / 4.0) estimated_tokens = math.ceil(total_chars / 4.0)
logger.debug(f"Estimated OpenAI token count (char/4): {estimated_tokens} for {len(messages)} messages") logger.debug(f"Estimated OpenAI token count (char/4): {estimated_tokens} for {len(messages)} messages")
return estimated_tokens return estimated_tokens
@@ -54,49 +48,41 @@ def truncate_messages(messages: list[dict[str, str]], model: str) -> tuple[list[
- The final estimated token count after truncation (if any). - The final estimated token count after truncation (if any).
""" """
context_limit = get_context_window(model) context_limit = get_context_window(model)
# Add a buffer to be safer with approximation buffer = 200
buffer = 200 # Reduce buffer slightly as we round up now
effective_limit = context_limit - buffer effective_limit = context_limit - buffer
initial_estimated_count = estimate_openai_token_count(messages) initial_estimated_count = estimate_openai_token_count(messages)
final_estimated_count = initial_estimated_count final_estimated_count = initial_estimated_count
truncated_messages = list(messages) # Make a copy truncated_messages = list(messages)
# Identify if the first message is a system prompt
has_system_prompt = False has_system_prompt = False
if truncated_messages and truncated_messages[0].get("role") == "system": if truncated_messages and truncated_messages[0].get("role") == "system":
has_system_prompt = True has_system_prompt = True
# If only system prompt exists, don't truncate further
if len(truncated_messages) == 1 and final_estimated_count > effective_limit: if len(truncated_messages) == 1 and final_estimated_count > effective_limit:
logger.warning(f"System prompt alone ({final_estimated_count} tokens) exceeds effective limit ({effective_limit}). Cannot truncate further.") logger.warning(f"System prompt alone ({final_estimated_count} tokens) exceeds effective limit ({effective_limit}). Cannot truncate further.")
# Return original messages to avoid removing the only message
return messages, initial_estimated_count, final_estimated_count return messages, initial_estimated_count, final_estimated_count
while final_estimated_count > effective_limit: while final_estimated_count > effective_limit:
if has_system_prompt and len(truncated_messages) <= 1: if has_system_prompt and len(truncated_messages) <= 1:
# Should not happen if check above works, but safety break
logger.warning("Truncation stopped: Only system prompt remains.") logger.warning("Truncation stopped: Only system prompt remains.")
break break
if not has_system_prompt and len(truncated_messages) <= 0: if not has_system_prompt and len(truncated_messages) <= 0:
logger.warning("Truncation stopped: No messages left.") logger.warning("Truncation stopped: No messages left.")
break # No messages left break
# Determine index to remove: 1 if system prompt exists and list is long enough, else 0
remove_index = 1 if has_system_prompt and len(truncated_messages) > 1 else 0 remove_index = 1 if has_system_prompt and len(truncated_messages) > 1 else 0
if remove_index >= len(truncated_messages): if remove_index >= len(truncated_messages):
logger.error(f"Truncation logic error: remove_index {remove_index} out of bounds for {len(truncated_messages)} messages.") logger.error(f"Truncation logic error: remove_index {remove_index} out of bounds for {len(truncated_messages)} messages.")
break # Avoid index error break
removed_message = truncated_messages.pop(remove_index) removed_message = truncated_messages.pop(remove_index)
logger.debug(f"Truncating message at index {remove_index} (Role: {removed_message.get('role')}) due to context limit.") logger.debug(f"Truncating message at index {remove_index} (Role: {removed_message.get('role')}) due to context limit.")
# Recalculate estimated count
final_estimated_count = estimate_openai_token_count(truncated_messages) final_estimated_count = estimate_openai_token_count(truncated_messages)
logger.debug(f"Recalculated estimated tokens: {final_estimated_count}") logger.debug(f"Recalculated estimated tokens: {final_estimated_count}")
# Safety break if list becomes unexpectedly empty
if not truncated_messages: if not truncated_messages:
logger.warning("Truncation resulted in empty message list.") logger.warning("Truncation resulted in empty message list.")
break break