Compare commits

...

2 Commits

11 changed files with 826 additions and 459 deletions

View File

@@ -3,9 +3,9 @@ import logging
from providers.anthropic_provider import AnthropicProvider
from providers.base import BaseProvider
from providers.google_provider import GoogleProvider
from providers.openai_provider import OpenAIProvider
# from providers.google_provider import GoogleProvider
# from providers.openrouter_provider import OpenRouterProvider
logger = logging.getLogger(__name__)
@@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
PROVIDER_MAP: dict[str, type[BaseProvider]] = {
"openai": OpenAIProvider,
"anthropic": AnthropicProvider,
# "google": GoogleProvider,
"google": GoogleProvider,
# "openrouter": OpenRouterProvider, # OpenRouter can often use OpenAIProvider with custom base_url
}

View File

@@ -1,453 +0,0 @@
# src/providers/anthropic_provider.py
import json
import logging
import math
from collections.abc import Generator
from typing import Any
from anthropic import Anthropic, APIError, Stream
from anthropic.types import Message, MessageStreamEvent, TextDelta
from providers.base import BaseProvider
from src.llm_models import MODELS
from src.tools.conversion import convert_to_anthropic_tools
logger = logging.getLogger(__name__)
class AnthropicProvider(BaseProvider):
"""Provider implementation for Anthropic Claude models."""
def __init__(self, api_key: str, base_url: str | None = None):
# Anthropic client doesn't use base_url in the same way, but store it if needed
# Use default Anthropic endpoint if base_url is not provided or relevant
effective_base_url = base_url or MODELS.get("anthropic", {}).get("endpoint")
super().__init__(api_key, effective_base_url) # Pass base_url to parent, though Anthropic client might ignore it
logger.info("Initializing AnthropicProvider")
try:
self.client = Anthropic(api_key=self.api_key)
# Note: Anthropic client doesn't take base_url during init
except Exception as e:
logger.error(f"Failed to initialize Anthropic client: {e}", exc_info=True)
raise
def _get_context_window(self, model: str) -> int:
"""Retrieves the context window size for a given Anthropic model."""
default_window = 100000 # Default fallback for Anthropic
try:
provider_models = MODELS.get("anthropic", {}).get("models", [])
for m in provider_models:
if m.get("id") == model:
return m.get("context_window", default_window)
logger.warning(f"Context window for Anthropic model '{model}' not found in MODELS config. Using default: {default_window}")
return default_window
except Exception as e:
logger.error(f"Error retrieving context window for model {model}: {e}. Using default: {default_window}", exc_info=True)
return default_window
def _count_anthropic_tokens(self, messages: list[dict[str, Any]], system_prompt: str | None) -> int:
"""Counts tokens for Anthropic messages using the official client."""
# Note: Anthropic's count_tokens might not directly accept the message list format used for creation.
# It often expects plain text. We need to concatenate the content appropriately.
# This is a simplification and might not be perfectly accurate, especially with tool calls/results.
# A more robust approach might involve formatting messages into a single string representation.
text_to_count = ""
if system_prompt:
text_to_count += f"System: {system_prompt}\n\n"
for message in messages:
role = message.get("role")
content = message.get("content")
# Simple concatenation - might need refinement for complex content types (tool calls/results)
if isinstance(content, str):
text_to_count += f"{role}: {content}\n"
elif isinstance(content, list): # Handle tool results/calls if represented as list
try:
content_str = json.dumps(content)
text_to_count += f"{role}: {content_str}\n"
except Exception:
text_to_count += f"{role}: [Unserializable Content]\n"
try:
# Use the client's count_tokens method if available and works with text
# Check Anthropic documentation for the correct usage
# Assuming self.client.count_tokens exists and takes text
count = self.client.count_tokens(text=text_to_count)
logger.debug(f"Counted Anthropic tokens using client.count_tokens: {count}")
return count
except APIError as api_err:
# Handle potential errors if count_tokens itself is an API call or fails
logger.error(f"Anthropic API error during token count: {api_err}", exc_info=True)
# Fallback to approximation if official count fails?
estimated_tokens = math.ceil(len(text_to_count) / 4.0) # Same approximation as OpenAI
logger.warning(f"Falling back to character count approximation for Anthropic: {estimated_tokens}")
return estimated_tokens
except AttributeError:
# Fallback if count_tokens method doesn't exist or works differently
logger.warning("self.client.count_tokens not available or failed. Falling back to character count approximation.")
estimated_tokens = math.ceil(len(text_to_count) / 4.0) # Same approximation as OpenAI
return estimated_tokens
except Exception as e:
logger.error(f"Unexpected error during Anthropic token count: {e}", exc_info=True)
estimated_tokens = math.ceil(len(text_to_count) / 4.0) # Fallback approximation
logger.warning(f"Falling back to character count approximation due to unexpected error: {estimated_tokens}")
return estimated_tokens
def _truncate_messages(self, messages: list[dict[str, Any]], system_prompt: str | None, model: str) -> tuple[list[dict[str, Any]], str | None, int, int]:
"""
Truncates messages for Anthropic, preserving system prompt.
Returns:
- Potentially truncated list of messages.
- Original system prompt (or None).
- Initial token count.
- Final token count.
"""
context_limit = self._get_context_window(model)
buffer = 200 # Safety buffer
effective_limit = context_limit - buffer
initial_token_count = self._count_anthropic_tokens(messages, system_prompt)
final_token_count = initial_token_count
truncated_messages = list(messages) # Copy
# Anthropic requires alternating user/assistant messages. Truncation needs care.
# We remove from the beginning (after potential system prompt).
# Removing the oldest message (index 0 of the list passed here, as system is separate)
while final_token_count > effective_limit and len(truncated_messages) > 0:
# Always remove the oldest message (index 0)
removed_message = truncated_messages.pop(0)
logger.debug(f"Truncating Anthropic message at index 0 (Role: {removed_message.get('role')}) due to context limit.")
# Ensure alternation after removal if possible (might be complex)
# For simplicity, just remove and recount for now.
# A more robust approach might need to remove pairs (user/assistant).
final_token_count = self._count_anthropic_tokens(truncated_messages, system_prompt)
logger.debug(f"Recalculated Anthropic tokens: {final_token_count}")
# Safety break
if not truncated_messages:
logger.warning("Truncation resulted in empty message list for Anthropic.")
break
if initial_token_count != final_token_count:
logger.info(
f"Truncated messages for Anthropic model {model}. Initial tokens: {initial_token_count}, Final tokens: {final_token_count}, Limit: {context_limit} (Effective: {effective_limit})"
)
else:
logger.debug(f"No truncation needed for Anthropic model {model}. Tokens: {final_token_count}, Limit: {context_limit} (Effective: {effective_limit})")
# Ensure the remaining messages start with 'user' role if no system prompt
if not system_prompt and truncated_messages and truncated_messages[0].get("role") != "user":
logger.warning("First message after truncation is not 'user'. Prepending placeholder.")
# This might indicate an issue with the simple pop(0) logic if pairs weren't removed.
# For now, prepend a basic user message.
truncated_messages.insert(0, {"role": "user", "content": "[Context truncated]"})
# Recount after adding placeholder? Might exceed limit again. Risky.
# Let's log a warning instead of adding potentially problematic content.
# logger.warning("First message after truncation is not 'user'. This might cause issues with Anthropic API.")
return truncated_messages, system_prompt, initial_token_count, final_token_count
def _convert_messages(self, messages: list[dict[str, Any]]) -> tuple[str | None, list[dict[str, Any]]]:
"""Converts standard message format to Anthropic's format, extracting system prompt."""
anthropic_messages = []
system_prompt = None
for i, message in enumerate(messages):
role = message.get("role")
content = message.get("content")
if role == "system":
if i == 0:
system_prompt = content
logger.debug("Extracted system prompt for Anthropic.")
else:
# Handle system message not at the start (append to previous user message or add as user)
logger.warning("System message found not at the beginning. Treating as user message.")
anthropic_messages.append({"role": "user", "content": f"[System Note]\n{content}"})
continue
# Handle tool results specifically
if role == "tool":
# Find the preceding assistant message with the corresponding tool_use block
# This requires careful handling in the follow-up logic
tool_use_id = message.get("tool_call_id")
tool_content = content
# Format as a tool_result content block
anthropic_messages.append({"role": "user", "content": [{"type": "tool_result", "tool_use_id": tool_use_id, "content": tool_content}]})
continue
# Handle assistant message potentially containing tool_use blocks
if role == "assistant":
# Check if content is structured (e.g., from a previous tool call response)
if isinstance(content, list): # Assuming tool calls might be represented as a list
anthropic_messages.append({"role": "assistant", "content": content})
else:
anthropic_messages.append({"role": "assistant", "content": content}) # Regular text content
continue
# Regular user messages
if role == "user":
anthropic_messages.append({"role": "user", "content": content})
continue
logger.warning(f"Unsupported role '{role}' in message conversion for Anthropic.")
# Ensure conversation starts with a user message if no system prompt was used
if not system_prompt and anthropic_messages and anthropic_messages[0]["role"] != "user":
logger.warning("Anthropic conversation must start with a user message. Prepending empty user message.")
anthropic_messages.insert(0, {"role": "user", "content": "[Start of conversation]"}) # Or handle differently
return system_prompt, anthropic_messages
def create_chat_completion(
self,
messages: list[dict[str, str]],
model: str,
temperature: float = 0.4,
max_tokens: int | None = None, # Anthropic requires max_tokens
stream: bool = True,
tools: list[dict[str, Any]] | None = None,
) -> Stream[MessageStreamEvent] | Message:
"""Creates a chat completion using the Anthropic API, handling context truncation."""
logger.debug(f"Anthropic create_chat_completion called. Model: {model}, Stream: {stream}, Tools: {bool(tools)}")
# --- Context Truncation ---
# First, convert to Anthropic format to separate system prompt
temp_system_prompt, temp_anthropic_messages = self._convert_messages(messages)
# Then, truncate based on token count
truncated_anthropic_msgs, final_system_prompt, _, _ = self._truncate_messages(temp_anthropic_messages, temp_system_prompt, model)
# --------------------------
# Anthropic requires max_tokens
if max_tokens is None:
max_tokens = 4096 # Default value if not provided
logger.warning(f"max_tokens not provided for Anthropic, defaulting to {max_tokens}")
# system_prompt, anthropic_messages = self._convert_messages(messages) # Moved above
try:
completion_params = {
"model": model,
"messages": truncated_anthropic_msgs, # Use truncated messages
"temperature": temperature,
"max_tokens": max_tokens,
"stream": stream,
}
if final_system_prompt: # Use potentially modified system prompt
completion_params["system"] = final_system_prompt
if tools:
completion_params["tools"] = tools
# Anthropic doesn't have an explicit 'tool_choice' like OpenAI's 'auto' in the main API call
# Remove None values (though Anthropic requires max_tokens)
completion_params = {k: v for k, v in completion_params.items() if v is not None}
log_params = completion_params.copy()
if "messages" in log_params:
log_params["messages"] = [{k: (v[:100] + "..." if isinstance(v, str) and len(v) > 100 else v) for k, v in msg.items()} for msg in log_params["messages"][-2:]]
tools_log = log_params.get("tools", "Not Present")
logger.debug(f"Calling Anthropic API. Model: {log_params.get('model')}, Stream: {log_params.get('stream')}, System: {bool(log_params.get('system'))}, Tools: {tools_log}")
logger.debug(f"Full API Params (messages summarized): {log_params}")
response = self.client.messages.create(**completion_params)
logger.debug("Anthropic API call successful.")
# --- Capture Actual Usage ---
actual_usage = None
if isinstance(response, Message) and response.usage:
actual_usage = {
"prompt_tokens": response.usage.input_tokens, # Anthropic uses input_tokens
"completion_tokens": response.usage.output_tokens, # Anthropic uses output_tokens
# Anthropic doesn't typically provide total_tokens directly in usage block
"total_tokens": response.usage.input_tokens + response.usage.output_tokens,
}
logger.info(f"Actual Anthropic API usage: {actual_usage}")
# TODO: How to get usage for streaming responses? Anthropic might send it in a final 'message_stop' event? Needs investigation.
return response
# --------------------------
except Exception as e:
logger.error(f"Anthropic API error: {e}", exc_info=True)
raise
def get_streaming_content(self, response: Stream[MessageStreamEvent]) -> Generator[str, None, None]:
"""Yields content chunks from an Anthropic streaming response."""
logger.debug("Processing Anthropic stream...")
full_delta = ""
try:
# Iterate through events in the stream
for event in response:
if event.type == "content_block_delta":
# Check if the delta is for text content before accessing .text
if isinstance(event.delta, TextDelta):
delta_text = event.delta.text
if delta_text:
full_delta += delta_text
yield delta_text
# Ignore other delta types like InputJSONDelta for text streaming
# Other event types like 'message_start', 'content_block_start', etc., can be logged or handled if needed
elif event.type == "message_start":
logger.debug(f"Anthropic stream started. Model: {event.message.model}")
elif event.type == "message_stop":
# The stop_reason might be available on the 'message' object associated with the stream,
# not directly on the stop event itself. We log that the stop event occurred.
# Accessing the actual reason might require inspecting the final message state if needed.
logger.debug("Anthropic stream message_stop event received.")
elif event.type == "content_block_start":
if event.content_block.type == "tool_use":
logger.debug(f"Anthropic stream detected tool use start: ID {event.content_block.id}, Name: {event.content_block.name}")
elif event.type == "content_block_stop":
logger.debug(f"Anthropic stream detected content block stop. Index: {event.index}")
logger.debug(f"Anthropic stream finished. Total delta length: {len(full_delta)}")
except Exception as e:
logger.error(f"Error processing Anthropic stream: {e}", exc_info=True)
yield json.dumps({"error": f"Stream processing error: {str(e)}"})
def get_content(self, response: Message) -> str:
"""Extracts content from a non-streaming Anthropic response."""
try:
# Combine text content from all text blocks
text_content = "".join([block.text for block in response.content if block.type == "text"])
logger.debug(f"Extracted content (length {len(text_content)}) from non-streaming Anthropic response.")
return text_content
except Exception as e:
logger.error(f"Error extracting content from Anthropic response: {e}", exc_info=True)
return f"[Error extracting content: {str(e)}]"
def has_tool_calls(self, response: Stream[MessageStreamEvent] | Message) -> bool:
"""Checks if the Anthropic response contains tool calls."""
try:
if isinstance(response, Message): # Non-streaming
# Check stop reason and content blocks
has_tool_use_block = any(block.type == "tool_use" for block in response.content)
has_calls = response.stop_reason == "tool_use" or has_tool_use_block
logger.debug(f"Non-streaming Anthropic response check: stop_reason='{response.stop_reason}', has_tool_use_block={has_tool_use_block}. Result: {has_calls}")
return has_calls
elif isinstance(response, Stream):
# Cannot reliably check an unconsumed stream without consuming it.
# The LLMClient should handle this by checking after consumption or based on stop_reason if available post-stream.
logger.warning("has_tool_calls check on an Anthropic stream is unreliable before consumption.")
return False
else:
logger.warning(f"has_tool_calls received unexpected type for Anthropic: {type(response)}")
return False
except Exception as e:
logger.error(f"Error checking for Anthropic tool calls: {e}", exc_info=True)
return False
def parse_tool_calls(self, response: Message) -> list[dict[str, Any]]:
"""Parses tool calls from a non-streaming Anthropic response."""
parsed_calls = []
try:
if not isinstance(response, Message):
logger.error(f"parse_tool_calls expects Anthropic Message, got {type(response)}")
return []
if response.stop_reason != "tool_use":
logger.debug("No tool use indicated by stop_reason.")
# return [] # Might still have tool_use blocks even if stop_reason isn't tool_use? Check API docs. Let's check content anyway.
tool_use_blocks = [block for block in response.content if block.type == "tool_use"]
if not tool_use_blocks:
logger.debug("No 'tool_use' content blocks found in Anthropic response.")
return []
logger.debug(f"Parsing {len(tool_use_blocks)} 'tool_use' blocks from Anthropic response.")
for block in tool_use_blocks:
# Adapt server/tool name splitting if needed (similar to OpenAI provider)
# Assuming Anthropic tool names might also be prefixed like "server__tool"
parts = block.name.split("__", 1)
if len(parts) == 2:
server_name, func_name = parts
else:
logger.warning(f"Could not determine server_name from Anthropic tool name '{block.name}'.")
server_name = None
func_name = block.name
parsed_calls.append({
"id": block.id,
"server_name": server_name,
"function_name": func_name,
"arguments": json.dumps(block.input), # Anthropic input is already a dict, dump to string like OpenAI provider expects? Or keep as dict? Let's keep as dict for now.
# "arguments": block.input, # Keep as dict? Let's try this first.
})
return parsed_calls
except Exception as e:
logger.error(f"Error parsing Anthropic tool calls: {e}", exc_info=True)
return []
def format_tool_results(self, tool_call_id: str, result: Any) -> dict[str, Any]:
"""Formats a tool result for an Anthropic follow-up request."""
# Anthropic expects a 'tool_result' content block
# The content of the result block should typically be a string.
try:
if isinstance(result, dict):
content_str = json.dumps(result)
else:
content_str = str(result)
except Exception as e:
logger.error(f"Error JSON-encoding tool result for Anthropic {tool_call_id}: {e}")
content_str = json.dumps({"error": "Failed to encode tool result", "original_type": str(type(result))})
logger.debug(f"Formatting Anthropic tool result for call ID {tool_call_id}")
# This needs to be placed inside a "user" role message's content list
return {
"type": "tool_result",
"tool_use_id": tool_call_id,
"content": content_str,
# Optionally add is_error=True if result indicates an error
}
def convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Converts internal tool format to Anthropic's format."""
# Use the conversion function, assuming it's correctly placed and imported
logger.debug(f"Converting {len(tools)} tools to Anthropic format.")
try:
# The conversion function needs to handle the server__tool prefixing
anthropic_tools = convert_to_anthropic_tools(tools)
logger.debug(f"Tool conversion result: {anthropic_tools}")
return anthropic_tools
except Exception as e:
logger.error(f"Error during Anthropic tool conversion: {e}", exc_info=True)
return []
# Helper needed by LLMClient's current tool handling logic (if adapting OpenAI's pattern)
def get_original_message_with_calls(self, response: Message) -> dict[str, Any]:
"""Extracts the assistant's message containing tool calls for Anthropic."""
try:
if isinstance(response, Message) and any(block.type == "tool_use" for block in response.content):
# Anthropic's response structure is different. The 'message' itself is the assistant's turn.
# We need to return a representation of this turn, including the tool_use blocks.
# Convert Pydantic models within content to dicts
content_list = [block.model_dump(exclude_unset=True) for block in response.content]
return {"role": "assistant", "content": content_list}
else:
logger.warning("Could not extract original message with tool calls from Anthropic response.")
return {"role": "assistant", "content": "[Could not extract tool calls message]"}
except Exception as e:
logger.error(f"Error extracting original Anthropic message with calls: {e}", exc_info=True)
return {"role": "assistant", "content": f"[Error extracting tool calls message: {str(e)}]"}
def get_usage(self, response: Any) -> dict[str, int] | None:
"""Extracts token usage from a non-streaming Anthropic response."""
try:
if isinstance(response, Message) and response.usage:
usage = {
"prompt_tokens": response.usage.input_tokens,
"completion_tokens": response.usage.output_tokens,
# "total_tokens": response.usage.input_tokens + response.usage.output_tokens, # Optional
}
logger.debug(f"Extracted usage from Anthropic response: {usage}")
return usage
else:
logger.warning(f"Could not extract usage from Anthropic response object of type {type(response)}")
return None
except Exception as e:
logger.error(f"Error extracting usage from Anthropic response: {e}", exc_info=True)
return None

