Implement provider-agnostic tool calling for Anthropic streaming responses.

This commit is contained in:
2025-03-02 05:24:16 +00:00
parent 7df5e3c55e
commit 6cb60f1bbd

View File

@@ -96,6 +96,17 @@ class AnthropicProvider(BaseLLMProvider):
response = self.client.messages.create(**params)
logger.info("Received response from Anthropic")
# Log the response (with sensitive information redacted)
logger.info(f"Anthropic response type: {type(response).__name__}")
# Log as much information as possible
if hasattr(response, "json"):
logger.info(f"Anthropic response json: {json.dumps(response.json)}")
# Log response attributes
response_attrs = [attr for attr in dir(response) if not attr.startswith("_") and not callable(getattr(response, attr))]
logger.info(f"Anthropic response attributes: {response_attrs}")
return response
except Exception as e:
error_msg = str(e)
@@ -139,15 +150,17 @@ class AnthropicProvider(BaseLLMProvider):
Check if the response contains tool calls.
Args:
response: Anthropic response object
response: Anthropic response object or generator with tool_call attribute
Returns:
True if the response contains tool calls, False otherwise
"""
# Check if any content block is a tool_use block
if not hasattr(response, "content"):
return False
# Check if response is a generator with a tool_call attribute
if hasattr(response, "tool_call") and response.tool_call is not None:
return True
# Check if any content block is a tool_use block (for non-streaming responses)
if hasattr(response, "content"):
for block in response.content:
if isinstance(block, dict) and block.get("type") == "tool_use":
return True
@@ -159,7 +172,7 @@ class AnthropicProvider(BaseLLMProvider):
Process tool calls from the response.
Args:
response: Anthropic response object
response: Anthropic response object or generator with tool_call attribute
cookie: Airflow cookie for authentication
Returns:
@@ -170,13 +183,29 @@ class AnthropicProvider(BaseLLMProvider):
if not self.has_tool_calls(response):
return results
# Extract tool_use blocks
tool_use_blocks = [block for block in response.content if isinstance(block, dict) and block.get("type") == "tool_use"]
tool_calls = []
for block in tool_use_blocks:
tool_id = block.get("id")
tool_name = block.get("name")
tool_input = block.get("input", {})
# Check if response is a generator with a tool_call attribute
if hasattr(response, "tool_call") and response.tool_call is not None:
logger.info(f"Processing tool call from generator: {response.tool_call}")
tool_calls.append(response.tool_call)
# Otherwise, extract tool calls from response content (for non-streaming responses)
elif hasattr(response, "content"):
logger.info("Processing tool calls from response content")
tool_calls = [block for block in response.content if isinstance(block, dict) and block.get("type") == "tool_use"]
for tool_call in tool_calls:
# Extract tool details - handle both formats (generator's tool_call and content block)
if isinstance(tool_call, dict) and "id" in tool_call:
# This is from the generator's tool_call attribute
tool_id = tool_call.get("id")
tool_name = tool_call.get("name")
tool_input = tool_call.get("input", {})
else:
# This is from the content blocks
tool_id = tool_call.get("id")
tool_name = tool_call.get("name")
tool_input = tool_call.get("input", {})
try:
# Execute the Airflow tool with the provided arguments and cookie
@@ -270,15 +299,63 @@ class AnthropicProvider(BaseLLMProvider):
response: Anthropic streaming response object
Returns:
Generator yielding content chunks
Generator yielding content chunks with tool_call attribute if detected
"""
logger.info("Starting Anthropic streaming response processing")
# Track only the first tool call detected during streaming
tool_call = None
tool_use_detected = False
def generate():
nonlocal tool_call, tool_use_detected
for chunk in response:
logger.debug(f"Chunk type: {type(chunk)}")
logger.debug(f"Chunk content: {json.dumps(chunk.json) if hasattr(chunk, 'json') else str(chunk)}")
# Handle different types of chunks from Anthropic API
# Check for content_block_start events with type "tool_use"
if not tool_use_detected and hasattr(chunk, "type") and chunk.type == "content_block_start":
if hasattr(chunk, "content_block") and hasattr(chunk.content_block, "type"):
if chunk.content_block.type == "tool_use":
logger.info(f"Tool use detected in streaming response: {json.dumps(chunk.json) if hasattr(chunk, 'json') else str(chunk)}")
tool_use_detected = True
tool_call = {"id": getattr(chunk.content_block, "id", ""), "name": getattr(chunk.content_block, "name", ""), "input": getattr(chunk.content_block, "input", {})}
# We don't signal to the frontend during streaming
# The tool will only be executed after streaming ends
continue
# Handle content_block_delta events for tool_use (input updates)
if tool_use_detected and hasattr(chunk, "type") and chunk.type == "content_block_delta":
if hasattr(chunk, "delta") and hasattr(chunk.delta, "type") and chunk.delta.type == "input_json_delta":
if hasattr(chunk.delta, "partial_json"):
logger.info(f"Tool use input update: {chunk.delta.partial_json}")
# Update the current tool call input
if tool_call:
try:
# Try to parse the partial JSON and update the input
partial_input = json.loads(chunk.delta.partial_json)
tool_call["input"].update(partial_input)
except json.JSONDecodeError:
logger.warning(f"Failed to parse partial JSON: {chunk.delta.partial_json}")
continue
# Handle content_block_stop events for tool_use
if tool_use_detected and hasattr(chunk, "type") and chunk.type == "content_block_stop":
logger.info("Tool use block completed")
# Log the complete tool call for debugging
if tool_call:
logger.info(f"Completed tool call: {json.dumps(tool_call)}")
continue
# Handle message_delta events with stop_reason "tool_use"
if hasattr(chunk, "type") and chunk.type == "message_delta":
if hasattr(chunk, "delta") and hasattr(chunk.delta, "stop_reason"):
if chunk.delta.stop_reason == "tool_use":
logger.info("Message stopped due to tool use")
continue
# Handle regular content chunks
content = None
if hasattr(chunk, "type") and chunk.type == "content_block_delta":
if hasattr(chunk, "delta") and hasattr(chunk.delta, "text"):
@@ -294,4 +371,12 @@ class AnthropicProvider(BaseLLMProvider):
# Don't do any newline replacement here
yield content
return generate()
# Create the generator
gen = generate()
# Attach the single tool_call to the generator object for later reference
# This will be used after streaming is complete
gen.tool_call = tool_call
# Return the enhanced generator
return gen