feat: add GoogleProvider implementation and update conversion utilities for Google tools
This commit is contained in:
@@ -3,9 +3,9 @@ import logging
|
|||||||
|
|
||||||
from providers.anthropic_provider import AnthropicProvider
|
from providers.anthropic_provider import AnthropicProvider
|
||||||
from providers.base import BaseProvider
|
from providers.base import BaseProvider
|
||||||
|
from providers.google_provider import GoogleProvider
|
||||||
from providers.openai_provider import OpenAIProvider
|
from providers.openai_provider import OpenAIProvider
|
||||||
|
|
||||||
# from providers.google_provider import GoogleProvider
|
|
||||||
# from providers.openrouter_provider import OpenRouterProvider
|
# from providers.openrouter_provider import OpenRouterProvider
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
|
|||||||
PROVIDER_MAP: dict[str, type[BaseProvider]] = {
|
PROVIDER_MAP: dict[str, type[BaseProvider]] = {
|
||||||
"openai": OpenAIProvider,
|
"openai": OpenAIProvider,
|
||||||
"anthropic": AnthropicProvider,
|
"anthropic": AnthropicProvider,
|
||||||
# "google": GoogleProvider,
|
"google": GoogleProvider,
|
||||||
# "openrouter": OpenRouterProvider, # OpenRouter can often use OpenAIProvider with custom base_url
|
# "openrouter": OpenRouterProvider, # OpenRouter can often use OpenAIProvider with custom base_url
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
483
src/providers/google_provider.py
Normal file
483
src/providers/google_provider.py
Normal file
@@ -0,0 +1,483 @@
|
|||||||
|
# src/providers/google_provider.py
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import traceback
|
||||||
|
from collections.abc import Generator
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from google import genai
|
||||||
|
from google.genai.types import (
|
||||||
|
Content,
|
||||||
|
FunctionDeclaration,
|
||||||
|
Part,
|
||||||
|
Schema,
|
||||||
|
Tool,
|
||||||
|
)
|
||||||
|
|
||||||
|
from src.llm_models import MODELS
|
||||||
|
from src.providers.base import BaseProvider
|
||||||
|
from src.tools.conversion import convert_to_google_tools
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleProvider(BaseProvider):
|
||||||
|
"""Provider implementation for Google Gemini models."""
|
||||||
|
|
||||||
|
def __init__(self, api_key: str, base_url: str | None = None):
|
||||||
|
# Google client typically doesn't use a base_url, but we accept it for consistency
|
||||||
|
effective_base_url = base_url or MODELS.get("google", {}).get("endpoint")
|
||||||
|
super().__init__(api_key, effective_base_url)
|
||||||
|
logger.info("Initializing GoogleProvider")
|
||||||
|
|
||||||
|
if genai is None:
|
||||||
|
raise ImportError("Google Generative AI SDK is required for GoogleProvider. Please install google-generativeai.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Configure the client
|
||||||
|
genai.configure(api_key=self.api_key)
|
||||||
|
self.client_module = genai
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to configure Google Generative AI client: {e}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _get_context_window(self, model: str) -> int:
|
||||||
|
"""Retrieves the context window size for a given Google model."""
|
||||||
|
default_window = 1000000 # Default fallback for Gemini
|
||||||
|
try:
|
||||||
|
provider_models = MODELS.get("google", {}).get("models", [])
|
||||||
|
for m in provider_models:
|
||||||
|
if m.get("id") == model:
|
||||||
|
return m.get("context_window", default_window)
|
||||||
|
logger.warning(f"Context window for Google model '{model}' not found in MODELS config. Using default: {default_window}")
|
||||||
|
return default_window
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error retrieving context window for model {model}: {e}. Using default: {default_window}", exc_info=True)
|
||||||
|
return default_window
|
||||||
|
|
||||||
|
def _convert_messages(self, messages: list[dict[str, Any]]) -> tuple[list[Content], str | None]:
|
||||||
|
"""
|
||||||
|
Converts standard message format to Google's format, extracting system prompt.
|
||||||
|
Handles mapping roles and structuring tool calls/results.
|
||||||
|
"""
|
||||||
|
google_messages: list[Content] = []
|
||||||
|
system_prompt: str | None = None
|
||||||
|
|
||||||
|
for i, message in enumerate(messages):
|
||||||
|
role = message.get("role")
|
||||||
|
content = message.get("content")
|
||||||
|
tool_calls = message.get("tool_calls")
|
||||||
|
tool_call_id = message.get("tool_call_id")
|
||||||
|
|
||||||
|
if role == "system":
|
||||||
|
if i == 0:
|
||||||
|
system_prompt = content
|
||||||
|
logger.debug("Extracted system prompt for Google.")
|
||||||
|
else:
|
||||||
|
logger.warning("System message found not at the beginning. Merging into subsequent user message.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
google_role = {"user": "user", "assistant": "model", "tool": "user"}.get(role)
|
||||||
|
|
||||||
|
if not google_role:
|
||||||
|
logger.warning(f"Unsupported role '{role}' for Google provider, skipping message.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
parts: list[Part | str] = []
|
||||||
|
if role == "tool":
|
||||||
|
if tool_call_id and content:
|
||||||
|
try:
|
||||||
|
response_content_dict = json.loads(content)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning(f"Could not decode tool result content for {tool_call_id}, sending as raw string.")
|
||||||
|
response_content_dict = {"result": content}
|
||||||
|
|
||||||
|
func_name = "unknown_function"
|
||||||
|
if i > 0 and messages[i - 1].get("role") == "assistant":
|
||||||
|
prev_tool_calls = messages[i - 1].get("tool_calls")
|
||||||
|
if prev_tool_calls:
|
||||||
|
for tc in prev_tool_calls:
|
||||||
|
if tc.get("id") == tool_call_id:
|
||||||
|
func_name = tc.get("function_name", "unknown_function")
|
||||||
|
break
|
||||||
|
|
||||||
|
parts.append(Part.from_function_response(name=func_name, response={"content": response_content_dict}))
|
||||||
|
google_role = "function"
|
||||||
|
else:
|
||||||
|
logger.warning(f"Skipping tool message due to missing tool_call_id or content: {message}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
elif role == "assistant" and tool_calls:
|
||||||
|
for tool_call in tool_calls:
|
||||||
|
args = tool_call.get("arguments", {})
|
||||||
|
if isinstance(args, str):
|
||||||
|
try:
|
||||||
|
args = json.loads(args)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.error(f"Failed to parse arguments string for tool call {tool_call.get('id')}: {args}")
|
||||||
|
args = {"error": "failed to parse arguments"}
|
||||||
|
func_name = tool_call.get("function_name", "unknown_function")
|
||||||
|
parts.append(Part.from_function_call(name=func_name, args=args))
|
||||||
|
if content:
|
||||||
|
parts.append(Part.from_text(content))
|
||||||
|
|
||||||
|
elif content:
|
||||||
|
if isinstance(content, str):
|
||||||
|
parts.append(Part.from_text(content))
|
||||||
|
else:
|
||||||
|
logger.warning(f"Unsupported content type for role '{role}': {type(content)}. Converting to string.")
|
||||||
|
parts.append(Part.from_text(str(content)))
|
||||||
|
|
||||||
|
if parts:
|
||||||
|
google_messages.append(Content(role=google_role, parts=parts))
|
||||||
|
else:
|
||||||
|
logger.debug(f"No parts generated for message: {message}")
|
||||||
|
|
||||||
|
last_role = None
|
||||||
|
valid_alternation = True
|
||||||
|
for msg in google_messages:
|
||||||
|
current_role = msg.role
|
||||||
|
if current_role == last_role and current_role in ["user", "model"]:
|
||||||
|
valid_alternation = False
|
||||||
|
logger.warning(f"Invalid role sequence detected: consecutive '{current_role}' roles.")
|
||||||
|
break
|
||||||
|
if last_role == "function" and current_role != "user":
|
||||||
|
valid_alternation = False
|
||||||
|
logger.warning(f"Invalid role sequence: '{current_role}' follows 'function'. Expected 'user'.")
|
||||||
|
break
|
||||||
|
last_role = current_role
|
||||||
|
|
||||||
|
if not valid_alternation:
|
||||||
|
logger.error("Message list does not follow required user/model alternation for Google API.")
|
||||||
|
raise ValueError("Invalid message sequence for Google API.")
|
||||||
|
|
||||||
|
return google_messages, system_prompt
|
||||||
|
|
||||||
|
def create_chat_completion(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
model: str,
|
||||||
|
temperature: float = 0.4,
|
||||||
|
max_tokens: int | None = None,
|
||||||
|
stream: bool = True,
|
||||||
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
) -> Any:
|
||||||
|
"""Creates a chat completion using the Google Gemini API."""
|
||||||
|
logger.debug(f"Google create_chat_completion called. Model: {model}, Stream: {stream}, Tools: {bool(tools)}")
|
||||||
|
|
||||||
|
if self.client_module is None:
|
||||||
|
return {"error": "Google Generative AI SDK not installed."} if not stream else iter([json.dumps({"error": "Google Generative AI SDK not installed."})])
|
||||||
|
|
||||||
|
try:
|
||||||
|
google_messages, system_prompt = self._convert_messages(messages)
|
||||||
|
generation_config: dict[str, Any] = {"temperature": temperature}
|
||||||
|
if max_tokens is not None:
|
||||||
|
generation_config["max_output_tokens"] = max_tokens
|
||||||
|
|
||||||
|
google_tools = None
|
||||||
|
if tools:
|
||||||
|
try:
|
||||||
|
tool_dict_list = convert_to_google_tools(tools)
|
||||||
|
google_tools = self._convert_to_tool_objects(tool_dict_list)
|
||||||
|
logger.debug(f"Converted {len(tools)} tools to {len(google_tools)} Google Tool objects.")
|
||||||
|
except Exception as tool_conv_err:
|
||||||
|
logger.error(f"Failed to convert tools for Google: {tool_conv_err}", exc_info=True)
|
||||||
|
google_tools = None
|
||||||
|
|
||||||
|
gemini_model = self.client_module.GenerativeModel(
|
||||||
|
model_name=model,
|
||||||
|
system_instruction=system_prompt,
|
||||||
|
tools=google_tools if google_tools else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
log_params = {
|
||||||
|
"model": model,
|
||||||
|
"stream": stream,
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"system_prompt_present": bool(system_prompt),
|
||||||
|
"num_tools": len(google_tools) if google_tools else 0,
|
||||||
|
"num_messages": len(google_messages),
|
||||||
|
}
|
||||||
|
logger.debug(f"Calling Google API with params: {log_params}")
|
||||||
|
|
||||||
|
response = gemini_model.generate_content(
|
||||||
|
contents=google_messages,
|
||||||
|
generation_config=generation_config,
|
||||||
|
stream=stream,
|
||||||
|
)
|
||||||
|
logger.debug("Google API call successful.")
|
||||||
|
return response
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Google API error: {e}"
|
||||||
|
logger.error(error_msg, exc_info=True)
|
||||||
|
if stream:
|
||||||
|
yield json.dumps({"error": error_msg, "traceback": traceback.format_exc()})
|
||||||
|
else:
|
||||||
|
return {"error": error_msg, "traceback": traceback.format_exc()}
|
||||||
|
|
||||||
|
def get_streaming_content(self, response: Any) -> Generator[str, None, None]:
|
||||||
|
"""Yields content chunks from a Google streaming response."""
|
||||||
|
logger.debug("Processing Google stream...")
|
||||||
|
full_delta = ""
|
||||||
|
try:
|
||||||
|
if isinstance(response, dict) and "error" in response:
|
||||||
|
yield json.dumps(response)
|
||||||
|
return
|
||||||
|
if hasattr(response, "__iter__") and not hasattr(response, "candidates"):
|
||||||
|
yield from response
|
||||||
|
return
|
||||||
|
|
||||||
|
for chunk in response:
|
||||||
|
if isinstance(chunk, dict) and "error" in chunk:
|
||||||
|
yield json.dumps(chunk)
|
||||||
|
continue
|
||||||
|
if hasattr(chunk, "text"):
|
||||||
|
delta = chunk.text
|
||||||
|
if delta:
|
||||||
|
full_delta += delta
|
||||||
|
yield delta
|
||||||
|
elif hasattr(chunk, "candidates") and chunk.candidates:
|
||||||
|
for part in chunk.candidates[0].content.parts:
|
||||||
|
if hasattr(part, "function_call") and part.function_call:
|
||||||
|
logger.debug(f"Function call detected during stream: {part.function_call.name}")
|
||||||
|
break
|
||||||
|
|
||||||
|
logger.debug(f"Google stream finished. Total delta length: {len(full_delta)}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing Google stream: {e}", exc_info=True)
|
||||||
|
yield json.dumps({"error": f"Stream processing error: {str(e)}"})
|
||||||
|
|
||||||
|
def get_content(self, response: Any) -> str:
|
||||||
|
"""Extracts content from a non-streaming Google response."""
|
||||||
|
try:
|
||||||
|
if isinstance(response, dict) and "error" in response:
|
||||||
|
logger.error(f"Cannot get content from error response: {response['error']}")
|
||||||
|
return f"[Error: {response['error']}]"
|
||||||
|
if hasattr(response, "text"):
|
||||||
|
content = response.text
|
||||||
|
logger.debug(f"Extracted content (length {len(content)}) from response.text.")
|
||||||
|
return content
|
||||||
|
elif hasattr(response, "candidates") and response.candidates:
|
||||||
|
first_candidate = response.candidates[0]
|
||||||
|
if hasattr(first_candidate, "content") and hasattr(first_candidate.content, "parts"):
|
||||||
|
text_parts = [part.text for part in first_candidate.content.parts if hasattr(part, "text")]
|
||||||
|
content = "".join(text_parts)
|
||||||
|
logger.debug(f"Extracted content (length {len(content)}) from response candidates.")
|
||||||
|
return content
|
||||||
|
else:
|
||||||
|
logger.warning("Google response candidate has no content or parts.")
|
||||||
|
return ""
|
||||||
|
else:
|
||||||
|
logger.warning("Could not extract content from Google response: No 'text' or valid 'candidates'.")
|
||||||
|
return ""
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error extracting content from Google response: {e}", exc_info=True)
|
||||||
|
return f"[Error extracting content: {str(e)}]"
|
||||||
|
|
||||||
|
def has_tool_calls(self, response: Any) -> bool:
|
||||||
|
"""Checks if the Google response contains tool calls (function calls)."""
|
||||||
|
try:
|
||||||
|
if isinstance(response, dict) and "error" in response:
|
||||||
|
return False
|
||||||
|
if hasattr(response, "candidates") and response.candidates:
|
||||||
|
candidate = response.candidates[0]
|
||||||
|
if hasattr(candidate, "content") and hasattr(candidate.content, "parts"):
|
||||||
|
for part in candidate.content.parts:
|
||||||
|
if hasattr(part, "function_call") and part.function_call:
|
||||||
|
logger.debug(f"Tool call (FunctionCall) detected in Google response part: {part.function_call.name}")
|
||||||
|
return True
|
||||||
|
logger.debug("No tool calls (FunctionCall) detected in Google response.")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error checking for Google tool calls: {e}", exc_info=True)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def parse_tool_calls(self, response: Any) -> list[dict[str, Any]]:
|
||||||
|
"""Parses tool calls (function calls) from a non-streaming Google response."""
|
||||||
|
parsed_calls = []
|
||||||
|
try:
|
||||||
|
if not (hasattr(response, "candidates") and response.candidates):
|
||||||
|
logger.warning("Cannot parse tool calls: Response has no candidates.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
candidate = response.candidates[0]
|
||||||
|
if not (hasattr(candidate, "content") and hasattr(candidate.content, "parts")):
|
||||||
|
logger.warning("Cannot parse tool calls: Response candidate has no content or parts.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
logger.debug("Parsing tool calls (FunctionCall) from Google response.")
|
||||||
|
call_index = 0
|
||||||
|
for part in candidate.content.parts:
|
||||||
|
if hasattr(part, "function_call") and part.function_call:
|
||||||
|
func_call = part.function_call
|
||||||
|
call_id = f"call_{call_index}"
|
||||||
|
call_index += 1
|
||||||
|
|
||||||
|
full_name = func_call.name
|
||||||
|
parts = full_name.split("__", 1)
|
||||||
|
if len(parts) == 2:
|
||||||
|
server_name, func_name = parts
|
||||||
|
else:
|
||||||
|
logger.warning(f"Could not determine server_name from Google tool name '{full_name}'.")
|
||||||
|
server_name = None
|
||||||
|
func_name = full_name
|
||||||
|
|
||||||
|
try:
|
||||||
|
args_str = json.dumps(func_call.args or {})
|
||||||
|
except Exception as json_err:
|
||||||
|
logger.error(f"Failed to dump arguments dict to JSON string for {func_name}: {json_err}")
|
||||||
|
args_str = json.dumps({"error": "Failed to serialize arguments", "original_args": str(func_call.args)})
|
||||||
|
|
||||||
|
parsed_calls.append({
|
||||||
|
"id": call_id,
|
||||||
|
"server_name": server_name,
|
||||||
|
"function_name": func_name,
|
||||||
|
"arguments": args_str,
|
||||||
|
})
|
||||||
|
logger.debug(f"Parsed tool call: ID {call_id}, Server {server_name}, Func {func_name}, Args {args_str[:100]}...")
|
||||||
|
|
||||||
|
return parsed_calls
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error parsing Google tool calls: {e}", exc_info=True)
|
||||||
|
return []
|
||||||
|
|
||||||
|
def format_tool_results(self, tool_call_id: str, result: Any) -> dict[str, Any]:
|
||||||
|
"""Formats a tool result for a Google follow-up request."""
|
||||||
|
try:
|
||||||
|
if isinstance(result, dict):
|
||||||
|
content_str = json.dumps(result)
|
||||||
|
else:
|
||||||
|
content_str = str(result)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error JSON-encoding tool result for Google {tool_call_id}: {e}")
|
||||||
|
content_str = json.dumps({"error": "Failed to encode tool result", "original_type": str(type(result))})
|
||||||
|
|
||||||
|
logger.debug(f"Formatting Google tool result for call ID {tool_call_id}")
|
||||||
|
return {
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": tool_call_id,
|
||||||
|
"content": content_str,
|
||||||
|
"function_name": "unknown_function",
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_original_message_with_calls(self, response: Any) -> dict[str, Any]:
|
||||||
|
"""Extracts the assistant's message containing tool calls for Google."""
|
||||||
|
try:
|
||||||
|
if not (hasattr(response, "candidates") and response.candidates):
|
||||||
|
return {"role": "assistant", "content": "[Could not extract tool calls message: No candidates]"}
|
||||||
|
|
||||||
|
candidate = response.candidates[0]
|
||||||
|
if not (hasattr(candidate, "content") and hasattr(candidate.content, "parts")):
|
||||||
|
return {"role": "assistant", "content": "[Could not extract tool calls message: No content/parts]"}
|
||||||
|
|
||||||
|
tool_calls_formatted = []
|
||||||
|
text_content_parts = []
|
||||||
|
for part in candidate.content.parts:
|
||||||
|
if hasattr(part, "function_call") and part.function_call:
|
||||||
|
func_call = part.function_call
|
||||||
|
args = func_call.args or {}
|
||||||
|
tool_calls_formatted.append({
|
||||||
|
"function_name": func_call.name,
|
||||||
|
"arguments": args,
|
||||||
|
})
|
||||||
|
elif hasattr(part, "text"):
|
||||||
|
text_content_parts.append(part.text)
|
||||||
|
|
||||||
|
message = {"role": "assistant"}
|
||||||
|
if tool_calls_formatted:
|
||||||
|
message["tool_calls"] = tool_calls_formatted
|
||||||
|
text_content = "".join(text_content_parts)
|
||||||
|
if text_content:
|
||||||
|
message["content"] = text_content
|
||||||
|
elif not tool_calls_formatted:
|
||||||
|
message["content"] = ""
|
||||||
|
|
||||||
|
return message
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error extracting original Google message with calls: {e}", exc_info=True)
|
||||||
|
return {"role": "assistant", "content": f"[Error extracting tool calls message: {str(e)}]"}
|
||||||
|
|
||||||
|
def get_usage(self, response: Any) -> dict[str, int] | None:
|
||||||
|
"""Extracts token usage from a Google response."""
|
||||||
|
try:
|
||||||
|
if isinstance(response, dict) and "error" in response:
|
||||||
|
return None
|
||||||
|
if hasattr(response, "usage_metadata"):
|
||||||
|
metadata = response.usage_metadata
|
||||||
|
usage = {
|
||||||
|
"prompt_tokens": getattr(metadata, "prompt_token_count", 0),
|
||||||
|
"completion_tokens": getattr(metadata, "candidates_token_count", 0),
|
||||||
|
}
|
||||||
|
logger.debug(f"Extracted usage from Google response metadata: {usage}")
|
||||||
|
return usage
|
||||||
|
else:
|
||||||
|
logger.warning(f"Could not extract usage from Google response object of type {type(response)}. No 'usage_metadata'.")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error extracting usage from Google response: {e}", exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _convert_to_tool_objects(self, tool_configs: list[dict[str, Any]]) -> list[Tool] | None:
|
||||||
|
"""Convert dictionary-format tools into Google's Tool objects."""
|
||||||
|
if not tool_configs:
|
||||||
|
return None
|
||||||
|
|
||||||
|
all_func_declarations = []
|
||||||
|
for config in tool_configs:
|
||||||
|
if "function_declarations" in config:
|
||||||
|
for func_dict in config["function_declarations"]:
|
||||||
|
try:
|
||||||
|
params_schema_dict = func_dict.get("parameters", {"type": "object", "properties": {}})
|
||||||
|
if params_schema_dict.get("type") != "object":
|
||||||
|
logger.warning(f"Tool {func_dict['name']} parameters schema is not type 'object'. Forcing object type.")
|
||||||
|
params_schema_dict = {"type": "object", "properties": params_schema_dict}
|
||||||
|
|
||||||
|
def create_schema(schema_dict):
|
||||||
|
if not isinstance(schema_dict, dict):
|
||||||
|
logger.warning(f"Invalid schema part encountered: {schema_dict}. Returning empty schema.")
|
||||||
|
return Schema()
|
||||||
|
schema_args = {
|
||||||
|
"type": schema_dict.get("type"),
|
||||||
|
"format": schema_dict.get("format"),
|
||||||
|
"description": schema_dict.get("description"),
|
||||||
|
"nullable": schema_dict.get("nullable"),
|
||||||
|
"enum": schema_dict.get("enum"),
|
||||||
|
"items": create_schema(schema_dict["items"]) if "items" in schema_dict else None,
|
||||||
|
"properties": {k: create_schema(v) for k, v in schema_dict.get("properties", {}).items()} if schema_dict.get("properties") else None,
|
||||||
|
"required": schema_dict.get("required"),
|
||||||
|
}
|
||||||
|
schema_args = {k: v for k, v in schema_args.items() if v is not None}
|
||||||
|
if "type" in schema_args:
|
||||||
|
type_mapping = {
|
||||||
|
"string": "STRING",
|
||||||
|
"number": "NUMBER",
|
||||||
|
"integer": "INTEGER",
|
||||||
|
"boolean": "BOOLEAN",
|
||||||
|
"array": "ARRAY",
|
||||||
|
"object": "OBJECT",
|
||||||
|
}
|
||||||
|
schema_args["type"] = type_mapping.get(str(schema_args["type"]).lower(), schema_args["type"])
|
||||||
|
try:
|
||||||
|
return Schema(**schema_args)
|
||||||
|
except Exception as schema_creation_err:
|
||||||
|
logger.error(f"Failed to create Schema object for {func_dict['name']} with args {schema_args}: {schema_creation_err}", exc_info=True)
|
||||||
|
return Schema()
|
||||||
|
|
||||||
|
parameters_schema = create_schema(params_schema_dict)
|
||||||
|
declaration = FunctionDeclaration(
|
||||||
|
name=func_dict["name"],
|
||||||
|
description=func_dict.get("description", ""),
|
||||||
|
parameters=parameters_schema,
|
||||||
|
)
|
||||||
|
all_func_declarations.append(declaration)
|
||||||
|
except Exception as decl_err:
|
||||||
|
logger.error(f"Failed to create FunctionDeclaration for tool '{func_dict.get('name', 'Unknown')}': {decl_err}", exc_info=True)
|
||||||
|
|
||||||
|
if not all_func_declarations:
|
||||||
|
logger.warning("No valid function declarations found after conversion.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return [Tool(function_declarations=all_func_declarations)]
|
||||||
@@ -164,11 +164,12 @@ def convert_to_google_tools(mcp_tools: list[dict[str, Any]]) -> list[dict[str, A
|
|||||||
function_declarations.append(function_declaration)
|
function_declarations.append(function_declaration)
|
||||||
logger.debug(f"Converted MCP tool to Google FunctionDeclaration: {prefixed_tool_name}")
|
logger.debug(f"Converted MCP tool to Google FunctionDeclaration: {prefixed_tool_name}")
|
||||||
|
|
||||||
# Google API expects a list containing one Tool object dict
|
# Google API expects a list containing one dictionary with 'function_declarations'
|
||||||
google_tools_wrapper = [{"function_declarations": function_declarations}] if function_declarations else []
|
# The provider's _convert_to_tool_objects will handle creating Tool objects from this.
|
||||||
|
google_tool_config = [{"function_declarations": function_declarations}] if function_declarations else []
|
||||||
|
|
||||||
logger.debug(f"Final Google tools structure: {google_tools_wrapper}")
|
logger.debug(f"Final Google tool config structure: {google_tool_config}")
|
||||||
return google_tools_wrapper
|
return google_tool_config
|
||||||
|
|
||||||
|
|
||||||
# Note: The _handle_schema_construct helper from the reference code is not strictly
|
# Note: The _handle_schema_construct helper from the reference code is not strictly
|
||||||
|
|||||||
Reference in New Issue
Block a user