From 6b390a35f80fcce118ccc42b223072482196921b Mon Sep 17 00:00:00 2001 From: abhishekbhakat Date: Thu, 27 Mar 2025 11:11:56 +0000 Subject: [PATCH] feat: Implement GoogleProvider for Google Generative AI integration - Added GoogleProvider class to handle chat completions with Google Gemini API. - Implemented client initialization and response handling for streaming and non-streaming responses. - Created utility functions for tool conversion, response parsing, and content extraction. - Removed legacy tool conversion utilities from the tools module. - Enhanced logging for better traceability of API interactions and error handling. --- .gitignore | 9 +- src/providers/google_provider.py | 483 -------------------- src/providers/google_provider/__init__.py | 90 ++++ src/providers/google_provider/client.py | 27 ++ src/providers/google_provider/completion.py | 140 ++++++ src/providers/google_provider/response.py | 205 +++++++++ src/providers/google_provider/tools.py | 359 +++++++++++++++ src/providers/google_provider/utils.py | 150 ++++++ src/tools/__init__.py | 6 - src/tools/conversion.py | 77 ---- 10 files changed, 979 insertions(+), 567 deletions(-) delete mode 100644 src/providers/google_provider.py create mode 100644 src/providers/google_provider/__init__.py create mode 100644 src/providers/google_provider/client.py create mode 100644 src/providers/google_provider/completion.py create mode 100644 src/providers/google_provider/response.py create mode 100644 src/providers/google_provider/tools.py create mode 100644 src/providers/google_provider/utils.py delete mode 100644 src/tools/__init__.py delete mode 100644 src/tools/conversion.py diff --git a/.gitignore b/.gitignore index 3f1c90a..c4d6902 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ __pycache__/ # Virtual environment env/ +.venv/ # Configuration config/config.ini @@ -20,4 +21,10 @@ config/mcp_config.json # resources resources/ -# __pycache__/ +# Ruff +.ruff_cache/ + +# Distribution / packaging +dist/ +build/ +*.egg-info/ diff --git a/src/providers/google_provider.py b/src/providers/google_provider.py deleted file mode 100644 index a006700..0000000 --- a/src/providers/google_provider.py +++ /dev/null @@ -1,483 +0,0 @@ -# src/providers/google_provider.py -import json -import logging -import traceback -from collections.abc import Generator -from typing import Any - -from google import genai -from google.genai.types import ( - Content, - FunctionDeclaration, - Part, - Schema, - Tool, -) - -from src.llm_models import MODELS -from src.providers.base import BaseProvider -from src.tools.conversion import convert_to_google_tools - -logger = logging.getLogger(__name__) - - -class GoogleProvider(BaseProvider): - """Provider implementation for Google Gemini models.""" - - def __init__(self, api_key: str, base_url: str | None = None): - # Google client typically doesn't use a base_url, but we accept it for consistency - effective_base_url = base_url or MODELS.get("google", {}).get("endpoint") - super().__init__(api_key, effective_base_url) - logger.info("Initializing GoogleProvider") - - if genai is None: - raise ImportError("Google Generative AI SDK is required for GoogleProvider. Please install google-generativeai.") - - try: - # Configure the client - genai.configure(api_key=self.api_key) - self.client_module = genai - except Exception as e: - logger.error(f"Failed to configure Google Generative AI client: {e}", exc_info=True) - raise - - def _get_context_window(self, model: str) -> int: - """Retrieves the context window size for a given Google model.""" - default_window = 1000000 # Default fallback for Gemini - try: - provider_models = MODELS.get("google", {}).get("models", []) - for m in provider_models: - if m.get("id") == model: - return m.get("context_window", default_window) - logger.warning(f"Context window for Google model '{model}' not found in MODELS config. Using default: {default_window}") - return default_window - except Exception as e: - logger.error(f"Error retrieving context window for model {model}: {e}. Using default: {default_window}", exc_info=True) - return default_window - - def _convert_messages(self, messages: list[dict[str, Any]]) -> tuple[list[Content], str | None]: - """ - Converts standard message format to Google's format, extracting system prompt. - Handles mapping roles and structuring tool calls/results. - """ - google_messages: list[Content] = [] - system_prompt: str | None = None - - for i, message in enumerate(messages): - role = message.get("role") - content = message.get("content") - tool_calls = message.get("tool_calls") - tool_call_id = message.get("tool_call_id") - - if role == "system": - if i == 0: - system_prompt = content - logger.debug("Extracted system prompt for Google.") - else: - logger.warning("System message found not at the beginning. Merging into subsequent user message.") - continue - - google_role = {"user": "user", "assistant": "model", "tool": "user"}.get(role) - - if not google_role: - logger.warning(f"Unsupported role '{role}' for Google provider, skipping message.") - continue - - parts: list[Part | str] = [] - if role == "tool": - if tool_call_id and content: - try: - response_content_dict = json.loads(content) - except json.JSONDecodeError: - logger.warning(f"Could not decode tool result content for {tool_call_id}, sending as raw string.") - response_content_dict = {"result": content} - - func_name = "unknown_function" - if i > 0 and messages[i - 1].get("role") == "assistant": - prev_tool_calls = messages[i - 1].get("tool_calls") - if prev_tool_calls: - for tc in prev_tool_calls: - if tc.get("id") == tool_call_id: - func_name = tc.get("function_name", "unknown_function") - break - - parts.append(Part.from_function_response(name=func_name, response={"content": response_content_dict})) - google_role = "function" - else: - logger.warning(f"Skipping tool message due to missing tool_call_id or content: {message}") - continue - - elif role == "assistant" and tool_calls: - for tool_call in tool_calls: - args = tool_call.get("arguments", {}) - if isinstance(args, str): - try: - args = json.loads(args) - except json.JSONDecodeError: - logger.error(f"Failed to parse arguments string for tool call {tool_call.get('id')}: {args}") - args = {"error": "failed to parse arguments"} - func_name = tool_call.get("function_name", "unknown_function") - parts.append(Part.from_function_call(name=func_name, args=args)) - if content: - parts.append(Part.from_text(content)) - - elif content: - if isinstance(content, str): - parts.append(Part.from_text(content)) - else: - logger.warning(f"Unsupported content type for role '{role}': {type(content)}. Converting to string.") - parts.append(Part.from_text(str(content))) - - if parts: - google_messages.append(Content(role=google_role, parts=parts)) - else: - logger.debug(f"No parts generated for message: {message}") - - last_role = None - valid_alternation = True - for msg in google_messages: - current_role = msg.role - if current_role == last_role and current_role in ["user", "model"]: - valid_alternation = False - logger.warning(f"Invalid role sequence detected: consecutive '{current_role}' roles.") - break - if last_role == "function" and current_role != "user": - valid_alternation = False - logger.warning(f"Invalid role sequence: '{current_role}' follows 'function'. Expected 'user'.") - break - last_role = current_role - - if not valid_alternation: - logger.error("Message list does not follow required user/model alternation for Google API.") - raise ValueError("Invalid message sequence for Google API.") - - return google_messages, system_prompt - - def create_chat_completion( - self, - messages: list[dict[str, str]], - model: str, - temperature: float = 0.4, - max_tokens: int | None = None, - stream: bool = True, - tools: list[dict[str, Any]] | None = None, - ) -> Any: - """Creates a chat completion using the Google Gemini API.""" - logger.debug(f"Google create_chat_completion called. Model: {model}, Stream: {stream}, Tools: {bool(tools)}") - - if self.client_module is None: - return {"error": "Google Generative AI SDK not installed."} if not stream else iter([json.dumps({"error": "Google Generative AI SDK not installed."})]) - - try: - google_messages, system_prompt = self._convert_messages(messages) - generation_config: dict[str, Any] = {"temperature": temperature} - if max_tokens is not None: - generation_config["max_output_tokens"] = max_tokens - - google_tools = None - if tools: - try: - tool_dict_list = convert_to_google_tools(tools) - google_tools = self._convert_to_tool_objects(tool_dict_list) - logger.debug(f"Converted {len(tools)} tools to {len(google_tools)} Google Tool objects.") - except Exception as tool_conv_err: - logger.error(f"Failed to convert tools for Google: {tool_conv_err}", exc_info=True) - google_tools = None - - gemini_model = self.client_module.GenerativeModel( - model_name=model, - system_instruction=system_prompt, - tools=google_tools if google_tools else None, - ) - - log_params = { - "model": model, - "stream": stream, - "temperature": temperature, - "max_tokens": max_tokens, - "system_prompt_present": bool(system_prompt), - "num_tools": len(google_tools) if google_tools else 0, - "num_messages": len(google_messages), - } - logger.debug(f"Calling Google API with params: {log_params}") - - response = gemini_model.generate_content( - contents=google_messages, - generation_config=generation_config, - stream=stream, - ) - logger.debug("Google API call successful.") - return response - - except Exception as e: - error_msg = f"Google API error: {e}" - logger.error(error_msg, exc_info=True) - if stream: - yield json.dumps({"error": error_msg, "traceback": traceback.format_exc()}) - else: - return {"error": error_msg, "traceback": traceback.format_exc()} - - def get_streaming_content(self, response: Any) -> Generator[str, None, None]: - """Yields content chunks from a Google streaming response.""" - logger.debug("Processing Google stream...") - full_delta = "" - try: - if isinstance(response, dict) and "error" in response: - yield json.dumps(response) - return - if hasattr(response, "__iter__") and not hasattr(response, "candidates"): - yield from response - return - - for chunk in response: - if isinstance(chunk, dict) and "error" in chunk: - yield json.dumps(chunk) - continue - if hasattr(chunk, "text"): - delta = chunk.text - if delta: - full_delta += delta - yield delta - elif hasattr(chunk, "candidates") and chunk.candidates: - for part in chunk.candidates[0].content.parts: - if hasattr(part, "function_call") and part.function_call: - logger.debug(f"Function call detected during stream: {part.function_call.name}") - break - - logger.debug(f"Google stream finished. Total delta length: {len(full_delta)}") - except Exception as e: - logger.error(f"Error processing Google stream: {e}", exc_info=True) - yield json.dumps({"error": f"Stream processing error: {str(e)}"}) - - def get_content(self, response: Any) -> str: - """Extracts content from a non-streaming Google response.""" - try: - if isinstance(response, dict) and "error" in response: - logger.error(f"Cannot get content from error response: {response['error']}") - return f"[Error: {response['error']}]" - if hasattr(response, "text"): - content = response.text - logger.debug(f"Extracted content (length {len(content)}) from response.text.") - return content - elif hasattr(response, "candidates") and response.candidates: - first_candidate = response.candidates[0] - if hasattr(first_candidate, "content") and hasattr(first_candidate.content, "parts"): - text_parts = [part.text for part in first_candidate.content.parts if hasattr(part, "text")] - content = "".join(text_parts) - logger.debug(f"Extracted content (length {len(content)}) from response candidates.") - return content - else: - logger.warning("Google response candidate has no content or parts.") - return "" - else: - logger.warning("Could not extract content from Google response: No 'text' or valid 'candidates'.") - return "" - except Exception as e: - logger.error(f"Error extracting content from Google response: {e}", exc_info=True) - return f"[Error extracting content: {str(e)}]" - - def has_tool_calls(self, response: Any) -> bool: - """Checks if the Google response contains tool calls (function calls).""" - try: - if isinstance(response, dict) and "error" in response: - return False - if hasattr(response, "candidates") and response.candidates: - candidate = response.candidates[0] - if hasattr(candidate, "content") and hasattr(candidate.content, "parts"): - for part in candidate.content.parts: - if hasattr(part, "function_call") and part.function_call: - logger.debug(f"Tool call (FunctionCall) detected in Google response part: {part.function_call.name}") - return True - logger.debug("No tool calls (FunctionCall) detected in Google response.") - return False - except Exception as e: - logger.error(f"Error checking for Google tool calls: {e}", exc_info=True) - return False - - def parse_tool_calls(self, response: Any) -> list[dict[str, Any]]: - """Parses tool calls (function calls) from a non-streaming Google response.""" - parsed_calls = [] - try: - if not (hasattr(response, "candidates") and response.candidates): - logger.warning("Cannot parse tool calls: Response has no candidates.") - return [] - - candidate = response.candidates[0] - if not (hasattr(candidate, "content") and hasattr(candidate.content, "parts")): - logger.warning("Cannot parse tool calls: Response candidate has no content or parts.") - return [] - - logger.debug("Parsing tool calls (FunctionCall) from Google response.") - call_index = 0 - for part in candidate.content.parts: - if hasattr(part, "function_call") and part.function_call: - func_call = part.function_call - call_id = f"call_{call_index}" - call_index += 1 - - full_name = func_call.name - parts = full_name.split("__", 1) - if len(parts) == 2: - server_name, func_name = parts - else: - logger.warning(f"Could not determine server_name from Google tool name '{full_name}'.") - server_name = None - func_name = full_name - - try: - args_str = json.dumps(func_call.args or {}) - except Exception as json_err: - logger.error(f"Failed to dump arguments dict to JSON string for {func_name}: {json_err}") - args_str = json.dumps({"error": "Failed to serialize arguments", "original_args": str(func_call.args)}) - - parsed_calls.append({ - "id": call_id, - "server_name": server_name, - "function_name": func_name, - "arguments": args_str, - }) - logger.debug(f"Parsed tool call: ID {call_id}, Server {server_name}, Func {func_name}, Args {args_str[:100]}...") - - return parsed_calls - except Exception as e: - logger.error(f"Error parsing Google tool calls: {e}", exc_info=True) - return [] - - def format_tool_results(self, tool_call_id: str, result: Any) -> dict[str, Any]: - """Formats a tool result for a Google follow-up request.""" - try: - if isinstance(result, dict): - content_str = json.dumps(result) - else: - content_str = str(result) - except Exception as e: - logger.error(f"Error JSON-encoding tool result for Google {tool_call_id}: {e}") - content_str = json.dumps({"error": "Failed to encode tool result", "original_type": str(type(result))}) - - logger.debug(f"Formatting Google tool result for call ID {tool_call_id}") - return { - "role": "tool", - "tool_call_id": tool_call_id, - "content": content_str, - "function_name": "unknown_function", - } - - def get_original_message_with_calls(self, response: Any) -> dict[str, Any]: - """Extracts the assistant's message containing tool calls for Google.""" - try: - if not (hasattr(response, "candidates") and response.candidates): - return {"role": "assistant", "content": "[Could not extract tool calls message: No candidates]"} - - candidate = response.candidates[0] - if not (hasattr(candidate, "content") and hasattr(candidate.content, "parts")): - return {"role": "assistant", "content": "[Could not extract tool calls message: No content/parts]"} - - tool_calls_formatted = [] - text_content_parts = [] - for part in candidate.content.parts: - if hasattr(part, "function_call") and part.function_call: - func_call = part.function_call - args = func_call.args or {} - tool_calls_formatted.append({ - "function_name": func_call.name, - "arguments": args, - }) - elif hasattr(part, "text"): - text_content_parts.append(part.text) - - message = {"role": "assistant"} - if tool_calls_formatted: - message["tool_calls"] = tool_calls_formatted - text_content = "".join(text_content_parts) - if text_content: - message["content"] = text_content - elif not tool_calls_formatted: - message["content"] = "" - - return message - - except Exception as e: - logger.error(f"Error extracting original Google message with calls: {e}", exc_info=True) - return {"role": "assistant", "content": f"[Error extracting tool calls message: {str(e)}]"} - - def get_usage(self, response: Any) -> dict[str, int] | None: - """Extracts token usage from a Google response.""" - try: - if isinstance(response, dict) and "error" in response: - return None - if hasattr(response, "usage_metadata"): - metadata = response.usage_metadata - usage = { - "prompt_tokens": getattr(metadata, "prompt_token_count", 0), - "completion_tokens": getattr(metadata, "candidates_token_count", 0), - } - logger.debug(f"Extracted usage from Google response metadata: {usage}") - return usage - else: - logger.warning(f"Could not extract usage from Google response object of type {type(response)}. No 'usage_metadata'.") - return None - except Exception as e: - logger.error(f"Error extracting usage from Google response: {e}", exc_info=True) - return None - - def _convert_to_tool_objects(self, tool_configs: list[dict[str, Any]]) -> list[Tool] | None: - """Convert dictionary-format tools into Google's Tool objects.""" - if not tool_configs: - return None - - all_func_declarations = [] - for config in tool_configs: - if "function_declarations" in config: - for func_dict in config["function_declarations"]: - try: - params_schema_dict = func_dict.get("parameters", {"type": "object", "properties": {}}) - if params_schema_dict.get("type") != "object": - logger.warning(f"Tool {func_dict['name']} parameters schema is not type 'object'. Forcing object type.") - params_schema_dict = {"type": "object", "properties": params_schema_dict} - - def create_schema(schema_dict): - if not isinstance(schema_dict, dict): - logger.warning(f"Invalid schema part encountered: {schema_dict}. Returning empty schema.") - return Schema() - schema_args = { - "type": schema_dict.get("type"), - "format": schema_dict.get("format"), - "description": schema_dict.get("description"), - "nullable": schema_dict.get("nullable"), - "enum": schema_dict.get("enum"), - "items": create_schema(schema_dict["items"]) if "items" in schema_dict else None, - "properties": {k: create_schema(v) for k, v in schema_dict.get("properties", {}).items()} if schema_dict.get("properties") else None, - "required": schema_dict.get("required"), - } - schema_args = {k: v for k, v in schema_args.items() if v is not None} - if "type" in schema_args: - type_mapping = { - "string": "STRING", - "number": "NUMBER", - "integer": "INTEGER", - "boolean": "BOOLEAN", - "array": "ARRAY", - "object": "OBJECT", - } - schema_args["type"] = type_mapping.get(str(schema_args["type"]).lower(), schema_args["type"]) - try: - return Schema(**schema_args) - except Exception as schema_creation_err: - logger.error(f"Failed to create Schema object for {func_dict['name']} with args {schema_args}: {schema_creation_err}", exc_info=True) - return Schema() - - parameters_schema = create_schema(params_schema_dict) - declaration = FunctionDeclaration( - name=func_dict["name"], - description=func_dict.get("description", ""), - parameters=parameters_schema, - ) - all_func_declarations.append(declaration) - except Exception as decl_err: - logger.error(f"Failed to create FunctionDeclaration for tool '{func_dict.get('name', 'Unknown')}': {decl_err}", exc_info=True) - - if not all_func_declarations: - logger.warning("No valid function declarations found after conversion.") - return None - - return [Tool(function_declarations=all_func_declarations)] diff --git a/src/providers/google_provider/__init__.py b/src/providers/google_provider/__init__.py new file mode 100644 index 0000000..272499e --- /dev/null +++ b/src/providers/google_provider/__init__.py @@ -0,0 +1,90 @@ +# src/providers/google_provider/__init__.py +import logging +from collections.abc import Generator +from typing import Any + +from google.genai.types import GenerateContentResponse + +from providers.google_provider.client import initialize_client +from providers.google_provider.completion import create_chat_completion +from providers.google_provider.response import get_content, get_streaming_content, get_usage +from providers.google_provider.tools import convert_to_google_tools, format_google_tool_results, has_google_tool_calls, parse_google_tool_calls +from src.providers.base import BaseProvider + +logger = logging.getLogger(__name__) + + +class GoogleProvider(BaseProvider): + """Provider implementation for Google Generative AI (Gemini).""" + + # Type hint for the client (it's the configured 'genai' module itself) + client_module: Any + + def __init__(self, api_key: str, base_url: str | None = None): + """ + Initializes the GoogleProvider. + + Args: + api_key: The Google API key. + base_url: Base URL (typically not used by Google client config, but kept for interface consistency). + """ + # initialize_client returns the configured genai module + self.client_module = initialize_client(api_key, base_url) + self.api_key = api_key # Store if needed later + self.base_url = base_url # Store if needed later + logger.info("GoogleProvider initialized.") + + def create_chat_completion( + self, + messages: list[dict[str, Any]], + model: str, + temperature: float = 0.4, + max_tokens: int | None = None, + stream: bool = True, + tools: list[dict[str, Any]] | None = None, + ) -> Any: # Return type is complex: iterator for stream, GenerateContentResponse otherwise, or error dict/iterator + """Creates a chat completion using the Google Gemini API.""" + # Pass self (provider instance) to the helper function + return create_chat_completion(self, messages, model, temperature, max_tokens, stream, tools) + + def get_streaming_content(self, response: Any) -> Generator[str, None, None]: + """Extracts content chunks from a Google streaming response.""" + # Response is expected to be an iterator from generate_content(stream=True) + return get_streaming_content(response) + + def get_content(self, response: GenerateContentResponse | dict[str, Any]) -> str: + """Extracts the full text content from a non-streaming Google response.""" + return get_content(response) + + def has_tool_calls(self, response: GenerateContentResponse | dict[str, Any]) -> bool: + """Checks if the Google response contains tool calls (FunctionCalls).""" + # Note: For streaming responses, this check is reliable only after the stream is fully consumed + # or if the specific chunk containing the call is processed. + return has_google_tool_calls(response) + + def parse_tool_calls(self, response: GenerateContentResponse | dict[str, Any]) -> list[dict[str, Any]]: + """Parses tool calls (FunctionCalls) from a non-streaming Google response.""" + # Expects a non-streaming GenerateContentResponse or an error dict + return parse_google_tool_calls(response) + + # Note: Google's format_tool_results helper requires the original function_name. + # Ensure the calling code (e.g., LLMClient) provides this when invoking this method. + def format_tool_results(self, tool_call_id: str, function_name: str, result: Any) -> dict[str, Any]: + """Formats a tool result for a Google follow-up request (into standard message format).""" + return format_google_tool_results(tool_call_id, function_name, result) + + def convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Converts MCP tools list to Google's intermediate dictionary format.""" + # The `create_chat_completion` function handles the final conversion + # from this intermediate format to Google's `Tool` objects internally. + return convert_to_google_tools(tools) + + def get_usage(self, response: GenerateContentResponse | dict[str, Any]) -> dict[str, int] | None: + """Extracts token usage information from a Google response.""" + # Expects a non-streaming GenerateContentResponse or an error dict + return get_usage(response) + + # `get_original_message_with_calls` (present in OpenAIProvider) is not implemented here + # as Google's API structure integrates FunctionCall parts directly into the assistant's + # message content, rather than having a separate `tool_calls` attribute on the message object. + # The necessary information is handled during message conversion and tool call parsing. diff --git a/src/providers/google_provider/client.py b/src/providers/google_provider/client.py new file mode 100644 index 0000000..f49e59c --- /dev/null +++ b/src/providers/google_provider/client.py @@ -0,0 +1,27 @@ +# src/providers/google_provider/client.py +import logging +from typing import Any + +from google import genai + +logger = logging.getLogger(__name__) + + +def initialize_client(api_key: str, base_url: str | None = None) -> Any: + """Initializes and returns the Google Generative AI client module.""" + logger.info("Initializing Google Generative AI client") + + if genai is None: + logger.error("Google Generative AI SDK (google-generativeai) is not installed.") + raise ImportError("Google Generative AI SDK is required for GoogleProvider. Please install google-generativeai.") + + try: + # Configure the client + genai.configure(api_key=api_key) + if base_url: + logger.warning(f"base_url '{base_url}' provided but not typically used by Google client configuration.") + # Return the configured module itself, as it's used directly + return genai + except Exception as e: + logger.error(f"Failed to configure Google Generative AI client: {e}", exc_info=True) + raise diff --git a/src/providers/google_provider/completion.py b/src/providers/google_provider/completion.py new file mode 100644 index 0000000..946c647 --- /dev/null +++ b/src/providers/google_provider/completion.py @@ -0,0 +1,140 @@ +# src/providers/google_provider/completion.py +import json +import logging +import traceback +from typing import Any + +from google.genai.types import Tool + +from providers.google_provider.tools import convert_to_google_tool_objects, convert_to_google_tools +from providers.google_provider.utils import convert_messages + +logger = logging.getLogger(__name__) + + +def create_chat_completion( + provider, + messages: list[dict[str, Any]], + model: str, + temperature: float = 0.4, + max_tokens: int | None = None, + stream: bool = True, + tools: list[dict[str, Any]] | None = None, +) -> Any: + """ + Creates a chat completion using the Google Gemini API. + + Args: + provider: The instance of the GoogleProvider. + messages: A list of message dictionaries in the standard format. + model: The model ID to use (e.g., "gemini-1.5-flash"). + temperature: The sampling temperature. + max_tokens: The maximum number of tokens to generate. + stream: Whether to stream the response. + tools: A list of tool definitions in the MCP format. + + Returns: + The response object from the Google API (could be a stream iterator or + a GenerateContentResponse object), or an error dictionary/iterator. + """ + logger.debug(f"Google create_chat_completion called. Model: {model}, Stream: {stream}, Tools: {bool(tools)}") + + if provider.client_module is None: + error_msg = "Google Generative AI SDK not configured or installed." + logger.error(error_msg) + # Return an error structure compatible with both streaming and non-streaming expectations + if stream: + return iter([json.dumps({"error": error_msg})]) + else: + return {"error": error_msg} + + try: + # 1. Convert messages to Google's format + google_messages, system_prompt = convert_messages(messages) + logger.debug(f"Converted {len(messages)} messages to {len(google_messages)} Google Content objects. System prompt present: {bool(system_prompt)}") + + # 2. Prepare generation configuration + generation_config: dict[str, Any] = {"temperature": temperature} + if max_tokens is not None: + # Google uses 'max_output_tokens' + generation_config["max_output_tokens"] = max_tokens + logger.debug(f"Setting max_output_tokens: {max_tokens}") + else: + logger.debug("No max_tokens specified.") + + # 3. Convert tools if provided + google_tool_objects: list[Tool] | None = None + if tools: + try: + # Step 3a: Convert MCP tools to intermediate Google dict format + tool_dict_list = convert_to_google_tools(tools) + # Step 3b: Convert intermediate dict format to Google Tool objects + google_tool_objects = convert_to_google_tool_objects(tool_dict_list) + if google_tool_objects: + logger.debug(f"Successfully converted {len(tools)} MCP tools to {len(google_tool_objects)} Google Tool objects.") + else: + logger.warning("Tool conversion resulted in no valid Google Tool objects.") + except Exception as tool_conv_err: + logger.error(f"Failed to convert tools for Google: {tool_conv_err}", exc_info=True) + # Decide whether to proceed without tools or raise an error + # Proceeding without tools for now, but logging the error. + google_tool_objects = None + else: + logger.debug("No tools provided for conversion.") + + # 4. Initialize the Google Generative Model + # Ensure client_module is callable and has GenerativeModel + if not hasattr(provider.client_module, "GenerativeModel"): + raise AttributeError("Configured Google client module does not have 'GenerativeModel'") + + gemini_model = provider.client_module.GenerativeModel( + model_name=model, + system_instruction=system_prompt, # Pass extracted system prompt + tools=google_tool_objects, # Pass converted Tool objects (or None) + # Add safety_settings if needed: safety_settings=... + ) + logger.debug(f"Initialized Google GenerativeModel for '{model}'.") + + # 5. Log parameters before API call + log_params = { + "model": model, + "stream": stream, + "temperature": temperature, + "max_output_tokens": generation_config.get("max_output_tokens"), + "system_prompt_present": bool(system_prompt), + "num_tools": len(google_tool_objects) if google_tool_objects else 0, + "num_messages": len(google_messages), + } + logger.info(f"Calling Google generate_content API with params: {log_params}") + # Avoid logging full message content unless necessary for debugging specific issues + # logger.debug(f"Google messages being sent: {google_messages}") + + # 6. Call the Google API + response = gemini_model.generate_content( + contents=google_messages, + generation_config=generation_config, + stream=stream, + # tool_config={"function_calling_config": "AUTO"} # AUTO is default + ) + logger.debug("Google API call successful, returning response object.") + return response + + except ValueError as ve: # Catch specific errors like invalid message sequence + error_msg = f"Google API request validation error: {ve}" + logger.error(error_msg, exc_info=True) + if stream: + # Yield a JSON error message in an iterator + yield json.dumps({"error": error_msg, "traceback": traceback.format_exc()}) + else: + # Return an error dictionary + return {"error": error_msg, "traceback": traceback.format_exc()} + except Exception as e: + # Catch any other exceptions during setup or API call + error_msg = f"Google API error during chat completion: {e}" + logger.error(error_msg, exc_info=True) + if stream: + # Yield a JSON error message in an iterator + yield json.dumps({"error": error_msg, "traceback": traceback.format_exc()}) + else: + # Return an error dictionary + return {"error": error_msg, "traceback": traceback.format_exc()} diff --git a/src/providers/google_provider/response.py b/src/providers/google_provider/response.py new file mode 100644 index 0000000..f8b7315 --- /dev/null +++ b/src/providers/google_provider/response.py @@ -0,0 +1,205 @@ +# src/providers/google_provider/response.py +""" +Response handling utilities specific to the Google Generative AI provider. + +Includes functions for: +- Extracting content from streaming responses. +- Extracting content from non-streaming responses. +- Extracting token usage information. +""" + +import json +import logging +from collections.abc import Generator +from typing import Any + +from google.genai.types import GenerateContentResponse + +logger = logging.getLogger(__name__) + + +def get_streaming_content(response: Any) -> Generator[str, None, None]: + """ + Yields content chunks (text) from a Google streaming response iterator. + + Args: + response: The streaming response iterator returned by `generate_content(stream=True)`. + + Yields: + String chunks of the generated text content. + May yield JSON strings containing error information if errors occur during streaming. + """ + logger.debug("Processing Google stream...") + full_delta = "" + try: + # Check if the response itself is an error indicator (e.g., from create_chat_completion error handling) + if isinstance(response, dict) and "error" in response: + yield json.dumps(response) + logger.error(f"Stream processing stopped due to initial error: {response['error']}") + return + # Check if response is already an error iterator + if hasattr(response, "__iter__") and not hasattr(response, "candidates"): + # If it looks like an error iterator from create_chat_completion + first_item = next(response, None) + if first_item and isinstance(first_item, str): + try: + error_data = json.loads(first_item) + if "error" in error_data: + yield first_item # Yield the error JSON + yield from response + logger.error(f"Stream processing stopped due to yielded error: {error_data['error']}") + return + except json.JSONDecodeError: + # Not a JSON error, yield it as is and continue? Or stop? + # Assuming it might be valid content if not JSON error. + yield first_item + elif first_item: # Put the first item back if it wasn't an error + # This requires a way to chain iterators, simple yield doesn't work well here. + # For simplicity, we assume error iterators yield JSON strings. + # If the stream is valid, the loop below will handle it. + # Re-assigning response might be complex. Let the main loop handle valid streams. + pass # Let the main loop handle the original response iterator + + # Process the stream chunk by chunk + for chunk in response: + # Check for errors embedded within the stream chunks (less common for Google?) + if isinstance(chunk, dict) and "error" in chunk: + yield json.dumps(chunk) + logger.error(f"Error encountered during Google stream: {chunk['error']}") + continue # Continue processing stream or stop? Continuing for now. + + # Extract text content + delta = "" + try: + if hasattr(chunk, "text"): + delta = chunk.text + elif hasattr(chunk, "candidates") and chunk.candidates: + # Sometimes content might be nested under candidates even in stream? + # Check the first candidate's first part for text. + first_candidate = chunk.candidates[0] + if hasattr(first_candidate, "content") and hasattr(first_candidate.content, "parts") and first_candidate.content.parts: + first_part = first_candidate.content.parts[0] + if hasattr(first_part, "text"): + delta = first_part.text + except Exception as e: + logger.warning(f"Could not extract text from stream chunk: {chunk}. Error: {e}", exc_info=True) + delta = "" # Ensure delta is a string + + if delta: + full_delta += delta + yield delta + + # Detect function calls during stream (optional, for logging/early detection) + try: + if hasattr(chunk, "candidates") and chunk.candidates: + for part in chunk.candidates[0].content.parts: + if hasattr(part, "function_call") and part.function_call: + logger.debug(f"Function call detected during stream: {part.function_call.name}") + # Note: We don't yield the function call itself here, just the text. + # Function calls are typically processed after the stream completes. + break # Found a function call in this chunk + except Exception: + # Ignore errors during optional function call detection in stream + pass + + logger.debug(f"Google stream finished. Total delta length: {len(full_delta)}") + + except StopIteration: + logger.debug("Google stream finished (StopIteration).") # Normal end of iteration + except Exception as e: + logger.error(f"Error processing Google stream: {e}", exc_info=True) + # Yield a final error message + yield json.dumps({"error": f"Stream processing error: {str(e)}"}) + + +def get_content(response: GenerateContentResponse | dict[str, Any]) -> str: + """ + Extracts the full text content from a non-streaming Google response. + + Args: + response: The non-streaming response object (`GenerateContentResponse`) or + an error dictionary. + + Returns: + The concatenated text content, or an error message string. + """ + try: + # Handle error dictionary case + if isinstance(response, dict) and "error" in response: + logger.error(f"Cannot get content from error response: {response['error']}") + return f"[Error: {response['error']}]" + + # Handle successful GenerateContentResponse object + if hasattr(response, "text"): + # The `.text` attribute usually provides the concatenated text content directly + content = response.text + logger.debug(f"Extracted content (length {len(content)}) from response.text.") + return content + elif hasattr(response, "candidates") and response.candidates: + # Fallback: manually concatenate text from parts if .text is missing + first_candidate = response.candidates[0] + if hasattr(first_candidate, "content") and hasattr(first_candidate.content, "parts"): + text_parts = [] + for part in first_candidate.content.parts: + if hasattr(part, "text"): + text_parts.append(part.text) + # We are only interested in text content here, ignore function calls etc. + content = "".join(text_parts) + logger.debug(f"Extracted content (length {len(content)}) from response candidates' parts.") + return content + else: + logger.warning("Google response candidate has no content or parts.") + return "" # Return empty string if no text found + else: + logger.warning(f"Could not extract content from Google response: No 'text' or valid 'candidates'. Response type: {type(response)}") + return "" # Return empty string if no text found + except AttributeError as ae: + logger.error(f"Attribute error extracting content from Google response: {ae}. Response object: {response}", exc_info=True) + return f"[Error extracting content: Attribute missing - {str(ae)}]" + except Exception as e: + logger.error(f"Unexpected error extracting content from Google response: {e}", exc_info=True) + return f"[Error extracting content: {str(e)}]" + + +def get_usage(response: GenerateContentResponse | dict[str, Any]) -> dict[str, int] | None: + """ + Extracts token usage information from a Google response object. + + Args: + response: The response object (`GenerateContentResponse`) or an error dictionary. + + Returns: + A dictionary containing 'prompt_tokens' and 'completion_tokens', or None if + usage information is unavailable or an error occurred. + """ + try: + # Handle error dictionary case + if isinstance(response, dict) and "error" in response: + logger.warning("Cannot get usage from error response.") + return None + + # Check for usage metadata in the response object + if hasattr(response, "usage_metadata"): + metadata = response.usage_metadata + # Google uses prompt_token_count and candidates_token_count + usage = { + "prompt_tokens": getattr(metadata, "prompt_token_count", 0), + "completion_tokens": getattr(metadata, "candidates_token_count", 0), + # Google also provides total_token_count, could be added if needed + # "total_tokens": getattr(metadata, "total_token_count", 0), + } + # Ensure values are integers + usage = {k: int(v) for k, v in usage.items()} + logger.debug(f"Extracted usage from Google response metadata: {usage}") + return usage + else: + # Log a warning only if it's not clearly an error dict already handled + if not (isinstance(response, dict) and "error" in response): + logger.warning(f"Could not extract usage from Google response object of type {type(response)}. No 'usage_metadata' attribute found.") + return None + except AttributeError as ae: + logger.error(f"Attribute error extracting usage from Google response: {ae}. Response object: {response}", exc_info=True) + return None + except Exception as e: + logger.error(f"Unexpected error extracting usage from Google response: {e}", exc_info=True) + return None diff --git a/src/providers/google_provider/tools.py b/src/providers/google_provider/tools.py new file mode 100644 index 0000000..f256724 --- /dev/null +++ b/src/providers/google_provider/tools.py @@ -0,0 +1,359 @@ +# src/providers/google_provider/tools.py +""" +Tool handling utilities specific to the Google Generative AI provider. + +Includes functions for: +- Converting MCP tool definitions to Google's format. +- Creating Google Tool/FunctionDeclaration objects. +- Parsing tool calls (FunctionCalls) from Google responses. +- Formatting tool results for subsequent API calls. +""" + +import json +import logging +from typing import Any + +from google.genai.types import FunctionDeclaration, Schema, Tool + +logger = logging.getLogger(__name__) + + +# --- Tool Conversion (from MCP format to Google format) --- + + +def convert_to_google_tools(mcp_tools: list[dict[str, Any]]) -> list[dict[str, Any]]: + """ + Convert MCP tools to Google Gemini format (dictionary structure). + + This format is an intermediate step before creating Tool objects. + + Args: + mcp_tools: List of MCP tools (each with server_name, name, description, inputSchema). + + Returns: + List containing one dictionary with 'function_declarations'. + Returns an empty list if no valid tools are provided or converted. + """ + logger.debug(f"Converting {len(mcp_tools)} MCP tools to Google Gemini format") + + function_declarations = [] + + for tool in mcp_tools: + server_name = tool.get("server_name") + tool_name = tool.get("name") + description = tool.get("description") + input_schema = tool.get("inputSchema") + + if not server_name or not tool_name or not description or not input_schema: + logger.warning(f"Skipping invalid MCP tool definition during Google conversion: {tool}") + continue + + # Prefix tool name with server name for routing + prefixed_tool_name = f"{server_name}__{tool_name}" + + # Basic validation/cleaning of schema for Google compatibility + if not isinstance(input_schema, dict) or input_schema.get("type") != "object": + logger.warning(f"Input schema for tool '{prefixed_tool_name}' is not a valid JSON object schema. Google might reject this. Attempting to normalize.") + # Ensure basic structure if missing + if not isinstance(input_schema, dict): + input_schema = {} # Start fresh if not a dict + if "type" not in input_schema or input_schema["type"] != "object": + # Wrap existing schema or create new if type is wrong/missing + input_schema = {"type": "object", "properties": {"_original_schema": input_schema}} if input_schema else {"type": "object", "properties": {}} + logger.warning(f"Wrapped original schema for {prefixed_tool_name} under '_original_schema' property.") + + if "properties" not in input_schema: + input_schema["properties"] = {} + + # Google requires properties for object type, add dummy if empty + if not input_schema["properties"]: + logger.warning(f"Empty properties for tool '{prefixed_tool_name}', adding dummy property for Google.") + input_schema["properties"] = {"_dummy_param": {"type": "STRING", "description": "Placeholder parameter as properties cannot be empty."}} + if "required" in input_schema and not isinstance(input_schema.get("required"), list): + input_schema["required"] = [] # Clear invalid required list + + # Create function declaration dictionary for Google's format + function_declaration = { + "name": prefixed_tool_name, + "description": description, + "parameters": input_schema, # Google uses JSON Schema directly + } + + function_declarations.append(function_declaration) + logger.debug(f"Prepared Google FunctionDeclaration dict for: {prefixed_tool_name}") + + # Google API expects a list containing one dictionary with 'function_declarations' key + google_tool_config = [{"function_declarations": function_declarations}] if function_declarations else [] + + logger.debug(f"Final Google tool config structure (pre-Tool object): {google_tool_config}") + return google_tool_config + + +def _create_google_schema_recursive(schema_dict: dict[str, Any]) -> Schema | None: + """ + Recursively creates Google Schema objects from a JSON schema dictionary. + + Handles type mapping and nested structures. Returns None on failure. + """ + if Schema is None: + logger.error("Cannot create Schema object: google.genai types not available.") + return None + + if not isinstance(schema_dict, dict): + logger.warning(f"Invalid schema part encountered: {schema_dict}. Returning empty schema.") + return Schema() # Return empty schema to avoid breaking the parent + + # Map JSON Schema types to Google's Type enum strings + type_mapping = { + "string": "STRING", + "number": "NUMBER", + "integer": "INTEGER", + "boolean": "BOOLEAN", + "array": "ARRAY", + "object": "OBJECT", + # Add other mappings if necessary + } + original_type = schema_dict.get("type") + google_type = type_mapping.get(str(original_type).lower()) if original_type else None + + # Prepare arguments for Schema constructor, filtering out None values + schema_args = { + "type": google_type, + "format": schema_dict.get("format"), + "description": schema_dict.get("description"), + "nullable": schema_dict.get("nullable"), + "enum": schema_dict.get("enum"), + "items": _create_google_schema_recursive(schema_dict["items"]) if "items" in schema_dict and google_type == "ARRAY" else None, + "properties": {k: _create_google_schema_recursive(v) for k, v in schema_dict.get("properties", {}).items()} if schema_dict.get("properties") and google_type == "OBJECT" else None, + "required": schema_dict.get("required") if google_type == "OBJECT" else None, + } + # Remove keys with None values + schema_args = {k: v for k, v in schema_args.items() if v is not None} + + if not schema_args.get("type"): + logger.warning(f"Schema dictionary missing 'type' or type '{original_type}' is not recognized: {schema_dict}. Creating empty Schema.") + return Schema() # Return empty schema + + try: + return Schema(**schema_args) + except Exception as schema_creation_err: + logger.error(f"Failed to create Schema object with args {schema_args}: {schema_creation_err}", exc_info=True) + return Schema() # Return empty schema on error + + +def convert_to_google_tool_objects(tool_configs: list[dict[str, Any]]) -> list[Tool] | None: + """ + Convert the dictionary-based tool configurations into Google's Tool objects. + + Args: + tool_configs: A list containing a dictionary with 'function_declarations', + as produced by `convert_to_google_tools`. + + Returns: + A list containing a single Google `Tool` object, or None if conversion fails + or no valid declarations are found. + """ + if Tool is None or FunctionDeclaration is None: + logger.error("Cannot create Tool objects: google.genai types not available.") + return None + if not tool_configs: + logger.debug("No tool configurations provided to convert to Tool objects.") + return None + + all_func_declarations = [] + # Expecting structure like [{"function_declarations": [...]}] + if isinstance(tool_configs, list) and len(tool_configs) > 0 and "function_declarations" in tool_configs[0]: + func_declarations_list = tool_configs[0]["function_declarations"] + if not isinstance(func_declarations_list, list): + logger.error(f"Expected 'function_declarations' to be a list, got {type(func_declarations_list)}") + return None + + for func_dict in func_declarations_list: + try: + params_schema_dict = func_dict.get("parameters", {"type": "object", "properties": {}}) + # Ensure parameters is a valid schema dict for the recursive creator + if not isinstance(params_schema_dict, dict): + logger.warning(f"Invalid 'parameters' format for tool {func_dict.get('name')}: {params_schema_dict}. Using empty object schema.") + params_schema_dict = {"type": "object", "properties": {}} + elif params_schema_dict.get("type") != "object": + logger.warning(f"Tool {func_dict.get('name')} parameters schema is not type 'object'. Forcing object type.") + params_schema_dict = {"type": "object", "properties": params_schema_dict.get("properties", {})} # Attempt to salvage properties + + parameters_schema = _create_google_schema_recursive(params_schema_dict) + + # Only proceed if schema creation was somewhat successful + if parameters_schema is not None: + declaration = FunctionDeclaration( + name=func_dict["name"], + description=func_dict.get("description", ""), + parameters=parameters_schema, + ) + all_func_declarations.append(declaration) + else: + logger.error(f"Failed to create parameters Schema for FunctionDeclaration '{func_dict.get('name', 'Unknown')}'") + + except Exception as decl_err: + logger.error(f"Failed to create FunctionDeclaration object for tool '{func_dict.get('name', 'Unknown')}': {decl_err}", exc_info=True) + + else: + logger.error(f"Invalid tool_configs structure provided: {tool_configs}") + return None + + if not all_func_declarations: + logger.warning("No valid Google FunctionDeclarations were created from the provided configurations.") + return None + + # Google expects a list containing one Tool object + logger.info(f"Successfully created {len(all_func_declarations)} Google FunctionDeclarations.") + return [Tool(function_declarations=all_func_declarations)] + + +# --- Tool Call Parsing and Handling (from Google response) --- + + +def has_google_tool_calls(response: Any) -> bool: + """ + Checks if the Google response object contains tool calls (FunctionCalls). + + Args: + response: The response object from the Google generate_content API call. + + Returns: + True if FunctionCalls are present, False otherwise. + """ + try: + # Check non-streaming response structure + if hasattr(response, "candidates") and response.candidates: + candidate = response.candidates[0] + if hasattr(candidate, "content") and hasattr(candidate.content, "parts"): + for part in candidate.content.parts: + if hasattr(part, "function_call") and part.function_call: + logger.debug(f"Tool call (FunctionCall) detected in Google response part: {part.function_call.name}") + return True + + # Note: Detecting function calls reliably in a stream might require accumulating parts. + # This function primarily works reliably for non-streaming responses. + # For streaming, the check might happen during stream processing itself. + + logger.debug("No tool calls (FunctionCall) detected in Google response.") + return False + except Exception as e: + logger.error(f"Error checking for Google tool calls: {e}", exc_info=True) + return False + + +def parse_google_tool_calls(response: Any) -> list[dict[str, Any]]: + """ + Parses tool calls (FunctionCalls) from a non-streaming Google response object. + + Args: + response: The non-streaming response object from the Google generate_content API call. + + Returns: + A list of dictionaries, each representing a tool call in the standard MCP format + (id, server_name, function_name, arguments as JSON string). + Returns an empty list if no calls are found or an error occurs. + """ + parsed_calls = [] + try: + if not (hasattr(response, "candidates") and response.candidates): + logger.warning("Cannot parse tool calls: Response has no candidates.") + return [] + + candidate = response.candidates[0] + if not (hasattr(candidate, "content") and hasattr(candidate.content, "parts")): + logger.warning("Cannot parse tool calls: Response candidate has no content or parts.") + return [] + + logger.debug("Parsing tool calls (FunctionCall) from Google response.") + call_index = 0 + for part in candidate.content.parts: + if hasattr(part, "function_call") and part.function_call: + func_call = part.function_call + # Generate a simple unique ID for this call within this response + call_id = f"call_{call_index}" + call_index += 1 + + # Extract server_name and func_name from the prefixed name + full_name = func_call.name + parts = full_name.split("__", 1) + if len(parts) == 2: + server_name, func_name = parts + else: + # If the prefix isn't found, assume it's just the function name + logger.warning(f"Could not determine server_name from Google tool name '{full_name}'. Using None for server_name.") + server_name = None + func_name = full_name + + # Convert arguments dict to JSON string + try: + # func_call.args is already a dict-like object (Mapping) + args_dict = dict(func_call.args) if func_call.args else {} + args_str = json.dumps(args_dict) + except Exception as json_err: + logger.error(f"Failed to dump arguments dict to JSON string for {func_name}: {json_err}") + # Provide error info in arguments if serialization fails + args_str = json.dumps({"error": "Failed to serialize arguments", "original_args": str(func_call.args)}) + + parsed_calls.append({ + "id": call_id, # Internal ID for tracking this call + "server_name": server_name, + "function_name": func_name, # The original function name + "arguments": args_str, # Arguments as a JSON string + "_google_tool_name": full_name, # Keep original name if needed later + }) + logger.debug(f"Parsed tool call: ID {call_id}, Server {server_name}, Func {func_name}, Args {args_str[:100]}...") + + return parsed_calls + except Exception as e: + logger.error(f"Error parsing Google tool calls: {e}", exc_info=True) + return [] + + +def format_google_tool_results(tool_call_id: str, function_name: str, result: Any) -> dict[str, Any]: + """ + Formats a tool result for a Google follow-up request (FunctionResponse). + + Args: + tool_call_id: The unique ID assigned during parsing (e.g., "call_0"). + Note: Google's API itself doesn't use this ID directly in the + FunctionResponse part, but we need it for mapping in the message list. + function_name: The original function name (without server prefix) that was called. + result: The data returned by the tool execution. Should be JSON-serializable. + + Returns: + A dictionary representing the tool result message in the standard MCP format. + This will be converted later by `_convert_messages`. + """ + try: + # Google expects the 'response' field in FunctionResponse to contain a dict. + # The content should ideally be JSON serializable. We wrap the result. + if isinstance(result, (str, int, float, bool, list)): + content_dict = {"result": result} + elif isinstance(result, dict): + content_dict = result # Assume it's already a suitable dict + else: + logger.warning(f"Tool result for {function_name} is of non-standard type {type(result)}. Converting to string.") + content_dict = {"result": str(result)} + + # Ensure the content is JSON serializable for the 'content' field + try: + content_str = json.dumps(content_dict) + except Exception as json_err: + logger.error(f"Error JSON-encoding tool result content for Google {function_name} ({tool_call_id}): {json_err}") + content_str = json.dumps({"error": "Failed to encode tool result content", "original_type": str(type(result))}) + + except Exception as e: + logger.error(f"Error preparing tool result content for Google {function_name} ({tool_call_id}): {e}") + content_str = json.dumps({"error": "Failed to prepare tool result content", "details": str(e)}) + + logger.debug(f"Formatting Google tool result for call ID {tool_call_id} (Function: {function_name})") + # Return in the standard message format, _convert_messages will handle Google's structure + return { + "role": "tool", + "tool_call_id": tool_call_id, # Used by _convert_messages to find the original call + "content": content_str, # The JSON string representing the result content + "name": function_name, # Store original function name for _convert_messages + # Note: Google's FunctionResponse Part needs 'name' and 'response' (dict). + # This standard format will be converted by the provider's message conversion logic. + } diff --git a/src/providers/google_provider/utils.py b/src/providers/google_provider/utils.py new file mode 100644 index 0000000..da36ff7 --- /dev/null +++ b/src/providers/google_provider/utils.py @@ -0,0 +1,150 @@ +# src/providers/google_provider/utils.py +import json +import logging +from typing import Any + +from google.genai.types import Content, Part + +from src.llm_models import MODELS + +logger = logging.getLogger(__name__) + + +def get_context_window(model: str) -> int: + """Retrieves the context window size for a given Google model.""" + default_window = 1000000 # Default fallback for Gemini + try: + provider_models = MODELS.get("google", {}).get("models", []) + for m in provider_models: + if m.get("id") == model: + return m.get("context_window", default_window) + logger.warning(f"Context window for Google model '{model}' not found in MODELS config. Using default: {default_window}") + return default_window + except Exception as e: + logger.error(f"Error retrieving context window for model {model}: {e}. Using default: {default_window}", exc_info=True) + return default_window + + +def convert_messages(messages: list[dict[str, Any]]) -> tuple[list[Content], str | None]: + """ + Converts standard message format to Google's format, extracting system prompt. + Handles mapping roles and structuring tool calls/results. + """ + google_messages: list[Content] = [] + system_prompt: str | None = None + + for i, message in enumerate(messages): + role = message.get("role") + content = message.get("content") + tool_calls = message.get("tool_calls") + tool_call_id = message.get("tool_call_id") + + if role == "system": + if i == 0: + system_prompt = content + logger.debug("Extracted system prompt for Google.") + else: + # Google API expects system prompt only at the beginning. + # If found later, log a warning and skip or merge if possible (though merging is complex). + logger.warning("System message found not at the beginning. Skipping for Google API.") + continue # Skip adding system messages to the main list + + # Map roles: 'assistant' -> 'model', 'tool' -> 'function' (handled below) + google_role = {"user": "user", "assistant": "model"}.get(role) + + if not google_role and role != "tool": + logger.warning(f"Unsupported role '{role}' for Google provider, skipping message.") + continue + + parts: list[Part | str] = [] + if role == "tool": + # Tool results are mapped to 'function' role in Google API + if tool_call_id and content: + try: + # Attempt to parse the content as JSON, assuming it's the tool output + response_content_dict = json.loads(content) + except json.JSONDecodeError: + logger.warning(f"Could not decode tool result content for {tool_call_id}, sending as raw string.") + response_content_dict = {"result": content} # Wrap raw string if not JSON + + # Find the original function name from the preceding assistant message + func_name = "unknown_function" # Default if name can't be found + if i > 0 and messages[i - 1].get("role") == "assistant": + prev_tool_calls = messages[i - 1].get("tool_calls") + if prev_tool_calls: + for tc in prev_tool_calls: + # Match based on the ID provided in the tool message + if tc.get("id") == tool_call_id: + # Google uses 'server__func' format, extract original func name if possible + full_name = tc.get("function_name", "unknown_function") + func_name = full_name.split("__", 1)[-1] # Get the part after '__' or the full name + break + + # Create a FunctionResponse part + parts.append(Part.from_function_response(name=func_name, response={"content": response_content_dict})) + google_role = "function" # Explicitly set role for tool results + else: + logger.warning(f"Skipping tool message due to missing tool_call_id or content: {message}") + continue # Skip if essential parts are missing + + elif role == "assistant" and tool_calls: + # Assistant message requesting tool calls + for tool_call in tool_calls: + args = tool_call.get("arguments", {}) + # Ensure arguments are a dict, not a string + if isinstance(args, str): + try: + args = json.loads(args) + except json.JSONDecodeError: + logger.error(f"Failed to parse arguments string for tool call {tool_call.get('id')}: {args}") + args = {"error": "failed to parse arguments"} # Provide error feedback + + # Google uses 'server__func' format, extract original func name if possible + full_name = tool_call.get("function_name", "unknown_function") + func_name = full_name.split("__", 1)[-1] # Get the part after '__' or the full name + + # Create a FunctionCall part + parts.append(Part.from_function_call(name=func_name, args=args)) + + # Include any text content alongside the function calls + if content and isinstance(content, str): + parts.append(Part.from_text(content)) + + elif content: + # Regular user or assistant message content + if isinstance(content, str): + parts.append(Part.from_text(content)) + # TODO: Handle potential image content if needed in the future + else: + logger.warning(f"Unsupported content type for role '{role}': {type(content)}. Converting to string.") + parts.append(Part.from_text(str(content))) + + # Add the constructed Content object if parts were generated + if parts: + google_messages.append(Content(role=google_role, parts=parts)) + else: + # Log if a message resulted in no parts (e.g., empty content, skipped system message) + logger.debug(f"No parts generated for message: {message}") + + # Validate message alternation (user -> model -> user/function -> user -> ...) + last_role = None + valid_alternation = True + for msg in google_messages: + current_role = msg.role + # Check for consecutive user/model roles + if current_role == last_role and current_role in ["user", "model"]: + valid_alternation = False + logger.error(f"Invalid role sequence for Google: consecutive '{current_role}' roles.") + break + # Check if 'function' role is followed by 'user' + if last_role == "function" and current_role != "user": + valid_alternation = False + logger.error(f"Invalid role sequence for Google: '{current_role}' follows 'function'. Expected 'user'.") + break + last_role = current_role + + # Raise error if alternation is invalid, as Google API enforces this + if not valid_alternation: + raise ValueError("Invalid message sequence for Google API. Roles must alternate between 'user' and 'model', with 'function' responses followed by 'user'.") + + return google_messages, system_prompt diff --git a/src/tools/__init__.py b/src/tools/__init__.py deleted file mode 100644 index 057b8bf..0000000 --- a/src/tools/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# src/tools/__init__.py -# This file makes the 'tools' directory a Python package. - -# Optionally import key functions/classes for easier access -# from .conversion import convert_to_openai_tools, convert_to_anthropic_tools -# from .execution import execute_tool # Assuming execution.py will exist diff --git a/src/tools/conversion.py b/src/tools/conversion.py deleted file mode 100644 index 8ec6b21..0000000 --- a/src/tools/conversion.py +++ /dev/null @@ -1,77 +0,0 @@ -""" -Conversion utilities for MCP tools. - -This module contains functions to convert between different tool formats -for various LLM providers (OpenAI, Anthropic, etc.). -""" - -import logging -from typing import Any - -logger = logging.getLogger(__name__) - - -def convert_to_google_tools(mcp_tools: list[dict[str, Any]]) -> list[dict[str, Any]]: - """ - Convert MCP tools to Google Gemini format (dictionary structure). - - Args: - mcp_tools: List of MCP tools (each with server_name, name, description, inputSchema). - - Returns: - List containing one dictionary with 'function_declarations'. - """ - logger.debug(f"Converting {len(mcp_tools)} MCP tools to Google Gemini format") - - function_declarations = [] - - for tool in mcp_tools: - server_name = tool.get("server_name") - tool_name = tool.get("name") - description = tool.get("description") - input_schema = tool.get("inputSchema") - - if not server_name or not tool_name or not description or not input_schema: - logger.warning(f"Skipping invalid MCP tool definition during Google conversion: {tool}") - continue - - # Prefix tool name with server name for routing - prefixed_tool_name = f"{server_name}__{tool_name}" - - # Basic validation/cleaning of schema - if not isinstance(input_schema, dict) or input_schema.get("type") != "object": - logger.warning(f"Input schema for tool '{prefixed_tool_name}' is not a valid JSON object schema. Google might reject this.") - # Ensure basic structure if missing - if not isinstance(input_schema, dict): - input_schema = {} - if "type" not in input_schema: - input_schema["type"] = "object" - if "properties" not in input_schema: - input_schema["properties"] = {} - # Google requires properties for object type, add dummy if empty - if not input_schema["properties"]: - logger.warning(f"Empty properties for tool '{prefixed_tool_name}', adding dummy property for Google.") - input_schema["properties"] = {"_dummy_param": {"type": "STRING", "description": "Placeholder"}} - - # Create function declaration for Google's format - function_declaration = { - "name": prefixed_tool_name, - "description": description, - "parameters": input_schema, # Google uses JSON Schema directly - } - - function_declarations.append(function_declaration) - logger.debug(f"Converted MCP tool to Google FunctionDeclaration: {prefixed_tool_name}") - - # Google API expects a list containing one dictionary with 'function_declarations' - # The provider's _convert_to_tool_objects will handle creating Tool objects from this. - google_tool_config = [{"function_declarations": function_declarations}] if function_declarations else [] - - logger.debug(f"Final Google tool config structure: {google_tool_config}") - return google_tool_config - - -# Note: The _handle_schema_construct helper from the reference code is not strictly -# needed if we assume the inputSchema is already valid JSON Schema. -# If complex schemas (anyOf, etc.) need specific handling beyond standard JSON Schema, -# that logic could be added here or within the provider implementations.