Files
airflow-wingman/src/airflow_wingman/providers/anthropic_provider.py

459 lines
20 KiB
Python

"""
Anthropic provider implementation for Airflow Wingman.
This module contains the Anthropic provider implementation that handles
API requests, tool conversion, and response processing for Anthropic's Claude models.
"""
import json
import logging
import traceback
from typing import Any
from anthropic import Anthropic
from airflow_wingman.providers.base import BaseLLMProvider, StreamingResponse
from airflow_wingman.tools import execute_airflow_tool
from airflow_wingman.tools.conversion import convert_to_anthropic_tools
# Create a properly namespaced logger for the Airflow plugin
logger = logging.getLogger("airflow.plugins.wingman")
class AnthropicProvider(BaseLLMProvider):
"""
Anthropic provider implementation.
This class handles API requests, tool conversion, and response processing
for the Anthropic API (Claude models).
"""
def __init__(self, api_key: str):
"""
Initialize the Anthropic provider.
Args:
api_key: API key for Anthropic
"""
self.api_key = api_key
self.client = Anthropic(api_key=api_key)
def convert_tools(self, airflow_tools: list) -> list:
"""
Convert Airflow tools to Anthropic format.
Args:
airflow_tools: List of Airflow tools from MCP server
Returns:
List of Anthropic tool definitions
"""
return convert_to_anthropic_tools(airflow_tools)
def create_chat_completion(
self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, stream: bool = False, tools: list[dict[str, Any]] | None = None
) -> Any:
"""
Make API request to Anthropic.
Args:
messages: List of message dictionaries with 'role' and 'content'
model: Model identifier
temperature: Sampling temperature (0-1)
max_tokens: Maximum tokens to generate
stream: Whether to stream the response
tools: List of tool definitions in Anthropic format
Returns:
Anthropic response object
Raises:
Exception: If the API request fails
"""
# Convert max_tokens to Anthropic's max_tokens parameter (if provided)
max_tokens_param = max_tokens if max_tokens is not None else 4096
# Convert messages from ChatML format to Anthropic's format
anthropic_messages = self._convert_to_anthropic_messages(messages)
try:
logger.info(f"Sending chat completion request to Anthropic with model: {model}")
# Create request parameters
params = {"model": model, "messages": anthropic_messages, "temperature": temperature, "max_tokens": max_tokens_param, "stream": stream}
# Add tools if provided
if tools and len(tools) > 0:
params["tools"] = tools
else:
logger.warning("No tools included in request")
# Log the full request parameters (with sensitive information redacted)
log_params = params.copy()
logger.info(f"Request parameters: {json.dumps(log_params)}")
# Make the API request
response = self.client.messages.create(**params)
logger.info("Received response from Anthropic")
# Log the response (with sensitive information redacted)
logger.info(f"Anthropic response type: {type(response).__name__}")
# Log as much information as possible
if hasattr(response, "json"):
if callable(response.json):
# If json is a method, call it
try:
logger.info(f"Anthropic response json: {json.dumps(response.model_dump_json())}")
except Exception as json_err:
logger.warning(f"Could not serialize response.json(): {str(json_err)}")
else:
# If json is a property, use it directly
try:
logger.info(f"Anthropic response json: {json.dumps(response.json)}")
except Exception as json_err:
logger.warning(f"Could not serialize response.json: {str(json_err)}")
# Log response attributes
response_attrs = [attr for attr in dir(response) if not attr.startswith("_") and not callable(getattr(response, attr))]
logger.info(f"Anthropic response attributes: {response_attrs}")
return response
except Exception as e:
error_msg = str(e)
logger.error(f"Failed to get response from Anthropic: {error_msg}\n{traceback.format_exc()}")
raise
def _convert_to_anthropic_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""
Convert messages from ChatML format to Anthropic's format.
Args:
messages: List of message dictionaries in ChatML format
Returns:
List of message dictionaries in Anthropic format
"""
anthropic_messages = []
for message in messages:
role = message["role"]
content = message["content"]
# Map ChatML roles to Anthropic roles
if role == "system":
# System messages in Anthropic are handled differently
# We'll add them as a user message with a special prefix
anthropic_messages.append({"role": "user", "content": f"<system>\n{content}\n</system>"})
elif role == "user":
anthropic_messages.append({"role": "user", "content": content})
elif role == "assistant":
anthropic_messages.append({"role": "assistant", "content": content})
elif role == "tool":
# Tool messages in ChatML become part of the user message in Anthropic
# We'll handle this in the follow-up completion
continue
return anthropic_messages
def has_tool_calls(self, response: Any) -> bool:
"""
Check if the response contains tool calls.
Args:
response: Anthropic response object or StreamingResponse with tool_call attribute
Returns:
True if the response contains tool calls, False otherwise
"""
logger.info(f"Checking for tool calls in response of type: {type(response)}")
# Check if response is a StreamingResponse with a tool_call attribute
if isinstance(response, StreamingResponse):
logger.info(f"Response is a StreamingResponse, has tool_call attribute: {hasattr(response, 'tool_call')}")
if response.tool_call is not None:
logger.info(f"StreamingResponse has non-None tool_call: {response.tool_call}")
return True
else:
logger.info("StreamingResponse has None tool_call")
else:
logger.info("Response is not a StreamingResponse")
# Check if any content block is a tool_use block (for non-streaming responses)
if hasattr(response, "content"):
logger.info(f"Response has content attribute with {len(response.content)} blocks")
for i, block in enumerate(response.content):
logger.info(f"Checking content block {i}: {type(block)}")
if isinstance(block, dict) and block.get("type") == "tool_use":
logger.info(f"Found tool_use block: {block}")
return True
else:
logger.info("Response does not have content attribute")
logger.info("No tool calls found in response")
return False
def get_tool_calls(self, response: Any) -> list:
"""
Extract tool calls from the response.
Args:
response: Anthropic response object or StreamingResponse with tool_call attribute
Returns:
List of tool call objects in a standardized format
"""
tool_calls = []
# Check if response is a StreamingResponse with a tool_call attribute
if isinstance(response, StreamingResponse) and response.tool_call is not None:
logger.info(f"Extracting tool call from StreamingResponse: {response.tool_call}")
tool_calls.append(response.tool_call)
# Otherwise, extract tool calls from response content (for non-streaming responses)
elif hasattr(response, "content"):
logger.info("Extracting tool calls from response content")
for block in response.content:
if isinstance(block, dict) and block.get("type") == "tool_use":
tool_call = {"id": block.get("id", ""), "name": block.get("name", ""), "input": block.get("input", {})}
tool_calls.append(tool_call)
return tool_calls
def process_tool_calls(self, response: Any, cookie: str) -> dict[str, Any]:
"""
Process tool calls from the response.
Args:
response: Anthropic response object or generator with tool_call attribute
cookie: Airflow cookie for authentication
Returns:
Dictionary mapping tool call IDs to results
"""
results = {}
if not self.has_tool_calls(response):
return results
# Get tool calls using the standardized method
tool_calls = self.get_tool_calls(response)
logger.info(f"Processing {len(tool_calls)} tool calls")
for tool_call in tool_calls:
# Extract tool details - handle both formats (generator's tool_call and content block)
if isinstance(tool_call, dict) and "id" in tool_call:
# This is from the generator's tool_call attribute
tool_id = tool_call.get("id")
tool_name = tool_call.get("name")
tool_input = tool_call.get("input", {})
else:
# This is from the content blocks
tool_id = tool_call.get("id")
tool_name = tool_call.get("name")
tool_input = tool_call.get("input", {})
try:
# Execute the Airflow tool with the provided arguments and cookie
logger.info(f"Executing tool: {tool_name} with arguments: {tool_input}")
result = execute_airflow_tool(tool_name, tool_input, cookie)
logger.info(f"Tool execution result: {result}")
results[tool_id] = {"status": "success", "result": result}
except Exception as e:
error_msg = f"Error executing tool: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
results[tool_id] = {"status": "error", "message": error_msg}
return results
def create_follow_up_completion(
self,
messages: list[dict[str, Any]],
model: str,
temperature: float = 0.4,
max_tokens: int | None = None,
tool_results: dict[str, Any] = None,
original_response: Any = None,
stream: bool = True,
tools: list[dict[str, Any]] | None = None,
) -> Any:
"""
Create a follow-up completion with tool results.
Args:
messages: Original messages
model: Model identifier
temperature: Sampling temperature (0-1)
max_tokens: Maximum tokens to generate
tool_results: Results of tool executions
original_response: Original response with tool calls
stream: Whether to stream the response
tools: List of tool definitions in Anthropic format
Returns:
Anthropic response object or generator if streaming
"""
if not original_response or not tool_results:
return original_response
# Extract tool call from the StreamingResponse or content blocks from Anthropic response
tool_use_blocks = []
if isinstance(original_response, StreamingResponse) and original_response.tool_call:
# For StreamingResponse, create a tool_use block from the tool_call
logger.info(f"Creating tool_use block from StreamingResponse.tool_call: {original_response.tool_call}")
tool_call = original_response.tool_call
tool_use_blocks.append({"type": "tool_use", "id": tool_call.get("id", ""), "name": tool_call.get("name", ""), "input": tool_call.get("input", {})})
elif hasattr(original_response, "content"):
# For regular Anthropic response, extract from content blocks
logger.info("Extracting tool_use blocks from response content")
tool_use_blocks = [block for block in original_response.content if isinstance(block, dict) and block.get("type") == "tool_use"]
# Create tool result blocks
tool_result_blocks = []
for tool_id, result in tool_results.items():
tool_result_blocks.append({"type": "tool_result", "tool_use_id": tool_id, "content": result.get("result", str(result))})
# Convert original messages to Anthropic format
anthropic_messages = self._convert_to_anthropic_messages(messages)
# Add the assistant response with tool use
anthropic_messages.append({"role": "assistant", "content": tool_use_blocks})
# Add the user message with tool results
anthropic_messages.append({"role": "user", "content": tool_result_blocks})
# Make a second request to get the final response
logger.info(f"Making second request with tool results (stream={stream})")
return self.create_chat_completion(
messages=anthropic_messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
stream=stream,
tools=tools,
)
def get_content(self, response: Any) -> str:
"""
Extract content from the response.
Args:
response: Anthropic response object
Returns:
Content string from the response
"""
if not hasattr(response, "content"):
return ""
# Combine all text blocks into a single string
content_parts = []
for block in response.content:
if isinstance(block, dict) and block.get("type") == "text":
content_parts.append(block.get("text", ""))
elif isinstance(block, str):
content_parts.append(block)
return "".join(content_parts)
def get_streaming_content(self, response: Any) -> StreamingResponse:
"""
Get a generator for streaming content from the response.
Args:
response: Anthropic streaming response object
Returns:
StreamingResponse object wrapping a generator that yields content chunks
and can also store tool call information detected during streaming
"""
logger.info("Starting Anthropic streaming response processing")
# Track only the first tool call detected during streaming
tool_call = None
tool_use_detected = False
# Create the StreamingResponse object first
streaming_response = StreamingResponse(generator=None, tool_call=None)
def generate():
nonlocal tool_call, tool_use_detected
for chunk in response:
logger.debug(f"Chunk type: {type(chunk)}")
logger.debug(f"Chunk content: {json.dumps(chunk.model_dump_json()) if hasattr(chunk, 'json') else str(chunk)}")
# Check for content_block_start events with type "tool_use"
if not tool_use_detected and hasattr(chunk, "type") and chunk.type == "content_block_start":
if hasattr(chunk, "content_block") and hasattr(chunk.content_block, "type"):
if chunk.content_block.type == "tool_use":
logger.info(f"Tool use detected in streaming response: {json.dumps(chunk.model_dump_json()) if hasattr(chunk, 'json') else str(chunk)}")
tool_use_detected = True
tool_call = {"id": getattr(chunk.content_block, "id", ""), "name": getattr(chunk.content_block, "name", ""), "input": getattr(chunk.content_block, "input", {})}
# Update the StreamingResponse object's tool_call attribute
streaming_response.tool_call = tool_call
# We don't signal to the frontend during streaming
# The tool will only be executed after streaming ends
continue
# Handle content_block_delta events for tool_use (input updates)
if tool_use_detected and hasattr(chunk, "type") and chunk.type == "content_block_delta":
if hasattr(chunk, "delta") and hasattr(chunk.delta, "type") and chunk.delta.type == "input_json_delta":
if hasattr(chunk.delta, "partial_json") and chunk.delta.partial_json:
logger.info(f"Tool use input update: {chunk.delta.partial_json}")
# Update the current tool call input
if tool_call:
try:
# Try to parse the partial JSON and update the input
partial_input = json.loads(chunk.delta.partial_json)
tool_call["input"].update(partial_input)
# Update the StreamingResponse object's tool_call attribute
streaming_response.tool_call = tool_call
except json.JSONDecodeError:
logger.warning(f"Failed to parse partial JSON: {chunk.delta.partial_json}")
continue
# Handle content_block_stop events for tool_use
if tool_use_detected and hasattr(chunk, "type") and chunk.type == "content_block_stop":
logger.info("Tool use block completed")
# Log the complete tool call for debugging
if tool_call:
logger.info(f"Completed tool call: {json.dumps(tool_call)}")
# Update the StreamingResponse object's tool_call attribute
streaming_response.tool_call = tool_call
continue
# Handle message_delta events with stop_reason "tool_use"
if hasattr(chunk, "type") and chunk.type == "message_delta":
if hasattr(chunk, "delta") and hasattr(chunk.delta, "stop_reason"):
if chunk.delta.stop_reason == "tool_use":
logger.info("Message stopped due to tool use")
# Update the StreamingResponse object's tool_call attribute one last time
if tool_call:
streaming_response.tool_call = tool_call
continue
# Handle regular content chunks
content = None
if hasattr(chunk, "type") and chunk.type == "content_block_delta":
if hasattr(chunk, "delta") and hasattr(chunk.delta, "text"):
content = chunk.delta.text
elif hasattr(chunk, "delta") and hasattr(chunk.delta, "text"):
content = chunk.delta.text
elif hasattr(chunk, "content") and chunk.content:
for block in chunk.content:
if isinstance(block, dict) and block.get("type") == "text":
content = block.get("text", "")
if content:
# Don't do any newline replacement here
yield content
# Create the generator
gen = generate()
# Set the generator in the StreamingResponse object
streaming_response.generator = gen
# Return the StreamingResponse object
return streaming_response