View File

@@ -0,0 +1,34 @@
from providers.anthropic_provider.client import initialize_client
from providers.anthropic_provider.completion import create_chat_completion
from providers.anthropic_provider.response import get_content, get_streaming_content, get_usage
from providers.anthropic_provider.tools import convert_tools, format_tool_results, has_tool_calls, parse_tool_calls
from providers.base import BaseProvider
class AnthropicProvider(BaseProvider):
def __init__(self, api_key: str, base_url: str | None = None):
self.client = initialize_client(api_key, base_url)
def create_chat_completion(self, messages, model, temperature=0.4, max_tokens=None, stream=True, tools=None):
return create_chat_completion(self, messages, model, temperature, max_tokens, stream, tools)
def get_streaming_content(self, response):
return get_streaming_content(response)
def get_content(self, response):
return get_content(response)
def has_tool_calls(self, response):
return has_tool_calls(response)
def parse_tool_calls(self, response):
return parse_tool_calls(response)
def format_tool_results(self, tool_call_id, result):
return format_tool_results(tool_call_id, result)
def convert_tools(self, tools):
return convert_tools(tools)
def get_usage(self, response):
return get_usage(response)

View File

@@ -0,0 +1,17 @@
import logging
from anthropic import Anthropic
logger = logging.getLogger(__name__)
def initialize_client(api_key: str, base_url: str | None = None) -> Anthropic:
logger.info("Initializing Anthropic client")
try:
client = Anthropic(api_key=api_key)
if base_url:
logger.warning(f"base_url '{base_url}' provided but not used by Anthropic client")
return client
except Exception as e:
logger.error(f"Failed to initialize Anthropic client: {e}", exc_info=True)
raise

