# src/providers/google_provider.py import json import logging import traceback from collections.abc import Generator from typing import Any from google import genai from google.genai.types import ( Content, FunctionDeclaration, Part, Schema, Tool, ) from src.llm_models import MODELS from src.providers.base import BaseProvider from src.tools.conversion import convert_to_google_tools logger = logging.getLogger(__name__) class GoogleProvider(BaseProvider): """Provider implementation for Google Gemini models.""" def __init__(self, api_key: str, base_url: str | None = None): # Google client typically doesn't use a base_url, but we accept it for consistency effective_base_url = base_url or MODELS.get("google", {}).get("endpoint") super().__init__(api_key, effective_base_url) logger.info("Initializing GoogleProvider") if genai is None: raise ImportError("Google Generative AI SDK is required for GoogleProvider. Please install google-generativeai.") try: # Configure the client genai.configure(api_key=self.api_key) self.client_module = genai except Exception as e: logger.error(f"Failed to configure Google Generative AI client: {e}", exc_info=True) raise def _get_context_window(self, model: str) -> int: """Retrieves the context window size for a given Google model.""" default_window = 1000000 # Default fallback for Gemini try: provider_models = MODELS.get("google", {}).get("models", []) for m in provider_models: if m.get("id") == model: return m.get("context_window", default_window) logger.warning(f"Context window for Google model '{model}' not found in MODELS config. Using default: {default_window}") return default_window except Exception as e: logger.error(f"Error retrieving context window for model {model}: {e}. Using default: {default_window}", exc_info=True) return default_window def _convert_messages(self, messages: list[dict[str, Any]]) -> tuple[list[Content], str | None]: """ Converts standard message format to Google's format, extracting system prompt. Handles mapping roles and structuring tool calls/results. """ google_messages: list[Content] = [] system_prompt: str | None = None for i, message in enumerate(messages): role = message.get("role") content = message.get("content") tool_calls = message.get("tool_calls") tool_call_id = message.get("tool_call_id") if role == "system": if i == 0: system_prompt = content logger.debug("Extracted system prompt for Google.") else: logger.warning("System message found not at the beginning. Merging into subsequent user message.") continue google_role = {"user": "user", "assistant": "model", "tool": "user"}.get(role) if not google_role: logger.warning(f"Unsupported role '{role}' for Google provider, skipping message.") continue parts: list[Part | str] = [] if role == "tool": if tool_call_id and content: try: response_content_dict = json.loads(content) except json.JSONDecodeError: logger.warning(f"Could not decode tool result content for {tool_call_id}, sending as raw string.") response_content_dict = {"result": content} func_name = "unknown_function" if i > 0 and messages[i - 1].get("role") == "assistant": prev_tool_calls = messages[i - 1].get("tool_calls") if prev_tool_calls: for tc in prev_tool_calls: if tc.get("id") == tool_call_id: func_name = tc.get("function_name", "unknown_function") break parts.append(Part.from_function_response(name=func_name, response={"content": response_content_dict})) google_role = "function" else: logger.warning(f"Skipping tool message due to missing tool_call_id or content: {message}") continue elif role == "assistant" and tool_calls: for tool_call in tool_calls: args = tool_call.get("arguments", {}) if isinstance(args, str): try: args = json.loads(args) except json.JSONDecodeError: logger.error(f"Failed to parse arguments string for tool call {tool_call.get('id')}: {args}") args = {"error": "failed to parse arguments"} func_name = tool_call.get("function_name", "unknown_function") parts.append(Part.from_function_call(name=func_name, args=args)) if content: parts.append(Part.from_text(content)) elif content: if isinstance(content, str): parts.append(Part.from_text(content)) else: logger.warning(f"Unsupported content type for role '{role}': {type(content)}. Converting to string.") parts.append(Part.from_text(str(content))) if parts: google_messages.append(Content(role=google_role, parts=parts)) else: logger.debug(f"No parts generated for message: {message}") last_role = None valid_alternation = True for msg in google_messages: current_role = msg.role if current_role == last_role and current_role in ["user", "model"]: valid_alternation = False logger.warning(f"Invalid role sequence detected: consecutive '{current_role}' roles.") break if last_role == "function" and current_role != "user": valid_alternation = False logger.warning(f"Invalid role sequence: '{current_role}' follows 'function'. Expected 'user'.") break last_role = current_role if not valid_alternation: logger.error("Message list does not follow required user/model alternation for Google API.") raise ValueError("Invalid message sequence for Google API.") return google_messages, system_prompt def create_chat_completion( self, messages: list[dict[str, str]], model: str, temperature: float = 0.4, max_tokens: int | None = None, stream: bool = True, tools: list[dict[str, Any]] | None = None, ) -> Any: """Creates a chat completion using the Google Gemini API.""" logger.debug(f"Google create_chat_completion called. Model: {model}, Stream: {stream}, Tools: {bool(tools)}") if self.client_module is None: return {"error": "Google Generative AI SDK not installed."} if not stream else iter([json.dumps({"error": "Google Generative AI SDK not installed."})]) try: google_messages, system_prompt = self._convert_messages(messages) generation_config: dict[str, Any] = {"temperature": temperature} if max_tokens is not None: generation_config["max_output_tokens"] = max_tokens google_tools = None if tools: try: tool_dict_list = convert_to_google_tools(tools) google_tools = self._convert_to_tool_objects(tool_dict_list) logger.debug(f"Converted {len(tools)} tools to {len(google_tools)} Google Tool objects.") except Exception as tool_conv_err: logger.error(f"Failed to convert tools for Google: {tool_conv_err}", exc_info=True) google_tools = None gemini_model = self.client_module.GenerativeModel( model_name=model, system_instruction=system_prompt, tools=google_tools if google_tools else None, ) log_params = { "model": model, "stream": stream, "temperature": temperature, "max_tokens": max_tokens, "system_prompt_present": bool(system_prompt), "num_tools": len(google_tools) if google_tools else 0, "num_messages": len(google_messages), } logger.debug(f"Calling Google API with params: {log_params}") response = gemini_model.generate_content( contents=google_messages, generation_config=generation_config, stream=stream, ) logger.debug("Google API call successful.") return response except Exception as e: error_msg = f"Google API error: {e}" logger.error(error_msg, exc_info=True) if stream: yield json.dumps({"error": error_msg, "traceback": traceback.format_exc()}) else: return {"error": error_msg, "traceback": traceback.format_exc()} def get_streaming_content(self, response: Any) -> Generator[str, None, None]: """Yields content chunks from a Google streaming response.""" logger.debug("Processing Google stream...") full_delta = "" try: if isinstance(response, dict) and "error" in response: yield json.dumps(response) return if hasattr(response, "__iter__") and not hasattr(response, "candidates"): yield from response return for chunk in response: if isinstance(chunk, dict) and "error" in chunk: yield json.dumps(chunk) continue if hasattr(chunk, "text"): delta = chunk.text if delta: full_delta += delta yield delta elif hasattr(chunk, "candidates") and chunk.candidates: for part in chunk.candidates[0].content.parts: if hasattr(part, "function_call") and part.function_call: logger.debug(f"Function call detected during stream: {part.function_call.name}") break logger.debug(f"Google stream finished. Total delta length: {len(full_delta)}") except Exception as e: logger.error(f"Error processing Google stream: {e}", exc_info=True) yield json.dumps({"error": f"Stream processing error: {str(e)}"}) def get_content(self, response: Any) -> str: """Extracts content from a non-streaming Google response.""" try: if isinstance(response, dict) and "error" in response: logger.error(f"Cannot get content from error response: {response['error']}") return f"[Error: {response['error']}]" if hasattr(response, "text"): content = response.text logger.debug(f"Extracted content (length {len(content)}) from response.text.") return content elif hasattr(response, "candidates") and response.candidates: first_candidate = response.candidates[0] if hasattr(first_candidate, "content") and hasattr(first_candidate.content, "parts"): text_parts = [part.text for part in first_candidate.content.parts if hasattr(part, "text")] content = "".join(text_parts) logger.debug(f"Extracted content (length {len(content)}) from response candidates.") return content else: logger.warning("Google response candidate has no content or parts.") return "" else: logger.warning("Could not extract content from Google response: No 'text' or valid 'candidates'.") return "" except Exception as e: logger.error(f"Error extracting content from Google response: {e}", exc_info=True) return f"[Error extracting content: {str(e)}]" def has_tool_calls(self, response: Any) -> bool: """Checks if the Google response contains tool calls (function calls).""" try: if isinstance(response, dict) and "error" in response: return False if hasattr(response, "candidates") and response.candidates: candidate = response.candidates[0] if hasattr(candidate, "content") and hasattr(candidate.content, "parts"): for part in candidate.content.parts: if hasattr(part, "function_call") and part.function_call: logger.debug(f"Tool call (FunctionCall) detected in Google response part: {part.function_call.name}") return True logger.debug("No tool calls (FunctionCall) detected in Google response.") return False except Exception as e: logger.error(f"Error checking for Google tool calls: {e}", exc_info=True) return False def parse_tool_calls(self, response: Any) -> list[dict[str, Any]]: """Parses tool calls (function calls) from a non-streaming Google response.""" parsed_calls = [] try: if not (hasattr(response, "candidates") and response.candidates): logger.warning("Cannot parse tool calls: Response has no candidates.") return [] candidate = response.candidates[0] if not (hasattr(candidate, "content") and hasattr(candidate.content, "parts")): logger.warning("Cannot parse tool calls: Response candidate has no content or parts.") return [] logger.debug("Parsing tool calls (FunctionCall) from Google response.") call_index = 0 for part in candidate.content.parts: if hasattr(part, "function_call") and part.function_call: func_call = part.function_call call_id = f"call_{call_index}" call_index += 1 full_name = func_call.name parts = full_name.split("__", 1) if len(parts) == 2: server_name, func_name = parts else: logger.warning(f"Could not determine server_name from Google tool name '{full_name}'.") server_name = None func_name = full_name try: args_str = json.dumps(func_call.args or {}) except Exception as json_err: logger.error(f"Failed to dump arguments dict to JSON string for {func_name}: {json_err}") args_str = json.dumps({"error": "Failed to serialize arguments", "original_args": str(func_call.args)}) parsed_calls.append({ "id": call_id, "server_name": server_name, "function_name": func_name, "arguments": args_str, }) logger.debug(f"Parsed tool call: ID {call_id}, Server {server_name}, Func {func_name}, Args {args_str[:100]}...") return parsed_calls except Exception as e: logger.error(f"Error parsing Google tool calls: {e}", exc_info=True) return [] def format_tool_results(self, tool_call_id: str, result: Any) -> dict[str, Any]: """Formats a tool result for a Google follow-up request.""" try: if isinstance(result, dict): content_str = json.dumps(result) else: content_str = str(result) except Exception as e: logger.error(f"Error JSON-encoding tool result for Google {tool_call_id}: {e}") content_str = json.dumps({"error": "Failed to encode tool result", "original_type": str(type(result))}) logger.debug(f"Formatting Google tool result for call ID {tool_call_id}") return { "role": "tool", "tool_call_id": tool_call_id, "content": content_str, "function_name": "unknown_function", } def get_original_message_with_calls(self, response: Any) -> dict[str, Any]: """Extracts the assistant's message containing tool calls for Google.""" try: if not (hasattr(response, "candidates") and response.candidates): return {"role": "assistant", "content": "[Could not extract tool calls message: No candidates]"} candidate = response.candidates[0] if not (hasattr(candidate, "content") and hasattr(candidate.content, "parts")): return {"role": "assistant", "content": "[Could not extract tool calls message: No content/parts]"} tool_calls_formatted = [] text_content_parts = [] for part in candidate.content.parts: if hasattr(part, "function_call") and part.function_call: func_call = part.function_call args = func_call.args or {} tool_calls_formatted.append({ "function_name": func_call.name, "arguments": args, }) elif hasattr(part, "text"): text_content_parts.append(part.text) message = {"role": "assistant"} if tool_calls_formatted: message["tool_calls"] = tool_calls_formatted text_content = "".join(text_content_parts) if text_content: message["content"] = text_content elif not tool_calls_formatted: message["content"] = "" return message except Exception as e: logger.error(f"Error extracting original Google message with calls: {e}", exc_info=True) return {"role": "assistant", "content": f"[Error extracting tool calls message: {str(e)}]"} def get_usage(self, response: Any) -> dict[str, int] | None: """Extracts token usage from a Google response.""" try: if isinstance(response, dict) and "error" in response: return None if hasattr(response, "usage_metadata"): metadata = response.usage_metadata usage = { "prompt_tokens": getattr(metadata, "prompt_token_count", 0), "completion_tokens": getattr(metadata, "candidates_token_count", 0), } logger.debug(f"Extracted usage from Google response metadata: {usage}") return usage else: logger.warning(f"Could not extract usage from Google response object of type {type(response)}. No 'usage_metadata'.") return None except Exception as e: logger.error(f"Error extracting usage from Google response: {e}", exc_info=True) return None def _convert_to_tool_objects(self, tool_configs: list[dict[str, Any]]) -> list[Tool] | None: """Convert dictionary-format tools into Google's Tool objects.""" if not tool_configs: return None all_func_declarations = [] for config in tool_configs: if "function_declarations" in config: for func_dict in config["function_declarations"]: try: params_schema_dict = func_dict.get("parameters", {"type": "object", "properties": {}}) if params_schema_dict.get("type") != "object": logger.warning(f"Tool {func_dict['name']} parameters schema is not type 'object'. Forcing object type.") params_schema_dict = {"type": "object", "properties": params_schema_dict} def create_schema(schema_dict): if not isinstance(schema_dict, dict): logger.warning(f"Invalid schema part encountered: {schema_dict}. Returning empty schema.") return Schema() schema_args = { "type": schema_dict.get("type"), "format": schema_dict.get("format"), "description": schema_dict.get("description"), "nullable": schema_dict.get("nullable"), "enum": schema_dict.get("enum"), "items": create_schema(schema_dict["items"]) if "items" in schema_dict else None, "properties": {k: create_schema(v) for k, v in schema_dict.get("properties", {}).items()} if schema_dict.get("properties") else None, "required": schema_dict.get("required"), } schema_args = {k: v for k, v in schema_args.items() if v is not None} if "type" in schema_args: type_mapping = { "string": "STRING", "number": "NUMBER", "integer": "INTEGER", "boolean": "BOOLEAN", "array": "ARRAY", "object": "OBJECT", } schema_args["type"] = type_mapping.get(str(schema_args["type"]).lower(), schema_args["type"]) try: return Schema(**schema_args) except Exception as schema_creation_err: logger.error(f"Failed to create Schema object for {func_dict['name']} with args {schema_args}: {schema_creation_err}", exc_info=True) return Schema() parameters_schema = create_schema(params_schema_dict) declaration = FunctionDeclaration( name=func_dict["name"], description=func_dict.get("description", ""), parameters=parameters_schema, ) all_func_declarations.append(declaration) except Exception as decl_err: logger.error(f"Failed to create FunctionDeclaration for tool '{func_dict.get('name', 'Unknown')}': {decl_err}", exc_info=True) if not all_func_declarations: logger.warning("No valid function declarations found after conversion.") return None return [Tool(function_declarations=all_func_declarations)]