fix: update temperature parameter to 0.6 across multiple providers and add debugging output
This commit is contained in:
@@ -56,7 +56,7 @@ class LLMClient:
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
model: str,
|
||||
temperature: float = 0.4,
|
||||
temperature: float = 0.6,
|
||||
max_tokens: int | None = None,
|
||||
stream: bool = True,
|
||||
) -> Generator[str, None, None] | dict[str, Any]:
|
||||
@@ -97,6 +97,7 @@ class LLMClient:
|
||||
stream=stream,
|
||||
tools=provider_tools,
|
||||
)
|
||||
print(f"Response: {response}") # Debugging line to check the response
|
||||
logger.info("Received response from provider.")
|
||||
|
||||
if stream:
|
||||
|
||||
@@ -10,7 +10,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_chat_completion(
|
||||
provider, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, stream: bool = True, tools: list[dict[str, Any]] | None = None
|
||||
provider, messages: list[dict[str, Any]], model: str, temperature: float = 0.6, max_tokens: int | None = None, stream: bool = True, tools: list[dict[str, Any]] | None = None
|
||||
) -> Stream | Message:
|
||||
logger.debug(f"Creating Anthropic chat completion. Model: {model}, Stream: {stream}, Tools: {bool(tools)}")
|
||||
temp_system_prompt, temp_anthropic_messages = convert_messages(messages)
|
||||
|
||||
@@ -28,7 +28,7 @@ class BaseProvider(abc.ABC):
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
model: str,
|
||||
temperature: float = 0.4,
|
||||
temperature: float = 0.6,
|
||||
max_tokens: int | None = None,
|
||||
stream: bool = True,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
@@ -39,7 +39,7 @@ class BaseProvider(abc.ABC):
|
||||
Args:
|
||||
messages: List of message dictionaries with 'role' and 'content'.
|
||||
model: Model identifier.
|
||||
temperature: Sampling temperature (0-1).
|
||||
temperature: Sampling temperature (0-2).
|
||||
max_tokens: Maximum tokens to generate.
|
||||
stream: Whether to stream the response.
|
||||
tools: Optional list of tools in the provider-specific format.
|
||||
|
||||
@@ -3,9 +3,12 @@ import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
# Import Generator type for isinstance check - Keep this import for type hints
|
||||
from google.genai.types import GenerateContentResponse
|
||||
|
||||
from providers.google_provider.client import initialize_client
|
||||
|
||||
# Correctly import the renamed function directly
|
||||
from providers.google_provider.completion import create_chat_completion
|
||||
from providers.google_provider.response import get_content, get_streaming_content, get_usage
|
||||
from providers.google_provider.tools import convert_to_google_tools, format_google_tool_results, has_google_tool_calls, parse_google_tool_calls
|
||||
@@ -28,7 +31,7 @@ class GoogleProvider(BaseProvider):
|
||||
api_key: The Google API key.
|
||||
base_url: Base URL (typically not used by Google client config, but kept for interface consistency).
|
||||
"""
|
||||
# initialize_client returns the configured genai module
|
||||
# initialize_client returns the client instance now
|
||||
self.client_module = initialize_client(api_key, base_url)
|
||||
self.api_key = api_key # Store if needed later
|
||||
self.base_url = base_url # Store if needed later
|
||||
@@ -38,18 +41,24 @@ class GoogleProvider(BaseProvider):
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
model: str,
|
||||
temperature: float = 0.4,
|
||||
temperature: float = 0.6,
|
||||
max_tokens: int | None = None,
|
||||
stream: bool = True,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
) -> Any: # Return type is complex: iterator for stream, GenerateContentResponse otherwise, or error dict/iterator
|
||||
"""Creates a chat completion using the Google Gemini API."""
|
||||
# Pass self (provider instance) to the helper function
|
||||
return create_chat_completion(self, messages, model, temperature, max_tokens, stream, tools)
|
||||
raw_response = create_chat_completion(self, messages, model, temperature, max_tokens, stream, tools)
|
||||
print(f"Raw response type: {type(raw_response)}") # Debugging line to check the type of raw_response
|
||||
print(f"Raw response: {raw_response}") # Debugging line to check the content of raw_response
|
||||
|
||||
# The completion helper function handles returning the correct type or an error dict.
|
||||
# No need for generator handling here anymore.
|
||||
return raw_response
|
||||
|
||||
def get_streaming_content(self, response: Any) -> Generator[str, None, None]:
|
||||
"""Extracts content chunks from a Google streaming response."""
|
||||
# Response is expected to be an iterator from generate_content(stream=True)
|
||||
# Response is expected to be an iterator from generate_content_stream
|
||||
return get_streaming_content(response)
|
||||
|
||||
def get_content(self, response: GenerateContentResponse | dict[str, Any]) -> str:
|
||||
|
||||
@@ -16,12 +16,13 @@ def initialize_client(api_key: str, base_url: str | None = None) -> Any:
|
||||
raise ImportError("Google Generative AI SDK is required for GoogleProvider. Please install google-generativeai.")
|
||||
|
||||
try:
|
||||
# Configure the client
|
||||
genai.configure(api_key=api_key)
|
||||
# Instantiate the client directly using the API key
|
||||
client = genai.Client(api_key=api_key)
|
||||
logger.info("Google Generative AI client instantiated.")
|
||||
if base_url:
|
||||
logger.warning(f"base_url '{base_url}' provided but not typically used by Google client configuration.")
|
||||
# Return the configured module itself, as it's used directly
|
||||
return genai
|
||||
logger.warning(f"base_url '{base_url}' provided but not typically used by Google client instantiation.")
|
||||
# Return the client instance
|
||||
return client
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure Google Generative AI client: {e}", exc_info=True)
|
||||
logger.error(f"Failed to instantiate Google Generative AI client: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
@@ -1,140 +1,174 @@
|
||||
# src/providers/google_provider/completion.py
|
||||
import json
|
||||
import logging
|
||||
import traceback
|
||||
from collections.abc import Iterable # Added Iterable
|
||||
from typing import Any
|
||||
|
||||
from google.genai.types import Tool
|
||||
# Import specific types for better hinting
|
||||
from google.genai.types import ContentDict, GenerateContentResponse, GenerationConfigDict, Tool
|
||||
|
||||
from providers.google_provider.tools import convert_to_google_tool_objects, convert_to_google_tools
|
||||
# Removed convert_to_google_tools import as it's handled later
|
||||
from providers.google_provider.tools import convert_to_google_tool_objects
|
||||
from providers.google_provider.utils import convert_messages
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_chat_completion(
|
||||
# --- Helper for Non-Streaming ---
|
||||
def _create_chat_completion_non_stream(
|
||||
provider,
|
||||
model: str,
|
||||
google_messages: list[ContentDict],
|
||||
generation_config: GenerationConfigDict,
|
||||
) -> GenerateContentResponse | dict[str, Any]:
|
||||
"""Handles the non-streaming API call."""
|
||||
try:
|
||||
logger.debug("Calling client.models.generate_content...")
|
||||
# Use the client instance stored on the provider
|
||||
response = provider.client_module.models.generate_content(
|
||||
model=f"models/{model}",
|
||||
contents=google_messages,
|
||||
config=generation_config,
|
||||
)
|
||||
logger.debug("generate_content call successful, returning raw response object.")
|
||||
# Return the direct response object
|
||||
return response
|
||||
except ValueError as ve:
|
||||
error_msg = f"Google API request validation error: {ve}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
# Return error dict
|
||||
return {"error": error_msg, "traceback": traceback.format_exc()}
|
||||
except Exception as e:
|
||||
error_msg = f"Google API error during non-stream chat completion: {e}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
# Return error dict
|
||||
return {"error": error_msg, "traceback": traceback.format_exc()}
|
||||
|
||||
|
||||
# --- Helper for Streaming ---
|
||||
def _create_chat_completion_stream(
|
||||
provider,
|
||||
model: str,
|
||||
google_messages: list[ContentDict],
|
||||
generation_config: GenerationConfigDict,
|
||||
) -> Iterable[GenerateContentResponse | dict[str, Any]]: # Return Iterable of response chunks or error dict
|
||||
"""Handles the streaming API call and yields results."""
|
||||
try:
|
||||
logger.debug("Calling client.models.generate_content_stream...")
|
||||
# Use the client instance stored on the provider
|
||||
response_iterator = provider.client_module.models.generate_content_stream(
|
||||
model=f"models/{model}",
|
||||
contents=google_messages,
|
||||
config=generation_config,
|
||||
)
|
||||
logger.debug("generate_content_stream call successful, yielding from iterator.")
|
||||
# Yield from the SDK's iterator which produces GenerateContentResponse chunks
|
||||
yield from response_iterator
|
||||
except ValueError as ve:
|
||||
error_msg = f"Google API request validation error: {ve}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
# Yield error as a dict matching non-streaming error structure
|
||||
yield {"error": error_msg, "traceback": traceback.format_exc()}
|
||||
except Exception as e:
|
||||
error_msg = f"Google API error during stream chat completion: {e}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
# Yield error as a dict
|
||||
yield {"error": error_msg, "traceback": traceback.format_exc()}
|
||||
|
||||
|
||||
# --- Main Function ---
|
||||
# Renamed original function to avoid conflict if needed, though overwrite is fine
|
||||
def create_chat_completion(
|
||||
provider, # Provider instance is passed in
|
||||
messages: list[dict[str, Any]],
|
||||
model: str,
|
||||
temperature: float = 0.4,
|
||||
temperature: float = 0.6,
|
||||
max_tokens: int | None = None,
|
||||
stream: bool = True,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
) -> Any:
|
||||
tools: list[dict[str, Any]] | None = None, # Expects intermediate dict format
|
||||
) -> Any: # Return type depends on stream flag
|
||||
"""
|
||||
Creates a chat completion using the Google Gemini API.
|
||||
|
||||
Args:
|
||||
provider: The instance of the GoogleProvider.
|
||||
messages: A list of message dictionaries in the standard format.
|
||||
model: The model ID to use (e.g., "gemini-1.5-flash").
|
||||
temperature: The sampling temperature.
|
||||
max_tokens: The maximum number of tokens to generate.
|
||||
stream: Whether to stream the response.
|
||||
tools: A list of tool definitions in the MCP format.
|
||||
|
||||
Returns:
|
||||
The response object from the Google API (could be a stream iterator or
|
||||
a GenerateContentResponse object), or an error dictionary/iterator.
|
||||
Delegates to streaming or non-streaming helpers. Contains NO yield itself.
|
||||
"""
|
||||
logger.debug(f"Google create_chat_completion called. Model: {model}, Stream: {stream}, Tools: {bool(tools)}")
|
||||
logger.debug(f"Google create_chat_completion_inner called. Model: {model}, Stream: {stream}, Tools: {bool(tools)}")
|
||||
|
||||
# Check if client exists on the provider instance
|
||||
if provider.client_module is None:
|
||||
error_msg = "Google Generative AI SDK not configured or installed."
|
||||
error_msg = "Google Generative AI client not initialized on provider."
|
||||
logger.error(error_msg)
|
||||
# Return an error structure compatible with both streaming and non-streaming expectations
|
||||
if stream:
|
||||
return iter([json.dumps({"error": error_msg})])
|
||||
else:
|
||||
return {"error": error_msg}
|
||||
# Return error dict directly for non-stream, create iterator for stream
|
||||
return iter([{"error": error_msg}]) if stream else {"error": error_msg}
|
||||
|
||||
try:
|
||||
# 1. Convert messages to Google's format
|
||||
# 1. Convert messages (Common logic)
|
||||
google_messages, system_prompt = convert_messages(messages)
|
||||
logger.debug(f"Converted {len(messages)} messages to {len(google_messages)} Google Content objects. System prompt present: {bool(system_prompt)}")
|
||||
|
||||
# 2. Prepare generation configuration
|
||||
generation_config: dict[str, Any] = {"temperature": temperature}
|
||||
# 2. Prepare generation configuration (Common logic)
|
||||
# Use GenerationConfigDict for better type hinting if possible
|
||||
generation_config: GenerationConfigDict = {"temperature": temperature}
|
||||
if max_tokens is not None:
|
||||
# Google uses 'max_output_tokens'
|
||||
generation_config["max_output_tokens"] = max_tokens
|
||||
logger.debug(f"Setting max_output_tokens: {max_tokens}")
|
||||
else:
|
||||
logger.debug("No max_tokens specified.")
|
||||
# Google requires max_output_tokens, set a default if None
|
||||
# Defaulting to a reasonable value, e.g., 8192, check model limits if needed
|
||||
default_max_tokens = 8192
|
||||
generation_config["max_output_tokens"] = default_max_tokens
|
||||
logger.warning(f"max_tokens not provided, defaulting to {default_max_tokens} for Google API.")
|
||||
|
||||
# 3. Convert tools if provided
|
||||
# 3. Convert tools if provided (Common logic)
|
||||
google_tool_objects: list[Tool] | None = None
|
||||
if tools:
|
||||
try:
|
||||
# Step 3a: Convert MCP tools to intermediate Google dict format
|
||||
tool_dict_list = convert_to_google_tools(tools)
|
||||
# Step 3b: Convert intermediate dict format to Google Tool objects
|
||||
google_tool_objects = convert_to_google_tool_objects(tool_dict_list)
|
||||
# Convert intermediate dict format to Google Tool objects
|
||||
google_tool_objects = convert_to_google_tool_objects(tools)
|
||||
if google_tool_objects:
|
||||
logger.debug(f"Successfully converted {len(tools)} MCP tools to {len(google_tool_objects)} Google Tool objects.")
|
||||
num_declarations = sum(len(t.function_declarations) for t in google_tool_objects if t.function_declarations)
|
||||
logger.debug(f"Successfully converted intermediate tool config to {len(google_tool_objects)} Google Tool objects with {num_declarations} declarations.")
|
||||
else:
|
||||
logger.warning("Tool conversion resulted in no valid Google Tool objects.")
|
||||
except Exception as tool_conv_err:
|
||||
logger.error(f"Failed to convert tools for Google: {tool_conv_err}", exc_info=True)
|
||||
# Decide whether to proceed without tools or raise an error
|
||||
# Proceeding without tools for now, but logging the error.
|
||||
google_tool_objects = None
|
||||
google_tool_objects = None # Continue without tools on conversion error
|
||||
else:
|
||||
logger.debug("No tools provided for conversion.")
|
||||
|
||||
# 4. Initialize the Google Generative Model
|
||||
# Ensure client_module is callable and has GenerativeModel
|
||||
if not hasattr(provider.client_module, "GenerativeModel"):
|
||||
raise AttributeError("Configured Google client module does not have 'GenerativeModel'")
|
||||
# 4. Add system prompt and tools to generation_config (Common logic)
|
||||
if system_prompt:
|
||||
# Ensure system_instruction is ContentDict or compatible type
|
||||
generation_config["system_instruction"] = system_prompt
|
||||
logger.debug("Added system_instruction to generation_config.")
|
||||
if google_tool_objects:
|
||||
# Assign the list of Tool objects directly
|
||||
generation_config["tools"] = google_tool_objects
|
||||
logger.debug(f"Added {len(google_tool_objects)} tool objects to generation_config.")
|
||||
|
||||
gemini_model = provider.client_module.GenerativeModel(
|
||||
model_name=model,
|
||||
system_instruction=system_prompt, # Pass extracted system prompt
|
||||
tools=google_tool_objects, # Pass converted Tool objects (or None)
|
||||
# Add safety_settings if needed: safety_settings=...
|
||||
)
|
||||
logger.debug(f"Initialized Google GenerativeModel for '{model}'.")
|
||||
|
||||
# 5. Log parameters before API call
|
||||
# 5. Log parameters before API call (Common logic)
|
||||
log_params = {
|
||||
"model": model,
|
||||
"stream": stream,
|
||||
"temperature": temperature,
|
||||
"max_output_tokens": generation_config.get("max_output_tokens"),
|
||||
"system_prompt_present": bool(system_prompt),
|
||||
"num_tools": len(google_tool_objects) if google_tool_objects else 0,
|
||||
"num_tools": len(generation_config.get("tools", [])) if "tools" in generation_config else 0,
|
||||
"num_messages": len(google_messages),
|
||||
}
|
||||
logger.info(f"Calling Google generate_content API with params: {log_params}")
|
||||
# Avoid logging full message content unless necessary for debugging specific issues
|
||||
# logger.debug(f"Google messages being sent: {google_messages}")
|
||||
logger.info(f"Calling Google API via helper with params: {log_params}")
|
||||
|
||||
# 6. Call the Google API
|
||||
response = gemini_model.generate_content(
|
||||
contents=google_messages,
|
||||
generation_config=generation_config,
|
||||
stream=stream,
|
||||
# tool_config={"function_calling_config": "AUTO"} # AUTO is default
|
||||
)
|
||||
logger.debug("Google API call successful, returning response object.")
|
||||
return response
|
||||
|
||||
except ValueError as ve: # Catch specific errors like invalid message sequence
|
||||
error_msg = f"Google API request validation error: {ve}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
# 6. Delegate to appropriate helper
|
||||
if stream:
|
||||
# Yield a JSON error message in an iterator
|
||||
yield json.dumps({"error": error_msg, "traceback": traceback.format_exc()})
|
||||
# Return the generator/iterator from the streaming helper
|
||||
# This helper uses 'yield from'
|
||||
return _create_chat_completion_stream(provider, model, google_messages, generation_config)
|
||||
else:
|
||||
# Return an error dictionary
|
||||
return {"error": error_msg, "traceback": traceback.format_exc()}
|
||||
# Return the direct result (GenerateContentResponse or error dict) from the non-streaming helper
|
||||
# This helper uses 'return'
|
||||
return _create_chat_completion_non_stream(provider, model, google_messages, generation_config)
|
||||
|
||||
except Exception as e:
|
||||
# Catch any other exceptions during setup or API call
|
||||
error_msg = f"Google API error during chat completion: {e}"
|
||||
# Catch errors during common setup (message/tool conversion etc.)
|
||||
error_msg = f"Error during Google completion setup: {e}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
if stream:
|
||||
# Yield a JSON error message in an iterator
|
||||
yield json.dumps({"error": error_msg, "traceback": traceback.format_exc()})
|
||||
else:
|
||||
# Return an error dictionary
|
||||
return {"error": error_msg, "traceback": traceback.format_exc()}
|
||||
# Return error dict directly for non-stream, create iterator for stream
|
||||
return iter([{"error": error_msg, "traceback": traceback.format_exc()}]) if stream else {"error": error_msg, "traceback": traceback.format_exc()}
|
||||
|
||||
@@ -124,37 +124,46 @@ def get_content(response: GenerateContentResponse | dict[str, Any]) -> str:
|
||||
The concatenated text content, or an error message string.
|
||||
"""
|
||||
try:
|
||||
# Handle error dictionary case
|
||||
# Check if it's an error dictionary passed from upstream (e.g., completion helper)
|
||||
if isinstance(response, dict) and "error" in response:
|
||||
logger.error(f"Cannot get content from error response: {response['error']}")
|
||||
logger.error(f"Cannot get content from error dict: {response['error']}")
|
||||
return f"[Error: {response['error']}]"
|
||||
|
||||
# Handle successful GenerateContentResponse object
|
||||
if hasattr(response, "text"):
|
||||
# The `.text` attribute usually provides the concatenated text content directly
|
||||
# Ensure it's a GenerateContentResponse object before accessing attributes
|
||||
if not isinstance(response, GenerateContentResponse):
|
||||
logger.error(f"Cannot get content: Expected GenerateContentResponse or error dict, got {type(response)}")
|
||||
return f"[Error: Unexpected response type {type(response)}]"
|
||||
|
||||
# --- Access GenerateContentResponse attributes ---
|
||||
# Prioritize response.text if available and not empty
|
||||
if hasattr(response, "text") and response.text:
|
||||
content = response.text
|
||||
logger.debug(f"Extracted content (length {len(content)}) from response.text.")
|
||||
return content
|
||||
elif hasattr(response, "candidates") and response.candidates:
|
||||
# Fallback: manually concatenate text from parts if .text is missing
|
||||
|
||||
# Fallback: manually concatenate text from parts if .text is missing/empty
|
||||
if hasattr(response, "candidates") and response.candidates:
|
||||
first_candidate = response.candidates[0]
|
||||
if hasattr(first_candidate, "content") and hasattr(first_candidate.content, "parts"):
|
||||
text_parts = []
|
||||
for part in first_candidate.content.parts:
|
||||
if hasattr(part, "text"):
|
||||
text_parts.append(part.text)
|
||||
# We are only interested in text content here, ignore function calls etc.
|
||||
content = "".join(text_parts)
|
||||
logger.debug(f"Extracted content (length {len(content)}) from response candidates' parts.")
|
||||
return content
|
||||
# Check candidate content and parts carefully
|
||||
if hasattr(first_candidate, "content") and first_candidate.content and hasattr(first_candidate.content, "parts") and first_candidate.content.parts:
|
||||
text_parts = [part.text for part in first_candidate.content.parts if hasattr(part, "text")]
|
||||
if text_parts:
|
||||
content = "".join(text_parts)
|
||||
logger.debug(f"Extracted content (length {len(content)}) from response candidate parts.")
|
||||
return content
|
||||
else:
|
||||
logger.warning("Google response candidate parts contained no text.")
|
||||
return "" # Return empty if parts exist but have no text
|
||||
else:
|
||||
logger.warning("Google response candidate has no content or parts.")
|
||||
return "" # Return empty string if no text found
|
||||
logger.warning("Google response candidate has no valid content or parts.")
|
||||
return "" # Return empty string if no valid content/parts
|
||||
else:
|
||||
logger.warning(f"Could not extract content from Google response: No 'text' or valid 'candidates'. Response type: {type(response)}")
|
||||
# If neither .text nor valid candidates are found
|
||||
logger.warning(f"Could not extract content from Google response: No .text or valid candidates found. Response: {response}")
|
||||
return "" # Return empty string if no text found
|
||||
|
||||
except AttributeError as ae:
|
||||
logger.error(f"Attribute error extracting content from Google response: {ae}. Response object: {response}", exc_info=True)
|
||||
logger.error(f"Attribute error extracting content from Google response: {ae}. Response type: {type(response)}", exc_info=True)
|
||||
return f"[Error extracting content: Attribute missing - {str(ae)}]"
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error extracting content from Google response: {e}", exc_info=True)
|
||||
@@ -173,32 +182,34 @@ def get_usage(response: GenerateContentResponse | dict[str, Any]) -> dict[str, i
|
||||
usage information is unavailable or an error occurred.
|
||||
"""
|
||||
try:
|
||||
# Handle error dictionary case
|
||||
# Check if it's an error dictionary passed from upstream
|
||||
if isinstance(response, dict) and "error" in response:
|
||||
logger.warning("Cannot get usage from error response.")
|
||||
logger.warning(f"Cannot get usage from error dict: {response['error']}")
|
||||
return None
|
||||
|
||||
# Check for usage metadata in the response object
|
||||
if hasattr(response, "usage_metadata"):
|
||||
metadata = response.usage_metadata
|
||||
# Ensure it's a GenerateContentResponse object before accessing attributes
|
||||
if not isinstance(response, GenerateContentResponse):
|
||||
logger.warning(f"Cannot get usage: Expected GenerateContentResponse or error dict, got {type(response)}")
|
||||
return None
|
||||
|
||||
# Safely access usage metadata
|
||||
metadata = getattr(response, "usage_metadata", None)
|
||||
if metadata:
|
||||
# Google uses prompt_token_count and candidates_token_count
|
||||
prompt_tokens = getattr(metadata, "prompt_token_count", 0)
|
||||
completion_tokens = getattr(metadata, "candidates_token_count", 0)
|
||||
usage = {
|
||||
"prompt_tokens": getattr(metadata, "prompt_token_count", 0),
|
||||
"completion_tokens": getattr(metadata, "candidates_token_count", 0),
|
||||
# Google also provides total_token_count, could be added if needed
|
||||
# "total_tokens": getattr(metadata, "total_token_count", 0),
|
||||
"prompt_tokens": int(prompt_tokens),
|
||||
"completion_tokens": int(completion_tokens),
|
||||
}
|
||||
# Ensure values are integers
|
||||
usage = {k: int(v) for k, v in usage.items()}
|
||||
logger.debug(f"Extracted usage from Google response metadata: {usage}")
|
||||
return usage
|
||||
else:
|
||||
# Log a warning only if it's not clearly an error dict already handled
|
||||
if not (isinstance(response, dict) and "error" in response):
|
||||
logger.warning(f"Could not extract usage from Google response object of type {type(response)}. No 'usage_metadata' attribute found.")
|
||||
logger.warning(f"Could not extract usage from Google response object: No 'usage_metadata' attribute found. Response: {response}")
|
||||
return None
|
||||
|
||||
except AttributeError as ae:
|
||||
logger.error(f"Attribute error extracting usage from Google response: {ae}. Response object: {response}", exc_info=True)
|
||||
logger.error(f"Attribute error extracting usage from Google response: {ae}. Response type: {type(response)}", exc_info=True)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error extracting usage from Google response: {e}", exc_info=True)
|
||||
|
||||
@@ -13,7 +13,7 @@ import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from google.genai.types import FunctionDeclaration, Schema, Tool
|
||||
from google.genai.types import FunctionDeclaration, Schema, Tool, Type
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -95,50 +95,66 @@ def _create_google_schema_recursive(schema_dict: dict[str, Any]) -> Schema | Non
|
||||
|
||||
Handles type mapping and nested structures. Returns None on failure.
|
||||
"""
|
||||
if Schema is None:
|
||||
logger.error("Cannot create Schema object: google.genai types not available.")
|
||||
if Schema is None or Type is None:
|
||||
logger.error("Cannot create Schema object: google.genai types (Schema or Type) not available.")
|
||||
return None
|
||||
|
||||
if not isinstance(schema_dict, dict):
|
||||
logger.warning(f"Invalid schema part encountered: {schema_dict}. Returning empty schema.")
|
||||
return Schema() # Return empty schema to avoid breaking the parent
|
||||
logger.warning(f"Invalid schema part encountered: {schema_dict}. Returning None.")
|
||||
return None # Return None on invalid input
|
||||
|
||||
# Map JSON Schema types to Google's Type enum strings
|
||||
# Map JSON Schema types to Google's Type enum members
|
||||
type_mapping = {
|
||||
"string": "STRING",
|
||||
"number": "NUMBER",
|
||||
"integer": "INTEGER",
|
||||
"boolean": "BOOLEAN",
|
||||
"array": "ARRAY",
|
||||
"object": "OBJECT",
|
||||
# Add other mappings if necessary
|
||||
"string": Type.STRING,
|
||||
"number": Type.NUMBER,
|
||||
"integer": Type.INTEGER,
|
||||
"boolean": Type.BOOLEAN,
|
||||
"array": Type.ARRAY,
|
||||
"object": Type.OBJECT,
|
||||
}
|
||||
original_type = schema_dict.get("type")
|
||||
google_type = type_mapping.get(str(original_type).lower()) if original_type else None
|
||||
|
||||
if not google_type:
|
||||
logger.warning(f"Schema dictionary missing 'type' or type '{original_type}' is not recognized: {schema_dict}. Returning None.")
|
||||
return None # Return None if type is invalid/missing
|
||||
|
||||
# Prepare arguments for Schema constructor, filtering out None values
|
||||
schema_args = {
|
||||
"type": google_type,
|
||||
"type": google_type, # Use the Type enum member
|
||||
"format": schema_dict.get("format"),
|
||||
"description": schema_dict.get("description"),
|
||||
"nullable": schema_dict.get("nullable"),
|
||||
"nullable": schema_dict.get("nullable"), # Note: Google's Schema might not directly support nullable in constructor
|
||||
"enum": schema_dict.get("enum"),
|
||||
"items": _create_google_schema_recursive(schema_dict["items"]) if "items" in schema_dict and google_type == "ARRAY" else None,
|
||||
"properties": {k: _create_google_schema_recursive(v) for k, v in schema_dict.get("properties", {}).items()} if schema_dict.get("properties") and google_type == "OBJECT" else None,
|
||||
"required": schema_dict.get("required") if google_type == "OBJECT" else None,
|
||||
# Recursively create nested schemas, ensuring None is handled if recursion fails
|
||||
"items": _create_google_schema_recursive(schema_dict["items"]) if google_type == Type.ARRAY and "items" in schema_dict else None,
|
||||
"properties": {k: prop_schema for k, v in schema_dict.get("properties", {}).items() if (prop_schema := _create_google_schema_recursive(v)) is not None}
|
||||
if google_type == Type.OBJECT and schema_dict.get("properties")
|
||||
else None,
|
||||
"required": schema_dict.get("required") if google_type == Type.OBJECT else None,
|
||||
}
|
||||
# Remove keys with None values
|
||||
|
||||
# Remove keys with None values before passing to Schema constructor
|
||||
schema_args = {k: v for k, v in schema_args.items() if v is not None}
|
||||
|
||||
if not schema_args.get("type"):
|
||||
logger.warning(f"Schema dictionary missing 'type' or type '{original_type}' is not recognized: {schema_dict}. Creating empty Schema.")
|
||||
return Schema() # Return empty schema
|
||||
# Handle specific cases for ARRAY and OBJECT where items/properties might be needed
|
||||
if google_type == Type.ARRAY and "items" not in schema_args:
|
||||
logger.warning(f"Array schema missing 'items': {schema_dict}. Returning None.")
|
||||
return None # Array schema requires items
|
||||
if google_type == Type.OBJECT and "properties" not in schema_args:
|
||||
# Allow object schema without properties initially, might be handled later
|
||||
pass
|
||||
# logger.warning(f"Object schema missing 'properties': {schema_dict}. Creating empty properties.")
|
||||
# schema_args["properties"] = {} # Or return None if properties are strictly required
|
||||
|
||||
try:
|
||||
return Schema(**schema_args)
|
||||
# Create the Schema object
|
||||
created_schema = Schema(**schema_args)
|
||||
# logger.debug(f"Successfully created Schema: {created_schema}")
|
||||
return created_schema
|
||||
except Exception as schema_creation_err:
|
||||
logger.error(f"Failed to create Schema object with args {schema_args}: {schema_creation_err}", exc_info=True)
|
||||
return Schema() # Return empty schema on error
|
||||
return None # Return None on creation error
|
||||
|
||||
|
||||
def convert_to_google_tool_objects(tool_configs: list[dict[str, Any]]) -> list[Tool] | None:
|
||||
@@ -169,31 +185,71 @@ def convert_to_google_tool_objects(tool_configs: list[dict[str, Any]]) -> list[T
|
||||
return None
|
||||
|
||||
for func_dict in func_declarations_list:
|
||||
func_name = func_dict.get("name", "Unknown")
|
||||
try:
|
||||
params_schema_dict = func_dict.get("parameters", {"type": "object", "properties": {}})
|
||||
# Ensure parameters is a valid schema dict for the recursive creator
|
||||
params_schema_dict = func_dict.get("parameters", {})
|
||||
|
||||
# Ensure parameters is a dict and defaults to object type if missing
|
||||
if not isinstance(params_schema_dict, dict):
|
||||
logger.warning(f"Invalid 'parameters' format for tool {func_dict.get('name')}: {params_schema_dict}. Using empty object schema.")
|
||||
logger.warning(f"Invalid 'parameters' format for tool {func_name}: {params_schema_dict}. Using empty object schema.")
|
||||
params_schema_dict = {"type": "object", "properties": {}}
|
||||
elif params_schema_dict.get("type") != "object":
|
||||
logger.warning(f"Tool {func_dict.get('name')} parameters schema is not type 'object'. Forcing object type.")
|
||||
params_schema_dict = {"type": "object", "properties": params_schema_dict.get("properties", {})} # Attempt to salvage properties
|
||||
elif "type" not in params_schema_dict:
|
||||
params_schema_dict["type"] = "object" # Default to object if type is missing
|
||||
elif params_schema_dict["type"] != "object":
|
||||
logger.warning(f"Tool {func_name} parameters schema is not type 'object' ({params_schema_dict.get('type')}). Google requires 'object'. Attempting to wrap properties.")
|
||||
# Attempt to salvage properties if the top level isn't object
|
||||
original_properties = params_schema_dict.get("properties", {})
|
||||
if not isinstance(original_properties, dict):
|
||||
original_properties = {}
|
||||
params_schema_dict = {"type": "object", "properties": original_properties}
|
||||
|
||||
parameters_schema = _create_google_schema_recursive(params_schema_dict)
|
||||
|
||||
# Only proceed if schema creation was somewhat successful
|
||||
if parameters_schema is not None:
|
||||
declaration = FunctionDeclaration(
|
||||
name=func_dict["name"],
|
||||
description=func_dict.get("description", ""),
|
||||
parameters=parameters_schema,
|
||||
)
|
||||
all_func_declarations.append(declaration)
|
||||
properties_dict = params_schema_dict.get("properties", {})
|
||||
google_properties = {}
|
||||
if isinstance(properties_dict, dict):
|
||||
for prop_name, prop_schema_dict in properties_dict.items():
|
||||
prop_schema = _create_google_schema_recursive(prop_schema_dict)
|
||||
if prop_schema:
|
||||
google_properties[prop_name] = prop_schema
|
||||
else:
|
||||
logger.warning(f"Failed to create schema for property '{prop_name}' in tool '{func_name}'. Skipping property.")
|
||||
else:
|
||||
logger.error(f"Failed to create parameters Schema for FunctionDeclaration '{func_dict.get('name', 'Unknown')}'")
|
||||
logger.warning(f"'properties' for tool {func_name} is not a dictionary: {properties_dict}. Ignoring properties.")
|
||||
|
||||
# Handle empty properties - Google requires parameters to be OBJECT, and properties cannot be null/empty
|
||||
if not google_properties:
|
||||
logger.warning(f"Function '{func_name}' has no valid properties defined. Adding dummy property for Google compatibility.")
|
||||
google_properties = {"_dummy_param": Schema(type=Type.STRING, description="Placeholder parameter as properties cannot be empty.")}
|
||||
# Clear required list if properties are empty/dummy
|
||||
required_list = []
|
||||
else:
|
||||
# Validate required list against actual properties
|
||||
original_required = params_schema_dict.get("required", [])
|
||||
if isinstance(original_required, list):
|
||||
required_list = [req for req in original_required if req in google_properties]
|
||||
if len(required_list) != len(original_required):
|
||||
logger.warning(f"Some required properties for '{func_name}' were invalid or missing from properties: {set(original_required) - set(required_list)}")
|
||||
else:
|
||||
logger.warning(f"'required' field for '{func_name}' is not a list: {original_required}. Ignoring required field.")
|
||||
required_list = []
|
||||
|
||||
# Create the top-level parameters schema, ensuring it's OBJECT type
|
||||
parameters_schema = Schema(
|
||||
type=Type.OBJECT,
|
||||
properties=google_properties,
|
||||
required=required_list if required_list else None, # Pass None if empty list
|
||||
)
|
||||
|
||||
# Create the FunctionDeclaration
|
||||
declaration = FunctionDeclaration(
|
||||
name=func_name,
|
||||
description=func_dict.get("description", ""),
|
||||
parameters=parameters_schema,
|
||||
)
|
||||
all_func_declarations.append(declaration)
|
||||
logger.debug(f"Successfully created FunctionDeclaration for: {func_name}")
|
||||
|
||||
except Exception as decl_err:
|
||||
logger.error(f"Failed to create FunctionDeclaration object for tool '{func_dict.get('name', 'Unknown')}': {decl_err}", exc_info=True)
|
||||
logger.error(f"Failed to create FunctionDeclaration object for tool '{func_name}': {decl_err}", exc_info=True)
|
||||
|
||||
else:
|
||||
logger.error(f"Invalid tool_configs structure provided: {tool_configs}")
|
||||
|
||||
@@ -108,16 +108,16 @@ def convert_messages(messages: list[dict[str, Any]]) -> tuple[list[Content], str
|
||||
|
||||
# Include any text content alongside the function calls
|
||||
if content and isinstance(content, str):
|
||||
parts.append(Part.from_text(content))
|
||||
parts.append(Part(text=content)) # Use direct instantiation
|
||||
|
||||
elif content:
|
||||
# Regular user or assistant message content
|
||||
if isinstance(content, str):
|
||||
parts.append(Part.from_text(content))
|
||||
parts.append(Part(text=content)) # Use direct instantiation
|
||||
# TODO: Handle potential image content if needed in the future
|
||||
else:
|
||||
logger.warning(f"Unsupported content type for role '{role}': {type(content)}. Converting to string.")
|
||||
parts.append(Part.from_text(str(content)))
|
||||
parts.append(Part(text=str(content))) # Use direct instantiation
|
||||
|
||||
# Add the constructed Content object if parts were generated
|
||||
if parts:
|
||||
|
||||
@@ -32,7 +32,7 @@ class OpenAIProvider(BaseProvider):
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
model: str,
|
||||
temperature: float = 0.4,
|
||||
temperature: float = 0.6,
|
||||
max_tokens: int | None = None,
|
||||
stream: bool = True,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
|
||||
@@ -14,7 +14,7 @@ def create_chat_completion(
|
||||
provider, # The OpenAIProvider instance
|
||||
messages: list[dict[str, str]],
|
||||
model: str,
|
||||
temperature: float = 0.4,
|
||||
temperature: float = 0.6,
|
||||
max_tokens: int | None = None,
|
||||
stream: bool = True,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
|
||||
Reference in New Issue
Block a user