View File

@@ -0,0 +1,38 @@
import logging
from typing import Any
from anthropic import Stream
from anthropic.types import Message
from providers.anthropic_provider.messages import convert_messages, truncate_messages
logger = logging.getLogger(__name__)
def create_chat_completion(
provider, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, stream: bool = True, tools: list[dict[str, Any]] | None = None
) -> Stream | Message:
logger.debug(f"Creating Anthropic chat completion. Model: {model}, Stream: {stream}, Tools: {bool(tools)}")
temp_system_prompt, temp_anthropic_messages = convert_messages(messages)
truncated_messages, final_system_prompt, _, _ = truncate_messages(provider, temp_anthropic_messages, temp_system_prompt, model)
if max_tokens is None:
max_tokens = 4096
logger.warning(f"max_tokens not provided, defaulting to {max_tokens}")
completion_params = {
"model": model,
"messages": truncated_messages,
"temperature": temperature,
"max_tokens": max_tokens,
"stream": stream,
}
if final_system_prompt:
completion_params["system"] = final_system_prompt
if tools:
completion_params["tools"] = tools
try:
response = provider.client.messages.create(**completion_params)
logger.debug("Anthropic API call successful.")
return response
except Exception as e:
logger.error(f"Anthropic API error: {e}", exc_info=True)
raise

