fix: update temperature parameter to 0.6 across multiple providers and add debugging output

This commit is contained in:
2025-03-27 19:02:52 +00:00
parent ccf750fed4
commit 51e3058961
11 changed files with 292 additions and 180 deletions

View File

@@ -56,7 +56,7 @@ class LLMClient:
self, self,
messages: list[dict[str, str]], messages: list[dict[str, str]],
model: str, model: str,
temperature: float = 0.4, temperature: float = 0.6,
max_tokens: int | None = None, max_tokens: int | None = None,
stream: bool = True, stream: bool = True,
) -> Generator[str, None, None] | dict[str, Any]: ) -> Generator[str, None, None] | dict[str, Any]:
@@ -97,6 +97,7 @@ class LLMClient:
stream=stream, stream=stream,
tools=provider_tools, tools=provider_tools,
) )
print(f"Response: {response}") # Debugging line to check the response
logger.info("Received response from provider.") logger.info("Received response from provider.")
if stream: if stream:

View File

@@ -10,7 +10,7 @@ logger = logging.getLogger(__name__)
def create_chat_completion( def create_chat_completion(
provider, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, stream: bool = True, tools: list[dict[str, Any]] | None = None provider, messages: list[dict[str, Any]], model: str, temperature: float = 0.6, max_tokens: int | None = None, stream: bool = True, tools: list[dict[str, Any]] | None = None
) -> Stream | Message: ) -> Stream | Message:
logger.debug(f"Creating Anthropic chat completion. Model: {model}, Stream: {stream}, Tools: {bool(tools)}") logger.debug(f"Creating Anthropic chat completion. Model: {model}, Stream: {stream}, Tools: {bool(tools)}")
temp_system_prompt, temp_anthropic_messages = convert_messages(messages) temp_system_prompt, temp_anthropic_messages = convert_messages(messages)

View File

@@ -28,7 +28,7 @@ class BaseProvider(abc.ABC):
self, self,
messages: list[dict[str, str]], messages: list[dict[str, str]],
model: str, model: str,
temperature: float = 0.4, temperature: float = 0.6,
max_tokens: int | None = None, max_tokens: int | None = None,
stream: bool = True, stream: bool = True,
tools: list[dict[str, Any]] | None = None, tools: list[dict[str, Any]] | None = None,
@@ -39,7 +39,7 @@ class BaseProvider(abc.ABC):
Args: Args:
messages: List of message dictionaries with 'role' and 'content'. messages: List of message dictionaries with 'role' and 'content'.
model: Model identifier. model: Model identifier.
temperature: Sampling temperature (0-1). temperature: Sampling temperature (0-2).
max_tokens: Maximum tokens to generate. max_tokens: Maximum tokens to generate.
stream: Whether to stream the response. stream: Whether to stream the response.
tools: Optional list of tools in the provider-specific format. tools: Optional list of tools in the provider-specific format.

View File

@@ -3,9 +3,12 @@ import logging
from collections.abc import Generator from collections.abc import Generator
from typing import Any from typing import Any
# Import Generator type for isinstance check - Keep this import for type hints
from google.genai.types import GenerateContentResponse from google.genai.types import GenerateContentResponse
from providers.google_provider.client import initialize_client from providers.google_provider.client import initialize_client
# Correctly import the renamed function directly
from providers.google_provider.completion import create_chat_completion from providers.google_provider.completion import create_chat_completion
from providers.google_provider.response import get_content, get_streaming_content, get_usage from providers.google_provider.response import get_content, get_streaming_content, get_usage
from providers.google_provider.tools import convert_to_google_tools, format_google_tool_results, has_google_tool_calls, parse_google_tool_calls from providers.google_provider.tools import convert_to_google_tools, format_google_tool_results, has_google_tool_calls, parse_google_tool_calls
@@ -28,7 +31,7 @@ class GoogleProvider(BaseProvider):
api_key: The Google API key. api_key: The Google API key.
base_url: Base URL (typically not used by Google client config, but kept for interface consistency). base_url: Base URL (typically not used by Google client config, but kept for interface consistency).
""" """
# initialize_client returns the configured genai module # initialize_client returns the client instance now
self.client_module = initialize_client(api_key, base_url) self.client_module = initialize_client(api_key, base_url)
self.api_key = api_key # Store if needed later self.api_key = api_key # Store if needed later
self.base_url = base_url # Store if needed later self.base_url = base_url # Store if needed later
@@ -38,18 +41,24 @@ class GoogleProvider(BaseProvider):
self, self,
messages: list[dict[str, Any]], messages: list[dict[str, Any]],
model: str, model: str,
temperature: float = 0.4, temperature: float = 0.6,
max_tokens: int | None = None, max_tokens: int | None = None,
stream: bool = True, stream: bool = True,
tools: list[dict[str, Any]] | None = None, tools: list[dict[str, Any]] | None = None,
) -> Any: # Return type is complex: iterator for stream, GenerateContentResponse otherwise, or error dict/iterator ) -> Any: # Return type is complex: iterator for stream, GenerateContentResponse otherwise, or error dict/iterator
"""Creates a chat completion using the Google Gemini API.""" """Creates a chat completion using the Google Gemini API."""
# Pass self (provider instance) to the helper function # Pass self (provider instance) to the helper function
return create_chat_completion(self, messages, model, temperature, max_tokens, stream, tools) raw_response = create_chat_completion(self, messages, model, temperature, max_tokens, stream, tools)
print(f"Raw response type: {type(raw_response)}") # Debugging line to check the type of raw_response
print(f"Raw response: {raw_response}") # Debugging line to check the content of raw_response
# The completion helper function handles returning the correct type or an error dict.
# No need for generator handling here anymore.
return raw_response
def get_streaming_content(self, response: Any) -> Generator[str, None, None]: def get_streaming_content(self, response: Any) -> Generator[str, None, None]:
"""Extracts content chunks from a Google streaming response.""" """Extracts content chunks from a Google streaming response."""
# Response is expected to be an iterator from generate_content(stream=True) # Response is expected to be an iterator from generate_content_stream
return get_streaming_content(response) return get_streaming_content(response)
def get_content(self, response: GenerateContentResponse | dict[str, Any]) -> str: def get_content(self, response: GenerateContentResponse | dict[str, Any]) -> str:

