Implement provider-agnostic tool calling for Anthropic streaming responses.
This commit is contained in:
@@ -96,6 +96,17 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
response = self.client.messages.create(**params)
|
||||
|
||||
logger.info("Received response from Anthropic")
|
||||
# Log the response (with sensitive information redacted)
|
||||
logger.info(f"Anthropic response type: {type(response).__name__}")
|
||||
|
||||
# Log as much information as possible
|
||||
if hasattr(response, "json"):
|
||||
logger.info(f"Anthropic response json: {json.dumps(response.json)}")
|
||||
|
||||
# Log response attributes
|
||||
response_attrs = [attr for attr in dir(response) if not attr.startswith("_") and not callable(getattr(response, attr))]
|
||||
logger.info(f"Anthropic response attributes: {response_attrs}")
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
@@ -139,18 +150,20 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
Check if the response contains tool calls.
|
||||
|
||||
Args:
|
||||
response: Anthropic response object
|
||||
response: Anthropic response object or generator with tool_call attribute
|
||||
|
||||
Returns:
|
||||
True if the response contains tool calls, False otherwise
|
||||
"""
|
||||
# Check if any content block is a tool_use block
|
||||
if not hasattr(response, "content"):
|
||||
return False
|
||||
# Check if response is a generator with a tool_call attribute
|
||||
if hasattr(response, "tool_call") and response.tool_call is not None:
|
||||
return True
|
||||
|
||||
for block in response.content:
|
||||
if isinstance(block, dict) and block.get("type") == "tool_use":
|
||||
return True
|
||||
# Check if any content block is a tool_use block (for non-streaming responses)
|
||||
if hasattr(response, "content"):
|
||||
for block in response.content:
|
||||
if isinstance(block, dict) and block.get("type") == "tool_use":
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@@ -159,7 +172,7 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
Process tool calls from the response.
|
||||
|
||||
Args:
|
||||
response: Anthropic response object
|
||||
response: Anthropic response object or generator with tool_call attribute
|
||||
cookie: Airflow cookie for authentication
|
||||
|
||||
Returns:
|
||||
@@ -170,13 +183,29 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
if not self.has_tool_calls(response):
|
||||
return results
|
||||
|
||||
# Extract tool_use blocks
|
||||
tool_use_blocks = [block for block in response.content if isinstance(block, dict) and block.get("type") == "tool_use"]
|
||||
tool_calls = []
|
||||
|
||||
for block in tool_use_blocks:
|
||||
tool_id = block.get("id")
|
||||
tool_name = block.get("name")
|
||||
tool_input = block.get("input", {})
|
||||
# Check if response is a generator with a tool_call attribute
|
||||
if hasattr(response, "tool_call") and response.tool_call is not None:
|
||||
logger.info(f"Processing tool call from generator: {response.tool_call}")
|
||||
tool_calls.append(response.tool_call)
|
||||
# Otherwise, extract tool calls from response content (for non-streaming responses)
|
||||
elif hasattr(response, "content"):
|
||||
logger.info("Processing tool calls from response content")
|
||||
tool_calls = [block for block in response.content if isinstance(block, dict) and block.get("type") == "tool_use"]
|
||||
|
||||
for tool_call in tool_calls:
|
||||
# Extract tool details - handle both formats (generator's tool_call and content block)
|
||||
if isinstance(tool_call, dict) and "id" in tool_call:
|
||||
# This is from the generator's tool_call attribute
|
||||
tool_id = tool_call.get("id")
|
||||
tool_name = tool_call.get("name")
|
||||
tool_input = tool_call.get("input", {})
|
||||
else:
|
||||
# This is from the content blocks
|
||||
tool_id = tool_call.get("id")
|
||||
tool_name = tool_call.get("name")
|
||||
tool_input = tool_call.get("input", {})
|
||||
|
||||
try:
|
||||
# Execute the Airflow tool with the provided arguments and cookie
|
||||
@@ -270,15 +299,63 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
response: Anthropic streaming response object
|
||||
|
||||
Returns:
|
||||
Generator yielding content chunks
|
||||
Generator yielding content chunks with tool_call attribute if detected
|
||||
"""
|
||||
logger.info("Starting Anthropic streaming response processing")
|
||||
|
||||
# Track only the first tool call detected during streaming
|
||||
tool_call = None
|
||||
tool_use_detected = False
|
||||
|
||||
def generate():
|
||||
nonlocal tool_call, tool_use_detected
|
||||
|
||||
for chunk in response:
|
||||
logger.debug(f"Chunk type: {type(chunk)}")
|
||||
logger.debug(f"Chunk content: {json.dumps(chunk.json) if hasattr(chunk, 'json') else str(chunk)}")
|
||||
|
||||
# Handle different types of chunks from Anthropic API
|
||||
# Check for content_block_start events with type "tool_use"
|
||||
if not tool_use_detected and hasattr(chunk, "type") and chunk.type == "content_block_start":
|
||||
if hasattr(chunk, "content_block") and hasattr(chunk.content_block, "type"):
|
||||
if chunk.content_block.type == "tool_use":
|
||||
logger.info(f"Tool use detected in streaming response: {json.dumps(chunk.json) if hasattr(chunk, 'json') else str(chunk)}")
|
||||
tool_use_detected = True
|
||||
tool_call = {"id": getattr(chunk.content_block, "id", ""), "name": getattr(chunk.content_block, "name", ""), "input": getattr(chunk.content_block, "input", {})}
|
||||
# We don't signal to the frontend during streaming
|
||||
# The tool will only be executed after streaming ends
|
||||
continue
|
||||
|
||||
# Handle content_block_delta events for tool_use (input updates)
|
||||
if tool_use_detected and hasattr(chunk, "type") and chunk.type == "content_block_delta":
|
||||
if hasattr(chunk, "delta") and hasattr(chunk.delta, "type") and chunk.delta.type == "input_json_delta":
|
||||
if hasattr(chunk.delta, "partial_json"):
|
||||
logger.info(f"Tool use input update: {chunk.delta.partial_json}")
|
||||
# Update the current tool call input
|
||||
if tool_call:
|
||||
try:
|
||||
# Try to parse the partial JSON and update the input
|
||||
partial_input = json.loads(chunk.delta.partial_json)
|
||||
tool_call["input"].update(partial_input)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse partial JSON: {chunk.delta.partial_json}")
|
||||
continue
|
||||
|
||||
# Handle content_block_stop events for tool_use
|
||||
if tool_use_detected and hasattr(chunk, "type") and chunk.type == "content_block_stop":
|
||||
logger.info("Tool use block completed")
|
||||
# Log the complete tool call for debugging
|
||||
if tool_call:
|
||||
logger.info(f"Completed tool call: {json.dumps(tool_call)}")
|
||||
continue
|
||||
|
||||
# Handle message_delta events with stop_reason "tool_use"
|
||||
if hasattr(chunk, "type") and chunk.type == "message_delta":
|
||||
if hasattr(chunk, "delta") and hasattr(chunk.delta, "stop_reason"):
|
||||
if chunk.delta.stop_reason == "tool_use":
|
||||
logger.info("Message stopped due to tool use")
|
||||
continue
|
||||
|
||||
# Handle regular content chunks
|
||||
content = None
|
||||
if hasattr(chunk, "type") and chunk.type == "content_block_delta":
|
||||
if hasattr(chunk, "delta") and hasattr(chunk.delta, "text"):
|
||||
@@ -294,4 +371,12 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
# Don't do any newline replacement here
|
||||
yield content
|
||||
|
||||
return generate()
|
||||
# Create the generator
|
||||
gen = generate()
|
||||
|
||||
# Attach the single tool_call to the generator object for later reference
|
||||
# This will be used after streaming is complete
|
||||
gen.tool_call = tool_call
|
||||
|
||||
# Return the enhanced generator
|
||||
return gen
|
||||
|
||||
Reference in New Issue
Block a user