View File

@@ -0,0 +1,61 @@
import logging
from typing import Any
from providers.anthropic_provider.utils import count_anthropic_tokens, get_context_window
logger = logging.getLogger(__name__)
def convert_messages(messages: list[dict[str, Any]]) -> tuple[str | None, list[dict[str, Any]]]:
anthropic_messages = []
system_prompt = None
for i, message in enumerate(messages):
role = message.get("role")
content = message.get("content")
if role == "system":
if i == 0:
system_prompt = content
else:
logger.warning("System message not at beginning. Treating as user message.")
anthropic_messages.append({"role": "user", "content": f"[System Note]\n{content}"})
continue
if role == "tool":
tool_use_id = message.get("tool_call_id")
tool_content = content
anthropic_messages.append({"role": "user", "content": [{"type": "tool_result", "tool_use_id": tool_use_id, "content": tool_content}]})
continue
if role == "assistant":
if isinstance(content, list):
anthropic_messages.append({"role": "assistant", "content": content})
else:
anthropic_messages.append({"role": "assistant", "content": content})
continue
if role == "user":
anthropic_messages.append({"role": "user", "content": content})
continue
logger.warning(f"Unsupported role '{role}' in message conversion.")
if not system_prompt and anthropic_messages and anthropic_messages[0]["role"] != "user":
logger.warning("Conversation must start with user message. Prepending placeholder.")
anthropic_messages.insert(0, {"role": "user", "content": "[Start of conversation]"})
return system_prompt, anthropic_messages
def truncate_messages(provider, messages: list[dict[str, Any]], system_prompt: str | None, model: str) -> tuple[list[dict[str, Any]], str | None, int, int]:
context_limit = get_context_window(model)
buffer = 200
effective_limit = context_limit - buffer
initial_token_count = count_anthropic_tokens(provider.client, messages, system_prompt)
final_token_count = initial_token_count
truncated_messages = list(messages)
while final_token_count > effective_limit and len(truncated_messages) > 0:
removed_message = truncated_messages.pop(0)
logger.debug(f"Truncating message (Role: {removed_message.get('role')})")
final_token_count = count_anthropic_tokens(provider.client, truncated_messages, system_prompt)
if initial_token_count != final_token_count:
logger.info(f"Truncated messages. Initial tokens: {initial_token_count}, Final: {final_token_count}")
else:
logger.debug(f"No truncation needed. Tokens: {final_token_count}")
if not system_prompt and truncated_messages and truncated_messages[0].get("role") != "user":
logger.warning("First message after truncation is not 'user'. Prepending placeholder.")
truncated_messages.insert(0, {"role": "user", "content": "[Context truncated]"})
return truncated_messages, system_prompt, initial_token_count, final_token_count