View File

@@ -16,12 +16,13 @@ def initialize_client(api_key: str, base_url: str | None = None) -> Any:
raise ImportError("Google Generative AI SDK is required for GoogleProvider. Please install google-generativeai.") raise ImportError("Google Generative AI SDK is required for GoogleProvider. Please install google-generativeai.")
try: try:
# Configure the client # Instantiate the client directly using the API key
genai.configure(api_key=api_key) client = genai.Client(api_key=api_key)
logger.info("Google Generative AI client instantiated.")
if base_url: if base_url:
logger.warning(f"base_url '{base_url}' provided but not typically used by Google client configuration.") logger.warning(f"base_url '{base_url}' provided but not typically used by Google client instantiation.")
# Return the configured module itself, as it's used directly # Return the client instance
return genai return client
except Exception as e: except Exception as e:
logger.error(f"Failed to configure Google Generative AI client: {e}", exc_info=True) logger.error(f"Failed to instantiate Google Generative AI client: {e}", exc_info=True)
raise raise

View File

@@ -1,140 +1,174 @@
# src/providers/google_provider/completion.py
import json
import logging import logging
import traceback import traceback
from collections.abc import Iterable # Added Iterable
from typing import Any from typing import Any
from google.genai.types import Tool # Import specific types for better hinting
from google.genai.types import ContentDict, GenerateContentResponse, GenerationConfigDict, Tool
from providers.google_provider.tools import convert_to_google_tool_objects, convert_to_google_tools # Removed convert_to_google_tools import as it's handled later
from providers.google_provider.tools import convert_to_google_tool_objects
from providers.google_provider.utils import convert_messages from providers.google_provider.utils import convert_messages
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def create_chat_completion( # --- Helper for Non-Streaming ---
def _create_chat_completion_non_stream(
provider, provider,
model: str,
google_messages: list[ContentDict],
generation_config: GenerationConfigDict,
) -> GenerateContentResponse | dict[str, Any]:
"""Handles the non-streaming API call."""
try:
logger.debug("Calling client.models.generate_content...")
# Use the client instance stored on the provider
response = provider.client_module.models.generate_content(
model=f"models/{model}",
contents=google_messages,
config=generation_config,
)
logger.debug("generate_content call successful, returning raw response object.")
# Return the direct response object
return response
except ValueError as ve:
error_msg = f"Google API request validation error: {ve}"
logger.error(error_msg, exc_info=True)
# Return error dict
return {"error": error_msg, "traceback": traceback.format_exc()}
except Exception as e:
error_msg = f"Google API error during non-stream chat completion: {e}"
logger.error(error_msg, exc_info=True)
# Return error dict
return {"error": error_msg, "traceback": traceback.format_exc()}
# --- Helper for Streaming ---
def _create_chat_completion_stream(
provider,
model: str,
google_messages: list[ContentDict],
generation_config: GenerationConfigDict,
) -> Iterable[GenerateContentResponse | dict[str, Any]]: # Return Iterable of response chunks or error dict
"""Handles the streaming API call and yields results."""
try:
logger.debug("Calling client.models.generate_content_stream...")
# Use the client instance stored on the provider
response_iterator = provider.client_module.models.generate_content_stream(
model=f"models/{model}",
contents=google_messages,
config=generation_config,
)
logger.debug("generate_content_stream call successful, yielding from iterator.")
# Yield from the SDK's iterator which produces GenerateContentResponse chunks
yield from response_iterator
except ValueError as ve:
error_msg = f"Google API request validation error: {ve}"
logger.error(error_msg, exc_info=True)
# Yield error as a dict matching non-streaming error structure
yield {"error": error_msg, "traceback": traceback.format_exc()}
except Exception as e:
error_msg = f"Google API error during stream chat completion: {e}"
logger.error(error_msg, exc_info=True)
# Yield error as a dict
yield {"error": error_msg, "traceback": traceback.format_exc()}
# --- Main Function ---
# Renamed original function to avoid conflict if needed, though overwrite is fine
def create_chat_completion(
provider, # Provider instance is passed in
messages: list[dict[str, Any]], messages: list[dict[str, Any]],
model: str, model: str,
temperature: float = 0.4, temperature: float = 0.6,
max_tokens: int | None = None, max_tokens: int | None = None,
stream: bool = True, stream: bool = True,
tools: list[dict[str, Any]] | None = None, tools: list[dict[str, Any]] | None = None, # Expects intermediate dict format
) -> Any: ) -> Any: # Return type depends on stream flag
""" """
Creates a chat completion using the Google Gemini API. Creates a chat completion using the Google Gemini API.
Delegates to streaming or non-streaming helpers. Contains NO yield itself.
Args:
provider: The instance of the GoogleProvider.
messages: A list of message dictionaries in the standard format.
model: The model ID to use (e.g., "gemini-1.5-flash").
temperature: The sampling temperature.
max_tokens: The maximum number of tokens to generate.
stream: Whether to stream the response.
tools: A list of tool definitions in the MCP format.
Returns:
The response object from the Google API (could be a stream iterator or
a GenerateContentResponse object), or an error dictionary/iterator.
""" """
logger.debug(f"Google create_chat_completion called. Model: {model}, Stream: {stream}, Tools: {bool(tools)}") logger.debug(f"Google create_chat_completion_inner called. Model: {model}, Stream: {stream}, Tools: {bool(tools)}")
# Check if client exists on the provider instance
if provider.client_module is None: if provider.client_module is None:
error_msg = "Google Generative AI SDK not configured or installed." error_msg = "Google Generative AI client not initialized on provider."
logger.error(error_msg) logger.error(error_msg)
# Return an error structure compatible with both streaming and non-streaming expectations # Return error dict directly for non-stream, create iterator for stream
if stream: return iter([{"error": error_msg}]) if stream else {"error": error_msg}
return iter([json.dumps({"error": error_msg})])
else:
return {"error": error_msg}
try: try:
# 1. Convert messages to Google's format # 1. Convert messages (Common logic)
google_messages, system_prompt = convert_messages(messages) google_messages, system_prompt = convert_messages(messages)
logger.debug(f"Converted {len(messages)} messages to {len(google_messages)} Google Content objects. System prompt present: {bool(system_prompt)}") logger.debug(f"Converted {len(messages)} messages to {len(google_messages)} Google Content objects. System prompt present: {bool(system_prompt)}")
# 2. Prepare generation configuration # 2. Prepare generation configuration (Common logic)
generation_config: dict[str, Any] = {"temperature": temperature} # Use GenerationConfigDict for better type hinting if possible
generation_config: GenerationConfigDict = {"temperature": temperature}
if max_tokens is not None: if max_tokens is not None:
# Google uses 'max_output_tokens'
generation_config["max_output_tokens"] = max_tokens generation_config["max_output_tokens"] = max_tokens
logger.debug(f"Setting max_output_tokens: {max_tokens}") logger.debug(f"Setting max_output_tokens: {max_tokens}")
else: else:
logger.debug("No max_tokens specified.") # Google requires max_output_tokens, set a default if None
# Defaulting to a reasonable value, e.g., 8192, check model limits if needed
default_max_tokens = 8192
generation_config["max_output_tokens"] = default_max_tokens
logger.warning(f"max_tokens not provided, defaulting to {default_max_tokens} for Google API.")
# 3. Convert tools if provided # 3. Convert tools if provided (Common logic)
google_tool_objects: list[Tool] | None = None google_tool_objects: list[Tool] | None = None
if tools: if tools:
try: try:
# Step 3a: Convert MCP tools to intermediate Google dict format # Convert intermediate dict format to Google Tool objects
tool_dict_list = convert_to_google_tools(tools) google_tool_objects = convert_to_google_tool_objects(tools)
# Step 3b: Convert intermediate dict format to Google Tool objects
google_tool_objects = convert_to_google_tool_objects(tool_dict_list)
if google_tool_objects: if google_tool_objects:
logger.debug(f"Successfully converted {len(tools)} MCP tools to {len(google_tool_objects)} Google Tool objects.") num_declarations = sum(len(t.function_declarations) for t in google_tool_objects if t.function_declarations)
logger.debug(f"Successfully converted intermediate tool config to {len(google_tool_objects)} Google Tool objects with {num_declarations} declarations.")
else: else:
logger.warning("Tool conversion resulted in no valid Google Tool objects.") logger.warning("Tool conversion resulted in no valid Google Tool objects.")
except Exception as tool_conv_err: except Exception as tool_conv_err:
logger.error(f"Failed to convert tools for Google: {tool_conv_err}", exc_info=True) logger.error(f"Failed to convert tools for Google: {tool_conv_err}", exc_info=True)
# Decide whether to proceed without tools or raise an error google_tool_objects = None # Continue without tools on conversion error
# Proceeding without tools for now, but logging the error.
google_tool_objects = None
else: else:
logger.debug("No tools provided for conversion.") logger.debug("No tools provided for conversion.")
# 4. Initialize the Google Generative Model # 4. Add system prompt and tools to generation_config (Common logic)
# Ensure client_module is callable and has GenerativeModel if system_prompt:
if not hasattr(provider.client_module, "GenerativeModel"): # Ensure system_instruction is ContentDict or compatible type
raise AttributeError("Configured Google client module does not have 'GenerativeModel'") generation_config["system_instruction"] = system_prompt
logger.debug("Added system_instruction to generation_config.")
if google_tool_objects:
# Assign the list of Tool objects directly
generation_config["tools"] = google_tool_objects
logger.debug(f"Added {len(google_tool_objects)} tool objects to generation_config.")
gemini_model = provider.client_module.GenerativeModel( # 5. Log parameters before API call (Common logic)
model_name=model,
system_instruction=system_prompt, # Pass extracted system prompt
tools=google_tool_objects, # Pass converted Tool objects (or None)
# Add safety_settings if needed: safety_settings=...
)
logger.debug(f"Initialized Google GenerativeModel for '{model}'.")
# 5. Log parameters before API call
log_params = { log_params = {
"model": model, "model": model,
"stream": stream, "stream": stream,
"temperature": temperature, "temperature": temperature,
"max_output_tokens": generation_config.get("max_output_tokens"), "max_output_tokens": generation_config.get("max_output_tokens"),
"system_prompt_present": bool(system_prompt), "system_prompt_present": bool(system_prompt),
"num_tools": len(google_tool_objects) if google_tool_objects else 0, "num_tools": len(generation_config.get("tools", [])) if "tools" in generation_config else 0,
"num_messages": len(google_messages), "num_messages": len(google_messages),
} }
logger.info(f"Calling Google generate_content API with params: {log_params}") logger.info(f"Calling Google API via helper with params: {log_params}")
# Avoid logging full message content unless necessary for debugging specific issues
# logger.debug(f"Google messages being sent: {google_messages}")
# 6. Call the Google API # 6. Delegate to appropriate helper
response = gemini_model.generate_content(
contents=google_messages,
generation_config=generation_config,
stream=stream,
# tool_config={"function_calling_config": "AUTO"} # AUTO is default
)
logger.debug("Google API call successful, returning response object.")
return response
except ValueError as ve: # Catch specific errors like invalid message sequence
error_msg = f"Google API request validation error: {ve}"
logger.error(error_msg, exc_info=True)
if stream: if stream:
# Yield a JSON error message in an iterator # Return the generator/iterator from the streaming helper
yield json.dumps({"error": error_msg, "traceback": traceback.format_exc()}) # This helper uses 'yield from'
return _create_chat_completion_stream(provider, model, google_messages, generation_config)
else: else:
# Return an error dictionary # Return the direct result (GenerateContentResponse or error dict) from the non-streaming helper
return {"error": error_msg, "traceback": traceback.format_exc()} # This helper uses 'return'
return _create_chat_completion_non_stream(provider, model, google_messages, generation_config)
except Exception as e: except Exception as e:
# Catch any other exceptions during setup or API call # Catch errors during common setup (message/tool conversion etc.)
error_msg = f"Google API error during chat completion: {e}" error_msg = f"Error during Google completion setup: {e}"
logger.error(error_msg, exc_info=True) logger.error(error_msg, exc_info=True)
if stream: # Return error dict directly for non-stream, create iterator for stream
# Yield a JSON error message in an iterator return iter([{"error": error_msg, "traceback": traceback.format_exc()}]) if stream else {"error": error_msg, "traceback": traceback.format_exc()}
yield json.dumps({"error": error_msg, "traceback": traceback.format_exc()})
else:
# Return an error dictionary
return {"error": error_msg, "traceback": traceback.format_exc()}

