fix tool call until cookies are legit

This commit is contained in:
2025-03-02 07:44:24 +00:00
parent 6cb60f1bbd
commit 5491ba71aa
6 changed files with 394 additions and 79 deletions

View File

@@ -12,7 +12,7 @@ from typing import Any
from anthropic import Anthropic
from airflow_wingman.providers.base import BaseLLMProvider
from airflow_wingman.providers.base import BaseLLMProvider, StreamingResponse
from airflow_wingman.tools import execute_airflow_tool
from airflow_wingman.tools.conversion import convert_to_anthropic_tools
@@ -101,7 +101,18 @@ class AnthropicProvider(BaseLLMProvider):
# Log as much information as possible
if hasattr(response, "json"):
logger.info(f"Anthropic response json: {json.dumps(response.json)}")
if callable(response.json):
# If json is a method, call it
try:
logger.info(f"Anthropic response json: {json.dumps(response.json())}")
except Exception as json_err:
logger.warning(f"Could not serialize response.json(): {str(json_err)}")
else:
# If json is a property, use it directly
try:
logger.info(f"Anthropic response json: {json.dumps(response.json)}")
except Exception as json_err:
logger.warning(f"Could not serialize response.json: {str(json_err)}")
# Log response attributes
response_attrs = [attr for attr in dir(response) if not attr.startswith("_") and not callable(getattr(response, attr))]
@@ -150,22 +161,63 @@ class AnthropicProvider(BaseLLMProvider):
Check if the response contains tool calls.
Args:
response: Anthropic response object or generator with tool_call attribute
response: Anthropic response object or StreamingResponse with tool_call attribute
Returns:
True if the response contains tool calls, False otherwise
"""
# Check if response is a generator with a tool_call attribute
if hasattr(response, "tool_call") and response.tool_call is not None:
return True
logger.info(f"Checking for tool calls in response of type: {type(response)}")
# Check if response is a StreamingResponse with a tool_call attribute
if isinstance(response, StreamingResponse):
logger.info(f"Response is a StreamingResponse, has tool_call attribute: {hasattr(response, 'tool_call')}")
if response.tool_call is not None:
logger.info(f"StreamingResponse has non-None tool_call: {response.tool_call}")
return True
else:
logger.info("StreamingResponse has None tool_call")
else:
logger.info("Response is not a StreamingResponse")
# Check if any content block is a tool_use block (for non-streaming responses)
if hasattr(response, "content"):
logger.info(f"Response has content attribute with {len(response.content)} blocks")
for i, block in enumerate(response.content):
logger.info(f"Checking content block {i}: {type(block)}")
if isinstance(block, dict) and block.get("type") == "tool_use":
logger.info(f"Found tool_use block: {block}")
return True
else:
logger.info("Response does not have content attribute")
logger.info("No tool calls found in response")
return False
def get_tool_calls(self, response: Any) -> list:
"""
Extract tool calls from the response.
Args:
response: Anthropic response object or StreamingResponse with tool_call attribute
Returns:
List of tool call objects in a standardized format
"""
tool_calls = []
# Check if response is a StreamingResponse with a tool_call attribute
if isinstance(response, StreamingResponse) and response.tool_call is not None:
logger.info(f"Extracting tool call from StreamingResponse: {response.tool_call}")
tool_calls.append(response.tool_call)
# Otherwise, extract tool calls from response content (for non-streaming responses)
elif hasattr(response, "content"):
logger.info("Extracting tool calls from response content")
for block in response.content:
if isinstance(block, dict) and block.get("type") == "tool_use":
return True
tool_call = {"id": block.get("id", ""), "name": block.get("name", ""), "input": block.get("input", {})}
tool_calls.append(tool_call)
return False
return tool_calls
def process_tool_calls(self, response: Any, cookie: str) -> dict[str, Any]:
"""
@@ -183,16 +235,9 @@ class AnthropicProvider(BaseLLMProvider):
if not self.has_tool_calls(response):
return results
tool_calls = []
# Check if response is a generator with a tool_call attribute
if hasattr(response, "tool_call") and response.tool_call is not None:
logger.info(f"Processing tool call from generator: {response.tool_call}")
tool_calls.append(response.tool_call)
# Otherwise, extract tool calls from response content (for non-streaming responses)
elif hasattr(response, "content"):
logger.info("Processing tool calls from response content")
tool_calls = [block for block in response.content if isinstance(block, dict) and block.get("type") == "tool_use"]
# Get tool calls using the standardized method
tool_calls = self.get_tool_calls(response)
logger.info(f"Processing {len(tool_calls)} tool calls")
for tool_call in tool_calls:
# Extract tool details - handle both formats (generator's tool_call and content block)
@@ -221,7 +266,14 @@ class AnthropicProvider(BaseLLMProvider):
return results
def create_follow_up_completion(
self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None
self,
messages: list[dict[str, Any]],
model: str,
temperature: float = 0.4,
max_tokens: int | None = None,
tool_results: dict[str, Any] = None,
original_response: Any = None,
stream: bool = True,
) -> Any:
"""
Create a follow-up completion with tool results.
@@ -233,15 +285,25 @@ class AnthropicProvider(BaseLLMProvider):
max_tokens: Maximum tokens to generate
tool_results: Results of tool executions
original_response: Original response with tool calls
stream: Whether to stream the response
Returns:
Anthropic response object
Anthropic response object or generator if streaming
"""
if not original_response or not tool_results:
return original_response
# Extract tool_use blocks from the original response
tool_use_blocks = [block for block in original_response.content if isinstance(block, dict) and block.get("type") == "tool_use"]
# Extract tool call from the StreamingResponse or content blocks from Anthropic response
tool_use_blocks = []
if isinstance(original_response, StreamingResponse) and original_response.tool_call:
# For StreamingResponse, create a tool_use block from the tool_call
logger.info(f"Creating tool_use block from StreamingResponse.tool_call: {original_response.tool_call}")
tool_call = original_response.tool_call
tool_use_blocks.append({"type": "tool_use", "id": tool_call.get("id", ""), "name": tool_call.get("name", ""), "input": tool_call.get("input", {})})
elif hasattr(original_response, "content"):
# For regular Anthropic response, extract from content blocks
logger.info("Extracting tool_use blocks from response content")
tool_use_blocks = [block for block in original_response.content if isinstance(block, dict) and block.get("type") == "tool_use"]
# Create tool result blocks
tool_result_blocks = []
@@ -258,13 +320,13 @@ class AnthropicProvider(BaseLLMProvider):
anthropic_messages.append({"role": "user", "content": tool_result_blocks})
# Make a second request to get the final response
logger.info("Making second request with tool results")
logger.info(f"Making second request with tool results (stream={stream})")
return self.create_chat_completion(
messages=anthropic_messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
stream=False,
stream=stream,
tools=None, # No tools needed for follow-up
)
@@ -291,7 +353,7 @@ class AnthropicProvider(BaseLLMProvider):
return "".join(content_parts)
def get_streaming_content(self, response: Any) -> Any:
def get_streaming_content(self, response: Any) -> StreamingResponse:
"""
Get a generator for streaming content from the response.
@@ -299,7 +361,8 @@ class AnthropicProvider(BaseLLMProvider):
response: Anthropic streaming response object
Returns:
Generator yielding content chunks with tool_call attribute if detected
StreamingResponse object wrapping a generator that yields content chunks
and can also store tool call information detected during streaming
"""
logger.info("Starting Anthropic streaming response processing")
@@ -307,20 +370,25 @@ class AnthropicProvider(BaseLLMProvider):
tool_call = None
tool_use_detected = False
# Create the StreamingResponse object first
streaming_response = StreamingResponse(generator=None, tool_call=None)
def generate():
nonlocal tool_call, tool_use_detected
for chunk in response:
logger.debug(f"Chunk type: {type(chunk)}")
logger.debug(f"Chunk content: {json.dumps(chunk.json) if hasattr(chunk, 'json') else str(chunk)}")
logger.debug(f"Chunk content: {json.dumps(chunk.model_dump_json()) if hasattr(chunk, 'json') else str(chunk)}")
# Check for content_block_start events with type "tool_use"
if not tool_use_detected and hasattr(chunk, "type") and chunk.type == "content_block_start":
if hasattr(chunk, "content_block") and hasattr(chunk.content_block, "type"):
if chunk.content_block.type == "tool_use":
logger.info(f"Tool use detected in streaming response: {json.dumps(chunk.json) if hasattr(chunk, 'json') else str(chunk)}")
logger.info(f"Tool use detected in streaming response: {json.dumps(chunk.model_dump_json()) if hasattr(chunk, 'json') else str(chunk)}")
tool_use_detected = True
tool_call = {"id": getattr(chunk.content_block, "id", ""), "name": getattr(chunk.content_block, "name", ""), "input": getattr(chunk.content_block, "input", {})}
# Update the StreamingResponse object's tool_call attribute
streaming_response.tool_call = tool_call
# We don't signal to the frontend during streaming
# The tool will only be executed after streaming ends
continue
@@ -328,7 +396,7 @@ class AnthropicProvider(BaseLLMProvider):
# Handle content_block_delta events for tool_use (input updates)
if tool_use_detected and hasattr(chunk, "type") and chunk.type == "content_block_delta":
if hasattr(chunk, "delta") and hasattr(chunk.delta, "type") and chunk.delta.type == "input_json_delta":
if hasattr(chunk.delta, "partial_json"):
if hasattr(chunk.delta, "partial_json") and chunk.delta.partial_json:
logger.info(f"Tool use input update: {chunk.delta.partial_json}")
# Update the current tool call input
if tool_call:
@@ -336,6 +404,8 @@ class AnthropicProvider(BaseLLMProvider):
# Try to parse the partial JSON and update the input
partial_input = json.loads(chunk.delta.partial_json)
tool_call["input"].update(partial_input)
# Update the StreamingResponse object's tool_call attribute
streaming_response.tool_call = tool_call
except json.JSONDecodeError:
logger.warning(f"Failed to parse partial JSON: {chunk.delta.partial_json}")
continue
@@ -346,6 +416,8 @@ class AnthropicProvider(BaseLLMProvider):
# Log the complete tool call for debugging
if tool_call:
logger.info(f"Completed tool call: {json.dumps(tool_call)}")
# Update the StreamingResponse object's tool_call attribute
streaming_response.tool_call = tool_call
continue
# Handle message_delta events with stop_reason "tool_use"
@@ -353,6 +425,9 @@ class AnthropicProvider(BaseLLMProvider):
if hasattr(chunk, "delta") and hasattr(chunk.delta, "stop_reason"):
if chunk.delta.stop_reason == "tool_use":
logger.info("Message stopped due to tool use")
# Update the StreamingResponse object's tool_call attribute one last time
if tool_call:
streaming_response.tool_call = tool_call
continue
# Handle regular content chunks
@@ -374,9 +449,8 @@ class AnthropicProvider(BaseLLMProvider):
# Create the generator
gen = generate()
# Attach the single tool_call to the generator object for later reference
# This will be used after streaming is complete
gen.tool_call = tool_call
# Set the generator in the StreamingResponse object
streaming_response.generator = gen
# Return the enhanced generator
return gen
# Return the StreamingResponse object
return streaming_response