Implement provider-agnostic tool calling for Anthropic streaming responses.
This commit is contained in:
@@ -96,6 +96,17 @@ class AnthropicProvider(BaseLLMProvider):
|
|||||||
response = self.client.messages.create(**params)
|
response = self.client.messages.create(**params)
|
||||||
|
|
||||||
logger.info("Received response from Anthropic")
|
logger.info("Received response from Anthropic")
|
||||||
|
# Log the response (with sensitive information redacted)
|
||||||
|
logger.info(f"Anthropic response type: {type(response).__name__}")
|
||||||
|
|
||||||
|
# Log as much information as possible
|
||||||
|
if hasattr(response, "json"):
|
||||||
|
logger.info(f"Anthropic response json: {json.dumps(response.json)}")
|
||||||
|
|
||||||
|
# Log response attributes
|
||||||
|
response_attrs = [attr for attr in dir(response) if not attr.startswith("_") and not callable(getattr(response, attr))]
|
||||||
|
logger.info(f"Anthropic response attributes: {response_attrs}")
|
||||||
|
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = str(e)
|
error_msg = str(e)
|
||||||
@@ -139,15 +150,17 @@ class AnthropicProvider(BaseLLMProvider):
|
|||||||
Check if the response contains tool calls.
|
Check if the response contains tool calls.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
response: Anthropic response object
|
response: Anthropic response object or generator with tool_call attribute
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if the response contains tool calls, False otherwise
|
True if the response contains tool calls, False otherwise
|
||||||
"""
|
"""
|
||||||
# Check if any content block is a tool_use block
|
# Check if response is a generator with a tool_call attribute
|
||||||
if not hasattr(response, "content"):
|
if hasattr(response, "tool_call") and response.tool_call is not None:
|
||||||
return False
|
return True
|
||||||
|
|
||||||
|
# Check if any content block is a tool_use block (for non-streaming responses)
|
||||||
|
if hasattr(response, "content"):
|
||||||
for block in response.content:
|
for block in response.content:
|
||||||
if isinstance(block, dict) and block.get("type") == "tool_use":
|
if isinstance(block, dict) and block.get("type") == "tool_use":
|
||||||
return True
|
return True
|
||||||
@@ -159,7 +172,7 @@ class AnthropicProvider(BaseLLMProvider):
|
|||||||
Process tool calls from the response.
|
Process tool calls from the response.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
response: Anthropic response object
|
response: Anthropic response object or generator with tool_call attribute
|
||||||
cookie: Airflow cookie for authentication
|
cookie: Airflow cookie for authentication
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -170,13 +183,29 @@ class AnthropicProvider(BaseLLMProvider):
|
|||||||
if not self.has_tool_calls(response):
|
if not self.has_tool_calls(response):
|
||||||
return results
|
return results
|
||||||
|
|
||||||
# Extract tool_use blocks
|
tool_calls = []
|
||||||
tool_use_blocks = [block for block in response.content if isinstance(block, dict) and block.get("type") == "tool_use"]
|
|
||||||
|
|
||||||
for block in tool_use_blocks:
|
# Check if response is a generator with a tool_call attribute
|
||||||
tool_id = block.get("id")
|
if hasattr(response, "tool_call") and response.tool_call is not None:
|
||||||
tool_name = block.get("name")
|
logger.info(f"Processing tool call from generator: {response.tool_call}")
|
||||||
tool_input = block.get("input", {})
|
tool_calls.append(response.tool_call)
|
||||||
|
# Otherwise, extract tool calls from response content (for non-streaming responses)
|
||||||
|
elif hasattr(response, "content"):
|
||||||
|
logger.info("Processing tool calls from response content")
|
||||||
|
tool_calls = [block for block in response.content if isinstance(block, dict) and block.get("type") == "tool_use"]
|
||||||
|
|
||||||
|
for tool_call in tool_calls:
|
||||||
|
# Extract tool details - handle both formats (generator's tool_call and content block)
|
||||||
|
if isinstance(tool_call, dict) and "id" in tool_call:
|
||||||
|
# This is from the generator's tool_call attribute
|
||||||
|
tool_id = tool_call.get("id")
|
||||||
|
tool_name = tool_call.get("name")
|
||||||
|
tool_input = tool_call.get("input", {})
|
||||||
|
else:
|
||||||
|
# This is from the content blocks
|
||||||
|
tool_id = tool_call.get("id")
|
||||||
|
tool_name = tool_call.get("name")
|
||||||
|
tool_input = tool_call.get("input", {})
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Execute the Airflow tool with the provided arguments and cookie
|
# Execute the Airflow tool with the provided arguments and cookie
|
||||||
@@ -270,15 +299,63 @@ class AnthropicProvider(BaseLLMProvider):
|
|||||||
response: Anthropic streaming response object
|
response: Anthropic streaming response object
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Generator yielding content chunks
|
Generator yielding content chunks with tool_call attribute if detected
|
||||||
"""
|
"""
|
||||||
logger.info("Starting Anthropic streaming response processing")
|
logger.info("Starting Anthropic streaming response processing")
|
||||||
|
|
||||||
|
# Track only the first tool call detected during streaming
|
||||||
|
tool_call = None
|
||||||
|
tool_use_detected = False
|
||||||
|
|
||||||
def generate():
|
def generate():
|
||||||
|
nonlocal tool_call, tool_use_detected
|
||||||
|
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
logger.debug(f"Chunk type: {type(chunk)}")
|
logger.debug(f"Chunk type: {type(chunk)}")
|
||||||
|
logger.debug(f"Chunk content: {json.dumps(chunk.json) if hasattr(chunk, 'json') else str(chunk)}")
|
||||||
|
|
||||||
# Handle different types of chunks from Anthropic API
|
# Check for content_block_start events with type "tool_use"
|
||||||
|
if not tool_use_detected and hasattr(chunk, "type") and chunk.type == "content_block_start":
|
||||||
|
if hasattr(chunk, "content_block") and hasattr(chunk.content_block, "type"):
|
||||||
|
if chunk.content_block.type == "tool_use":
|
||||||
|
logger.info(f"Tool use detected in streaming response: {json.dumps(chunk.json) if hasattr(chunk, 'json') else str(chunk)}")
|
||||||
|
tool_use_detected = True
|
||||||
|
tool_call = {"id": getattr(chunk.content_block, "id", ""), "name": getattr(chunk.content_block, "name", ""), "input": getattr(chunk.content_block, "input", {})}
|
||||||
|
# We don't signal to the frontend during streaming
|
||||||
|
# The tool will only be executed after streaming ends
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Handle content_block_delta events for tool_use (input updates)
|
||||||
|
if tool_use_detected and hasattr(chunk, "type") and chunk.type == "content_block_delta":
|
||||||
|
if hasattr(chunk, "delta") and hasattr(chunk.delta, "type") and chunk.delta.type == "input_json_delta":
|
||||||
|
if hasattr(chunk.delta, "partial_json"):
|
||||||
|
logger.info(f"Tool use input update: {chunk.delta.partial_json}")
|
||||||
|
# Update the current tool call input
|
||||||
|
if tool_call:
|
||||||
|
try:
|
||||||
|
# Try to parse the partial JSON and update the input
|
||||||
|
partial_input = json.loads(chunk.delta.partial_json)
|
||||||
|
tool_call["input"].update(partial_input)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning(f"Failed to parse partial JSON: {chunk.delta.partial_json}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Handle content_block_stop events for tool_use
|
||||||
|
if tool_use_detected and hasattr(chunk, "type") and chunk.type == "content_block_stop":
|
||||||
|
logger.info("Tool use block completed")
|
||||||
|
# Log the complete tool call for debugging
|
||||||
|
if tool_call:
|
||||||
|
logger.info(f"Completed tool call: {json.dumps(tool_call)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Handle message_delta events with stop_reason "tool_use"
|
||||||
|
if hasattr(chunk, "type") and chunk.type == "message_delta":
|
||||||
|
if hasattr(chunk, "delta") and hasattr(chunk.delta, "stop_reason"):
|
||||||
|
if chunk.delta.stop_reason == "tool_use":
|
||||||
|
logger.info("Message stopped due to tool use")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Handle regular content chunks
|
||||||
content = None
|
content = None
|
||||||
if hasattr(chunk, "type") and chunk.type == "content_block_delta":
|
if hasattr(chunk, "type") and chunk.type == "content_block_delta":
|
||||||
if hasattr(chunk, "delta") and hasattr(chunk.delta, "text"):
|
if hasattr(chunk, "delta") and hasattr(chunk.delta, "text"):
|
||||||
@@ -294,4 +371,12 @@ class AnthropicProvider(BaseLLMProvider):
|
|||||||
# Don't do any newline replacement here
|
# Don't do any newline replacement here
|
||||||
yield content
|
yield content
|
||||||
|
|
||||||
return generate()
|
# Create the generator
|
||||||
|
gen = generate()
|
||||||
|
|
||||||
|
# Attach the single tool_call to the generator object for later reference
|
||||||
|
# This will be used after streaming is complete
|
||||||
|
gen.tool_call = tool_call
|
||||||
|
|
||||||
|
# Return the enhanced generator
|
||||||
|
return gen
|
||||||
|
|||||||
Reference in New Issue
Block a user