Compare commits
12 Commits
a4683023ad
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
247835e595
|
|||
|
51e3058961
|
|||
|
ccf750fed4
|
|||
|
2fb6c5af3c
|
|||
|
6b390a35f8
|
|||
|
678f395649
|
|||
|
bae517a322
|
|||
|
ab8d5fe074
|
|||
|
246d921743
|
|||
|
15ecb9fc48
|
|||
|
49aebc12d5
|
|||
|
bd56cc839d
|
9
.gitignore
vendored
9
.gitignore
vendored
@@ -5,6 +5,7 @@ __pycache__/
|
||||
|
||||
# Virtual environment
|
||||
env/
|
||||
.venv/
|
||||
|
||||
# Configuration
|
||||
config/config.ini
|
||||
@@ -20,4 +21,10 @@ config/mcp_config.json
|
||||
# resources
|
||||
resources/
|
||||
|
||||
# __pycache__/
|
||||
# Ruff
|
||||
.ruff_cache/
|
||||
|
||||
# Distribution / packaging
|
||||
dist/
|
||||
build/
|
||||
*.egg-info/
|
||||
|
||||
@@ -67,7 +67,7 @@ servers_json = config/mcp_config.json
|
||||
|
||||
Start the application:
|
||||
```bash
|
||||
streamlit run src/app.py
|
||||
uv run mcpapp
|
||||
```
|
||||
|
||||
The app will be available at `http://localhost:8501`
|
||||
@@ -82,9 +82,6 @@ Key components:
|
||||
|
||||
## Development
|
||||
|
||||
### Running Tests
|
||||
```bash
|
||||
pytest
|
||||
```
|
||||
|
||||
### Code Formatting
|
||||
@@ -94,7 +91,7 @@ ruff check . --fix
|
||||
|
||||
### Building
|
||||
```bash
|
||||
python -m build
|
||||
uv build
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
@@ -1,24 +1,28 @@
|
||||
[base]
|
||||
# provider can be [ openai|openrouter|anthropic|google]
|
||||
provider = openrouter
|
||||
streamlit_headless = true
|
||||
|
||||
[openrouter]
|
||||
api_key = YOUR_API_KEY
|
||||
base_url = https://openrouter.ai/api/v1
|
||||
model = openai/gpt-4o-2024-11-20
|
||||
context_window = 128000
|
||||
temperature = 0.6
|
||||
|
||||
[anthropic]
|
||||
api_key = YOUR_API_KEY
|
||||
base_url = https://api.anthropic.com/v1/messages
|
||||
model = claude-3-7-sonnet-20250219
|
||||
context_window = 128000
|
||||
temperature = 0.6
|
||||
|
||||
[google]
|
||||
api_key = YOUR_API_KEY
|
||||
base_url = https://generativelanguage.googleapis.com/v1beta/generateContent
|
||||
model = gemini-2.0-flash
|
||||
context_window = 1000000
|
||||
temperature = 0.6
|
||||
|
||||
|
||||
[openai]
|
||||
@@ -26,6 +30,7 @@ api_key = YOUR_API_KEY
|
||||
base_url = https://api.openai.com/v1
|
||||
model = openai/gpt-4o
|
||||
context_window = 128000
|
||||
temperature = 0.6
|
||||
|
||||
[mcp]
|
||||
servers_json = config/mcp_config.json
|
||||
|
||||
106
project_planning/updates.md
Normal file
106
project_planning/updates.md
Normal file
@@ -0,0 +1,106 @@
|
||||
What is the google-genai Module?
|
||||
|
||||
The google-genai module is part of the Google Gen AI Python SDK, a software development kit provided by Google to enable developers to integrate Google's generative AI models into Python applications. This SDK is distinct from the older, deprecated google-generativeai package. The google-genai package represents the newer, unified SDK designed to work with Google's latest generative AI offerings, such as the Gemini models.
|
||||
|
||||
Installation
|
||||
|
||||
To use the google-genai module, you need to install it via pip. The package name on PyPI is google-genai, and you can install it with the following command:
|
||||
|
||||
bash
|
||||
|
||||
```bash
|
||||
pip install google-genai
|
||||
```
|
||||
|
||||
This installs the necessary dependencies and makes the module available in your Python environment.
|
||||
|
||||
Correct Import Statement
|
||||
|
||||
The standard import statement for the google-genai SDK, as per the official documentation and examples, is:
|
||||
|
||||
python
|
||||
|
||||
```python
|
||||
from google import genai
|
||||
```
|
||||
|
||||
This differs from the older SDK's import style, which was:
|
||||
|
||||
python
|
||||
|
||||
```python
|
||||
import google.generativeai as genai
|
||||
```
|
||||
|
||||
When you install google-genai, it provides a module structure where genai is a submodule of the google package. Thus, from google import genai is the correct way to access its functionality.
|
||||
|
||||
Usage in the New SDK
|
||||
|
||||
Unlike the older google-generativeai SDK, which used a configure method (e.g., genai.configure(api_key='YOUR_API_KEY')) to set up the API key globally, the new google-genai SDK adopts a client-based approach. You create a Client instance with your API key and use it to interact with the models. Here’s a basic example:
|
||||
|
||||
python
|
||||
|
||||
```python
|
||||
from google import genai
|
||||
|
||||
# Initialize the client with your API key
|
||||
client = genai.Client(api_key='YOUR_API_KEY')
|
||||
|
||||
# Example: Generate content using a model
|
||||
response = client.models.generate_content(
|
||||
model='gemini-2.0-flash-001', # Specify the model name
|
||||
contents='Why is the sky blue?'
|
||||
)
|
||||
|
||||
print(response.text)
|
||||
```
|
||||
|
||||
Key points about this usage:
|
||||
|
||||
- No configure Method: The new SDK does not have a genai.configure method directly on the genai module. Instead, you pass the API key when creating a Client instance.
|
||||
|
||||
- Client-Based Interaction: All interactions with the generative models (e.g., generating content) are performed through the client object.
|
||||
|
||||
|
||||
Official Documentation
|
||||
|
||||
The official documentation for the google-genai SDK can be found on Google's API documentation site. Specifically:
|
||||
|
||||
- Google Gen AI Python SDK Documentation: Hosted at https://googleapis.github.io/python-genai/, this site provides detailed guides, API references, and code examples. It confirms the use of from google import genai and the client-based approach.
|
||||
|
||||
- GitHub Repository: The source code and additional examples are available at https://github.com/googleapis/python-genai. The repository documentation reinforces that from google import genai is the import style for the new SDK.
|
||||
|
||||
|
||||
Why Your Import is Correct
|
||||
|
||||
You mentioned that your import is correct, which I assume refers to from google import genai. This is indeed the proper import for the google-genai package, aligning with the new SDK's design. If you're encountering issues (e.g., an error like module 'google.genai' has no attribute 'configure'), it’s likely because the code is trying to use methods from the older SDK (like genai.configure) that don’t exist in the new one. To resolve this, you should update the code to use the client-based approach shown above.
|
||||
|
||||
Troubleshooting Common Issues
|
||||
|
||||
If you're seeing errors with from google import genai, here are some things to check:
|
||||
|
||||
1. Correct Package Installed:
|
||||
|
||||
- Ensure google-genai is installed (pip install google-genai).
|
||||
|
||||
- If google-generativeai is installed instead, uninstall it with pip uninstall google-generativeai to avoid conflicts, then install google-genai.
|
||||
|
||||
2. Code Compatibility:
|
||||
|
||||
- If your code uses genai.configure or assumes the older SDK’s structure, you’ll need to refactor it. Replace configuration calls with genai.Client(api_key='...') and adjust model interactions to use the client object.
|
||||
|
||||
3. Environment Verification:
|
||||
|
||||
- Run pip show google-genai to confirm the package is installed and check its version. This ensures you’re working with the intended SDK.
|
||||
|
||||
|
||||
Additional Resources
|
||||
|
||||
- PyPI Page: The google-genai package on PyPI (https://pypi.org/project/google-genai/) provides installation instructions and links to the GitHub repository.
|
||||
|
||||
- Examples: The GitHub repository includes sample code demonstrating how to use the SDK with from google import genai.
|
||||
|
||||
|
||||
Conclusion
|
||||
|
||||
Your import, from google import genai, aligns with the google-genai module from the new Google Gen AI Python SDK. The documentation and online resources confirm this as the correct approach for the current, unified SDK. If import google.generativeai was previously suggested or used, it pertains to the older, deprecated SDK, which explains why it might be considered incorrect in your context. To fully leverage google-genai, ensure your code uses the client-based API as outlined, and you should be able to interact with Google’s generative AI models effectively. If you’re still facing specific errors, feel free to share them, and I can assist further!
|
||||
@@ -1,5 +1,5 @@
|
||||
[project]
|
||||
name = "streamlit-chat-app"
|
||||
name = "macpapp"
|
||||
version = "0.1.0"
|
||||
description = "Streamlit chat app with MCP"
|
||||
readme = "README.md"
|
||||
@@ -27,6 +27,9 @@ license-files = ["LICEN[CS]E*"]
|
||||
GitHub = "https://git.bhakat.dev/abhishekbhakat/mcpapp"
|
||||
Issues = "https://git.bhakat.dev/abhishekbhakat/mcpapp/issues"
|
||||
|
||||
[project.scripts]
|
||||
mcpapp = "run_app:run_streamlit_app"
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"build>=1.2.2",
|
||||
@@ -84,7 +87,7 @@ skip-magic-trailing-comma = false
|
||||
combine-as-imports = true
|
||||
|
||||
[tool.ruff.lint.mccabe]
|
||||
max-complexity = 16
|
||||
max-complexity = 30
|
||||
|
||||
[tool.ruff.lint.flake8-tidy-imports]
|
||||
# Disallow all relative imports.
|
||||
|
||||
57
run_app.py
Normal file
57
run_app.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import configparser
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
|
||||
def run_streamlit_app():
|
||||
"""
|
||||
Reads the configuration file and launches the Streamlit app,
|
||||
optionally in headless mode.
|
||||
"""
|
||||
config_path = "config/config.ini"
|
||||
headless = False
|
||||
|
||||
try:
|
||||
if os.path.exists(config_path):
|
||||
config = configparser.ConfigParser()
|
||||
config.read(config_path)
|
||||
if config.has_section("base"):
|
||||
headless = config.getboolean("base", "streamlit_headless", fallback=False)
|
||||
if headless:
|
||||
print(f"INFO: Headless mode enabled via {config_path}.")
|
||||
else:
|
||||
print(f"INFO: Headless mode disabled via {config_path}.")
|
||||
else:
|
||||
print(f"WARNING: [base] section not found in {config_path}. Defaulting to non-headless.")
|
||||
else:
|
||||
print(f"WARNING: Configuration file not found at {config_path}. Defaulting to non-headless.")
|
||||
except Exception as e:
|
||||
print(f"ERROR: Could not read headless config from {config_path}: {e}. Defaulting to non-headless.")
|
||||
headless = False # Ensure default on error
|
||||
|
||||
# Construct the command
|
||||
command = [sys.executable, "-m", "streamlit", "run", "src/app.py"]
|
||||
if headless:
|
||||
command.extend(["--server.headless", "true"])
|
||||
|
||||
print(f"Running command: {' '.join(command)}")
|
||||
|
||||
try:
|
||||
# Run Streamlit using subprocess.run which waits for completion
|
||||
# Use check=True to raise an error if Streamlit fails
|
||||
# Capture output might be useful for debugging but can be complex with interactive apps
|
||||
process = subprocess.Popen(command)
|
||||
process.wait() # Wait for the Streamlit process to exit
|
||||
print(f"Streamlit process finished with exit code: {process.returncode}")
|
||||
|
||||
except FileNotFoundError:
|
||||
print("ERROR: 'streamlit' command not found. Make sure Streamlit is installed and in your PATH.")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to run Streamlit: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_streamlit_app()
|
||||
119
src/app.py
119
src/app.py
@@ -1,15 +1,12 @@
|
||||
import atexit
|
||||
import configparser
|
||||
import json # For handling potential error JSON in stream
|
||||
import logging
|
||||
|
||||
import streamlit as st
|
||||
|
||||
# Updated imports
|
||||
from llm_client import LLMClient
|
||||
from src.custom_mcp.manager import SyncMCPManager # Updated import path
|
||||
from src.custom_mcp.manager import SyncMCPManager
|
||||
|
||||
# Configure logging for the app
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -24,14 +21,12 @@ def init_session_state():
|
||||
logger.info("Attempting to initialize clients...")
|
||||
try:
|
||||
config = configparser.ConfigParser()
|
||||
# TODO: Improve config file path handling (e.g., environment variable, absolute path)
|
||||
config_files_read = config.read("config/config.ini")
|
||||
if not config_files_read:
|
||||
raise FileNotFoundError("config.ini not found or could not be read.")
|
||||
logger.info(f"Read configuration from: {config_files_read}")
|
||||
|
||||
# --- MCP Manager Setup ---
|
||||
mcp_config_path = "config/mcp_config.json" # Default
|
||||
mcp_config_path = "config/mcp_config.json"
|
||||
if config.has_section("mcp") and config["mcp"].get("servers_json"):
|
||||
mcp_config_path = config["mcp"]["servers_json"]
|
||||
logger.info(f"Using MCP config path from config.ini: {mcp_config_path}")
|
||||
@@ -40,39 +35,37 @@ def init_session_state():
|
||||
|
||||
mcp_manager = SyncMCPManager(mcp_config_path)
|
||||
if not mcp_manager.initialize():
|
||||
# Log warning but continue - LLMClient will operate without tools
|
||||
logger.warning("MCP Manager failed to initialize. Proceeding without MCP tools.")
|
||||
else:
|
||||
logger.info("MCP Manager initialized successfully.")
|
||||
# Register shutdown hook for MCP manager
|
||||
atexit.register(mcp_manager.shutdown)
|
||||
logger.info("Registered MCP Manager shutdown hook.")
|
||||
|
||||
# --- LLM Client Setup ---
|
||||
provider_name = None
|
||||
model_name = None
|
||||
api_key = None
|
||||
base_url = None
|
||||
|
||||
# 1. Determine provider from [base] section
|
||||
if config.has_section("base") and config["base"].get("provider"):
|
||||
provider_name = config["base"].get("provider")
|
||||
logger.info(f"Provider selected from [base] section: {provider_name}")
|
||||
else:
|
||||
# Fallback or error if [base] provider is missing? Let's error for now.
|
||||
raise ValueError("Missing 'provider' setting in [base] section of config.ini")
|
||||
|
||||
# 2. Read details from the specific provider's section
|
||||
if config.has_section(provider_name):
|
||||
provider_config = config[provider_name]
|
||||
model_name = provider_config.get("model")
|
||||
api_key = provider_config.get("api_key")
|
||||
base_url = provider_config.get("base_url") # Optional
|
||||
base_url = provider_config.get("base_url")
|
||||
provider_temperature = provider_config.getfloat("temperature", 0.6)
|
||||
if "temperature" not in provider_config:
|
||||
logger.warning(f"Temperature not found in [{provider_name}] section, defaulting to {provider_temperature}")
|
||||
else:
|
||||
logger.info(f"Loaded temperature for {provider_name}: {provider_temperature}")
|
||||
logger.info(f"Read configuration from [{provider_name}] section.")
|
||||
else:
|
||||
raise ValueError(f"Missing configuration section '[{provider_name}]' in config.ini for the selected provider.")
|
||||
|
||||
# Validate required config
|
||||
if not api_key:
|
||||
raise ValueError(f"Missing 'api_key' in [{provider_name}] section of config.ini")
|
||||
if not model_name:
|
||||
@@ -84,14 +77,15 @@ def init_session_state():
|
||||
api_key=api_key,
|
||||
mcp_manager=mcp_manager,
|
||||
base_url=base_url,
|
||||
temperature=provider_temperature,
|
||||
)
|
||||
st.session_state.model_name = model_name
|
||||
st.session_state.provider_name = provider_name
|
||||
logger.info("LLMClient initialized successfully.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize application clients: {e}", exc_info=True)
|
||||
st.error(f"Application Initialization Error: {e}. Please check configuration and logs.")
|
||||
# Stop the app if initialization fails critically
|
||||
st.stop()
|
||||
|
||||
|
||||
@@ -99,8 +93,12 @@ def display_chat_messages():
|
||||
"""Displays chat messages stored in session state."""
|
||||
for message in st.session_state.messages:
|
||||
with st.chat_message(message["role"]):
|
||||
# Simple markdown display for now
|
||||
st.markdown(message["content"])
|
||||
if message["role"] == "assistant" and "usage" in message:
|
||||
usage = message["usage"]
|
||||
prompt_tokens = usage.get("prompt_tokens", "N/A")
|
||||
completion_tokens = usage.get("completion_tokens", "N/A")
|
||||
st.caption(f"Tokens: Prompt {prompt_tokens}, Completion {completion_tokens}")
|
||||
|
||||
|
||||
def handle_user_input():
|
||||
@@ -116,60 +114,41 @@ def handle_user_input():
|
||||
response_placeholder = st.empty()
|
||||
full_response = ""
|
||||
error_occurred = False
|
||||
response_usage = None
|
||||
|
||||
logger.info("Processing message via LLMClient...")
|
||||
# Use the new client and method, always requesting stream for UI
|
||||
response_stream = st.session_state.client.chat_completion(
|
||||
response_data = st.session_state.client.chat_completion(
|
||||
messages=st.session_state.messages,
|
||||
model=st.session_state.model_name, # Get model from session state
|
||||
stream=True,
|
||||
model=st.session_state.model_name,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# Handle the response (stream generator or error dict)
|
||||
if hasattr(response_stream, "__iter__") and not isinstance(response_stream, dict):
|
||||
logger.debug("Processing response stream...")
|
||||
for chunk in response_stream:
|
||||
# Check for potential error JSON yielded by the stream
|
||||
try:
|
||||
# Attempt to parse chunk as JSON only if it looks like it
|
||||
if isinstance(chunk, str) and chunk.strip().startswith("{"):
|
||||
error_data = json.loads(chunk)
|
||||
if isinstance(error_data, dict) and "error" in error_data:
|
||||
full_response = f"Error: {error_data['error']}"
|
||||
logger.error(f"Error received in stream: {full_response}")
|
||||
st.error(full_response)
|
||||
error_occurred = True
|
||||
break # Stop processing stream on error
|
||||
# If not error JSON, treat as content chunk
|
||||
if not error_occurred and isinstance(chunk, str):
|
||||
full_response += chunk
|
||||
response_placeholder.markdown(full_response + "▌") # Add cursor effect
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# Not JSON or not error structure, treat as content chunk
|
||||
if not error_occurred and isinstance(chunk, str):
|
||||
full_response += chunk
|
||||
response_placeholder.markdown(full_response + "▌") # Add cursor effect
|
||||
if isinstance(response_data, dict):
|
||||
if "error" in response_data:
|
||||
full_response = f"Error: {response_data['error']}"
|
||||
logger.error(f"Error returned from chat_completion: {full_response}")
|
||||
st.error(full_response)
|
||||
error_occurred = True
|
||||
else:
|
||||
full_response = response_data.get("content", "")
|
||||
response_usage = response_data.get("usage")
|
||||
if not full_response and not error_occurred:
|
||||
logger.warning("Empty content received from LLMClient.")
|
||||
response_placeholder.markdown(full_response)
|
||||
logger.debug("Non-streaming response processed.")
|
||||
|
||||
if not error_occurred:
|
||||
response_placeholder.markdown(full_response) # Final update without cursor
|
||||
logger.debug("Stream processing complete.")
|
||||
|
||||
elif isinstance(response_stream, dict) and "error" in response_stream:
|
||||
# Handle error dict returned directly (e.g., API error before streaming)
|
||||
full_response = f"Error: {response_stream['error']}"
|
||||
logger.error(f"Error returned directly from chat_completion: {full_response}")
|
||||
st.error(full_response)
|
||||
error_occurred = True
|
||||
else:
|
||||
# Unexpected response type
|
||||
full_response = "[Unexpected response format from LLMClient]"
|
||||
logger.error(f"Unexpected response type: {type(response_stream)}")
|
||||
logger.error(f"Unexpected response type: {type(response_data)}")
|
||||
st.error(full_response)
|
||||
error_occurred = True
|
||||
|
||||
# Only add non-error, non-empty responses to history
|
||||
if not error_occurred and full_response:
|
||||
st.session_state.messages.append({"role": "assistant", "content": full_response})
|
||||
assistant_message = {"role": "assistant", "content": full_response}
|
||||
if response_usage:
|
||||
assistant_message["usage"] = response_usage
|
||||
logger.info(f"Assistant response usage: {response_usage}")
|
||||
st.session_state.messages.append(assistant_message)
|
||||
logger.info("Assistant response added to history.")
|
||||
elif error_occurred:
|
||||
logger.warning("Assistant response not added to history due to error.")
|
||||
@@ -183,13 +162,31 @@ def handle_user_input():
|
||||
|
||||
def main():
|
||||
"""Main function to run the Streamlit app."""
|
||||
st.title("MCP Chat App") # Updated title
|
||||
try:
|
||||
init_session_state()
|
||||
|
||||
provider_name = st.session_state.get("provider_name", "Unknown Provider")
|
||||
model_name = st.session_state.get("model_name", "Unknown Model")
|
||||
mcp_manager = st.session_state.client.mcp_manager
|
||||
|
||||
server_count = 0
|
||||
tool_count = 0
|
||||
if mcp_manager and mcp_manager.initialized:
|
||||
server_count = len(mcp_manager.servers)
|
||||
try:
|
||||
tool_count = len(mcp_manager.list_all_tools())
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not retrieve tool count for header: {e}")
|
||||
tool_count = "N/A"
|
||||
|
||||
st.markdown(f"# Say Hi to **{provider_name.capitalize()}**!")
|
||||
st.write(f"MCP Servers: **{server_count}** | Tools: **{tool_count}**")
|
||||
st.write(f"Model: **{model_name}**")
|
||||
st.divider()
|
||||
|
||||
display_chat_messages()
|
||||
handle_user_input()
|
||||
except Exception as e:
|
||||
# Catch potential errors during rendering or handling
|
||||
logger.critical(f"Critical error in main app flow: {e}", exc_info=True)
|
||||
st.error(f"A critical application error occurred: {e}")
|
||||
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
# This file makes src/mcp a Python package
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# src/mcp/client.py
|
||||
"""Client class for managing and interacting with a single MCP server process."""
|
||||
|
||||
import asyncio
|
||||
@@ -9,9 +8,8 @@ from custom_mcp import process, protocol
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Define reasonable timeouts
|
||||
LIST_TOOLS_TIMEOUT = 20.0 # Seconds (using the increased value from previous step)
|
||||
CALL_TOOL_TIMEOUT = 110.0 # Seconds
|
||||
LIST_TOOLS_TIMEOUT = 20.0
|
||||
CALL_TOOL_TIMEOUT = 110.0
|
||||
|
||||
|
||||
class MCPClient:
|
||||
@@ -39,7 +37,7 @@ class MCPClient:
|
||||
self._stderr_task: asyncio.Task | None = None
|
||||
self._request_counter = 0
|
||||
self._is_running = False
|
||||
self.logger = logging.getLogger(f"{__name__}.{self.server_name}") # Instance-specific logger
|
||||
self.logger = logging.getLogger(f"{__name__}.{self.server_name}")
|
||||
|
||||
async def _log_stderr(self):
|
||||
"""Logs stderr output from the server process."""
|
||||
@@ -55,7 +53,6 @@ class MCPClient:
|
||||
except asyncio.CancelledError:
|
||||
self.logger.debug("Stderr logging task cancelled.")
|
||||
except Exception as e:
|
||||
# Log errors but don't crash the logger task if possible
|
||||
self.logger.error(f"Error reading stderr: {e}", exc_info=True)
|
||||
finally:
|
||||
self.logger.debug("Stderr logging task finished.")
|
||||
@@ -79,13 +76,11 @@ class MCPClient:
|
||||
|
||||
if self.reader is None or self.writer is None:
|
||||
self.logger.error("Failed to get stdout/stdin streams after process start.")
|
||||
await self.stop() # Attempt cleanup
|
||||
await self.stop()
|
||||
return False
|
||||
|
||||
# Start background task to monitor stderr
|
||||
self._stderr_task = asyncio.create_task(self._log_stderr())
|
||||
|
||||
# --- Start MCP Initialization Handshake ---
|
||||
self.logger.info("Starting MCP initialization handshake...")
|
||||
self._request_counter += 1
|
||||
init_req_id = self._request_counter
|
||||
@@ -94,21 +89,18 @@ class MCPClient:
|
||||
"id": init_req_id,
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05", # Use a recent version
|
||||
"clientInfo": {"name": "CustomMCPClient", "version": "1.0.0"}, # Identify the client
|
||||
"capabilities": {}, # Client capabilities (can be empty)
|
||||
"protocolVersion": "2024-11-05",
|
||||
"clientInfo": {"name": "CustomMCPClient", "version": "1.0.0"},
|
||||
"capabilities": {},
|
||||
},
|
||||
}
|
||||
|
||||
# Define a timeout for initialization
|
||||
INITIALIZE_TIMEOUT = 15.0 # Seconds
|
||||
INITIALIZE_TIMEOUT = 15.0
|
||||
|
||||
try:
|
||||
# Send initialize request
|
||||
await protocol.send_request(self.writer, initialize_req)
|
||||
self.logger.debug(f"Sent 'initialize' request (ID: {init_req_id}). Waiting for response...")
|
||||
|
||||
# Wait for initialize response
|
||||
init_response = await protocol.read_response(self.reader, INITIALIZE_TIMEOUT)
|
||||
|
||||
if init_response and init_response.get("id") == init_req_id:
|
||||
@@ -117,9 +109,8 @@ class MCPClient:
|
||||
await self.stop()
|
||||
return False
|
||||
elif "result" in init_response:
|
||||
self.logger.info(f"Received 'initialize' response: {init_response.get('result', '{}')}") # Log server capabilities if provided
|
||||
self.logger.info(f"Received 'initialize' response: {init_response.get('result', '{}')}")
|
||||
|
||||
# Send initialized notification (using standard method name)
|
||||
initialized_notify = {"jsonrpc": "2.0", "method": "notifications/initialized", "params": {}}
|
||||
await protocol.send_request(self.writer, initialized_notify)
|
||||
self.logger.info("'notifications/initialized' notification sent.")
|
||||
@@ -135,7 +126,7 @@ class MCPClient:
|
||||
self.logger.error(f"Received response with mismatched ID during initialization. Expected {init_req_id}, got {init_response.get('id')}")
|
||||
await self.stop()
|
||||
return False
|
||||
else: # Timeout case
|
||||
else:
|
||||
self.logger.error(f"'initialize' request timed out after {INITIALIZE_TIMEOUT} seconds.")
|
||||
await self.stop()
|
||||
return False
|
||||
@@ -148,26 +139,23 @@ class MCPClient:
|
||||
self.logger.error(f"Unexpected error during initialization handshake: {e}", exc_info=True)
|
||||
await self.stop()
|
||||
return False
|
||||
# --- End MCP Initialization Handshake ---
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to start MCP server process: {e}", exc_info=True)
|
||||
self.process = None # Ensure process is None on failure
|
||||
self.process = None
|
||||
self.reader = None
|
||||
self.writer = None
|
||||
self._is_running = False
|
||||
return False
|
||||
|
||||
async def stop(self):
|
||||
"""Stops the MCP server subprocess gracefully."""
|
||||
if not self._is_running and not self.process:
|
||||
self.logger.debug("Stop called but client is not running.")
|
||||
return
|
||||
|
||||
self.logger.info("Stopping MCP server process...")
|
||||
self._is_running = False # Mark as stopping
|
||||
self._is_running = False
|
||||
|
||||
# Cancel stderr logging task
|
||||
if self._stderr_task and not self._stderr_task.done():
|
||||
self._stderr_task.cancel()
|
||||
try:
|
||||
@@ -178,11 +166,9 @@ class MCPClient:
|
||||
self.logger.error(f"Error waiting for stderr task cancellation: {e}")
|
||||
self._stderr_task = None
|
||||
|
||||
# Stop the process using the utility function
|
||||
if self.process:
|
||||
await process.stop_mcp_process(self.process, self.server_name)
|
||||
|
||||
# Nullify references
|
||||
self.process = None
|
||||
self.reader = None
|
||||
self.writer = None
|
||||
@@ -219,7 +205,6 @@ class MCPClient:
|
||||
self.logger.error(f"Error response for listTools ID {req_id}: {response['error']}")
|
||||
return None
|
||||
else:
|
||||
# Includes timeout case (read_response returns None)
|
||||
self.logger.error(f"No valid response or timeout for listTools ID {req_id}.")
|
||||
return None
|
||||
|
||||
@@ -260,15 +245,12 @@ class MCPClient:
|
||||
response = await protocol.read_response(self.reader, CALL_TOOL_TIMEOUT)
|
||||
|
||||
if response and "result" in response:
|
||||
# Assuming result is the desired payload
|
||||
self.logger.info(f"Tool '{tool_name}' executed successfully.")
|
||||
return response["result"]
|
||||
elif response and "error" in response:
|
||||
self.logger.error(f"Error response for tool '{tool_name}' ID {req_id}: {response['error']}")
|
||||
# Return the error structure itself? Or just None? Returning error dict for now.
|
||||
return {"error": response["error"]}
|
||||
else:
|
||||
# Includes timeout case
|
||||
self.logger.error(f"No valid response or timeout for tool '{tool_name}' ID {req_id}.")
|
||||
return None
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# src/mcp/manager.py
|
||||
"""Synchronous manager for multiple MCPClient instances."""
|
||||
|
||||
import asyncio
|
||||
@@ -7,19 +6,15 @@ import logging
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
# Use relative imports within the mcp package
|
||||
from custom_mcp.client import MCPClient
|
||||
|
||||
# Configure basic logging
|
||||
# Consider moving this to the main app entry point if not already done
|
||||
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Define reasonable timeouts for sync calls (should be slightly longer than async timeouts)
|
||||
INITIALIZE_TIMEOUT = 60.0 # Seconds
|
||||
SHUTDOWN_TIMEOUT = 30.0 # Seconds
|
||||
LIST_ALL_TOOLS_TIMEOUT = 30.0 # Seconds
|
||||
EXECUTE_TOOL_TIMEOUT = 120.0 # Seconds
|
||||
INITIALIZE_TIMEOUT = 60.0
|
||||
SHUTDOWN_TIMEOUT = 30.0
|
||||
LIST_ALL_TOOLS_TIMEOUT = 30.0
|
||||
EXECUTE_TOOL_TIMEOUT = 120.0
|
||||
|
||||
|
||||
class SyncMCPManager:
|
||||
@@ -37,7 +32,6 @@ class SyncMCPManager:
|
||||
"""
|
||||
self.config_path = config_path
|
||||
self.config: dict[str, Any] | None = None
|
||||
# Stores server_name -> MCPClient instance
|
||||
self.servers: dict[str, MCPClient] = {}
|
||||
self.initialized = False
|
||||
self._lock = threading.Lock()
|
||||
@@ -50,7 +44,6 @@ class SyncMCPManager:
|
||||
"""Load MCP configuration from JSON file."""
|
||||
logger.debug(f"Attempting to load MCP config from: {self.config_path}")
|
||||
try:
|
||||
# Using direct file access
|
||||
with open(self.config_path) as f:
|
||||
self.config = json.load(f)
|
||||
logger.info("MCP configuration loaded successfully.")
|
||||
@@ -65,8 +58,6 @@ class SyncMCPManager:
|
||||
logger.error(f"Error loading MCP config from {self.config_path}: {e}", exc_info=True)
|
||||
self.config = None
|
||||
|
||||
# --- Background Event Loop Management ---
|
||||
|
||||
def _run_event_loop(self):
|
||||
"""Target function for the background event loop thread."""
|
||||
try:
|
||||
@@ -75,14 +66,12 @@ class SyncMCPManager:
|
||||
self._loop.run_forever()
|
||||
finally:
|
||||
if self._loop and not self._loop.is_closed():
|
||||
# Clean up remaining tasks before closing
|
||||
try:
|
||||
tasks = asyncio.all_tasks(self._loop)
|
||||
if tasks:
|
||||
logger.debug(f"Cancelling {len(tasks)} outstanding tasks before closing loop...")
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
# Allow cancellation to propagate
|
||||
self._loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
|
||||
logger.debug("Outstanding tasks cancelled.")
|
||||
self._loop.run_until_complete(self._loop.shutdown_asyncgens())
|
||||
@@ -99,9 +88,7 @@ class SyncMCPManager:
|
||||
self._thread = threading.Thread(target=self._run_event_loop, name="MCPEventLoop", daemon=True)
|
||||
self._thread.start()
|
||||
logger.info("Event loop thread started.")
|
||||
# Wait briefly for the loop to become available and running
|
||||
while self._loop is None or not self._loop.is_running():
|
||||
# Use time.sleep in sync context
|
||||
import time
|
||||
|
||||
time.sleep(0.01)
|
||||
@@ -121,8 +108,6 @@ class SyncMCPManager:
|
||||
self._thread = None
|
||||
logger.info("Event loop stopped.")
|
||||
|
||||
# --- Public Synchronous Interface ---
|
||||
|
||||
def initialize(self) -> bool:
|
||||
"""
|
||||
Initializes and starts all configured MCP servers synchronously.
|
||||
@@ -147,8 +132,6 @@ class SyncMCPManager:
|
||||
|
||||
logger.info("Submitting asynchronous server initialization...")
|
||||
|
||||
# Prepare coroutine to start all clients
|
||||
|
||||
async def _async_init_all():
|
||||
tasks = []
|
||||
for server_name, server_config in self.config["mcpServers"].items():
|
||||
@@ -161,19 +144,17 @@ class SyncMCPManager:
|
||||
|
||||
client = MCPClient(server_name, command, args, config_env)
|
||||
self.servers[server_name] = client
|
||||
tasks.append(client.start()) # Append the start coroutine
|
||||
tasks.append(client.start())
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Check results - True means success, False or Exception means failure
|
||||
all_success = True
|
||||
failed_servers = []
|
||||
for i, result in enumerate(results):
|
||||
server_name = list(self.config["mcpServers"].keys())[i] # Assumes order is maintained
|
||||
server_name = list(self.config["mcpServers"].keys())[i]
|
||||
if isinstance(result, Exception) or result is False:
|
||||
all_success = False
|
||||
failed_servers.append(server_name)
|
||||
# Remove failed client from managed servers
|
||||
if server_name in self.servers:
|
||||
del self.servers[server_name]
|
||||
logger.error(f"Failed to start client for server '{server_name}'. Result/Error: {result}")
|
||||
@@ -182,7 +163,6 @@ class SyncMCPManager:
|
||||
logger.error(f"Initialization failed for servers: {failed_servers}")
|
||||
return all_success
|
||||
|
||||
# Run the initialization coroutine in the background loop
|
||||
future = asyncio.run_coroutine_threadsafe(_async_init_all(), self._loop)
|
||||
try:
|
||||
success = future.result(timeout=INITIALIZE_TIMEOUT)
|
||||
@@ -192,17 +172,16 @@ class SyncMCPManager:
|
||||
else:
|
||||
logger.error("Asynchronous initialization failed.")
|
||||
self.initialized = False
|
||||
# Attempt to clean up any partially started servers
|
||||
self.shutdown() # Call sync shutdown
|
||||
self.shutdown()
|
||||
except TimeoutError:
|
||||
logger.error(f"Initialization timed out after {INITIALIZE_TIMEOUT}s.")
|
||||
self.initialized = False
|
||||
self.shutdown() # Clean up
|
||||
self.shutdown()
|
||||
success = False
|
||||
except Exception as e:
|
||||
logger.error(f"Exception during initialization future result: {e}", exc_info=True)
|
||||
self.initialized = False
|
||||
self.shutdown() # Clean up
|
||||
self.shutdown()
|
||||
success = False
|
||||
|
||||
return self.initialized
|
||||
@@ -211,20 +190,14 @@ class SyncMCPManager:
|
||||
"""Shuts down all managed MCP servers synchronously."""
|
||||
logger.info("Manager shutdown requested.")
|
||||
with self._lock:
|
||||
# Check servers dict too, in case init was partial
|
||||
if not self.initialized and not self.servers:
|
||||
logger.debug("Shutdown skipped: Not initialized or no servers running.")
|
||||
# Ensure loop is stopped if it exists
|
||||
if self._thread and self._thread.is_alive():
|
||||
self._stop_event_loop_thread()
|
||||
return
|
||||
|
||||
if not self._loop or not self._loop.is_running():
|
||||
logger.warning("Shutdown requested but event loop not running. Attempting direct cleanup.")
|
||||
# Attempt direct cleanup if loop isn't running (shouldn't happen ideally)
|
||||
# This part is tricky as MCPClient.stop is async.
|
||||
# For simplicity, we might just log and rely on process termination on app exit.
|
||||
# Or, try a temporary loop just for shutdown? Let's stick to stopping the thread for now.
|
||||
self.servers = {}
|
||||
self.initialized = False
|
||||
if self._thread and self._thread.is_alive():
|
||||
@@ -233,28 +206,22 @@ class SyncMCPManager:
|
||||
|
||||
logger.info("Submitting asynchronous server shutdown...")
|
||||
|
||||
# Prepare coroutine to stop all clients
|
||||
|
||||
async def _async_shutdown_all():
|
||||
tasks = [client.stop() for client in self.servers.values()]
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Run the shutdown coroutine in the background loop
|
||||
future = asyncio.run_coroutine_threadsafe(_async_shutdown_all(), self._loop)
|
||||
try:
|
||||
future.result(timeout=SHUTDOWN_TIMEOUT)
|
||||
logger.info("Asynchronous shutdown completed.")
|
||||
except TimeoutError:
|
||||
logger.error(f"Shutdown timed out after {SHUTDOWN_TIMEOUT}s. Event loop will be stopped.")
|
||||
# Processes might still be running, OS will clean up on exit hopefully
|
||||
except Exception as e:
|
||||
logger.error(f"Exception during shutdown future result: {e}", exc_info=True)
|
||||
finally:
|
||||
# Always mark as uninitialized and clear servers dict
|
||||
self.servers = {}
|
||||
self.initialized = False
|
||||
# Stop the background thread
|
||||
self._stop_event_loop_thread()
|
||||
|
||||
logger.info("Manager shutdown complete.")
|
||||
@@ -277,7 +244,6 @@ class SyncMCPManager:
|
||||
|
||||
logger.info(f"Requesting tools from {len(self.servers)} servers...")
|
||||
|
||||
# Prepare coroutine to list tools from all clients
|
||||
async def _async_list_all():
|
||||
tasks = []
|
||||
server_names_in_order = []
|
||||
@@ -293,10 +259,8 @@ class SyncMCPManager:
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"Error listing tools for server '{server_name}': {result}")
|
||||
elif result is None:
|
||||
# MCPClient.list_tools returns None on timeout/error
|
||||
logger.error(f"Failed to list tools for server '{server_name}' (timeout or error).")
|
||||
elif isinstance(result, list):
|
||||
# Add server_name to each tool definition
|
||||
for tool in result:
|
||||
tool["server_name"] = server_name
|
||||
all_tools.extend(result)
|
||||
@@ -305,7 +269,6 @@ class SyncMCPManager:
|
||||
logger.error(f"Unexpected result type ({type(result)}) when listing tools for {server_name}.")
|
||||
return all_tools
|
||||
|
||||
# Run the coroutine in the background loop
|
||||
future = asyncio.run_coroutine_threadsafe(_async_list_all(), self._loop)
|
||||
try:
|
||||
aggregated_tools = future.result(timeout=LIST_ALL_TOOLS_TIMEOUT)
|
||||
@@ -346,18 +309,16 @@ class SyncMCPManager:
|
||||
|
||||
logger.info(f"Executing tool '{tool_name}' on server '{server_name}' with args: {arguments}")
|
||||
|
||||
# Run the client's call_tool coroutine in the background loop
|
||||
future = asyncio.run_coroutine_threadsafe(client.call_tool(tool_name, arguments), self._loop)
|
||||
try:
|
||||
result = future.result(timeout=EXECUTE_TOOL_TIMEOUT)
|
||||
# MCPClient.call_tool returns the result dict or an error dict or None
|
||||
if result is None:
|
||||
logger.error(f"Tool execution '{tool_name}' on {server_name} failed (timeout or comm error).")
|
||||
elif isinstance(result, dict) and "error" in result:
|
||||
logger.error(f"Tool execution '{tool_name}' on {server_name} returned error: {result['error']}")
|
||||
else:
|
||||
logger.info(f"Tool '{tool_name}' execution successful.")
|
||||
return result # Return result dict, error dict, or None
|
||||
return result
|
||||
except TimeoutError:
|
||||
logger.error(f"Tool execution timed out after {EXECUTE_TOOL_TIMEOUT}s for '{tool_name}' on {server_name}.")
|
||||
return None
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# src/mcp/process.py
|
||||
"""Async utilities for managing MCP server subprocesses."""
|
||||
|
||||
import asyncio
|
||||
@@ -29,25 +28,20 @@ async def start_mcp_process(command: str, args: list[str], config_env: dict[str,
|
||||
"""
|
||||
logger.debug(f"Preparing to start process for command: {command}")
|
||||
|
||||
# --- Add tilde expansion for arguments ---
|
||||
expanded_args = []
|
||||
try:
|
||||
for arg in args:
|
||||
if isinstance(arg, str) and "~" in arg:
|
||||
expanded_args.append(os.path.expanduser(arg))
|
||||
else:
|
||||
# Ensure all args are strings for list2cmdline
|
||||
expanded_args.append(str(arg))
|
||||
logger.debug(f"Expanded args: {expanded_args}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error expanding arguments for {command}: {e}", exc_info=True)
|
||||
raise ValueError(f"Failed to expand arguments: {e}") from e
|
||||
|
||||
# --- Merge os.environ with config_env ---
|
||||
merged_env = {**os.environ, **config_env}
|
||||
# logger.debug(f"Merged environment prepared (keys: {list(merged_env.keys())})") # Avoid logging values
|
||||
|
||||
# Combine command and expanded args into a single string for shell execution
|
||||
try:
|
||||
cmd_string = subprocess.list2cmdline([command] + expanded_args)
|
||||
logger.debug(f"Executing shell command: {cmd_string}")
|
||||
@@ -55,7 +49,6 @@ async def start_mcp_process(command: str, args: list[str], config_env: dict[str,
|
||||
logger.error(f"Error creating command string: {e}", exc_info=True)
|
||||
raise ValueError(f"Failed to create command string: {e}") from e
|
||||
|
||||
# --- Start the subprocess using shell ---
|
||||
try:
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
cmd_string,
|
||||
@@ -68,10 +61,10 @@ async def start_mcp_process(command: str, args: list[str], config_env: dict[str,
|
||||
return process
|
||||
except FileNotFoundError:
|
||||
logger.error(f"Command not found: '{command}' when trying to execute '{cmd_string}'")
|
||||
raise # Re-raise specific error
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create subprocess for '{cmd_string}': {e}", exc_info=True)
|
||||
raise # Re-raise other errors
|
||||
raise
|
||||
|
||||
|
||||
async def stop_mcp_process(process: asyncio.subprocess.Process, server_name: str = "MCP Server"):
|
||||
@@ -89,7 +82,6 @@ async def stop_mcp_process(process: asyncio.subprocess.Process, server_name: str
|
||||
pid = process.pid
|
||||
logger.info(f"Attempting to stop process {server_name} (PID: {pid})...")
|
||||
|
||||
# Close stdin first
|
||||
if process.stdin and not process.stdin.is_closing():
|
||||
try:
|
||||
process.stdin.close()
|
||||
@@ -98,7 +90,6 @@ async def stop_mcp_process(process: asyncio.subprocess.Process, server_name: str
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing stdin for {server_name} (PID: {pid}): {e}")
|
||||
|
||||
# Attempt graceful termination
|
||||
try:
|
||||
process.terminate()
|
||||
logger.debug(f"Sent terminate signal to {server_name} (PID: {pid})")
|
||||
@@ -108,7 +99,7 @@ async def stop_mcp_process(process: asyncio.subprocess.Process, server_name: str
|
||||
logger.warning(f"Process {server_name} (PID: {pid}) did not terminate gracefully after 5s, killing.")
|
||||
try:
|
||||
process.kill()
|
||||
await process.wait() # Wait for kill to complete
|
||||
await process.wait()
|
||||
logger.info(f"Process {server_name} (PID: {pid}) killed (return code: {process.returncode}).")
|
||||
except ProcessLookupError:
|
||||
logger.warning(f"Process {server_name} (PID: {pid}) already exited before kill.")
|
||||
@@ -118,7 +109,6 @@ async def stop_mcp_process(process: asyncio.subprocess.Process, server_name: str
|
||||
logger.warning(f"Process {server_name} (PID: {pid}) already exited before termination.")
|
||||
except Exception as e_term:
|
||||
logger.error(f"Error during termination of {server_name} (PID: {pid}): {e_term}")
|
||||
# Attempt kill as fallback if terminate failed and process might still be running
|
||||
if process.returncode is None:
|
||||
try:
|
||||
process.kill()
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# src/mcp/protocol.py
|
||||
"""Async utilities for MCP JSON-RPC communication over streams."""
|
||||
|
||||
import asyncio
|
||||
@@ -28,10 +27,10 @@ async def send_request(writer: asyncio.StreamWriter, request_dict: dict[str, Any
|
||||
logger.debug(f"Sent request ID {request_dict.get('id')}: {request_json.strip()}")
|
||||
except ConnectionResetError:
|
||||
logger.error(f"Connection lost while sending request ID {request_dict.get('id')}")
|
||||
raise # Re-raise for the caller (MCPClient) to handle
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending request ID {request_dict.get('id')}: {e}", exc_info=True)
|
||||
raise # Re-raise for the caller
|
||||
raise
|
||||
|
||||
|
||||
async def read_response(reader: asyncio.StreamReader, timeout: float) -> dict[str, Any] | None:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# src/llm_client.py
|
||||
"""
|
||||
Generic LLM client supporting multiple providers and MCP tool integration.
|
||||
"""
|
||||
@@ -26,6 +25,7 @@ class LLMClient:
|
||||
api_key: str,
|
||||
mcp_manager: SyncMCPManager,
|
||||
base_url: str | None = None,
|
||||
temperature: float = 0.6, # Add temperature parameter with a fallback default
|
||||
):
|
||||
"""
|
||||
Initialize the LLM client.
|
||||
@@ -35,9 +35,15 @@ class LLMClient:
|
||||
api_key: API key for the provider.
|
||||
mcp_manager: An initialized instance of SyncMCPManager.
|
||||
base_url: Optional base URL for the provider API.
|
||||
temperature: Default temperature to configure the provider with.
|
||||
"""
|
||||
logger.info(f"Initializing LLMClient for provider: {provider_name}")
|
||||
self.provider: BaseProvider = create_llm_provider(provider_name, api_key, base_url)
|
||||
self.provider: BaseProvider = create_llm_provider(
|
||||
provider_name,
|
||||
api_key,
|
||||
base_url,
|
||||
temperature=temperature, # Pass temperature to provider factory
|
||||
)
|
||||
self.mcp_manager = mcp_manager
|
||||
self.mcp_tools: list[dict[str, Any]] = []
|
||||
self._refresh_mcp_tools() # Initial tool load
|
||||
@@ -56,7 +62,7 @@ class LLMClient:
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
model: str,
|
||||
temperature: float = 0.4,
|
||||
# temperature: float = 0.6, # REMOVE THIS LINE
|
||||
max_tokens: int | None = None,
|
||||
stream: bool = True,
|
||||
) -> Generator[str, None, None] | dict[str, Any]:
|
||||
@@ -66,14 +72,15 @@ class LLMClient:
|
||||
Args:
|
||||
messages: List of message dictionaries ({'role': 'user'/'assistant', 'content': ...}).
|
||||
model: Model identifier string.
|
||||
temperature: Sampling temperature.
|
||||
# temperature: REMOVED - Provider uses its configured temperature.
|
||||
max_tokens: Maximum tokens to generate.
|
||||
stream: Whether to stream the response.
|
||||
|
||||
Returns:
|
||||
If stream=True: A generator yielding content chunks.
|
||||
If stream=False: A dictionary containing the final content or an error.
|
||||
e.g., {"content": "..."} or {"error": "..."}
|
||||
If stream=False: A dictionary containing the final content, usage, or an error.
|
||||
e.g., {"content": "...", "usage": {"prompt_tokens": ..., "completion_tokens": ...}}
|
||||
or {"error": "..."}
|
||||
"""
|
||||
# Ensure tools are up-to-date (optional, could be done less frequently)
|
||||
# self._refresh_mcp_tools()
|
||||
@@ -91,11 +98,12 @@ class LLMClient:
|
||||
response = self.provider.create_chat_completion(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
# temperature=temperature, # REMOVE THIS LINE (provider uses its own)
|
||||
max_tokens=max_tokens,
|
||||
stream=stream,
|
||||
tools=provider_tools,
|
||||
)
|
||||
print(f"Response: {response}") # Debugging line to check the response
|
||||
logger.info("Received response from provider.")
|
||||
|
||||
if stream:
|
||||
@@ -167,14 +175,18 @@ class LLMClient:
|
||||
follow_up_response = self.provider.create_chat_completion(
|
||||
messages=messages, # Now includes assistant's turn and tool results
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
# temperature=temperature, # REMOVE THIS LINE
|
||||
max_tokens=max_tokens,
|
||||
stream=False, # Follow-up is non-streaming here
|
||||
tools=provider_tools, # Pass tools again? Some providers might need it.
|
||||
)
|
||||
final_content = self.provider.get_content(follow_up_response)
|
||||
final_usage = self.provider.get_usage(follow_up_response) # Get usage from follow-up
|
||||
logger.info("Received follow-up response content.")
|
||||
return {"content": final_content}
|
||||
result_dict = {"content": final_content}
|
||||
if final_usage:
|
||||
result_dict["usage"] = final_usage
|
||||
return result_dict
|
||||
|
||||
except Exception as tool_handling_err:
|
||||
logger.error(f"Error processing tool calls: {tool_handling_err}", exc_info=True)
|
||||
@@ -183,7 +195,11 @@ class LLMClient:
|
||||
else: # No tool calls
|
||||
logger.info("No tool calls detected.")
|
||||
content = self.provider.get_content(response)
|
||||
return {"content": content}
|
||||
usage = self.provider.get_usage(response) # Get usage from initial response
|
||||
result_dict = {"content": content}
|
||||
if usage:
|
||||
result_dict["usage"] = usage
|
||||
return result_dict
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"LLM API Error: {str(e)}"
|
||||
@@ -203,17 +219,3 @@ class LLMClient:
|
||||
except Exception as e:
|
||||
logger.error(f"Error during streaming: {e}", exc_info=True)
|
||||
yield json.dumps({"error": f"Streaming error: {str(e)}"}) # Yield error as JSON chunk
|
||||
|
||||
|
||||
# Example of how a provider might need to implement get_original_message_with_calls
|
||||
# This would be in the specific provider class (e.g., openai_provider.py)
|
||||
# def get_original_message_with_calls(self, response: Any) -> Dict[str, Any]:
|
||||
# # For OpenAI, the tool calls are usually in the *first* response chunk's choice delta
|
||||
# # or in the non-streaming response's choice message
|
||||
# # Needs careful implementation based on provider's response structure
|
||||
# assistant_message = {
|
||||
# "role": "assistant",
|
||||
# "content": None, # Often null when tool calls are present
|
||||
# "tool_calls": [...] # Extracted tool calls in provider format
|
||||
# }
|
||||
# return assistant_message
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
MODELS = {
|
||||
"openai": {
|
||||
"name": "OpenAI",
|
||||
"endpoint": "https://api.openai.com/v1",
|
||||
"models": [
|
||||
{
|
||||
"id": "gpt-4o",
|
||||
"name": "GPT-4o",
|
||||
"default": True,
|
||||
"context_window": 128000,
|
||||
"description": "Input $5/M tokens, Output $15/M tokens",
|
||||
}
|
||||
],
|
||||
},
|
||||
"anthropic": {
|
||||
"name": "Anthropic",
|
||||
"endpoint": "https://api.anthropic.com/v1/messages",
|
||||
"models": [
|
||||
{
|
||||
"id": "claude-3-7-sonnet-20250219",
|
||||
"name": "Claude 3.7 Sonnet",
|
||||
"default": True,
|
||||
"context_window": 200000,
|
||||
"description": "Input $3/M tokens, Output $15/M tokens",
|
||||
},
|
||||
{
|
||||
"id": "claude-3-5-haiku-20241022",
|
||||
"name": "Claude 3.5 Haiku",
|
||||
"default": False,
|
||||
"context_window": 200000,
|
||||
"description": "Input $0.80/M tokens, Output $4/M tokens",
|
||||
},
|
||||
],
|
||||
},
|
||||
"google": {
|
||||
"name": "Google Gemini",
|
||||
"endpoint": "https://generativelanguage.googleapis.com/v1beta/generateContent",
|
||||
"models": [
|
||||
{
|
||||
"id": "gemini-2.0-flash",
|
||||
"name": "Gemini 2.0 Flash",
|
||||
"default": True,
|
||||
"context_window": 1000000,
|
||||
"description": "Input $0.1/M tokens, Output $0.4/M tokens",
|
||||
}
|
||||
],
|
||||
},
|
||||
"openrouter": {
|
||||
"name": "OpenRouter",
|
||||
"endpoint": "https://openrouter.ai/api/v1/chat/completions",
|
||||
"models": [
|
||||
{
|
||||
"id": "custom",
|
||||
"name": "Custom Model",
|
||||
"default": False,
|
||||
"context_window": 128000, # Default context window, will be updated based on model
|
||||
"description": "Enter any model name supported by OpenRouter (e.g., 'anthropic/claude-3-opus', 'meta-llama/llama-2-70b')",
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
@@ -1,68 +0,0 @@
|
||||
"""OpenAI client with custom MCP integration."""
|
||||
|
||||
import configparser
|
||||
import logging # Import logging
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from mcp_manager import SyncMCPManager
|
||||
|
||||
# Get a logger for this module
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIClient:
|
||||
def __init__(self):
|
||||
logger.debug("Initializing OpenAIClient...") # Add init log
|
||||
self.config = configparser.ConfigParser()
|
||||
self.config.read("config/config.ini")
|
||||
|
||||
# Validate configuration
|
||||
if not self.config.has_section("openai"):
|
||||
raise Exception("Missing [openai] section in config.ini")
|
||||
if not self.config["openai"].get("api_key"):
|
||||
raise Exception("Missing api_key in config.ini")
|
||||
|
||||
# Configure OpenAI client
|
||||
self.client = OpenAI(
|
||||
api_key=self.config["openai"]["api_key"], base_url=self.config["openai"]["base_url"], default_headers={"HTTP-Referer": "https://streamlit-chat-app.com", "X-Title": "Streamlit Chat App"}
|
||||
)
|
||||
|
||||
# Initialize MCP manager if configured
|
||||
self.mcp_manager = None
|
||||
if self.config.has_section("mcp"):
|
||||
mcp_config_path = self.config["mcp"].get("servers_json", "config/mcp_config.json")
|
||||
self.mcp_manager = SyncMCPManager(mcp_config_path)
|
||||
|
||||
def get_chat_response(self, messages):
|
||||
try:
|
||||
# Try using MCP if available
|
||||
if self.mcp_manager and self.mcp_manager.initialize():
|
||||
logger.info("Using MCP with tools...") # Use logger
|
||||
last_message = messages[-1]["content"]
|
||||
# Pass API key and base URL from config.ini
|
||||
response = self.mcp_manager.process_query(
|
||||
query=last_message,
|
||||
model_name=self.config["openai"]["model"],
|
||||
api_key=self.config["openai"]["api_key"],
|
||||
base_url=self.config["openai"].get("base_url"), # Use .get for optional base_url
|
||||
)
|
||||
|
||||
if "error" not in response:
|
||||
logger.debug("MCP processing successful, wrapping response.")
|
||||
# Convert to OpenAI-compatible response format
|
||||
return self._wrap_mcp_response(response)
|
||||
|
||||
# Fall back to standard OpenAI
|
||||
logger.info(f"Falling back to standard OpenAI API with model: {self.config['openai']['model']}") # Use logger
|
||||
return self.client.chat.completions.create(model=self.config["openai"]["model"], messages=messages, stream=True)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"API Error (Code: {getattr(e, 'code', 'N/A')}): {str(e)}"
|
||||
logger.error(error_msg, exc_info=True) # Use logger
|
||||
raise Exception(error_msg)
|
||||
|
||||
def _wrap_mcp_response(self, response: dict):
|
||||
"""Return the MCP response dictionary directly (for non-streaming)."""
|
||||
# No conversion needed if app.py handles dicts separately
|
||||
return response
|
||||
@@ -1,20 +1,16 @@
|
||||
# src/providers/__init__.py
|
||||
import logging
|
||||
|
||||
from providers.anthropic_provider import AnthropicProvider
|
||||
from providers.base import BaseProvider
|
||||
from providers.google_provider import GoogleProvider
|
||||
from providers.openai_provider import OpenAIProvider
|
||||
|
||||
# from providers.google_provider import GoogleProvider
|
||||
# from providers.openrouter_provider import OpenRouterProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Map provider names (lowercase) to their corresponding class implementations
|
||||
PROVIDER_MAP: dict[str, type[BaseProvider]] = {
|
||||
"openai": OpenAIProvider,
|
||||
"anthropic": AnthropicProvider,
|
||||
# "google": GoogleProvider,
|
||||
"google": GoogleProvider,
|
||||
# "openrouter": OpenRouterProvider, # OpenRouter can often use OpenAIProvider with custom base_url
|
||||
}
|
||||
|
||||
@@ -27,7 +23,7 @@ def register_provider(name: str, provider_class: type[BaseProvider]):
|
||||
logger.info(f"Registered provider: {name}")
|
||||
|
||||
|
||||
def create_llm_provider(provider_name: str, api_key: str, base_url: str | None = None) -> BaseProvider:
|
||||
def create_llm_provider(provider_name: str, api_key: str, base_url: str | None = None, temperature: float = 0.6) -> BaseProvider:
|
||||
"""
|
||||
Factory function to create an instance of a specific LLM provider.
|
||||
|
||||
@@ -48,9 +44,9 @@ def create_llm_provider(provider_name: str, api_key: str, base_url: str | None =
|
||||
available = ", ".join(PROVIDER_MAP.keys()) or "None"
|
||||
raise ValueError(f"Unsupported LLM provider: '{provider_name}'. Available providers: {available}")
|
||||
|
||||
logger.info(f"Creating LLM provider instance for: {provider_name}")
|
||||
logger.info(f"Creating LLM provider instance for: {provider_name} with temperature: {temperature}")
|
||||
try:
|
||||
return provider_class(api_key=api_key, base_url=base_url)
|
||||
return provider_class(api_key=api_key, base_url=base_url, temperature=temperature)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to instantiate provider '{provider_name}': {e}", exc_info=True)
|
||||
raise RuntimeError(f"Could not create provider '{provider_name}'.") from e
|
||||
@@ -59,11 +55,3 @@ def create_llm_provider(provider_name: str, api_key: str, base_url: str | None =
|
||||
def get_available_providers() -> list[str]:
|
||||
"""Returns a list of registered provider names."""
|
||||
return list(PROVIDER_MAP.keys())
|
||||
|
||||
|
||||
# Example of how specific providers would register themselves if structured as plugins,
|
||||
# but for now, we'll explicitly import and map them above.
|
||||
# def load_providers():
|
||||
# # Potentially load providers dynamically if designed as plugins
|
||||
# pass
|
||||
# load_providers()
|
||||
|
||||
@@ -1,295 +0,0 @@
|
||||
# src/providers/anthropic_provider.py
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from anthropic import Anthropic, Stream
|
||||
from anthropic.types import Message, MessageStreamEvent, TextDelta
|
||||
|
||||
# Use relative imports for modules within the same package
|
||||
from providers.base import BaseProvider
|
||||
|
||||
# Use absolute imports as per Ruff warning and user instructions
|
||||
from src.llm_models import MODELS
|
||||
from src.tools.conversion import convert_to_anthropic_tools
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnthropicProvider(BaseProvider):
|
||||
"""Provider implementation for Anthropic Claude models."""
|
||||
|
||||
def __init__(self, api_key: str, base_url: str | None = None):
|
||||
# Anthropic client doesn't use base_url in the same way, but store it if needed
|
||||
# Use default Anthropic endpoint if base_url is not provided or relevant
|
||||
effective_base_url = base_url or MODELS.get("anthropic", {}).get("endpoint")
|
||||
super().__init__(api_key, effective_base_url) # Pass base_url to parent, though Anthropic client might ignore it
|
||||
logger.info("Initializing AnthropicProvider")
|
||||
try:
|
||||
self.client = Anthropic(api_key=self.api_key)
|
||||
# Note: Anthropic client doesn't take base_url during init
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize Anthropic client: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def _convert_messages(self, messages: list[dict[str, Any]]) -> tuple[str | None, list[dict[str, Any]]]:
|
||||
"""Converts standard message format to Anthropic's format, extracting system prompt."""
|
||||
anthropic_messages = []
|
||||
system_prompt = None
|
||||
for i, message in enumerate(messages):
|
||||
role = message.get("role")
|
||||
content = message.get("content")
|
||||
|
||||
if role == "system":
|
||||
if i == 0:
|
||||
system_prompt = content
|
||||
logger.debug("Extracted system prompt for Anthropic.")
|
||||
else:
|
||||
# Handle system message not at the start (append to previous user message or add as user)
|
||||
logger.warning("System message found not at the beginning. Treating as user message.")
|
||||
anthropic_messages.append({"role": "user", "content": f"[System Note]\n{content}"})
|
||||
continue
|
||||
|
||||
# Handle tool results specifically
|
||||
if role == "tool":
|
||||
# Find the preceding assistant message with the corresponding tool_use block
|
||||
# This requires careful handling in the follow-up logic
|
||||
tool_use_id = message.get("tool_call_id")
|
||||
tool_content = content
|
||||
# Format as a tool_result content block
|
||||
anthropic_messages.append({"role": "user", "content": [{"type": "tool_result", "tool_use_id": tool_use_id, "content": tool_content}]})
|
||||
continue
|
||||
|
||||
# Handle assistant message potentially containing tool_use blocks
|
||||
if role == "assistant":
|
||||
# Check if content is structured (e.g., from a previous tool call response)
|
||||
if isinstance(content, list): # Assuming tool calls might be represented as a list
|
||||
anthropic_messages.append({"role": "assistant", "content": content})
|
||||
else:
|
||||
anthropic_messages.append({"role": "assistant", "content": content}) # Regular text content
|
||||
continue
|
||||
|
||||
# Regular user messages
|
||||
if role == "user":
|
||||
anthropic_messages.append({"role": "user", "content": content})
|
||||
continue
|
||||
|
||||
logger.warning(f"Unsupported role '{role}' in message conversion for Anthropic.")
|
||||
|
||||
# Ensure conversation starts with a user message if no system prompt was used
|
||||
if not system_prompt and anthropic_messages and anthropic_messages[0]["role"] != "user":
|
||||
logger.warning("Anthropic conversation must start with a user message. Prepending empty user message.")
|
||||
anthropic_messages.insert(0, {"role": "user", "content": "[Start of conversation]"}) # Or handle differently
|
||||
|
||||
return system_prompt, anthropic_messages
|
||||
|
||||
def create_chat_completion(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
model: str,
|
||||
temperature: float = 0.4,
|
||||
max_tokens: int | None = None, # Anthropic requires max_tokens
|
||||
stream: bool = True,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
) -> Stream[MessageStreamEvent] | Message:
|
||||
"""Creates a chat completion using the Anthropic API."""
|
||||
logger.debug(f"Anthropic create_chat_completion called. Stream: {stream}, Tools: {bool(tools)}")
|
||||
|
||||
# Anthropic requires max_tokens
|
||||
if max_tokens is None:
|
||||
max_tokens = 4096 # Default value if not provided
|
||||
logger.warning(f"max_tokens not provided for Anthropic, defaulting to {max_tokens}")
|
||||
|
||||
system_prompt, anthropic_messages = self._convert_messages(messages)
|
||||
|
||||
try:
|
||||
completion_params = {
|
||||
"model": model,
|
||||
"messages": anthropic_messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"stream": stream,
|
||||
}
|
||||
if system_prompt:
|
||||
completion_params["system"] = system_prompt
|
||||
if tools:
|
||||
completion_params["tools"] = tools
|
||||
# Anthropic doesn't have an explicit 'tool_choice' like OpenAI's 'auto' in the main API call
|
||||
|
||||
# Remove None values (though Anthropic requires max_tokens)
|
||||
completion_params = {k: v for k, v in completion_params.items() if v is not None}
|
||||
|
||||
log_params = completion_params.copy()
|
||||
if "messages" in log_params:
|
||||
log_params["messages"] = [{k: (v[:100] + "..." if isinstance(v, str) and len(v) > 100 else v) for k, v in msg.items()} for msg in log_params["messages"][-2:]]
|
||||
tools_log = log_params.get("tools", "Not Present")
|
||||
logger.debug(f"Calling Anthropic API. Model: {log_params.get('model')}, Stream: {log_params.get('stream')}, System: {bool(log_params.get('system'))}, Tools: {tools_log}")
|
||||
logger.debug(f"Full API Params (messages summarized): {log_params}")
|
||||
|
||||
response = self.client.messages.create(**completion_params)
|
||||
logger.debug("Anthropic API call successful.")
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"Anthropic API error: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def get_streaming_content(self, response: Stream[MessageStreamEvent]) -> Generator[str, None, None]:
|
||||
"""Yields content chunks from an Anthropic streaming response."""
|
||||
logger.debug("Processing Anthropic stream...")
|
||||
full_delta = ""
|
||||
try:
|
||||
# Iterate through events in the stream
|
||||
for event in response:
|
||||
if event.type == "content_block_delta":
|
||||
# Check if the delta is for text content before accessing .text
|
||||
if isinstance(event.delta, TextDelta):
|
||||
delta_text = event.delta.text
|
||||
if delta_text:
|
||||
full_delta += delta_text
|
||||
yield delta_text
|
||||
# Ignore other delta types like InputJSONDelta for text streaming
|
||||
# Other event types like 'message_start', 'content_block_start', etc., can be logged or handled if needed
|
||||
elif event.type == "message_start":
|
||||
logger.debug(f"Anthropic stream started. Model: {event.message.model}")
|
||||
elif event.type == "message_stop":
|
||||
# The stop_reason might be available on the 'message' object associated with the stream,
|
||||
# not directly on the stop event itself. We log that the stop event occurred.
|
||||
# Accessing the actual reason might require inspecting the final message state if needed.
|
||||
logger.debug("Anthropic stream message_stop event received.")
|
||||
elif event.type == "content_block_start":
|
||||
if event.content_block.type == "tool_use":
|
||||
logger.debug(f"Anthropic stream detected tool use start: ID {event.content_block.id}, Name: {event.content_block.name}")
|
||||
elif event.type == "content_block_stop":
|
||||
logger.debug(f"Anthropic stream detected content block stop. Index: {event.index}")
|
||||
|
||||
logger.debug(f"Anthropic stream finished. Total delta length: {len(full_delta)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing Anthropic stream: {e}", exc_info=True)
|
||||
yield json.dumps({"error": f"Stream processing error: {str(e)}"})
|
||||
|
||||
def get_content(self, response: Message) -> str:
|
||||
"""Extracts content from a non-streaming Anthropic response."""
|
||||
try:
|
||||
# Combine text content from all text blocks
|
||||
text_content = "".join([block.text for block in response.content if block.type == "text"])
|
||||
logger.debug(f"Extracted content (length {len(text_content)}) from non-streaming Anthropic response.")
|
||||
return text_content
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting content from Anthropic response: {e}", exc_info=True)
|
||||
return f"[Error extracting content: {str(e)}]"
|
||||
|
||||
def has_tool_calls(self, response: Stream[MessageStreamEvent] | Message) -> bool:
|
||||
"""Checks if the Anthropic response contains tool calls."""
|
||||
try:
|
||||
if isinstance(response, Message): # Non-streaming
|
||||
# Check stop reason and content blocks
|
||||
has_tool_use_block = any(block.type == "tool_use" for block in response.content)
|
||||
has_calls = response.stop_reason == "tool_use" or has_tool_use_block
|
||||
logger.debug(f"Non-streaming Anthropic response check: stop_reason='{response.stop_reason}', has_tool_use_block={has_tool_use_block}. Result: {has_calls}")
|
||||
return has_calls
|
||||
elif isinstance(response, Stream):
|
||||
# Cannot reliably check an unconsumed stream without consuming it.
|
||||
# The LLMClient should handle this by checking after consumption or based on stop_reason if available post-stream.
|
||||
logger.warning("has_tool_calls check on an Anthropic stream is unreliable before consumption.")
|
||||
return False
|
||||
else:
|
||||
logger.warning(f"has_tool_calls received unexpected type for Anthropic: {type(response)}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking for Anthropic tool calls: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def parse_tool_calls(self, response: Message) -> list[dict[str, Any]]:
|
||||
"""Parses tool calls from a non-streaming Anthropic response."""
|
||||
parsed_calls = []
|
||||
try:
|
||||
if not isinstance(response, Message):
|
||||
logger.error(f"parse_tool_calls expects Anthropic Message, got {type(response)}")
|
||||
return []
|
||||
|
||||
if response.stop_reason != "tool_use":
|
||||
logger.debug("No tool use indicated by stop_reason.")
|
||||
# return [] # Might still have tool_use blocks even if stop_reason isn't tool_use? Check API docs. Let's check content anyway.
|
||||
|
||||
tool_use_blocks = [block for block in response.content if block.type == "tool_use"]
|
||||
if not tool_use_blocks:
|
||||
logger.debug("No 'tool_use' content blocks found in Anthropic response.")
|
||||
return []
|
||||
|
||||
logger.debug(f"Parsing {len(tool_use_blocks)} 'tool_use' blocks from Anthropic response.")
|
||||
for block in tool_use_blocks:
|
||||
# Adapt server/tool name splitting if needed (similar to OpenAI provider)
|
||||
# Assuming Anthropic tool names might also be prefixed like "server__tool"
|
||||
parts = block.name.split("__", 1)
|
||||
if len(parts) == 2:
|
||||
server_name, func_name = parts
|
||||
else:
|
||||
logger.warning(f"Could not determine server_name from Anthropic tool name '{block.name}'.")
|
||||
server_name = None
|
||||
func_name = block.name
|
||||
|
||||
parsed_calls.append({
|
||||
"id": block.id,
|
||||
"server_name": server_name,
|
||||
"function_name": func_name,
|
||||
"arguments": json.dumps(block.input), # Anthropic input is already a dict, dump to string like OpenAI provider expects? Or keep as dict? Let's keep as dict for now.
|
||||
# "arguments": block.input, # Keep as dict? Let's try this first.
|
||||
})
|
||||
|
||||
return parsed_calls
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing Anthropic tool calls: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
def format_tool_results(self, tool_call_id: str, result: Any) -> dict[str, Any]:
|
||||
"""Formats a tool result for an Anthropic follow-up request."""
|
||||
# Anthropic expects a 'tool_result' content block
|
||||
# The content of the result block should typically be a string.
|
||||
try:
|
||||
if isinstance(result, dict):
|
||||
content_str = json.dumps(result)
|
||||
else:
|
||||
content_str = str(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error JSON-encoding tool result for Anthropic {tool_call_id}: {e}")
|
||||
content_str = json.dumps({"error": "Failed to encode tool result", "original_type": str(type(result))})
|
||||
|
||||
logger.debug(f"Formatting Anthropic tool result for call ID {tool_call_id}")
|
||||
# This needs to be placed inside a "user" role message's content list
|
||||
return {
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tool_call_id,
|
||||
"content": content_str,
|
||||
# Optionally add is_error=True if result indicates an error
|
||||
}
|
||||
|
||||
def convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Converts internal tool format to Anthropic's format."""
|
||||
# Use the conversion function, assuming it's correctly placed and imported
|
||||
logger.debug(f"Converting {len(tools)} tools to Anthropic format.")
|
||||
try:
|
||||
# The conversion function needs to handle the server__tool prefixing
|
||||
anthropic_tools = convert_to_anthropic_tools(tools)
|
||||
logger.debug(f"Tool conversion result: {anthropic_tools}")
|
||||
return anthropic_tools
|
||||
except Exception as e:
|
||||
logger.error(f"Error during Anthropic tool conversion: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
# Helper needed by LLMClient's current tool handling logic (if adapting OpenAI's pattern)
|
||||
def get_original_message_with_calls(self, response: Message) -> dict[str, Any]:
|
||||
"""Extracts the assistant's message containing tool calls for Anthropic."""
|
||||
try:
|
||||
if isinstance(response, Message) and any(block.type == "tool_use" for block in response.content):
|
||||
# Anthropic's response structure is different. The 'message' itself is the assistant's turn.
|
||||
# We need to return a representation of this turn, including the tool_use blocks.
|
||||
# Convert Pydantic models within content to dicts
|
||||
content_list = [block.model_dump(exclude_unset=True) for block in response.content]
|
||||
return {"role": "assistant", "content": content_list}
|
||||
else:
|
||||
logger.warning("Could not extract original message with tool calls from Anthropic response.")
|
||||
return {"role": "assistant", "content": "[Could not extract tool calls message]"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting original Anthropic message with calls: {e}", exc_info=True)
|
||||
return {"role": "assistant", "content": f"[Error extracting tool calls message: {str(e)}]"}
|
||||
44
src/providers/anthropic_provider/__init__.py
Normal file
44
src/providers/anthropic_provider/__init__.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from providers.anthropic_provider.client import initialize_client
|
||||
from providers.anthropic_provider.completion import create_chat_completion
|
||||
from providers.anthropic_provider.response import get_content, get_streaming_content, get_usage
|
||||
from providers.anthropic_provider.tools import convert_tools, format_tool_results, has_tool_calls, parse_tool_calls
|
||||
from providers.base import BaseProvider
|
||||
|
||||
|
||||
class AnthropicProvider(BaseProvider):
|
||||
temperature: float
|
||||
|
||||
def __init__(self, api_key: str, base_url: str | None = None, temperature: float = 0.6):
|
||||
self.client = initialize_client(api_key, base_url)
|
||||
self.temperature = temperature
|
||||
|
||||
def create_chat_completion(
|
||||
self,
|
||||
messages,
|
||||
model,
|
||||
max_tokens=None,
|
||||
stream=True,
|
||||
tools=None,
|
||||
):
|
||||
return create_chat_completion(self, messages, model, self.temperature, max_tokens, stream, tools)
|
||||
|
||||
def get_streaming_content(self, response):
|
||||
return get_streaming_content(response)
|
||||
|
||||
def get_content(self, response):
|
||||
return get_content(response)
|
||||
|
||||
def has_tool_calls(self, response):
|
||||
return has_tool_calls(response)
|
||||
|
||||
def parse_tool_calls(self, response):
|
||||
return parse_tool_calls(response)
|
||||
|
||||
def format_tool_results(self, tool_call_id, result):
|
||||
return format_tool_results(tool_call_id, result)
|
||||
|
||||
def convert_tools(self, tools):
|
||||
return convert_tools(tools)
|
||||
|
||||
def get_usage(self, response):
|
||||
return get_usage(response)
|
||||
17
src/providers/anthropic_provider/client.py
Normal file
17
src/providers/anthropic_provider/client.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import logging
|
||||
|
||||
from anthropic import Anthropic
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def initialize_client(api_key: str, base_url: str | None = None) -> Anthropic:
|
||||
logger.info("Initializing Anthropic client")
|
||||
try:
|
||||
client = Anthropic(api_key=api_key)
|
||||
if base_url:
|
||||
logger.warning(f"base_url '{base_url}' provided but not used by Anthropic client")
|
||||
return client
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize Anthropic client: {e}", exc_info=True)
|
||||
raise
|
||||
38
src/providers/anthropic_provider/completion.py
Normal file
38
src/providers/anthropic_provider/completion.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from anthropic import Stream
|
||||
from anthropic.types import Message
|
||||
|
||||
from providers.anthropic_provider.messages import convert_messages, truncate_messages
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_chat_completion(
|
||||
provider, messages: list[dict[str, Any]], model: str, temperature: float = 0.6, max_tokens: int | None = None, stream: bool = True, tools: list[dict[str, Any]] | None = None
|
||||
) -> Stream | Message:
|
||||
logger.debug(f"Creating Anthropic chat completion. Model: {model}, Stream: {stream}, Tools: {bool(tools)}")
|
||||
temp_system_prompt, temp_anthropic_messages = convert_messages(messages)
|
||||
truncated_messages, final_system_prompt, _, _ = truncate_messages(provider, temp_anthropic_messages, temp_system_prompt, model)
|
||||
if max_tokens is None:
|
||||
max_tokens = 4096
|
||||
logger.warning(f"max_tokens not provided, defaulting to {max_tokens}")
|
||||
completion_params = {
|
||||
"model": model,
|
||||
"messages": truncated_messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"stream": stream,
|
||||
}
|
||||
if final_system_prompt:
|
||||
completion_params["system"] = final_system_prompt
|
||||
if tools:
|
||||
completion_params["tools"] = tools
|
||||
try:
|
||||
response = provider.client.messages.create(**completion_params)
|
||||
logger.debug("Anthropic API call successful.")
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"Anthropic API error: {e}", exc_info=True)
|
||||
raise
|
||||
61
src/providers/anthropic_provider/messages.py
Normal file
61
src/providers/anthropic_provider/messages.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from providers.anthropic_provider.utils import count_anthropic_tokens, get_context_window
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def convert_messages(messages: list[dict[str, Any]]) -> tuple[str | None, list[dict[str, Any]]]:
|
||||
anthropic_messages = []
|
||||
system_prompt = None
|
||||
for i, message in enumerate(messages):
|
||||
role = message.get("role")
|
||||
content = message.get("content")
|
||||
if role == "system":
|
||||
if i == 0:
|
||||
system_prompt = content
|
||||
else:
|
||||
logger.warning("System message not at beginning. Treating as user message.")
|
||||
anthropic_messages.append({"role": "user", "content": f"[System Note]\n{content}"})
|
||||
continue
|
||||
if role == "tool":
|
||||
tool_use_id = message.get("tool_call_id")
|
||||
tool_content = content
|
||||
anthropic_messages.append({"role": "user", "content": [{"type": "tool_result", "tool_use_id": tool_use_id, "content": tool_content}]})
|
||||
continue
|
||||
if role == "assistant":
|
||||
if isinstance(content, list):
|
||||
anthropic_messages.append({"role": "assistant", "content": content})
|
||||
else:
|
||||
anthropic_messages.append({"role": "assistant", "content": content})
|
||||
continue
|
||||
if role == "user":
|
||||
anthropic_messages.append({"role": "user", "content": content})
|
||||
continue
|
||||
logger.warning(f"Unsupported role '{role}' in message conversion.")
|
||||
if not system_prompt and anthropic_messages and anthropic_messages[0]["role"] != "user":
|
||||
logger.warning("Conversation must start with user message. Prepending placeholder.")
|
||||
anthropic_messages.insert(0, {"role": "user", "content": "[Start of conversation]"})
|
||||
return system_prompt, anthropic_messages
|
||||
|
||||
|
||||
def truncate_messages(provider, messages: list[dict[str, Any]], system_prompt: str | None, model: str) -> tuple[list[dict[str, Any]], str | None, int, int]:
|
||||
context_limit = get_context_window(model)
|
||||
buffer = 200
|
||||
effective_limit = context_limit - buffer
|
||||
initial_token_count = count_anthropic_tokens(provider.client, messages, system_prompt)
|
||||
final_token_count = initial_token_count
|
||||
truncated_messages = list(messages)
|
||||
while final_token_count > effective_limit and len(truncated_messages) > 0:
|
||||
removed_message = truncated_messages.pop(0)
|
||||
logger.debug(f"Truncating message (Role: {removed_message.get('role')})")
|
||||
final_token_count = count_anthropic_tokens(provider.client, truncated_messages, system_prompt)
|
||||
if initial_token_count != final_token_count:
|
||||
logger.info(f"Truncated messages. Initial tokens: {initial_token_count}, Final: {final_token_count}")
|
||||
else:
|
||||
logger.debug(f"No truncation needed. Tokens: {final_token_count}")
|
||||
if not system_prompt and truncated_messages and truncated_messages[0].get("role") != "user":
|
||||
logger.warning("First message after truncation is not 'user'. Prepending placeholder.")
|
||||
truncated_messages.insert(0, {"role": "user", "content": "[Context truncated]"})
|
||||
return truncated_messages, system_prompt, initial_token_count, final_token_count
|
||||
62
src/providers/anthropic_provider/response.py
Normal file
62
src/providers/anthropic_provider/response.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from anthropic import Stream
|
||||
from anthropic.types import Message, MessageStreamEvent, TextDelta
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_streaming_content(response: Stream[MessageStreamEvent]) -> Generator[str, None, None]:
|
||||
logger.debug("Processing Anthropic stream...")
|
||||
full_delta = ""
|
||||
try:
|
||||
for event in response:
|
||||
if event.type == "content_block_delta":
|
||||
if isinstance(event.delta, TextDelta):
|
||||
delta_text = event.delta.text
|
||||
if delta_text:
|
||||
full_delta += delta_text
|
||||
yield delta_text
|
||||
elif event.type == "message_start":
|
||||
logger.debug(f"Stream started. Model: {event.message.model}")
|
||||
elif event.type == "message_stop":
|
||||
logger.debug("Stream message_stop event received.")
|
||||
elif event.type == "content_block_start":
|
||||
if event.content_block.type == "tool_use":
|
||||
logger.debug(f"Tool use start: ID {event.content_block.id}, Name: {event.content_block.name}")
|
||||
elif event.type == "content_block_stop":
|
||||
logger.debug(f"Content block stop. Index: {event.index}")
|
||||
logger.debug(f"Stream finished. Total delta length: {len(full_delta)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing stream: {e}", exc_info=True)
|
||||
yield json.dumps({"error": f"Stream processing error: {str(e)}"})
|
||||
|
||||
|
||||
def get_content(response: Message) -> str:
|
||||
try:
|
||||
text_content = "".join([block.text for block in response.content if block.type == "text"])
|
||||
logger.debug(f"Extracted content (length {len(text_content)})")
|
||||
return text_content
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting content: {e}", exc_info=True)
|
||||
return f"[Error extracting content: {str(e)}]"
|
||||
|
||||
|
||||
def get_usage(response: Any) -> dict[str, int] | None:
|
||||
try:
|
||||
if isinstance(response, Message) and response.usage:
|
||||
usage = {
|
||||
"prompt_tokens": response.usage.input_tokens,
|
||||
"completion_tokens": response.usage.output_tokens,
|
||||
}
|
||||
logger.debug(f"Extracted usage: {usage}")
|
||||
return usage
|
||||
else:
|
||||
logger.warning(f"Could not extract usage from {type(response)}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting usage: {e}", exc_info=True)
|
||||
return None
|
||||
115
src/providers/anthropic_provider/tools.py
Normal file
115
src/providers/anthropic_provider/tools.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from anthropic.types import Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def has_tool_calls(response: Any) -> bool:
|
||||
try:
|
||||
if isinstance(response, Message):
|
||||
has_tool_use_block = any(block.type == "tool_use" for block in response.content)
|
||||
has_calls = response.stop_reason == "tool_use" or has_tool_use_block
|
||||
logger.debug(f"Tool calls check: stop_reason='{response.stop_reason}', has_tool_use_block={has_tool_use_block}. Result: {has_calls}")
|
||||
return has_calls
|
||||
else:
|
||||
logger.warning(f"has_tool_calls received unexpected type: {type(response)}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking for tool calls: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
|
||||
def parse_tool_calls(response: Message) -> list[dict[str, Any]]:
|
||||
parsed_calls = []
|
||||
try:
|
||||
if not isinstance(response, Message):
|
||||
logger.error(f"parse_tool_calls expects Message, got {type(response)}")
|
||||
return []
|
||||
tool_use_blocks = [block for block in response.content if block.type == "tool_use"]
|
||||
if not tool_use_blocks:
|
||||
logger.debug("No 'tool_use' content blocks found.")
|
||||
return []
|
||||
logger.debug(f"Parsing {len(tool_use_blocks)} 'tool_use' blocks.")
|
||||
for block in tool_use_blocks:
|
||||
parts = block.name.split("__", 1)
|
||||
if len(parts) == 2:
|
||||
server_name, func_name = parts
|
||||
else:
|
||||
logger.warning(f"Could not determine server_name from tool name '{block.name}'.")
|
||||
server_name = None
|
||||
func_name = block.name
|
||||
parsed_calls.append({"id": block.id, "server_name": server_name, "function_name": func_name, "arguments": block.input})
|
||||
return parsed_calls
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing tool calls: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
|
||||
def format_tool_results(tool_call_id: str, result: Any) -> dict[str, Any]:
|
||||
try:
|
||||
if isinstance(result, dict):
|
||||
content_str = json.dumps(result)
|
||||
else:
|
||||
content_str = str(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error encoding tool result for {tool_call_id}: {e}")
|
||||
content_str = json.dumps({"error": "Failed to encode tool result", "original_type": str(type(result))})
|
||||
logger.debug(f"Formatting tool result for call ID {tool_call_id}")
|
||||
return {"type": "tool_result", "tool_use_id": tool_call_id, "content": content_str}
|
||||
|
||||
|
||||
def convert_to_anthropic_tools(mcp_tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Convert MCP tools to Anthropic tool definitions.
|
||||
|
||||
Args:
|
||||
mcp_tools: List of MCP tools (each with server_name, name, description, inputSchema).
|
||||
|
||||
Returns:
|
||||
List of Anthropic tool definitions.
|
||||
"""
|
||||
logger.debug(f"Converting {len(mcp_tools)} MCP tools to Anthropic format")
|
||||
anthropic_tools = []
|
||||
|
||||
for tool in mcp_tools:
|
||||
server_name = tool.get("server_name")
|
||||
tool_name = tool.get("name")
|
||||
description = tool.get("description")
|
||||
input_schema = tool.get("inputSchema")
|
||||
|
||||
if not server_name or not tool_name or not description or not input_schema:
|
||||
logger.warning(f"Skipping invalid MCP tool definition during Anthropic conversion: {tool}")
|
||||
continue
|
||||
|
||||
prefixed_tool_name = f"{server_name}__{tool_name}"
|
||||
|
||||
anthropic_tool = {"name": prefixed_tool_name, "description": description, "input_schema": input_schema}
|
||||
|
||||
if not isinstance(input_schema, dict) or input_schema.get("type") != "object":
|
||||
logger.warning(f"Input schema for tool '{prefixed_tool_name}' is not a valid JSON object schema. Anthropic might reject this.")
|
||||
if not isinstance(input_schema, dict):
|
||||
input_schema = {}
|
||||
if "type" not in input_schema:
|
||||
input_schema["type"] = "object"
|
||||
if "properties" not in input_schema:
|
||||
input_schema["properties"] = {}
|
||||
anthropic_tool["input_schema"] = input_schema
|
||||
|
||||
anthropic_tools.append(anthropic_tool)
|
||||
logger.debug(f"Converted MCP tool to Anthropic: {prefixed_tool_name}")
|
||||
|
||||
return anthropic_tools
|
||||
|
||||
|
||||
def convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
logger.debug(f"Converting {len(tools)} tools to Anthropic format.")
|
||||
try:
|
||||
anthropic_tools = convert_to_anthropic_tools(tools)
|
||||
logger.debug(f"Tool conversion result: {anthropic_tools}")
|
||||
return anthropic_tools
|
||||
except Exception as e:
|
||||
logger.error(f"Error during tool conversion: {e}", exc_info=True)
|
||||
return []
|
||||
50
src/providers/anthropic_provider/utils.py
Normal file
50
src/providers/anthropic_provider/utils.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
from typing import Any
|
||||
|
||||
from anthropic import Anthropic
|
||||
|
||||
from src.llm_models import MODELS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_context_window(model: str) -> int:
|
||||
default_window = 100000
|
||||
try:
|
||||
provider_models = MODELS.get("anthropic", {}).get("models", [])
|
||||
for m in provider_models:
|
||||
if m.get("id") == model:
|
||||
return m.get("context_window", default_window)
|
||||
logger.warning(f"Context window for Anthropic model '{model}' not found. Using default: {default_window}")
|
||||
return default_window
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving context window for model {model}: {e}. Using default: {default_window}", exc_info=True)
|
||||
return default_window
|
||||
|
||||
|
||||
def count_anthropic_tokens(client: Anthropic, messages: list[dict[str, Any]], system_prompt: str | None) -> int:
|
||||
text_to_count = ""
|
||||
if system_prompt:
|
||||
text_to_count += f"System: {system_prompt}\n\n"
|
||||
for message in messages:
|
||||
role = message.get("role")
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
text_to_count += f"{role}: {content}\n"
|
||||
elif isinstance(content, list):
|
||||
try:
|
||||
content_str = json.dumps(content)
|
||||
text_to_count += f"{role}: {content_str}\n"
|
||||
except Exception:
|
||||
text_to_count += f"{role}: [Unserializable Content]\n"
|
||||
try:
|
||||
count = client.count_tokens(text=text_to_count)
|
||||
logger.debug(f"Counted Anthropic tokens: {count}")
|
||||
return count
|
||||
except Exception as e:
|
||||
logger.error(f"Error counting Anthropic tokens: {e}", exc_info=True)
|
||||
estimated_tokens = math.ceil(len(text_to_count) / 4.0)
|
||||
logger.warning(f"Falling back to approximation: {estimated_tokens}")
|
||||
return estimated_tokens
|
||||
@@ -1,4 +1,3 @@
|
||||
# src/providers/base.py
|
||||
import abc
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
@@ -28,7 +27,7 @@ class BaseProvider(abc.ABC):
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
model: str,
|
||||
temperature: float = 0.4,
|
||||
temperature: float = 0.6,
|
||||
max_tokens: int | None = None,
|
||||
stream: bool = True,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
@@ -39,7 +38,7 @@ class BaseProvider(abc.ABC):
|
||||
Args:
|
||||
messages: List of message dictionaries with 'role' and 'content'.
|
||||
model: Model identifier.
|
||||
temperature: Sampling temperature (0-1).
|
||||
temperature: Sampling temperature (0-2).
|
||||
max_tokens: Maximum tokens to generate.
|
||||
stream: Whether to stream the response.
|
||||
tools: Optional list of tools in the provider-specific format.
|
||||
@@ -134,7 +133,16 @@ class BaseProvider(abc.ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
# Optional: Add a method for follow-up completions if the provider API
|
||||
# requires a specific structure different from just appending messages.
|
||||
# def create_follow_up_completion(...) -> Any:
|
||||
# pass
|
||||
@abc.abstractmethod
|
||||
def get_usage(self, response: Any) -> dict[str, int] | None:
|
||||
"""
|
||||
Extracts token usage information from a non-streaming response object.
|
||||
|
||||
Args:
|
||||
response: The non-streaming response object.
|
||||
|
||||
Returns:
|
||||
A dictionary containing 'prompt_tokens' and 'completion_tokens',
|
||||
or None if usage information is not available.
|
||||
"""
|
||||
pass
|
||||
|
||||
78
src/providers/google_provider/__init__.py
Normal file
78
src/providers/google_provider/__init__.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from google.genai.types import GenerateContentResponse
|
||||
|
||||
from providers.google_provider.client import initialize_client
|
||||
from providers.google_provider.completion import create_chat_completion
|
||||
from providers.google_provider.response import get_content, get_streaming_content, get_usage
|
||||
from providers.google_provider.tools import convert_to_google_tools, format_google_tool_results, has_google_tool_calls, parse_google_tool_calls
|
||||
from src.providers.base import BaseProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GoogleProvider(BaseProvider):
|
||||
"""Provider implementation for Google Generative AI (Gemini)."""
|
||||
|
||||
client_module: Any
|
||||
temperature: float
|
||||
|
||||
def __init__(self, api_key: str, base_url: str | None = None, temperature: float = 0.6):
|
||||
"""
|
||||
Initializes the GoogleProvider.
|
||||
|
||||
Args:
|
||||
api_key: The Google API key.
|
||||
base_url: Base URL (typically not used by Google client config, but kept for interface consistency).
|
||||
temperature: The default temperature for completions.
|
||||
"""
|
||||
self.client_module = initialize_client(api_key, base_url)
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
self.temperature = temperature
|
||||
logger.info(f"GoogleProvider initialized with temperature: {self.temperature}")
|
||||
|
||||
def create_chat_completion(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
model: str,
|
||||
max_tokens: int | None = None,
|
||||
stream: bool = True,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
) -> Any:
|
||||
"""Creates a chat completion using the Google Gemini API."""
|
||||
raw_response = create_chat_completion(self, messages, model, self.temperature, max_tokens, stream, tools)
|
||||
print(f"Raw response type: {type(raw_response)}")
|
||||
print(f"Raw response: {raw_response}")
|
||||
|
||||
return raw_response
|
||||
|
||||
def get_streaming_content(self, response: Any) -> Generator[str, None, None]:
|
||||
"""Extracts content chunks from a Google streaming response."""
|
||||
return get_streaming_content(response)
|
||||
|
||||
def get_content(self, response: GenerateContentResponse | dict[str, Any]) -> str:
|
||||
"""Extracts the full text content from a non-streaming Google response."""
|
||||
return get_content(response)
|
||||
|
||||
def has_tool_calls(self, response: GenerateContentResponse | dict[str, Any]) -> bool:
|
||||
"""Checks if the Google response contains tool calls (FunctionCalls)."""
|
||||
return has_google_tool_calls(response)
|
||||
|
||||
def parse_tool_calls(self, response: GenerateContentResponse | dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""Parses tool calls (FunctionCalls) from a non-streaming Google response."""
|
||||
return parse_google_tool_calls(response)
|
||||
|
||||
def format_tool_results(self, tool_call_id: str, function_name: str, result: Any) -> dict[str, Any]:
|
||||
"""Formats a tool result for a Google follow-up request (into standard message format)."""
|
||||
return format_google_tool_results(tool_call_id, function_name, result)
|
||||
|
||||
def convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Converts MCP tools list to Google's intermediate dictionary format."""
|
||||
return convert_to_google_tools(tools)
|
||||
|
||||
def get_usage(self, response: GenerateContentResponse | dict[str, Any]) -> dict[str, int] | None:
|
||||
"""Extracts token usage information from a Google response."""
|
||||
return get_usage(response)
|
||||
25
src/providers/google_provider/client.py
Normal file
25
src/providers/google_provider/client.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from google import genai
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def initialize_client(api_key: str, base_url: str | None = None) -> Any:
|
||||
"""Initializes and returns the Google Generative AI client module."""
|
||||
logger.info("Initializing Google Generative AI client")
|
||||
|
||||
if genai is None:
|
||||
logger.error("Google Generative AI SDK (google-genai) is not installed.")
|
||||
raise ImportError("Google Generative AI SDK is required for GoogleProvider. Please install google-generativeai.")
|
||||
|
||||
try:
|
||||
client = genai.Client(api_key=api_key)
|
||||
logger.info("Google Generative AI client instantiated.")
|
||||
if base_url:
|
||||
logger.warning(f"base_url '{base_url}' provided but not typically used by Google client instantiation.")
|
||||
return client
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to instantiate Google Generative AI client: {e}", exc_info=True)
|
||||
raise
|
||||
140
src/providers/google_provider/completion.py
Normal file
140
src/providers/google_provider/completion.py
Normal file
@@ -0,0 +1,140 @@
|
||||
import logging
|
||||
import traceback
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
from google.genai.types import ContentDict, GenerateContentResponse, GenerationConfigDict, Tool
|
||||
|
||||
from providers.google_provider.tools import convert_to_google_tool_objects
|
||||
from providers.google_provider.utils import convert_messages
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _create_chat_completion_non_stream(
|
||||
provider,
|
||||
model: str,
|
||||
google_messages: list[ContentDict],
|
||||
generation_config: GenerationConfigDict,
|
||||
) -> GenerateContentResponse | dict[str, Any]:
|
||||
"""Handles the non-streaming API call."""
|
||||
try:
|
||||
logger.debug("Calling client.models.generate_content...")
|
||||
response = provider.client_module.models.generate_content(
|
||||
model=f"models/{model}",
|
||||
contents=google_messages,
|
||||
config=generation_config,
|
||||
)
|
||||
logger.debug("generate_content call successful, returning raw response object.")
|
||||
return response
|
||||
except ValueError as ve:
|
||||
error_msg = f"Google API request validation error: {ve}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
return {"error": error_msg, "traceback": traceback.format_exc()}
|
||||
except Exception as e:
|
||||
error_msg = f"Google API error during non-stream chat completion: {e}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
return {"error": error_msg, "traceback": traceback.format_exc()}
|
||||
|
||||
|
||||
def _create_chat_completion_stream(
|
||||
provider,
|
||||
model: str,
|
||||
google_messages: list[ContentDict],
|
||||
generation_config: GenerationConfigDict,
|
||||
) -> Iterable[GenerateContentResponse | dict[str, Any]]:
|
||||
"""Handles the streaming API call and yields results."""
|
||||
try:
|
||||
logger.debug("Calling client.models.generate_content_stream...")
|
||||
response_iterator = provider.client_module.models.generate_content_stream(
|
||||
model=f"models/{model}",
|
||||
contents=google_messages,
|
||||
config=generation_config,
|
||||
)
|
||||
logger.debug("generate_content_stream call successful, yielding from iterator.")
|
||||
yield from response_iterator
|
||||
except ValueError as ve:
|
||||
error_msg = f"Google API request validation error: {ve}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
yield {"error": error_msg, "traceback": traceback.format_exc()}
|
||||
except Exception as e:
|
||||
error_msg = f"Google API error during stream chat completion: {e}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
yield {"error": error_msg, "traceback": traceback.format_exc()}
|
||||
|
||||
|
||||
def create_chat_completion(
|
||||
provider,
|
||||
messages: list[dict[str, Any]],
|
||||
model: str,
|
||||
temperature: float = 0.6,
|
||||
max_tokens: int | None = None,
|
||||
stream: bool = True,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Creates a chat completion using the Google Gemini API.
|
||||
Delegates to streaming or non-streaming helpers. Contains NO yield itself.
|
||||
"""
|
||||
logger.debug(f"Google create_chat_completion_inner called. Model: {model}, Stream: {stream}, Tools: {bool(tools)}")
|
||||
|
||||
if provider.client_module is None:
|
||||
error_msg = "Google Generative AI client not initialized on provider."
|
||||
logger.error(error_msg)
|
||||
return iter([{"error": error_msg}]) if stream else {"error": error_msg}
|
||||
|
||||
try:
|
||||
google_messages, system_prompt = convert_messages(messages)
|
||||
logger.debug(f"Converted {len(messages)} messages to {len(google_messages)} Google Content objects. System prompt present: {bool(system_prompt)}")
|
||||
|
||||
generation_config: GenerationConfigDict = {"temperature": temperature}
|
||||
if max_tokens is not None:
|
||||
generation_config["max_output_tokens"] = max_tokens
|
||||
logger.debug(f"Setting max_output_tokens: {max_tokens}")
|
||||
else:
|
||||
default_max_tokens = 8192
|
||||
generation_config["max_output_tokens"] = default_max_tokens
|
||||
logger.warning(f"max_tokens not provided, defaulting to {default_max_tokens} for Google API.")
|
||||
|
||||
google_tool_objects: list[Tool] | None = None
|
||||
if tools:
|
||||
try:
|
||||
google_tool_objects = convert_to_google_tool_objects(tools)
|
||||
if google_tool_objects:
|
||||
num_declarations = sum(len(t.function_declarations) for t in google_tool_objects if t.function_declarations)
|
||||
logger.debug(f"Successfully converted intermediate tool config to {len(google_tool_objects)} Google Tool objects with {num_declarations} declarations.")
|
||||
else:
|
||||
logger.warning("Tool conversion resulted in no valid Google Tool objects.")
|
||||
except Exception as tool_conv_err:
|
||||
logger.error(f"Failed to convert tools for Google: {tool_conv_err}", exc_info=True)
|
||||
google_tool_objects = None
|
||||
else:
|
||||
logger.debug("No tools provided for conversion.")
|
||||
|
||||
if system_prompt:
|
||||
generation_config["system_instruction"] = system_prompt
|
||||
logger.debug("Added system_instruction to generation_config.")
|
||||
if google_tool_objects:
|
||||
generation_config["tools"] = google_tool_objects
|
||||
logger.debug(f"Added {len(google_tool_objects)} tool objects to generation_config.")
|
||||
|
||||
log_params = {
|
||||
"model": model,
|
||||
"stream": stream,
|
||||
"temperature": temperature,
|
||||
"max_output_tokens": generation_config.get("max_output_tokens"),
|
||||
"system_prompt_present": bool(system_prompt),
|
||||
"num_tools": len(generation_config.get("tools", [])) if "tools" in generation_config else 0,
|
||||
"num_messages": len(google_messages),
|
||||
}
|
||||
logger.info(f"Calling Google API via helper with params: {log_params}")
|
||||
|
||||
if stream:
|
||||
return _create_chat_completion_stream(provider, model, google_messages, generation_config)
|
||||
else:
|
||||
return _create_chat_completion_non_stream(provider, model, google_messages, generation_config)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error during Google completion setup: {e}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
return iter([{"error": error_msg, "traceback": traceback.format_exc()}]) if stream else {"error": error_msg, "traceback": traceback.format_exc()}
|
||||
185
src/providers/google_provider/response.py
Normal file
185
src/providers/google_provider/response.py
Normal file
@@ -0,0 +1,185 @@
|
||||
"""
|
||||
Response handling utilities specific to the Google Generative AI provider.
|
||||
|
||||
Includes functions for:
|
||||
- Extracting content from streaming responses.
|
||||
- Extracting content from non-streaming responses.
|
||||
- Extracting token usage information.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from google.genai.types import GenerateContentResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_streaming_content(response: Any) -> Generator[str, None, None]:
|
||||
"""
|
||||
Yields content chunks (text) from a Google streaming response iterator.
|
||||
|
||||
Args:
|
||||
response: The streaming response iterator returned by `generate_content(stream=True)`.
|
||||
|
||||
Yields:
|
||||
String chunks of the generated text content.
|
||||
May yield JSON strings containing error information if errors occur during streaming.
|
||||
"""
|
||||
logger.debug("Processing Google stream...")
|
||||
full_delta = ""
|
||||
try:
|
||||
if isinstance(response, dict) and "error" in response:
|
||||
yield json.dumps(response)
|
||||
logger.error(f"Stream processing stopped due to initial error: {response['error']}")
|
||||
return
|
||||
if hasattr(response, "__iter__") and not hasattr(response, "candidates"):
|
||||
first_item = next(response, None)
|
||||
if first_item and isinstance(first_item, str):
|
||||
try:
|
||||
error_data = json.loads(first_item)
|
||||
if "error" in error_data:
|
||||
yield first_item
|
||||
yield from response
|
||||
logger.error(f"Stream processing stopped due to yielded error: {error_data['error']}")
|
||||
return
|
||||
except json.JSONDecodeError:
|
||||
yield first_item
|
||||
elif first_item:
|
||||
pass
|
||||
|
||||
for chunk in response:
|
||||
if isinstance(chunk, dict) and "error" in chunk:
|
||||
yield json.dumps(chunk)
|
||||
logger.error(f"Error encountered during Google stream: {chunk['error']}")
|
||||
continue
|
||||
|
||||
delta = ""
|
||||
try:
|
||||
if hasattr(chunk, "text"):
|
||||
delta = chunk.text
|
||||
elif hasattr(chunk, "candidates") and chunk.candidates:
|
||||
first_candidate = chunk.candidates[0]
|
||||
if hasattr(first_candidate, "content") and hasattr(first_candidate.content, "parts") and first_candidate.content.parts:
|
||||
first_part = first_candidate.content.parts[0]
|
||||
if hasattr(first_part, "text"):
|
||||
delta = first_part.text
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not extract text from stream chunk: {chunk}. Error: {e}", exc_info=True)
|
||||
delta = ""
|
||||
|
||||
if delta:
|
||||
full_delta += delta
|
||||
yield delta
|
||||
|
||||
try:
|
||||
if hasattr(chunk, "candidates") and chunk.candidates:
|
||||
for part in chunk.candidates[0].content.parts:
|
||||
if hasattr(part, "function_call") and part.function_call:
|
||||
logger.debug(f"Function call detected during stream: {part.function_call.name}")
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.debug(f"Google stream finished. Total delta length: {len(full_delta)}")
|
||||
|
||||
except StopIteration:
|
||||
logger.debug("Google stream finished (StopIteration).")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing Google stream: {e}", exc_info=True)
|
||||
yield json.dumps({"error": f"Stream processing error: {str(e)}"})
|
||||
|
||||
|
||||
def get_content(response: GenerateContentResponse | dict[str, Any]) -> str:
|
||||
"""
|
||||
Extracts the full text content from a non-streaming Google response.
|
||||
|
||||
Args:
|
||||
response: The non-streaming response object (`GenerateContentResponse`) or
|
||||
an error dictionary.
|
||||
|
||||
Returns:
|
||||
The concatenated text content, or an error message string.
|
||||
"""
|
||||
try:
|
||||
if isinstance(response, dict) and "error" in response:
|
||||
logger.error(f"Cannot get content from error dict: {response['error']}")
|
||||
return f"[Error: {response['error']}]"
|
||||
|
||||
if not isinstance(response, GenerateContentResponse):
|
||||
logger.error(f"Cannot get content: Expected GenerateContentResponse or error dict, got {type(response)}")
|
||||
return f"[Error: Unexpected response type {type(response)}]"
|
||||
|
||||
if hasattr(response, "text") and response.text:
|
||||
content = response.text
|
||||
logger.debug(f"Extracted content (length {len(content)}) from response.text.")
|
||||
return content
|
||||
|
||||
if hasattr(response, "candidates") and response.candidates:
|
||||
first_candidate = response.candidates[0]
|
||||
if hasattr(first_candidate, "content") and first_candidate.content and hasattr(first_candidate.content, "parts") and first_candidate.content.parts:
|
||||
text_parts = [part.text for part in first_candidate.content.parts if hasattr(part, "text")]
|
||||
if text_parts:
|
||||
content = "".join(text_parts)
|
||||
logger.debug(f"Extracted content (length {len(content)}) from response candidate parts.")
|
||||
return content
|
||||
else:
|
||||
logger.warning("Google response candidate parts contained no text.")
|
||||
return ""
|
||||
else:
|
||||
logger.warning("Google response candidate has no valid content or parts.")
|
||||
return ""
|
||||
else:
|
||||
logger.warning(f"Could not extract content from Google response: No .text or valid candidates found. Response: {response}")
|
||||
return ""
|
||||
|
||||
except AttributeError as ae:
|
||||
logger.error(f"Attribute error extracting content from Google response: {ae}. Response type: {type(response)}", exc_info=True)
|
||||
return f"[Error extracting content: Attribute missing - {str(ae)}]"
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error extracting content from Google response: {e}", exc_info=True)
|
||||
return f"[Error extracting content: {str(e)}]"
|
||||
|
||||
|
||||
def get_usage(response: GenerateContentResponse | dict[str, Any]) -> dict[str, int] | None:
|
||||
"""
|
||||
Extracts token usage information from a Google response object.
|
||||
|
||||
Args:
|
||||
response: The response object (`GenerateContentResponse`) or an error dictionary.
|
||||
|
||||
Returns:
|
||||
A dictionary containing 'prompt_tokens' and 'completion_tokens', or None if
|
||||
usage information is unavailable or an error occurred.
|
||||
"""
|
||||
try:
|
||||
if isinstance(response, dict) and "error" in response:
|
||||
logger.warning(f"Cannot get usage from error dict: {response['error']}")
|
||||
return None
|
||||
|
||||
if not isinstance(response, GenerateContentResponse):
|
||||
logger.warning(f"Cannot get usage: Expected GenerateContentResponse or error dict, got {type(response)}")
|
||||
return None
|
||||
|
||||
metadata = getattr(response, "usage_metadata", None)
|
||||
if metadata:
|
||||
prompt_tokens = getattr(metadata, "prompt_token_count", 0)
|
||||
completion_tokens = getattr(metadata, "candidates_token_count", 0)
|
||||
usage = {
|
||||
"prompt_tokens": int(prompt_tokens),
|
||||
"completion_tokens": int(completion_tokens),
|
||||
}
|
||||
logger.debug(f"Extracted usage from Google response metadata: {usage}")
|
||||
return usage
|
||||
else:
|
||||
logger.warning(f"Could not extract usage from Google response object: No 'usage_metadata' attribute found. Response: {response}")
|
||||
return None
|
||||
|
||||
except AttributeError as ae:
|
||||
logger.error(f"Attribute error extracting usage from Google response: {ae}. Response type: {type(response)}", exc_info=True)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error extracting usage from Google response: {e}", exc_info=True)
|
||||
return None
|
||||
365
src/providers/google_provider/tools.py
Normal file
365
src/providers/google_provider/tools.py
Normal file
@@ -0,0 +1,365 @@
|
||||
"""
|
||||
Tool handling utilities specific to the Google Generative AI provider.
|
||||
|
||||
Includes functions for:
|
||||
- Converting MCP tool definitions to Google's format.
|
||||
- Creating Google Tool/FunctionDeclaration objects.
|
||||
- Parsing tool calls (FunctionCalls) from Google responses.
|
||||
- Formatting tool results for subsequent API calls.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from google.genai.types import FunctionDeclaration, Schema, Tool, Type
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def convert_to_google_tools(mcp_tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Convert MCP tools to Google Gemini format (dictionary structure).
|
||||
|
||||
This format is an intermediate step before creating Tool objects.
|
||||
|
||||
Args:
|
||||
mcp_tools: List of MCP tools (each with server_name, name, description, inputSchema).
|
||||
|
||||
Returns:
|
||||
List containing one dictionary with 'function_declarations'.
|
||||
Returns an empty list if no valid tools are provided or converted.
|
||||
"""
|
||||
logger.debug(f"Converting {len(mcp_tools)} MCP tools to Google Gemini format")
|
||||
|
||||
function_declarations = []
|
||||
|
||||
for tool in mcp_tools:
|
||||
server_name = tool.get("server_name")
|
||||
tool_name = tool.get("name")
|
||||
description = tool.get("description")
|
||||
input_schema = tool.get("inputSchema")
|
||||
|
||||
if not server_name or not tool_name or not description or not input_schema:
|
||||
logger.warning(f"Skipping invalid MCP tool definition during Google conversion: {tool}")
|
||||
continue
|
||||
|
||||
prefixed_tool_name = f"{server_name}__{tool_name}"
|
||||
|
||||
if not isinstance(input_schema, dict) or input_schema.get("type") != "object":
|
||||
logger.warning(f"Input schema for tool '{prefixed_tool_name}' is not a valid JSON object schema. Google might reject this. Attempting to normalize.")
|
||||
if not isinstance(input_schema, dict):
|
||||
input_schema = {}
|
||||
if "type" not in input_schema or input_schema["type"] != "object":
|
||||
input_schema = {"type": "object", "properties": {"_original_schema": input_schema}} if input_schema else {"type": "object", "properties": {}}
|
||||
logger.warning(f"Wrapped original schema for {prefixed_tool_name} under '_original_schema' property.")
|
||||
|
||||
if "properties" not in input_schema:
|
||||
input_schema["properties"] = {}
|
||||
|
||||
if not input_schema["properties"]:
|
||||
logger.warning(f"Empty properties for tool '{prefixed_tool_name}', adding dummy property for Google.")
|
||||
input_schema["properties"] = {"_dummy_param": {"type": "STRING", "description": "Placeholder parameter as properties cannot be empty."}}
|
||||
if "required" in input_schema and not isinstance(input_schema.get("required"), list):
|
||||
input_schema["required"] = []
|
||||
|
||||
function_declaration = {
|
||||
"name": prefixed_tool_name,
|
||||
"description": description,
|
||||
"parameters": input_schema,
|
||||
}
|
||||
|
||||
function_declarations.append(function_declaration)
|
||||
logger.debug(f"Prepared Google FunctionDeclaration dict for: {prefixed_tool_name}")
|
||||
|
||||
google_tool_config = [{"function_declarations": function_declarations}] if function_declarations else []
|
||||
|
||||
logger.debug(f"Final Google tool config structure (pre-Tool object): {google_tool_config}")
|
||||
return google_tool_config
|
||||
|
||||
|
||||
def _create_google_schema_recursive(schema_dict: dict[str, Any]) -> Schema | None:
|
||||
"""
|
||||
Recursively creates Google Schema objects from a JSON schema dictionary.
|
||||
|
||||
Handles type mapping and nested structures. Returns None on failure.
|
||||
"""
|
||||
if Schema is None or Type is None:
|
||||
logger.error("Cannot create Schema object: google.genai types (Schema or Type) not available.")
|
||||
return None
|
||||
|
||||
if not isinstance(schema_dict, dict):
|
||||
logger.warning(f"Invalid schema part encountered: {schema_dict}. Returning None.")
|
||||
return None
|
||||
|
||||
type_mapping = {
|
||||
"string": Type.STRING,
|
||||
"number": Type.NUMBER,
|
||||
"integer": Type.INTEGER,
|
||||
"boolean": Type.BOOLEAN,
|
||||
"array": Type.ARRAY,
|
||||
"object": Type.OBJECT,
|
||||
}
|
||||
original_type = schema_dict.get("type")
|
||||
google_type = type_mapping.get(str(original_type).lower()) if original_type else None
|
||||
|
||||
if not google_type:
|
||||
logger.warning(f"Schema dictionary missing 'type' or type '{original_type}' is not recognized: {schema_dict}. Returning None.")
|
||||
return None
|
||||
|
||||
schema_args = {
|
||||
"type": google_type,
|
||||
"format": schema_dict.get("format"),
|
||||
"description": schema_dict.get("description"),
|
||||
"nullable": schema_dict.get("nullable"),
|
||||
"enum": schema_dict.get("enum"),
|
||||
"items": _create_google_schema_recursive(schema_dict["items"]) if google_type == Type.ARRAY and "items" in schema_dict else None,
|
||||
"properties": {k: prop_schema for k, v in schema_dict.get("properties", {}).items() if (prop_schema := _create_google_schema_recursive(v)) is not None}
|
||||
if google_type == Type.OBJECT and schema_dict.get("properties")
|
||||
else None,
|
||||
"required": schema_dict.get("required") if google_type == Type.OBJECT else None,
|
||||
}
|
||||
|
||||
schema_args = {k: v for k, v in schema_args.items() if v is not None}
|
||||
|
||||
if google_type == Type.ARRAY and "items" not in schema_args:
|
||||
logger.warning(f"Array schema missing 'items': {schema_dict}. Returning None.")
|
||||
return None
|
||||
if google_type == Type.OBJECT and "properties" not in schema_args:
|
||||
pass
|
||||
|
||||
try:
|
||||
created_schema = Schema(**schema_args)
|
||||
return created_schema
|
||||
except Exception as schema_creation_err:
|
||||
logger.error(f"Failed to create Schema object with args {schema_args}: {schema_creation_err}", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
def convert_to_google_tool_objects(tool_configs: list[dict[str, Any]]) -> list[Tool] | None:
|
||||
"""
|
||||
Convert the dictionary-based tool configurations into Google's Tool objects.
|
||||
|
||||
Args:
|
||||
tool_configs: A list containing a dictionary with 'function_declarations',
|
||||
as produced by `convert_to_google_tools`.
|
||||
|
||||
Returns:
|
||||
A list containing a single Google `Tool` object, or None if conversion fails
|
||||
or no valid declarations are found.
|
||||
"""
|
||||
if Tool is None or FunctionDeclaration is None:
|
||||
logger.error("Cannot create Tool objects: google.genai types not available.")
|
||||
return None
|
||||
if not tool_configs:
|
||||
logger.debug("No tool configurations provided to convert to Tool objects.")
|
||||
return None
|
||||
|
||||
all_func_declarations = []
|
||||
if isinstance(tool_configs, list) and len(tool_configs) > 0 and "function_declarations" in tool_configs[0]:
|
||||
func_declarations_list = tool_configs[0]["function_declarations"]
|
||||
if not isinstance(func_declarations_list, list):
|
||||
logger.error(f"Expected 'function_declarations' to be a list, got {type(func_declarations_list)}")
|
||||
return None
|
||||
|
||||
for func_dict in func_declarations_list:
|
||||
func_name = func_dict.get("name", "Unknown")
|
||||
try:
|
||||
params_schema_dict = func_dict.get("parameters", {})
|
||||
|
||||
if not isinstance(params_schema_dict, dict):
|
||||
logger.warning(f"Invalid 'parameters' format for tool {func_name}: {params_schema_dict}. Using empty object schema.")
|
||||
params_schema_dict = {"type": "object", "properties": {}}
|
||||
elif "type" not in params_schema_dict:
|
||||
params_schema_dict["type"] = "object"
|
||||
elif params_schema_dict["type"] != "object":
|
||||
logger.warning(f"Tool {func_name} parameters schema is not type 'object' ({params_schema_dict.get('type')}). Google requires 'object'. Attempting to wrap properties.")
|
||||
original_properties = params_schema_dict.get("properties", {})
|
||||
if not isinstance(original_properties, dict):
|
||||
original_properties = {}
|
||||
params_schema_dict = {"type": "object", "properties": original_properties}
|
||||
|
||||
properties_dict = params_schema_dict.get("properties", {})
|
||||
google_properties = {}
|
||||
if isinstance(properties_dict, dict):
|
||||
for prop_name, prop_schema_dict in properties_dict.items():
|
||||
prop_schema = _create_google_schema_recursive(prop_schema_dict)
|
||||
if prop_schema:
|
||||
google_properties[prop_name] = prop_schema
|
||||
else:
|
||||
logger.warning(f"Failed to create schema for property '{prop_name}' in tool '{func_name}'. Skipping property.")
|
||||
else:
|
||||
logger.warning(f"'properties' for tool {func_name} is not a dictionary: {properties_dict}. Ignoring properties.")
|
||||
|
||||
if not google_properties:
|
||||
logger.warning(f"Function '{func_name}' has no valid properties defined. Adding dummy property for Google compatibility.")
|
||||
google_properties = {"_dummy_param": Schema(type=Type.STRING, description="Placeholder parameter as properties cannot be empty.")}
|
||||
required_list = []
|
||||
else:
|
||||
original_required = params_schema_dict.get("required", [])
|
||||
if isinstance(original_required, list):
|
||||
required_list = [req for req in original_required if req in google_properties]
|
||||
if len(required_list) != len(original_required):
|
||||
logger.warning(f"Some required properties for '{func_name}' were invalid or missing from properties: {set(original_required) - set(required_list)}")
|
||||
else:
|
||||
logger.warning(f"'required' field for '{func_name}' is not a list: {original_required}. Ignoring required field.")
|
||||
required_list = []
|
||||
|
||||
parameters_schema = Schema(
|
||||
type=Type.OBJECT,
|
||||
properties=google_properties,
|
||||
required=required_list if required_list else None,
|
||||
)
|
||||
|
||||
declaration = FunctionDeclaration(
|
||||
name=func_name,
|
||||
description=func_dict.get("description", ""),
|
||||
parameters=parameters_schema,
|
||||
)
|
||||
all_func_declarations.append(declaration)
|
||||
logger.debug(f"Successfully created FunctionDeclaration for: {func_name}")
|
||||
|
||||
except Exception as decl_err:
|
||||
logger.error(f"Failed to create FunctionDeclaration object for tool '{func_name}': {decl_err}", exc_info=True)
|
||||
|
||||
else:
|
||||
logger.error(f"Invalid tool_configs structure provided: {tool_configs}")
|
||||
return None
|
||||
|
||||
if not all_func_declarations:
|
||||
logger.warning("No valid Google FunctionDeclarations were created from the provided configurations.")
|
||||
return None
|
||||
|
||||
logger.info(f"Successfully created {len(all_func_declarations)} Google FunctionDeclarations.")
|
||||
return [Tool(function_declarations=all_func_declarations)]
|
||||
|
||||
|
||||
def has_google_tool_calls(response: Any) -> bool:
|
||||
"""
|
||||
Checks if the Google response object contains tool calls (FunctionCalls).
|
||||
|
||||
Args:
|
||||
response: The response object from the Google generate_content API call.
|
||||
|
||||
Returns:
|
||||
True if FunctionCalls are present, False otherwise.
|
||||
"""
|
||||
try:
|
||||
if hasattr(response, "candidates") and response.candidates:
|
||||
candidate = response.candidates[0]
|
||||
if hasattr(candidate, "content") and hasattr(candidate.content, "parts"):
|
||||
for part in candidate.content.parts:
|
||||
if hasattr(part, "function_call") and part.function_call:
|
||||
logger.debug(f"Tool call (FunctionCall) detected in Google response part: {part.function_call.name}")
|
||||
return True
|
||||
|
||||
logger.debug("No tool calls (FunctionCall) detected in Google response.")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking for Google tool calls: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
|
||||
def parse_google_tool_calls(response: Any) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Parses tool calls (FunctionCalls) from a non-streaming Google response object.
|
||||
|
||||
Args:
|
||||
response: The non-streaming response object from the Google generate_content API call.
|
||||
|
||||
Returns:
|
||||
A list of dictionaries, each representing a tool call in the standard MCP format
|
||||
(id, server_name, function_name, arguments as JSON string).
|
||||
Returns an empty list if no calls are found or an error occurs.
|
||||
"""
|
||||
parsed_calls = []
|
||||
try:
|
||||
if not (hasattr(response, "candidates") and response.candidates):
|
||||
logger.warning("Cannot parse tool calls: Response has no candidates.")
|
||||
return []
|
||||
|
||||
candidate = response.candidates[0]
|
||||
if not (hasattr(candidate, "content") and hasattr(candidate.content, "parts")):
|
||||
logger.warning("Cannot parse tool calls: Response candidate has no content or parts.")
|
||||
return []
|
||||
|
||||
logger.debug("Parsing tool calls (FunctionCall) from Google response.")
|
||||
call_index = 0
|
||||
for part in candidate.content.parts:
|
||||
if hasattr(part, "function_call") and part.function_call:
|
||||
func_call = part.function_call
|
||||
call_id = f"call_{call_index}"
|
||||
call_index += 1
|
||||
|
||||
full_name = func_call.name
|
||||
parts = full_name.split("__", 1)
|
||||
if len(parts) == 2:
|
||||
server_name, func_name = parts
|
||||
else:
|
||||
logger.warning(f"Could not determine server_name from Google tool name '{full_name}'. Using None for server_name.")
|
||||
server_name = None
|
||||
func_name = full_name
|
||||
|
||||
try:
|
||||
args_dict = dict(func_call.args) if func_call.args else {}
|
||||
args_str = json.dumps(args_dict)
|
||||
except Exception as json_err:
|
||||
logger.error(f"Failed to dump arguments dict to JSON string for {func_name}: {json_err}")
|
||||
args_str = json.dumps({"error": "Failed to serialize arguments", "original_args": str(func_call.args)})
|
||||
|
||||
parsed_calls.append({
|
||||
"id": call_id,
|
||||
"server_name": server_name,
|
||||
"function_name": func_name,
|
||||
"arguments": args_str,
|
||||
"_google_tool_name": full_name,
|
||||
})
|
||||
logger.debug(f"Parsed tool call: ID {call_id}, Server {server_name}, Func {func_name}, Args {args_str[:100]}...")
|
||||
|
||||
return parsed_calls
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing Google tool calls: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
|
||||
def format_google_tool_results(tool_call_id: str, function_name: str, result: Any) -> dict[str, Any]:
|
||||
"""
|
||||
Formats a tool result for a Google follow-up request (FunctionResponse).
|
||||
|
||||
Args:
|
||||
tool_call_id: The unique ID assigned during parsing (e.g., "call_0").
|
||||
Note: Google's API itself doesn't use this ID directly in the
|
||||
FunctionResponse part, but we need it for mapping in the message list.
|
||||
function_name: The original function name (without server prefix) that was called.
|
||||
result: The data returned by the tool execution. Should be JSON-serializable.
|
||||
|
||||
Returns:
|
||||
A dictionary representing the tool result message in the standard MCP format.
|
||||
This will be converted later by `_convert_messages`.
|
||||
"""
|
||||
try:
|
||||
if isinstance(result, (str, int, float, bool, list)):
|
||||
content_dict = {"result": result}
|
||||
elif isinstance(result, dict):
|
||||
content_dict = result
|
||||
else:
|
||||
logger.warning(f"Tool result for {function_name} is of non-standard type {type(result)}. Converting to string.")
|
||||
content_dict = {"result": str(result)}
|
||||
|
||||
try:
|
||||
content_str = json.dumps(content_dict)
|
||||
except Exception as json_err:
|
||||
logger.error(f"Error JSON-encoding tool result content for Google {function_name} ({tool_call_id}): {json_err}")
|
||||
content_str = json.dumps({"error": "Failed to encode tool result content", "original_type": str(type(result))})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error preparing tool result content for Google {function_name} ({tool_call_id}): {e}")
|
||||
content_str = json.dumps({"error": "Failed to prepare tool result content", "details": str(e)})
|
||||
|
||||
logger.debug(f"Formatting Google tool result for call ID {tool_call_id} (Function: {function_name})")
|
||||
return {
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"content": content_str,
|
||||
"name": function_name,
|
||||
}
|
||||
127
src/providers/google_provider/utils.py
Normal file
127
src/providers/google_provider/utils.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from google.genai.types import Content, Part
|
||||
|
||||
from src.llm_models import MODELS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_context_window(model: str) -> int:
|
||||
"""Retrieves the context window size for a given Google model."""
|
||||
default_window = 1000000
|
||||
try:
|
||||
provider_models = MODELS.get("google", {}).get("models", [])
|
||||
for m in provider_models:
|
||||
if m.get("id") == model:
|
||||
return m.get("context_window", default_window)
|
||||
logger.warning(f"Context window for Google model '{model}' not found in MODELS config. Using default: {default_window}")
|
||||
return default_window
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving context window for model {model}: {e}. Using default: {default_window}", exc_info=True)
|
||||
return default_window
|
||||
|
||||
|
||||
def convert_messages(messages: list[dict[str, Any]]) -> tuple[list[Content], str | None]:
|
||||
"""
|
||||
Converts standard message format to Google's format, extracting system prompt.
|
||||
Handles mapping roles and structuring tool calls/results.
|
||||
"""
|
||||
google_messages: list[Content] = []
|
||||
system_prompt: str | None = None
|
||||
|
||||
for i, message in enumerate(messages):
|
||||
role = message.get("role")
|
||||
content = message.get("content")
|
||||
tool_calls = message.get("tool_calls")
|
||||
tool_call_id = message.get("tool_call_id")
|
||||
|
||||
if role == "system":
|
||||
if i == 0:
|
||||
system_prompt = content
|
||||
logger.debug("Extracted system prompt for Google.")
|
||||
else:
|
||||
logger.warning("System message found not at the beginning. Skipping for Google API.")
|
||||
continue
|
||||
|
||||
google_role = {"user": "user", "assistant": "model"}.get(role)
|
||||
|
||||
if not google_role and role != "tool":
|
||||
logger.warning(f"Unsupported role '{role}' for Google provider, skipping message.")
|
||||
continue
|
||||
|
||||
parts: list[Part | str] = []
|
||||
if role == "tool":
|
||||
if tool_call_id and content:
|
||||
try:
|
||||
response_content_dict = json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Could not decode tool result content for {tool_call_id}, sending as raw string.")
|
||||
response_content_dict = {"result": content}
|
||||
|
||||
func_name = "unknown_function"
|
||||
if i > 0 and messages[i - 1].get("role") == "assistant":
|
||||
prev_tool_calls = messages[i - 1].get("tool_calls")
|
||||
if prev_tool_calls:
|
||||
for tc in prev_tool_calls:
|
||||
if tc.get("id") == tool_call_id:
|
||||
full_name = tc.get("function_name", "unknown_function")
|
||||
func_name = full_name.split("__", 1)[-1]
|
||||
break
|
||||
|
||||
parts.append(Part.from_function_response(name=func_name, response={"content": response_content_dict}))
|
||||
google_role = "function"
|
||||
else:
|
||||
logger.warning(f"Skipping tool message due to missing tool_call_id or content: {message}")
|
||||
continue
|
||||
|
||||
elif role == "assistant" and tool_calls:
|
||||
for tool_call in tool_calls:
|
||||
args = tool_call.get("arguments", {})
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
args = json.loads(args)
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Failed to parse arguments string for tool call {tool_call.get('id')}: {args}")
|
||||
args = {"error": "failed to parse arguments"}
|
||||
|
||||
full_name = tool_call.get("function_name", "unknown_function")
|
||||
func_name = full_name.split("__", 1)[-1]
|
||||
|
||||
parts.append(Part.from_function_call(name=func_name, args=args))
|
||||
|
||||
if content and isinstance(content, str):
|
||||
parts.append(Part(text=content))
|
||||
|
||||
elif content:
|
||||
if isinstance(content, str):
|
||||
parts.append(Part(text=content))
|
||||
else:
|
||||
logger.warning(f"Unsupported content type for role '{role}': {type(content)}. Converting to string.")
|
||||
parts.append(Part(text=str(content)))
|
||||
|
||||
if parts:
|
||||
google_messages.append(Content(role=google_role, parts=parts))
|
||||
else:
|
||||
logger.debug(f"No parts generated for message: {message}")
|
||||
|
||||
last_role = None
|
||||
valid_alternation = True
|
||||
for msg in google_messages:
|
||||
current_role = msg.role
|
||||
if current_role == last_role and current_role in ["user", "model"]:
|
||||
valid_alternation = False
|
||||
logger.error(f"Invalid role sequence for Google: consecutive '{current_role}' roles.")
|
||||
break
|
||||
if last_role == "function" and current_role != "user":
|
||||
valid_alternation = False
|
||||
logger.error(f"Invalid role sequence for Google: '{current_role}' follows 'function'. Expected 'user'.")
|
||||
break
|
||||
last_role = current_role
|
||||
|
||||
if not valid_alternation:
|
||||
raise ValueError("Invalid message sequence for Google API. Roles must alternate between 'user' and 'model', with 'function' responses followed by 'user'.")
|
||||
|
||||
return google_messages, system_prompt
|
||||
@@ -1,239 +0,0 @@
|
||||
# src/providers/openai_provider.py
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from openai import OpenAI, Stream
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall
|
||||
|
||||
from providers.base import BaseProvider
|
||||
from src.llm_models import MODELS # Use absolute import
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIProvider(BaseProvider):
|
||||
"""Provider implementation for OpenAI and compatible APIs."""
|
||||
|
||||
def __init__(self, api_key: str, base_url: str | None = None):
|
||||
# Use default OpenAI endpoint if base_url is not provided
|
||||
effective_base_url = base_url or MODELS.get("openai", {}).get("endpoint")
|
||||
super().__init__(api_key, effective_base_url)
|
||||
logger.info(f"Initializing OpenAIProvider with base URL: {self.base_url}")
|
||||
try:
|
||||
# TODO: Add default headers like in original client?
|
||||
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize OpenAI client: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def create_chat_completion(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
model: str,
|
||||
temperature: float = 0.4,
|
||||
max_tokens: int | None = None,
|
||||
stream: bool = True,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
) -> Stream[ChatCompletionChunk] | ChatCompletion:
|
||||
"""Creates a chat completion using the OpenAI API."""
|
||||
logger.debug(f"OpenAI create_chat_completion called. Stream: {stream}, Tools: {bool(tools)}")
|
||||
try:
|
||||
completion_params = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"stream": stream,
|
||||
}
|
||||
if tools:
|
||||
completion_params["tools"] = tools
|
||||
completion_params["tool_choice"] = "auto" # Let OpenAI decide when to use tools
|
||||
|
||||
# Remove None values like max_tokens if not provided
|
||||
completion_params = {k: v for k, v in completion_params.items() if v is not None}
|
||||
|
||||
# --- Added Debug Logging ---
|
||||
log_params = completion_params.copy()
|
||||
# Avoid logging full messages if they are too long
|
||||
if "messages" in log_params:
|
||||
log_params["messages"] = [
|
||||
{k: (v[:100] + "..." if isinstance(v, str) and len(v) > 100 else v) for k, v in msg.items()}
|
||||
for msg in log_params["messages"][-2:] # Log last 2 messages summary
|
||||
]
|
||||
# Specifically log tools structure if present
|
||||
tools_log = log_params.get("tools", "Not Present")
|
||||
logger.debug(f"Calling OpenAI API. Model: {log_params.get('model')}, Stream: {log_params.get('stream')}, Tools: {tools_log}")
|
||||
logger.debug(f"Full API Params (messages summarized): {log_params}")
|
||||
# --- End Added Debug Logging ---
|
||||
|
||||
response = self.client.chat.completions.create(**completion_params)
|
||||
logger.debug("OpenAI API call successful.")
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI API error: {e}", exc_info=True)
|
||||
# Re-raise for the LLMClient to handle
|
||||
raise
|
||||
|
||||
def get_streaming_content(self, response: Stream[ChatCompletionChunk]) -> Generator[str, None, None]:
|
||||
"""Yields content chunks from an OpenAI streaming response."""
|
||||
logger.debug("Processing OpenAI stream...")
|
||||
full_delta = ""
|
||||
try:
|
||||
for chunk in response:
|
||||
delta = chunk.choices[0].delta.content
|
||||
if delta:
|
||||
full_delta += delta
|
||||
yield delta
|
||||
logger.debug(f"Stream finished. Total delta length: {len(full_delta)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing OpenAI stream: {e}", exc_info=True)
|
||||
# Yield an error message? Or let the generator stop?
|
||||
yield json.dumps({"error": f"Stream processing error: {str(e)}"})
|
||||
|
||||
def get_content(self, response: ChatCompletion) -> str:
|
||||
"""Extracts content from a non-streaming OpenAI response."""
|
||||
try:
|
||||
content = response.choices[0].message.content
|
||||
logger.debug(f"Extracted content (length {len(content) if content else 0}) from non-streaming response.")
|
||||
return content or "" # Return empty string if content is None
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting content from OpenAI response: {e}", exc_info=True)
|
||||
return f"[Error extracting content: {str(e)}]"
|
||||
|
||||
def has_tool_calls(self, response: Stream[ChatCompletionChunk] | ChatCompletion) -> bool:
|
||||
"""Checks if the OpenAI response contains tool calls."""
|
||||
try:
|
||||
if isinstance(response, ChatCompletion): # Non-streaming
|
||||
return bool(response.choices[0].message.tool_calls)
|
||||
elif hasattr(response, "_iterator"): # Check if it looks like our stream wrapper
|
||||
# This is tricky for streams. We'd need to peek at the first chunk(s)
|
||||
# or buffer the response. For simplicity, this check might be unreliable
|
||||
# for streams *before* they are consumed. LLMClient needs robust handling.
|
||||
logger.warning("has_tool_calls check on a stream is unreliable before consumption.")
|
||||
# A more robust check would involve consuming the start of the stream
|
||||
# or relying on the structure after consumption.
|
||||
return False # Assume no for unconsumed stream for now
|
||||
else:
|
||||
# If it's already consumed stream or unexpected type
|
||||
logger.warning(f"has_tool_calls received unexpected type: {type(response)}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking for tool calls: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def parse_tool_calls(self, response: ChatCompletion) -> list[dict[str, Any]]:
|
||||
"""Parses tool calls from a non-streaming OpenAI response."""
|
||||
# This implementation assumes a non-streaming response or a fully buffered stream
|
||||
parsed_calls = []
|
||||
try:
|
||||
if not isinstance(response, ChatCompletion):
|
||||
logger.error(f"parse_tool_calls expects ChatCompletion, got {type(response)}")
|
||||
# Attempt to handle buffered stream if possible? Complex.
|
||||
return []
|
||||
|
||||
tool_calls: list[ChatCompletionMessageToolCall] | None = response.choices[0].message.tool_calls
|
||||
if not tool_calls:
|
||||
return []
|
||||
|
||||
logger.debug(f"Parsing {len(tool_calls)} tool calls from OpenAI response.")
|
||||
for call in tool_calls:
|
||||
if call.type == "function":
|
||||
# Attempt to parse server_name from function name if prefixed
|
||||
# e.g., "server-name__actual-tool-name"
|
||||
parts = call.function.name.split("__", 1)
|
||||
if len(parts) == 2:
|
||||
server_name, func_name = parts
|
||||
else:
|
||||
# If no prefix, how do we know the server? Needs refinement.
|
||||
# Defaulting to None or a default server? Log warning.
|
||||
logger.warning(f"Could not determine server_name from tool name '{call.function.name}'. Assuming default or error needed.")
|
||||
server_name = None # Or raise error, or use a default?
|
||||
func_name = call.function.name
|
||||
|
||||
parsed_calls.append({
|
||||
"id": call.id,
|
||||
"server_name": server_name, # May be None if not prefixed
|
||||
"function_name": func_name,
|
||||
"arguments": call.function.arguments, # Arguments are already a string here
|
||||
})
|
||||
else:
|
||||
logger.warning(f"Unsupported tool call type: {call.type}")
|
||||
|
||||
return parsed_calls
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing OpenAI tool calls: {e}", exc_info=True)
|
||||
return [] # Return empty list on error
|
||||
|
||||
def format_tool_results(self, tool_call_id: str, result: Any) -> dict[str, Any]:
|
||||
"""Formats a tool result for an OpenAI follow-up request."""
|
||||
# Result might be a dict (including potential errors) or simple string/number
|
||||
# OpenAI expects the content to be a string, often JSON.
|
||||
try:
|
||||
if isinstance(result, dict):
|
||||
content = json.dumps(result)
|
||||
else:
|
||||
content = str(result) # Ensure it's a string
|
||||
except Exception as e:
|
||||
logger.error(f"Error JSON-encoding tool result for {tool_call_id}: {e}")
|
||||
content = json.dumps({"error": "Failed to encode tool result", "original_type": str(type(result))})
|
||||
|
||||
logger.debug(f"Formatting tool result for call ID {tool_call_id}")
|
||||
return {
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"content": content,
|
||||
}
|
||||
|
||||
def convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Converts internal tool format to OpenAI's format."""
|
||||
openai_tools = []
|
||||
logger.debug(f"Converting {len(tools)} tools to OpenAI format.")
|
||||
for tool in tools:
|
||||
server_name = tool.get("server_name")
|
||||
tool_name = tool.get("name")
|
||||
description = tool.get("description")
|
||||
input_schema = tool.get("inputSchema")
|
||||
|
||||
if not server_name or not tool_name or not description or not input_schema:
|
||||
logger.warning(f"Skipping invalid tool definition during conversion: {tool}")
|
||||
continue
|
||||
|
||||
# Prefix tool name with server name to avoid clashes and allow routing
|
||||
prefixed_tool_name = f"{server_name}__{tool_name}"
|
||||
|
||||
openai_tool_format = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": prefixed_tool_name,
|
||||
"description": description,
|
||||
"parameters": input_schema, # OpenAI uses JSON Schema directly
|
||||
},
|
||||
}
|
||||
openai_tools.append(openai_tool_format)
|
||||
logger.debug(f"Converted tool: {prefixed_tool_name}")
|
||||
|
||||
return openai_tools
|
||||
|
||||
# Helper needed by LLMClient's current tool handling logic
|
||||
def get_original_message_with_calls(self, response: ChatCompletion) -> dict[str, Any]:
|
||||
"""Extracts the assistant's message containing tool calls."""
|
||||
try:
|
||||
if isinstance(response, ChatCompletion) and response.choices[0].message.tool_calls:
|
||||
message = response.choices[0].message
|
||||
# Convert Pydantic model to dict for message history
|
||||
return message.model_dump(exclude_unset=True)
|
||||
else:
|
||||
logger.warning("Could not extract original message with tool calls from response.")
|
||||
# Return a placeholder or raise error?
|
||||
return {"role": "assistant", "content": "[Could not extract tool calls message]"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting original message with calls: {e}", exc_info=True)
|
||||
return {"role": "assistant", "content": f"[Error extracting tool calls message: {str(e)}]"}
|
||||
|
||||
|
||||
# Register this provider (if using the registration mechanism)
|
||||
# from . import register_provider
|
||||
# register_provider("openai", OpenAIProvider)
|
||||
62
src/providers/openai_provider/__init__.py
Normal file
62
src/providers/openai_provider/__init__.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from typing import Any
|
||||
|
||||
from openai import Stream
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
|
||||
from providers.openai_provider.client import initialize_client
|
||||
from providers.openai_provider.completion import create_chat_completion
|
||||
from providers.openai_provider.response import get_content, get_streaming_content, get_usage
|
||||
from providers.openai_provider.tools import (
|
||||
convert_tools,
|
||||
format_tool_results,
|
||||
get_original_message_with_calls,
|
||||
has_tool_calls,
|
||||
parse_tool_calls,
|
||||
)
|
||||
from src.providers.base import BaseProvider
|
||||
|
||||
|
||||
class OpenAIProvider(BaseProvider):
|
||||
"""Provider implementation for OpenAI and compatible APIs."""
|
||||
|
||||
temperature: float
|
||||
|
||||
def __init__(self, api_key: str, base_url: str | None = None, temperature: float = 0.6):
|
||||
self.client = initialize_client(api_key, base_url)
|
||||
self.api_key = api_key
|
||||
self.base_url = self.client.base_url
|
||||
self.temperature = temperature
|
||||
|
||||
def create_chat_completion(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
model: str,
|
||||
max_tokens: int | None = None,
|
||||
stream: bool = True,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
) -> Stream[ChatCompletionChunk] | ChatCompletion:
|
||||
return create_chat_completion(self, messages, model, self.temperature, max_tokens, stream, tools)
|
||||
|
||||
def get_streaming_content(self, response: Stream[ChatCompletionChunk]):
|
||||
return get_streaming_content(response)
|
||||
|
||||
def get_content(self, response: ChatCompletion) -> str:
|
||||
return get_content(response)
|
||||
|
||||
def has_tool_calls(self, response: Stream[ChatCompletionChunk] | ChatCompletion) -> bool:
|
||||
return has_tool_calls(response)
|
||||
|
||||
def parse_tool_calls(self, response: ChatCompletion) -> list[dict[str, Any]]:
|
||||
return parse_tool_calls(response)
|
||||
|
||||
def format_tool_results(self, tool_call_id: str, result: Any) -> dict[str, Any]:
|
||||
return format_tool_results(tool_call_id, result)
|
||||
|
||||
def convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
return convert_tools(tools)
|
||||
|
||||
def get_original_message_with_calls(self, response: ChatCompletion) -> dict[str, Any]:
|
||||
return get_original_message_with_calls(response)
|
||||
|
||||
def get_usage(self, response: Any) -> dict[str, int] | None:
|
||||
return get_usage(response)
|
||||
19
src/providers/openai_provider/client.py
Normal file
19
src/providers/openai_provider/client.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import logging
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from src.llm_models import MODELS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def initialize_client(api_key: str, base_url: str | None = None) -> OpenAI:
|
||||
"""Initializes and returns an OpenAI client instance."""
|
||||
effective_base_url = base_url or MODELS.get("openai", {}).get("endpoint")
|
||||
logger.info(f"Initializing OpenAI client with base URL: {effective_base_url}")
|
||||
try:
|
||||
client = OpenAI(api_key=api_key, base_url=effective_base_url)
|
||||
return client
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize OpenAI client: {e}", exc_info=True)
|
||||
raise
|
||||
61
src/providers/openai_provider/completion.py
Normal file
61
src/providers/openai_provider/completion.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from openai import Stream
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
|
||||
from providers.openai_provider.utils import truncate_messages
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_chat_completion(
|
||||
provider,
|
||||
messages: list[dict[str, str]],
|
||||
model: str,
|
||||
temperature: float = 0.6,
|
||||
max_tokens: int | None = None,
|
||||
stream: bool = True,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
) -> Stream[ChatCompletionChunk] | ChatCompletion:
|
||||
"""Creates a chat completion using the OpenAI API, handling context window truncation."""
|
||||
logger.debug(f"OpenAI create_chat_completion called. Model: {model}, Stream: {stream}, Tools: {bool(tools)}")
|
||||
|
||||
truncated_messages, initial_est_tokens, final_est_tokens = truncate_messages(messages, model)
|
||||
|
||||
try:
|
||||
completion_params = {
|
||||
"model": model,
|
||||
"messages": truncated_messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"stream": stream,
|
||||
}
|
||||
if tools:
|
||||
completion_params["tools"] = tools
|
||||
completion_params["tool_choice"] = "auto"
|
||||
|
||||
completion_params = {k: v for k, v in completion_params.items() if v is not None}
|
||||
|
||||
log_params = completion_params.copy()
|
||||
tools_log = log_params.get("tools", "Not Present")
|
||||
logger.debug(f"Calling OpenAI API. Model: {log_params.get('model')}, Stream: {log_params.get('stream')}, Tools: {tools_log}")
|
||||
logger.debug(f"Full API Params: {log_params}")
|
||||
|
||||
response = provider.client.chat.completions.create(**completion_params)
|
||||
logger.debug("OpenAI API call successful.")
|
||||
|
||||
actual_usage = None
|
||||
if isinstance(response, ChatCompletion) and response.usage:
|
||||
actual_usage = {
|
||||
"prompt_tokens": response.usage.prompt_tokens,
|
||||
"completion_tokens": response.usage.completion_tokens,
|
||||
"total_tokens": response.usage.total_tokens,
|
||||
}
|
||||
logger.info(f"Actual OpenAI API usage: {actual_usage}")
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI API error: {e}", exc_info=True)
|
||||
raise
|
||||
60
src/providers/openai_provider/response.py
Normal file
60
src/providers/openai_provider/response.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from openai import Stream
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_streaming_content(response: Stream[ChatCompletionChunk]) -> Generator[str, None, None]:
|
||||
"""Yields content chunks from an OpenAI streaming response."""
|
||||
logger.debug("Processing OpenAI stream...")
|
||||
full_delta = ""
|
||||
try:
|
||||
for chunk in response:
|
||||
if chunk.choices:
|
||||
delta = chunk.choices[0].delta.content
|
||||
if delta:
|
||||
full_delta += delta
|
||||
yield delta
|
||||
logger.debug(f"Stream finished. Total delta length: {len(full_delta)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing OpenAI stream: {e}", exc_info=True)
|
||||
yield json.dumps({"error": f"Stream processing error: {str(e)}"})
|
||||
|
||||
|
||||
def get_content(response: ChatCompletion) -> str:
|
||||
"""Extracts content from a non-streaming OpenAI response."""
|
||||
try:
|
||||
if response.choices:
|
||||
content = response.choices[0].message.content
|
||||
logger.debug(f"Extracted content (length {len(content) if content else 0}) from non-streaming response.")
|
||||
return content or ""
|
||||
else:
|
||||
logger.warning("No choices found in OpenAI non-streaming response.")
|
||||
return "[No content received]"
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting content from OpenAI response: {e}", exc_info=True)
|
||||
return f"[Error extracting content: {str(e)}]"
|
||||
|
||||
|
||||
def get_usage(response: Any) -> dict[str, int] | None:
|
||||
"""Extracts token usage from a non-streaming OpenAI response."""
|
||||
try:
|
||||
if isinstance(response, ChatCompletion) and response.usage:
|
||||
usage = {
|
||||
"prompt_tokens": response.usage.prompt_tokens,
|
||||
"completion_tokens": response.usage.completion_tokens,
|
||||
}
|
||||
logger.debug(f"Extracted usage from OpenAI response: {usage}")
|
||||
return usage
|
||||
else:
|
||||
if not isinstance(response, Stream):
|
||||
logger.warning(f"Could not extract usage from OpenAI response object of type {type(response)}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting usage from OpenAI response: {e}", exc_info=True)
|
||||
return None
|
||||
147
src/providers/openai_provider/tools.py
Normal file
147
src/providers/openai_provider/tools.py
Normal file
@@ -0,0 +1,147 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from openai import Stream
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def has_tool_calls(response: Stream[ChatCompletionChunk] | ChatCompletion) -> bool:
|
||||
"""Checks if the OpenAI response contains tool calls."""
|
||||
try:
|
||||
if isinstance(response, ChatCompletion):
|
||||
if response.choices:
|
||||
return bool(response.choices[0].message.tool_calls)
|
||||
else:
|
||||
logger.warning("No choices found in OpenAI non-streaming response for tool check.")
|
||||
return False
|
||||
elif isinstance(response, Stream):
|
||||
logger.warning("has_tool_calls check on a stream is unreliable before consumption.")
|
||||
return False
|
||||
else:
|
||||
logger.warning(f"has_tool_calls received unexpected type: {type(response)}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking for tool calls: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
|
||||
def parse_tool_calls(response: ChatCompletion) -> list[dict[str, Any]]:
|
||||
"""Parses tool calls from a non-streaming OpenAI response."""
|
||||
parsed_calls = []
|
||||
try:
|
||||
if not isinstance(response, ChatCompletion):
|
||||
logger.error(f"parse_tool_calls expects ChatCompletion, got {type(response)}")
|
||||
return []
|
||||
|
||||
if not response.choices:
|
||||
logger.warning("No choices found in OpenAI non-streaming response for tool parsing.")
|
||||
return []
|
||||
|
||||
tool_calls: list[ChatCompletionMessageToolCall] | None = response.choices[0].message.tool_calls
|
||||
if not tool_calls:
|
||||
return []
|
||||
|
||||
logger.debug(f"Parsing {len(tool_calls)} tool calls from OpenAI response.")
|
||||
for call in tool_calls:
|
||||
if call.type == "function":
|
||||
parts = call.function.name.split("__", 1)
|
||||
if len(parts) == 2:
|
||||
server_name, func_name = parts
|
||||
else:
|
||||
logger.warning(f"Could not determine server_name from tool name '{call.function.name}'. Assuming default or error needed.")
|
||||
server_name = None
|
||||
func_name = call.function.name
|
||||
|
||||
arguments_obj = None
|
||||
try:
|
||||
if isinstance(call.function.arguments, str):
|
||||
arguments_obj = json.loads(call.function.arguments)
|
||||
else:
|
||||
arguments_obj = call.function.arguments
|
||||
except json.JSONDecodeError as json_err:
|
||||
logger.error(f"Failed to parse JSON arguments for tool {func_name} (ID: {call.id}): {json_err}")
|
||||
logger.error(f"Raw arguments string: {call.function.arguments}")
|
||||
arguments_obj = {"error": "Failed to parse arguments", "raw_arguments": call.function.arguments}
|
||||
|
||||
parsed_calls.append({
|
||||
"id": call.id,
|
||||
"server_name": server_name,
|
||||
"function_name": func_name,
|
||||
"arguments": arguments_obj,
|
||||
})
|
||||
else:
|
||||
logger.warning(f"Unsupported tool call type: {call.type}")
|
||||
|
||||
return parsed_calls
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing OpenAI tool calls: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
|
||||
def format_tool_results(tool_call_id: str, result: Any) -> dict[str, Any]:
|
||||
"""Formats a tool result for an OpenAI follow-up request."""
|
||||
try:
|
||||
if isinstance(result, dict):
|
||||
content = json.dumps(result)
|
||||
elif isinstance(result, str):
|
||||
content = result
|
||||
else:
|
||||
content = str(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error JSON-encoding tool result for {tool_call_id}: {e}")
|
||||
content = json.dumps({"error": "Failed to encode tool result", "original_type": str(type(result))})
|
||||
|
||||
logger.debug(f"Formatting tool result for call ID {tool_call_id}")
|
||||
return {
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"content": content,
|
||||
}
|
||||
|
||||
|
||||
def convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Converts internal tool format to OpenAI's format."""
|
||||
openai_tools = []
|
||||
logger.debug(f"Converting {len(tools)} tools to OpenAI format.")
|
||||
for tool in tools:
|
||||
server_name = tool.get("server_name")
|
||||
tool_name = tool.get("name")
|
||||
description = tool.get("description")
|
||||
input_schema = tool.get("inputSchema")
|
||||
|
||||
if not server_name or not tool_name or not description or not input_schema:
|
||||
logger.warning(f"Skipping invalid tool definition during conversion: {tool}")
|
||||
continue
|
||||
|
||||
prefixed_tool_name = f"{server_name}__{tool_name}"
|
||||
|
||||
openai_tool_format = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": prefixed_tool_name,
|
||||
"description": description,
|
||||
"parameters": input_schema,
|
||||
},
|
||||
}
|
||||
openai_tools.append(openai_tool_format)
|
||||
logger.debug(f"Converted tool: {prefixed_tool_name}")
|
||||
|
||||
return openai_tools
|
||||
|
||||
|
||||
def get_original_message_with_calls(response: ChatCompletion) -> dict[str, Any]:
|
||||
"""Extracts the assistant's message containing tool calls."""
|
||||
try:
|
||||
if isinstance(response, ChatCompletion) and response.choices and response.choices[0].message.tool_calls:
|
||||
message = response.choices[0].message
|
||||
return message.model_dump(exclude_unset=True)
|
||||
else:
|
||||
logger.warning("Could not extract original message with tool calls from response.")
|
||||
return {"role": "assistant", "content": "[Could not extract tool calls message]"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting original message with calls: {e}", exc_info=True)
|
||||
return {"role": "assistant", "content": f"[Error extracting tool calls message: {str(e)}]"}
|
||||
100
src/providers/openai_provider/utils.py
Normal file
100
src/providers/openai_provider/utils.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import logging
|
||||
import math
|
||||
|
||||
from src.llm_models import MODELS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_context_window(model: str) -> int:
|
||||
"""Retrieves the context window size for a given model."""
|
||||
default_window = 8000
|
||||
try:
|
||||
provider_models = MODELS.get("openai", {}).get("models", [])
|
||||
for m in provider_models:
|
||||
if m.get("id") == model:
|
||||
return m.get("context_window", default_window)
|
||||
logger.warning(f"Context window for OpenAI model '{model}' not found in MODELS config. Using default: {default_window}")
|
||||
return default_window
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving context window for model {model}: {e}. Using default: {default_window}", exc_info=True)
|
||||
return default_window
|
||||
|
||||
|
||||
def estimate_openai_token_count(messages: list[dict[str, str]]) -> int:
|
||||
"""
|
||||
Estimates the token count for OpenAI messages using char count / 4 approximation.
|
||||
Note: This is less accurate than using tiktoken.
|
||||
"""
|
||||
total_chars = 0
|
||||
for message in messages:
|
||||
total_chars += len(message.get("role", ""))
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
total_chars += len(content)
|
||||
estimated_tokens = math.ceil(total_chars / 4.0)
|
||||
logger.debug(f"Estimated OpenAI token count (char/4): {estimated_tokens} for {len(messages)} messages")
|
||||
return estimated_tokens
|
||||
|
||||
|
||||
def truncate_messages(messages: list[dict[str, str]], model: str) -> tuple[list[dict[str, str]], int, int]:
|
||||
"""
|
||||
Truncates messages from the beginning if estimated token count exceeds the limit.
|
||||
Preserves the first message if it's a system prompt.
|
||||
|
||||
Returns:
|
||||
- The potentially truncated list of messages.
|
||||
- The initial estimated token count.
|
||||
- The final estimated token count after truncation (if any).
|
||||
"""
|
||||
context_limit = get_context_window(model)
|
||||
buffer = 200
|
||||
effective_limit = context_limit - buffer
|
||||
|
||||
initial_estimated_count = estimate_openai_token_count(messages)
|
||||
final_estimated_count = initial_estimated_count
|
||||
|
||||
truncated_messages = list(messages)
|
||||
|
||||
has_system_prompt = False
|
||||
if truncated_messages and truncated_messages[0].get("role") == "system":
|
||||
has_system_prompt = True
|
||||
if len(truncated_messages) == 1 and final_estimated_count > effective_limit:
|
||||
logger.warning(f"System prompt alone ({final_estimated_count} tokens) exceeds effective limit ({effective_limit}). Cannot truncate further.")
|
||||
return messages, initial_estimated_count, final_estimated_count
|
||||
|
||||
while final_estimated_count > effective_limit:
|
||||
if has_system_prompt and len(truncated_messages) <= 1:
|
||||
logger.warning("Truncation stopped: Only system prompt remains.")
|
||||
break
|
||||
if not has_system_prompt and len(truncated_messages) <= 0:
|
||||
logger.warning("Truncation stopped: No messages left.")
|
||||
break
|
||||
|
||||
remove_index = 1 if has_system_prompt and len(truncated_messages) > 1 else 0
|
||||
|
||||
if remove_index >= len(truncated_messages):
|
||||
logger.error(f"Truncation logic error: remove_index {remove_index} out of bounds for {len(truncated_messages)} messages.")
|
||||
break
|
||||
|
||||
removed_message = truncated_messages.pop(remove_index)
|
||||
logger.debug(f"Truncating message at index {remove_index} (Role: {removed_message.get('role')}) due to context limit.")
|
||||
|
||||
final_estimated_count = estimate_openai_token_count(truncated_messages)
|
||||
logger.debug(f"Recalculated estimated tokens: {final_estimated_count}")
|
||||
|
||||
if not truncated_messages:
|
||||
logger.warning("Truncation resulted in empty message list.")
|
||||
break
|
||||
|
||||
if initial_estimated_count != final_estimated_count:
|
||||
logger.info(
|
||||
f"Truncated messages for model {model}. "
|
||||
f"Initial estimated tokens: {initial_estimated_count}, "
|
||||
f"Final estimated tokens: {final_estimated_count}, "
|
||||
f"Limit: {context_limit} (Effective: {effective_limit})"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"No truncation needed for model {model}. Estimated tokens: {final_estimated_count}, Limit: {context_limit} (Effective: {effective_limit})")
|
||||
|
||||
return truncated_messages, initial_estimated_count, final_estimated_count
|
||||
@@ -1,6 +0,0 @@
|
||||
# src/tools/__init__.py
|
||||
# This file makes the 'tools' directory a Python package.
|
||||
|
||||
# Optionally import key functions/classes for easier access
|
||||
# from .conversion import convert_to_openai_tools, convert_to_anthropic_tools
|
||||
# from .execution import execute_tool # Assuming execution.py will exist
|
||||
@@ -1,177 +0,0 @@
|
||||
"""
|
||||
Conversion utilities for MCP tools.
|
||||
|
||||
This module contains functions to convert between different tool formats
|
||||
for various LLM providers (OpenAI, Anthropic, etc.).
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def convert_to_openai_tools(mcp_tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Convert MCP tools to OpenAI tool definitions.
|
||||
|
||||
Args:
|
||||
mcp_tools: List of MCP tools (each with server_name, name, description, inputSchema).
|
||||
|
||||
Returns:
|
||||
List of OpenAI tool definitions.
|
||||
"""
|
||||
openai_tools = []
|
||||
logger.debug(f"Converting {len(mcp_tools)} MCP tools to OpenAI format.")
|
||||
|
||||
for tool in mcp_tools:
|
||||
server_name = tool.get("server_name")
|
||||
tool_name = tool.get("name")
|
||||
description = tool.get("description")
|
||||
input_schema = tool.get("inputSchema")
|
||||
|
||||
if not server_name or not tool_name or not description or not input_schema:
|
||||
logger.warning(f"Skipping invalid MCP tool definition during OpenAI conversion: {tool}")
|
||||
continue
|
||||
|
||||
# Prefix tool name with server name for routing
|
||||
prefixed_tool_name = f"{server_name}__{tool_name}"
|
||||
|
||||
# Initialize the OpenAI tool structure
|
||||
openai_tool = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": prefixed_tool_name,
|
||||
"description": description,
|
||||
"parameters": input_schema, # OpenAI uses JSON Schema directly
|
||||
},
|
||||
}
|
||||
# Basic validation/cleaning of schema if needed could go here
|
||||
if not isinstance(input_schema, dict) or input_schema.get("type") != "object":
|
||||
logger.warning(f"Input schema for tool '{prefixed_tool_name}' is not a valid JSON object schema. OpenAI might reject this.")
|
||||
# Ensure basic structure if missing
|
||||
if not isinstance(input_schema, dict):
|
||||
input_schema = {}
|
||||
if "type" not in input_schema:
|
||||
input_schema["type"] = "object"
|
||||
if "properties" not in input_schema:
|
||||
input_schema["properties"] = {}
|
||||
openai_tool["function"]["parameters"] = input_schema
|
||||
|
||||
openai_tools.append(openai_tool)
|
||||
logger.debug(f"Converted MCP tool to OpenAI: {prefixed_tool_name}")
|
||||
|
||||
return openai_tools
|
||||
|
||||
|
||||
def convert_to_anthropic_tools(mcp_tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Convert MCP tools to Anthropic tool definitions.
|
||||
|
||||
Args:
|
||||
mcp_tools: List of MCP tools (each with server_name, name, description, inputSchema).
|
||||
|
||||
Returns:
|
||||
List of Anthropic tool definitions.
|
||||
"""
|
||||
logger.debug(f"Converting {len(mcp_tools)} MCP tools to Anthropic format")
|
||||
anthropic_tools = []
|
||||
|
||||
for tool in mcp_tools:
|
||||
server_name = tool.get("server_name")
|
||||
tool_name = tool.get("name")
|
||||
description = tool.get("description")
|
||||
input_schema = tool.get("inputSchema")
|
||||
|
||||
if not server_name or not tool_name or not description or not input_schema:
|
||||
logger.warning(f"Skipping invalid MCP tool definition during Anthropic conversion: {tool}")
|
||||
continue
|
||||
|
||||
# Prefix tool name with server name for routing
|
||||
prefixed_tool_name = f"{server_name}__{tool_name}"
|
||||
|
||||
# Initialize the Anthropic tool structure
|
||||
# Anthropic's format is quite close to JSON Schema
|
||||
anthropic_tool = {"name": prefixed_tool_name, "description": description, "input_schema": input_schema}
|
||||
|
||||
# Basic validation/cleaning of schema if needed
|
||||
if not isinstance(input_schema, dict) or input_schema.get("type") != "object":
|
||||
logger.warning(f"Input schema for tool '{prefixed_tool_name}' is not a valid JSON object schema. Anthropic might reject this.")
|
||||
# Ensure basic structure if missing
|
||||
if not isinstance(input_schema, dict):
|
||||
input_schema = {}
|
||||
if "type" not in input_schema:
|
||||
input_schema["type"] = "object"
|
||||
if "properties" not in input_schema:
|
||||
input_schema["properties"] = {}
|
||||
anthropic_tool["input_schema"] = input_schema
|
||||
|
||||
anthropic_tools.append(anthropic_tool)
|
||||
logger.debug(f"Converted MCP tool to Anthropic: {prefixed_tool_name}")
|
||||
|
||||
return anthropic_tools
|
||||
|
||||
|
||||
def convert_to_google_tools(mcp_tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Convert MCP tools to Google Gemini format (dictionary structure).
|
||||
|
||||
Args:
|
||||
mcp_tools: List of MCP tools (each with server_name, name, description, inputSchema).
|
||||
|
||||
Returns:
|
||||
List containing one dictionary with 'function_declarations'.
|
||||
"""
|
||||
logger.debug(f"Converting {len(mcp_tools)} MCP tools to Google Gemini format")
|
||||
|
||||
function_declarations = []
|
||||
|
||||
for tool in mcp_tools:
|
||||
server_name = tool.get("server_name")
|
||||
tool_name = tool.get("name")
|
||||
description = tool.get("description")
|
||||
input_schema = tool.get("inputSchema")
|
||||
|
||||
if not server_name or not tool_name or not description or not input_schema:
|
||||
logger.warning(f"Skipping invalid MCP tool definition during Google conversion: {tool}")
|
||||
continue
|
||||
|
||||
# Prefix tool name with server name for routing
|
||||
prefixed_tool_name = f"{server_name}__{tool_name}"
|
||||
|
||||
# Basic validation/cleaning of schema
|
||||
if not isinstance(input_schema, dict) or input_schema.get("type") != "object":
|
||||
logger.warning(f"Input schema for tool '{prefixed_tool_name}' is not a valid JSON object schema. Google might reject this.")
|
||||
# Ensure basic structure if missing
|
||||
if not isinstance(input_schema, dict):
|
||||
input_schema = {}
|
||||
if "type" not in input_schema:
|
||||
input_schema["type"] = "object"
|
||||
if "properties" not in input_schema:
|
||||
input_schema["properties"] = {}
|
||||
# Google requires properties for object type, add dummy if empty
|
||||
if not input_schema["properties"]:
|
||||
logger.warning(f"Empty properties for tool '{prefixed_tool_name}', adding dummy property for Google.")
|
||||
input_schema["properties"] = {"_dummy_param": {"type": "STRING", "description": "Placeholder"}}
|
||||
|
||||
# Create function declaration for Google's format
|
||||
function_declaration = {
|
||||
"name": prefixed_tool_name,
|
||||
"description": description,
|
||||
"parameters": input_schema, # Google uses JSON Schema directly
|
||||
}
|
||||
|
||||
function_declarations.append(function_declaration)
|
||||
logger.debug(f"Converted MCP tool to Google FunctionDeclaration: {prefixed_tool_name}")
|
||||
|
||||
# Google API expects a list containing one Tool object dict
|
||||
google_tools_wrapper = [{"function_declarations": function_declarations}] if function_declarations else []
|
||||
|
||||
logger.debug(f"Final Google tools structure: {google_tools_wrapper}")
|
||||
return google_tools_wrapper
|
||||
|
||||
|
||||
# Note: The _handle_schema_construct helper from the reference code is not strictly
|
||||
# needed if we assume the inputSchema is already valid JSON Schema.
|
||||
# If complex schemas (anyOf, etc.) need specific handling beyond standard JSON Schema,
|
||||
# that logic could be added here or within the provider implementations.
|
||||
Reference in New Issue
Block a user