From 6cb60f1bbd394d3fce64e89f30fe564e20a8898b Mon Sep 17 00:00:00 2001 From: abhishekbhakat Date: Sun, 2 Mar 2025 05:24:16 +0000 Subject: [PATCH] Implement provider-agnostic tool calling for Anthropic streaming responses. --- .../providers/anthropic_provider.py | 119 +++++++++++++++--- 1 file changed, 102 insertions(+), 17 deletions(-) diff --git a/src/airflow_wingman/providers/anthropic_provider.py b/src/airflow_wingman/providers/anthropic_provider.py index 40e5787..6996e30 100644 --- a/src/airflow_wingman/providers/anthropic_provider.py +++ b/src/airflow_wingman/providers/anthropic_provider.py @@ -96,6 +96,17 @@ class AnthropicProvider(BaseLLMProvider): response = self.client.messages.create(**params) logger.info("Received response from Anthropic") + # Log the response (with sensitive information redacted) + logger.info(f"Anthropic response type: {type(response).__name__}") + + # Log as much information as possible + if hasattr(response, "json"): + logger.info(f"Anthropic response json: {json.dumps(response.json)}") + + # Log response attributes + response_attrs = [attr for attr in dir(response) if not attr.startswith("_") and not callable(getattr(response, attr))] + logger.info(f"Anthropic response attributes: {response_attrs}") + return response except Exception as e: error_msg = str(e) @@ -139,18 +150,20 @@ class AnthropicProvider(BaseLLMProvider): Check if the response contains tool calls. Args: - response: Anthropic response object + response: Anthropic response object or generator with tool_call attribute Returns: True if the response contains tool calls, False otherwise """ - # Check if any content block is a tool_use block - if not hasattr(response, "content"): - return False + # Check if response is a generator with a tool_call attribute + if hasattr(response, "tool_call") and response.tool_call is not None: + return True - for block in response.content: - if isinstance(block, dict) and block.get("type") == "tool_use": - return True + # Check if any content block is a tool_use block (for non-streaming responses) + if hasattr(response, "content"): + for block in response.content: + if isinstance(block, dict) and block.get("type") == "tool_use": + return True return False @@ -159,7 +172,7 @@ class AnthropicProvider(BaseLLMProvider): Process tool calls from the response. Args: - response: Anthropic response object + response: Anthropic response object or generator with tool_call attribute cookie: Airflow cookie for authentication Returns: @@ -170,13 +183,29 @@ class AnthropicProvider(BaseLLMProvider): if not self.has_tool_calls(response): return results - # Extract tool_use blocks - tool_use_blocks = [block for block in response.content if isinstance(block, dict) and block.get("type") == "tool_use"] + tool_calls = [] - for block in tool_use_blocks: - tool_id = block.get("id") - tool_name = block.get("name") - tool_input = block.get("input", {}) + # Check if response is a generator with a tool_call attribute + if hasattr(response, "tool_call") and response.tool_call is not None: + logger.info(f"Processing tool call from generator: {response.tool_call}") + tool_calls.append(response.tool_call) + # Otherwise, extract tool calls from response content (for non-streaming responses) + elif hasattr(response, "content"): + logger.info("Processing tool calls from response content") + tool_calls = [block for block in response.content if isinstance(block, dict) and block.get("type") == "tool_use"] + + for tool_call in tool_calls: + # Extract tool details - handle both formats (generator's tool_call and content block) + if isinstance(tool_call, dict) and "id" in tool_call: + # This is from the generator's tool_call attribute + tool_id = tool_call.get("id") + tool_name = tool_call.get("name") + tool_input = tool_call.get("input", {}) + else: + # This is from the content blocks + tool_id = tool_call.get("id") + tool_name = tool_call.get("name") + tool_input = tool_call.get("input", {}) try: # Execute the Airflow tool with the provided arguments and cookie @@ -270,15 +299,63 @@ class AnthropicProvider(BaseLLMProvider): response: Anthropic streaming response object Returns: - Generator yielding content chunks + Generator yielding content chunks with tool_call attribute if detected """ logger.info("Starting Anthropic streaming response processing") + # Track only the first tool call detected during streaming + tool_call = None + tool_use_detected = False + def generate(): + nonlocal tool_call, tool_use_detected + for chunk in response: logger.debug(f"Chunk type: {type(chunk)}") + logger.debug(f"Chunk content: {json.dumps(chunk.json) if hasattr(chunk, 'json') else str(chunk)}") - # Handle different types of chunks from Anthropic API + # Check for content_block_start events with type "tool_use" + if not tool_use_detected and hasattr(chunk, "type") and chunk.type == "content_block_start": + if hasattr(chunk, "content_block") and hasattr(chunk.content_block, "type"): + if chunk.content_block.type == "tool_use": + logger.info(f"Tool use detected in streaming response: {json.dumps(chunk.json) if hasattr(chunk, 'json') else str(chunk)}") + tool_use_detected = True + tool_call = {"id": getattr(chunk.content_block, "id", ""), "name": getattr(chunk.content_block, "name", ""), "input": getattr(chunk.content_block, "input", {})} + # We don't signal to the frontend during streaming + # The tool will only be executed after streaming ends + continue + + # Handle content_block_delta events for tool_use (input updates) + if tool_use_detected and hasattr(chunk, "type") and chunk.type == "content_block_delta": + if hasattr(chunk, "delta") and hasattr(chunk.delta, "type") and chunk.delta.type == "input_json_delta": + if hasattr(chunk.delta, "partial_json"): + logger.info(f"Tool use input update: {chunk.delta.partial_json}") + # Update the current tool call input + if tool_call: + try: + # Try to parse the partial JSON and update the input + partial_input = json.loads(chunk.delta.partial_json) + tool_call["input"].update(partial_input) + except json.JSONDecodeError: + logger.warning(f"Failed to parse partial JSON: {chunk.delta.partial_json}") + continue + + # Handle content_block_stop events for tool_use + if tool_use_detected and hasattr(chunk, "type") and chunk.type == "content_block_stop": + logger.info("Tool use block completed") + # Log the complete tool call for debugging + if tool_call: + logger.info(f"Completed tool call: {json.dumps(tool_call)}") + continue + + # Handle message_delta events with stop_reason "tool_use" + if hasattr(chunk, "type") and chunk.type == "message_delta": + if hasattr(chunk, "delta") and hasattr(chunk.delta, "stop_reason"): + if chunk.delta.stop_reason == "tool_use": + logger.info("Message stopped due to tool use") + continue + + # Handle regular content chunks content = None if hasattr(chunk, "type") and chunk.type == "content_block_delta": if hasattr(chunk, "delta") and hasattr(chunk.delta, "text"): @@ -294,4 +371,12 @@ class AnthropicProvider(BaseLLMProvider): # Don't do any newline replacement here yield content - return generate() + # Create the generator + gen = generate() + + # Attach the single tool_call to the generator object for later reference + # This will be used after streaming is complete + gen.tool_call = tool_call + + # Return the enhanced generator + return gen