View File

@@ -0,0 +1,62 @@
import json
import logging
from collections.abc import Generator
from typing import Any
from anthropic import Stream
from anthropic.types import Message, MessageStreamEvent, TextDelta
logger = logging.getLogger(__name__)
def get_streaming_content(response: Stream[MessageStreamEvent]) -> Generator[str, None, None]:
logger.debug("Processing Anthropic stream...")
full_delta = ""
try:
for event in response:
if event.type == "content_block_delta":
if isinstance(event.delta, TextDelta):
delta_text = event.delta.text
if delta_text:
full_delta += delta_text
yield delta_text
elif event.type == "message_start":
logger.debug(f"Stream started. Model: {event.message.model}")
elif event.type == "message_stop":
logger.debug("Stream message_stop event received.")
elif event.type == "content_block_start":
if event.content_block.type == "tool_use":
logger.debug(f"Tool use start: ID {event.content_block.id}, Name: {event.content_block.name}")
elif event.type == "content_block_stop":
logger.debug(f"Content block stop. Index: {event.index}")
logger.debug(f"Stream finished. Total delta length: {len(full_delta)}")
except Exception as e:
logger.error(f"Error processing stream: {e}", exc_info=True)
yield json.dumps({"error": f"Stream processing error: {str(e)}"})
def get_content(response: Message) -> str:
try:
text_content = "".join([block.text for block in response.content if block.type == "text"])
logger.debug(f"Extracted content (length {len(text_content)})")
return text_content
except Exception as e:
logger.error(f"Error extracting content: {e}", exc_info=True)
return f"[Error extracting content: {str(e)}]"
def get_usage(response: Any) -> dict[str, int] | None:
try:
if isinstance(response, Message) and response.usage:
usage = {
"prompt_tokens": response.usage.input_tokens,
"completion_tokens": response.usage.output_tokens,
}
logger.debug(f"Extracted usage: {usage}")
return usage
else:
logger.warning(f"Could not extract usage from {type(response)}")
return None
except Exception as e:
logger.error(f"Error extracting usage: {e}", exc_info=True)
return None

View File

@@ -0,0 +1,74 @@
import json
import logging
from typing import Any
from anthropic.types import Message
from src.tools.conversion import convert_to_anthropic_tools
logger = logging.getLogger(__name__)
def has_tool_calls(response: Any) -> bool:
try:
if isinstance(response, Message):
has_tool_use_block = any(block.type == "tool_use" for block in response.content)
has_calls = response.stop_reason == "tool_use" or has_tool_use_block
logger.debug(f"Tool calls check: stop_reason='{response.stop_reason}', has_tool_use_block={has_tool_use_block}. Result: {has_calls}")
return has_calls
else:
logger.warning(f"has_tool_calls received unexpected type: {type(response)}")
return False
except Exception as e:
logger.error(f"Error checking for tool calls: {e}", exc_info=True)
return False
def parse_tool_calls(response: Message) -> list[dict[str, Any]]:
parsed_calls = []
try:
if not isinstance(response, Message):
logger.error(f"parse_tool_calls expects Message, got {type(response)}")
return []
tool_use_blocks = [block for block in response.content if block.type == "tool_use"]
if not tool_use_blocks:
logger.debug("No 'tool_use' content blocks found.")
return []
logger.debug(f"Parsing {len(tool_use_blocks)} 'tool_use' blocks.")
for block in tool_use_blocks:
parts = block.name.split("__", 1)
if len(parts) == 2:
server_name, func_name = parts
else:
logger.warning(f"Could not determine server_name from tool name '{block.name}'.")
server_name = None
func_name = block.name
parsed_calls.append({"id": block.id, "server_name": server_name, "function_name": func_name, "arguments": block.input})
return parsed_calls
except Exception as e:
logger.error(f"Error parsing tool calls: {e}", exc_info=True)
return []
def format_tool_results(tool_call_id: str, result: Any) -> dict[str, Any]:
try:
if isinstance(result, dict):
content_str = json.dumps(result)
else:
content_str = str(result)
except Exception as e:
logger.error(f"Error encoding tool result for {tool_call_id}: {e}")
content_str = json.dumps({"error": "Failed to encode tool result", "original_type": str(type(result))})
logger.debug(f"Formatting tool result for call ID {tool_call_id}")
return {"type": "tool_result", "tool_use_id": tool_call_id, "content": content_str}
def convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
logger.debug(f"Converting {len(tools)} tools to Anthropic format.")
try:
anthropic_tools = convert_to_anthropic_tools(tools)
logger.debug(f"Tool conversion result: {anthropic_tools}")
return anthropic_tools
except Exception as e:
logger.error(f"Error during tool conversion: {e}", exc_info=True)
return []

View File

