feat: add GoogleProvider implementation and update conversion utilities for Google tools

This commit is contained in:
2025-03-26 18:18:10 +00:00
parent 15ecb9fc48
commit 246d921743
3 changed files with 490 additions and 6 deletions

View File

@@ -3,9 +3,9 @@ import logging
from providers.anthropic_provider import AnthropicProvider from providers.anthropic_provider import AnthropicProvider
from providers.base import BaseProvider from providers.base import BaseProvider
from providers.google_provider import GoogleProvider
from providers.openai_provider import OpenAIProvider from providers.openai_provider import OpenAIProvider
# from providers.google_provider import GoogleProvider
# from providers.openrouter_provider import OpenRouterProvider # from providers.openrouter_provider import OpenRouterProvider
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
PROVIDER_MAP: dict[str, type[BaseProvider]] = { PROVIDER_MAP: dict[str, type[BaseProvider]] = {
"openai": OpenAIProvider, "openai": OpenAIProvider,
"anthropic": AnthropicProvider, "anthropic": AnthropicProvider,
# "google": GoogleProvider, "google": GoogleProvider,
# "openrouter": OpenRouterProvider, # OpenRouter can often use OpenAIProvider with custom base_url # "openrouter": OpenRouterProvider, # OpenRouter can often use OpenAIProvider with custom base_url
} }

View File

