Compare commits
2 Commits
678f395649
...
2fb6c5af3c
| Author | SHA1 | Date | |
|---|---|---|---|
|
2fb6c5af3c
|
|||
|
6b390a35f8
|
9
.gitignore
vendored
9
.gitignore
vendored
@@ -5,6 +5,7 @@ __pycache__/
|
|||||||
|
|
||||||
# Virtual environment
|
# Virtual environment
|
||||||
env/
|
env/
|
||||||
|
.venv/
|
||||||
|
|
||||||
# Configuration
|
# Configuration
|
||||||
config/config.ini
|
config/config.ini
|
||||||
@@ -20,4 +21,10 @@ config/mcp_config.json
|
|||||||
# resources
|
# resources
|
||||||
resources/
|
resources/
|
||||||
|
|
||||||
# __pycache__/
|
# Ruff
|
||||||
|
.ruff_cache/
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
dist/
|
||||||
|
build/
|
||||||
|
*.egg-info/
|
||||||
|
|||||||
@@ -1,68 +0,0 @@
|
|||||||
"""OpenAI client with custom MCP integration."""
|
|
||||||
|
|
||||||
import configparser
|
|
||||||
import logging # Import logging
|
|
||||||
|
|
||||||
from openai import OpenAI
|
|
||||||
|
|
||||||
from mcp_manager import SyncMCPManager
|
|
||||||
|
|
||||||
# Get a logger for this module
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIClient:
|
|
||||||
def __init__(self):
|
|
||||||
logger.debug("Initializing OpenAIClient...") # Add init log
|
|
||||||
self.config = configparser.ConfigParser()
|
|
||||||
self.config.read("config/config.ini")
|
|
||||||
|
|
||||||
# Validate configuration
|
|
||||||
if not self.config.has_section("openai"):
|
|
||||||
raise Exception("Missing [openai] section in config.ini")
|
|
||||||
if not self.config["openai"].get("api_key"):
|
|
||||||
raise Exception("Missing api_key in config.ini")
|
|
||||||
|
|
||||||
# Configure OpenAI client
|
|
||||||
self.client = OpenAI(
|
|
||||||
api_key=self.config["openai"]["api_key"], base_url=self.config["openai"]["base_url"], default_headers={"HTTP-Referer": "https://streamlit-chat-app.com", "X-Title": "Streamlit Chat App"}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize MCP manager if configured
|
|
||||||
self.mcp_manager = None
|
|
||||||
if self.config.has_section("mcp"):
|
|
||||||
mcp_config_path = self.config["mcp"].get("servers_json", "config/mcp_config.json")
|
|
||||||
self.mcp_manager = SyncMCPManager(mcp_config_path)
|
|
||||||
|
|
||||||
def get_chat_response(self, messages):
|
|
||||||
try:
|
|
||||||
# Try using MCP if available
|
|
||||||
if self.mcp_manager and self.mcp_manager.initialize():
|
|
||||||
logger.info("Using MCP with tools...") # Use logger
|
|
||||||
last_message = messages[-1]["content"]
|
|
||||||
# Pass API key and base URL from config.ini
|
|
||||||
response = self.mcp_manager.process_query(
|
|
||||||
query=last_message,
|
|
||||||
model_name=self.config["openai"]["model"],
|
|
||||||
api_key=self.config["openai"]["api_key"],
|
|
||||||
base_url=self.config["openai"].get("base_url"), # Use .get for optional base_url
|
|
||||||
)
|
|
||||||
|
|
||||||
if "error" not in response:
|
|
||||||
logger.debug("MCP processing successful, wrapping response.")
|
|
||||||
# Convert to OpenAI-compatible response format
|
|
||||||
return self._wrap_mcp_response(response)
|
|
||||||
|
|
||||||
# Fall back to standard OpenAI
|
|
||||||
logger.info(f"Falling back to standard OpenAI API with model: {self.config['openai']['model']}") # Use logger
|
|
||||||
return self.client.chat.completions.create(model=self.config["openai"]["model"], messages=messages, stream=True)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
error_msg = f"API Error (Code: {getattr(e, 'code', 'N/A')}): {str(e)}"
|
|
||||||
logger.error(error_msg, exc_info=True) # Use logger
|
|
||||||
raise Exception(error_msg)
|
|
||||||
|
|
||||||
def _wrap_mcp_response(self, response: dict):
|
|
||||||
"""Return the MCP response dictionary directly (for non-streaming)."""
|
|
||||||
# No conversion needed if app.py handles dicts separately
|
|
||||||
return response
|
|
||||||
@@ -1,483 +0,0 @@
|
|||||||
# src/providers/google_provider.py
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import traceback
|
|
||||||
from collections.abc import Generator
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from google import genai
|
|
||||||
from google.genai.types import (
|
|
||||||
Content,
|
|
||||||
FunctionDeclaration,
|
|
||||||
Part,
|
|
||||||
Schema,
|
|
||||||
Tool,
|
|
||||||
)
|
|
||||||
|
|
||||||
from src.llm_models import MODELS
|
|
||||||
from src.providers.base import BaseProvider
|
|
||||||
from src.tools.conversion import convert_to_google_tools
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class GoogleProvider(BaseProvider):
|
|
||||||
"""Provider implementation for Google Gemini models."""
|
|
||||||
|
|
||||||
def __init__(self, api_key: str, base_url: str | None = None):
|
|
||||||
# Google client typically doesn't use a base_url, but we accept it for consistency
|
|
||||||
effective_base_url = base_url or MODELS.get("google", {}).get("endpoint")
|
|
||||||
super().__init__(api_key, effective_base_url)
|
|
||||||
logger.info("Initializing GoogleProvider")
|
|
||||||
|
|
||||||
if genai is None:
|
|
||||||
raise ImportError("Google Generative AI SDK is required for GoogleProvider. Please install google-generativeai.")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Configure the client
|
|
||||||
genai.configure(api_key=self.api_key)
|
|
||||||
self.client_module = genai
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to configure Google Generative AI client: {e}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _get_context_window(self, model: str) -> int:
|
|
||||||
"""Retrieves the context window size for a given Google model."""
|
|
||||||
default_window = 1000000 # Default fallback for Gemini
|
|
||||||
try:
|
|
||||||
provider_models = MODELS.get("google", {}).get("models", [])
|
|
||||||
for m in provider_models:
|
|
||||||
if m.get("id") == model:
|
|
||||||
return m.get("context_window", default_window)
|
|
||||||
logger.warning(f"Context window for Google model '{model}' not found in MODELS config. Using default: {default_window}")
|
|
||||||
return default_window
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error retrieving context window for model {model}: {e}. Using default: {default_window}", exc_info=True)
|
|
||||||
return default_window
|
|
||||||
|
|
||||||
def _convert_messages(self, messages: list[dict[str, Any]]) -> tuple[list[Content], str | None]:
|
|
||||||
"""
|
|
||||||
Converts standard message format to Google's format, extracting system prompt.
|
|
||||||
Handles mapping roles and structuring tool calls/results.
|
|
||||||
"""
|
|
||||||
google_messages: list[Content] = []
|
|
||||||
system_prompt: str | None = None
|
|
||||||
|
|
||||||
for i, message in enumerate(messages):
|
|
||||||
role = message.get("role")
|
|
||||||
content = message.get("content")
|
|
||||||
tool_calls = message.get("tool_calls")
|
|
||||||
tool_call_id = message.get("tool_call_id")
|
|
||||||
|
|
||||||
if role == "system":
|
|
||||||
if i == 0:
|
|
||||||
system_prompt = content
|
|
||||||
logger.debug("Extracted system prompt for Google.")
|
|
||||||
else:
|
|
||||||
logger.warning("System message found not at the beginning. Merging into subsequent user message.")
|
|
||||||
continue
|
|
||||||
|
|
||||||
google_role = {"user": "user", "assistant": "model", "tool": "user"}.get(role)
|
|
||||||
|
|
||||||
if not google_role:
|
|
||||||
logger.warning(f"Unsupported role '{role}' for Google provider, skipping message.")
|
|
||||||
continue
|
|
||||||
|
|
||||||
parts: list[Part | str] = []
|
|
||||||
if role == "tool":
|
|
||||||
if tool_call_id and content:
|
|
||||||
try:
|
|
||||||
response_content_dict = json.loads(content)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
logger.warning(f"Could not decode tool result content for {tool_call_id}, sending as raw string.")
|
|
||||||
response_content_dict = {"result": content}
|
|
||||||
|
|
||||||
func_name = "unknown_function"
|
|
||||||
if i > 0 and messages[i - 1].get("role") == "assistant":
|
|
||||||
prev_tool_calls = messages[i - 1].get("tool_calls")
|
|
||||||
if prev_tool_calls:
|
|
||||||
for tc in prev_tool_calls:
|
|
||||||
if tc.get("id") == tool_call_id:
|
|
||||||
func_name = tc.get("function_name", "unknown_function")
|
|
||||||
break
|
|
||||||
|
|
||||||
parts.append(Part.from_function_response(name=func_name, response={"content": response_content_dict}))
|
|
||||||
google_role = "function"
|
|
||||||
else:
|
|
||||||
logger.warning(f"Skipping tool message due to missing tool_call_id or content: {message}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
elif role == "assistant" and tool_calls:
|
|
||||||
for tool_call in tool_calls:
|
|
||||||
args = tool_call.get("arguments", {})
|
|
||||||
if isinstance(args, str):
|
|
||||||
try:
|
|
||||||
args = json.loads(args)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
logger.error(f"Failed to parse arguments string for tool call {tool_call.get('id')}: {args}")
|
|
||||||
args = {"error": "failed to parse arguments"}
|
|
||||||
func_name = tool_call.get("function_name", "unknown_function")
|
|
||||||
parts.append(Part.from_function_call(name=func_name, args=args))
|
|
||||||
if content:
|
|
||||||
parts.append(Part.from_text(content))
|
|
||||||
|
|
||||||
elif content:
|
|
||||||
if isinstance(content, str):
|
|
||||||
parts.append(Part.from_text(content))
|
|
||||||
else:
|
|
||||||
logger.warning(f"Unsupported content type for role '{role}': {type(content)}. Converting to string.")
|
|
||||||
parts.append(Part.from_text(str(content)))
|
|
||||||
|
|
||||||
if parts:
|
|
||||||
google_messages.append(Content(role=google_role, parts=parts))
|
|
||||||
else:
|
|
||||||
logger.debug(f"No parts generated for message: {message}")
|
|
||||||
|
|
||||||
last_role = None
|
|
||||||
valid_alternation = True
|
|
||||||
for msg in google_messages:
|
|
||||||
current_role = msg.role
|
|
||||||
if current_role == last_role and current_role in ["user", "model"]:
|
|
||||||
valid_alternation = False
|
|
||||||
logger.warning(f"Invalid role sequence detected: consecutive '{current_role}' roles.")
|
|
||||||
break
|
|
||||||
if last_role == "function" and current_role != "user":
|
|
||||||
valid_alternation = False
|
|
||||||
logger.warning(f"Invalid role sequence: '{current_role}' follows 'function'. Expected 'user'.")
|
|
||||||
break
|
|
||||||
last_role = current_role
|
|
||||||
|
|
||||||
if not valid_alternation:
|
|
||||||
logger.error("Message list does not follow required user/model alternation for Google API.")
|
|
||||||
raise ValueError("Invalid message sequence for Google API.")
|
|
||||||
|
|
||||||
return google_messages, system_prompt
|
|
||||||
|
|
||||||
def create_chat_completion(
|
|
||||||
self,
|
|
||||||
messages: list[dict[str, str]],
|
|
||||||
model: str,
|
|
||||||
temperature: float = 0.4,
|
|
||||||
max_tokens: int | None = None,
|
|
||||||
stream: bool = True,
|
|
||||||
tools: list[dict[str, Any]] | None = None,
|
|
||||||
) -> Any:
|
|
||||||
"""Creates a chat completion using the Google Gemini API."""
|
|
||||||
logger.debug(f"Google create_chat_completion called. Model: {model}, Stream: {stream}, Tools: {bool(tools)}")
|
|
||||||
|
|
||||||
if self.client_module is None:
|
|
||||||
return {"error": "Google Generative AI SDK not installed."} if not stream else iter([json.dumps({"error": "Google Generative AI SDK not installed."})])
|
|
||||||
|
|
||||||
try:
|
|
||||||
google_messages, system_prompt = self._convert_messages(messages)
|
|
||||||
generation_config: dict[str, Any] = {"temperature": temperature}
|
|
||||||
if max_tokens is not None:
|
|
||||||
generation_config["max_output_tokens"] = max_tokens
|
|
||||||
|
|
||||||
google_tools = None
|
|
||||||
if tools:
|
|
||||||
try:
|
|
||||||
tool_dict_list = convert_to_google_tools(tools)
|
|
||||||
google_tools = self._convert_to_tool_objects(tool_dict_list)
|
|
||||||
logger.debug(f"Converted {len(tools)} tools to {len(google_tools)} Google Tool objects.")
|
|
||||||
except Exception as tool_conv_err:
|
|
||||||
logger.error(f"Failed to convert tools for Google: {tool_conv_err}", exc_info=True)
|
|
||||||
google_tools = None
|
|
||||||
|
|
||||||
gemini_model = self.client_module.GenerativeModel(
|
|
||||||
model_name=model,
|
|
||||||
system_instruction=system_prompt,
|
|
||||||
tools=google_tools if google_tools else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
log_params = {
|
|
||||||
"model": model,
|
|
||||||
"stream": stream,
|
|
||||||
"temperature": temperature,
|
|
||||||
"max_tokens": max_tokens,
|
|
||||||
"system_prompt_present": bool(system_prompt),
|
|
||||||
"num_tools": len(google_tools) if google_tools else 0,
|
|
||||||
"num_messages": len(google_messages),
|
|
||||||
}
|
|
||||||
logger.debug(f"Calling Google API with params: {log_params}")
|
|
||||||
|
|
||||||
response = gemini_model.generate_content(
|
|
||||||
contents=google_messages,
|
|
||||||
generation_config=generation_config,
|
|
||||||
stream=stream,
|
|
||||||
)
|
|
||||||
logger.debug("Google API call successful.")
|
|
||||||
return response
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
error_msg = f"Google API error: {e}"
|
|
||||||
logger.error(error_msg, exc_info=True)
|
|
||||||
if stream:
|
|
||||||
yield json.dumps({"error": error_msg, "traceback": traceback.format_exc()})
|
|
||||||
else:
|
|
||||||
return {"error": error_msg, "traceback": traceback.format_exc()}
|
|
||||||
|
|
||||||
def get_streaming_content(self, response: Any) -> Generator[str, None, None]:
|
|
||||||
"""Yields content chunks from a Google streaming response."""
|
|
||||||
logger.debug("Processing Google stream...")
|
|
||||||
full_delta = ""
|
|
||||||
try:
|
|
||||||
if isinstance(response, dict) and "error" in response:
|
|
||||||
yield json.dumps(response)
|
|
||||||
return
|
|
||||||
if hasattr(response, "__iter__") and not hasattr(response, "candidates"):
|
|
||||||
yield from response
|
|
||||||
return
|
|
||||||
|
|
||||||
for chunk in response:
|
|
||||||
if isinstance(chunk, dict) and "error" in chunk:
|
|
||||||
yield json.dumps(chunk)
|
|
||||||
continue
|
|
||||||
if hasattr(chunk, "text"):
|
|
||||||
delta = chunk.text
|
|
||||||
if delta:
|
|
||||||
full_delta += delta
|
|
||||||
yield delta
|
|
||||||
elif hasattr(chunk, "candidates") and chunk.candidates:
|
|
||||||
for part in chunk.candidates[0].content.parts:
|
|
||||||
if hasattr(part, "function_call") and part.function_call:
|
|
||||||
logger.debug(f"Function call detected during stream: {part.function_call.name}")
|
|
||||||
break
|
|
||||||
|
|
||||||
logger.debug(f"Google stream finished. Total delta length: {len(full_delta)}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error processing Google stream: {e}", exc_info=True)
|
|
||||||
yield json.dumps({"error": f"Stream processing error: {str(e)}"})
|
|
||||||
|
|
||||||
def get_content(self, response: Any) -> str:
|
|
||||||
"""Extracts content from a non-streaming Google response."""
|
|
||||||
try:
|
|
||||||
if isinstance(response, dict) and "error" in response:
|
|
||||||
logger.error(f"Cannot get content from error response: {response['error']}")
|
|
||||||
return f"[Error: {response['error']}]"
|
|
||||||
if hasattr(response, "text"):
|
|
||||||
content = response.text
|
|
||||||
logger.debug(f"Extracted content (length {len(content)}) from response.text.")
|
|
||||||
return content
|
|
||||||
elif hasattr(response, "candidates") and response.candidates:
|
|
||||||
first_candidate = response.candidates[0]
|
|
||||||
if hasattr(first_candidate, "content") and hasattr(first_candidate.content, "parts"):
|
|
||||||
text_parts = [part.text for part in first_candidate.content.parts if hasattr(part, "text")]
|
|
||||||
content = "".join(text_parts)
|
|
||||||
logger.debug(f"Extracted content (length {len(content)}) from response candidates.")
|
|
||||||
return content
|
|
||||||
else:
|
|
||||||
logger.warning("Google response candidate has no content or parts.")
|
|
||||||
return ""
|
|
||||||
else:
|
|
||||||
logger.warning("Could not extract content from Google response: No 'text' or valid 'candidates'.")
|
|
||||||
return ""
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error extracting content from Google response: {e}", exc_info=True)
|
|
||||||
return f"[Error extracting content: {str(e)}]"
|
|
||||||
|
|
||||||
def has_tool_calls(self, response: Any) -> bool:
|
|
||||||
"""Checks if the Google response contains tool calls (function calls)."""
|
|
||||||
try:
|
|
||||||
if isinstance(response, dict) and "error" in response:
|
|
||||||
return False
|
|
||||||
if hasattr(response, "candidates") and response.candidates:
|
|
||||||
candidate = response.candidates[0]
|
|
||||||
if hasattr(candidate, "content") and hasattr(candidate.content, "parts"):
|
|
||||||
for part in candidate.content.parts:
|
|
||||||
if hasattr(part, "function_call") and part.function_call:
|
|
||||||
logger.debug(f"Tool call (FunctionCall) detected in Google response part: {part.function_call.name}")
|
|
||||||
return True
|
|
||||||
logger.debug("No tool calls (FunctionCall) detected in Google response.")
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error checking for Google tool calls: {e}", exc_info=True)
|
|
||||||
return False
|
|
||||||
|
|
||||||
def parse_tool_calls(self, response: Any) -> list[dict[str, Any]]:
|
|
||||||
"""Parses tool calls (function calls) from a non-streaming Google response."""
|
|
||||||
parsed_calls = []
|
|
||||||
try:
|
|
||||||
if not (hasattr(response, "candidates") and response.candidates):
|
|
||||||
logger.warning("Cannot parse tool calls: Response has no candidates.")
|
|
||||||
return []
|
|
||||||
|
|
||||||
candidate = response.candidates[0]
|
|
||||||
if not (hasattr(candidate, "content") and hasattr(candidate.content, "parts")):
|
|
||||||
logger.warning("Cannot parse tool calls: Response candidate has no content or parts.")
|
|
||||||
return []
|
|
||||||
|
|
||||||
logger.debug("Parsing tool calls (FunctionCall) from Google response.")
|
|
||||||
call_index = 0
|
|
||||||
for part in candidate.content.parts:
|
|
||||||
if hasattr(part, "function_call") and part.function_call:
|
|
||||||
func_call = part.function_call
|
|
||||||
call_id = f"call_{call_index}"
|
|
||||||
call_index += 1
|
|
||||||
|
|
||||||
full_name = func_call.name
|
|
||||||
parts = full_name.split("__", 1)
|
|
||||||
if len(parts) == 2:
|
|
||||||
server_name, func_name = parts
|
|
||||||
else:
|
|
||||||
logger.warning(f"Could not determine server_name from Google tool name '{full_name}'.")
|
|
||||||
server_name = None
|
|
||||||
func_name = full_name
|
|
||||||
|
|
||||||
try:
|
|
||||||
args_str = json.dumps(func_call.args or {})
|
|
||||||
except Exception as json_err:
|
|
||||||
logger.error(f"Failed to dump arguments dict to JSON string for {func_name}: {json_err}")
|
|
||||||
args_str = json.dumps({"error": "Failed to serialize arguments", "original_args": str(func_call.args)})
|
|
||||||
|
|
||||||
parsed_calls.append({
|
|
||||||
"id": call_id,
|
|
||||||
"server_name": server_name,
|
|
||||||
"function_name": func_name,
|
|
||||||
"arguments": args_str,
|
|
||||||
})
|
|
||||||
logger.debug(f"Parsed tool call: ID {call_id}, Server {server_name}, Func {func_name}, Args {args_str[:100]}...")
|
|
||||||
|
|
||||||
return parsed_calls
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error parsing Google tool calls: {e}", exc_info=True)
|
|
||||||
return []
|
|
||||||
|
|
||||||
def format_tool_results(self, tool_call_id: str, result: Any) -> dict[str, Any]:
|
|
||||||
"""Formats a tool result for a Google follow-up request."""
|
|
||||||
try:
|
|
||||||
if isinstance(result, dict):
|
|
||||||
content_str = json.dumps(result)
|
|
||||||
else:
|
|
||||||
content_str = str(result)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error JSON-encoding tool result for Google {tool_call_id}: {e}")
|
|
||||||
content_str = json.dumps({"error": "Failed to encode tool result", "original_type": str(type(result))})
|
|
||||||
|
|
||||||
logger.debug(f"Formatting Google tool result for call ID {tool_call_id}")
|
|
||||||
return {
|
|
||||||
"role": "tool",
|
|
||||||
"tool_call_id": tool_call_id,
|
|
||||||
"content": content_str,
|
|
||||||
"function_name": "unknown_function",
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_original_message_with_calls(self, response: Any) -> dict[str, Any]:
|
|
||||||
"""Extracts the assistant's message containing tool calls for Google."""
|
|
||||||
try:
|
|
||||||
if not (hasattr(response, "candidates") and response.candidates):
|
|
||||||
return {"role": "assistant", "content": "[Could not extract tool calls message: No candidates]"}
|
|
||||||
|
|
||||||
candidate = response.candidates[0]
|
|
||||||
if not (hasattr(candidate, "content") and hasattr(candidate.content, "parts")):
|
|
||||||
return {"role": "assistant", "content": "[Could not extract tool calls message: No content/parts]"}
|
|
||||||
|
|
||||||
tool_calls_formatted = []
|
|
||||||
text_content_parts = []
|
|
||||||
for part in candidate.content.parts:
|
|
||||||
if hasattr(part, "function_call") and part.function_call:
|
|
||||||
func_call = part.function_call
|
|
||||||
args = func_call.args or {}
|
|
||||||
tool_calls_formatted.append({
|
|
||||||
"function_name": func_call.name,
|
|
||||||
"arguments": args,
|
|
||||||
})
|
|
||||||
elif hasattr(part, "text"):
|
|
||||||
text_content_parts.append(part.text)
|
|
||||||
|
|
||||||
message = {"role": "assistant"}
|
|
||||||
if tool_calls_formatted:
|
|
||||||
message["tool_calls"] = tool_calls_formatted
|
|
||||||
text_content = "".join(text_content_parts)
|
|
||||||
if text_content:
|
|
||||||
message["content"] = text_content
|
|
||||||
elif not tool_calls_formatted:
|
|
||||||
message["content"] = ""
|
|
||||||
|
|
||||||
return message
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error extracting original Google message with calls: {e}", exc_info=True)
|
|
||||||
return {"role": "assistant", "content": f"[Error extracting tool calls message: {str(e)}]"}
|
|
||||||
|
|
||||||
def get_usage(self, response: Any) -> dict[str, int] | None:
|
|
||||||
"""Extracts token usage from a Google response."""
|
|
||||||
try:
|
|
||||||
if isinstance(response, dict) and "error" in response:
|
|
||||||
return None
|
|
||||||
if hasattr(response, "usage_metadata"):
|
|
||||||
metadata = response.usage_metadata
|
|
||||||
usage = {
|
|
||||||
"prompt_tokens": getattr(metadata, "prompt_token_count", 0),
|
|
||||||
"completion_tokens": getattr(metadata, "candidates_token_count", 0),
|
|
||||||
}
|
|
||||||
logger.debug(f"Extracted usage from Google response metadata: {usage}")
|
|
||||||
return usage
|
|
||||||
else:
|
|
||||||
logger.warning(f"Could not extract usage from Google response object of type {type(response)}. No 'usage_metadata'.")
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error extracting usage from Google response: {e}", exc_info=True)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _convert_to_tool_objects(self, tool_configs: list[dict[str, Any]]) -> list[Tool] | None:
|
|
||||||
"""Convert dictionary-format tools into Google's Tool objects."""
|
|
||||||
if not tool_configs:
|
|
||||||
return None
|
|
||||||
|
|
||||||
all_func_declarations = []
|
|
||||||
for config in tool_configs:
|
|
||||||
if "function_declarations" in config:
|
|
||||||
for func_dict in config["function_declarations"]:
|
|
||||||
try:
|
|
||||||
params_schema_dict = func_dict.get("parameters", {"type": "object", "properties": {}})
|
|
||||||
if params_schema_dict.get("type") != "object":
|
|
||||||
logger.warning(f"Tool {func_dict['name']} parameters schema is not type 'object'. Forcing object type.")
|
|
||||||
params_schema_dict = {"type": "object", "properties": params_schema_dict}
|
|
||||||
|
|
||||||
def create_schema(schema_dict):
|
|
||||||
if not isinstance(schema_dict, dict):
|
|
||||||
logger.warning(f"Invalid schema part encountered: {schema_dict}. Returning empty schema.")
|
|
||||||
return Schema()
|
|
||||||
schema_args = {
|
|
||||||
"type": schema_dict.get("type"),
|
|
||||||
"format": schema_dict.get("format"),
|
|
||||||
"description": schema_dict.get("description"),
|
|
||||||
"nullable": schema_dict.get("nullable"),
|
|
||||||
"enum": schema_dict.get("enum"),
|
|
||||||
"items": create_schema(schema_dict["items"]) if "items" in schema_dict else None,
|
|
||||||
"properties": {k: create_schema(v) for k, v in schema_dict.get("properties", {}).items()} if schema_dict.get("properties") else None,
|
|
||||||
"required": schema_dict.get("required"),
|
|
||||||
}
|
|
||||||
schema_args = {k: v for k, v in schema_args.items() if v is not None}
|
|
||||||
if "type" in schema_args:
|
|
||||||
type_mapping = {
|
|
||||||
"string": "STRING",
|
|
||||||
"number": "NUMBER",
|
|
||||||
"integer": "INTEGER",
|
|
||||||
"boolean": "BOOLEAN",
|
|
||||||
"array": "ARRAY",
|
|
||||||
"object": "OBJECT",
|
|
||||||
}
|
|
||||||
schema_args["type"] = type_mapping.get(str(schema_args["type"]).lower(), schema_args["type"])
|
|
||||||
try:
|
|
||||||
return Schema(**schema_args)
|
|
||||||
except Exception as schema_creation_err:
|
|
||||||
logger.error(f"Failed to create Schema object for {func_dict['name']} with args {schema_args}: {schema_creation_err}", exc_info=True)
|
|
||||||
return Schema()
|
|
||||||
|
|
||||||
parameters_schema = create_schema(params_schema_dict)
|
|
||||||
declaration = FunctionDeclaration(
|
|
||||||
name=func_dict["name"],
|
|
||||||
description=func_dict.get("description", ""),
|
|
||||||
parameters=parameters_schema,
|
|
||||||
)
|
|
||||||
all_func_declarations.append(declaration)
|
|
||||||
except Exception as decl_err:
|
|
||||||
logger.error(f"Failed to create FunctionDeclaration for tool '{func_dict.get('name', 'Unknown')}': {decl_err}", exc_info=True)
|
|
||||||
|
|
||||||
if not all_func_declarations:
|
|
||||||
logger.warning("No valid function declarations found after conversion.")
|
|
||||||
return None
|
|
||||||
|
|
||||||
return [Tool(function_declarations=all_func_declarations)]
|
|
||||||
90
src/providers/google_provider/__init__.py
Normal file
90
src/providers/google_provider/__init__.py
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
# src/providers/google_provider/__init__.py
|
||||||
|
import logging
|
||||||
|
from collections.abc import Generator
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from google.genai.types import GenerateContentResponse
|
||||||
|
|
||||||
|
from providers.google_provider.client import initialize_client
|
||||||
|
from providers.google_provider.completion import create_chat_completion
|
||||||
|
from providers.google_provider.response import get_content, get_streaming_content, get_usage
|
||||||
|
from providers.google_provider.tools import convert_to_google_tools, format_google_tool_results, has_google_tool_calls, parse_google_tool_calls
|
||||||
|
from src.providers.base import BaseProvider
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleProvider(BaseProvider):
|
||||||
|
"""Provider implementation for Google Generative AI (Gemini)."""
|
||||||
|
|
||||||
|
# Type hint for the client (it's the configured 'genai' module itself)
|
||||||
|
client_module: Any
|
||||||
|
|
||||||
|
def __init__(self, api_key: str, base_url: str | None = None):
|
||||||
|
"""
|
||||||
|
Initializes the GoogleProvider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: The Google API key.
|
||||||
|
base_url: Base URL (typically not used by Google client config, but kept for interface consistency).
|
||||||
|
"""
|
||||||
|
# initialize_client returns the configured genai module
|
||||||
|
self.client_module = initialize_client(api_key, base_url)
|
||||||
|
self.api_key = api_key # Store if needed later
|
||||||
|
self.base_url = base_url # Store if needed later
|
||||||
|
logger.info("GoogleProvider initialized.")
|
||||||
|
|
||||||
|
def create_chat_completion(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
model: str,
|
||||||
|
temperature: float = 0.4,
|
||||||
|
max_tokens: int | None = None,
|
||||||
|
stream: bool = True,
|
||||||
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
) -> Any: # Return type is complex: iterator for stream, GenerateContentResponse otherwise, or error dict/iterator
|
||||||
|
"""Creates a chat completion using the Google Gemini API."""
|
||||||
|
# Pass self (provider instance) to the helper function
|
||||||
|
return create_chat_completion(self, messages, model, temperature, max_tokens, stream, tools)
|
||||||
|
|
||||||
|
def get_streaming_content(self, response: Any) -> Generator[str, None, None]:
|
||||||
|
"""Extracts content chunks from a Google streaming response."""
|
||||||
|
# Response is expected to be an iterator from generate_content(stream=True)
|
||||||
|
return get_streaming_content(response)
|
||||||
|
|
||||||
|
def get_content(self, response: GenerateContentResponse | dict[str, Any]) -> str:
|
||||||
|
"""Extracts the full text content from a non-streaming Google response."""
|
||||||
|
return get_content(response)
|
||||||
|
|
||||||
|
def has_tool_calls(self, response: GenerateContentResponse | dict[str, Any]) -> bool:
|
||||||
|
"""Checks if the Google response contains tool calls (FunctionCalls)."""
|
||||||
|
# Note: For streaming responses, this check is reliable only after the stream is fully consumed
|
||||||
|
# or if the specific chunk containing the call is processed.
|
||||||
|
return has_google_tool_calls(response)
|
||||||
|
|
||||||
|
def parse_tool_calls(self, response: GenerateContentResponse | dict[str, Any]) -> list[dict[str, Any]]:
|
||||||
|
"""Parses tool calls (FunctionCalls) from a non-streaming Google response."""
|
||||||
|
# Expects a non-streaming GenerateContentResponse or an error dict
|
||||||
|
return parse_google_tool_calls(response)
|
||||||
|
|
||||||
|
# Note: Google's format_tool_results helper requires the original function_name.
|
||||||
|
# Ensure the calling code (e.g., LLMClient) provides this when invoking this method.
|
||||||
|
def format_tool_results(self, tool_call_id: str, function_name: str, result: Any) -> dict[str, Any]:
|
||||||
|
"""Formats a tool result for a Google follow-up request (into standard message format)."""
|
||||||
|
return format_google_tool_results(tool_call_id, function_name, result)
|
||||||
|
|
||||||
|
def convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
|
"""Converts MCP tools list to Google's intermediate dictionary format."""
|
||||||
|
# The `create_chat_completion` function handles the final conversion
|
||||||
|
# from this intermediate format to Google's `Tool` objects internally.
|
||||||
|
return convert_to_google_tools(tools)
|
||||||
|
|
||||||
|
def get_usage(self, response: GenerateContentResponse | dict[str, Any]) -> dict[str, int] | None:
|
||||||
|
"""Extracts token usage information from a Google response."""
|
||||||
|
# Expects a non-streaming GenerateContentResponse or an error dict
|
||||||
|
return get_usage(response)
|
||||||
|
|
||||||
|
# `get_original_message_with_calls` (present in OpenAIProvider) is not implemented here
|
||||||
|
# as Google's API structure integrates FunctionCall parts directly into the assistant's
|
||||||
|
# message content, rather than having a separate `tool_calls` attribute on the message object.
|
||||||
|
# The necessary information is handled during message conversion and tool call parsing.
|
||||||
27
src/providers/google_provider/client.py
Normal file
27
src/providers/google_provider/client.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
# src/providers/google_provider/client.py
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from google import genai
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_client(api_key: str, base_url: str | None = None) -> Any:
|
||||||
|
"""Initializes and returns the Google Generative AI client module."""
|
||||||
|
logger.info("Initializing Google Generative AI client")
|
||||||
|
|
||||||
|
if genai is None:
|
||||||
|
logger.error("Google Generative AI SDK (google-generativeai) is not installed.")
|
||||||
|
raise ImportError("Google Generative AI SDK is required for GoogleProvider. Please install google-generativeai.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Configure the client
|
||||||
|
genai.configure(api_key=api_key)
|
||||||
|
if base_url:
|
||||||
|
logger.warning(f"base_url '{base_url}' provided but not typically used by Google client configuration.")
|
||||||
|
# Return the configured module itself, as it's used directly
|
||||||
|
return genai
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to configure Google Generative AI client: {e}", exc_info=True)
|
||||||
|
raise
|
||||||
140
src/providers/google_provider/completion.py
Normal file
140
src/providers/google_provider/completion.py
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
# src/providers/google_provider/completion.py
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import traceback
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from google.genai.types import Tool
|
||||||
|
|
||||||
|
from providers.google_provider.tools import convert_to_google_tool_objects, convert_to_google_tools
|
||||||
|
from providers.google_provider.utils import convert_messages
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def create_chat_completion(
|
||||||
|
provider,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
model: str,
|
||||||
|
temperature: float = 0.4,
|
||||||
|
max_tokens: int | None = None,
|
||||||
|
stream: bool = True,
|
||||||
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Creates a chat completion using the Google Gemini API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: The instance of the GoogleProvider.
|
||||||
|
messages: A list of message dictionaries in the standard format.
|
||||||
|
model: The model ID to use (e.g., "gemini-1.5-flash").
|
||||||
|
temperature: The sampling temperature.
|
||||||
|
max_tokens: The maximum number of tokens to generate.
|
||||||
|
stream: Whether to stream the response.
|
||||||
|
tools: A list of tool definitions in the MCP format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The response object from the Google API (could be a stream iterator or
|
||||||
|
a GenerateContentResponse object), or an error dictionary/iterator.
|
||||||
|
"""
|
||||||
|
logger.debug(f"Google create_chat_completion called. Model: {model}, Stream: {stream}, Tools: {bool(tools)}")
|
||||||
|
|
||||||
|
if provider.client_module is None:
|
||||||
|
error_msg = "Google Generative AI SDK not configured or installed."
|
||||||
|
logger.error(error_msg)
|
||||||
|
# Return an error structure compatible with both streaming and non-streaming expectations
|
||||||
|
if stream:
|
||||||
|
return iter([json.dumps({"error": error_msg})])
|
||||||
|
else:
|
||||||
|
return {"error": error_msg}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1. Convert messages to Google's format
|
||||||
|
google_messages, system_prompt = convert_messages(messages)
|
||||||
|
logger.debug(f"Converted {len(messages)} messages to {len(google_messages)} Google Content objects. System prompt present: {bool(system_prompt)}")
|
||||||
|
|
||||||
|
# 2. Prepare generation configuration
|
||||||
|
generation_config: dict[str, Any] = {"temperature": temperature}
|
||||||
|
if max_tokens is not None:
|
||||||
|
# Google uses 'max_output_tokens'
|
||||||
|
generation_config["max_output_tokens"] = max_tokens
|
||||||
|
logger.debug(f"Setting max_output_tokens: {max_tokens}")
|
||||||
|
else:
|
||||||
|
logger.debug("No max_tokens specified.")
|
||||||
|
|
||||||
|
# 3. Convert tools if provided
|
||||||
|
google_tool_objects: list[Tool] | None = None
|
||||||
|
if tools:
|
||||||
|
try:
|
||||||
|
# Step 3a: Convert MCP tools to intermediate Google dict format
|
||||||
|
tool_dict_list = convert_to_google_tools(tools)
|
||||||
|
# Step 3b: Convert intermediate dict format to Google Tool objects
|
||||||
|
google_tool_objects = convert_to_google_tool_objects(tool_dict_list)
|
||||||
|
if google_tool_objects:
|
||||||
|
logger.debug(f"Successfully converted {len(tools)} MCP tools to {len(google_tool_objects)} Google Tool objects.")
|
||||||
|
else:
|
||||||
|
logger.warning("Tool conversion resulted in no valid Google Tool objects.")
|
||||||
|
except Exception as tool_conv_err:
|
||||||
|
logger.error(f"Failed to convert tools for Google: {tool_conv_err}", exc_info=True)
|
||||||
|
# Decide whether to proceed without tools or raise an error
|
||||||
|
# Proceeding without tools for now, but logging the error.
|
||||||
|
google_tool_objects = None
|
||||||
|
else:
|
||||||
|
logger.debug("No tools provided for conversion.")
|
||||||
|
|
||||||
|
# 4. Initialize the Google Generative Model
|
||||||
|
# Ensure client_module is callable and has GenerativeModel
|
||||||
|
if not hasattr(provider.client_module, "GenerativeModel"):
|
||||||
|
raise AttributeError("Configured Google client module does not have 'GenerativeModel'")
|
||||||
|
|
||||||
|
gemini_model = provider.client_module.GenerativeModel(
|
||||||
|
model_name=model,
|
||||||
|
system_instruction=system_prompt, # Pass extracted system prompt
|
||||||
|
tools=google_tool_objects, # Pass converted Tool objects (or None)
|
||||||
|
# Add safety_settings if needed: safety_settings=...
|
||||||
|
)
|
||||||
|
logger.debug(f"Initialized Google GenerativeModel for '{model}'.")
|
||||||
|
|
||||||
|
# 5. Log parameters before API call
|
||||||
|
log_params = {
|
||||||
|
"model": model,
|
||||||
|
"stream": stream,
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_output_tokens": generation_config.get("max_output_tokens"),
|
||||||
|
"system_prompt_present": bool(system_prompt),
|
||||||
|
"num_tools": len(google_tool_objects) if google_tool_objects else 0,
|
||||||
|
"num_messages": len(google_messages),
|
||||||
|
}
|
||||||
|
logger.info(f"Calling Google generate_content API with params: {log_params}")
|
||||||
|
# Avoid logging full message content unless necessary for debugging specific issues
|
||||||
|
# logger.debug(f"Google messages being sent: {google_messages}")
|
||||||
|
|
||||||
|
# 6. Call the Google API
|
||||||
|
response = gemini_model.generate_content(
|
||||||
|
contents=google_messages,
|
||||||
|
generation_config=generation_config,
|
||||||
|
stream=stream,
|
||||||
|
# tool_config={"function_calling_config": "AUTO"} # AUTO is default
|
||||||
|
)
|
||||||
|
logger.debug("Google API call successful, returning response object.")
|
||||||
|
return response
|
||||||
|
|
||||||
|
except ValueError as ve: # Catch specific errors like invalid message sequence
|
||||||
|
error_msg = f"Google API request validation error: {ve}"
|
||||||
|
logger.error(error_msg, exc_info=True)
|
||||||
|
if stream:
|
||||||
|
# Yield a JSON error message in an iterator
|
||||||
|
yield json.dumps({"error": error_msg, "traceback": traceback.format_exc()})
|
||||||
|
else:
|
||||||
|
# Return an error dictionary
|
||||||
|
return {"error": error_msg, "traceback": traceback.format_exc()}
|
||||||
|
except Exception as e:
|
||||||
|
# Catch any other exceptions during setup or API call
|
||||||
|
error_msg = f"Google API error during chat completion: {e}"
|
||||||
|
logger.error(error_msg, exc_info=True)
|
||||||
|
if stream:
|
||||||
|
# Yield a JSON error message in an iterator
|
||||||
|
yield json.dumps({"error": error_msg, "traceback": traceback.format_exc()})
|
||||||
|
else:
|
||||||
|
# Return an error dictionary
|
||||||
|
return {"error": error_msg, "traceback": traceback.format_exc()}
|
||||||
205
src/providers/google_provider/response.py
Normal file
205
src/providers/google_provider/response.py
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
# src/providers/google_provider/response.py
|
||||||
|
"""
|
||||||
|
Response handling utilities specific to the Google Generative AI provider.
|
||||||
|
|
||||||
|
Includes functions for:
|
||||||
|
- Extracting content from streaming responses.
|
||||||
|
- Extracting content from non-streaming responses.
|
||||||
|
- Extracting token usage information.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from collections.abc import Generator
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from google.genai.types import GenerateContentResponse
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_streaming_content(response: Any) -> Generator[str, None, None]:
|
||||||
|
"""
|
||||||
|
Yields content chunks (text) from a Google streaming response iterator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: The streaming response iterator returned by `generate_content(stream=True)`.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
String chunks of the generated text content.
|
||||||
|
May yield JSON strings containing error information if errors occur during streaming.
|
||||||
|
"""
|
||||||
|
logger.debug("Processing Google stream...")
|
||||||
|
full_delta = ""
|
||||||
|
try:
|
||||||
|
# Check if the response itself is an error indicator (e.g., from create_chat_completion error handling)
|
||||||
|
if isinstance(response, dict) and "error" in response:
|
||||||
|
yield json.dumps(response)
|
||||||
|
logger.error(f"Stream processing stopped due to initial error: {response['error']}")
|
||||||
|
return
|
||||||
|
# Check if response is already an error iterator
|
||||||
|
if hasattr(response, "__iter__") and not hasattr(response, "candidates"):
|
||||||
|
# If it looks like an error iterator from create_chat_completion
|
||||||
|
first_item = next(response, None)
|
||||||
|
if first_item and isinstance(first_item, str):
|
||||||
|
try:
|
||||||
|
error_data = json.loads(first_item)
|
||||||
|
if "error" in error_data:
|
||||||
|
yield first_item # Yield the error JSON
|
||||||
|
yield from response
|
||||||
|
logger.error(f"Stream processing stopped due to yielded error: {error_data['error']}")
|
||||||
|
return
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# Not a JSON error, yield it as is and continue? Or stop?
|
||||||
|
# Assuming it might be valid content if not JSON error.
|
||||||
|
yield first_item
|
||||||
|
elif first_item: # Put the first item back if it wasn't an error
|
||||||
|
# This requires a way to chain iterators, simple yield doesn't work well here.
|
||||||
|
# For simplicity, we assume error iterators yield JSON strings.
|
||||||
|
# If the stream is valid, the loop below will handle it.
|
||||||
|
# Re-assigning response might be complex. Let the main loop handle valid streams.
|
||||||
|
pass # Let the main loop handle the original response iterator
|
||||||
|
|
||||||
|
# Process the stream chunk by chunk
|
||||||
|
for chunk in response:
|
||||||
|
# Check for errors embedded within the stream chunks (less common for Google?)
|
||||||
|
if isinstance(chunk, dict) and "error" in chunk:
|
||||||
|
yield json.dumps(chunk)
|
||||||
|
logger.error(f"Error encountered during Google stream: {chunk['error']}")
|
||||||
|
continue # Continue processing stream or stop? Continuing for now.
|
||||||
|
|
||||||
|
# Extract text content
|
||||||
|
delta = ""
|
||||||
|
try:
|
||||||
|
if hasattr(chunk, "text"):
|
||||||
|
delta = chunk.text
|
||||||
|
elif hasattr(chunk, "candidates") and chunk.candidates:
|
||||||
|
# Sometimes content might be nested under candidates even in stream?
|
||||||
|
# Check the first candidate's first part for text.
|
||||||
|
first_candidate = chunk.candidates[0]
|
||||||
|
if hasattr(first_candidate, "content") and hasattr(first_candidate.content, "parts") and first_candidate.content.parts:
|
||||||
|
first_part = first_candidate.content.parts[0]
|
||||||
|
if hasattr(first_part, "text"):
|
||||||
|
delta = first_part.text
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not extract text from stream chunk: {chunk}. Error: {e}", exc_info=True)
|
||||||
|
delta = "" # Ensure delta is a string
|
||||||
|
|
||||||
|
if delta:
|
||||||
|
full_delta += delta
|
||||||
|
yield delta
|
||||||
|
|
||||||
|
# Detect function calls during stream (optional, for logging/early detection)
|
||||||
|
try:
|
||||||
|
if hasattr(chunk, "candidates") and chunk.candidates:
|
||||||
|
for part in chunk.candidates[0].content.parts:
|
||||||
|
if hasattr(part, "function_call") and part.function_call:
|
||||||
|
logger.debug(f"Function call detected during stream: {part.function_call.name}")
|
||||||
|
# Note: We don't yield the function call itself here, just the text.
|
||||||
|
# Function calls are typically processed after the stream completes.
|
||||||
|
break # Found a function call in this chunk
|
||||||
|
except Exception:
|
||||||
|
# Ignore errors during optional function call detection in stream
|
||||||
|
pass
|
||||||
|
|
||||||
|
logger.debug(f"Google stream finished. Total delta length: {len(full_delta)}")
|
||||||
|
|
||||||
|
except StopIteration:
|
||||||
|
logger.debug("Google stream finished (StopIteration).") # Normal end of iteration
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing Google stream: {e}", exc_info=True)
|
||||||
|
# Yield a final error message
|
||||||
|
yield json.dumps({"error": f"Stream processing error: {str(e)}"})
|
||||||
|
|
||||||
|
|
||||||
|
def get_content(response: GenerateContentResponse | dict[str, Any]) -> str:
|
||||||
|
"""
|
||||||
|
Extracts the full text content from a non-streaming Google response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: The non-streaming response object (`GenerateContentResponse`) or
|
||||||
|
an error dictionary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The concatenated text content, or an error message string.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Handle error dictionary case
|
||||||
|
if isinstance(response, dict) and "error" in response:
|
||||||
|
logger.error(f"Cannot get content from error response: {response['error']}")
|
||||||
|
return f"[Error: {response['error']}]"
|
||||||
|
|
||||||
|
# Handle successful GenerateContentResponse object
|
||||||
|
if hasattr(response, "text"):
|
||||||
|
# The `.text` attribute usually provides the concatenated text content directly
|
||||||
|
content = response.text
|
||||||
|
logger.debug(f"Extracted content (length {len(content)}) from response.text.")
|
||||||
|
return content
|
||||||
|
elif hasattr(response, "candidates") and response.candidates:
|
||||||
|
# Fallback: manually concatenate text from parts if .text is missing
|
||||||
|
first_candidate = response.candidates[0]
|
||||||
|
if hasattr(first_candidate, "content") and hasattr(first_candidate.content, "parts"):
|
||||||
|
text_parts = []
|
||||||
|
for part in first_candidate.content.parts:
|
||||||
|
if hasattr(part, "text"):
|
||||||
|
text_parts.append(part.text)
|
||||||
|
# We are only interested in text content here, ignore function calls etc.
|
||||||
|
content = "".join(text_parts)
|
||||||
|
logger.debug(f"Extracted content (length {len(content)}) from response candidates' parts.")
|
||||||
|
return content
|
||||||
|
else:
|
||||||
|
logger.warning("Google response candidate has no content or parts.")
|
||||||
|
return "" # Return empty string if no text found
|
||||||
|
else:
|
||||||
|
logger.warning(f"Could not extract content from Google response: No 'text' or valid 'candidates'. Response type: {type(response)}")
|
||||||
|
return "" # Return empty string if no text found
|
||||||
|
except AttributeError as ae:
|
||||||
|
logger.error(f"Attribute error extracting content from Google response: {ae}. Response object: {response}", exc_info=True)
|
||||||
|
return f"[Error extracting content: Attribute missing - {str(ae)}]"
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error extracting content from Google response: {e}", exc_info=True)
|
||||||
|
return f"[Error extracting content: {str(e)}]"
|
||||||
|
|
||||||
|
|
||||||
|
def get_usage(response: GenerateContentResponse | dict[str, Any]) -> dict[str, int] | None:
|
||||||
|
"""
|
||||||
|
Extracts token usage information from a Google response object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: The response object (`GenerateContentResponse`) or an error dictionary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary containing 'prompt_tokens' and 'completion_tokens', or None if
|
||||||
|
usage information is unavailable or an error occurred.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Handle error dictionary case
|
||||||
|
if isinstance(response, dict) and "error" in response:
|
||||||
|
logger.warning("Cannot get usage from error response.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Check for usage metadata in the response object
|
||||||
|
if hasattr(response, "usage_metadata"):
|
||||||
|
metadata = response.usage_metadata
|
||||||
|
# Google uses prompt_token_count and candidates_token_count
|
||||||
|
usage = {
|
||||||
|
"prompt_tokens": getattr(metadata, "prompt_token_count", 0),
|
||||||
|
"completion_tokens": getattr(metadata, "candidates_token_count", 0),
|
||||||
|
# Google also provides total_token_count, could be added if needed
|
||||||
|
# "total_tokens": getattr(metadata, "total_token_count", 0),
|
||||||
|
}
|
||||||
|
# Ensure values are integers
|
||||||
|
usage = {k: int(v) for k, v in usage.items()}
|
||||||
|
logger.debug(f"Extracted usage from Google response metadata: {usage}")
|
||||||
|
return usage
|
||||||
|
else:
|
||||||
|
# Log a warning only if it's not clearly an error dict already handled
|
||||||
|
if not (isinstance(response, dict) and "error" in response):
|
||||||
|
logger.warning(f"Could not extract usage from Google response object of type {type(response)}. No 'usage_metadata' attribute found.")
|
||||||
|
return None
|
||||||
|
except AttributeError as ae:
|
||||||
|
logger.error(f"Attribute error extracting usage from Google response: {ae}. Response object: {response}", exc_info=True)
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error extracting usage from Google response: {e}", exc_info=True)
|
||||||
|
return None
|
||||||
359
src/providers/google_provider/tools.py
Normal file
359
src/providers/google_provider/tools.py
Normal file
@@ -0,0 +1,359 @@
|
|||||||
|
# src/providers/google_provider/tools.py
|
||||||
|
"""
|
||||||
|
Tool handling utilities specific to the Google Generative AI provider.
|
||||||
|
|
||||||
|
Includes functions for:
|
||||||
|
- Converting MCP tool definitions to Google's format.
|
||||||
|
- Creating Google Tool/FunctionDeclaration objects.
|
||||||
|
- Parsing tool calls (FunctionCalls) from Google responses.
|
||||||
|
- Formatting tool results for subsequent API calls.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from google.genai.types import FunctionDeclaration, Schema, Tool
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Tool Conversion (from MCP format to Google format) ---
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_google_tools(mcp_tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Convert MCP tools to Google Gemini format (dictionary structure).
|
||||||
|
|
||||||
|
This format is an intermediate step before creating Tool objects.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mcp_tools: List of MCP tools (each with server_name, name, description, inputSchema).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List containing one dictionary with 'function_declarations'.
|
||||||
|
Returns an empty list if no valid tools are provided or converted.
|
||||||
|
"""
|
||||||
|
logger.debug(f"Converting {len(mcp_tools)} MCP tools to Google Gemini format")
|
||||||
|
|
||||||
|
function_declarations = []
|
||||||
|
|
||||||
|
for tool in mcp_tools:
|
||||||
|
server_name = tool.get("server_name")
|
||||||
|
tool_name = tool.get("name")
|
||||||
|
description = tool.get("description")
|
||||||
|
input_schema = tool.get("inputSchema")
|
||||||
|
|
||||||
|
if not server_name or not tool_name or not description or not input_schema:
|
||||||
|
logger.warning(f"Skipping invalid MCP tool definition during Google conversion: {tool}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Prefix tool name with server name for routing
|
||||||
|
prefixed_tool_name = f"{server_name}__{tool_name}"
|
||||||
|
|
||||||
|
# Basic validation/cleaning of schema for Google compatibility
|
||||||
|
if not isinstance(input_schema, dict) or input_schema.get("type") != "object":
|
||||||
|
logger.warning(f"Input schema for tool '{prefixed_tool_name}' is not a valid JSON object schema. Google might reject this. Attempting to normalize.")
|
||||||
|
# Ensure basic structure if missing
|
||||||
|
if not isinstance(input_schema, dict):
|
||||||
|
input_schema = {} # Start fresh if not a dict
|
||||||
|
if "type" not in input_schema or input_schema["type"] != "object":
|
||||||
|
# Wrap existing schema or create new if type is wrong/missing
|
||||||
|
input_schema = {"type": "object", "properties": {"_original_schema": input_schema}} if input_schema else {"type": "object", "properties": {}}
|
||||||
|
logger.warning(f"Wrapped original schema for {prefixed_tool_name} under '_original_schema' property.")
|
||||||
|
|
||||||
|
if "properties" not in input_schema:
|
||||||
|
input_schema["properties"] = {}
|
||||||
|
|
||||||
|
# Google requires properties for object type, add dummy if empty
|
||||||
|
if not input_schema["properties"]:
|
||||||
|
logger.warning(f"Empty properties for tool '{prefixed_tool_name}', adding dummy property for Google.")
|
||||||
|
input_schema["properties"] = {"_dummy_param": {"type": "STRING", "description": "Placeholder parameter as properties cannot be empty."}}
|
||||||
|
if "required" in input_schema and not isinstance(input_schema.get("required"), list):
|
||||||
|
input_schema["required"] = [] # Clear invalid required list
|
||||||
|
|
||||||
|
# Create function declaration dictionary for Google's format
|
||||||
|
function_declaration = {
|
||||||
|
"name": prefixed_tool_name,
|
||||||
|
"description": description,
|
||||||
|
"parameters": input_schema, # Google uses JSON Schema directly
|
||||||
|
}
|
||||||
|
|
||||||
|
function_declarations.append(function_declaration)
|
||||||
|
logger.debug(f"Prepared Google FunctionDeclaration dict for: {prefixed_tool_name}")
|
||||||
|
|
||||||
|
# Google API expects a list containing one dictionary with 'function_declarations' key
|
||||||
|
google_tool_config = [{"function_declarations": function_declarations}] if function_declarations else []
|
||||||
|
|
||||||
|
logger.debug(f"Final Google tool config structure (pre-Tool object): {google_tool_config}")
|
||||||
|
return google_tool_config
|
||||||
|
|
||||||
|
|
||||||
|
def _create_google_schema_recursive(schema_dict: dict[str, Any]) -> Schema | None:
|
||||||
|
"""
|
||||||
|
Recursively creates Google Schema objects from a JSON schema dictionary.
|
||||||
|
|
||||||
|
Handles type mapping and nested structures. Returns None on failure.
|
||||||
|
"""
|
||||||
|
if Schema is None:
|
||||||
|
logger.error("Cannot create Schema object: google.genai types not available.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not isinstance(schema_dict, dict):
|
||||||
|
logger.warning(f"Invalid schema part encountered: {schema_dict}. Returning empty schema.")
|
||||||
|
return Schema() # Return empty schema to avoid breaking the parent
|
||||||
|
|
||||||
|
# Map JSON Schema types to Google's Type enum strings
|
||||||
|
type_mapping = {
|
||||||
|
"string": "STRING",
|
||||||
|
"number": "NUMBER",
|
||||||
|
"integer": "INTEGER",
|
||||||
|
"boolean": "BOOLEAN",
|
||||||
|
"array": "ARRAY",
|
||||||
|
"object": "OBJECT",
|
||||||
|
# Add other mappings if necessary
|
||||||
|
}
|
||||||
|
original_type = schema_dict.get("type")
|
||||||
|
google_type = type_mapping.get(str(original_type).lower()) if original_type else None
|
||||||
|
|
||||||
|
# Prepare arguments for Schema constructor, filtering out None values
|
||||||
|
schema_args = {
|
||||||
|
"type": google_type,
|
||||||
|
"format": schema_dict.get("format"),
|
||||||
|
"description": schema_dict.get("description"),
|
||||||
|
"nullable": schema_dict.get("nullable"),
|
||||||
|
"enum": schema_dict.get("enum"),
|
||||||
|
"items": _create_google_schema_recursive(schema_dict["items"]) if "items" in schema_dict and google_type == "ARRAY" else None,
|
||||||
|
"properties": {k: _create_google_schema_recursive(v) for k, v in schema_dict.get("properties", {}).items()} if schema_dict.get("properties") and google_type == "OBJECT" else None,
|
||||||
|
"required": schema_dict.get("required") if google_type == "OBJECT" else None,
|
||||||
|
}
|
||||||
|
# Remove keys with None values
|
||||||
|
schema_args = {k: v for k, v in schema_args.items() if v is not None}
|
||||||
|
|
||||||
|
if not schema_args.get("type"):
|
||||||
|
logger.warning(f"Schema dictionary missing 'type' or type '{original_type}' is not recognized: {schema_dict}. Creating empty Schema.")
|
||||||
|
return Schema() # Return empty schema
|
||||||
|
|
||||||
|
try:
|
||||||
|
return Schema(**schema_args)
|
||||||
|
except Exception as schema_creation_err:
|
||||||
|
logger.error(f"Failed to create Schema object with args {schema_args}: {schema_creation_err}", exc_info=True)
|
||||||
|
return Schema() # Return empty schema on error
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_google_tool_objects(tool_configs: list[dict[str, Any]]) -> list[Tool] | None:
|
||||||
|
"""
|
||||||
|
Convert the dictionary-based tool configurations into Google's Tool objects.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_configs: A list containing a dictionary with 'function_declarations',
|
||||||
|
as produced by `convert_to_google_tools`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list containing a single Google `Tool` object, or None if conversion fails
|
||||||
|
or no valid declarations are found.
|
||||||
|
"""
|
||||||
|
if Tool is None or FunctionDeclaration is None:
|
||||||
|
logger.error("Cannot create Tool objects: google.genai types not available.")
|
||||||
|
return None
|
||||||
|
if not tool_configs:
|
||||||
|
logger.debug("No tool configurations provided to convert to Tool objects.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
all_func_declarations = []
|
||||||
|
# Expecting structure like [{"function_declarations": [...]}]
|
||||||
|
if isinstance(tool_configs, list) and len(tool_configs) > 0 and "function_declarations" in tool_configs[0]:
|
||||||
|
func_declarations_list = tool_configs[0]["function_declarations"]
|
||||||
|
if not isinstance(func_declarations_list, list):
|
||||||
|
logger.error(f"Expected 'function_declarations' to be a list, got {type(func_declarations_list)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
for func_dict in func_declarations_list:
|
||||||
|
try:
|
||||||
|
params_schema_dict = func_dict.get("parameters", {"type": "object", "properties": {}})
|
||||||
|
# Ensure parameters is a valid schema dict for the recursive creator
|
||||||
|
if not isinstance(params_schema_dict, dict):
|
||||||
|
logger.warning(f"Invalid 'parameters' format for tool {func_dict.get('name')}: {params_schema_dict}. Using empty object schema.")
|
||||||
|
params_schema_dict = {"type": "object", "properties": {}}
|
||||||
|
elif params_schema_dict.get("type") != "object":
|
||||||
|
logger.warning(f"Tool {func_dict.get('name')} parameters schema is not type 'object'. Forcing object type.")
|
||||||
|
params_schema_dict = {"type": "object", "properties": params_schema_dict.get("properties", {})} # Attempt to salvage properties
|
||||||
|
|
||||||
|
parameters_schema = _create_google_schema_recursive(params_schema_dict)
|
||||||
|
|
||||||
|
# Only proceed if schema creation was somewhat successful
|
||||||
|
if parameters_schema is not None:
|
||||||
|
declaration = FunctionDeclaration(
|
||||||
|
name=func_dict["name"],
|
||||||
|
description=func_dict.get("description", ""),
|
||||||
|
parameters=parameters_schema,
|
||||||
|
)
|
||||||
|
all_func_declarations.append(declaration)
|
||||||
|
else:
|
||||||
|
logger.error(f"Failed to create parameters Schema for FunctionDeclaration '{func_dict.get('name', 'Unknown')}'")
|
||||||
|
|
||||||
|
except Exception as decl_err:
|
||||||
|
logger.error(f"Failed to create FunctionDeclaration object for tool '{func_dict.get('name', 'Unknown')}': {decl_err}", exc_info=True)
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.error(f"Invalid tool_configs structure provided: {tool_configs}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not all_func_declarations:
|
||||||
|
logger.warning("No valid Google FunctionDeclarations were created from the provided configurations.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Google expects a list containing one Tool object
|
||||||
|
logger.info(f"Successfully created {len(all_func_declarations)} Google FunctionDeclarations.")
|
||||||
|
return [Tool(function_declarations=all_func_declarations)]
|
||||||
|
|
||||||
|
|
||||||
|
# --- Tool Call Parsing and Handling (from Google response) ---
|
||||||
|
|
||||||
|
|
||||||
|
def has_google_tool_calls(response: Any) -> bool:
|
||||||
|
"""
|
||||||
|
Checks if the Google response object contains tool calls (FunctionCalls).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: The response object from the Google generate_content API call.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if FunctionCalls are present, False otherwise.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Check non-streaming response structure
|
||||||
|
if hasattr(response, "candidates") and response.candidates:
|
||||||
|
candidate = response.candidates[0]
|
||||||
|
if hasattr(candidate, "content") and hasattr(candidate.content, "parts"):
|
||||||
|
for part in candidate.content.parts:
|
||||||
|
if hasattr(part, "function_call") and part.function_call:
|
||||||
|
logger.debug(f"Tool call (FunctionCall) detected in Google response part: {part.function_call.name}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Note: Detecting function calls reliably in a stream might require accumulating parts.
|
||||||
|
# This function primarily works reliably for non-streaming responses.
|
||||||
|
# For streaming, the check might happen during stream processing itself.
|
||||||
|
|
||||||
|
logger.debug("No tool calls (FunctionCall) detected in Google response.")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error checking for Google tool calls: {e}", exc_info=True)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def parse_google_tool_calls(response: Any) -> list[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Parses tool calls (FunctionCalls) from a non-streaming Google response object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: The non-streaming response object from the Google generate_content API call.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of dictionaries, each representing a tool call in the standard MCP format
|
||||||
|
(id, server_name, function_name, arguments as JSON string).
|
||||||
|
Returns an empty list if no calls are found or an error occurs.
|
||||||
|
"""
|
||||||
|
parsed_calls = []
|
||||||
|
try:
|
||||||
|
if not (hasattr(response, "candidates") and response.candidates):
|
||||||
|
logger.warning("Cannot parse tool calls: Response has no candidates.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
candidate = response.candidates[0]
|
||||||
|
if not (hasattr(candidate, "content") and hasattr(candidate.content, "parts")):
|
||||||
|
logger.warning("Cannot parse tool calls: Response candidate has no content or parts.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
logger.debug("Parsing tool calls (FunctionCall) from Google response.")
|
||||||
|
call_index = 0
|
||||||
|
for part in candidate.content.parts:
|
||||||
|
if hasattr(part, "function_call") and part.function_call:
|
||||||
|
func_call = part.function_call
|
||||||
|
# Generate a simple unique ID for this call within this response
|
||||||
|
call_id = f"call_{call_index}"
|
||||||
|
call_index += 1
|
||||||
|
|
||||||
|
# Extract server_name and func_name from the prefixed name
|
||||||
|
full_name = func_call.name
|
||||||
|
parts = full_name.split("__", 1)
|
||||||
|
if len(parts) == 2:
|
||||||
|
server_name, func_name = parts
|
||||||
|
else:
|
||||||
|
# If the prefix isn't found, assume it's just the function name
|
||||||
|
logger.warning(f"Could not determine server_name from Google tool name '{full_name}'. Using None for server_name.")
|
||||||
|
server_name = None
|
||||||
|
func_name = full_name
|
||||||
|
|
||||||
|
# Convert arguments dict to JSON string
|
||||||
|
try:
|
||||||
|
# func_call.args is already a dict-like object (Mapping)
|
||||||
|
args_dict = dict(func_call.args) if func_call.args else {}
|
||||||
|
args_str = json.dumps(args_dict)
|
||||||
|
except Exception as json_err:
|
||||||
|
logger.error(f"Failed to dump arguments dict to JSON string for {func_name}: {json_err}")
|
||||||
|
# Provide error info in arguments if serialization fails
|
||||||
|
args_str = json.dumps({"error": "Failed to serialize arguments", "original_args": str(func_call.args)})
|
||||||
|
|
||||||
|
parsed_calls.append({
|
||||||
|
"id": call_id, # Internal ID for tracking this call
|
||||||
|
"server_name": server_name,
|
||||||
|
"function_name": func_name, # The original function name
|
||||||
|
"arguments": args_str, # Arguments as a JSON string
|
||||||
|
"_google_tool_name": full_name, # Keep original name if needed later
|
||||||
|
})
|
||||||
|
logger.debug(f"Parsed tool call: ID {call_id}, Server {server_name}, Func {func_name}, Args {args_str[:100]}...")
|
||||||
|
|
||||||
|
return parsed_calls
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error parsing Google tool calls: {e}", exc_info=True)
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def format_google_tool_results(tool_call_id: str, function_name: str, result: Any) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Formats a tool result for a Google follow-up request (FunctionResponse).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_call_id: The unique ID assigned during parsing (e.g., "call_0").
|
||||||
|
Note: Google's API itself doesn't use this ID directly in the
|
||||||
|
FunctionResponse part, but we need it for mapping in the message list.
|
||||||
|
function_name: The original function name (without server prefix) that was called.
|
||||||
|
result: The data returned by the tool execution. Should be JSON-serializable.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary representing the tool result message in the standard MCP format.
|
||||||
|
This will be converted later by `_convert_messages`.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Google expects the 'response' field in FunctionResponse to contain a dict.
|
||||||
|
# The content should ideally be JSON serializable. We wrap the result.
|
||||||
|
if isinstance(result, (str, int, float, bool, list)):
|
||||||
|
content_dict = {"result": result}
|
||||||
|
elif isinstance(result, dict):
|
||||||
|
content_dict = result # Assume it's already a suitable dict
|
||||||
|
else:
|
||||||
|
logger.warning(f"Tool result for {function_name} is of non-standard type {type(result)}. Converting to string.")
|
||||||
|
content_dict = {"result": str(result)}
|
||||||
|
|
||||||
|
# Ensure the content is JSON serializable for the 'content' field
|
||||||
|
try:
|
||||||
|
content_str = json.dumps(content_dict)
|
||||||
|
except Exception as json_err:
|
||||||
|
logger.error(f"Error JSON-encoding tool result content for Google {function_name} ({tool_call_id}): {json_err}")
|
||||||
|
content_str = json.dumps({"error": "Failed to encode tool result content", "original_type": str(type(result))})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error preparing tool result content for Google {function_name} ({tool_call_id}): {e}")
|
||||||
|
content_str = json.dumps({"error": "Failed to prepare tool result content", "details": str(e)})
|
||||||
|
|
||||||
|
logger.debug(f"Formatting Google tool result for call ID {tool_call_id} (Function: {function_name})")
|
||||||
|
# Return in the standard message format, _convert_messages will handle Google's structure
|
||||||
|
return {
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": tool_call_id, # Used by _convert_messages to find the original call
|
||||||
|
"content": content_str, # The JSON string representing the result content
|
||||||
|
"name": function_name, # Store original function name for _convert_messages
|
||||||
|
# Note: Google's FunctionResponse Part needs 'name' and 'response' (dict).
|
||||||
|
# This standard format will be converted by the provider's message conversion logic.
|
||||||
|
}
|
||||||
150
src/providers/google_provider/utils.py
Normal file
150
src/providers/google_provider/utils.py
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
# src/providers/google_provider/utils.py
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from google.genai.types import Content, Part
|
||||||
|
|
||||||
|
from src.llm_models import MODELS
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_context_window(model: str) -> int:
|
||||||
|
"""Retrieves the context window size for a given Google model."""
|
||||||
|
default_window = 1000000 # Default fallback for Gemini
|
||||||
|
try:
|
||||||
|
provider_models = MODELS.get("google", {}).get("models", [])
|
||||||
|
for m in provider_models:
|
||||||
|
if m.get("id") == model:
|
||||||
|
return m.get("context_window", default_window)
|
||||||
|
logger.warning(f"Context window for Google model '{model}' not found in MODELS config. Using default: {default_window}")
|
||||||
|
return default_window
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error retrieving context window for model {model}: {e}. Using default: {default_window}", exc_info=True)
|
||||||
|
return default_window
|
||||||
|
|
||||||
|
|
||||||
|
def convert_messages(messages: list[dict[str, Any]]) -> tuple[list[Content], str | None]:
|
||||||
|
"""
|
||||||
|
Converts standard message format to Google's format, extracting system prompt.
|
||||||
|
Handles mapping roles and structuring tool calls/results.
|
||||||
|
"""
|
||||||
|
google_messages: list[Content] = []
|
||||||
|
system_prompt: str | None = None
|
||||||
|
|
||||||
|
for i, message in enumerate(messages):
|
||||||
|
role = message.get("role")
|
||||||
|
content = message.get("content")
|
||||||
|
tool_calls = message.get("tool_calls")
|
||||||
|
tool_call_id = message.get("tool_call_id")
|
||||||
|
|
||||||
|
if role == "system":
|
||||||
|
if i == 0:
|
||||||
|
system_prompt = content
|
||||||
|
logger.debug("Extracted system prompt for Google.")
|
||||||
|
else:
|
||||||
|
# Google API expects system prompt only at the beginning.
|
||||||
|
# If found later, log a warning and skip or merge if possible (though merging is complex).
|
||||||
|
logger.warning("System message found not at the beginning. Skipping for Google API.")
|
||||||
|
continue # Skip adding system messages to the main list
|
||||||
|
|
||||||
|
# Map roles: 'assistant' -> 'model', 'tool' -> 'function' (handled below)
|
||||||
|
google_role = {"user": "user", "assistant": "model"}.get(role)
|
||||||
|
|
||||||
|
if not google_role and role != "tool":
|
||||||
|
logger.warning(f"Unsupported role '{role}' for Google provider, skipping message.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
parts: list[Part | str] = []
|
||||||
|
if role == "tool":
|
||||||
|
# Tool results are mapped to 'function' role in Google API
|
||||||
|
if tool_call_id and content:
|
||||||
|
try:
|
||||||
|
# Attempt to parse the content as JSON, assuming it's the tool output
|
||||||
|
response_content_dict = json.loads(content)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning(f"Could not decode tool result content for {tool_call_id}, sending as raw string.")
|
||||||
|
response_content_dict = {"result": content} # Wrap raw string if not JSON
|
||||||
|
|
||||||
|
# Find the original function name from the preceding assistant message
|
||||||
|
func_name = "unknown_function" # Default if name can't be found
|
||||||
|
if i > 0 and messages[i - 1].get("role") == "assistant":
|
||||||
|
prev_tool_calls = messages[i - 1].get("tool_calls")
|
||||||
|
if prev_tool_calls:
|
||||||
|
for tc in prev_tool_calls:
|
||||||
|
# Match based on the ID provided in the tool message
|
||||||
|
if tc.get("id") == tool_call_id:
|
||||||
|
# Google uses 'server__func' format, extract original func name if possible
|
||||||
|
full_name = tc.get("function_name", "unknown_function")
|
||||||
|
func_name = full_name.split("__", 1)[-1] # Get the part after '__' or the full name
|
||||||
|
break
|
||||||
|
|
||||||
|
# Create a FunctionResponse part
|
||||||
|
parts.append(Part.from_function_response(name=func_name, response={"content": response_content_dict}))
|
||||||
|
google_role = "function" # Explicitly set role for tool results
|
||||||
|
else:
|
||||||
|
logger.warning(f"Skipping tool message due to missing tool_call_id or content: {message}")
|
||||||
|
continue # Skip if essential parts are missing
|
||||||
|
|
||||||
|
elif role == "assistant" and tool_calls:
|
||||||
|
# Assistant message requesting tool calls
|
||||||
|
for tool_call in tool_calls:
|
||||||
|
args = tool_call.get("arguments", {})
|
||||||
|
# Ensure arguments are a dict, not a string
|
||||||
|
if isinstance(args, str):
|
||||||
|
try:
|
||||||
|
args = json.loads(args)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.error(f"Failed to parse arguments string for tool call {tool_call.get('id')}: {args}")
|
||||||
|
args = {"error": "failed to parse arguments"} # Provide error feedback
|
||||||
|
|
||||||
|
# Google uses 'server__func' format, extract original func name if possible
|
||||||
|
full_name = tool_call.get("function_name", "unknown_function")
|
||||||
|
func_name = full_name.split("__", 1)[-1] # Get the part after '__' or the full name
|
||||||
|
|
||||||
|
# Create a FunctionCall part
|
||||||
|
parts.append(Part.from_function_call(name=func_name, args=args))
|
||||||
|
|
||||||
|
# Include any text content alongside the function calls
|
||||||
|
if content and isinstance(content, str):
|
||||||
|
parts.append(Part.from_text(content))
|
||||||
|
|
||||||
|
elif content:
|
||||||
|
# Regular user or assistant message content
|
||||||
|
if isinstance(content, str):
|
||||||
|
parts.append(Part.from_text(content))
|
||||||
|
# TODO: Handle potential image content if needed in the future
|
||||||
|
else:
|
||||||
|
logger.warning(f"Unsupported content type for role '{role}': {type(content)}. Converting to string.")
|
||||||
|
parts.append(Part.from_text(str(content)))
|
||||||
|
|
||||||
|
# Add the constructed Content object if parts were generated
|
||||||
|
if parts:
|
||||||
|
google_messages.append(Content(role=google_role, parts=parts))
|
||||||
|
else:
|
||||||
|
# Log if a message resulted in no parts (e.g., empty content, skipped system message)
|
||||||
|
logger.debug(f"No parts generated for message: {message}")
|
||||||
|
|
||||||
|
# Validate message alternation (user -> model -> user/function -> user -> ...)
|
||||||
|
last_role = None
|
||||||
|
valid_alternation = True
|
||||||
|
for msg in google_messages:
|
||||||
|
current_role = msg.role
|
||||||
|
# Check for consecutive user/model roles
|
||||||
|
if current_role == last_role and current_role in ["user", "model"]:
|
||||||
|
valid_alternation = False
|
||||||
|
logger.error(f"Invalid role sequence for Google: consecutive '{current_role}' roles.")
|
||||||
|
break
|
||||||
|
# Check if 'function' role is followed by 'user'
|
||||||
|
if last_role == "function" and current_role != "user":
|
||||||
|
valid_alternation = False
|
||||||
|
logger.error(f"Invalid role sequence for Google: '{current_role}' follows 'function'. Expected 'user'.")
|
||||||
|
break
|
||||||
|
last_role = current_role
|
||||||
|
|
||||||
|
# Raise error if alternation is invalid, as Google API enforces this
|
||||||
|
if not valid_alternation:
|
||||||
|
raise ValueError("Invalid message sequence for Google API. Roles must alternate between 'user' and 'model', with 'function' responses followed by 'user'.")
|
||||||
|
|
||||||
|
return google_messages, system_prompt
|
||||||
@@ -1,6 +0,0 @@
|
|||||||
# src/tools/__init__.py
|
|
||||||
# This file makes the 'tools' directory a Python package.
|
|
||||||
|
|
||||||
# Optionally import key functions/classes for easier access
|
|
||||||
# from .conversion import convert_to_openai_tools, convert_to_anthropic_tools
|
|
||||||
# from .execution import execute_tool # Assuming execution.py will exist
|
|
||||||
@@ -1,77 +0,0 @@
|
|||||||
"""
|
|
||||||
Conversion utilities for MCP tools.
|
|
||||||
|
|
||||||
This module contains functions to convert between different tool formats
|
|
||||||
for various LLM providers (OpenAI, Anthropic, etc.).
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def convert_to_google_tools(mcp_tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Convert MCP tools to Google Gemini format (dictionary structure).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mcp_tools: List of MCP tools (each with server_name, name, description, inputSchema).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List containing one dictionary with 'function_declarations'.
|
|
||||||
"""
|
|
||||||
logger.debug(f"Converting {len(mcp_tools)} MCP tools to Google Gemini format")
|
|
||||||
|
|
||||||
function_declarations = []
|
|
||||||
|
|
||||||
for tool in mcp_tools:
|
|
||||||
server_name = tool.get("server_name")
|
|
||||||
tool_name = tool.get("name")
|
|
||||||
description = tool.get("description")
|
|
||||||
input_schema = tool.get("inputSchema")
|
|
||||||
|
|
||||||
if not server_name or not tool_name or not description or not input_schema:
|
|
||||||
logger.warning(f"Skipping invalid MCP tool definition during Google conversion: {tool}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Prefix tool name with server name for routing
|
|
||||||
prefixed_tool_name = f"{server_name}__{tool_name}"
|
|
||||||
|
|
||||||
# Basic validation/cleaning of schema
|
|
||||||
if not isinstance(input_schema, dict) or input_schema.get("type") != "object":
|
|
||||||
logger.warning(f"Input schema for tool '{prefixed_tool_name}' is not a valid JSON object schema. Google might reject this.")
|
|
||||||
# Ensure basic structure if missing
|
|
||||||
if not isinstance(input_schema, dict):
|
|
||||||
input_schema = {}
|
|
||||||
if "type" not in input_schema:
|
|
||||||
input_schema["type"] = "object"
|
|
||||||
if "properties" not in input_schema:
|
|
||||||
input_schema["properties"] = {}
|
|
||||||
# Google requires properties for object type, add dummy if empty
|
|
||||||
if not input_schema["properties"]:
|
|
||||||
logger.warning(f"Empty properties for tool '{prefixed_tool_name}', adding dummy property for Google.")
|
|
||||||
input_schema["properties"] = {"_dummy_param": {"type": "STRING", "description": "Placeholder"}}
|
|
||||||
|
|
||||||
# Create function declaration for Google's format
|
|
||||||
function_declaration = {
|
|
||||||
"name": prefixed_tool_name,
|
|
||||||
"description": description,
|
|
||||||
"parameters": input_schema, # Google uses JSON Schema directly
|
|
||||||
}
|
|
||||||
|
|
||||||
function_declarations.append(function_declaration)
|
|
||||||
logger.debug(f"Converted MCP tool to Google FunctionDeclaration: {prefixed_tool_name}")
|
|
||||||
|
|
||||||
# Google API expects a list containing one dictionary with 'function_declarations'
|
|
||||||
# The provider's _convert_to_tool_objects will handle creating Tool objects from this.
|
|
||||||
google_tool_config = [{"function_declarations": function_declarations}] if function_declarations else []
|
|
||||||
|
|
||||||
logger.debug(f"Final Google tool config structure: {google_tool_config}")
|
|
||||||
return google_tool_config
|
|
||||||
|
|
||||||
|
|
||||||
# Note: The _handle_schema_construct helper from the reference code is not strictly
|
|
||||||
# needed if we assume the inputSchema is already valid JSON Schema.
|
|
||||||
# If complex schemas (anyOf, etc.) need specific handling beyond standard JSON Schema,
|
|
||||||
# that logic could be added here or within the provider implementations.
|
|
||||||
Reference in New Issue
Block a user