@@ -0,0 +1,50 @@
import json
import logging
import math
from typing import Any
from anthropic import Anthropic
from src.llm_models import MODELS
logger = logging.getLogger(__name__)
def get_context_window(model: str) -> int:
default_window = 100000
try:
provider_models = MODELS.get("anthropic", {}).get("models", [])
for m in provider_models:
if m.get("id") == model:
return m.get("context_window", default_window)
logger.warning(f"Context window for Anthropic model '{model}' not found. Using default: {default_window}")
return default_window
except Exception as e:
logger.error(f"Error retrieving context window for model {model}: {e}. Using default: {default_window}", exc_info=True)
return default_window
def count_anthropic_tokens(client: Anthropic, messages: list[dict[str, Any]], system_prompt: str | None) -> int:
text_to_count = ""
if system_prompt:
text_to_count += f"System: {system_prompt}\n\n"
for message in messages:
role = message.get("role")
content = message.get("content")
if isinstance(content, str):
text_to_count += f"{role}: {content}\n"
elif isinstance(content, list):
try:
content_str = json.dumps(content)
text_to_count += f"{role}: {content_str}\n"
except Exception:
text_to_count += f"{role}: [Unserializable Content]\n"
try:
count = client.count_tokens(text=text_to_count)
logger.debug(f"Counted Anthropic tokens: {count}")
return count
except Exception as e:
logger.error(f"Error counting Anthropic tokens: {e}", exc_info=True)
estimated_tokens = math.ceil(len(text_to_count) / 4.0)
logger.warning(f"Falling back to approximation: {estimated_tokens}")
return estimated_tokens

View File

