diff --git a/src/llm_client.py b/src/llm_client.py index daf18cf..9e1acc3 100644 --- a/src/llm_client.py +++ b/src/llm_client.py @@ -56,7 +56,7 @@ class LLMClient: self, messages: list[dict[str, str]], model: str, - temperature: float = 0.4, + temperature: float = 0.6, max_tokens: int | None = None, stream: bool = True, ) -> Generator[str, None, None] | dict[str, Any]: @@ -97,6 +97,7 @@ class LLMClient: stream=stream, tools=provider_tools, ) + print(f"Response: {response}") # Debugging line to check the response logger.info("Received response from provider.") if stream: diff --git a/src/providers/anthropic_provider/completion.py b/src/providers/anthropic_provider/completion.py index 5009308..ddd63c8 100644 --- a/src/providers/anthropic_provider/completion.py +++ b/src/providers/anthropic_provider/completion.py @@ -10,7 +10,7 @@ logger = logging.getLogger(__name__) def create_chat_completion( - provider, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, stream: bool = True, tools: list[dict[str, Any]] | None = None + provider, messages: list[dict[str, Any]], model: str, temperature: float = 0.6, max_tokens: int | None = None, stream: bool = True, tools: list[dict[str, Any]] | None = None ) -> Stream | Message: logger.debug(f"Creating Anthropic chat completion. Model: {model}, Stream: {stream}, Tools: {bool(tools)}") temp_system_prompt, temp_anthropic_messages = convert_messages(messages) diff --git a/src/providers/base.py b/src/providers/base.py index 3332509..e3d93f2 100644 --- a/src/providers/base.py +++ b/src/providers/base.py @@ -28,7 +28,7 @@ class BaseProvider(abc.ABC): self, messages: list[dict[str, str]], model: str, - temperature: float = 0.4, + temperature: float = 0.6, max_tokens: int | None = None, stream: bool = True, tools: list[dict[str, Any]] | None = None, @@ -39,7 +39,7 @@ class BaseProvider(abc.ABC): Args: messages: List of message dictionaries with 'role' and 'content'. model: Model identifier. - temperature: Sampling temperature (0-1). + temperature: Sampling temperature (0-2). max_tokens: Maximum tokens to generate. stream: Whether to stream the response. tools: Optional list of tools in the provider-specific format. diff --git a/src/providers/google_provider/__init__.py b/src/providers/google_provider/__init__.py index 272499e..22011ce 100644 --- a/src/providers/google_provider/__init__.py +++ b/src/providers/google_provider/__init__.py @@ -3,9 +3,12 @@ import logging from collections.abc import Generator from typing import Any +# Import Generator type for isinstance check - Keep this import for type hints from google.genai.types import GenerateContentResponse from providers.google_provider.client import initialize_client + +# Correctly import the renamed function directly from providers.google_provider.completion import create_chat_completion from providers.google_provider.response import get_content, get_streaming_content, get_usage from providers.google_provider.tools import convert_to_google_tools, format_google_tool_results, has_google_tool_calls, parse_google_tool_calls @@ -28,7 +31,7 @@ class GoogleProvider(BaseProvider): api_key: The Google API key. base_url: Base URL (typically not used by Google client config, but kept for interface consistency). """ - # initialize_client returns the configured genai module + # initialize_client returns the client instance now self.client_module = initialize_client(api_key, base_url) self.api_key = api_key # Store if needed later self.base_url = base_url # Store if needed later @@ -38,18 +41,24 @@ class GoogleProvider(BaseProvider): self, messages: list[dict[str, Any]], model: str, - temperature: float = 0.4, + temperature: float = 0.6, max_tokens: int | None = None, stream: bool = True, tools: list[dict[str, Any]] | None = None, ) -> Any: # Return type is complex: iterator for stream, GenerateContentResponse otherwise, or error dict/iterator """Creates a chat completion using the Google Gemini API.""" # Pass self (provider instance) to the helper function - return create_chat_completion(self, messages, model, temperature, max_tokens, stream, tools) + raw_response = create_chat_completion(self, messages, model, temperature, max_tokens, stream, tools) + print(f"Raw response type: {type(raw_response)}") # Debugging line to check the type of raw_response + print(f"Raw response: {raw_response}") # Debugging line to check the content of raw_response + + # The completion helper function handles returning the correct type or an error dict. + # No need for generator handling here anymore. + return raw_response def get_streaming_content(self, response: Any) -> Generator[str, None, None]: """Extracts content chunks from a Google streaming response.""" - # Response is expected to be an iterator from generate_content(stream=True) + # Response is expected to be an iterator from generate_content_stream return get_streaming_content(response) def get_content(self, response: GenerateContentResponse | dict[str, Any]) -> str: diff --git a/src/providers/google_provider/client.py b/src/providers/google_provider/client.py index b3806c1..fe00495 100644 --- a/src/providers/google_provider/client.py +++ b/src/providers/google_provider/client.py @@ -16,12 +16,13 @@ def initialize_client(api_key: str, base_url: str | None = None) -> Any: raise ImportError("Google Generative AI SDK is required for GoogleProvider. Please install google-generativeai.") try: - # Configure the client - genai.configure(api_key=api_key) + # Instantiate the client directly using the API key + client = genai.Client(api_key=api_key) + logger.info("Google Generative AI client instantiated.") if base_url: - logger.warning(f"base_url '{base_url}' provided but not typically used by Google client configuration.") - # Return the configured module itself, as it's used directly - return genai + logger.warning(f"base_url '{base_url}' provided but not typically used by Google client instantiation.") + # Return the client instance + return client except Exception as e: - logger.error(f"Failed to configure Google Generative AI client: {e}", exc_info=True) + logger.error(f"Failed to instantiate Google Generative AI client: {e}", exc_info=True) raise diff --git a/src/providers/google_provider/completion.py b/src/providers/google_provider/completion.py index 946c647..77b1899 100644 --- a/src/providers/google_provider/completion.py +++ b/src/providers/google_provider/completion.py @@ -1,140 +1,174 @@ -# src/providers/google_provider/completion.py -import json import logging import traceback +from collections.abc import Iterable # Added Iterable from typing import Any -from google.genai.types import Tool +# Import specific types for better hinting +from google.genai.types import ContentDict, GenerateContentResponse, GenerationConfigDict, Tool -from providers.google_provider.tools import convert_to_google_tool_objects, convert_to_google_tools +# Removed convert_to_google_tools import as it's handled later +from providers.google_provider.tools import convert_to_google_tool_objects from providers.google_provider.utils import convert_messages logger = logging.getLogger(__name__) -def create_chat_completion( +# --- Helper for Non-Streaming --- +def _create_chat_completion_non_stream( provider, + model: str, + google_messages: list[ContentDict], + generation_config: GenerationConfigDict, +) -> GenerateContentResponse | dict[str, Any]: + """Handles the non-streaming API call.""" + try: + logger.debug("Calling client.models.generate_content...") + # Use the client instance stored on the provider + response = provider.client_module.models.generate_content( + model=f"models/{model}", + contents=google_messages, + config=generation_config, + ) + logger.debug("generate_content call successful, returning raw response object.") + # Return the direct response object + return response + except ValueError as ve: + error_msg = f"Google API request validation error: {ve}" + logger.error(error_msg, exc_info=True) + # Return error dict + return {"error": error_msg, "traceback": traceback.format_exc()} + except Exception as e: + error_msg = f"Google API error during non-stream chat completion: {e}" + logger.error(error_msg, exc_info=True) + # Return error dict + return {"error": error_msg, "traceback": traceback.format_exc()} + + +# --- Helper for Streaming --- +def _create_chat_completion_stream( + provider, + model: str, + google_messages: list[ContentDict], + generation_config: GenerationConfigDict, +) -> Iterable[GenerateContentResponse | dict[str, Any]]: # Return Iterable of response chunks or error dict + """Handles the streaming API call and yields results.""" + try: + logger.debug("Calling client.models.generate_content_stream...") + # Use the client instance stored on the provider + response_iterator = provider.client_module.models.generate_content_stream( + model=f"models/{model}", + contents=google_messages, + config=generation_config, + ) + logger.debug("generate_content_stream call successful, yielding from iterator.") + # Yield from the SDK's iterator which produces GenerateContentResponse chunks + yield from response_iterator + except ValueError as ve: + error_msg = f"Google API request validation error: {ve}" + logger.error(error_msg, exc_info=True) + # Yield error as a dict matching non-streaming error structure + yield {"error": error_msg, "traceback": traceback.format_exc()} + except Exception as e: + error_msg = f"Google API error during stream chat completion: {e}" + logger.error(error_msg, exc_info=True) + # Yield error as a dict + yield {"error": error_msg, "traceback": traceback.format_exc()} + + +# --- Main Function --- +# Renamed original function to avoid conflict if needed, though overwrite is fine +def create_chat_completion( + provider, # Provider instance is passed in messages: list[dict[str, Any]], model: str, - temperature: float = 0.4, + temperature: float = 0.6, max_tokens: int | None = None, stream: bool = True, - tools: list[dict[str, Any]] | None = None, -) -> Any: + tools: list[dict[str, Any]] | None = None, # Expects intermediate dict format +) -> Any: # Return type depends on stream flag """ Creates a chat completion using the Google Gemini API. - - Args: - provider: The instance of the GoogleProvider. - messages: A list of message dictionaries in the standard format. - model: The model ID to use (e.g., "gemini-1.5-flash"). - temperature: The sampling temperature. - max_tokens: The maximum number of tokens to generate. - stream: Whether to stream the response. - tools: A list of tool definitions in the MCP format. - - Returns: - The response object from the Google API (could be a stream iterator or - a GenerateContentResponse object), or an error dictionary/iterator. + Delegates to streaming or non-streaming helpers. Contains NO yield itself. """ - logger.debug(f"Google create_chat_completion called. Model: {model}, Stream: {stream}, Tools: {bool(tools)}") + logger.debug(f"Google create_chat_completion_inner called. Model: {model}, Stream: {stream}, Tools: {bool(tools)}") + # Check if client exists on the provider instance if provider.client_module is None: - error_msg = "Google Generative AI SDK not configured or installed." + error_msg = "Google Generative AI client not initialized on provider." logger.error(error_msg) - # Return an error structure compatible with both streaming and non-streaming expectations - if stream: - return iter([json.dumps({"error": error_msg})]) - else: - return {"error": error_msg} + # Return error dict directly for non-stream, create iterator for stream + return iter([{"error": error_msg}]) if stream else {"error": error_msg} try: - # 1. Convert messages to Google's format + # 1. Convert messages (Common logic) google_messages, system_prompt = convert_messages(messages) logger.debug(f"Converted {len(messages)} messages to {len(google_messages)} Google Content objects. System prompt present: {bool(system_prompt)}") - # 2. Prepare generation configuration - generation_config: dict[str, Any] = {"temperature": temperature} + # 2. Prepare generation configuration (Common logic) + # Use GenerationConfigDict for better type hinting if possible + generation_config: GenerationConfigDict = {"temperature": temperature} if max_tokens is not None: - # Google uses 'max_output_tokens' generation_config["max_output_tokens"] = max_tokens logger.debug(f"Setting max_output_tokens: {max_tokens}") else: - logger.debug("No max_tokens specified.") + # Google requires max_output_tokens, set a default if None + # Defaulting to a reasonable value, e.g., 8192, check model limits if needed + default_max_tokens = 8192 + generation_config["max_output_tokens"] = default_max_tokens + logger.warning(f"max_tokens not provided, defaulting to {default_max_tokens} for Google API.") - # 3. Convert tools if provided + # 3. Convert tools if provided (Common logic) google_tool_objects: list[Tool] | None = None if tools: try: - # Step 3a: Convert MCP tools to intermediate Google dict format - tool_dict_list = convert_to_google_tools(tools) - # Step 3b: Convert intermediate dict format to Google Tool objects - google_tool_objects = convert_to_google_tool_objects(tool_dict_list) + # Convert intermediate dict format to Google Tool objects + google_tool_objects = convert_to_google_tool_objects(tools) if google_tool_objects: - logger.debug(f"Successfully converted {len(tools)} MCP tools to {len(google_tool_objects)} Google Tool objects.") + num_declarations = sum(len(t.function_declarations) for t in google_tool_objects if t.function_declarations) + logger.debug(f"Successfully converted intermediate tool config to {len(google_tool_objects)} Google Tool objects with {num_declarations} declarations.") else: logger.warning("Tool conversion resulted in no valid Google Tool objects.") except Exception as tool_conv_err: logger.error(f"Failed to convert tools for Google: {tool_conv_err}", exc_info=True) - # Decide whether to proceed without tools or raise an error - # Proceeding without tools for now, but logging the error. - google_tool_objects = None + google_tool_objects = None # Continue without tools on conversion error else: logger.debug("No tools provided for conversion.") - # 4. Initialize the Google Generative Model - # Ensure client_module is callable and has GenerativeModel - if not hasattr(provider.client_module, "GenerativeModel"): - raise AttributeError("Configured Google client module does not have 'GenerativeModel'") + # 4. Add system prompt and tools to generation_config (Common logic) + if system_prompt: + # Ensure system_instruction is ContentDict or compatible type + generation_config["system_instruction"] = system_prompt + logger.debug("Added system_instruction to generation_config.") + if google_tool_objects: + # Assign the list of Tool objects directly + generation_config["tools"] = google_tool_objects + logger.debug(f"Added {len(google_tool_objects)} tool objects to generation_config.") - gemini_model = provider.client_module.GenerativeModel( - model_name=model, - system_instruction=system_prompt, # Pass extracted system prompt - tools=google_tool_objects, # Pass converted Tool objects (or None) - # Add safety_settings if needed: safety_settings=... - ) - logger.debug(f"Initialized Google GenerativeModel for '{model}'.") - - # 5. Log parameters before API call + # 5. Log parameters before API call (Common logic) log_params = { "model": model, "stream": stream, "temperature": temperature, "max_output_tokens": generation_config.get("max_output_tokens"), "system_prompt_present": bool(system_prompt), - "num_tools": len(google_tool_objects) if google_tool_objects else 0, + "num_tools": len(generation_config.get("tools", [])) if "tools" in generation_config else 0, "num_messages": len(google_messages), } - logger.info(f"Calling Google generate_content API with params: {log_params}") - # Avoid logging full message content unless necessary for debugging specific issues - # logger.debug(f"Google messages being sent: {google_messages}") + logger.info(f"Calling Google API via helper with params: {log_params}") - # 6. Call the Google API - response = gemini_model.generate_content( - contents=google_messages, - generation_config=generation_config, - stream=stream, - # tool_config={"function_calling_config": "AUTO"} # AUTO is default - ) - logger.debug("Google API call successful, returning response object.") - return response - - except ValueError as ve: # Catch specific errors like invalid message sequence - error_msg = f"Google API request validation error: {ve}" - logger.error(error_msg, exc_info=True) + # 6. Delegate to appropriate helper if stream: - # Yield a JSON error message in an iterator - yield json.dumps({"error": error_msg, "traceback": traceback.format_exc()}) + # Return the generator/iterator from the streaming helper + # This helper uses 'yield from' + return _create_chat_completion_stream(provider, model, google_messages, generation_config) else: - # Return an error dictionary - return {"error": error_msg, "traceback": traceback.format_exc()} + # Return the direct result (GenerateContentResponse or error dict) from the non-streaming helper + # This helper uses 'return' + return _create_chat_completion_non_stream(provider, model, google_messages, generation_config) + except Exception as e: - # Catch any other exceptions during setup or API call - error_msg = f"Google API error during chat completion: {e}" + # Catch errors during common setup (message/tool conversion etc.) + error_msg = f"Error during Google completion setup: {e}" logger.error(error_msg, exc_info=True) - if stream: - # Yield a JSON error message in an iterator - yield json.dumps({"error": error_msg, "traceback": traceback.format_exc()}) - else: - # Return an error dictionary - return {"error": error_msg, "traceback": traceback.format_exc()} + # Return error dict directly for non-stream, create iterator for stream + return iter([{"error": error_msg, "traceback": traceback.format_exc()}]) if stream else {"error": error_msg, "traceback": traceback.format_exc()} diff --git a/src/providers/google_provider/response.py b/src/providers/google_provider/response.py index f8b7315..d1abbe8 100644 --- a/src/providers/google_provider/response.py +++ b/src/providers/google_provider/response.py @@ -124,37 +124,46 @@ def get_content(response: GenerateContentResponse | dict[str, Any]) -> str: The concatenated text content, or an error message string. """ try: - # Handle error dictionary case + # Check if it's an error dictionary passed from upstream (e.g., completion helper) if isinstance(response, dict) and "error" in response: - logger.error(f"Cannot get content from error response: {response['error']}") + logger.error(f"Cannot get content from error dict: {response['error']}") return f"[Error: {response['error']}]" - # Handle successful GenerateContentResponse object - if hasattr(response, "text"): - # The `.text` attribute usually provides the concatenated text content directly + # Ensure it's a GenerateContentResponse object before accessing attributes + if not isinstance(response, GenerateContentResponse): + logger.error(f"Cannot get content: Expected GenerateContentResponse or error dict, got {type(response)}") + return f"[Error: Unexpected response type {type(response)}]" + + # --- Access GenerateContentResponse attributes --- + # Prioritize response.text if available and not empty + if hasattr(response, "text") and response.text: content = response.text logger.debug(f"Extracted content (length {len(content)}) from response.text.") return content - elif hasattr(response, "candidates") and response.candidates: - # Fallback: manually concatenate text from parts if .text is missing + + # Fallback: manually concatenate text from parts if .text is missing/empty + if hasattr(response, "candidates") and response.candidates: first_candidate = response.candidates[0] - if hasattr(first_candidate, "content") and hasattr(first_candidate.content, "parts"): - text_parts = [] - for part in first_candidate.content.parts: - if hasattr(part, "text"): - text_parts.append(part.text) - # We are only interested in text content here, ignore function calls etc. - content = "".join(text_parts) - logger.debug(f"Extracted content (length {len(content)}) from response candidates' parts.") - return content + # Check candidate content and parts carefully + if hasattr(first_candidate, "content") and first_candidate.content and hasattr(first_candidate.content, "parts") and first_candidate.content.parts: + text_parts = [part.text for part in first_candidate.content.parts if hasattr(part, "text")] + if text_parts: + content = "".join(text_parts) + logger.debug(f"Extracted content (length {len(content)}) from response candidate parts.") + return content + else: + logger.warning("Google response candidate parts contained no text.") + return "" # Return empty if parts exist but have no text else: - logger.warning("Google response candidate has no content or parts.") - return "" # Return empty string if no text found + logger.warning("Google response candidate has no valid content or parts.") + return "" # Return empty string if no valid content/parts else: - logger.warning(f"Could not extract content from Google response: No 'text' or valid 'candidates'. Response type: {type(response)}") + # If neither .text nor valid candidates are found + logger.warning(f"Could not extract content from Google response: No .text or valid candidates found. Response: {response}") return "" # Return empty string if no text found + except AttributeError as ae: - logger.error(f"Attribute error extracting content from Google response: {ae}. Response object: {response}", exc_info=True) + logger.error(f"Attribute error extracting content from Google response: {ae}. Response type: {type(response)}", exc_info=True) return f"[Error extracting content: Attribute missing - {str(ae)}]" except Exception as e: logger.error(f"Unexpected error extracting content from Google response: {e}", exc_info=True) @@ -173,32 +182,34 @@ def get_usage(response: GenerateContentResponse | dict[str, Any]) -> dict[str, i usage information is unavailable or an error occurred. """ try: - # Handle error dictionary case + # Check if it's an error dictionary passed from upstream if isinstance(response, dict) and "error" in response: - logger.warning("Cannot get usage from error response.") + logger.warning(f"Cannot get usage from error dict: {response['error']}") return None - # Check for usage metadata in the response object - if hasattr(response, "usage_metadata"): - metadata = response.usage_metadata + # Ensure it's a GenerateContentResponse object before accessing attributes + if not isinstance(response, GenerateContentResponse): + logger.warning(f"Cannot get usage: Expected GenerateContentResponse or error dict, got {type(response)}") + return None + + # Safely access usage metadata + metadata = getattr(response, "usage_metadata", None) + if metadata: # Google uses prompt_token_count and candidates_token_count + prompt_tokens = getattr(metadata, "prompt_token_count", 0) + completion_tokens = getattr(metadata, "candidates_token_count", 0) usage = { - "prompt_tokens": getattr(metadata, "prompt_token_count", 0), - "completion_tokens": getattr(metadata, "candidates_token_count", 0), - # Google also provides total_token_count, could be added if needed - # "total_tokens": getattr(metadata, "total_token_count", 0), + "prompt_tokens": int(prompt_tokens), + "completion_tokens": int(completion_tokens), } - # Ensure values are integers - usage = {k: int(v) for k, v in usage.items()} logger.debug(f"Extracted usage from Google response metadata: {usage}") return usage else: - # Log a warning only if it's not clearly an error dict already handled - if not (isinstance(response, dict) and "error" in response): - logger.warning(f"Could not extract usage from Google response object of type {type(response)}. No 'usage_metadata' attribute found.") + logger.warning(f"Could not extract usage from Google response object: No 'usage_metadata' attribute found. Response: {response}") return None + except AttributeError as ae: - logger.error(f"Attribute error extracting usage from Google response: {ae}. Response object: {response}", exc_info=True) + logger.error(f"Attribute error extracting usage from Google response: {ae}. Response type: {type(response)}", exc_info=True) return None except Exception as e: logger.error(f"Unexpected error extracting usage from Google response: {e}", exc_info=True) diff --git a/src/providers/google_provider/tools.py b/src/providers/google_provider/tools.py index f256724..dc23289 100644 --- a/src/providers/google_provider/tools.py +++ b/src/providers/google_provider/tools.py @@ -13,7 +13,7 @@ import json import logging from typing import Any -from google.genai.types import FunctionDeclaration, Schema, Tool +from google.genai.types import FunctionDeclaration, Schema, Tool, Type logger = logging.getLogger(__name__) @@ -95,50 +95,66 @@ def _create_google_schema_recursive(schema_dict: dict[str, Any]) -> Schema | Non Handles type mapping and nested structures. Returns None on failure. """ - if Schema is None: - logger.error("Cannot create Schema object: google.genai types not available.") + if Schema is None or Type is None: + logger.error("Cannot create Schema object: google.genai types (Schema or Type) not available.") return None if not isinstance(schema_dict, dict): - logger.warning(f"Invalid schema part encountered: {schema_dict}. Returning empty schema.") - return Schema() # Return empty schema to avoid breaking the parent + logger.warning(f"Invalid schema part encountered: {schema_dict}. Returning None.") + return None # Return None on invalid input - # Map JSON Schema types to Google's Type enum strings + # Map JSON Schema types to Google's Type enum members type_mapping = { - "string": "STRING", - "number": "NUMBER", - "integer": "INTEGER", - "boolean": "BOOLEAN", - "array": "ARRAY", - "object": "OBJECT", - # Add other mappings if necessary + "string": Type.STRING, + "number": Type.NUMBER, + "integer": Type.INTEGER, + "boolean": Type.BOOLEAN, + "array": Type.ARRAY, + "object": Type.OBJECT, } original_type = schema_dict.get("type") google_type = type_mapping.get(str(original_type).lower()) if original_type else None + if not google_type: + logger.warning(f"Schema dictionary missing 'type' or type '{original_type}' is not recognized: {schema_dict}. Returning None.") + return None # Return None if type is invalid/missing + # Prepare arguments for Schema constructor, filtering out None values schema_args = { - "type": google_type, + "type": google_type, # Use the Type enum member "format": schema_dict.get("format"), "description": schema_dict.get("description"), - "nullable": schema_dict.get("nullable"), + "nullable": schema_dict.get("nullable"), # Note: Google's Schema might not directly support nullable in constructor "enum": schema_dict.get("enum"), - "items": _create_google_schema_recursive(schema_dict["items"]) if "items" in schema_dict and google_type == "ARRAY" else None, - "properties": {k: _create_google_schema_recursive(v) for k, v in schema_dict.get("properties", {}).items()} if schema_dict.get("properties") and google_type == "OBJECT" else None, - "required": schema_dict.get("required") if google_type == "OBJECT" else None, + # Recursively create nested schemas, ensuring None is handled if recursion fails + "items": _create_google_schema_recursive(schema_dict["items"]) if google_type == Type.ARRAY and "items" in schema_dict else None, + "properties": {k: prop_schema for k, v in schema_dict.get("properties", {}).items() if (prop_schema := _create_google_schema_recursive(v)) is not None} + if google_type == Type.OBJECT and schema_dict.get("properties") + else None, + "required": schema_dict.get("required") if google_type == Type.OBJECT else None, } - # Remove keys with None values + + # Remove keys with None values before passing to Schema constructor schema_args = {k: v for k, v in schema_args.items() if v is not None} - if not schema_args.get("type"): - logger.warning(f"Schema dictionary missing 'type' or type '{original_type}' is not recognized: {schema_dict}. Creating empty Schema.") - return Schema() # Return empty schema + # Handle specific cases for ARRAY and OBJECT where items/properties might be needed + if google_type == Type.ARRAY and "items" not in schema_args: + logger.warning(f"Array schema missing 'items': {schema_dict}. Returning None.") + return None # Array schema requires items + if google_type == Type.OBJECT and "properties" not in schema_args: + # Allow object schema without properties initially, might be handled later + pass + # logger.warning(f"Object schema missing 'properties': {schema_dict}. Creating empty properties.") + # schema_args["properties"] = {} # Or return None if properties are strictly required try: - return Schema(**schema_args) + # Create the Schema object + created_schema = Schema(**schema_args) + # logger.debug(f"Successfully created Schema: {created_schema}") + return created_schema except Exception as schema_creation_err: logger.error(f"Failed to create Schema object with args {schema_args}: {schema_creation_err}", exc_info=True) - return Schema() # Return empty schema on error + return None # Return None on creation error def convert_to_google_tool_objects(tool_configs: list[dict[str, Any]]) -> list[Tool] | None: @@ -169,31 +185,71 @@ def convert_to_google_tool_objects(tool_configs: list[dict[str, Any]]) -> list[T return None for func_dict in func_declarations_list: + func_name = func_dict.get("name", "Unknown") try: - params_schema_dict = func_dict.get("parameters", {"type": "object", "properties": {}}) - # Ensure parameters is a valid schema dict for the recursive creator + params_schema_dict = func_dict.get("parameters", {}) + + # Ensure parameters is a dict and defaults to object type if missing if not isinstance(params_schema_dict, dict): - logger.warning(f"Invalid 'parameters' format for tool {func_dict.get('name')}: {params_schema_dict}. Using empty object schema.") + logger.warning(f"Invalid 'parameters' format for tool {func_name}: {params_schema_dict}. Using empty object schema.") params_schema_dict = {"type": "object", "properties": {}} - elif params_schema_dict.get("type") != "object": - logger.warning(f"Tool {func_dict.get('name')} parameters schema is not type 'object'. Forcing object type.") - params_schema_dict = {"type": "object", "properties": params_schema_dict.get("properties", {})} # Attempt to salvage properties + elif "type" not in params_schema_dict: + params_schema_dict["type"] = "object" # Default to object if type is missing + elif params_schema_dict["type"] != "object": + logger.warning(f"Tool {func_name} parameters schema is not type 'object' ({params_schema_dict.get('type')}). Google requires 'object'. Attempting to wrap properties.") + # Attempt to salvage properties if the top level isn't object + original_properties = params_schema_dict.get("properties", {}) + if not isinstance(original_properties, dict): + original_properties = {} + params_schema_dict = {"type": "object", "properties": original_properties} - parameters_schema = _create_google_schema_recursive(params_schema_dict) - - # Only proceed if schema creation was somewhat successful - if parameters_schema is not None: - declaration = FunctionDeclaration( - name=func_dict["name"], - description=func_dict.get("description", ""), - parameters=parameters_schema, - ) - all_func_declarations.append(declaration) + properties_dict = params_schema_dict.get("properties", {}) + google_properties = {} + if isinstance(properties_dict, dict): + for prop_name, prop_schema_dict in properties_dict.items(): + prop_schema = _create_google_schema_recursive(prop_schema_dict) + if prop_schema: + google_properties[prop_name] = prop_schema + else: + logger.warning(f"Failed to create schema for property '{prop_name}' in tool '{func_name}'. Skipping property.") else: - logger.error(f"Failed to create parameters Schema for FunctionDeclaration '{func_dict.get('name', 'Unknown')}'") + logger.warning(f"'properties' for tool {func_name} is not a dictionary: {properties_dict}. Ignoring properties.") + + # Handle empty properties - Google requires parameters to be OBJECT, and properties cannot be null/empty + if not google_properties: + logger.warning(f"Function '{func_name}' has no valid properties defined. Adding dummy property for Google compatibility.") + google_properties = {"_dummy_param": Schema(type=Type.STRING, description="Placeholder parameter as properties cannot be empty.")} + # Clear required list if properties are empty/dummy + required_list = [] + else: + # Validate required list against actual properties + original_required = params_schema_dict.get("required", []) + if isinstance(original_required, list): + required_list = [req for req in original_required if req in google_properties] + if len(required_list) != len(original_required): + logger.warning(f"Some required properties for '{func_name}' were invalid or missing from properties: {set(original_required) - set(required_list)}") + else: + logger.warning(f"'required' field for '{func_name}' is not a list: {original_required}. Ignoring required field.") + required_list = [] + + # Create the top-level parameters schema, ensuring it's OBJECT type + parameters_schema = Schema( + type=Type.OBJECT, + properties=google_properties, + required=required_list if required_list else None, # Pass None if empty list + ) + + # Create the FunctionDeclaration + declaration = FunctionDeclaration( + name=func_name, + description=func_dict.get("description", ""), + parameters=parameters_schema, + ) + all_func_declarations.append(declaration) + logger.debug(f"Successfully created FunctionDeclaration for: {func_name}") except Exception as decl_err: - logger.error(f"Failed to create FunctionDeclaration object for tool '{func_dict.get('name', 'Unknown')}': {decl_err}", exc_info=True) + logger.error(f"Failed to create FunctionDeclaration object for tool '{func_name}': {decl_err}", exc_info=True) else: logger.error(f"Invalid tool_configs structure provided: {tool_configs}") diff --git a/src/providers/google_provider/utils.py b/src/providers/google_provider/utils.py index da36ff7..b979a2f 100644 --- a/src/providers/google_provider/utils.py +++ b/src/providers/google_provider/utils.py @@ -108,16 +108,16 @@ def convert_messages(messages: list[dict[str, Any]]) -> tuple[list[Content], str # Include any text content alongside the function calls if content and isinstance(content, str): - parts.append(Part.from_text(content)) + parts.append(Part(text=content)) # Use direct instantiation elif content: # Regular user or assistant message content if isinstance(content, str): - parts.append(Part.from_text(content)) + parts.append(Part(text=content)) # Use direct instantiation # TODO: Handle potential image content if needed in the future else: logger.warning(f"Unsupported content type for role '{role}': {type(content)}. Converting to string.") - parts.append(Part.from_text(str(content))) + parts.append(Part(text=str(content))) # Use direct instantiation # Add the constructed Content object if parts were generated if parts: diff --git a/src/providers/openai_provider/__init__.py b/src/providers/openai_provider/__init__.py index 96995ff..5bc69c5 100644 --- a/src/providers/openai_provider/__init__.py +++ b/src/providers/openai_provider/__init__.py @@ -32,7 +32,7 @@ class OpenAIProvider(BaseProvider): self, messages: list[dict[str, str]], model: str, - temperature: float = 0.4, + temperature: float = 0.6, max_tokens: int | None = None, stream: bool = True, tools: list[dict[str, Any]] | None = None, diff --git a/src/providers/openai_provider/completion.py b/src/providers/openai_provider/completion.py index 78652c6..731e56f 100644 --- a/src/providers/openai_provider/completion.py +++ b/src/providers/openai_provider/completion.py @@ -14,7 +14,7 @@ def create_chat_completion( provider, # The OpenAIProvider instance messages: list[dict[str, str]], model: str, - temperature: float = 0.4, + temperature: float = 0.6, max_tokens: int | None = None, stream: bool = True, tools: list[dict[str, Any]] | None = None,