Compare commits
2 Commits
15ecb9fc48
...
ab8d5fe074
| Author | SHA1 | Date | |
|---|---|---|---|
|
ab8d5fe074
|
|||
|
246d921743
|
@@ -3,9 +3,9 @@ import logging
|
||||
|
||||
from providers.anthropic_provider import AnthropicProvider
|
||||
from providers.base import BaseProvider
|
||||
from providers.google_provider import GoogleProvider
|
||||
from providers.openai_provider import OpenAIProvider
|
||||
|
||||
# from providers.google_provider import GoogleProvider
|
||||
# from providers.openrouter_provider import OpenRouterProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
|
||||
PROVIDER_MAP: dict[str, type[BaseProvider]] = {
|
||||
"openai": OpenAIProvider,
|
||||
"anthropic": AnthropicProvider,
|
||||
# "google": GoogleProvider,
|
||||
"google": GoogleProvider,
|
||||
# "openrouter": OpenRouterProvider, # OpenRouter can often use OpenAIProvider with custom base_url
|
||||
}
|
||||
|
||||
|
||||
@@ -1,453 +0,0 @@
|
||||
# src/providers/anthropic_provider.py
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from anthropic import Anthropic, APIError, Stream
|
||||
from anthropic.types import Message, MessageStreamEvent, TextDelta
|
||||
|
||||
from providers.base import BaseProvider
|
||||
from src.llm_models import MODELS
|
||||
from src.tools.conversion import convert_to_anthropic_tools
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnthropicProvider(BaseProvider):
|
||||
"""Provider implementation for Anthropic Claude models."""
|
||||
|
||||
def __init__(self, api_key: str, base_url: str | None = None):
|
||||
# Anthropic client doesn't use base_url in the same way, but store it if needed
|
||||
# Use default Anthropic endpoint if base_url is not provided or relevant
|
||||
effective_base_url = base_url or MODELS.get("anthropic", {}).get("endpoint")
|
||||
super().__init__(api_key, effective_base_url) # Pass base_url to parent, though Anthropic client might ignore it
|
||||
logger.info("Initializing AnthropicProvider")
|
||||
try:
|
||||
self.client = Anthropic(api_key=self.api_key)
|
||||
# Note: Anthropic client doesn't take base_url during init
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize Anthropic client: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def _get_context_window(self, model: str) -> int:
|
||||
"""Retrieves the context window size for a given Anthropic model."""
|
||||
default_window = 100000 # Default fallback for Anthropic
|
||||
try:
|
||||
provider_models = MODELS.get("anthropic", {}).get("models", [])
|
||||
for m in provider_models:
|
||||
if m.get("id") == model:
|
||||
return m.get("context_window", default_window)
|
||||
logger.warning(f"Context window for Anthropic model '{model}' not found in MODELS config. Using default: {default_window}")
|
||||
return default_window
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving context window for model {model}: {e}. Using default: {default_window}", exc_info=True)
|
||||
return default_window
|
||||
|
||||
def _count_anthropic_tokens(self, messages: list[dict[str, Any]], system_prompt: str | None) -> int:
|
||||
"""Counts tokens for Anthropic messages using the official client."""
|
||||
# Note: Anthropic's count_tokens might not directly accept the message list format used for creation.
|
||||
# It often expects plain text. We need to concatenate the content appropriately.
|
||||
# This is a simplification and might not be perfectly accurate, especially with tool calls/results.
|
||||
# A more robust approach might involve formatting messages into a single string representation.
|
||||
text_to_count = ""
|
||||
if system_prompt:
|
||||
text_to_count += f"System: {system_prompt}\n\n"
|
||||
for message in messages:
|
||||
role = message.get("role")
|
||||
content = message.get("content")
|
||||
# Simple concatenation - might need refinement for complex content types (tool calls/results)
|
||||
if isinstance(content, str):
|
||||
text_to_count += f"{role}: {content}\n"
|
||||
elif isinstance(content, list): # Handle tool results/calls if represented as list
|
||||
try:
|
||||
content_str = json.dumps(content)
|
||||
text_to_count += f"{role}: {content_str}\n"
|
||||
except Exception:
|
||||
text_to_count += f"{role}: [Unserializable Content]\n"
|
||||
|
||||
try:
|
||||
# Use the client's count_tokens method if available and works with text
|
||||
# Check Anthropic documentation for the correct usage
|
||||
# Assuming self.client.count_tokens exists and takes text
|
||||
count = self.client.count_tokens(text=text_to_count)
|
||||
logger.debug(f"Counted Anthropic tokens using client.count_tokens: {count}")
|
||||
return count
|
||||
except APIError as api_err:
|
||||
# Handle potential errors if count_tokens itself is an API call or fails
|
||||
logger.error(f"Anthropic API error during token count: {api_err}", exc_info=True)
|
||||
# Fallback to approximation if official count fails?
|
||||
estimated_tokens = math.ceil(len(text_to_count) / 4.0) # Same approximation as OpenAI
|
||||
logger.warning(f"Falling back to character count approximation for Anthropic: {estimated_tokens}")
|
||||
return estimated_tokens
|
||||
except AttributeError:
|
||||
# Fallback if count_tokens method doesn't exist or works differently
|
||||
logger.warning("self.client.count_tokens not available or failed. Falling back to character count approximation.")
|
||||
estimated_tokens = math.ceil(len(text_to_count) / 4.0) # Same approximation as OpenAI
|
||||
return estimated_tokens
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during Anthropic token count: {e}", exc_info=True)
|
||||
estimated_tokens = math.ceil(len(text_to_count) / 4.0) # Fallback approximation
|
||||
logger.warning(f"Falling back to character count approximation due to unexpected error: {estimated_tokens}")
|
||||
return estimated_tokens
|
||||
|
||||
def _truncate_messages(self, messages: list[dict[str, Any]], system_prompt: str | None, model: str) -> tuple[list[dict[str, Any]], str | None, int, int]:
|
||||
"""
|
||||
Truncates messages for Anthropic, preserving system prompt.
|
||||
|
||||
Returns:
|
||||
- Potentially truncated list of messages.
|
||||
- Original system prompt (or None).
|
||||
- Initial token count.
|
||||
- Final token count.
|
||||
"""
|
||||
context_limit = self._get_context_window(model)
|
||||
buffer = 200 # Safety buffer
|
||||
effective_limit = context_limit - buffer
|
||||
|
||||
initial_token_count = self._count_anthropic_tokens(messages, system_prompt)
|
||||
final_token_count = initial_token_count
|
||||
|
||||
truncated_messages = list(messages) # Copy
|
||||
|
||||
# Anthropic requires alternating user/assistant messages. Truncation needs care.
|
||||
# We remove from the beginning (after potential system prompt).
|
||||
# Removing the oldest message (index 0 of the list passed here, as system is separate)
|
||||
|
||||
while final_token_count > effective_limit and len(truncated_messages) > 0:
|
||||
# Always remove the oldest message (index 0)
|
||||
removed_message = truncated_messages.pop(0)
|
||||
logger.debug(f"Truncating Anthropic message at index 0 (Role: {removed_message.get('role')}) due to context limit.")
|
||||
|
||||
# Ensure alternation after removal if possible (might be complex)
|
||||
# For simplicity, just remove and recount for now.
|
||||
# A more robust approach might need to remove pairs (user/assistant).
|
||||
|
||||
final_token_count = self._count_anthropic_tokens(truncated_messages, system_prompt)
|
||||
logger.debug(f"Recalculated Anthropic tokens: {final_token_count}")
|
||||
|
||||
# Safety break
|
||||
if not truncated_messages:
|
||||
logger.warning("Truncation resulted in empty message list for Anthropic.")
|
||||
break
|
||||
|
||||
if initial_token_count != final_token_count:
|
||||
logger.info(
|
||||
f"Truncated messages for Anthropic model {model}. Initial tokens: {initial_token_count}, Final tokens: {final_token_count}, Limit: {context_limit} (Effective: {effective_limit})"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"No truncation needed for Anthropic model {model}. Tokens: {final_token_count}, Limit: {context_limit} (Effective: {effective_limit})")
|
||||
|
||||
# Ensure the remaining messages start with 'user' role if no system prompt
|
||||
if not system_prompt and truncated_messages and truncated_messages[0].get("role") != "user":
|
||||
logger.warning("First message after truncation is not 'user'. Prepending placeholder.")
|
||||
# This might indicate an issue with the simple pop(0) logic if pairs weren't removed.
|
||||
# For now, prepend a basic user message.
|
||||
truncated_messages.insert(0, {"role": "user", "content": "[Context truncated]"})
|
||||
# Recount after adding placeholder? Might exceed limit again. Risky.
|
||||
# Let's log a warning instead of adding potentially problematic content.
|
||||
# logger.warning("First message after truncation is not 'user'. This might cause issues with Anthropic API.")
|
||||
|
||||
return truncated_messages, system_prompt, initial_token_count, final_token_count
|
||||
|
||||
def _convert_messages(self, messages: list[dict[str, Any]]) -> tuple[str | None, list[dict[str, Any]]]:
|
||||
"""Converts standard message format to Anthropic's format, extracting system prompt."""
|
||||
anthropic_messages = []
|
||||
system_prompt = None
|
||||
for i, message in enumerate(messages):
|
||||
role = message.get("role")
|
||||
content = message.get("content")
|
||||
|
||||
if role == "system":
|
||||
if i == 0:
|
||||
system_prompt = content
|
||||
logger.debug("Extracted system prompt for Anthropic.")
|
||||
else:
|
||||
# Handle system message not at the start (append to previous user message or add as user)
|
||||
logger.warning("System message found not at the beginning. Treating as user message.")
|
||||
anthropic_messages.append({"role": "user", "content": f"[System Note]\n{content}"})
|
||||
continue
|
||||
|
||||
# Handle tool results specifically
|
||||
if role == "tool":
|
||||
# Find the preceding assistant message with the corresponding tool_use block
|
||||
# This requires careful handling in the follow-up logic
|
||||
tool_use_id = message.get("tool_call_id")
|
||||
tool_content = content
|
||||
# Format as a tool_result content block
|
||||
anthropic_messages.append({"role": "user", "content": [{"type": "tool_result", "tool_use_id": tool_use_id, "content": tool_content}]})
|
||||
continue
|
||||
|
||||
# Handle assistant message potentially containing tool_use blocks
|
||||
if role == "assistant":
|
||||
# Check if content is structured (e.g., from a previous tool call response)
|
||||
if isinstance(content, list): # Assuming tool calls might be represented as a list
|
||||
anthropic_messages.append({"role": "assistant", "content": content})
|
||||
else:
|
||||
anthropic_messages.append({"role": "assistant", "content": content}) # Regular text content
|
||||
continue
|
||||
|
||||
# Regular user messages
|
||||
if role == "user":
|
||||
anthropic_messages.append({"role": "user", "content": content})
|
||||
continue
|
||||
|
||||
logger.warning(f"Unsupported role '{role}' in message conversion for Anthropic.")
|
||||
|
||||
# Ensure conversation starts with a user message if no system prompt was used
|
||||
if not system_prompt and anthropic_messages and anthropic_messages[0]["role"] != "user":
|
||||
logger.warning("Anthropic conversation must start with a user message. Prepending empty user message.")
|
||||
anthropic_messages.insert(0, {"role": "user", "content": "[Start of conversation]"}) # Or handle differently
|
||||
|
||||
return system_prompt, anthropic_messages
|
||||
|
||||
def create_chat_completion(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
model: str,
|
||||
temperature: float = 0.4,
|
||||
max_tokens: int | None = None, # Anthropic requires max_tokens
|
||||
stream: bool = True,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
) -> Stream[MessageStreamEvent] | Message:
|
||||
"""Creates a chat completion using the Anthropic API, handling context truncation."""
|
||||
logger.debug(f"Anthropic create_chat_completion called. Model: {model}, Stream: {stream}, Tools: {bool(tools)}")
|
||||
|
||||
# --- Context Truncation ---
|
||||
# First, convert to Anthropic format to separate system prompt
|
||||
temp_system_prompt, temp_anthropic_messages = self._convert_messages(messages)
|
||||
# Then, truncate based on token count
|
||||
truncated_anthropic_msgs, final_system_prompt, _, _ = self._truncate_messages(temp_anthropic_messages, temp_system_prompt, model)
|
||||
# --------------------------
|
||||
|
||||
# Anthropic requires max_tokens
|
||||
if max_tokens is None:
|
||||
max_tokens = 4096 # Default value if not provided
|
||||
logger.warning(f"max_tokens not provided for Anthropic, defaulting to {max_tokens}")
|
||||
|
||||
# system_prompt, anthropic_messages = self._convert_messages(messages) # Moved above
|
||||
|
||||
try:
|
||||
completion_params = {
|
||||
"model": model,
|
||||
"messages": truncated_anthropic_msgs, # Use truncated messages
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"stream": stream,
|
||||
}
|
||||
if final_system_prompt: # Use potentially modified system prompt
|
||||
completion_params["system"] = final_system_prompt
|
||||
if tools:
|
||||
completion_params["tools"] = tools
|
||||
# Anthropic doesn't have an explicit 'tool_choice' like OpenAI's 'auto' in the main API call
|
||||
|
||||
# Remove None values (though Anthropic requires max_tokens)
|
||||
completion_params = {k: v for k, v in completion_params.items() if v is not None}
|
||||
|
||||
log_params = completion_params.copy()
|
||||
if "messages" in log_params:
|
||||
log_params["messages"] = [{k: (v[:100] + "..." if isinstance(v, str) and len(v) > 100 else v) for k, v in msg.items()} for msg in log_params["messages"][-2:]]
|
||||
tools_log = log_params.get("tools", "Not Present")
|
||||
logger.debug(f"Calling Anthropic API. Model: {log_params.get('model')}, Stream: {log_params.get('stream')}, System: {bool(log_params.get('system'))}, Tools: {tools_log}")
|
||||
logger.debug(f"Full API Params (messages summarized): {log_params}")
|
||||
|
||||
response = self.client.messages.create(**completion_params)
|
||||
logger.debug("Anthropic API call successful.")
|
||||
|
||||
# --- Capture Actual Usage ---
|
||||
actual_usage = None
|
||||
if isinstance(response, Message) and response.usage:
|
||||
actual_usage = {
|
||||
"prompt_tokens": response.usage.input_tokens, # Anthropic uses input_tokens
|
||||
"completion_tokens": response.usage.output_tokens, # Anthropic uses output_tokens
|
||||
# Anthropic doesn't typically provide total_tokens directly in usage block
|
||||
"total_tokens": response.usage.input_tokens + response.usage.output_tokens,
|
||||
}
|
||||
logger.info(f"Actual Anthropic API usage: {actual_usage}")
|
||||
# TODO: How to get usage for streaming responses? Anthropic might send it in a final 'message_stop' event? Needs investigation.
|
||||
|
||||
return response
|
||||
# --------------------------
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Anthropic API error: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def get_streaming_content(self, response: Stream[MessageStreamEvent]) -> Generator[str, None, None]:
|
||||
"""Yields content chunks from an Anthropic streaming response."""
|
||||
logger.debug("Processing Anthropic stream...")
|
||||
full_delta = ""
|
||||
try:
|
||||
# Iterate through events in the stream
|
||||
for event in response:
|
||||
if event.type == "content_block_delta":
|
||||
# Check if the delta is for text content before accessing .text
|
||||
if isinstance(event.delta, TextDelta):
|
||||
delta_text = event.delta.text
|
||||
if delta_text:
|
||||
full_delta += delta_text
|
||||
yield delta_text
|
||||
# Ignore other delta types like InputJSONDelta for text streaming
|
||||
# Other event types like 'message_start', 'content_block_start', etc., can be logged or handled if needed
|
||||
elif event.type == "message_start":
|
||||
logger.debug(f"Anthropic stream started. Model: {event.message.model}")
|
||||
elif event.type == "message_stop":
|
||||
# The stop_reason might be available on the 'message' object associated with the stream,
|
||||
# not directly on the stop event itself. We log that the stop event occurred.
|
||||
# Accessing the actual reason might require inspecting the final message state if needed.
|
||||
logger.debug("Anthropic stream message_stop event received.")
|
||||
elif event.type == "content_block_start":
|
||||
if event.content_block.type == "tool_use":
|
||||
logger.debug(f"Anthropic stream detected tool use start: ID {event.content_block.id}, Name: {event.content_block.name}")
|
||||
elif event.type == "content_block_stop":
|
||||
logger.debug(f"Anthropic stream detected content block stop. Index: {event.index}")
|
||||
|
||||
logger.debug(f"Anthropic stream finished. Total delta length: {len(full_delta)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing Anthropic stream: {e}", exc_info=True)
|
||||
yield json.dumps({"error": f"Stream processing error: {str(e)}"})
|
||||
|
||||
def get_content(self, response: Message) -> str:
|
||||
"""Extracts content from a non-streaming Anthropic response."""
|
||||
try:
|
||||
# Combine text content from all text blocks
|
||||
text_content = "".join([block.text for block in response.content if block.type == "text"])
|
||||
logger.debug(f"Extracted content (length {len(text_content)}) from non-streaming Anthropic response.")
|
||||
return text_content
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting content from Anthropic response: {e}", exc_info=True)
|
||||
return f"[Error extracting content: {str(e)}]"
|
||||
|
||||
def has_tool_calls(self, response: Stream[MessageStreamEvent] | Message) -> bool:
|
||||
"""Checks if the Anthropic response contains tool calls."""
|
||||
try:
|
||||
if isinstance(response, Message): # Non-streaming
|
||||
# Check stop reason and content blocks
|
||||
has_tool_use_block = any(block.type == "tool_use" for block in response.content)
|
||||
has_calls = response.stop_reason == "tool_use" or has_tool_use_block
|
||||
logger.debug(f"Non-streaming Anthropic response check: stop_reason='{response.stop_reason}', has_tool_use_block={has_tool_use_block}. Result: {has_calls}")
|
||||
return has_calls
|
||||
elif isinstance(response, Stream):
|
||||
# Cannot reliably check an unconsumed stream without consuming it.
|
||||
# The LLMClient should handle this by checking after consumption or based on stop_reason if available post-stream.
|
||||
logger.warning("has_tool_calls check on an Anthropic stream is unreliable before consumption.")
|
||||
return False
|
||||
else:
|
||||
logger.warning(f"has_tool_calls received unexpected type for Anthropic: {type(response)}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking for Anthropic tool calls: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def parse_tool_calls(self, response: Message) -> list[dict[str, Any]]:
|
||||
"""Parses tool calls from a non-streaming Anthropic response."""
|
||||
parsed_calls = []
|
||||
try:
|
||||
if not isinstance(response, Message):
|
||||
logger.error(f"parse_tool_calls expects Anthropic Message, got {type(response)}")
|
||||
return []
|
||||
|
||||
if response.stop_reason != "tool_use":
|
||||
logger.debug("No tool use indicated by stop_reason.")
|
||||
# return [] # Might still have tool_use blocks even if stop_reason isn't tool_use? Check API docs. Let's check content anyway.
|
||||
|
||||
tool_use_blocks = [block for block in response.content if block.type == "tool_use"]
|
||||
if not tool_use_blocks:
|
||||
logger.debug("No 'tool_use' content blocks found in Anthropic response.")
|
||||
return []
|
||||
|
||||
logger.debug(f"Parsing {len(tool_use_blocks)} 'tool_use' blocks from Anthropic response.")
|
||||
for block in tool_use_blocks:
|
||||
# Adapt server/tool name splitting if needed (similar to OpenAI provider)
|
||||
# Assuming Anthropic tool names might also be prefixed like "server__tool"
|
||||
parts = block.name.split("__", 1)
|
||||
if len(parts) == 2:
|
||||
server_name, func_name = parts
|
||||
else:
|
||||
logger.warning(f"Could not determine server_name from Anthropic tool name '{block.name}'.")
|
||||
server_name = None
|
||||
func_name = block.name
|
||||
|
||||
parsed_calls.append({
|
||||
"id": block.id,
|
||||
"server_name": server_name,
|
||||
"function_name": func_name,
|
||||
"arguments": json.dumps(block.input), # Anthropic input is already a dict, dump to string like OpenAI provider expects? Or keep as dict? Let's keep as dict for now.
|
||||
# "arguments": block.input, # Keep as dict? Let's try this first.
|
||||
})
|
||||
|
||||
return parsed_calls
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing Anthropic tool calls: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
def format_tool_results(self, tool_call_id: str, result: Any) -> dict[str, Any]:
|
||||
"""Formats a tool result for an Anthropic follow-up request."""
|
||||
# Anthropic expects a 'tool_result' content block
|
||||
# The content of the result block should typically be a string.
|
||||
try:
|
||||
if isinstance(result, dict):
|
||||
content_str = json.dumps(result)
|
||||
else:
|
||||
content_str = str(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error JSON-encoding tool result for Anthropic {tool_call_id}: {e}")
|
||||
content_str = json.dumps({"error": "Failed to encode tool result", "original_type": str(type(result))})
|
||||
|
||||
logger.debug(f"Formatting Anthropic tool result for call ID {tool_call_id}")
|
||||
# This needs to be placed inside a "user" role message's content list
|
||||
return {
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tool_call_id,
|
||||
"content": content_str,
|
||||
# Optionally add is_error=True if result indicates an error
|
||||
}
|
||||
|
||||
def convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Converts internal tool format to Anthropic's format."""
|
||||
# Use the conversion function, assuming it's correctly placed and imported
|
||||
logger.debug(f"Converting {len(tools)} tools to Anthropic format.")
|
||||
try:
|
||||
# The conversion function needs to handle the server__tool prefixing
|
||||
anthropic_tools = convert_to_anthropic_tools(tools)
|
||||
logger.debug(f"Tool conversion result: {anthropic_tools}")
|
||||
return anthropic_tools
|
||||
except Exception as e:
|
||||
logger.error(f"Error during Anthropic tool conversion: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
# Helper needed by LLMClient's current tool handling logic (if adapting OpenAI's pattern)
|
||||
def get_original_message_with_calls(self, response: Message) -> dict[str, Any]:
|
||||
"""Extracts the assistant's message containing tool calls for Anthropic."""
|
||||
try:
|
||||
if isinstance(response, Message) and any(block.type == "tool_use" for block in response.content):
|
||||
# Anthropic's response structure is different. The 'message' itself is the assistant's turn.
|
||||
# We need to return a representation of this turn, including the tool_use blocks.
|
||||
# Convert Pydantic models within content to dicts
|
||||
content_list = [block.model_dump(exclude_unset=True) for block in response.content]
|
||||
return {"role": "assistant", "content": content_list}
|
||||
else:
|
||||
logger.warning("Could not extract original message with tool calls from Anthropic response.")
|
||||
return {"role": "assistant", "content": "[Could not extract tool calls message]"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting original Anthropic message with calls: {e}", exc_info=True)
|
||||
return {"role": "assistant", "content": f"[Error extracting tool calls message: {str(e)}]"}
|
||||
|
||||
def get_usage(self, response: Any) -> dict[str, int] | None:
|
||||
"""Extracts token usage from a non-streaming Anthropic response."""
|
||||
try:
|
||||
if isinstance(response, Message) and response.usage:
|
||||
usage = {
|
||||
"prompt_tokens": response.usage.input_tokens,
|
||||
"completion_tokens": response.usage.output_tokens,
|
||||
# "total_tokens": response.usage.input_tokens + response.usage.output_tokens, # Optional
|
||||
}
|
||||
logger.debug(f"Extracted usage from Anthropic response: {usage}")
|
||||
return usage
|
||||
else:
|
||||
logger.warning(f"Could not extract usage from Anthropic response object of type {type(response)}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting usage from Anthropic response: {e}", exc_info=True)
|
||||
return None
|
||||
34
src/providers/anthropic_provider/__init__.py
Normal file
34
src/providers/anthropic_provider/__init__.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from providers.anthropic_provider.client import initialize_client
|
||||
from providers.anthropic_provider.completion import create_chat_completion
|
||||
from providers.anthropic_provider.response import get_content, get_streaming_content, get_usage
|
||||
from providers.anthropic_provider.tools import convert_tools, format_tool_results, has_tool_calls, parse_tool_calls
|
||||
from providers.base import BaseProvider
|
||||
|
||||
|
||||
class AnthropicProvider(BaseProvider):
|
||||
def __init__(self, api_key: str, base_url: str | None = None):
|
||||
self.client = initialize_client(api_key, base_url)
|
||||
|
||||
def create_chat_completion(self, messages, model, temperature=0.4, max_tokens=None, stream=True, tools=None):
|
||||
return create_chat_completion(self, messages, model, temperature, max_tokens, stream, tools)
|
||||
|
||||
def get_streaming_content(self, response):
|
||||
return get_streaming_content(response)
|
||||
|
||||
def get_content(self, response):
|
||||
return get_content(response)
|
||||
|
||||
def has_tool_calls(self, response):
|
||||
return has_tool_calls(response)
|
||||
|
||||
def parse_tool_calls(self, response):
|
||||
return parse_tool_calls(response)
|
||||
|
||||
def format_tool_results(self, tool_call_id, result):
|
||||
return format_tool_results(tool_call_id, result)
|
||||
|
||||
def convert_tools(self, tools):
|
||||
return convert_tools(tools)
|
||||
|
||||
def get_usage(self, response):
|
||||
return get_usage(response)
|
||||
17
src/providers/anthropic_provider/client.py
Normal file
17
src/providers/anthropic_provider/client.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import logging
|
||||
|
||||
from anthropic import Anthropic
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def initialize_client(api_key: str, base_url: str | None = None) -> Anthropic:
|
||||
logger.info("Initializing Anthropic client")
|
||||
try:
|
||||
client = Anthropic(api_key=api_key)
|
||||
if base_url:
|
||||
logger.warning(f"base_url '{base_url}' provided but not used by Anthropic client")
|
||||
return client
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize Anthropic client: {e}", exc_info=True)
|
||||
raise
|
||||
38
src/providers/anthropic_provider/completion.py
Normal file
38
src/providers/anthropic_provider/completion.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from anthropic import Stream
|
||||
from anthropic.types import Message
|
||||
|
||||
from providers.anthropic_provider.messages import convert_messages, truncate_messages
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_chat_completion(
|
||||
provider, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, stream: bool = True, tools: list[dict[str, Any]] | None = None
|
||||
) -> Stream | Message:
|
||||
logger.debug(f"Creating Anthropic chat completion. Model: {model}, Stream: {stream}, Tools: {bool(tools)}")
|
||||
temp_system_prompt, temp_anthropic_messages = convert_messages(messages)
|
||||
truncated_messages, final_system_prompt, _, _ = truncate_messages(provider, temp_anthropic_messages, temp_system_prompt, model)
|
||||
if max_tokens is None:
|
||||
max_tokens = 4096
|
||||
logger.warning(f"max_tokens not provided, defaulting to {max_tokens}")
|
||||
completion_params = {
|
||||
"model": model,
|
||||
"messages": truncated_messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"stream": stream,
|
||||
}
|
||||
if final_system_prompt:
|
||||
completion_params["system"] = final_system_prompt
|
||||
if tools:
|
||||
completion_params["tools"] = tools
|
||||
try:
|
||||
response = provider.client.messages.create(**completion_params)
|
||||
logger.debug("Anthropic API call successful.")
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"Anthropic API error: {e}", exc_info=True)
|
||||
raise
|
||||
61
src/providers/anthropic_provider/messages.py
Normal file
61
src/providers/anthropic_provider/messages.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from providers.anthropic_provider.utils import count_anthropic_tokens, get_context_window
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def convert_messages(messages: list[dict[str, Any]]) -> tuple[str | None, list[dict[str, Any]]]:
|
||||
anthropic_messages = []
|
||||
system_prompt = None
|
||||
for i, message in enumerate(messages):
|
||||
role = message.get("role")
|
||||
content = message.get("content")
|
||||
if role == "system":
|
||||
if i == 0:
|
||||
system_prompt = content
|
||||
else:
|
||||
logger.warning("System message not at beginning. Treating as user message.")
|
||||
anthropic_messages.append({"role": "user", "content": f"[System Note]\n{content}"})
|
||||
continue
|
||||
if role == "tool":
|
||||
tool_use_id = message.get("tool_call_id")
|
||||
tool_content = content
|
||||
anthropic_messages.append({"role": "user", "content": [{"type": "tool_result", "tool_use_id": tool_use_id, "content": tool_content}]})
|
||||
continue
|
||||
if role == "assistant":
|
||||
if isinstance(content, list):
|
||||
anthropic_messages.append({"role": "assistant", "content": content})
|
||||
else:
|
||||
anthropic_messages.append({"role": "assistant", "content": content})
|
||||
continue
|
||||
if role == "user":
|
||||
anthropic_messages.append({"role": "user", "content": content})
|
||||
continue
|
||||
logger.warning(f"Unsupported role '{role}' in message conversion.")
|
||||
if not system_prompt and anthropic_messages and anthropic_messages[0]["role"] != "user":
|
||||
logger.warning("Conversation must start with user message. Prepending placeholder.")
|
||||
anthropic_messages.insert(0, {"role": "user", "content": "[Start of conversation]"})
|
||||
return system_prompt, anthropic_messages
|
||||
|
||||
|
||||
def truncate_messages(provider, messages: list[dict[str, Any]], system_prompt: str | None, model: str) -> tuple[list[dict[str, Any]], str | None, int, int]:
|
||||
context_limit = get_context_window(model)
|
||||
buffer = 200
|
||||
effective_limit = context_limit - buffer
|
||||
initial_token_count = count_anthropic_tokens(provider.client, messages, system_prompt)
|
||||
final_token_count = initial_token_count
|
||||
truncated_messages = list(messages)
|
||||
while final_token_count > effective_limit and len(truncated_messages) > 0:
|
||||
removed_message = truncated_messages.pop(0)
|
||||
logger.debug(f"Truncating message (Role: {removed_message.get('role')})")
|
||||
final_token_count = count_anthropic_tokens(provider.client, truncated_messages, system_prompt)
|
||||
if initial_token_count != final_token_count:
|
||||
logger.info(f"Truncated messages. Initial tokens: {initial_token_count}, Final: {final_token_count}")
|
||||
else:
|
||||
logger.debug(f"No truncation needed. Tokens: {final_token_count}")
|
||||
if not system_prompt and truncated_messages and truncated_messages[0].get("role") != "user":
|
||||
logger.warning("First message after truncation is not 'user'. Prepending placeholder.")
|
||||
truncated_messages.insert(0, {"role": "user", "content": "[Context truncated]"})
|
||||
return truncated_messages, system_prompt, initial_token_count, final_token_count
|
||||
62
src/providers/anthropic_provider/response.py
Normal file
62
src/providers/anthropic_provider/response.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from anthropic import Stream
|
||||
from anthropic.types import Message, MessageStreamEvent, TextDelta
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_streaming_content(response: Stream[MessageStreamEvent]) -> Generator[str, None, None]:
|
||||
logger.debug("Processing Anthropic stream...")
|
||||
full_delta = ""
|
||||
try:
|
||||
for event in response:
|
||||
if event.type == "content_block_delta":
|
||||
if isinstance(event.delta, TextDelta):
|
||||
delta_text = event.delta.text
|
||||
if delta_text:
|
||||
full_delta += delta_text
|
||||
yield delta_text
|
||||
elif event.type == "message_start":
|
||||
logger.debug(f"Stream started. Model: {event.message.model}")
|
||||
elif event.type == "message_stop":
|
||||
logger.debug("Stream message_stop event received.")
|
||||
elif event.type == "content_block_start":
|
||||
if event.content_block.type == "tool_use":
|
||||
logger.debug(f"Tool use start: ID {event.content_block.id}, Name: {event.content_block.name}")
|
||||
elif event.type == "content_block_stop":
|
||||
logger.debug(f"Content block stop. Index: {event.index}")
|
||||
logger.debug(f"Stream finished. Total delta length: {len(full_delta)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing stream: {e}", exc_info=True)
|
||||
yield json.dumps({"error": f"Stream processing error: {str(e)}"})
|
||||
|
||||
|
||||
def get_content(response: Message) -> str:
|
||||
try:
|
||||
text_content = "".join([block.text for block in response.content if block.type == "text"])
|
||||
logger.debug(f"Extracted content (length {len(text_content)})")
|
||||
return text_content
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting content: {e}", exc_info=True)
|
||||
return f"[Error extracting content: {str(e)}]"
|
||||
|
||||
|
||||
def get_usage(response: Any) -> dict[str, int] | None:
|
||||
try:
|
||||
if isinstance(response, Message) and response.usage:
|
||||
usage = {
|
||||
"prompt_tokens": response.usage.input_tokens,
|
||||
"completion_tokens": response.usage.output_tokens,
|
||||
}
|
||||
logger.debug(f"Extracted usage: {usage}")
|
||||
return usage
|
||||
else:
|
||||
logger.warning(f"Could not extract usage from {type(response)}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting usage: {e}", exc_info=True)
|
||||
return None
|
||||
74
src/providers/anthropic_provider/tools.py
Normal file
74
src/providers/anthropic_provider/tools.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from anthropic.types import Message
|
||||
|
||||
from src.tools.conversion import convert_to_anthropic_tools
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def has_tool_calls(response: Any) -> bool:
|
||||
try:
|
||||
if isinstance(response, Message):
|
||||
has_tool_use_block = any(block.type == "tool_use" for block in response.content)
|
||||
has_calls = response.stop_reason == "tool_use" or has_tool_use_block
|
||||
logger.debug(f"Tool calls check: stop_reason='{response.stop_reason}', has_tool_use_block={has_tool_use_block}. Result: {has_calls}")
|
||||
return has_calls
|
||||
else:
|
||||
logger.warning(f"has_tool_calls received unexpected type: {type(response)}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking for tool calls: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
|
||||
def parse_tool_calls(response: Message) -> list[dict[str, Any]]:
|
||||
parsed_calls = []
|
||||
try:
|
||||
if not isinstance(response, Message):
|
||||
logger.error(f"parse_tool_calls expects Message, got {type(response)}")
|
||||
return []
|
||||
tool_use_blocks = [block for block in response.content if block.type == "tool_use"]
|
||||
if not tool_use_blocks:
|
||||
logger.debug("No 'tool_use' content blocks found.")
|
||||
return []
|
||||
logger.debug(f"Parsing {len(tool_use_blocks)} 'tool_use' blocks.")
|
||||
for block in tool_use_blocks:
|
||||
parts = block.name.split("__", 1)
|
||||
if len(parts) == 2:
|
||||
server_name, func_name = parts
|
||||
else:
|
||||
logger.warning(f"Could not determine server_name from tool name '{block.name}'.")
|
||||
server_name = None
|
||||
func_name = block.name
|
||||
parsed_calls.append({"id": block.id, "server_name": server_name, "function_name": func_name, "arguments": block.input})
|
||||
return parsed_calls
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing tool calls: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
|
||||
def format_tool_results(tool_call_id: str, result: Any) -> dict[str, Any]:
|
||||
try:
|
||||
if isinstance(result, dict):
|
||||
content_str = json.dumps(result)
|
||||
else:
|
||||
content_str = str(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error encoding tool result for {tool_call_id}: {e}")
|
||||
content_str = json.dumps({"error": "Failed to encode tool result", "original_type": str(type(result))})
|
||||
logger.debug(f"Formatting tool result for call ID {tool_call_id}")
|
||||
return {"type": "tool_result", "tool_use_id": tool_call_id, "content": content_str}
|
||||
|
||||
|
||||
def convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
logger.debug(f"Converting {len(tools)} tools to Anthropic format.")
|
||||
try:
|
||||
anthropic_tools = convert_to_anthropic_tools(tools)
|
||||
logger.debug(f"Tool conversion result: {anthropic_tools}")
|
||||
return anthropic_tools
|
||||
except Exception as e:
|
||||
logger.error(f"Error during tool conversion: {e}", exc_info=True)
|
||||
return []
|
||||
50
src/providers/anthropic_provider/utils.py
Normal file
50
src/providers/anthropic_provider/utils.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
from typing import Any
|
||||
|
||||
from anthropic import Anthropic
|
||||
|
||||
from src.llm_models import MODELS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_context_window(model: str) -> int:
|
||||
default_window = 100000
|
||||
try:
|
||||
provider_models = MODELS.get("anthropic", {}).get("models", [])
|
||||
for m in provider_models:
|
||||
if m.get("id") == model:
|
||||
return m.get("context_window", default_window)
|
||||
logger.warning(f"Context window for Anthropic model '{model}' not found. Using default: {default_window}")
|
||||
return default_window
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving context window for model {model}: {e}. Using default: {default_window}", exc_info=True)
|
||||
return default_window
|
||||
|
||||
|
||||
def count_anthropic_tokens(client: Anthropic, messages: list[dict[str, Any]], system_prompt: str | None) -> int:
|
||||
text_to_count = ""
|
||||
if system_prompt:
|
||||
text_to_count += f"System: {system_prompt}\n\n"
|
||||
for message in messages:
|
||||
role = message.get("role")
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
text_to_count += f"{role}: {content}\n"
|
||||
elif isinstance(content, list):
|
||||
try:
|
||||
content_str = json.dumps(content)
|
||||
text_to_count += f"{role}: {content_str}\n"
|
||||
except Exception:
|
||||
text_to_count += f"{role}: [Unserializable Content]\n"
|
||||
try:
|
||||
count = client.count_tokens(text=text_to_count)
|
||||
logger.debug(f"Counted Anthropic tokens: {count}")
|
||||
return count
|
||||
except Exception as e:
|
||||
logger.error(f"Error counting Anthropic tokens: {e}", exc_info=True)
|
||||
estimated_tokens = math.ceil(len(text_to_count) / 4.0)
|
||||
logger.warning(f"Falling back to approximation: {estimated_tokens}")
|
||||
return estimated_tokens
|
||||
483
src/providers/google_provider.py
Normal file
483
src/providers/google_provider.py
Normal file
@@ -0,0 +1,483 @@
|
||||
# src/providers/google_provider.py
|
||||
import json
|
||||
import logging
|
||||
import traceback
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from google import genai
|
||||
from google.genai.types import (
|
||||
Content,
|
||||
FunctionDeclaration,
|
||||
Part,
|
||||
Schema,
|
||||
Tool,
|
||||
)
|
||||
|
||||
from src.llm_models import MODELS
|
||||
from src.providers.base import BaseProvider
|
||||
from src.tools.conversion import convert_to_google_tools
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GoogleProvider(BaseProvider):
|
||||
"""Provider implementation for Google Gemini models."""
|
||||
|
||||
def __init__(self, api_key: str, base_url: str | None = None):
|
||||
# Google client typically doesn't use a base_url, but we accept it for consistency
|
||||
effective_base_url = base_url or MODELS.get("google", {}).get("endpoint")
|
||||
super().__init__(api_key, effective_base_url)
|
||||
logger.info("Initializing GoogleProvider")
|
||||
|
||||
if genai is None:
|
||||
raise ImportError("Google Generative AI SDK is required for GoogleProvider. Please install google-generativeai.")
|
||||
|
||||
try:
|
||||
# Configure the client
|
||||
genai.configure(api_key=self.api_key)
|
||||
self.client_module = genai
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure Google Generative AI client: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def _get_context_window(self, model: str) -> int:
|
||||
"""Retrieves the context window size for a given Google model."""
|
||||
default_window = 1000000 # Default fallback for Gemini
|
||||
try:
|
||||
provider_models = MODELS.get("google", {}).get("models", [])
|
||||
for m in provider_models:
|
||||
if m.get("id") == model:
|
||||
return m.get("context_window", default_window)
|
||||
logger.warning(f"Context window for Google model '{model}' not found in MODELS config. Using default: {default_window}")
|
||||
return default_window
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving context window for model {model}: {e}. Using default: {default_window}", exc_info=True)
|
||||
return default_window
|
||||
|
||||
def _convert_messages(self, messages: list[dict[str, Any]]) -> tuple[list[Content], str | None]:
|
||||
"""
|
||||
Converts standard message format to Google's format, extracting system prompt.
|
||||
Handles mapping roles and structuring tool calls/results.
|
||||
"""
|
||||
google_messages: list[Content] = []
|
||||
system_prompt: str | None = None
|
||||
|
||||
for i, message in enumerate(messages):
|
||||
role = message.get("role")
|
||||
content = message.get("content")
|
||||
tool_calls = message.get("tool_calls")
|
||||
tool_call_id = message.get("tool_call_id")
|
||||
|
||||
if role == "system":
|
||||
if i == 0:
|
||||
system_prompt = content
|
||||
logger.debug("Extracted system prompt for Google.")
|
||||
else:
|
||||
logger.warning("System message found not at the beginning. Merging into subsequent user message.")
|
||||
continue
|
||||
|
||||
google_role = {"user": "user", "assistant": "model", "tool": "user"}.get(role)
|
||||
|
||||
if not google_role:
|
||||
logger.warning(f"Unsupported role '{role}' for Google provider, skipping message.")
|
||||
continue
|
||||
|
||||
parts: list[Part | str] = []
|
||||
if role == "tool":
|
||||
if tool_call_id and content:
|
||||
try:
|
||||
response_content_dict = json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Could not decode tool result content for {tool_call_id}, sending as raw string.")
|
||||
response_content_dict = {"result": content}
|
||||
|
||||
func_name = "unknown_function"
|
||||
if i > 0 and messages[i - 1].get("role") == "assistant":
|
||||
prev_tool_calls = messages[i - 1].get("tool_calls")
|
||||
if prev_tool_calls:
|
||||
for tc in prev_tool_calls:
|
||||
if tc.get("id") == tool_call_id:
|
||||
func_name = tc.get("function_name", "unknown_function")
|
||||
break
|
||||
|
||||
parts.append(Part.from_function_response(name=func_name, response={"content": response_content_dict}))
|
||||
google_role = "function"
|
||||
else:
|
||||
logger.warning(f"Skipping tool message due to missing tool_call_id or content: {message}")
|
||||
continue
|
||||
|
||||
elif role == "assistant" and tool_calls:
|
||||
for tool_call in tool_calls:
|
||||
args = tool_call.get("arguments", {})
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
args = json.loads(args)
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Failed to parse arguments string for tool call {tool_call.get('id')}: {args}")
|
||||
args = {"error": "failed to parse arguments"}
|
||||
func_name = tool_call.get("function_name", "unknown_function")
|
||||
parts.append(Part.from_function_call(name=func_name, args=args))
|
||||
if content:
|
||||
parts.append(Part.from_text(content))
|
||||
|
||||
elif content:
|
||||
if isinstance(content, str):
|
||||
parts.append(Part.from_text(content))
|
||||
else:
|
||||
logger.warning(f"Unsupported content type for role '{role}': {type(content)}. Converting to string.")
|
||||
parts.append(Part.from_text(str(content)))
|
||||
|
||||
if parts:
|
||||
google_messages.append(Content(role=google_role, parts=parts))
|
||||
else:
|
||||
logger.debug(f"No parts generated for message: {message}")
|
||||
|
||||
last_role = None
|
||||
valid_alternation = True
|
||||
for msg in google_messages:
|
||||
current_role = msg.role
|
||||
if current_role == last_role and current_role in ["user", "model"]:
|
||||
valid_alternation = False
|
||||
logger.warning(f"Invalid role sequence detected: consecutive '{current_role}' roles.")
|
||||
break
|
||||
if last_role == "function" and current_role != "user":
|
||||
valid_alternation = False
|
||||
logger.warning(f"Invalid role sequence: '{current_role}' follows 'function'. Expected 'user'.")
|
||||
break
|
||||
last_role = current_role
|
||||
|
||||
if not valid_alternation:
|
||||
logger.error("Message list does not follow required user/model alternation for Google API.")
|
||||
raise ValueError("Invalid message sequence for Google API.")
|
||||
|
||||
return google_messages, system_prompt
|
||||
|
||||
def create_chat_completion(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
model: str,
|
||||
temperature: float = 0.4,
|
||||
max_tokens: int | None = None,
|
||||
stream: bool = True,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
) -> Any:
|
||||
"""Creates a chat completion using the Google Gemini API."""
|
||||
logger.debug(f"Google create_chat_completion called. Model: {model}, Stream: {stream}, Tools: {bool(tools)}")
|
||||
|
||||
if self.client_module is None:
|
||||
return {"error": "Google Generative AI SDK not installed."} if not stream else iter([json.dumps({"error": "Google Generative AI SDK not installed."})])
|
||||
|
||||
try:
|
||||
google_messages, system_prompt = self._convert_messages(messages)
|
||||
generation_config: dict[str, Any] = {"temperature": temperature}
|
||||
if max_tokens is not None:
|
||||
generation_config["max_output_tokens"] = max_tokens
|
||||
|
||||
google_tools = None
|
||||
if tools:
|
||||
try:
|
||||
tool_dict_list = convert_to_google_tools(tools)
|
||||
google_tools = self._convert_to_tool_objects(tool_dict_list)
|
||||
logger.debug(f"Converted {len(tools)} tools to {len(google_tools)} Google Tool objects.")
|
||||
except Exception as tool_conv_err:
|
||||
logger.error(f"Failed to convert tools for Google: {tool_conv_err}", exc_info=True)
|
||||
google_tools = None
|
||||
|
||||
gemini_model = self.client_module.GenerativeModel(
|
||||
model_name=model,
|
||||
system_instruction=system_prompt,
|
||||
tools=google_tools if google_tools else None,
|
||||
)
|
||||
|
||||
log_params = {
|
||||
"model": model,
|
||||
"stream": stream,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"system_prompt_present": bool(system_prompt),
|
||||
"num_tools": len(google_tools) if google_tools else 0,
|
||||
"num_messages": len(google_messages),
|
||||
}
|
||||
logger.debug(f"Calling Google API with params: {log_params}")
|
||||
|
||||
response = gemini_model.generate_content(
|
||||
contents=google_messages,
|
||||
generation_config=generation_config,
|
||||
stream=stream,
|
||||
)
|
||||
logger.debug("Google API call successful.")
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Google API error: {e}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
if stream:
|
||||
yield json.dumps({"error": error_msg, "traceback": traceback.format_exc()})
|
||||
else:
|
||||
return {"error": error_msg, "traceback": traceback.format_exc()}
|
||||
|
||||
def get_streaming_content(self, response: Any) -> Generator[str, None, None]:
|
||||
"""Yields content chunks from a Google streaming response."""
|
||||
logger.debug("Processing Google stream...")
|
||||
full_delta = ""
|
||||
try:
|
||||
if isinstance(response, dict) and "error" in response:
|
||||
yield json.dumps(response)
|
||||
return
|
||||
if hasattr(response, "__iter__") and not hasattr(response, "candidates"):
|
||||
yield from response
|
||||
return
|
||||
|
||||
for chunk in response:
|
||||
if isinstance(chunk, dict) and "error" in chunk:
|
||||
yield json.dumps(chunk)
|
||||
continue
|
||||
if hasattr(chunk, "text"):
|
||||
delta = chunk.text
|
||||
if delta:
|
||||
full_delta += delta
|
||||
yield delta
|
||||
elif hasattr(chunk, "candidates") and chunk.candidates:
|
||||
for part in chunk.candidates[0].content.parts:
|
||||
if hasattr(part, "function_call") and part.function_call:
|
||||
logger.debug(f"Function call detected during stream: {part.function_call.name}")
|
||||
break
|
||||
|
||||
logger.debug(f"Google stream finished. Total delta length: {len(full_delta)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing Google stream: {e}", exc_info=True)
|
||||
yield json.dumps({"error": f"Stream processing error: {str(e)}"})
|
||||
|
||||
def get_content(self, response: Any) -> str:
|
||||
"""Extracts content from a non-streaming Google response."""
|
||||
try:
|
||||
if isinstance(response, dict) and "error" in response:
|
||||
logger.error(f"Cannot get content from error response: {response['error']}")
|
||||
return f"[Error: {response['error']}]"
|
||||
if hasattr(response, "text"):
|
||||
content = response.text
|
||||
logger.debug(f"Extracted content (length {len(content)}) from response.text.")
|
||||
return content
|
||||
elif hasattr(response, "candidates") and response.candidates:
|
||||
first_candidate = response.candidates[0]
|
||||
if hasattr(first_candidate, "content") and hasattr(first_candidate.content, "parts"):
|
||||
text_parts = [part.text for part in first_candidate.content.parts if hasattr(part, "text")]
|
||||
content = "".join(text_parts)
|
||||
logger.debug(f"Extracted content (length {len(content)}) from response candidates.")
|
||||
return content
|
||||
else:
|
||||
logger.warning("Google response candidate has no content or parts.")
|
||||
return ""
|
||||
else:
|
||||
logger.warning("Could not extract content from Google response: No 'text' or valid 'candidates'.")
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting content from Google response: {e}", exc_info=True)
|
||||
return f"[Error extracting content: {str(e)}]"
|
||||
|
||||
def has_tool_calls(self, response: Any) -> bool:
|
||||
"""Checks if the Google response contains tool calls (function calls)."""
|
||||
try:
|
||||
if isinstance(response, dict) and "error" in response:
|
||||
return False
|
||||
if hasattr(response, "candidates") and response.candidates:
|
||||
candidate = response.candidates[0]
|
||||
if hasattr(candidate, "content") and hasattr(candidate.content, "parts"):
|
||||
for part in candidate.content.parts:
|
||||
if hasattr(part, "function_call") and part.function_call:
|
||||
logger.debug(f"Tool call (FunctionCall) detected in Google response part: {part.function_call.name}")
|
||||
return True
|
||||
logger.debug("No tool calls (FunctionCall) detected in Google response.")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking for Google tool calls: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def parse_tool_calls(self, response: Any) -> list[dict[str, Any]]:
|
||||
"""Parses tool calls (function calls) from a non-streaming Google response."""
|
||||
parsed_calls = []
|
||||
try:
|
||||
if not (hasattr(response, "candidates") and response.candidates):
|
||||
logger.warning("Cannot parse tool calls: Response has no candidates.")
|
||||
return []
|
||||
|
||||
candidate = response.candidates[0]
|
||||
if not (hasattr(candidate, "content") and hasattr(candidate.content, "parts")):
|
||||
logger.warning("Cannot parse tool calls: Response candidate has no content or parts.")
|
||||
return []
|
||||
|
||||
logger.debug("Parsing tool calls (FunctionCall) from Google response.")
|
||||
call_index = 0
|
||||
for part in candidate.content.parts:
|
||||
if hasattr(part, "function_call") and part.function_call:
|
||||
func_call = part.function_call
|
||||
call_id = f"call_{call_index}"
|
||||
call_index += 1
|
||||
|
||||
full_name = func_call.name
|
||||
parts = full_name.split("__", 1)
|
||||
if len(parts) == 2:
|
||||
server_name, func_name = parts
|
||||
else:
|
||||
logger.warning(f"Could not determine server_name from Google tool name '{full_name}'.")
|
||||
server_name = None
|
||||
func_name = full_name
|
||||
|
||||
try:
|
||||
args_str = json.dumps(func_call.args or {})
|
||||
except Exception as json_err:
|
||||
logger.error(f"Failed to dump arguments dict to JSON string for {func_name}: {json_err}")
|
||||
args_str = json.dumps({"error": "Failed to serialize arguments", "original_args": str(func_call.args)})
|
||||
|
||||
parsed_calls.append({
|
||||
"id": call_id,
|
||||
"server_name": server_name,
|
||||
"function_name": func_name,
|
||||
"arguments": args_str,
|
||||
})
|
||||
logger.debug(f"Parsed tool call: ID {call_id}, Server {server_name}, Func {func_name}, Args {args_str[:100]}...")
|
||||
|
||||
return parsed_calls
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing Google tool calls: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
def format_tool_results(self, tool_call_id: str, result: Any) -> dict[str, Any]:
|
||||
"""Formats a tool result for a Google follow-up request."""
|
||||
try:
|
||||
if isinstance(result, dict):
|
||||
content_str = json.dumps(result)
|
||||
else:
|
||||
content_str = str(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error JSON-encoding tool result for Google {tool_call_id}: {e}")
|
||||
content_str = json.dumps({"error": "Failed to encode tool result", "original_type": str(type(result))})
|
||||
|
||||
logger.debug(f"Formatting Google tool result for call ID {tool_call_id}")
|
||||
return {
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"content": content_str,
|
||||
"function_name": "unknown_function",
|
||||
}
|
||||
|
||||
def get_original_message_with_calls(self, response: Any) -> dict[str, Any]:
|
||||
"""Extracts the assistant's message containing tool calls for Google."""
|
||||
try:
|
||||
if not (hasattr(response, "candidates") and response.candidates):
|
||||
return {"role": "assistant", "content": "[Could not extract tool calls message: No candidates]"}
|
||||
|
||||
candidate = response.candidates[0]
|
||||
if not (hasattr(candidate, "content") and hasattr(candidate.content, "parts")):
|
||||
return {"role": "assistant", "content": "[Could not extract tool calls message: No content/parts]"}
|
||||
|
||||
tool_calls_formatted = []
|
||||
text_content_parts = []
|
||||
for part in candidate.content.parts:
|
||||
if hasattr(part, "function_call") and part.function_call:
|
||||
func_call = part.function_call
|
||||
args = func_call.args or {}
|
||||
tool_calls_formatted.append({
|
||||
"function_name": func_call.name,
|
||||
"arguments": args,
|
||||
})
|
||||
elif hasattr(part, "text"):
|
||||
text_content_parts.append(part.text)
|
||||
|
||||
message = {"role": "assistant"}
|
||||
if tool_calls_formatted:
|
||||
message["tool_calls"] = tool_calls_formatted
|
||||
text_content = "".join(text_content_parts)
|
||||
if text_content:
|
||||
message["content"] = text_content
|
||||
elif not tool_calls_formatted:
|
||||
message["content"] = ""
|
||||
|
||||
return message
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting original Google message with calls: {e}", exc_info=True)
|
||||
return {"role": "assistant", "content": f"[Error extracting tool calls message: {str(e)}]"}
|
||||
|
||||
def get_usage(self, response: Any) -> dict[str, int] | None:
|
||||
"""Extracts token usage from a Google response."""
|
||||
try:
|
||||
if isinstance(response, dict) and "error" in response:
|
||||
return None
|
||||
if hasattr(response, "usage_metadata"):
|
||||
metadata = response.usage_metadata
|
||||
usage = {
|
||||
"prompt_tokens": getattr(metadata, "prompt_token_count", 0),
|
||||
"completion_tokens": getattr(metadata, "candidates_token_count", 0),
|
||||
}
|
||||
logger.debug(f"Extracted usage from Google response metadata: {usage}")
|
||||
return usage
|
||||
else:
|
||||
logger.warning(f"Could not extract usage from Google response object of type {type(response)}. No 'usage_metadata'.")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting usage from Google response: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
def _convert_to_tool_objects(self, tool_configs: list[dict[str, Any]]) -> list[Tool] | None:
|
||||
"""Convert dictionary-format tools into Google's Tool objects."""
|
||||
if not tool_configs:
|
||||
return None
|
||||
|
||||
all_func_declarations = []
|
||||
for config in tool_configs:
|
||||
if "function_declarations" in config:
|
||||
for func_dict in config["function_declarations"]:
|
||||
try:
|
||||
params_schema_dict = func_dict.get("parameters", {"type": "object", "properties": {}})
|
||||
if params_schema_dict.get("type") != "object":
|
||||
logger.warning(f"Tool {func_dict['name']} parameters schema is not type 'object'. Forcing object type.")
|
||||
params_schema_dict = {"type": "object", "properties": params_schema_dict}
|
||||
|
||||
def create_schema(schema_dict):
|
||||
if not isinstance(schema_dict, dict):
|
||||
logger.warning(f"Invalid schema part encountered: {schema_dict}. Returning empty schema.")
|
||||
return Schema()
|
||||
schema_args = {
|
||||
"type": schema_dict.get("type"),
|
||||
"format": schema_dict.get("format"),
|
||||
"description": schema_dict.get("description"),
|
||||
"nullable": schema_dict.get("nullable"),
|
||||
"enum": schema_dict.get("enum"),
|
||||
"items": create_schema(schema_dict["items"]) if "items" in schema_dict else None,
|
||||
"properties": {k: create_schema(v) for k, v in schema_dict.get("properties", {}).items()} if schema_dict.get("properties") else None,
|
||||
"required": schema_dict.get("required"),
|
||||
}
|
||||
schema_args = {k: v for k, v in schema_args.items() if v is not None}
|
||||
if "type" in schema_args:
|
||||
type_mapping = {
|
||||
"string": "STRING",
|
||||
"number": "NUMBER",
|
||||
"integer": "INTEGER",
|
||||
"boolean": "BOOLEAN",
|
||||
"array": "ARRAY",
|
||||
"object": "OBJECT",
|
||||
}
|
||||
schema_args["type"] = type_mapping.get(str(schema_args["type"]).lower(), schema_args["type"])
|
||||
try:
|
||||
return Schema(**schema_args)
|
||||
except Exception as schema_creation_err:
|
||||
logger.error(f"Failed to create Schema object for {func_dict['name']} with args {schema_args}: {schema_creation_err}", exc_info=True)
|
||||
return Schema()
|
||||
|
||||
parameters_schema = create_schema(params_schema_dict)
|
||||
declaration = FunctionDeclaration(
|
||||
name=func_dict["name"],
|
||||
description=func_dict.get("description", ""),
|
||||
parameters=parameters_schema,
|
||||
)
|
||||
all_func_declarations.append(declaration)
|
||||
except Exception as decl_err:
|
||||
logger.error(f"Failed to create FunctionDeclaration for tool '{func_dict.get('name', 'Unknown')}': {decl_err}", exc_info=True)
|
||||
|
||||
if not all_func_declarations:
|
||||
logger.warning("No valid function declarations found after conversion.")
|
||||
return None
|
||||
|
||||
return [Tool(function_declarations=all_func_declarations)]
|
||||
@@ -164,11 +164,12 @@ def convert_to_google_tools(mcp_tools: list[dict[str, Any]]) -> list[dict[str, A
|
||||
function_declarations.append(function_declaration)
|
||||
logger.debug(f"Converted MCP tool to Google FunctionDeclaration: {prefixed_tool_name}")
|
||||
|
||||
# Google API expects a list containing one Tool object dict
|
||||
google_tools_wrapper = [{"function_declarations": function_declarations}] if function_declarations else []
|
||||
# Google API expects a list containing one dictionary with 'function_declarations'
|
||||
# The provider's _convert_to_tool_objects will handle creating Tool objects from this.
|
||||
google_tool_config = [{"function_declarations": function_declarations}] if function_declarations else []
|
||||
|
||||
logger.debug(f"Final Google tools structure: {google_tools_wrapper}")
|
||||
return google_tools_wrapper
|
||||
logger.debug(f"Final Google tool config structure: {google_tool_config}")
|
||||
return google_tool_config
|
||||
|
||||
|
||||
# Note: The _handle_schema_construct helper from the reference code is not strictly
|
||||
|
||||
Reference in New Issue
Block a user