From 246d92174317761e0c63d8a9be612c8049c2d7b9 Mon Sep 17 00:00:00 2001 From: abhishekbhakat Date: Wed, 26 Mar 2025 18:18:10 +0000 Subject: [PATCH] feat: add GoogleProvider implementation and update conversion utilities for Google tools --- src/providers/__init__.py | 4 +- src/providers/google_provider.py | 483 +++++++++++++++++++++++++++++++ src/tools/conversion.py | 9 +- 3 files changed, 490 insertions(+), 6 deletions(-) create mode 100644 src/providers/google_provider.py diff --git a/src/providers/__init__.py b/src/providers/__init__.py index e1fdd53..10f1e11 100644 --- a/src/providers/__init__.py +++ b/src/providers/__init__.py @@ -3,9 +3,9 @@ import logging from providers.anthropic_provider import AnthropicProvider from providers.base import BaseProvider +from providers.google_provider import GoogleProvider from providers.openai_provider import OpenAIProvider -# from providers.google_provider import GoogleProvider # from providers.openrouter_provider import OpenRouterProvider logger = logging.getLogger(__name__) @@ -14,7 +14,7 @@ logger = logging.getLogger(__name__) PROVIDER_MAP: dict[str, type[BaseProvider]] = { "openai": OpenAIProvider, "anthropic": AnthropicProvider, - # "google": GoogleProvider, + "google": GoogleProvider, # "openrouter": OpenRouterProvider, # OpenRouter can often use OpenAIProvider with custom base_url } diff --git a/src/providers/google_provider.py b/src/providers/google_provider.py new file mode 100644 index 0000000..a006700 --- /dev/null +++ b/src/providers/google_provider.py @@ -0,0 +1,483 @@ +# src/providers/google_provider.py +import json +import logging +import traceback +from collections.abc import Generator +from typing import Any + +from google import genai +from google.genai.types import ( + Content, + FunctionDeclaration, + Part, + Schema, + Tool, +) + +from src.llm_models import MODELS +from src.providers.base import BaseProvider +from src.tools.conversion import convert_to_google_tools + +logger = logging.getLogger(__name__) + + +class GoogleProvider(BaseProvider): + """Provider implementation for Google Gemini models.""" + + def __init__(self, api_key: str, base_url: str | None = None): + # Google client typically doesn't use a base_url, but we accept it for consistency + effective_base_url = base_url or MODELS.get("google", {}).get("endpoint") + super().__init__(api_key, effective_base_url) + logger.info("Initializing GoogleProvider") + + if genai is None: + raise ImportError("Google Generative AI SDK is required for GoogleProvider. Please install google-generativeai.") + + try: + # Configure the client + genai.configure(api_key=self.api_key) + self.client_module = genai + except Exception as e: + logger.error(f"Failed to configure Google Generative AI client: {e}", exc_info=True) + raise + + def _get_context_window(self, model: str) -> int: + """Retrieves the context window size for a given Google model.""" + default_window = 1000000 # Default fallback for Gemini + try: + provider_models = MODELS.get("google", {}).get("models", []) + for m in provider_models: + if m.get("id") == model: + return m.get("context_window", default_window) + logger.warning(f"Context window for Google model '{model}' not found in MODELS config. Using default: {default_window}") + return default_window + except Exception as e: + logger.error(f"Error retrieving context window for model {model}: {e}. Using default: {default_window}", exc_info=True) + return default_window + + def _convert_messages(self, messages: list[dict[str, Any]]) -> tuple[list[Content], str | None]: + """ + Converts standard message format to Google's format, extracting system prompt. + Handles mapping roles and structuring tool calls/results. + """ + google_messages: list[Content] = [] + system_prompt: str | None = None + + for i, message in enumerate(messages): + role = message.get("role") + content = message.get("content") + tool_calls = message.get("tool_calls") + tool_call_id = message.get("tool_call_id") + + if role == "system": + if i == 0: + system_prompt = content + logger.debug("Extracted system prompt for Google.") + else: + logger.warning("System message found not at the beginning. Merging into subsequent user message.") + continue + + google_role = {"user": "user", "assistant": "model", "tool": "user"}.get(role) + + if not google_role: + logger.warning(f"Unsupported role '{role}' for Google provider, skipping message.") + continue + + parts: list[Part | str] = [] + if role == "tool": + if tool_call_id and content: + try: + response_content_dict = json.loads(content) + except json.JSONDecodeError: + logger.warning(f"Could not decode tool result content for {tool_call_id}, sending as raw string.") + response_content_dict = {"result": content} + + func_name = "unknown_function" + if i > 0 and messages[i - 1].get("role") == "assistant": + prev_tool_calls = messages[i - 1].get("tool_calls") + if prev_tool_calls: + for tc in prev_tool_calls: + if tc.get("id") == tool_call_id: + func_name = tc.get("function_name", "unknown_function") + break + + parts.append(Part.from_function_response(name=func_name, response={"content": response_content_dict})) + google_role = "function" + else: + logger.warning(f"Skipping tool message due to missing tool_call_id or content: {message}") + continue + + elif role == "assistant" and tool_calls: + for tool_call in tool_calls: + args = tool_call.get("arguments", {}) + if isinstance(args, str): + try: + args = json.loads(args) + except json.JSONDecodeError: + logger.error(f"Failed to parse arguments string for tool call {tool_call.get('id')}: {args}") + args = {"error": "failed to parse arguments"} + func_name = tool_call.get("function_name", "unknown_function") + parts.append(Part.from_function_call(name=func_name, args=args)) + if content: + parts.append(Part.from_text(content)) + + elif content: + if isinstance(content, str): + parts.append(Part.from_text(content)) + else: + logger.warning(f"Unsupported content type for role '{role}': {type(content)}. Converting to string.") + parts.append(Part.from_text(str(content))) + + if parts: + google_messages.append(Content(role=google_role, parts=parts)) + else: + logger.debug(f"No parts generated for message: {message}") + + last_role = None + valid_alternation = True + for msg in google_messages: + current_role = msg.role + if current_role == last_role and current_role in ["user", "model"]: + valid_alternation = False + logger.warning(f"Invalid role sequence detected: consecutive '{current_role}' roles.") + break + if last_role == "function" and current_role != "user": + valid_alternation = False + logger.warning(f"Invalid role sequence: '{current_role}' follows 'function'. Expected 'user'.") + break + last_role = current_role + + if not valid_alternation: + logger.error("Message list does not follow required user/model alternation for Google API.") + raise ValueError("Invalid message sequence for Google API.") + + return google_messages, system_prompt + + def create_chat_completion( + self, + messages: list[dict[str, str]], + model: str, + temperature: float = 0.4, + max_tokens: int | None = None, + stream: bool = True, + tools: list[dict[str, Any]] | None = None, + ) -> Any: + """Creates a chat completion using the Google Gemini API.""" + logger.debug(f"Google create_chat_completion called. Model: {model}, Stream: {stream}, Tools: {bool(tools)}") + + if self.client_module is None: + return {"error": "Google Generative AI SDK not installed."} if not stream else iter([json.dumps({"error": "Google Generative AI SDK not installed."})]) + + try: + google_messages, system_prompt = self._convert_messages(messages) + generation_config: dict[str, Any] = {"temperature": temperature} + if max_tokens is not None: + generation_config["max_output_tokens"] = max_tokens + + google_tools = None + if tools: + try: + tool_dict_list = convert_to_google_tools(tools) + google_tools = self._convert_to_tool_objects(tool_dict_list) + logger.debug(f"Converted {len(tools)} tools to {len(google_tools)} Google Tool objects.") + except Exception as tool_conv_err: + logger.error(f"Failed to convert tools for Google: {tool_conv_err}", exc_info=True) + google_tools = None + + gemini_model = self.client_module.GenerativeModel( + model_name=model, + system_instruction=system_prompt, + tools=google_tools if google_tools else None, + ) + + log_params = { + "model": model, + "stream": stream, + "temperature": temperature, + "max_tokens": max_tokens, + "system_prompt_present": bool(system_prompt), + "num_tools": len(google_tools) if google_tools else 0, + "num_messages": len(google_messages), + } + logger.debug(f"Calling Google API with params: {log_params}") + + response = gemini_model.generate_content( + contents=google_messages, + generation_config=generation_config, + stream=stream, + ) + logger.debug("Google API call successful.") + return response + + except Exception as e: + error_msg = f"Google API error: {e}" + logger.error(error_msg, exc_info=True) + if stream: + yield json.dumps({"error": error_msg, "traceback": traceback.format_exc()}) + else: + return {"error": error_msg, "traceback": traceback.format_exc()} + + def get_streaming_content(self, response: Any) -> Generator[str, None, None]: + """Yields content chunks from a Google streaming response.""" + logger.debug("Processing Google stream...") + full_delta = "" + try: + if isinstance(response, dict) and "error" in response: + yield json.dumps(response) + return + if hasattr(response, "__iter__") and not hasattr(response, "candidates"): + yield from response + return + + for chunk in response: + if isinstance(chunk, dict) and "error" in chunk: + yield json.dumps(chunk) + continue + if hasattr(chunk, "text"): + delta = chunk.text + if delta: + full_delta += delta + yield delta + elif hasattr(chunk, "candidates") and chunk.candidates: + for part in chunk.candidates[0].content.parts: + if hasattr(part, "function_call") and part.function_call: + logger.debug(f"Function call detected during stream: {part.function_call.name}") + break + + logger.debug(f"Google stream finished. Total delta length: {len(full_delta)}") + except Exception as e: + logger.error(f"Error processing Google stream: {e}", exc_info=True) + yield json.dumps({"error": f"Stream processing error: {str(e)}"}) + + def get_content(self, response: Any) -> str: + """Extracts content from a non-streaming Google response.""" + try: + if isinstance(response, dict) and "error" in response: + logger.error(f"Cannot get content from error response: {response['error']}") + return f"[Error: {response['error']}]" + if hasattr(response, "text"): + content = response.text + logger.debug(f"Extracted content (length {len(content)}) from response.text.") + return content + elif hasattr(response, "candidates") and response.candidates: + first_candidate = response.candidates[0] + if hasattr(first_candidate, "content") and hasattr(first_candidate.content, "parts"): + text_parts = [part.text for part in first_candidate.content.parts if hasattr(part, "text")] + content = "".join(text_parts) + logger.debug(f"Extracted content (length {len(content)}) from response candidates.") + return content + else: + logger.warning("Google response candidate has no content or parts.") + return "" + else: + logger.warning("Could not extract content from Google response: No 'text' or valid 'candidates'.") + return "" + except Exception as e: + logger.error(f"Error extracting content from Google response: {e}", exc_info=True) + return f"[Error extracting content: {str(e)}]" + + def has_tool_calls(self, response: Any) -> bool: + """Checks if the Google response contains tool calls (function calls).""" + try: + if isinstance(response, dict) and "error" in response: + return False + if hasattr(response, "candidates") and response.candidates: + candidate = response.candidates[0] + if hasattr(candidate, "content") and hasattr(candidate.content, "parts"): + for part in candidate.content.parts: + if hasattr(part, "function_call") and part.function_call: + logger.debug(f"Tool call (FunctionCall) detected in Google response part: {part.function_call.name}") + return True + logger.debug("No tool calls (FunctionCall) detected in Google response.") + return False + except Exception as e: + logger.error(f"Error checking for Google tool calls: {e}", exc_info=True) + return False + + def parse_tool_calls(self, response: Any) -> list[dict[str, Any]]: + """Parses tool calls (function calls) from a non-streaming Google response.""" + parsed_calls = [] + try: + if not (hasattr(response, "candidates") and response.candidates): + logger.warning("Cannot parse tool calls: Response has no candidates.") + return [] + + candidate = response.candidates[0] + if not (hasattr(candidate, "content") and hasattr(candidate.content, "parts")): + logger.warning("Cannot parse tool calls: Response candidate has no content or parts.") + return [] + + logger.debug("Parsing tool calls (FunctionCall) from Google response.") + call_index = 0 + for part in candidate.content.parts: + if hasattr(part, "function_call") and part.function_call: + func_call = part.function_call + call_id = f"call_{call_index}" + call_index += 1 + + full_name = func_call.name + parts = full_name.split("__", 1) + if len(parts) == 2: + server_name, func_name = parts + else: + logger.warning(f"Could not determine server_name from Google tool name '{full_name}'.") + server_name = None + func_name = full_name + + try: + args_str = json.dumps(func_call.args or {}) + except Exception as json_err: + logger.error(f"Failed to dump arguments dict to JSON string for {func_name}: {json_err}") + args_str = json.dumps({"error": "Failed to serialize arguments", "original_args": str(func_call.args)}) + + parsed_calls.append({ + "id": call_id, + "server_name": server_name, + "function_name": func_name, + "arguments": args_str, + }) + logger.debug(f"Parsed tool call: ID {call_id}, Server {server_name}, Func {func_name}, Args {args_str[:100]}...") + + return parsed_calls + except Exception as e: + logger.error(f"Error parsing Google tool calls: {e}", exc_info=True) + return [] + + def format_tool_results(self, tool_call_id: str, result: Any) -> dict[str, Any]: + """Formats a tool result for a Google follow-up request.""" + try: + if isinstance(result, dict): + content_str = json.dumps(result) + else: + content_str = str(result) + except Exception as e: + logger.error(f"Error JSON-encoding tool result for Google {tool_call_id}: {e}") + content_str = json.dumps({"error": "Failed to encode tool result", "original_type": str(type(result))}) + + logger.debug(f"Formatting Google tool result for call ID {tool_call_id}") + return { + "role": "tool", + "tool_call_id": tool_call_id, + "content": content_str, + "function_name": "unknown_function", + } + + def get_original_message_with_calls(self, response: Any) -> dict[str, Any]: + """Extracts the assistant's message containing tool calls for Google.""" + try: + if not (hasattr(response, "candidates") and response.candidates): + return {"role": "assistant", "content": "[Could not extract tool calls message: No candidates]"} + + candidate = response.candidates[0] + if not (hasattr(candidate, "content") and hasattr(candidate.content, "parts")): + return {"role": "assistant", "content": "[Could not extract tool calls message: No content/parts]"} + + tool_calls_formatted = [] + text_content_parts = [] + for part in candidate.content.parts: + if hasattr(part, "function_call") and part.function_call: + func_call = part.function_call + args = func_call.args or {} + tool_calls_formatted.append({ + "function_name": func_call.name, + "arguments": args, + }) + elif hasattr(part, "text"): + text_content_parts.append(part.text) + + message = {"role": "assistant"} + if tool_calls_formatted: + message["tool_calls"] = tool_calls_formatted + text_content = "".join(text_content_parts) + if text_content: + message["content"] = text_content + elif not tool_calls_formatted: + message["content"] = "" + + return message + + except Exception as e: + logger.error(f"Error extracting original Google message with calls: {e}", exc_info=True) + return {"role": "assistant", "content": f"[Error extracting tool calls message: {str(e)}]"} + + def get_usage(self, response: Any) -> dict[str, int] | None: + """Extracts token usage from a Google response.""" + try: + if isinstance(response, dict) and "error" in response: + return None + if hasattr(response, "usage_metadata"): + metadata = response.usage_metadata + usage = { + "prompt_tokens": getattr(metadata, "prompt_token_count", 0), + "completion_tokens": getattr(metadata, "candidates_token_count", 0), + } + logger.debug(f"Extracted usage from Google response metadata: {usage}") + return usage + else: + logger.warning(f"Could not extract usage from Google response object of type {type(response)}. No 'usage_metadata'.") + return None + except Exception as e: + logger.error(f"Error extracting usage from Google response: {e}", exc_info=True) + return None + + def _convert_to_tool_objects(self, tool_configs: list[dict[str, Any]]) -> list[Tool] | None: + """Convert dictionary-format tools into Google's Tool objects.""" + if not tool_configs: + return None + + all_func_declarations = [] + for config in tool_configs: + if "function_declarations" in config: + for func_dict in config["function_declarations"]: + try: + params_schema_dict = func_dict.get("parameters", {"type": "object", "properties": {}}) + if params_schema_dict.get("type") != "object": + logger.warning(f"Tool {func_dict['name']} parameters schema is not type 'object'. Forcing object type.") + params_schema_dict = {"type": "object", "properties": params_schema_dict} + + def create_schema(schema_dict): + if not isinstance(schema_dict, dict): + logger.warning(f"Invalid schema part encountered: {schema_dict}. Returning empty schema.") + return Schema() + schema_args = { + "type": schema_dict.get("type"), + "format": schema_dict.get("format"), + "description": schema_dict.get("description"), + "nullable": schema_dict.get("nullable"), + "enum": schema_dict.get("enum"), + "items": create_schema(schema_dict["items"]) if "items" in schema_dict else None, + "properties": {k: create_schema(v) for k, v in schema_dict.get("properties", {}).items()} if schema_dict.get("properties") else None, + "required": schema_dict.get("required"), + } + schema_args = {k: v for k, v in schema_args.items() if v is not None} + if "type" in schema_args: + type_mapping = { + "string": "STRING", + "number": "NUMBER", + "integer": "INTEGER", + "boolean": "BOOLEAN", + "array": "ARRAY", + "object": "OBJECT", + } + schema_args["type"] = type_mapping.get(str(schema_args["type"]).lower(), schema_args["type"]) + try: + return Schema(**schema_args) + except Exception as schema_creation_err: + logger.error(f"Failed to create Schema object for {func_dict['name']} with args {schema_args}: {schema_creation_err}", exc_info=True) + return Schema() + + parameters_schema = create_schema(params_schema_dict) + declaration = FunctionDeclaration( + name=func_dict["name"], + description=func_dict.get("description", ""), + parameters=parameters_schema, + ) + all_func_declarations.append(declaration) + except Exception as decl_err: + logger.error(f"Failed to create FunctionDeclaration for tool '{func_dict.get('name', 'Unknown')}': {decl_err}", exc_info=True) + + if not all_func_declarations: + logger.warning("No valid function declarations found after conversion.") + return None + + return [Tool(function_declarations=all_func_declarations)] diff --git a/src/tools/conversion.py b/src/tools/conversion.py index a07b723..02a5bcf 100644 --- a/src/tools/conversion.py +++ b/src/tools/conversion.py @@ -164,11 +164,12 @@ def convert_to_google_tools(mcp_tools: list[dict[str, Any]]) -> list[dict[str, A function_declarations.append(function_declaration) logger.debug(f"Converted MCP tool to Google FunctionDeclaration: {prefixed_tool_name}") - # Google API expects a list containing one Tool object dict - google_tools_wrapper = [{"function_declarations": function_declarations}] if function_declarations else [] + # Google API expects a list containing one dictionary with 'function_declarations' + # The provider's _convert_to_tool_objects will handle creating Tool objects from this. + google_tool_config = [{"function_declarations": function_declarations}] if function_declarations else [] - logger.debug(f"Final Google tools structure: {google_tools_wrapper}") - return google_tools_wrapper + logger.debug(f"Final Google tool config structure: {google_tool_config}") + return google_tool_config # Note: The _handle_schema_construct helper from the reference code is not strictly