fix tool call until cookies are legit
This commit is contained in:
@@ -12,7 +12,7 @@ from typing import Any
|
||||
|
||||
from anthropic import Anthropic
|
||||
|
||||
from airflow_wingman.providers.base import BaseLLMProvider
|
||||
from airflow_wingman.providers.base import BaseLLMProvider, StreamingResponse
|
||||
from airflow_wingman.tools import execute_airflow_tool
|
||||
from airflow_wingman.tools.conversion import convert_to_anthropic_tools
|
||||
|
||||
@@ -101,7 +101,18 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
|
||||
# Log as much information as possible
|
||||
if hasattr(response, "json"):
|
||||
logger.info(f"Anthropic response json: {json.dumps(response.json)}")
|
||||
if callable(response.json):
|
||||
# If json is a method, call it
|
||||
try:
|
||||
logger.info(f"Anthropic response json: {json.dumps(response.json())}")
|
||||
except Exception as json_err:
|
||||
logger.warning(f"Could not serialize response.json(): {str(json_err)}")
|
||||
else:
|
||||
# If json is a property, use it directly
|
||||
try:
|
||||
logger.info(f"Anthropic response json: {json.dumps(response.json)}")
|
||||
except Exception as json_err:
|
||||
logger.warning(f"Could not serialize response.json: {str(json_err)}")
|
||||
|
||||
# Log response attributes
|
||||
response_attrs = [attr for attr in dir(response) if not attr.startswith("_") and not callable(getattr(response, attr))]
|
||||
@@ -150,22 +161,63 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
Check if the response contains tool calls.
|
||||
|
||||
Args:
|
||||
response: Anthropic response object or generator with tool_call attribute
|
||||
response: Anthropic response object or StreamingResponse with tool_call attribute
|
||||
|
||||
Returns:
|
||||
True if the response contains tool calls, False otherwise
|
||||
"""
|
||||
# Check if response is a generator with a tool_call attribute
|
||||
if hasattr(response, "tool_call") and response.tool_call is not None:
|
||||
return True
|
||||
logger.info(f"Checking for tool calls in response of type: {type(response)}")
|
||||
|
||||
# Check if response is a StreamingResponse with a tool_call attribute
|
||||
if isinstance(response, StreamingResponse):
|
||||
logger.info(f"Response is a StreamingResponse, has tool_call attribute: {hasattr(response, 'tool_call')}")
|
||||
if response.tool_call is not None:
|
||||
logger.info(f"StreamingResponse has non-None tool_call: {response.tool_call}")
|
||||
return True
|
||||
else:
|
||||
logger.info("StreamingResponse has None tool_call")
|
||||
else:
|
||||
logger.info("Response is not a StreamingResponse")
|
||||
|
||||
# Check if any content block is a tool_use block (for non-streaming responses)
|
||||
if hasattr(response, "content"):
|
||||
logger.info(f"Response has content attribute with {len(response.content)} blocks")
|
||||
for i, block in enumerate(response.content):
|
||||
logger.info(f"Checking content block {i}: {type(block)}")
|
||||
if isinstance(block, dict) and block.get("type") == "tool_use":
|
||||
logger.info(f"Found tool_use block: {block}")
|
||||
return True
|
||||
else:
|
||||
logger.info("Response does not have content attribute")
|
||||
|
||||
logger.info("No tool calls found in response")
|
||||
return False
|
||||
|
||||
def get_tool_calls(self, response: Any) -> list:
|
||||
"""
|
||||
Extract tool calls from the response.
|
||||
|
||||
Args:
|
||||
response: Anthropic response object or StreamingResponse with tool_call attribute
|
||||
|
||||
Returns:
|
||||
List of tool call objects in a standardized format
|
||||
"""
|
||||
tool_calls = []
|
||||
|
||||
# Check if response is a StreamingResponse with a tool_call attribute
|
||||
if isinstance(response, StreamingResponse) and response.tool_call is not None:
|
||||
logger.info(f"Extracting tool call from StreamingResponse: {response.tool_call}")
|
||||
tool_calls.append(response.tool_call)
|
||||
# Otherwise, extract tool calls from response content (for non-streaming responses)
|
||||
elif hasattr(response, "content"):
|
||||
logger.info("Extracting tool calls from response content")
|
||||
for block in response.content:
|
||||
if isinstance(block, dict) and block.get("type") == "tool_use":
|
||||
return True
|
||||
tool_call = {"id": block.get("id", ""), "name": block.get("name", ""), "input": block.get("input", {})}
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
return False
|
||||
return tool_calls
|
||||
|
||||
def process_tool_calls(self, response: Any, cookie: str) -> dict[str, Any]:
|
||||
"""
|
||||
@@ -183,16 +235,9 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
if not self.has_tool_calls(response):
|
||||
return results
|
||||
|
||||
tool_calls = []
|
||||
|
||||
# Check if response is a generator with a tool_call attribute
|
||||
if hasattr(response, "tool_call") and response.tool_call is not None:
|
||||
logger.info(f"Processing tool call from generator: {response.tool_call}")
|
||||
tool_calls.append(response.tool_call)
|
||||
# Otherwise, extract tool calls from response content (for non-streaming responses)
|
||||
elif hasattr(response, "content"):
|
||||
logger.info("Processing tool calls from response content")
|
||||
tool_calls = [block for block in response.content if isinstance(block, dict) and block.get("type") == "tool_use"]
|
||||
# Get tool calls using the standardized method
|
||||
tool_calls = self.get_tool_calls(response)
|
||||
logger.info(f"Processing {len(tool_calls)} tool calls")
|
||||
|
||||
for tool_call in tool_calls:
|
||||
# Extract tool details - handle both formats (generator's tool_call and content block)
|
||||
@@ -221,7 +266,14 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
return results
|
||||
|
||||
def create_follow_up_completion(
|
||||
self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
model: str,
|
||||
temperature: float = 0.4,
|
||||
max_tokens: int | None = None,
|
||||
tool_results: dict[str, Any] = None,
|
||||
original_response: Any = None,
|
||||
stream: bool = True,
|
||||
) -> Any:
|
||||
"""
|
||||
Create a follow-up completion with tool results.
|
||||
@@ -233,15 +285,25 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
max_tokens: Maximum tokens to generate
|
||||
tool_results: Results of tool executions
|
||||
original_response: Original response with tool calls
|
||||
stream: Whether to stream the response
|
||||
|
||||
Returns:
|
||||
Anthropic response object
|
||||
Anthropic response object or generator if streaming
|
||||
"""
|
||||
if not original_response or not tool_results:
|
||||
return original_response
|
||||
|
||||
# Extract tool_use blocks from the original response
|
||||
tool_use_blocks = [block for block in original_response.content if isinstance(block, dict) and block.get("type") == "tool_use"]
|
||||
# Extract tool call from the StreamingResponse or content blocks from Anthropic response
|
||||
tool_use_blocks = []
|
||||
if isinstance(original_response, StreamingResponse) and original_response.tool_call:
|
||||
# For StreamingResponse, create a tool_use block from the tool_call
|
||||
logger.info(f"Creating tool_use block from StreamingResponse.tool_call: {original_response.tool_call}")
|
||||
tool_call = original_response.tool_call
|
||||
tool_use_blocks.append({"type": "tool_use", "id": tool_call.get("id", ""), "name": tool_call.get("name", ""), "input": tool_call.get("input", {})})
|
||||
elif hasattr(original_response, "content"):
|
||||
# For regular Anthropic response, extract from content blocks
|
||||
logger.info("Extracting tool_use blocks from response content")
|
||||
tool_use_blocks = [block for block in original_response.content if isinstance(block, dict) and block.get("type") == "tool_use"]
|
||||
|
||||
# Create tool result blocks
|
||||
tool_result_blocks = []
|
||||
@@ -258,13 +320,13 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
anthropic_messages.append({"role": "user", "content": tool_result_blocks})
|
||||
|
||||
# Make a second request to get the final response
|
||||
logger.info("Making second request with tool results")
|
||||
logger.info(f"Making second request with tool results (stream={stream})")
|
||||
return self.create_chat_completion(
|
||||
messages=anthropic_messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stream=False,
|
||||
stream=stream,
|
||||
tools=None, # No tools needed for follow-up
|
||||
)
|
||||
|
||||
@@ -291,7 +353,7 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
|
||||
return "".join(content_parts)
|
||||
|
||||
def get_streaming_content(self, response: Any) -> Any:
|
||||
def get_streaming_content(self, response: Any) -> StreamingResponse:
|
||||
"""
|
||||
Get a generator for streaming content from the response.
|
||||
|
||||
@@ -299,7 +361,8 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
response: Anthropic streaming response object
|
||||
|
||||
Returns:
|
||||
Generator yielding content chunks with tool_call attribute if detected
|
||||
StreamingResponse object wrapping a generator that yields content chunks
|
||||
and can also store tool call information detected during streaming
|
||||
"""
|
||||
logger.info("Starting Anthropic streaming response processing")
|
||||
|
||||
@@ -307,20 +370,25 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
tool_call = None
|
||||
tool_use_detected = False
|
||||
|
||||
# Create the StreamingResponse object first
|
||||
streaming_response = StreamingResponse(generator=None, tool_call=None)
|
||||
|
||||
def generate():
|
||||
nonlocal tool_call, tool_use_detected
|
||||
|
||||
for chunk in response:
|
||||
logger.debug(f"Chunk type: {type(chunk)}")
|
||||
logger.debug(f"Chunk content: {json.dumps(chunk.json) if hasattr(chunk, 'json') else str(chunk)}")
|
||||
logger.debug(f"Chunk content: {json.dumps(chunk.model_dump_json()) if hasattr(chunk, 'json') else str(chunk)}")
|
||||
|
||||
# Check for content_block_start events with type "tool_use"
|
||||
if not tool_use_detected and hasattr(chunk, "type") and chunk.type == "content_block_start":
|
||||
if hasattr(chunk, "content_block") and hasattr(chunk.content_block, "type"):
|
||||
if chunk.content_block.type == "tool_use":
|
||||
logger.info(f"Tool use detected in streaming response: {json.dumps(chunk.json) if hasattr(chunk, 'json') else str(chunk)}")
|
||||
logger.info(f"Tool use detected in streaming response: {json.dumps(chunk.model_dump_json()) if hasattr(chunk, 'json') else str(chunk)}")
|
||||
tool_use_detected = True
|
||||
tool_call = {"id": getattr(chunk.content_block, "id", ""), "name": getattr(chunk.content_block, "name", ""), "input": getattr(chunk.content_block, "input", {})}
|
||||
# Update the StreamingResponse object's tool_call attribute
|
||||
streaming_response.tool_call = tool_call
|
||||
# We don't signal to the frontend during streaming
|
||||
# The tool will only be executed after streaming ends
|
||||
continue
|
||||
@@ -328,7 +396,7 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
# Handle content_block_delta events for tool_use (input updates)
|
||||
if tool_use_detected and hasattr(chunk, "type") and chunk.type == "content_block_delta":
|
||||
if hasattr(chunk, "delta") and hasattr(chunk.delta, "type") and chunk.delta.type == "input_json_delta":
|
||||
if hasattr(chunk.delta, "partial_json"):
|
||||
if hasattr(chunk.delta, "partial_json") and chunk.delta.partial_json:
|
||||
logger.info(f"Tool use input update: {chunk.delta.partial_json}")
|
||||
# Update the current tool call input
|
||||
if tool_call:
|
||||
@@ -336,6 +404,8 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
# Try to parse the partial JSON and update the input
|
||||
partial_input = json.loads(chunk.delta.partial_json)
|
||||
tool_call["input"].update(partial_input)
|
||||
# Update the StreamingResponse object's tool_call attribute
|
||||
streaming_response.tool_call = tool_call
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse partial JSON: {chunk.delta.partial_json}")
|
||||
continue
|
||||
@@ -346,6 +416,8 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
# Log the complete tool call for debugging
|
||||
if tool_call:
|
||||
logger.info(f"Completed tool call: {json.dumps(tool_call)}")
|
||||
# Update the StreamingResponse object's tool_call attribute
|
||||
streaming_response.tool_call = tool_call
|
||||
continue
|
||||
|
||||
# Handle message_delta events with stop_reason "tool_use"
|
||||
@@ -353,6 +425,9 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
if hasattr(chunk, "delta") and hasattr(chunk.delta, "stop_reason"):
|
||||
if chunk.delta.stop_reason == "tool_use":
|
||||
logger.info("Message stopped due to tool use")
|
||||
# Update the StreamingResponse object's tool_call attribute one last time
|
||||
if tool_call:
|
||||
streaming_response.tool_call = tool_call
|
||||
continue
|
||||
|
||||
# Handle regular content chunks
|
||||
@@ -374,9 +449,8 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
# Create the generator
|
||||
gen = generate()
|
||||
|
||||
# Attach the single tool_call to the generator object for later reference
|
||||
# This will be used after streaming is complete
|
||||
gen.tool_call = tool_call
|
||||
# Set the generator in the StreamingResponse object
|
||||
streaming_response.generator = gen
|
||||
|
||||
# Return the enhanced generator
|
||||
return gen
|
||||
# Return the StreamingResponse object
|
||||
return streaming_response
|
||||
|
||||
Reference in New Issue
Block a user