Compare commits
2 Commits
2fb6c5af3c
...
51e3058961
| Author | SHA1 | Date | |
|---|---|---|---|
|
51e3058961
|
|||
|
ccf750fed4
|
@@ -56,7 +56,7 @@ class LLMClient:
|
|||||||
self,
|
self,
|
||||||
messages: list[dict[str, str]],
|
messages: list[dict[str, str]],
|
||||||
model: str,
|
model: str,
|
||||||
temperature: float = 0.4,
|
temperature: float = 0.6,
|
||||||
max_tokens: int | None = None,
|
max_tokens: int | None = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
) -> Generator[str, None, None] | dict[str, Any]:
|
) -> Generator[str, None, None] | dict[str, Any]:
|
||||||
@@ -97,6 +97,7 @@ class LLMClient:
|
|||||||
stream=stream,
|
stream=stream,
|
||||||
tools=provider_tools,
|
tools=provider_tools,
|
||||||
)
|
)
|
||||||
|
print(f"Response: {response}") # Debugging line to check the response
|
||||||
logger.info("Received response from provider.")
|
logger.info("Received response from provider.")
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
def create_chat_completion(
|
def create_chat_completion(
|
||||||
provider, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, stream: bool = True, tools: list[dict[str, Any]] | None = None
|
provider, messages: list[dict[str, Any]], model: str, temperature: float = 0.6, max_tokens: int | None = None, stream: bool = True, tools: list[dict[str, Any]] | None = None
|
||||||
) -> Stream | Message:
|
) -> Stream | Message:
|
||||||
logger.debug(f"Creating Anthropic chat completion. Model: {model}, Stream: {stream}, Tools: {bool(tools)}")
|
logger.debug(f"Creating Anthropic chat completion. Model: {model}, Stream: {stream}, Tools: {bool(tools)}")
|
||||||
temp_system_prompt, temp_anthropic_messages = convert_messages(messages)
|
temp_system_prompt, temp_anthropic_messages = convert_messages(messages)
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ class BaseProvider(abc.ABC):
|
|||||||
self,
|
self,
|
||||||
messages: list[dict[str, str]],
|
messages: list[dict[str, str]],
|
||||||
model: str,
|
model: str,
|
||||||
temperature: float = 0.4,
|
temperature: float = 0.6,
|
||||||
max_tokens: int | None = None,
|
max_tokens: int | None = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
tools: list[dict[str, Any]] | None = None,
|
tools: list[dict[str, Any]] | None = None,
|
||||||
@@ -39,7 +39,7 @@ class BaseProvider(abc.ABC):
|
|||||||
Args:
|
Args:
|
||||||
messages: List of message dictionaries with 'role' and 'content'.
|
messages: List of message dictionaries with 'role' and 'content'.
|
||||||
model: Model identifier.
|
model: Model identifier.
|
||||||
temperature: Sampling temperature (0-1).
|
temperature: Sampling temperature (0-2).
|
||||||
max_tokens: Maximum tokens to generate.
|
max_tokens: Maximum tokens to generate.
|
||||||
stream: Whether to stream the response.
|
stream: Whether to stream the response.
|
||||||
tools: Optional list of tools in the provider-specific format.
|
tools: Optional list of tools in the provider-specific format.
|
||||||
|
|||||||
@@ -3,9 +3,12 @@ import logging
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
# Import Generator type for isinstance check - Keep this import for type hints
|
||||||
from google.genai.types import GenerateContentResponse
|
from google.genai.types import GenerateContentResponse
|
||||||
|
|
||||||
from providers.google_provider.client import initialize_client
|
from providers.google_provider.client import initialize_client
|
||||||
|
|
||||||
|
# Correctly import the renamed function directly
|
||||||
from providers.google_provider.completion import create_chat_completion
|
from providers.google_provider.completion import create_chat_completion
|
||||||
from providers.google_provider.response import get_content, get_streaming_content, get_usage
|
from providers.google_provider.response import get_content, get_streaming_content, get_usage
|
||||||
from providers.google_provider.tools import convert_to_google_tools, format_google_tool_results, has_google_tool_calls, parse_google_tool_calls
|
from providers.google_provider.tools import convert_to_google_tools, format_google_tool_results, has_google_tool_calls, parse_google_tool_calls
|
||||||
@@ -28,7 +31,7 @@ class GoogleProvider(BaseProvider):
|
|||||||
api_key: The Google API key.
|
api_key: The Google API key.
|
||||||
base_url: Base URL (typically not used by Google client config, but kept for interface consistency).
|
base_url: Base URL (typically not used by Google client config, but kept for interface consistency).
|
||||||
"""
|
"""
|
||||||
# initialize_client returns the configured genai module
|
# initialize_client returns the client instance now
|
||||||
self.client_module = initialize_client(api_key, base_url)
|
self.client_module = initialize_client(api_key, base_url)
|
||||||
self.api_key = api_key # Store if needed later
|
self.api_key = api_key # Store if needed later
|
||||||
self.base_url = base_url # Store if needed later
|
self.base_url = base_url # Store if needed later
|
||||||
@@ -38,18 +41,24 @@ class GoogleProvider(BaseProvider):
|
|||||||
self,
|
self,
|
||||||
messages: list[dict[str, Any]],
|
messages: list[dict[str, Any]],
|
||||||
model: str,
|
model: str,
|
||||||
temperature: float = 0.4,
|
temperature: float = 0.6,
|
||||||
max_tokens: int | None = None,
|
max_tokens: int | None = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
tools: list[dict[str, Any]] | None = None,
|
tools: list[dict[str, Any]] | None = None,
|
||||||
) -> Any: # Return type is complex: iterator for stream, GenerateContentResponse otherwise, or error dict/iterator
|
) -> Any: # Return type is complex: iterator for stream, GenerateContentResponse otherwise, or error dict/iterator
|
||||||
"""Creates a chat completion using the Google Gemini API."""
|
"""Creates a chat completion using the Google Gemini API."""
|
||||||
# Pass self (provider instance) to the helper function
|
# Pass self (provider instance) to the helper function
|
||||||
return create_chat_completion(self, messages, model, temperature, max_tokens, stream, tools)
|
raw_response = create_chat_completion(self, messages, model, temperature, max_tokens, stream, tools)
|
||||||
|
print(f"Raw response type: {type(raw_response)}") # Debugging line to check the type of raw_response
|
||||||
|
print(f"Raw response: {raw_response}") # Debugging line to check the content of raw_response
|
||||||
|
|
||||||
|
# The completion helper function handles returning the correct type or an error dict.
|
||||||
|
# No need for generator handling here anymore.
|
||||||
|
return raw_response
|
||||||
|
|
||||||
def get_streaming_content(self, response: Any) -> Generator[str, None, None]:
|
def get_streaming_content(self, response: Any) -> Generator[str, None, None]:
|
||||||
"""Extracts content chunks from a Google streaming response."""
|
"""Extracts content chunks from a Google streaming response."""
|
||||||
# Response is expected to be an iterator from generate_content(stream=True)
|
# Response is expected to be an iterator from generate_content_stream
|
||||||
return get_streaming_content(response)
|
return get_streaming_content(response)
|
||||||
|
|
||||||
def get_content(self, response: GenerateContentResponse | dict[str, Any]) -> str:
|
def get_content(self, response: GenerateContentResponse | dict[str, Any]) -> str:
|
||||||
|
|||||||
@@ -12,16 +12,17 @@ def initialize_client(api_key: str, base_url: str | None = None) -> Any:
|
|||||||
logger.info("Initializing Google Generative AI client")
|
logger.info("Initializing Google Generative AI client")
|
||||||
|
|
||||||
if genai is None:
|
if genai is None:
|
||||||
logger.error("Google Generative AI SDK (google-generativeai) is not installed.")
|
logger.error("Google Generative AI SDK (google-genai) is not installed.")
|
||||||
raise ImportError("Google Generative AI SDK is required for GoogleProvider. Please install google-generativeai.")
|
raise ImportError("Google Generative AI SDK is required for GoogleProvider. Please install google-generativeai.")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Configure the client
|
# Instantiate the client directly using the API key
|
||||||
genai.configure(api_key=api_key)
|
client = genai.Client(api_key=api_key)
|
||||||
|
logger.info("Google Generative AI client instantiated.")
|
||||||
if base_url:
|
if base_url:
|
||||||
logger.warning(f"base_url '{base_url}' provided but not typically used by Google client configuration.")
|
logger.warning(f"base_url '{base_url}' provided but not typically used by Google client instantiation.")
|
||||||
# Return the configured module itself, as it's used directly
|
# Return the client instance
|
||||||
return genai
|
return client
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to configure Google Generative AI client: {e}", exc_info=True)
|
logger.error(f"Failed to instantiate Google Generative AI client: {e}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|||||||
@@ -1,140 +1,174 @@
|
|||||||
# src/providers/google_provider/completion.py
|
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import traceback
|
import traceback
|
||||||
|
from collections.abc import Iterable # Added Iterable
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from google.genai.types import Tool
|
# Import specific types for better hinting
|
||||||
|
from google.genai.types import ContentDict, GenerateContentResponse, GenerationConfigDict, Tool
|
||||||
|
|
||||||
from providers.google_provider.tools import convert_to_google_tool_objects, convert_to_google_tools
|
# Removed convert_to_google_tools import as it's handled later
|
||||||
|
from providers.google_provider.tools import convert_to_google_tool_objects
|
||||||
from providers.google_provider.utils import convert_messages
|
from providers.google_provider.utils import convert_messages
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def create_chat_completion(
|
# --- Helper for Non-Streaming ---
|
||||||
|
def _create_chat_completion_non_stream(
|
||||||
provider,
|
provider,
|
||||||
|
model: str,
|
||||||
|
google_messages: list[ContentDict],
|
||||||
|
generation_config: GenerationConfigDict,
|
||||||
|
) -> GenerateContentResponse | dict[str, Any]:
|
||||||
|
"""Handles the non-streaming API call."""
|
||||||
|
try:
|
||||||
|
logger.debug("Calling client.models.generate_content...")
|
||||||
|
# Use the client instance stored on the provider
|
||||||
|
response = provider.client_module.models.generate_content(
|
||||||
|
model=f"models/{model}",
|
||||||
|
contents=google_messages,
|
||||||
|
config=generation_config,
|
||||||
|
)
|
||||||
|
logger.debug("generate_content call successful, returning raw response object.")
|
||||||
|
# Return the direct response object
|
||||||
|
return response
|
||||||
|
except ValueError as ve:
|
||||||
|
error_msg = f"Google API request validation error: {ve}"
|
||||||
|
logger.error(error_msg, exc_info=True)
|
||||||
|
# Return error dict
|
||||||
|
return {"error": error_msg, "traceback": traceback.format_exc()}
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Google API error during non-stream chat completion: {e}"
|
||||||
|
logger.error(error_msg, exc_info=True)
|
||||||
|
# Return error dict
|
||||||
|
return {"error": error_msg, "traceback": traceback.format_exc()}
|
||||||
|
|
||||||
|
|
||||||
|
# --- Helper for Streaming ---
|
||||||
|
def _create_chat_completion_stream(
|
||||||
|
provider,
|
||||||
|
model: str,
|
||||||
|
google_messages: list[ContentDict],
|
||||||
|
generation_config: GenerationConfigDict,
|
||||||
|
) -> Iterable[GenerateContentResponse | dict[str, Any]]: # Return Iterable of response chunks or error dict
|
||||||
|
"""Handles the streaming API call and yields results."""
|
||||||
|
try:
|
||||||
|
logger.debug("Calling client.models.generate_content_stream...")
|
||||||
|
# Use the client instance stored on the provider
|
||||||
|
response_iterator = provider.client_module.models.generate_content_stream(
|
||||||
|
model=f"models/{model}",
|
||||||
|
contents=google_messages,
|
||||||
|
config=generation_config,
|
||||||
|
)
|
||||||
|
logger.debug("generate_content_stream call successful, yielding from iterator.")
|
||||||
|
# Yield from the SDK's iterator which produces GenerateContentResponse chunks
|
||||||
|
yield from response_iterator
|
||||||
|
except ValueError as ve:
|
||||||
|
error_msg = f"Google API request validation error: {ve}"
|
||||||
|
logger.error(error_msg, exc_info=True)
|
||||||
|
# Yield error as a dict matching non-streaming error structure
|
||||||
|
yield {"error": error_msg, "traceback": traceback.format_exc()}
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Google API error during stream chat completion: {e}"
|
||||||
|
logger.error(error_msg, exc_info=True)
|
||||||
|
# Yield error as a dict
|
||||||
|
yield {"error": error_msg, "traceback": traceback.format_exc()}
|
||||||
|
|
||||||
|
|
||||||
|
# --- Main Function ---
|
||||||
|
# Renamed original function to avoid conflict if needed, though overwrite is fine
|
||||||
|
def create_chat_completion(
|
||||||
|
provider, # Provider instance is passed in
|
||||||
messages: list[dict[str, Any]],
|
messages: list[dict[str, Any]],
|
||||||
model: str,
|
model: str,
|
||||||
temperature: float = 0.4,
|
temperature: float = 0.6,
|
||||||
max_tokens: int | None = None,
|
max_tokens: int | None = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
tools: list[dict[str, Any]] | None = None,
|
tools: list[dict[str, Any]] | None = None, # Expects intermediate dict format
|
||||||
) -> Any:
|
) -> Any: # Return type depends on stream flag
|
||||||
"""
|
"""
|
||||||
Creates a chat completion using the Google Gemini API.
|
Creates a chat completion using the Google Gemini API.
|
||||||
|
Delegates to streaming or non-streaming helpers. Contains NO yield itself.
|
||||||
Args:
|
|
||||||
provider: The instance of the GoogleProvider.
|
|
||||||
messages: A list of message dictionaries in the standard format.
|
|
||||||
model: The model ID to use (e.g., "gemini-1.5-flash").
|
|
||||||
temperature: The sampling temperature.
|
|
||||||
max_tokens: The maximum number of tokens to generate.
|
|
||||||
stream: Whether to stream the response.
|
|
||||||
tools: A list of tool definitions in the MCP format.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The response object from the Google API (could be a stream iterator or
|
|
||||||
a GenerateContentResponse object), or an error dictionary/iterator.
|
|
||||||
"""
|
"""
|
||||||
logger.debug(f"Google create_chat_completion called. Model: {model}, Stream: {stream}, Tools: {bool(tools)}")
|
logger.debug(f"Google create_chat_completion_inner called. Model: {model}, Stream: {stream}, Tools: {bool(tools)}")
|
||||||
|
|
||||||
|
# Check if client exists on the provider instance
|
||||||
if provider.client_module is None:
|
if provider.client_module is None:
|
||||||
error_msg = "Google Generative AI SDK not configured or installed."
|
error_msg = "Google Generative AI client not initialized on provider."
|
||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
# Return an error structure compatible with both streaming and non-streaming expectations
|
# Return error dict directly for non-stream, create iterator for stream
|
||||||
if stream:
|
return iter([{"error": error_msg}]) if stream else {"error": error_msg}
|
||||||
return iter([json.dumps({"error": error_msg})])
|
|
||||||
else:
|
|
||||||
return {"error": error_msg}
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 1. Convert messages to Google's format
|
# 1. Convert messages (Common logic)
|
||||||
google_messages, system_prompt = convert_messages(messages)
|
google_messages, system_prompt = convert_messages(messages)
|
||||||
logger.debug(f"Converted {len(messages)} messages to {len(google_messages)} Google Content objects. System prompt present: {bool(system_prompt)}")
|
logger.debug(f"Converted {len(messages)} messages to {len(google_messages)} Google Content objects. System prompt present: {bool(system_prompt)}")
|
||||||
|
|
||||||
# 2. Prepare generation configuration
|
# 2. Prepare generation configuration (Common logic)
|
||||||
generation_config: dict[str, Any] = {"temperature": temperature}
|
# Use GenerationConfigDict for better type hinting if possible
|
||||||
|
generation_config: GenerationConfigDict = {"temperature": temperature}
|
||||||
if max_tokens is not None:
|
if max_tokens is not None:
|
||||||
# Google uses 'max_output_tokens'
|
|
||||||
generation_config["max_output_tokens"] = max_tokens
|
generation_config["max_output_tokens"] = max_tokens
|
||||||
logger.debug(f"Setting max_output_tokens: {max_tokens}")
|
logger.debug(f"Setting max_output_tokens: {max_tokens}")
|
||||||
else:
|
else:
|
||||||
logger.debug("No max_tokens specified.")
|
# Google requires max_output_tokens, set a default if None
|
||||||
|
# Defaulting to a reasonable value, e.g., 8192, check model limits if needed
|
||||||
|
default_max_tokens = 8192
|
||||||
|
generation_config["max_output_tokens"] = default_max_tokens
|
||||||
|
logger.warning(f"max_tokens not provided, defaulting to {default_max_tokens} for Google API.")
|
||||||
|
|
||||||
# 3. Convert tools if provided
|
# 3. Convert tools if provided (Common logic)
|
||||||
google_tool_objects: list[Tool] | None = None
|
google_tool_objects: list[Tool] | None = None
|
||||||
if tools:
|
if tools:
|
||||||
try:
|
try:
|
||||||
# Step 3a: Convert MCP tools to intermediate Google dict format
|
# Convert intermediate dict format to Google Tool objects
|
||||||
tool_dict_list = convert_to_google_tools(tools)
|
google_tool_objects = convert_to_google_tool_objects(tools)
|
||||||
# Step 3b: Convert intermediate dict format to Google Tool objects
|
|
||||||
google_tool_objects = convert_to_google_tool_objects(tool_dict_list)
|
|
||||||
if google_tool_objects:
|
if google_tool_objects:
|
||||||
logger.debug(f"Successfully converted {len(tools)} MCP tools to {len(google_tool_objects)} Google Tool objects.")
|
num_declarations = sum(len(t.function_declarations) for t in google_tool_objects if t.function_declarations)
|
||||||
|
logger.debug(f"Successfully converted intermediate tool config to {len(google_tool_objects)} Google Tool objects with {num_declarations} declarations.")
|
||||||
else:
|
else:
|
||||||
logger.warning("Tool conversion resulted in no valid Google Tool objects.")
|
logger.warning("Tool conversion resulted in no valid Google Tool objects.")
|
||||||
except Exception as tool_conv_err:
|
except Exception as tool_conv_err:
|
||||||
logger.error(f"Failed to convert tools for Google: {tool_conv_err}", exc_info=True)
|
logger.error(f"Failed to convert tools for Google: {tool_conv_err}", exc_info=True)
|
||||||
# Decide whether to proceed without tools or raise an error
|
google_tool_objects = None # Continue without tools on conversion error
|
||||||
# Proceeding without tools for now, but logging the error.
|
|
||||||
google_tool_objects = None
|
|
||||||
else:
|
else:
|
||||||
logger.debug("No tools provided for conversion.")
|
logger.debug("No tools provided for conversion.")
|
||||||
|
|
||||||
# 4. Initialize the Google Generative Model
|
# 4. Add system prompt and tools to generation_config (Common logic)
|
||||||
# Ensure client_module is callable and has GenerativeModel
|
if system_prompt:
|
||||||
if not hasattr(provider.client_module, "GenerativeModel"):
|
# Ensure system_instruction is ContentDict or compatible type
|
||||||
raise AttributeError("Configured Google client module does not have 'GenerativeModel'")
|
generation_config["system_instruction"] = system_prompt
|
||||||
|
logger.debug("Added system_instruction to generation_config.")
|
||||||
|
if google_tool_objects:
|
||||||
|
# Assign the list of Tool objects directly
|
||||||
|
generation_config["tools"] = google_tool_objects
|
||||||
|
logger.debug(f"Added {len(google_tool_objects)} tool objects to generation_config.")
|
||||||
|
|
||||||
gemini_model = provider.client_module.GenerativeModel(
|
# 5. Log parameters before API call (Common logic)
|
||||||
model_name=model,
|
|
||||||
system_instruction=system_prompt, # Pass extracted system prompt
|
|
||||||
tools=google_tool_objects, # Pass converted Tool objects (or None)
|
|
||||||
# Add safety_settings if needed: safety_settings=...
|
|
||||||
)
|
|
||||||
logger.debug(f"Initialized Google GenerativeModel for '{model}'.")
|
|
||||||
|
|
||||||
# 5. Log parameters before API call
|
|
||||||
log_params = {
|
log_params = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"stream": stream,
|
"stream": stream,
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
"max_output_tokens": generation_config.get("max_output_tokens"),
|
"max_output_tokens": generation_config.get("max_output_tokens"),
|
||||||
"system_prompt_present": bool(system_prompt),
|
"system_prompt_present": bool(system_prompt),
|
||||||
"num_tools": len(google_tool_objects) if google_tool_objects else 0,
|
"num_tools": len(generation_config.get("tools", [])) if "tools" in generation_config else 0,
|
||||||
"num_messages": len(google_messages),
|
"num_messages": len(google_messages),
|
||||||
}
|
}
|
||||||
logger.info(f"Calling Google generate_content API with params: {log_params}")
|
logger.info(f"Calling Google API via helper with params: {log_params}")
|
||||||
# Avoid logging full message content unless necessary for debugging specific issues
|
|
||||||
# logger.debug(f"Google messages being sent: {google_messages}")
|
|
||||||
|
|
||||||
# 6. Call the Google API
|
# 6. Delegate to appropriate helper
|
||||||
response = gemini_model.generate_content(
|
|
||||||
contents=google_messages,
|
|
||||||
generation_config=generation_config,
|
|
||||||
stream=stream,
|
|
||||||
# tool_config={"function_calling_config": "AUTO"} # AUTO is default
|
|
||||||
)
|
|
||||||
logger.debug("Google API call successful, returning response object.")
|
|
||||||
return response
|
|
||||||
|
|
||||||
except ValueError as ve: # Catch specific errors like invalid message sequence
|
|
||||||
error_msg = f"Google API request validation error: {ve}"
|
|
||||||
logger.error(error_msg, exc_info=True)
|
|
||||||
if stream:
|
if stream:
|
||||||
# Yield a JSON error message in an iterator
|
# Return the generator/iterator from the streaming helper
|
||||||
yield json.dumps({"error": error_msg, "traceback": traceback.format_exc()})
|
# This helper uses 'yield from'
|
||||||
|
return _create_chat_completion_stream(provider, model, google_messages, generation_config)
|
||||||
else:
|
else:
|
||||||
# Return an error dictionary
|
# Return the direct result (GenerateContentResponse or error dict) from the non-streaming helper
|
||||||
return {"error": error_msg, "traceback": traceback.format_exc()}
|
# This helper uses 'return'
|
||||||
|
return _create_chat_completion_non_stream(provider, model, google_messages, generation_config)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Catch any other exceptions during setup or API call
|
# Catch errors during common setup (message/tool conversion etc.)
|
||||||
error_msg = f"Google API error during chat completion: {e}"
|
error_msg = f"Error during Google completion setup: {e}"
|
||||||
logger.error(error_msg, exc_info=True)
|
logger.error(error_msg, exc_info=True)
|
||||||
if stream:
|
# Return error dict directly for non-stream, create iterator for stream
|
||||||
# Yield a JSON error message in an iterator
|
return iter([{"error": error_msg, "traceback": traceback.format_exc()}]) if stream else {"error": error_msg, "traceback": traceback.format_exc()}
|
||||||
yield json.dumps({"error": error_msg, "traceback": traceback.format_exc()})
|
|
||||||
else:
|
|
||||||
# Return an error dictionary
|
|
||||||
return {"error": error_msg, "traceback": traceback.format_exc()}
|
|
||||||
|
|||||||
@@ -124,37 +124,46 @@ def get_content(response: GenerateContentResponse | dict[str, Any]) -> str:
|
|||||||
The concatenated text content, or an error message string.
|
The concatenated text content, or an error message string.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Handle error dictionary case
|
# Check if it's an error dictionary passed from upstream (e.g., completion helper)
|
||||||
if isinstance(response, dict) and "error" in response:
|
if isinstance(response, dict) and "error" in response:
|
||||||
logger.error(f"Cannot get content from error response: {response['error']}")
|
logger.error(f"Cannot get content from error dict: {response['error']}")
|
||||||
return f"[Error: {response['error']}]"
|
return f"[Error: {response['error']}]"
|
||||||
|
|
||||||
# Handle successful GenerateContentResponse object
|
# Ensure it's a GenerateContentResponse object before accessing attributes
|
||||||
if hasattr(response, "text"):
|
if not isinstance(response, GenerateContentResponse):
|
||||||
# The `.text` attribute usually provides the concatenated text content directly
|
logger.error(f"Cannot get content: Expected GenerateContentResponse or error dict, got {type(response)}")
|
||||||
|
return f"[Error: Unexpected response type {type(response)}]"
|
||||||
|
|
||||||
|
# --- Access GenerateContentResponse attributes ---
|
||||||
|
# Prioritize response.text if available and not empty
|
||||||
|
if hasattr(response, "text") and response.text:
|
||||||
content = response.text
|
content = response.text
|
||||||
logger.debug(f"Extracted content (length {len(content)}) from response.text.")
|
logger.debug(f"Extracted content (length {len(content)}) from response.text.")
|
||||||
return content
|
return content
|
||||||
elif hasattr(response, "candidates") and response.candidates:
|
|
||||||
# Fallback: manually concatenate text from parts if .text is missing
|
# Fallback: manually concatenate text from parts if .text is missing/empty
|
||||||
|
if hasattr(response, "candidates") and response.candidates:
|
||||||
first_candidate = response.candidates[0]
|
first_candidate = response.candidates[0]
|
||||||
if hasattr(first_candidate, "content") and hasattr(first_candidate.content, "parts"):
|
# Check candidate content and parts carefully
|
||||||
text_parts = []
|
if hasattr(first_candidate, "content") and first_candidate.content and hasattr(first_candidate.content, "parts") and first_candidate.content.parts:
|
||||||
for part in first_candidate.content.parts:
|
text_parts = [part.text for part in first_candidate.content.parts if hasattr(part, "text")]
|
||||||
if hasattr(part, "text"):
|
if text_parts:
|
||||||
text_parts.append(part.text)
|
content = "".join(text_parts)
|
||||||
# We are only interested in text content here, ignore function calls etc.
|
logger.debug(f"Extracted content (length {len(content)}) from response candidate parts.")
|
||||||
content = "".join(text_parts)
|
return content
|
||||||
logger.debug(f"Extracted content (length {len(content)}) from response candidates' parts.")
|
else:
|
||||||
return content
|
logger.warning("Google response candidate parts contained no text.")
|
||||||
|
return "" # Return empty if parts exist but have no text
|
||||||
else:
|
else:
|
||||||
logger.warning("Google response candidate has no content or parts.")
|
logger.warning("Google response candidate has no valid content or parts.")
|
||||||
return "" # Return empty string if no text found
|
return "" # Return empty string if no valid content/parts
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Could not extract content from Google response: No 'text' or valid 'candidates'. Response type: {type(response)}")
|
# If neither .text nor valid candidates are found
|
||||||
|
logger.warning(f"Could not extract content from Google response: No .text or valid candidates found. Response: {response}")
|
||||||
return "" # Return empty string if no text found
|
return "" # Return empty string if no text found
|
||||||
|
|
||||||
except AttributeError as ae:
|
except AttributeError as ae:
|
||||||
logger.error(f"Attribute error extracting content from Google response: {ae}. Response object: {response}", exc_info=True)
|
logger.error(f"Attribute error extracting content from Google response: {ae}. Response type: {type(response)}", exc_info=True)
|
||||||
return f"[Error extracting content: Attribute missing - {str(ae)}]"
|
return f"[Error extracting content: Attribute missing - {str(ae)}]"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unexpected error extracting content from Google response: {e}", exc_info=True)
|
logger.error(f"Unexpected error extracting content from Google response: {e}", exc_info=True)
|
||||||
@@ -173,32 +182,34 @@ def get_usage(response: GenerateContentResponse | dict[str, Any]) -> dict[str, i
|
|||||||
usage information is unavailable or an error occurred.
|
usage information is unavailable or an error occurred.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Handle error dictionary case
|
# Check if it's an error dictionary passed from upstream
|
||||||
if isinstance(response, dict) and "error" in response:
|
if isinstance(response, dict) and "error" in response:
|
||||||
logger.warning("Cannot get usage from error response.")
|
logger.warning(f"Cannot get usage from error dict: {response['error']}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Check for usage metadata in the response object
|
# Ensure it's a GenerateContentResponse object before accessing attributes
|
||||||
if hasattr(response, "usage_metadata"):
|
if not isinstance(response, GenerateContentResponse):
|
||||||
metadata = response.usage_metadata
|
logger.warning(f"Cannot get usage: Expected GenerateContentResponse or error dict, got {type(response)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Safely access usage metadata
|
||||||
|
metadata = getattr(response, "usage_metadata", None)
|
||||||
|
if metadata:
|
||||||
# Google uses prompt_token_count and candidates_token_count
|
# Google uses prompt_token_count and candidates_token_count
|
||||||
|
prompt_tokens = getattr(metadata, "prompt_token_count", 0)
|
||||||
|
completion_tokens = getattr(metadata, "candidates_token_count", 0)
|
||||||
usage = {
|
usage = {
|
||||||
"prompt_tokens": getattr(metadata, "prompt_token_count", 0),
|
"prompt_tokens": int(prompt_tokens),
|
||||||
"completion_tokens": getattr(metadata, "candidates_token_count", 0),
|
"completion_tokens": int(completion_tokens),
|
||||||
# Google also provides total_token_count, could be added if needed
|
|
||||||
# "total_tokens": getattr(metadata, "total_token_count", 0),
|
|
||||||
}
|
}
|
||||||
# Ensure values are integers
|
|
||||||
usage = {k: int(v) for k, v in usage.items()}
|
|
||||||
logger.debug(f"Extracted usage from Google response metadata: {usage}")
|
logger.debug(f"Extracted usage from Google response metadata: {usage}")
|
||||||
return usage
|
return usage
|
||||||
else:
|
else:
|
||||||
# Log a warning only if it's not clearly an error dict already handled
|
logger.warning(f"Could not extract usage from Google response object: No 'usage_metadata' attribute found. Response: {response}")
|
||||||
if not (isinstance(response, dict) and "error" in response):
|
|
||||||
logger.warning(f"Could not extract usage from Google response object of type {type(response)}. No 'usage_metadata' attribute found.")
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
except AttributeError as ae:
|
except AttributeError as ae:
|
||||||
logger.error(f"Attribute error extracting usage from Google response: {ae}. Response object: {response}", exc_info=True)
|
logger.error(f"Attribute error extracting usage from Google response: {ae}. Response type: {type(response)}", exc_info=True)
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unexpected error extracting usage from Google response: {e}", exc_info=True)
|
logger.error(f"Unexpected error extracting usage from Google response: {e}", exc_info=True)
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from google.genai.types import FunctionDeclaration, Schema, Tool
|
from google.genai.types import FunctionDeclaration, Schema, Tool, Type
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -95,50 +95,66 @@ def _create_google_schema_recursive(schema_dict: dict[str, Any]) -> Schema | Non
|
|||||||
|
|
||||||
Handles type mapping and nested structures. Returns None on failure.
|
Handles type mapping and nested structures. Returns None on failure.
|
||||||
"""
|
"""
|
||||||
if Schema is None:
|
if Schema is None or Type is None:
|
||||||
logger.error("Cannot create Schema object: google.genai types not available.")
|
logger.error("Cannot create Schema object: google.genai types (Schema or Type) not available.")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if not isinstance(schema_dict, dict):
|
if not isinstance(schema_dict, dict):
|
||||||
logger.warning(f"Invalid schema part encountered: {schema_dict}. Returning empty schema.")
|
logger.warning(f"Invalid schema part encountered: {schema_dict}. Returning None.")
|
||||||
return Schema() # Return empty schema to avoid breaking the parent
|
return None # Return None on invalid input
|
||||||
|
|
||||||
# Map JSON Schema types to Google's Type enum strings
|
# Map JSON Schema types to Google's Type enum members
|
||||||
type_mapping = {
|
type_mapping = {
|
||||||
"string": "STRING",
|
"string": Type.STRING,
|
||||||
"number": "NUMBER",
|
"number": Type.NUMBER,
|
||||||
"integer": "INTEGER",
|
"integer": Type.INTEGER,
|
||||||
"boolean": "BOOLEAN",
|
"boolean": Type.BOOLEAN,
|
||||||
"array": "ARRAY",
|
"array": Type.ARRAY,
|
||||||
"object": "OBJECT",
|
"object": Type.OBJECT,
|
||||||
# Add other mappings if necessary
|
|
||||||
}
|
}
|
||||||
original_type = schema_dict.get("type")
|
original_type = schema_dict.get("type")
|
||||||
google_type = type_mapping.get(str(original_type).lower()) if original_type else None
|
google_type = type_mapping.get(str(original_type).lower()) if original_type else None
|
||||||
|
|
||||||
|
if not google_type:
|
||||||
|
logger.warning(f"Schema dictionary missing 'type' or type '{original_type}' is not recognized: {schema_dict}. Returning None.")
|
||||||
|
return None # Return None if type is invalid/missing
|
||||||
|
|
||||||
# Prepare arguments for Schema constructor, filtering out None values
|
# Prepare arguments for Schema constructor, filtering out None values
|
||||||
schema_args = {
|
schema_args = {
|
||||||
"type": google_type,
|
"type": google_type, # Use the Type enum member
|
||||||
"format": schema_dict.get("format"),
|
"format": schema_dict.get("format"),
|
||||||
"description": schema_dict.get("description"),
|
"description": schema_dict.get("description"),
|
||||||
"nullable": schema_dict.get("nullable"),
|
"nullable": schema_dict.get("nullable"), # Note: Google's Schema might not directly support nullable in constructor
|
||||||
"enum": schema_dict.get("enum"),
|
"enum": schema_dict.get("enum"),
|
||||||
"items": _create_google_schema_recursive(schema_dict["items"]) if "items" in schema_dict and google_type == "ARRAY" else None,
|
# Recursively create nested schemas, ensuring None is handled if recursion fails
|
||||||
"properties": {k: _create_google_schema_recursive(v) for k, v in schema_dict.get("properties", {}).items()} if schema_dict.get("properties") and google_type == "OBJECT" else None,
|
"items": _create_google_schema_recursive(schema_dict["items"]) if google_type == Type.ARRAY and "items" in schema_dict else None,
|
||||||
"required": schema_dict.get("required") if google_type == "OBJECT" else None,
|
"properties": {k: prop_schema for k, v in schema_dict.get("properties", {}).items() if (prop_schema := _create_google_schema_recursive(v)) is not None}
|
||||||
|
if google_type == Type.OBJECT and schema_dict.get("properties")
|
||||||
|
else None,
|
||||||
|
"required": schema_dict.get("required") if google_type == Type.OBJECT else None,
|
||||||
}
|
}
|
||||||
# Remove keys with None values
|
|
||||||
|
# Remove keys with None values before passing to Schema constructor
|
||||||
schema_args = {k: v for k, v in schema_args.items() if v is not None}
|
schema_args = {k: v for k, v in schema_args.items() if v is not None}
|
||||||
|
|
||||||
if not schema_args.get("type"):
|
# Handle specific cases for ARRAY and OBJECT where items/properties might be needed
|
||||||
logger.warning(f"Schema dictionary missing 'type' or type '{original_type}' is not recognized: {schema_dict}. Creating empty Schema.")
|
if google_type == Type.ARRAY and "items" not in schema_args:
|
||||||
return Schema() # Return empty schema
|
logger.warning(f"Array schema missing 'items': {schema_dict}. Returning None.")
|
||||||
|
return None # Array schema requires items
|
||||||
|
if google_type == Type.OBJECT and "properties" not in schema_args:
|
||||||
|
# Allow object schema without properties initially, might be handled later
|
||||||
|
pass
|
||||||
|
# logger.warning(f"Object schema missing 'properties': {schema_dict}. Creating empty properties.")
|
||||||
|
# schema_args["properties"] = {} # Or return None if properties are strictly required
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return Schema(**schema_args)
|
# Create the Schema object
|
||||||
|
created_schema = Schema(**schema_args)
|
||||||
|
# logger.debug(f"Successfully created Schema: {created_schema}")
|
||||||
|
return created_schema
|
||||||
except Exception as schema_creation_err:
|
except Exception as schema_creation_err:
|
||||||
logger.error(f"Failed to create Schema object with args {schema_args}: {schema_creation_err}", exc_info=True)
|
logger.error(f"Failed to create Schema object with args {schema_args}: {schema_creation_err}", exc_info=True)
|
||||||
return Schema() # Return empty schema on error
|
return None # Return None on creation error
|
||||||
|
|
||||||
|
|
||||||
def convert_to_google_tool_objects(tool_configs: list[dict[str, Any]]) -> list[Tool] | None:
|
def convert_to_google_tool_objects(tool_configs: list[dict[str, Any]]) -> list[Tool] | None:
|
||||||
@@ -169,31 +185,71 @@ def convert_to_google_tool_objects(tool_configs: list[dict[str, Any]]) -> list[T
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
for func_dict in func_declarations_list:
|
for func_dict in func_declarations_list:
|
||||||
|
func_name = func_dict.get("name", "Unknown")
|
||||||
try:
|
try:
|
||||||
params_schema_dict = func_dict.get("parameters", {"type": "object", "properties": {}})
|
params_schema_dict = func_dict.get("parameters", {})
|
||||||
# Ensure parameters is a valid schema dict for the recursive creator
|
|
||||||
|
# Ensure parameters is a dict and defaults to object type if missing
|
||||||
if not isinstance(params_schema_dict, dict):
|
if not isinstance(params_schema_dict, dict):
|
||||||
logger.warning(f"Invalid 'parameters' format for tool {func_dict.get('name')}: {params_schema_dict}. Using empty object schema.")
|
logger.warning(f"Invalid 'parameters' format for tool {func_name}: {params_schema_dict}. Using empty object schema.")
|
||||||
params_schema_dict = {"type": "object", "properties": {}}
|
params_schema_dict = {"type": "object", "properties": {}}
|
||||||
elif params_schema_dict.get("type") != "object":
|
elif "type" not in params_schema_dict:
|
||||||
logger.warning(f"Tool {func_dict.get('name')} parameters schema is not type 'object'. Forcing object type.")
|
params_schema_dict["type"] = "object" # Default to object if type is missing
|
||||||
params_schema_dict = {"type": "object", "properties": params_schema_dict.get("properties", {})} # Attempt to salvage properties
|
elif params_schema_dict["type"] != "object":
|
||||||
|
logger.warning(f"Tool {func_name} parameters schema is not type 'object' ({params_schema_dict.get('type')}). Google requires 'object'. Attempting to wrap properties.")
|
||||||
|
# Attempt to salvage properties if the top level isn't object
|
||||||
|
original_properties = params_schema_dict.get("properties", {})
|
||||||
|
if not isinstance(original_properties, dict):
|
||||||
|
original_properties = {}
|
||||||
|
params_schema_dict = {"type": "object", "properties": original_properties}
|
||||||
|
|
||||||
parameters_schema = _create_google_schema_recursive(params_schema_dict)
|
properties_dict = params_schema_dict.get("properties", {})
|
||||||
|
google_properties = {}
|
||||||
# Only proceed if schema creation was somewhat successful
|
if isinstance(properties_dict, dict):
|
||||||
if parameters_schema is not None:
|
for prop_name, prop_schema_dict in properties_dict.items():
|
||||||
declaration = FunctionDeclaration(
|
prop_schema = _create_google_schema_recursive(prop_schema_dict)
|
||||||
name=func_dict["name"],
|
if prop_schema:
|
||||||
description=func_dict.get("description", ""),
|
google_properties[prop_name] = prop_schema
|
||||||
parameters=parameters_schema,
|
else:
|
||||||
)
|
logger.warning(f"Failed to create schema for property '{prop_name}' in tool '{func_name}'. Skipping property.")
|
||||||
all_func_declarations.append(declaration)
|
|
||||||
else:
|
else:
|
||||||
logger.error(f"Failed to create parameters Schema for FunctionDeclaration '{func_dict.get('name', 'Unknown')}'")
|
logger.warning(f"'properties' for tool {func_name} is not a dictionary: {properties_dict}. Ignoring properties.")
|
||||||
|
|
||||||
|
# Handle empty properties - Google requires parameters to be OBJECT, and properties cannot be null/empty
|
||||||
|
if not google_properties:
|
||||||
|
logger.warning(f"Function '{func_name}' has no valid properties defined. Adding dummy property for Google compatibility.")
|
||||||
|
google_properties = {"_dummy_param": Schema(type=Type.STRING, description="Placeholder parameter as properties cannot be empty.")}
|
||||||
|
# Clear required list if properties are empty/dummy
|
||||||
|
required_list = []
|
||||||
|
else:
|
||||||
|
# Validate required list against actual properties
|
||||||
|
original_required = params_schema_dict.get("required", [])
|
||||||
|
if isinstance(original_required, list):
|
||||||
|
required_list = [req for req in original_required if req in google_properties]
|
||||||
|
if len(required_list) != len(original_required):
|
||||||
|
logger.warning(f"Some required properties for '{func_name}' were invalid or missing from properties: {set(original_required) - set(required_list)}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"'required' field for '{func_name}' is not a list: {original_required}. Ignoring required field.")
|
||||||
|
required_list = []
|
||||||
|
|
||||||
|
# Create the top-level parameters schema, ensuring it's OBJECT type
|
||||||
|
parameters_schema = Schema(
|
||||||
|
type=Type.OBJECT,
|
||||||
|
properties=google_properties,
|
||||||
|
required=required_list if required_list else None, # Pass None if empty list
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the FunctionDeclaration
|
||||||
|
declaration = FunctionDeclaration(
|
||||||
|
name=func_name,
|
||||||
|
description=func_dict.get("description", ""),
|
||||||
|
parameters=parameters_schema,
|
||||||
|
)
|
||||||
|
all_func_declarations.append(declaration)
|
||||||
|
logger.debug(f"Successfully created FunctionDeclaration for: {func_name}")
|
||||||
|
|
||||||
except Exception as decl_err:
|
except Exception as decl_err:
|
||||||
logger.error(f"Failed to create FunctionDeclaration object for tool '{func_dict.get('name', 'Unknown')}': {decl_err}", exc_info=True)
|
logger.error(f"Failed to create FunctionDeclaration object for tool '{func_name}': {decl_err}", exc_info=True)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logger.error(f"Invalid tool_configs structure provided: {tool_configs}")
|
logger.error(f"Invalid tool_configs structure provided: {tool_configs}")
|
||||||
|
|||||||
@@ -108,16 +108,16 @@ def convert_messages(messages: list[dict[str, Any]]) -> tuple[list[Content], str
|
|||||||
|
|
||||||
# Include any text content alongside the function calls
|
# Include any text content alongside the function calls
|
||||||
if content and isinstance(content, str):
|
if content and isinstance(content, str):
|
||||||
parts.append(Part.from_text(content))
|
parts.append(Part(text=content)) # Use direct instantiation
|
||||||
|
|
||||||
elif content:
|
elif content:
|
||||||
# Regular user or assistant message content
|
# Regular user or assistant message content
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
parts.append(Part.from_text(content))
|
parts.append(Part(text=content)) # Use direct instantiation
|
||||||
# TODO: Handle potential image content if needed in the future
|
# TODO: Handle potential image content if needed in the future
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Unsupported content type for role '{role}': {type(content)}. Converting to string.")
|
logger.warning(f"Unsupported content type for role '{role}': {type(content)}. Converting to string.")
|
||||||
parts.append(Part.from_text(str(content)))
|
parts.append(Part(text=str(content))) # Use direct instantiation
|
||||||
|
|
||||||
# Add the constructed Content object if parts were generated
|
# Add the constructed Content object if parts were generated
|
||||||
if parts:
|
if parts:
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ class OpenAIProvider(BaseProvider):
|
|||||||
self,
|
self,
|
||||||
messages: list[dict[str, str]],
|
messages: list[dict[str, str]],
|
||||||
model: str,
|
model: str,
|
||||||
temperature: float = 0.4,
|
temperature: float = 0.6,
|
||||||
max_tokens: int | None = None,
|
max_tokens: int | None = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
tools: list[dict[str, Any]] | None = None,
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ def create_chat_completion(
|
|||||||
provider, # The OpenAIProvider instance
|
provider, # The OpenAIProvider instance
|
||||||
messages: list[dict[str, str]],
|
messages: list[dict[str, str]],
|
||||||
model: str,
|
model: str,
|
||||||
temperature: float = 0.4,
|
temperature: float = 0.6,
|
||||||
max_tokens: int | None = None,
|
max_tokens: int | None = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
tools: list[dict[str, Any]] | None = None,
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
|||||||
Reference in New Issue
Block a user