@@ -0,0 +1,483 @@
# src/providers/google_provider.py
import json
import logging
import traceback
from collections.abc import Generator
from typing import Any
from google import genai
from google.genai.types import (
Content,
FunctionDeclaration,
Part,
Schema,
Tool,
)
from src.llm_models import MODELS
from src.providers.base import BaseProvider
from src.tools.conversion import convert_to_google_tools
logger = logging.getLogger(__name__)
class GoogleProvider(BaseProvider):
"""Provider implementation for Google Gemini models."""
def __init__(self, api_key: str, base_url: str | None = None):
# Google client typically doesn't use a base_url, but we accept it for consistency
effective_base_url = base_url or MODELS.get("google", {}).get("endpoint")
super().__init__(api_key, effective_base_url)
logger.info("Initializing GoogleProvider")
if genai is None:
raise ImportError("Google Generative AI SDK is required for GoogleProvider. Please install google-generativeai.")
try:
# Configure the client
genai.configure(api_key=self.api_key)
self.client_module = genai
except Exception as e:
logger.error(f"Failed to configure Google Generative AI client: {e}", exc_info=True)
raise
def _get_context_window(self, model: str) -> int:
"""Retrieves the context window size for a given Google model."""
default_window = 1000000 # Default fallback for Gemini
try:
provider_models = MODELS.get("google", {}).get("models", [])
for m in provider_models:
if m.get("id") == model:
return m.get("context_window", default_window)
logger.warning(f"Context window for Google model '{model}' not found in MODELS config. Using default: {default_window}")
return default_window
except Exception as e:
logger.error(f"Error retrieving context window for model {model}: {e}. Using default: {default_window}", exc_info=True)
return default_window
def _convert_messages(self, messages: list[dict[str, Any]]) -> tuple[list[Content], str | None]:
"""
Converts standard message format to Google's format, extracting system prompt.
Handles mapping roles and structuring tool calls/results.
"""
google_messages: list[Content] = []
system_prompt: str | None = None
for i, message in enumerate(messages):
role = message.get("role")
content = message.get("content")
tool_calls = message.get("tool_calls")
tool_call_id = message.get("tool_call_id")
if role == "system":
if i == 0:
system_prompt = content
logger.debug("Extracted system prompt for Google.")
else:
logger.warning("System message found not at the beginning. Merging into subsequent user message.")
continue
google_role = {"user": "user", "assistant": "model", "tool": "user"}.get(role)
if not google_role:
logger.warning(f"Unsupported role '{role}' for Google provider, skipping message.")
continue
parts: list[Part | str] = []
if role == "tool":
if tool_call_id and content:
try:
response_content_dict = json.loads(content)
except json.JSONDecodeError:
logger.warning(f"Could not decode tool result content for {tool_call_id}, sending as raw string.")
response_content_dict = {"result": content}
func_name = "unknown_function"
if i > 0 and messages[i - 1].get("role") == "assistant":
prev_tool_calls = messages[i - 1].get("tool_calls")
if prev_tool_calls:
for tc in prev_tool_calls:
if tc.get("id") == tool_call_id:
func_name = tc.get("function_name", "unknown_function")
break
parts.append(Part.from_function_response(name=func_name, response={"content": response_content_dict}))
google_role = "function"
else:
logger.warning(f"Skipping tool message due to missing tool_call_id or content: {message}")
continue
elif role == "assistant" and tool_calls:
for tool_call in tool_calls:
args = tool_call.get("arguments", {})
if isinstance(args, str):
try:
args = json.loads(args)
except json.JSONDecodeError:
logger.error(f"Failed to parse arguments string for tool call {tool_call.get('id')}: {args}")
args = {"error": "failed to parse arguments"}
func_name = tool_call.get("function_name", "unknown_function")
parts.append(Part.from_function_call(name=func_name, args=args))
if content:
parts.append(Part.from_text(content))
elif content:
if isinstance(content, str):
parts.append(Part.from_text(content))
else:
logger.warning(f"Unsupported content type for role '{role}': {type(content)}. Converting to string.")
parts.append(Part.from_text(str(content)))
if parts:
google_messages.append(Content(role=google_role, parts=parts))
else:
logger.debug(f"No parts generated for message: {message}")
last_role = None
valid_alternation = True
for msg in google_messages:
current_role = msg.role
if current_role == last_role and current_role in ["user", "model"]:
valid_alternation = False
logger.warning(f"Invalid role sequence detected: consecutive '{current_role}' roles.")
break
if last_role == "function" and current_role != "user":
valid_alternation = False
logger.warning(f"Invalid role sequence: '{current_role}' follows 'function'. Expected 'user'.")
break
last_role = current_role
if not valid_alternation:
logger.error("Message list does not follow required user/model alternation for Google API.")
raise ValueError("Invalid message sequence for Google API.")
return google_messages, system_prompt
def create_chat_completion(
self,
messages: list[dict[str, str]],
model: str,
temperature: float = 0.4,
max_tokens: int | None = None,
stream: bool = True,
tools: list[dict[str, Any]] | None = None,
) -> Any:
"""Creates a chat completion using the Google Gemini API."""
logger.debug(f"Google create_chat_completion called. Model: {model}, Stream: {stream}, Tools: {bool(tools)}")
if self.client_module is None:
return {"error": "Google Generative AI SDK not installed."} if not stream else iter([json.dumps({"error": "Google Generative AI SDK not installed."})])
try:
google_messages, system_prompt = self._convert_messages(messages)
generation_config: dict[str, Any] = {"temperature": temperature}
if max_tokens is not None:
generation_config["max_output_tokens"] = max_tokens
google_tools = None
if tools:
try:
tool_dict_list = convert_to_google_tools(tools)
google_tools = self._convert_to_tool_objects(tool_dict_list)
logger.debug(f"Converted {len(tools)} tools to {len(google_tools)} Google Tool objects.")
except Exception as tool_conv_err:
logger.error(f"Failed to convert tools for Google: {tool_conv_err}", exc_info=True)
google_tools = None
gemini_model = self.client_module.GenerativeModel(
model_name=model,
system_instruction=system_prompt,
tools=google_tools if google_tools else None,
)
log_params = {
"model": model,
"stream": stream,
"temperature": temperature,
"max_tokens": max_tokens,
"system_prompt_present": bool(system_prompt),
"num_tools": len(google_tools) if google_tools else 0,
"num_messages": len(google_messages),
}
logger.debug(f"Calling Google API with params: {log_params}")
response = gemini_model.generate_content(
contents=google_messages,
generation_config=generation_config,
stream=stream,
)
logger.debug("Google API call successful.")
return response
except Exception as e:
error_msg = f"Google API error: {e}"
logger.error(error_msg, exc_info=True)
if stream:
yield json.dumps({"error": error_msg, "traceback": traceback.format_exc()})
else:
return {"error": error_msg, "traceback": traceback.format_exc()}
def get_streaming_content(self, response: Any) -> Generator[str, None, None]:
"""Yields content chunks from a Google streaming response."""
logger.debug("Processing Google stream...")
full_delta = ""
try:
if isinstance(response, dict) and "error" in response:
yield json.dumps(response)
return
if hasattr(response, "__iter__") and not hasattr(response, "candidates"):
yield from response
return
for chunk in response:
if isinstance(chunk, dict) and "error" in chunk:
yield json.dumps(chunk)
continue
if hasattr(chunk, "text"):
delta = chunk.text
if delta:
full_delta += delta
yield delta
elif hasattr(chunk, "candidates") and chunk.candidates:
for part in chunk.candidates[0].content.parts:
if hasattr(part, "function_call") and part.function_call:
logger.debug(f"Function call detected during stream: {part.function_call.name}")
break
logger.debug(f"Google stream finished. Total delta length: {len(full_delta)}")
except Exception as e:
logger.error(f"Error processing Google stream: {e}", exc_info=True)
yield json.dumps({"error": f"Stream processing error: {str(e)}"})
def get_content(self, response: Any) -> str:
"""Extracts content from a non-streaming Google response."""
try:
if isinstance(response, dict) and "error" in response:
logger.error(f"Cannot get content from error response: {response['error']}")
return f"[Error: {response['error']}]"
if hasattr(response, "text"):
content = response.text
logger.debug(f"Extracted content (length {len(content)}) from response.text.")
return content
elif hasattr(response, "candidates") and response.candidates:
first_candidate = response.candidates[0]
if hasattr(first_candidate, "content") and hasattr(first_candidate.content, "parts"):
text_parts = [part.text for part in first_candidate.content.parts if hasattr(part, "text")]
content = "".join(text_parts)
logger.debug(f"Extracted content (length {len(content)}) from response candidates.")
return content
else:
logger.warning("Google response candidate has no content or parts.")
return ""
else:
logger.warning("Could not extract content from Google response: No 'text' or valid 'candidates'.")
return ""
except Exception as e:
logger.error(f"Error extracting content from Google response: {e}", exc_info=True)
return f"[Error extracting content: {str(e)}]"
def has_tool_calls(self, response: Any) -> bool:
"""Checks if the Google response contains tool calls (function calls)."""
try:
if isinstance(response, dict) and "error" in response:
return False
if hasattr(response, "candidates") and response.candidates:
candidate = response.candidates[0]
if hasattr(candidate, "content") and hasattr(candidate.content, "parts"):
for part in candidate.content.parts:
if hasattr(part, "function_call") and part.function_call:
logger.debug(f"Tool call (FunctionCall) detected in Google response part: {part.function_call.name}")
return True
logger.debug("No tool calls (FunctionCall) detected in Google response.")
return False
except Exception as e:
logger.error(f"Error checking for Google tool calls: {e}", exc_info=True)
return False
def parse_tool_calls(self, response: Any) -> list[dict[str, Any]]:
"""Parses tool calls (function calls) from a non-streaming Google response."""
parsed_calls = []
try:
if not (hasattr(response, "candidates") and response.candidates):
logger.warning("Cannot parse tool calls: Response has no candidates.")
return []
candidate = response.candidates[0]
if not (hasattr(candidate, "content") and hasattr(candidate.content, "parts")):
logger.warning("Cannot parse tool calls: Response candidate has no content or parts.")
return []
logger.debug("Parsing tool calls (FunctionCall) from Google response.")
call_index = 0
for part in candidate.content.parts:
if hasattr(part, "function_call") and part.function_call:
func_call = part.function_call
call_id = f"call_{call_index}"
call_index += 1
full_name = func_call.name
parts = full_name.split("__", 1)
if len(parts) == 2:
server_name, func_name = parts
else:
logger.warning(f"Could not determine server_name from Google tool name '{full_name}'.")
server_name = None
func_name = full_name
try:
args_str = json.dumps(func_call.args or {})
except Exception as json_err:
logger.error(f"Failed to dump arguments dict to JSON string for {func_name}: {json_err}")
args_str = json.dumps({"error": "Failed to serialize arguments", "original_args": str(func_call.args)})
parsed_calls.append({
"id": call_id,
"server_name": server_name,
"function_name": func_name,
"arguments": args_str,
})
logger.debug(f"Parsed tool call: ID {call_id}, Server {server_name}, Func {func_name}, Args {args_str[:100]}...")
return parsed_calls
except Exception as e:
logger.error(f"Error parsing Google tool calls: {e}", exc_info=True)
return []
def format_tool_results(self, tool_call_id: str, result: Any) -> dict[str, Any]:
"""Formats a tool result for a Google follow-up request."""
try:
if isinstance(result, dict):
content_str = json.dumps(result)
else:
content_str = str(result)
except Exception as e:
logger.error(f"Error JSON-encoding tool result for Google {tool_call_id}: {e}")
content_str = json.dumps({"error": "Failed to encode tool result", "original_type": str(type(result))})
logger.debug(f"Formatting Google tool result for call ID {tool_call_id}")
return {
"role": "tool",
"tool_call_id": tool_call_id,
"content": content_str,
"function_name": "unknown_function",
}
def get_original_message_with_calls(self, response: Any) -> dict[str, Any]:
"""Extracts the assistant's message containing tool calls for Google."""
try:
if not (hasattr(response, "candidates") and response.candidates):
return {"role": "assistant", "content": "[Could not extract tool calls message: No candidates]"}
candidate = response.candidates[0]
if not (hasattr(candidate, "content") and hasattr(candidate.content, "parts")):
return {"role": "assistant", "content": "[Could not extract tool calls message: No content/parts]"}
tool_calls_formatted = []
text_content_parts = []
for part in candidate.content.parts:
if hasattr(part, "function_call") and part.function_call:
func_call = part.function_call
args = func_call.args or {}
tool_calls_formatted.append({
"function_name": func_call.name,
"arguments": args,
})
elif hasattr(part, "text"):
text_content_parts.append(part.text)
message = {"role": "assistant"}
if tool_calls_formatted:
message["tool_calls"] = tool_calls_formatted
text_content = "".join(text_content_parts)
if text_content:
message["content"] = text_content
elif not tool_calls_formatted:
message["content"] = ""
return message
except Exception as e:
logger.error(f"Error extracting original Google message with calls: {e}", exc_info=True)
return {"role": "assistant", "content": f"[Error extracting tool calls message: {str(e)}]"}
def get_usage(self, response: Any) -> dict[str, int] | None:
"""Extracts token usage from a Google response."""
try:
if isinstance(response, dict) and "error" in response:
return None
if hasattr(response, "usage_metadata"):
metadata = response.usage_metadata
usage = {
"prompt_tokens": getattr(metadata, "prompt_token_count", 0),
"completion_tokens": getattr(metadata, "candidates_token_count", 0),
}
logger.debug(f"Extracted usage from Google response metadata: {usage}")
return usage
else:
logger.warning(f"Could not extract usage from Google response object of type {type(response)}. No 'usage_metadata'.")
return None
except Exception as e:
logger.error(f"Error extracting usage from Google response: {e}", exc_info=True)
return None
def _convert_to_tool_objects(self, tool_configs: list[dict[str, Any]]) -> list[Tool] | None:
"""Convert dictionary-format tools into Google's Tool objects."""
if not tool_configs:
return None
all_func_declarations = []
for config in tool_configs:
if "function_declarations" in config:
for func_dict in config["function_declarations"]:
try:
params_schema_dict = func_dict.get("parameters", {"type": "object", "properties": {}})
if params_schema_dict.get("type") != "object":
logger.warning(f"Tool {func_dict['name']} parameters schema is not type 'object'. Forcing object type.")
params_schema_dict = {"type": "object", "properties": params_schema_dict}
def create_schema(schema_dict):
if not isinstance(schema_dict, dict):
logger.warning(f"Invalid schema part encountered: {schema_dict}. Returning empty schema.")
return Schema()
schema_args = {
"type": schema_dict.get("type"),
"format": schema_dict.get("format"),
"description": schema_dict.get("description"),
"nullable": schema_dict.get("nullable"),
"enum": schema_dict.get("enum"),
"items": create_schema(schema_dict["items"]) if "items" in schema_dict else None,
"properties": {k: create_schema(v) for k, v in schema_dict.get("properties", {}).items()} if schema_dict.get("properties") else None,
"required": schema_dict.get("required"),
}
schema_args = {k: v for k, v in schema_args.items() if v is not None}
if "type" in schema_args:
type_mapping = {
"string": "STRING",
"number": "NUMBER",
"integer": "INTEGER",
"boolean": "BOOLEAN",
"array": "ARRAY",
"object": "OBJECT",
}
schema_args["type"] = type_mapping.get(str(schema_args["type"]).lower(), schema_args["type"])
try:
return Schema(**schema_args)
except Exception as schema_creation_err:
logger.error(f"Failed to create Schema object for {func_dict['name']} with args {schema_args}: {schema_creation_err}", exc_info=True)
return Schema()
parameters_schema = create_schema(params_schema_dict)
declaration = FunctionDeclaration(
name=func_dict["name"],
description=func_dict.get("description", ""),
parameters=parameters_schema,
)
all_func_declarations.append(declaration)
except Exception as decl_err:
logger.error(f"Failed to create FunctionDeclaration for tool '{func_dict.get('name', 'Unknown')}': {decl_err}", exc_info=True)
if not all_func_declarations:
logger.warning("No valid function declarations found after conversion.")
return None
return [Tool(function_declarations=all_func_declarations)]

View File

@@ -164,11 +164,12 @@ def convert_to_google_tools(mcp_tools: list[dict[str, Any]]) -> list[dict[str, A
function_declarations.append(function_declaration)
logger.debug(f"Converted MCP tool to Google FunctionDeclaration: {prefixed_tool_name}")
# Google API expects a list containing one Tool object dict
google_tools_wrapper = [{"function_declarations": function_declarations}] if function_declarations else []
# Google API expects a list containing one dictionary with 'function_declarations'
# The provider's _convert_to_tool_objects will handle creating Tool objects from this.
google_tool_config = [{"function_declarations": function_declarations}] if function_declarations else []
logger.debug(f"Final Google tools structure: {google_tools_wrapper}")
return google_tools_wrapper
logger.debug(f"Final Google tool config structure: {google_tool_config}")
return google_tool_config
# Note: The _handle_schema_construct helper from the reference code is not strictly