Compare commits

..

2 Commits

11 changed files with 293 additions and 181 deletions

View File

@@ -56,7 +56,7 @@ class LLMClient:
self,
messages: list[dict[str, str]],
model: str,
temperature: float = 0.4,
temperature: float = 0.6,
max_tokens: int | None = None,
stream: bool = True,
) -> Generator[str, None, None] | dict[str, Any]:
@@ -97,6 +97,7 @@ class LLMClient:
stream=stream,
tools=provider_tools,
)
print(f"Response: {response}") # Debugging line to check the response
logger.info("Received response from provider.")
if stream:

View File

@@ -10,7 +10,7 @@ logger = logging.getLogger(__name__)
def create_chat_completion(
provider, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, stream: bool = True, tools: list[dict[str, Any]] | None = None
provider, messages: list[dict[str, Any]], model: str, temperature: float = 0.6, max_tokens: int | None = None, stream: bool = True, tools: list[dict[str, Any]] | None = None
) -> Stream | Message:
logger.debug(f"Creating Anthropic chat completion. Model: {model}, Stream: {stream}, Tools: {bool(tools)}")
temp_system_prompt, temp_anthropic_messages = convert_messages(messages)

View File

@@ -28,7 +28,7 @@ class BaseProvider(abc.ABC):
self,
messages: list[dict[str, str]],
model: str,
temperature: float = 0.4,
temperature: float = 0.6,
max_tokens: int | None = None,
stream: bool = True,
tools: list[dict[str, Any]] | None = None,
@@ -39,7 +39,7 @@ class BaseProvider(abc.ABC):
Args:
messages: List of message dictionaries with 'role' and 'content'.
model: Model identifier.
temperature: Sampling temperature (0-1).
temperature: Sampling temperature (0-2).
max_tokens: Maximum tokens to generate.
stream: Whether to stream the response.
tools: Optional list of tools in the provider-specific format.

View File