View File

@@ -124,37 +124,46 @@ def get_content(response: GenerateContentResponse | dict[str, Any]) -> str:
The concatenated text content, or an error message string. The concatenated text content, or an error message string.
""" """
try: try:
# Handle error dictionary case # Check if it's an error dictionary passed from upstream (e.g., completion helper)
if isinstance(response, dict) and "error" in response: if isinstance(response, dict) and "error" in response:
logger.error(f"Cannot get content from error response: {response['error']}") logger.error(f"Cannot get content from error dict: {response['error']}")
return f"[Error: {response['error']}]" return f"[Error: {response['error']}]"
# Handle successful GenerateContentResponse object # Ensure it's a GenerateContentResponse object before accessing attributes
if hasattr(response, "text"): if not isinstance(response, GenerateContentResponse):
# The `.text` attribute usually provides the concatenated text content directly logger.error(f"Cannot get content: Expected GenerateContentResponse or error dict, got {type(response)}")
return f"[Error: Unexpected response type {type(response)}]"
# --- Access GenerateContentResponse attributes ---
# Prioritize response.text if available and not empty
if hasattr(response, "text") and response.text:
content = response.text content = response.text
logger.debug(f"Extracted content (length {len(content)}) from response.text.") logger.debug(f"Extracted content (length {len(content)}) from response.text.")
return content return content
elif hasattr(response, "candidates") and response.candidates:
# Fallback: manually concatenate text from parts if .text is missing # Fallback: manually concatenate text from parts if .text is missing/empty
if hasattr(response, "candidates") and response.candidates:
first_candidate = response.candidates[0] first_candidate = response.candidates[0]
if hasattr(first_candidate, "content") and hasattr(first_candidate.content, "parts"): # Check candidate content and parts carefully
text_parts = [] if hasattr(first_candidate, "content") and first_candidate.content and hasattr(first_candidate.content, "parts") and first_candidate.content.parts:
for part in first_candidate.content.parts: text_parts = [part.text for part in first_candidate.content.parts if hasattr(part, "text")]
if hasattr(part, "text"): if text_parts:
text_parts.append(part.text)
# We are only interested in text content here, ignore function calls etc.
content = "".join(text_parts) content = "".join(text_parts)
logger.debug(f"Extracted content (length {len(content)}) from response candidates' parts.") logger.debug(f"Extracted content (length {len(content)}) from response candidate parts.")
return content return content
else: else:
logger.warning("Google response candidate has no content or parts.") logger.warning("Google response candidate parts contained no text.")
return "" # Return empty string if no text found return "" # Return empty if parts exist but have no text
else: else:
logger.warning(f"Could not extract content from Google response: No 'text' or valid 'candidates'. Response type: {type(response)}") logger.warning("Google response candidate has no valid content or parts.")
return "" # Return empty string if no valid content/parts
else:
# If neither .text nor valid candidates are found
logger.warning(f"Could not extract content from Google response: No .text or valid candidates found. Response: {response}")
return "" # Return empty string if no text found return "" # Return empty string if no text found
except AttributeError as ae: except AttributeError as ae:
logger.error(f"Attribute error extracting content from Google response: {ae}. Response object: {response}", exc_info=True) logger.error(f"Attribute error extracting content from Google response: {ae}. Response type: {type(response)}", exc_info=True)
return f"[Error extracting content: Attribute missing - {str(ae)}]" return f"[Error extracting content: Attribute missing - {str(ae)}]"
except Exception as e: except Exception as e:
logger.error(f"Unexpected error extracting content from Google response: {e}", exc_info=True) logger.error(f"Unexpected error extracting content from Google response: {e}", exc_info=True)
@@ -173,32 +182,34 @@ def get_usage(response: GenerateContentResponse | dict[str, Any]) -> dict[str, i
usage information is unavailable or an error occurred. usage information is unavailable or an error occurred.
""" """
try: try:
# Handle error dictionary case # Check if it's an error dictionary passed from upstream
if isinstance(response, dict) and "error" in response: if isinstance(response, dict) and "error" in response:
logger.warning("Cannot get usage from error response.") logger.warning(f"Cannot get usage from error dict: {response['error']}")
return None return None
# Check for usage metadata in the response object # Ensure it's a GenerateContentResponse object before accessing attributes
if hasattr(response, "usage_metadata"): if not isinstance(response, GenerateContentResponse):
metadata = response.usage_metadata logger.warning(f"Cannot get usage: Expected GenerateContentResponse or error dict, got {type(response)}")
return None
# Safely access usage metadata
metadata = getattr(response, "usage_metadata", None)
if metadata:
# Google uses prompt_token_count and candidates_token_count # Google uses prompt_token_count and candidates_token_count
prompt_tokens = getattr(metadata, "prompt_token_count", 0)
completion_tokens = getattr(metadata, "candidates_token_count", 0)
usage = { usage = {
"prompt_tokens": getattr(metadata, "prompt_token_count", 0), "prompt_tokens": int(prompt_tokens),
"completion_tokens": getattr(metadata, "candidates_token_count", 0), "completion_tokens": int(completion_tokens),
# Google also provides total_token_count, could be added if needed
# "total_tokens": getattr(metadata, "total_token_count", 0),
} }
# Ensure values are integers
usage = {k: int(v) for k, v in usage.items()}
logger.debug(f"Extracted usage from Google response metadata: {usage}") logger.debug(f"Extracted usage from Google response metadata: {usage}")
return usage return usage
else: else:
# Log a warning only if it's not clearly an error dict already handled logger.warning(f"Could not extract usage from Google response object: No 'usage_metadata' attribute found. Response: {response}")
if not (isinstance(response, dict) and "error" in response):
logger.warning(f"Could not extract usage from Google response object of type {type(response)}. No 'usage_metadata' attribute found.")
return None return None
except AttributeError as ae: except AttributeError as ae:
logger.error(f"Attribute error extracting usage from Google response: {ae}. Response object: {response}", exc_info=True) logger.error(f"Attribute error extracting usage from Google response: {ae}. Response type: {type(response)}", exc_info=True)
return None return None
except Exception as e: except Exception as e:
logger.error(f"Unexpected error extracting usage from Google response: {e}", exc_info=True) logger.error(f"Unexpected error extracting usage from Google response: {e}", exc_info=True)

View File

@@ -13,7 +13,7 @@ import json
import logging import logging
from typing import Any from typing import Any
from google.genai.types import FunctionDeclaration, Schema, Tool from google.genai.types import FunctionDeclaration, Schema, Tool, Type
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -95,50 +95,66 @@ def _create_google_schema_recursive(schema_dict: dict[str, Any]) -> Schema | Non
Handles type mapping and nested structures. Returns None on failure. Handles type mapping and nested structures. Returns None on failure.
""" """
if Schema is None: if Schema is None or Type is None:
logger.error("Cannot create Schema object: google.genai types not available.") logger.error("Cannot create Schema object: google.genai types (Schema or Type) not available.")
return None return None
if not isinstance(schema_dict, dict): if not isinstance(schema_dict, dict):
logger.warning(f"Invalid schema part encountered: {schema_dict}. Returning empty schema.") logger.warning(f"Invalid schema part encountered: {schema_dict}. Returning None.")
return Schema() # Return empty schema to avoid breaking the parent return None # Return None on invalid input
# Map JSON Schema types to Google's Type enum strings # Map JSON Schema types to Google's Type enum members
type_mapping = { type_mapping = {
"string": "STRING", "string": Type.STRING,
"number": "NUMBER", "number": Type.NUMBER,
"integer": "INTEGER", "integer": Type.INTEGER,
"boolean": "BOOLEAN", "boolean": Type.BOOLEAN,
"array": "ARRAY", "array": Type.ARRAY,
"object": "OBJECT", "object": Type.OBJECT,
# Add other mappings if necessary
} }
original_type = schema_dict.get("type") original_type = schema_dict.get("type")
google_type = type_mapping.get(str(original_type).lower()) if original_type else None google_type = type_mapping.get(str(original_type).lower()) if original_type else None
if not google_type:
logger.warning(f"Schema dictionary missing 'type' or type '{original_type}' is not recognized: {schema_dict}. Returning None.")
return None # Return None if type is invalid/missing
# Prepare arguments for Schema constructor, filtering out None values # Prepare arguments for Schema constructor, filtering out None values
schema_args = { schema_args = {
"type": google_type, "type": google_type, # Use the Type enum member
"format": schema_dict.get("format"), "format": schema_dict.get("format"),
"description": schema_dict.get("description"), "description": schema_dict.get("description"),
"nullable": schema_dict.get("nullable"), "nullable": schema_dict.get("nullable"), # Note: Google's Schema might not directly support nullable in constructor
"enum": schema_dict.get("enum"), "enum": schema_dict.get("enum"),
"items": _create_google_schema_recursive(schema_dict["items"]) if "items" in schema_dict and google_type == "ARRAY" else None, # Recursively create nested schemas, ensuring None is handled if recursion fails
"properties": {k: _create_google_schema_recursive(v) for k, v in schema_dict.get("properties", {}).items()} if schema_dict.get("properties") and google_type == "OBJECT" else None, "items": _create_google_schema_recursive(schema_dict["items"]) if google_type == Type.ARRAY and "items" in schema_dict else None,
"required": schema_dict.get("required") if google_type == "OBJECT" else None, "properties": {k: prop_schema for k, v in schema_dict.get("properties", {}).items() if (prop_schema := _create_google_schema_recursive(v)) is not None}
if google_type == Type.OBJECT and schema_dict.get("properties")
else None,
"required": schema_dict.get("required") if google_type == Type.OBJECT else None,
} }
# Remove keys with None values
# Remove keys with None values before passing to Schema constructor
schema_args = {k: v for k, v in schema_args.items() if v is not None} schema_args = {k: v for k, v in schema_args.items() if v is not None}
if not schema_args.get("type"): # Handle specific cases for ARRAY and OBJECT where items/properties might be needed
logger.warning(f"Schema dictionary missing 'type' or type '{original_type}' is not recognized: {schema_dict}. Creating empty Schema.") if google_type == Type.ARRAY and "items" not in schema_args:
return Schema() # Return empty schema logger.warning(f"Array schema missing 'items': {schema_dict}. Returning None.")
return None # Array schema requires items
if google_type == Type.OBJECT and "properties" not in schema_args:
# Allow object schema without properties initially, might be handled later
pass
# logger.warning(f"Object schema missing 'properties': {schema_dict}. Creating empty properties.")
# schema_args["properties"] = {} # Or return None if properties are strictly required
try: try:
return Schema(**schema_args) # Create the Schema object
created_schema = Schema(**schema_args)
# logger.debug(f"Successfully created Schema: {created_schema}")
return created_schema
except Exception as schema_creation_err: except Exception as schema_creation_err:
logger.error(f"Failed to create Schema object with args {schema_args}: {schema_creation_err}", exc_info=True) logger.error(f"Failed to create Schema object with args {schema_args}: {schema_creation_err}", exc_info=True)
return Schema() # Return empty schema on error return None # Return None on creation error
def convert_to_google_tool_objects(tool_configs: list[dict[str, Any]]) -> list[Tool] | None: def convert_to_google_tool_objects(tool_configs: list[dict[str, Any]]) -> list[Tool] | None:
@@ -169,31 +185,71 @@ def convert_to_google_tool_objects(tool_configs: list[dict[str, Any]]) -> list[T
return None return None
for func_dict in func_declarations_list: for func_dict in func_declarations_list:
func_name = func_dict.get("name", "Unknown")
try: try:
params_schema_dict = func_dict.get("parameters", {"type": "object", "properties": {}}) params_schema_dict = func_dict.get("parameters", {})
# Ensure parameters is a valid schema dict for the recursive creator
# Ensure parameters is a dict and defaults to object type if missing
if not isinstance(params_schema_dict, dict): if not isinstance(params_schema_dict, dict):
logger.warning(f"Invalid 'parameters' format for tool {func_dict.get('name')}: {params_schema_dict}. Using empty object schema.") logger.warning(f"Invalid 'parameters' format for tool {func_name}: {params_schema_dict}. Using empty object schema.")
params_schema_dict = {"type": "object", "properties": {}} params_schema_dict = {"type": "object", "properties": {}}
elif params_schema_dict.get("type") != "object": elif "type" not in params_schema_dict:
logger.warning(f"Tool {func_dict.get('name')} parameters schema is not type 'object'. Forcing object type.") params_schema_dict["type"] = "object" # Default to object if type is missing
params_schema_dict = {"type": "object", "properties": params_schema_dict.get("properties", {})} # Attempt to salvage properties elif params_schema_dict["type"] != "object":
logger.warning(f"Tool {func_name} parameters schema is not type 'object' ({params_schema_dict.get('type')}). Google requires 'object'. Attempting to wrap properties.")
# Attempt to salvage properties if the top level isn't object
original_properties = params_schema_dict.get("properties", {})
if not isinstance(original_properties, dict):
original_properties = {}
params_schema_dict = {"type": "object", "properties": original_properties}
parameters_schema = _create_google_schema_recursive(params_schema_dict) properties_dict = params_schema_dict.get("properties", {})
google_properties = {}
if isinstance(properties_dict, dict):
for prop_name, prop_schema_dict in properties_dict.items():
prop_schema = _create_google_schema_recursive(prop_schema_dict)
if prop_schema:
google_properties[prop_name] = prop_schema
else:
logger.warning(f"Failed to create schema for property '{prop_name}' in tool '{func_name}'. Skipping property.")
else:
logger.warning(f"'properties' for tool {func_name} is not a dictionary: {properties_dict}. Ignoring properties.")
# Only proceed if schema creation was somewhat successful # Handle empty properties - Google requires parameters to be OBJECT, and properties cannot be null/empty
if parameters_schema is not None: if not google_properties:
logger.warning(f"Function '{func_name}' has no valid properties defined. Adding dummy property for Google compatibility.")
google_properties = {"_dummy_param": Schema(type=Type.STRING, description="Placeholder parameter as properties cannot be empty.")}
# Clear required list if properties are empty/dummy
required_list = []
else:
# Validate required list against actual properties
original_required = params_schema_dict.get("required", [])
if isinstance(original_required, list):
required_list = [req for req in original_required if req in google_properties]
if len(required_list) != len(original_required):
logger.warning(f"Some required properties for '{func_name}' were invalid or missing from properties: {set(original_required) - set(required_list)}")
else:
logger.warning(f"'required' field for '{func_name}' is not a list: {original_required}. Ignoring required field.")
required_list = []
# Create the top-level parameters schema, ensuring it's OBJECT type
parameters_schema = Schema(
type=Type.OBJECT,
properties=google_properties,
required=required_list if required_list else None, # Pass None if empty list
)
# Create the FunctionDeclaration
declaration = FunctionDeclaration( declaration = FunctionDeclaration(
name=func_dict["name"], name=func_name,
description=func_dict.get("description", ""), description=func_dict.get("description", ""),
parameters=parameters_schema, parameters=parameters_schema,
) )
all_func_declarations.append(declaration) all_func_declarations.append(declaration)
else: logger.debug(f"Successfully created FunctionDeclaration for: {func_name}")
logger.error(f"Failed to create parameters Schema for FunctionDeclaration '{func_dict.get('name', 'Unknown')}'")
except Exception as decl_err: except Exception as decl_err:
logger.error(f"Failed to create FunctionDeclaration object for tool '{func_dict.get('name', 'Unknown')}': {decl_err}", exc_info=True) logger.error(f"Failed to create FunctionDeclaration object for tool '{func_name}': {decl_err}", exc_info=True)
else: else:
logger.error(f"Invalid tool_configs structure provided: {tool_configs}") logger.error(f"Invalid tool_configs structure provided: {tool_configs}")

View File

@@ -108,16 +108,16 @@ def convert_messages(messages: list[dict[str, Any]]) -> tuple[list[Content], str
# Include any text content alongside the function calls # Include any text content alongside the function calls
if content and isinstance(content, str): if content and isinstance(content, str):
parts.append(Part.from_text(content)) parts.append(Part(text=content)) # Use direct instantiation
elif content: elif content:
# Regular user or assistant message content # Regular user or assistant message content
if isinstance(content, str): if isinstance(content, str):
parts.append(Part.from_text(content)) parts.append(Part(text=content)) # Use direct instantiation
# TODO: Handle potential image content if needed in the future # TODO: Handle potential image content if needed in the future
else: else:
logger.warning(f"Unsupported content type for role '{role}': {type(content)}. Converting to string.") logger.warning(f"Unsupported content type for role '{role}': {type(content)}. Converting to string.")
parts.append(Part.from_text(str(content))) parts.append(Part(text=str(content))) # Use direct instantiation
# Add the constructed Content object if parts were generated # Add the constructed Content object if parts were generated
if parts: if parts:

View File

@@ -32,7 +32,7 @@ class OpenAIProvider(BaseProvider):
self, self,
messages: list[dict[str, str]], messages: list[dict[str, str]],
model: str, model: str,
temperature: float = 0.4, temperature: float = 0.6,
max_tokens: int | None = None, max_tokens: int | None = None,
stream: bool = True, stream: bool = True,
tools: list[dict[str, Any]] | None = None, tools: list[dict[str, Any]] | None = None,

View File

@@ -14,7 +14,7 @@ def create_chat_completion(
provider, # The OpenAIProvider instance provider, # The OpenAIProvider instance
messages: list[dict[str, str]], messages: list[dict[str, str]],
model: str, model: str,
temperature: float = 0.4, temperature: float = 0.6,
max_tokens: int | None = None, max_tokens: int | None = None,
stream: bool = True, stream: bool = True,
tools: list[dict[str, Any]] | None = None, tools: list[dict[str, Any]] | None = None,