@@ -0,0 +1,483 @@
# src/providers/google_provider.py
import json
import logging
import traceback
from collections.abc import Generator
from typing import Any
from google import genai
from google.genai.types import (
Content,
FunctionDeclaration,
Part,
Schema,
Tool,
)
from src.llm_models import MODELS
from src.providers.base import BaseProvider
from src.tools.conversion import convert_to_google_tools
logger = logging.getLogger(__name__)
class GoogleProvider(BaseProvider):
"""Provider implementation for Google Gemini models."""
def __init__(self, api_key: str, base_url: str | None = None):
# Google client typically doesn't use a base_url, but we accept it for consistency
effective_base_url = base_url or MODELS.get("google", {}).get("endpoint")
super().__init__(api_key, effective_base_url)
logger.info("Initializing GoogleProvider")
if genai is None:
raise ImportError("Google Generative AI SDK is required for GoogleProvider. Please install google-generativeai.")
try:
# Configure the client
genai.configure(api_key=self.api_key)
self.client_module = genai
except Exception as e:
logger.error(f"Failed to configure Google Generative AI client: {e}", exc_info=True)
raise
def _get_context_window(self, model: str) -> int:
"""Retrieves the context window size for a given Google model."""
default_window = 1000000 # Default fallback for Gemini
try:
provider_models = MODELS.get("google", {}).get("models", [])
for m in provider_models:
if m.get("id") == model:
return m.get("context_window", default_window)
logger.warning(f"Context window for Google model '{model}' not found in MODELS config. Using default: {default_window}")
return default_window
except Exception as e:
logger.error(f"Error retrieving context window for model {model}: {e}. Using default: {default_window}", exc_info=True)
return default_window
def _convert_messages(self, messages: list[dict[str, Any]]) -> tuple[list[Content], str | None]:
"""
Converts standard message format to Google's format, extracting system prompt.
Handles mapping roles and structuring tool calls/results.
"""
google_messages: list[Content] = []
system_prompt: str | None = None
for i, message in enumerate(messages):
role = message.get("role")
content = message.get("content")
tool_calls = message.get("tool_calls")
tool_call_id = message.get("tool_call_id")
if role == "system":
if i == 0:
system_prompt = content
logger.debug("Extracted system prompt for Google.")
else:
logger.warning("System message found not at the beginning. Merging into subsequent user message.")
continue
google_role = {"user": "user", "assistant": "model", "tool": "user"}.get(role)
if not google_role:
logger.warning(f"Unsupported role '{role}' for Google provider, skipping message.")
continue
parts: list[Part | str] = []
if role == "tool":
if tool_call_id and content:
try:
response_content_dict = json.loads(content)
except json.JSONDecodeError:
logger.warning(f"Could not decode tool result content for {tool_call_id}, sending as raw string.")
response_content_dict = {"result": content}
func_name = "unknown_function"
if i > 0 and messages[i - 1].get("role") == "assistant":
prev_tool_calls = messages[i - 1].get("tool_calls")
if prev_tool_calls:
for tc in prev_tool_calls:
if tc.get("id") == tool_call_id:
func_name = tc.get("function_name", "unknown_function")
break
parts.append(Part.from_function_response(name=func_name, response={"content": response_content_dict}))
google_role = "function"
else:
logger.warning(f"Skipping tool message due to missing tool_call_id or content: {message}")
continue
elif role == "assistant" and tool_calls:
for tool_call in tool_calls:
args = tool_call.get("arguments", {})
if isinstance(args, str):
try:
args = json.loads(args)
except json.JSONDecodeError:
logger.error(f"Failed to parse arguments string for tool call {tool_call.get('id')}: {args}")
args = {"error": "failed to parse arguments"}
func_name = tool_call.get("function_name", "unknown_function")
parts.append(Part.from_function_call(name=func_name, args=args))
if content:
parts.append(Part.from_text(content))
elif content:
if isinstance(content, str):
parts.append(Part.from_text(content))
else:
logger.warning(f"Unsupported content type for role '{role}': {type(content)}. Converting to string.")
parts.append(Part.from_text(str(content)))
if parts:
google_messages.append(Content(role=google_role, parts=parts))
else:
logger.debug(f"No parts generated for message: {message}")
last_role = None
valid_alternation = True
for msg in google_messages:
current_role = msg.role
if current_role == last_role and current_role in ["user", "model"]:
valid_alternation = False
logger.warning(f"Invalid role sequence detected: consecutive '{current_role}' roles.")
break
if last_role == "function" and current_role != "user":
valid_alternation = False
logger.warning(f"Invalid role sequence: '{current_role}' follows 'function'. Expected 'user'.")
break
last_role = current_role
if not valid_alternation:
logger.error("Message list does not follow required user/model alternation for Google API.")
raise ValueError("Invalid message sequence for Google API.")
return google_messages, system_prompt
def create_chat_completion(
self,
messages: list[dict[str, str]],
model: str,
temperature: float = 0.4,
max_tokens: int | None = None,
stream: bool = True,
tools: list[dict[str, Any]] | None = None,
) -> Any:
"""Creates a chat completion using the Google Gemini API."""
logger.debug(f"Google create_chat_completion called. Model: {model}, Stream: {stream}, Tools: {bool(tools)}")
if self.client_module is None:
return {"error": "Google Generative AI SDK not installed."} if not stream else iter([json.dumps({"error": "Google Generative AI SDK not installed."})])
try:
google_messages, system_prompt = self._convert_messages(messages)
generation_config: dict[str, Any] = {"temperature": temperature}
if max_tokens is not None:
generation_config["max_output_tokens"] = max_tokens
google_tools = None
if tools:
try:
tool_dict_list = convert_to_google_tools(tools)
google_tools = self._convert_to_tool_objects(tool_dict_list)
logger.debug(f"Converted {len(tools)} tools to {len(google_tools)} Google Tool objects.")
except Exception as tool_conv_err:
logger.error(f"Failed to convert tools for Google: {tool_conv_err}", exc_info=True)
google_tools = None
gemini_model = self.client_module.GenerativeModel(
model_name=model,
system_instruction=system_prompt,
tools=google_tools if google_tools else None,
)
log_params = {
"model": model,
"stream": stream,
"temperature": temperature,
"max_tokens": max_tokens,
"system_prompt_present": bool(system_prompt),
"num_tools": len(google_tools) if google_tools else 0,
"num_messages": len(google_messages),
}
logger.debug(f"Calling Google API with params: {log_params}")
response = gemini_model.generate_content(
contents=google_messages,
generation_config=generation_config,
stream=stream,
)
logger.debug("Google API call successful.")
return response
except Exception as e:
error_msg = f"Google API error: {e}"
logger.error(error_msg, exc_info=True)
if stream:
yield json.dumps({"error": error_msg, "traceback": traceback.format_exc()})
else:
return {"error": error_msg, "traceback": traceback.format_exc()}
def get_streaming_content(self, response: Any) -> Generator[str, None, None]:
"""Yields content chunks from a Google streaming response."""
logger.debug("Processing Google stream...")
full_delta = ""
try:
if isinstance(response, dict) and "error" in response:
yield json.dumps(response)
return
if hasattr(response, "__iter__") and not hasattr(response, "candidates"):
yield from response
return
for chunk in response:
if isinstance(chunk, dict) and "error" in chunk:
yield json.dumps(chunk)
continue
if hasattr(chunk, "text"):
delta = chunk.text
if delta:
full_delta += delta
yield delta
elif hasattr(chunk, "candidates") and chunk.candidates:
for part in chunk.candidates[0].content.parts:
if hasattr(part, "function_call") and part.function_call:
logger.debug(f"Function call detected during stream: {part.function_call.name}")
break
logger.debug(f"Google stream finished. Total delta length: {len(full_delta)}")
except Exception as e:
logger.error(f"Error processing Google stream: {e}", exc_info=True)
yield json.dumps({"error": f"Stream processing error: {str(e)}"})
def get_content(self, response: Any) -> str:
"""Extracts content from a non-streaming Google response."""
try:
if isinstance(response, dict) and "error" in response:
logger.error(f"Cannot get content from error response: {response['error']}")
return f"[Error: {response['error']}]"
if hasattr(response, "text"):
content = response.text
logger.debug(f"Extracted content (length {len(content)}) from response.text.")
return content
elif hasattr(response, "candidates") and response.candidates:
first_candidate = response.candidates[0]
if hasattr(first_candidate, "content") and hasattr(first_candidate.content, "parts"):
text_parts = [part.text for part in first_candidate.content.parts if hasattr(part, "text")]
content = "".join(text_parts)
logger.debug(f"Extracted content (length {len(content)}) from response candidates.")
return content
else:
logger.warning("Google response candidate has no content or parts.")
return ""
else:
logger.warning("Could not extract content from Google response: No 'text' or valid 'candidates'.")
return ""
except Exception as e:
logger.error(f"Error extracting content from Google response: {e}", exc_info=True)
return f"[Error extracting content: {str(e)}]"
def has_tool_calls(self, response: Any) -> bool:
"""Checks if the Google response contains tool calls (function calls)."""
try:
if isinstance(response, dict) and "error" in response:
return False
if hasattr(response, "candidates") and response.candidates:
candidate = response.candidates[0]
if hasattr(candidate, "content") and hasattr(candidate.content, "parts"):
for part in candidate.content.parts:
if hasattr(part, "function_call") and part.function_call:
logger.debug(f"Tool call (FunctionCall) detected in Google response part: {part.function_call.name}")
return True
logger.debug("No tool calls (FunctionCall) detected in Google response.")
return False
except Exception as e:
logger.error(f"Error checking for Google tool calls: {e}", exc_info=True)
return False
def parse_tool_calls(self, response: Any) -> list[dict[str, Any]]:
"""Parses tool calls (function calls) from a non-streaming Google response."""
parsed_calls = []
try:
if not (hasattr(response, "candidates") and response.candidates):
logger.warning("Cannot parse tool calls: Response has no candidates.")
return []
candidate = response.candidates[0]
if not (hasattr(candidate, "content") and hasattr(candidate.content, "parts")):
logger.warning("Cannot parse tool calls: Response candidate has no content or parts.")
return []
logger.debug("Parsing tool calls (FunctionCall) from Google response.")
call_index = 0
for part in candidate.content.parts:
if hasattr(part, "function_call") and part.function_call:
func_call = part.function_call
call_id = f"call_{call_index}"
call_index += 1
full_name = func_call.name
parts = full_name.split("__", 1)
if len(parts) == 2:
server_name, func_name = parts
else:
logger.warning(f"Could not determine server_name from Google tool name '{full_name}'.")
server_name = None
func_name = full_name
try:
args_str = json.dumps(func_call.args or {})
except Exception as json_err:
logger.error(f"Failed to dump arguments dict to JSON string for {func_name}: {json_err}")
args_str = json.dumps({"error": "Failed to serialize arguments", "original_args": str(func_call.args)})
parsed_calls.append({
"id": call_id,
"server_name": server_name,
"function_name": func_name,
"arguments": args_str,
})
logger.debug(f"Parsed tool call: ID {call_id}, Server {server_name}, Func {func_name}, Args {args_str[:100]}...")
return parsed_calls
except Exception as e:
logger.error(f"Error parsing Google tool calls: {e}", exc_info=True)
return []
def format_tool_results(self, tool_call_id: str, result: Any) -> dict[str, Any]:
"""Formats a tool result for a Google follow-up request."""
try:
if isinstance(result, dict):
content_str = json.dumps(result)
else:
content_str = str(result)
except Exception as e:
logger.error(f"Error JSON-encoding tool result for Google {tool_call_id}: {e}")
content_str = json.dumps({"error": "Failed to encode tool result", "original_type": str(type(result))})
logger.debug(f"Formatting Google tool result for call ID {tool_call_id}")
return {
"role": "tool",
"tool_call_id": tool_call_id,
"content": content_str,
"function_name": "unknown_function",
}
def get_original_message_with_calls(self, response: Any) -> dict[str, Any]:
"""Extracts the assistant's message containing tool calls for Google."""
try:
if not (hasattr(response, "candidates") and response.candidates):
return {"role": "assistant", "content": "[Could not extract tool calls message: No candidates]"}
candidate = response.candidates[0]
if not (hasattr(candidate, "content") and hasattr(candidate.content, "parts")):
return {"role": "assistant", "content": "[Could not extract tool calls message: No content/parts]"}
tool_calls_formatted = []
text_content_parts = []
for part in candidate.content.parts:
if hasattr(part, "function_call") and part.function_call:
func_call = part.function_call
args = func_call.args or {}
tool_calls_formatted.append({
"function_name": func_call.name,
"arguments": args,
})
elif hasattr(part, "text"):
text_content_parts.append(part.text)
message = {"role": "assistant"}
if tool_calls_formatted:
message["tool_calls"] = tool_calls_formatted
text_content = "".join(text_content_parts)
if text_content:
message["content"] = text_content
elif not tool_calls_formatted:
message["content"] = ""
return message
except Exception as e:
logger.error(f"Error extracting original Google message with calls: {e}", exc_info=True)
return {"role": "assistant", "content": f"[Error extracting tool calls message: {str(e)}]"}
def get_usage(self, response: Any) -> dict[str, int] | None:
"""Extracts token usage from a Google response."""
try:
if isinstance(response, dict) and "error" in response:
return None
if hasattr(response, "usage_metadata"):
metadata = response.usage_metadata
usage = {
"prompt_tokens": getattr(metadata, "prompt_token_count", 0),
"completion_tokens": getattr(metadata, "candidates_token_count", 0),
}
logger.debug(f"Extracted usage from Google response metadata: {usage}")
return usage
else:
logger.warning(f"Could not extract usage from Google response object of type {type(response)}. No 'usage_metadata'.")
return None
except Exception as e:
logger.error(f"Error extracting usage from Google response: {e}", exc_info=True)
return None
def _convert_to_tool_objects(self, tool_configs: list[dict[str, Any]]) -> list[Tool] | None:
"""Convert dictionary-format tools into Google's Tool objects."""
if not tool_configs:
return None
all_func_declarations = []
for config in tool_configs:
if "function_declarations" in config:
for func_dict in config["function_declarations"]:
try:
params_schema_dict = func_dict.get("parameters", {"type": "object", "properties": {}})
if params_schema_dict.get("type") != "object":
logger.warning(f"Tool {func_dict['name']} parameters schema is not type 'object'. Forcing object type.")
params_schema_dict = {"type": "object", "properties": params_schema_dict}
def create_schema(schema_dict):
if not isinstance(schema_dict, dict):
logger.warning(f"Invalid schema part encountered: {schema_dict}. Returning empty schema.")
return Schema()
schema_args = {
"type": schema_dict.get("type"),
"format": schema_dict.get("format"),
"description": schema_dict.get("description"),
"nullable": schema_dict.get("nullable"),
"enum": schema_dict.get("enum"),
"items": create_schema(schema_dict["items"]) if "items" in schema_dict else None,
"properties": {k: create_schema(v) for k, v in schema_dict.get("properties", {}).items()} if schema_dict.get("properties") else None,
"required": schema_dict.get("required"),
}
schema_args = {k: v for k, v in schema_args.items() if v is not None}
if "type" in schema_args:
type_mapping = {
"string": "STRING",
"number": "NUMBER",
"integer": "INTEGER",
"boolean": "BOOLEAN",
"array": "ARRAY",
"object": "OBJECT",
}
schema_args["type"] = type_mapping.get(str(schema_args["type"]).lower(), schema_args["type"])
try:
return Schema(**schema_args)
except Exception as schema_creation_err:
logger.error(f"Failed to create Schema object for {func_dict['name']} with args {schema_args}: {schema_creation_err}", exc_info=True)
return Schema()
parameters_schema = create_schema(params_schema_dict)
declaration = FunctionDeclaration(
name=func_dict["name"],
description=func_dict.get("description", ""),
parameters=parameters_schema,
)
all_func_declarations.append(declaration)
except Exception as decl_err:
logger.error(f"Failed to create FunctionDeclaration for tool '{func_dict.get('name', 'Unknown')}': {decl_err}", exc_info=True)
if not all_func_declarations:
logger.warning("No valid function declarations found after conversion.")
return None
return [Tool(function_declarations=all_func_declarations)]

View File

@@ -164,11 +164,12 @@ def convert_to_google_tools(mcp_tools: list[dict[str, Any]]) -> list[dict[str, A
function_declarations.append(function_declaration) function_declarations.append(function_declaration)
logger.debug(f"Converted MCP tool to Google FunctionDeclaration: {prefixed_tool_name}") logger.debug(f"Converted MCP tool to Google FunctionDeclaration: {prefixed_tool_name}")
# Google API expects a list containing one Tool object dict # Google API expects a list containing one dictionary with 'function_declarations'
google_tools_wrapper = [{"function_declarations": function_declarations}] if function_declarations else [] # The provider's _convert_to_tool_objects will handle creating Tool objects from this.
google_tool_config = [{"function_declarations": function_declarations}] if function_declarations else []
logger.debug(f"Final Google tools structure: {google_tools_wrapper}") logger.debug(f"Final Google tool config structure: {google_tool_config}")
return google_tools_wrapper return google_tool_config
# Note: The _handle_schema_construct helper from the reference code is not strictly # Note: The _handle_schema_construct helper from the reference code is not strictly