@@ -3,9 +3,12 @@ import logging
from collections.abc import Generator
from typing import Any
# Import Generator type for isinstance check - Keep this import for type hints
from google.genai.types import GenerateContentResponse
from providers.google_provider.client import initialize_client
# Correctly import the renamed function directly
from providers.google_provider.completion import create_chat_completion
from providers.google_provider.response import get_content, get_streaming_content, get_usage
from providers.google_provider.tools import convert_to_google_tools, format_google_tool_results, has_google_tool_calls, parse_google_tool_calls
@@ -28,7 +31,7 @@ class GoogleProvider(BaseProvider):
api_key: The Google API key.
base_url: Base URL (typically not used by Google client config, but kept for interface consistency).
"""
# initialize_client returns the configured genai module
# initialize_client returns the client instance now
self.client_module = initialize_client(api_key, base_url)
self.api_key = api_key # Store if needed later
self.base_url = base_url # Store if needed later
@@ -38,18 +41,24 @@ class GoogleProvider(BaseProvider):
self,
messages: list[dict[str, Any]],
model: str,
temperature: float = 0.4,
temperature: float = 0.6,
max_tokens: int | None = None,
stream: bool = True,
tools: list[dict[str, Any]] | None = None,
) -> Any: # Return type is complex: iterator for stream, GenerateContentResponse otherwise, or error dict/iterator
"""Creates a chat completion using the Google Gemini API."""
# Pass self (provider instance) to the helper function
return create_chat_completion(self, messages, model, temperature, max_tokens, stream, tools)
raw_response = create_chat_completion(self, messages, model, temperature, max_tokens, stream, tools)
print(f"Raw response type: {type(raw_response)}") # Debugging line to check the type of raw_response
print(f"Raw response: {raw_response}") # Debugging line to check the content of raw_response
# The completion helper function handles returning the correct type or an error dict.
# No need for generator handling here anymore.
return raw_response
def get_streaming_content(self, response: Any) -> Generator[str, None, None]:
"""Extracts content chunks from a Google streaming response."""
# Response is expected to be an iterator from generate_content(stream=True)
# Response is expected to be an iterator from generate_content_stream
return get_streaming_content(response)
def get_content(self, response: GenerateContentResponse | dict[str, Any]) -> str:

View File

@@ -12,16 +12,17 @@ def initialize_client(api_key: str, base_url: str | None = None) -> Any:
logger.info("Initializing Google Generative AI client")
if genai is None:
logger.error("Google Generative AI SDK (google-generativeai) is not installed.")
logger.error("Google Generative AI SDK (google-genai) is not installed.")
raise ImportError("Google Generative AI SDK is required for GoogleProvider. Please install google-generativeai.")
try:
# Configure the client
genai.configure(api_key=api_key)
# Instantiate the client directly using the API key
client = genai.Client(api_key=api_key)
logger.info("Google Generative AI client instantiated.")
if base_url:
logger.warning(f"base_url '{base_url}' provided but not typically used by Google client configuration.")
# Return the configured module itself, as it's used directly
return genai
logger.warning(f"base_url '{base_url}' provided but not typically used by Google client instantiation.")
# Return the client instance
return client
except Exception as e:
logger.error(f"Failed to configure Google Generative AI client: {e}", exc_info=True)
logger.error(f"Failed to instantiate Google Generative AI client: {e}", exc_info=True)
raise

View File

@@ -1,140 +1,174 @@
# src/providers/google_provider/completion.py
import json
import logging
import traceback
from collections.abc import Iterable # Added Iterable
from typing import Any
from google.genai.types import Tool
# Import specific types for better hinting
from google.genai.types import ContentDict, GenerateContentResponse, GenerationConfigDict, Tool
from providers.google_provider.tools import convert_to_google_tool_objects, convert_to_google_tools
# Removed convert_to_google_tools import as it's handled later
from providers.google_provider.tools import convert_to_google_tool_objects
from providers.google_provider.utils import convert_messages
logger = logging.getLogger(__name__)
def create_chat_completion(
# --- Helper for Non-Streaming ---
def _create_chat_completion_non_stream(
provider,
model: str,
google_messages: list[ContentDict],
generation_config: GenerationConfigDict,
) -> GenerateContentResponse | dict[str, Any]:
"""Handles the non-streaming API call."""
try:
logger.debug("Calling client.models.generate_content...")
# Use the client instance stored on the provider
response = provider.client_module.models.generate_content(
model=f"models/{model}",
contents=google_messages,
config=generation_config,
)
logger.debug("generate_content call successful, returning raw response object.")
# Return the direct response object
return response
except ValueError as ve:
error_msg = f"Google API request validation error: {ve}"
logger.error(error_msg, exc_info=True)
# Return error dict
return {"error": error_msg, "traceback": traceback.format_exc()}
except Exception as e:
error_msg = f"Google API error during non-stream chat completion: {e}"
logger.error(error_msg, exc_info=True)
# Return error dict
return {"error": error_msg, "traceback": traceback.format_exc()}
# --- Helper for Streaming ---
def _create_chat_completion_stream(
provider,
model: str,
google_messages: list[ContentDict],
generation_config: GenerationConfigDict,
) -> Iterable[GenerateContentResponse | dict[str, Any]]: # Return Iterable of response chunks or error dict
"""Handles the streaming API call and yields results."""
try:
logger.debug("Calling client.models.generate_content_stream...")
# Use the client instance stored on the provider
response_iterator = provider.client_module.models.generate_content_stream(
model=f"models/{model}",
contents=google_messages,
config=generation_config,
)
logger.debug("generate_content_stream call successful, yielding from iterator.")
# Yield from the SDK's iterator which produces GenerateContentResponse chunks
yield from response_iterator
except ValueError as ve:
error_msg = f"Google API request validation error: {ve}"
logger.error(error_msg, exc_info=True)
# Yield error as a dict matching non-streaming error structure
yield {"error": error_msg, "traceback": traceback.format_exc()}
except Exception as e:
error_msg = f"Google API error during stream chat completion: {e}"
logger.error(error_msg, exc_info=True)
# Yield error as a dict
yield {"error": error_msg, "traceback": traceback.format_exc()}
# --- Main Function ---
# Renamed original function to avoid conflict if needed, though overwrite is fine
def create_chat_completion(
provider, # Provider instance is passed in
messages: list[dict[str, Any]],
model: str,
temperature: float = 0.4,
temperature: float = 0.6,
max_tokens: int | None = None,
stream: bool = True,
tools: list[dict[str, Any]] | None = None,
) -> Any:
tools: list[dict[str, Any]] | None = None, # Expects intermediate dict format
) -> Any: # Return type depends on stream flag
"""
Creates a chat completion using the Google Gemini API.
Args:
provider: The instance of the GoogleProvider.
messages: A list of message dictionaries in the standard format.
model: The model ID to use (e.g., "gemini-1.5-flash").
temperature: The sampling temperature.
max_tokens: The maximum number of tokens to generate.
stream: Whether to stream the response.
tools: A list of tool definitions in the MCP format.
Returns:
The response object from the Google API (could be a stream iterator or
a GenerateContentResponse object), or an error dictionary/iterator.
Delegates to streaming or non-streaming helpers. Contains NO yield itself.
"""
logger.debug(f"Google create_chat_completion called. Model: {model}, Stream: {stream}, Tools: {bool(tools)}")
logger.debug(f"Google create_chat_completion_inner called. Model: {model}, Stream: {stream}, Tools: {bool(tools)}")
# Check if client exists on the provider instance
if provider.client_module is None:
error_msg = "Google Generative AI SDK not configured or installed."
error_msg = "Google Generative AI client not initialized on provider."
logger.error(error_msg)
# Return an error structure compatible with both streaming and non-streaming expectations
if stream:
return iter([json.dumps({"error": error_msg})])
else:
return {"error": error_msg}
# Return error dict directly for non-stream, create iterator for stream
return iter([{"error": error_msg}]) if stream else {"error": error_msg}
try:
# 1. Convert messages to Google's format
# 1. Convert messages (Common logic)
google_messages, system_prompt = convert_messages(messages)
logger.debug(f"Converted {len(messages)} messages to {len(google_messages)} Google Content objects. System prompt present: {bool(system_prompt)}")
# 2. Prepare generation configuration
generation_config: dict[str, Any] = {"temperature": temperature}
# 2. Prepare generation configuration (Common logic)
# Use GenerationConfigDict for better type hinting if possible
generation_config: GenerationConfigDict = {"temperature": temperature}
if max_tokens is not None:
# Google uses 'max_output_tokens'
generation_config["max_output_tokens"] = max_tokens
logger.debug(f"Setting max_output_tokens: {max_tokens}")
else:
logger.debug("No max_tokens specified.")
# Google requires max_output_tokens, set a default if None
# Defaulting to a reasonable value, e.g., 8192, check model limits if needed
default_max_tokens = 8192
generation_config["max_output_tokens"] = default_max_tokens
logger.warning(f"max_tokens not provided, defaulting to {default_max_tokens} for Google API.")
# 3. Convert tools if provided
# 3. Convert tools if provided (Common logic)
google_tool_objects: list[Tool] | None = None
if tools:
try:
# Step 3a: Convert MCP tools to intermediate Google dict format
tool_dict_list = convert_to_google_tools(tools)
# Step 3b: Convert intermediate dict format to Google Tool objects
google_tool_objects = convert_to_google_tool_objects(tool_dict_list)
# Convert intermediate dict format to Google Tool objects
google_tool_objects = convert_to_google_tool_objects(tools)
if google_tool_objects:
logger.debug(f"Successfully converted {len(tools)} MCP tools to {len(google_tool_objects)} Google Tool objects.")
num_declarations = sum(len(t.function_declarations) for t in google_tool_objects if t.function_declarations)
logger.debug(f"Successfully converted intermediate tool config to {len(google_tool_objects)} Google Tool objects with {num_declarations} declarations.")
else:
logger.warning("Tool conversion resulted in no valid Google Tool objects.")
except Exception as tool_conv_err:
logger.error(f"Failed to convert tools for Google: {tool_conv_err}", exc_info=True)
# Decide whether to proceed without tools or raise an error
# Proceeding without tools for now, but logging the error.
google_tool_objects = None
google_tool_objects = None # Continue without tools on conversion error
else:
logger.debug("No tools provided for conversion.")
# 4. Initialize the Google Generative Model
# Ensure client_module is callable and has GenerativeModel
if not hasattr(provider.client_module, "GenerativeModel"):
raise AttributeError("Configured Google client module does not have 'GenerativeModel'")
# 4. Add system prompt and tools to generation_config (Common logic)
if system_prompt:
# Ensure system_instruction is ContentDict or compatible type
generation_config["system_instruction"] = system_prompt
logger.debug("Added system_instruction to generation_config.")
if google_tool_objects:
# Assign the list of Tool objects directly
generation_config["tools"] = google_tool_objects
logger.debug(f"Added {len(google_tool_objects)} tool objects to generation_config.")
gemini_model = provider.client_module.GenerativeModel(
model_name=model,
system_instruction=system_prompt, # Pass extracted system prompt
tools=google_tool_objects, # Pass converted Tool objects (or None)
# Add safety_settings if needed: safety_settings=...
)
logger.debug(f"Initialized Google GenerativeModel for '{model}'.")
# 5. Log parameters before API call
# 5. Log parameters before API call (Common logic)
log_params = {
"model": model,
"stream": stream,
"temperature": temperature,
"max_output_tokens": generation_config.get("max_output_tokens"),
"system_prompt_present": bool(system_prompt),
"num_tools": len(google_tool_objects) if google_tool_objects else 0,
"num_tools": len(generation_config.get("tools", [])) if "tools" in generation_config else 0,
"num_messages": len(google_messages),
}
logger.info(f"Calling Google generate_content API with params: {log_params}")
# Avoid logging full message content unless necessary for debugging specific issues
# logger.debug(f"Google messages being sent: {google_messages}")
logger.info(f"Calling Google API via helper with params: {log_params}")
# 6. Call the Google API
response = gemini_model.generate_content(
contents=google_messages,
generation_config=generation_config,
stream=stream,
# tool_config={"function_calling_config": "AUTO"} # AUTO is default
)
logger.debug("Google API call successful, returning response object.")
return response
except ValueError as ve: # Catch specific errors like invalid message sequence
error_msg = f"Google API request validation error: {ve}"
logger.error(error_msg, exc_info=True)
# 6. Delegate to appropriate helper
if stream:
# Yield a JSON error message in an iterator
yield json.dumps({"error": error_msg, "traceback": traceback.format_exc()})
# Return the generator/iterator from the streaming helper
# This helper uses 'yield from'
return _create_chat_completion_stream(provider, model, google_messages, generation_config)
else:
# Return an error dictionary
return {"error": error_msg, "traceback": traceback.format_exc()}
# Return the direct result (GenerateContentResponse or error dict) from the non-streaming helper
# This helper uses 'return'
return _create_chat_completion_non_stream(provider, model, google_messages, generation_config)
except Exception as e:
# Catch any other exceptions during setup or API call
error_msg = f"Google API error during chat completion: {e}"
# Catch errors during common setup (message/tool conversion etc.)
error_msg = f"Error during Google completion setup: {e}"
logger.error(error_msg, exc_info=True)
if stream:
# Yield a JSON error message in an iterator
yield json.dumps({"error": error_msg, "traceback": traceback.format_exc()})
else:
# Return an error dictionary
return {"error": error_msg, "traceback": traceback.format_exc()}
# Return error dict directly for non-stream, create iterator for stream
return iter([{"error": error_msg, "traceback": traceback.format_exc()}]) if stream else {"error": error_msg, "traceback": traceback.format_exc()}

View File

@@ -124,37 +124,46 @@ def get_content(response: GenerateContentResponse | dict[str, Any]) -> str:
The concatenated text content, or an error message string.
"""
try:
# Handle error dictionary case
# Check if it's an error dictionary passed from upstream (e.g., completion helper)
if isinstance(response, dict) and "error" in response:
logger.error(f"Cannot get content from error response: {response['error']}")
logger.error(f"Cannot get content from error dict: {response['error']}")
return f"[Error: {response['error']}]"
# Handle successful GenerateContentResponse object
if hasattr(response, "text"):
# The `.text` attribute usually provides the concatenated text content directly
# Ensure it's a GenerateContentResponse object before accessing attributes
if not isinstance(response, GenerateContentResponse):
logger.error(f"Cannot get content: Expected GenerateContentResponse or error dict, got {type(response)}")
return f"[Error: Unexpected response type {type(response)}]"
# --- Access GenerateContentResponse attributes ---
# Prioritize response.text if available and not empty
if hasattr(response, "text") and response.text:
content = response.text
logger.debug(f"Extracted content (length {len(content)}) from response.text.")
return content
elif hasattr(response, "candidates") and response.candidates:
# Fallback: manually concatenate text from parts if .text is missing
# Fallback: manually concatenate text from parts if .text is missing/empty
if hasattr(response, "candidates") and response.candidates:
first_candidate = response.candidates[0]
if hasattr(first_candidate, "content") and hasattr(first_candidate.content, "parts"):
text_parts = []
for part in first_candidate.content.parts:
if hasattr(part, "text"):
text_parts.append(part.text)
# We are only interested in text content here, ignore function calls etc.
content = "".join(text_parts)
logger.debug(f"Extracted content (length {len(content)}) from response candidates' parts.")
return content
# Check candidate content and parts carefully
if hasattr(first_candidate, "content") and first_candidate.content and hasattr(first_candidate.content, "parts") and first_candidate.content.parts:
text_parts = [part.text for part in first_candidate.content.parts if hasattr(part, "text")]
if text_parts:
content = "".join(text_parts)
logger.debug(f"Extracted content (length {len(content)}) from response candidate parts.")
return content
else:
logger.warning("Google response candidate parts contained no text.")
return "" # Return empty if parts exist but have no text
else:
logger.warning("Google response candidate has no content or parts.")
return "" # Return empty string if no text found
logger.warning("Google response candidate has no valid content or parts.")
return "" # Return empty string if no valid content/parts
else:
logger.warning(f"Could not extract content from Google response: No 'text' or valid 'candidates'. Response type: {type(response)}")
# If neither .text nor valid candidates are found
logger.warning(f"Could not extract content from Google response: No .text or valid candidates found. Response: {response}")
return "" # Return empty string if no text found
except AttributeError as ae:
logger.error(f"Attribute error extracting content from Google response: {ae}. Response object: {response}", exc_info=True)
logger.error(f"Attribute error extracting content from Google response: {ae}. Response type: {type(response)}", exc_info=True)
return f"[Error extracting content: Attribute missing - {str(ae)}]"
except Exception as e:
logger.error(f"Unexpected error extracting content from Google response: {e}", exc_info=True)
@@ -173,32 +182,34 @@ def get_usage(response: GenerateContentResponse | dict[str, Any]) -> dict[str, i
usage information is unavailable or an error occurred.
"""
try:
# Handle error dictionary case
# Check if it's an error dictionary passed from upstream
if isinstance(response, dict) and "error" in response:
logger.warning("Cannot get usage from error response.")
logger.warning(f"Cannot get usage from error dict: {response['error']}")
return None
# Check for usage metadata in the response object
if hasattr(response, "usage_metadata"):
metadata = response.usage_metadata
# Ensure it's a GenerateContentResponse object before accessing attributes
if not isinstance(response, GenerateContentResponse):
logger.warning(f"Cannot get usage: Expected GenerateContentResponse or error dict, got {type(response)}")
return None
# Safely access usage metadata
metadata = getattr(response, "usage_metadata", None)
if metadata:
# Google uses prompt_token_count and candidates_token_count
prompt_tokens = getattr(metadata, "prompt_token_count", 0)
completion_tokens = getattr(metadata, "candidates_token_count", 0)
usage = {
"prompt_tokens": getattr(metadata, "prompt_token_count", 0),
"completion_tokens": getattr(metadata, "candidates_token_count", 0),
# Google also provides total_token_count, could be added if needed
# "total_tokens": getattr(metadata, "total_token_count", 0),
"prompt_tokens": int(prompt_tokens),
"completion_tokens": int(completion_tokens),
}
# Ensure values are integers
usage = {k: int(v) for k, v in usage.items()}
logger.debug(f"Extracted usage from Google response metadata: {usage}")
return usage
else:
# Log a warning only if it's not clearly an error dict already handled
if not (isinstance(response, dict) and "error" in response):
logger.warning(f"Could not extract usage from Google response object of type {type(response)}. No 'usage_metadata' attribute found.")
logger.warning(f"Could not extract usage from Google response object: No 'usage_metadata' attribute found. Response: {response}")
return None
except AttributeError as ae:
logger.error(f"Attribute error extracting usage from Google response: {ae}. Response object: {response}", exc_info=True)
logger.error(f"Attribute error extracting usage from Google response: {ae}. Response type: {type(response)}", exc_info=True)
return None
except Exception as e:
logger.error(f"Unexpected error extracting usage from Google response: {e}", exc_info=True)

View File

@@ -13,7 +13,7 @@ import json
import logging
from typing import Any
from google.genai.types import FunctionDeclaration, Schema, Tool
from google.genai.types import FunctionDeclaration, Schema, Tool, Type
logger = logging.getLogger(__name__)
@@ -95,50 +95,66 @@ def _create_google_schema_recursive(schema_dict: dict[str, Any]) -> Schema | Non
Handles type mapping and nested structures. Returns None on failure.
"""
if Schema is None:
logger.error("Cannot create Schema object: google.genai types not available.")
if Schema is None or Type is None:
logger.error("Cannot create Schema object: google.genai types (Schema or Type) not available.")
return None
if not isinstance(schema_dict, dict):
logger.warning(f"Invalid schema part encountered: {schema_dict}. Returning empty schema.")
return Schema() # Return empty schema to avoid breaking the parent
logger.warning(f"Invalid schema part encountered: {schema_dict}. Returning None.")
return None # Return None on invalid input
# Map JSON Schema types to Google's Type enum strings
# Map JSON Schema types to Google's Type enum members
type_mapping = {
"string": "STRING",
"number": "NUMBER",
"integer": "INTEGER",
"boolean": "BOOLEAN",
"array": "ARRAY",
"object": "OBJECT",
# Add other mappings if necessary
"string": Type.STRING,
"number": Type.NUMBER,
"integer": Type.INTEGER,
"boolean": Type.BOOLEAN,
"array": Type.ARRAY,
"object": Type.OBJECT,
}
original_type = schema_dict.get("type")
google_type = type_mapping.get(str(original_type).lower()) if original_type else None
if not google_type:
logger.warning(f"Schema dictionary missing 'type' or type '{original_type}' is not recognized: {schema_dict}. Returning None.")
return None # Return None if type is invalid/missing
# Prepare arguments for Schema constructor, filtering out None values
schema_args = {
"type": google_type,
"type": google_type, # Use the Type enum member
"format": schema_dict.get("format"),
"description": schema_dict.get("description"),
"nullable": schema_dict.get("nullable"),
"nullable": schema_dict.get("nullable"), # Note: Google's Schema might not directly support nullable in constructor
"enum": schema_dict.get("enum"),
"items": _create_google_schema_recursive(schema_dict["items"]) if "items" in schema_dict and google_type == "ARRAY" else None,
"properties": {k: _create_google_schema_recursive(v) for k, v in schema_dict.get("properties", {}).items()} if schema_dict.get("properties") and google_type == "OBJECT" else None,
"required": schema_dict.get("required") if google_type == "OBJECT" else None,
# Recursively create nested schemas, ensuring None is handled if recursion fails
"items": _create_google_schema_recursive(schema_dict["items"]) if google_type == Type.ARRAY and "items" in schema_dict else None,
"properties": {k: prop_schema for k, v in schema_dict.get("properties", {}).items() if (prop_schema := _create_google_schema_recursive(v)) is not None}
if google_type == Type.OBJECT and schema_dict.get("properties")
else None,
"required": schema_dict.get("required") if google_type == Type.OBJECT else None,
}
# Remove keys with None values
# Remove keys with None values before passing to Schema constructor
schema_args = {k: v for k, v in schema_args.items() if v is not None}
if not schema_args.get("type"):
logger.warning(f"Schema dictionary missing 'type' or type '{original_type}' is not recognized: {schema_dict}. Creating empty Schema.")
return Schema() # Return empty schema
# Handle specific cases for ARRAY and OBJECT where items/properties might be needed
if google_type == Type.ARRAY and "items" not in schema_args:
logger.warning(f"Array schema missing 'items': {schema_dict}. Returning None.")
return None # Array schema requires items
if google_type == Type.OBJECT and "properties" not in schema_args:
# Allow object schema without properties initially, might be handled later
pass
# logger.warning(f"Object schema missing 'properties': {schema_dict}. Creating empty properties.")
# schema_args["properties"] = {} # Or return None if properties are strictly required
try:
return Schema(**schema_args)
# Create the Schema object
created_schema = Schema(**schema_args)
# logger.debug(f"Successfully created Schema: {created_schema}")
return created_schema
except Exception as schema_creation_err:
logger.error(f"Failed to create Schema object with args {schema_args}: {schema_creation_err}", exc_info=True)
return Schema() # Return empty schema on error
return None # Return None on creation error
def convert_to_google_tool_objects(tool_configs: list[dict[str, Any]]) -> list[Tool] | None:
@@ -169,31 +185,71 @@ def convert_to_google_tool_objects(tool_configs: list[dict[str, Any]]) -> list[T
return None
for func_dict in func_declarations_list:
func_name = func_dict.get("name", "Unknown")
try:
params_schema_dict = func_dict.get("parameters", {"type": "object", "properties": {}})
# Ensure parameters is a valid schema dict for the recursive creator
params_schema_dict = func_dict.get("parameters", {})
# Ensure parameters is a dict and defaults to object type if missing
if not isinstance(params_schema_dict, dict):
logger.warning(f"Invalid 'parameters' format for tool {func_dict.get('name')}: {params_schema_dict}. Using empty object schema.")
logger.warning(f"Invalid 'parameters' format for tool {func_name}: {params_schema_dict}. Using empty object schema.")
params_schema_dict = {"type": "object", "properties": {}}
elif params_schema_dict.get("type") != "object":
logger.warning(f"Tool {func_dict.get('name')} parameters schema is not type 'object'. Forcing object type.")
params_schema_dict = {"type": "object", "properties": params_schema_dict.get("properties", {})} # Attempt to salvage properties
elif "type" not in params_schema_dict:
params_schema_dict["type"] = "object" # Default to object if type is missing
elif params_schema_dict["type"] != "object":
logger.warning(f"Tool {func_name} parameters schema is not type 'object' ({params_schema_dict.get('type')}). Google requires 'object'. Attempting to wrap properties.")
# Attempt to salvage properties if the top level isn't object
original_properties = params_schema_dict.get("properties", {})
if not isinstance(original_properties, dict):
original_properties = {}
params_schema_dict = {"type": "object", "properties": original_properties}
parameters_schema = _create_google_schema_recursive(params_schema_dict)
# Only proceed if schema creation was somewhat successful
if parameters_schema is not None:
declaration = FunctionDeclaration(
name=func_dict["name"],
description=func_dict.get("description", ""),
parameters=parameters_schema,
)
all_func_declarations.append(declaration)
properties_dict = params_schema_dict.get("properties", {})
google_properties = {}
if isinstance(properties_dict, dict):
for prop_name, prop_schema_dict in properties_dict.items():
prop_schema = _create_google_schema_recursive(prop_schema_dict)
if prop_schema:
google_properties[prop_name] = prop_schema
else:
logger.warning(f"Failed to create schema for property '{prop_name}' in tool '{func_name}'. Skipping property.")
else:
logger.error(f"Failed to create parameters Schema for FunctionDeclaration '{func_dict.get('name', 'Unknown')}'")
logger.warning(f"'properties' for tool {func_name} is not a dictionary: {properties_dict}. Ignoring properties.")
# Handle empty properties - Google requires parameters to be OBJECT, and properties cannot be null/empty
if not google_properties:
logger.warning(f"Function '{func_name}' has no valid properties defined. Adding dummy property for Google compatibility.")
google_properties = {"_dummy_param": Schema(type=Type.STRING, description="Placeholder parameter as properties cannot be empty.")}
# Clear required list if properties are empty/dummy
required_list = []
else:
# Validate required list against actual properties
original_required = params_schema_dict.get("required", [])
if isinstance(original_required, list):
required_list = [req for req in original_required if req in google_properties]
if len(required_list) != len(original_required):
logger.warning(f"Some required properties for '{func_name}' were invalid or missing from properties: {set(original_required) - set(required_list)}")
else:
logger.warning(f"'required' field for '{func_name}' is not a list: {original_required}. Ignoring required field.")
required_list = []
# Create the top-level parameters schema, ensuring it's OBJECT type
parameters_schema = Schema(
type=Type.OBJECT,
properties=google_properties,
required=required_list if required_list else None, # Pass None if empty list
)
# Create the FunctionDeclaration
declaration = FunctionDeclaration(
name=func_name,
description=func_dict.get("description", ""),
parameters=parameters_schema,
)
all_func_declarations.append(declaration)
logger.debug(f"Successfully created FunctionDeclaration for: {func_name}")
except Exception as decl_err:
logger.error(f"Failed to create FunctionDeclaration object for tool '{func_dict.get('name', 'Unknown')}': {decl_err}", exc_info=True)
logger.error(f"Failed to create FunctionDeclaration object for tool '{func_name}': {decl_err}", exc_info=True)
else:
logger.error(f"Invalid tool_configs structure provided: {tool_configs}")

View File

@@ -108,16 +108,16 @@ def convert_messages(messages: list[dict[str, Any]]) -> tuple[list[Content], str
# Include any text content alongside the function calls
if content and isinstance(content, str):
parts.append(Part.from_text(content))
parts.append(Part(text=content)) # Use direct instantiation
elif content:
# Regular user or assistant message content
if isinstance(content, str):
parts.append(Part.from_text(content))
parts.append(Part(text=content)) # Use direct instantiation
# TODO: Handle potential image content if needed in the future
else:
logger.warning(f"Unsupported content type for role '{role}': {type(content)}. Converting to string.")
parts.append(Part.from_text(str(content)))
parts.append(Part(text=str(content))) # Use direct instantiation
# Add the constructed Content object if parts were generated
if parts:

View File

@@ -32,7 +32,7 @@ class OpenAIProvider(BaseProvider):
self,
messages: list[dict[str, str]],
model: str,
temperature: float = 0.4,
temperature: float = 0.6,
max_tokens: int | None = None,
stream: bool = True,
tools: list[dict[str, Any]] | None = None,

View File

@@ -14,7 +14,7 @@ def create_chat_completion(
provider, # The OpenAIProvider instance
messages: list[dict[str, str]],
model: str,
temperature: float = 0.4,
temperature: float = 0.6,
max_tokens: int | None = None,
stream: bool = True,
tools: list[dict[str, Any]] | None = None,