fix tool call until cookies are legit

This commit is contained in:
2025-03-02 07:44:24 +00:00
parent 6cb60f1bbd
commit 5491ba71aa
6 changed files with 394 additions and 79 deletions

View File

@@ -81,10 +81,8 @@ class LLMClient:
# If streaming, handle based on return_response_obj flag # If streaming, handle based on return_response_obj flag
if stream: if stream:
logger.info(f"Using streaming response from {self.provider_name}") logger.info(f"Using streaming response from {self.provider_name}")
if return_response_obj: streaming_content = self.provider.get_streaming_content(response)
return response, self.provider.get_streaming_content(response) return streaming_content
else:
return self.provider.get_streaming_content(response)
# For non-streaming responses, handle tool calls if present # For non-streaming responses, handle tool calls if present
if self.provider.has_tool_calls(response): if self.provider.has_tool_calls(response):
@@ -143,7 +141,7 @@ class LLMClient:
return cls(provider_name=provider_name, api_key=api_key, base_url=base_url) return cls(provider_name=provider_name, api_key=api_key, base_url=base_url)
def process_tool_calls_and_follow_up(self, response, messages, model, temperature, max_tokens, max_iterations=5): def process_tool_calls_and_follow_up(self, response, messages, model, temperature, max_tokens, max_iterations=5, cookie=None):
""" """
Process tool calls recursively from a response and make follow-up requests until Process tool calls recursively from a response and make follow-up requests until
there are no more tool calls or max_iterations is reached. there are no more tool calls or max_iterations is reached.
@@ -156,6 +154,7 @@ class LLMClient:
temperature: Sampling temperature (0-1) temperature: Sampling temperature (0-1)
max_tokens: Maximum tokens to generate max_tokens: Maximum tokens to generate
max_iterations: Maximum number of tool call iterations to prevent infinite loops max_iterations: Maximum number of tool call iterations to prevent infinite loops
cookie: Airflow cookie for authentication (optional, will try to get from session if not provided)
Returns: Returns:
Generator for streaming the final follow-up response Generator for streaming the final follow-up response
@@ -163,8 +162,8 @@ class LLMClient:
try: try:
iteration = 0 iteration = 0
current_response = response current_response = response
cookie = session.get("airflow_cookie")
# Check if we have a cookie
if not cookie: if not cookie:
error_msg = "No Airflow cookie available" error_msg = "No Airflow cookie available"
logger.error(error_msg) logger.error(error_msg)

View File

@@ -12,7 +12,7 @@ from typing import Any
from anthropic import Anthropic from anthropic import Anthropic
from airflow_wingman.providers.base import BaseLLMProvider from airflow_wingman.providers.base import BaseLLMProvider, StreamingResponse
from airflow_wingman.tools import execute_airflow_tool from airflow_wingman.tools import execute_airflow_tool
from airflow_wingman.tools.conversion import convert_to_anthropic_tools from airflow_wingman.tools.conversion import convert_to_anthropic_tools
@@ -101,7 +101,18 @@ class AnthropicProvider(BaseLLMProvider):
# Log as much information as possible # Log as much information as possible
if hasattr(response, "json"): if hasattr(response, "json"):
if callable(response.json):
# If json is a method, call it
try:
logger.info(f"Anthropic response json: {json.dumps(response.json())}")
except Exception as json_err:
logger.warning(f"Could not serialize response.json(): {str(json_err)}")
else:
# If json is a property, use it directly
try:
logger.info(f"Anthropic response json: {json.dumps(response.json)}") logger.info(f"Anthropic response json: {json.dumps(response.json)}")
except Exception as json_err:
logger.warning(f"Could not serialize response.json: {str(json_err)}")
# Log response attributes # Log response attributes
response_attrs = [attr for attr in dir(response) if not attr.startswith("_") and not callable(getattr(response, attr))] response_attrs = [attr for attr in dir(response) if not attr.startswith("_") and not callable(getattr(response, attr))]
@@ -150,22 +161,63 @@ class AnthropicProvider(BaseLLMProvider):
Check if the response contains tool calls. Check if the response contains tool calls.
Args: Args:
response: Anthropic response object or generator with tool_call attribute response: Anthropic response object or StreamingResponse with tool_call attribute
Returns: Returns:
True if the response contains tool calls, False otherwise True if the response contains tool calls, False otherwise
""" """
# Check if response is a generator with a tool_call attribute logger.info(f"Checking for tool calls in response of type: {type(response)}")
if hasattr(response, "tool_call") and response.tool_call is not None:
# Check if response is a StreamingResponse with a tool_call attribute
if isinstance(response, StreamingResponse):
logger.info(f"Response is a StreamingResponse, has tool_call attribute: {hasattr(response, 'tool_call')}")
if response.tool_call is not None:
logger.info(f"StreamingResponse has non-None tool_call: {response.tool_call}")
return True return True
else:
logger.info("StreamingResponse has None tool_call")
else:
logger.info("Response is not a StreamingResponse")
# Check if any content block is a tool_use block (for non-streaming responses) # Check if any content block is a tool_use block (for non-streaming responses)
if hasattr(response, "content"): if hasattr(response, "content"):
logger.info(f"Response has content attribute with {len(response.content)} blocks")
for i, block in enumerate(response.content):
logger.info(f"Checking content block {i}: {type(block)}")
if isinstance(block, dict) and block.get("type") == "tool_use":
logger.info(f"Found tool_use block: {block}")
return True
else:
logger.info("Response does not have content attribute")
logger.info("No tool calls found in response")
return False
def get_tool_calls(self, response: Any) -> list:
"""
Extract tool calls from the response.
Args:
response: Anthropic response object or StreamingResponse with tool_call attribute
Returns:
List of tool call objects in a standardized format
"""
tool_calls = []
# Check if response is a StreamingResponse with a tool_call attribute
if isinstance(response, StreamingResponse) and response.tool_call is not None:
logger.info(f"Extracting tool call from StreamingResponse: {response.tool_call}")
tool_calls.append(response.tool_call)
# Otherwise, extract tool calls from response content (for non-streaming responses)
elif hasattr(response, "content"):
logger.info("Extracting tool calls from response content")
for block in response.content: for block in response.content:
if isinstance(block, dict) and block.get("type") == "tool_use": if isinstance(block, dict) and block.get("type") == "tool_use":
return True tool_call = {"id": block.get("id", ""), "name": block.get("name", ""), "input": block.get("input", {})}
tool_calls.append(tool_call)
return False return tool_calls
def process_tool_calls(self, response: Any, cookie: str) -> dict[str, Any]: def process_tool_calls(self, response: Any, cookie: str) -> dict[str, Any]:
""" """
@@ -183,16 +235,9 @@ class AnthropicProvider(BaseLLMProvider):
if not self.has_tool_calls(response): if not self.has_tool_calls(response):
return results return results
tool_calls = [] # Get tool calls using the standardized method
tool_calls = self.get_tool_calls(response)
# Check if response is a generator with a tool_call attribute logger.info(f"Processing {len(tool_calls)} tool calls")
if hasattr(response, "tool_call") and response.tool_call is not None:
logger.info(f"Processing tool call from generator: {response.tool_call}")
tool_calls.append(response.tool_call)
# Otherwise, extract tool calls from response content (for non-streaming responses)
elif hasattr(response, "content"):
logger.info("Processing tool calls from response content")
tool_calls = [block for block in response.content if isinstance(block, dict) and block.get("type") == "tool_use"]
for tool_call in tool_calls: for tool_call in tool_calls:
# Extract tool details - handle both formats (generator's tool_call and content block) # Extract tool details - handle both formats (generator's tool_call and content block)
@@ -221,7 +266,14 @@ class AnthropicProvider(BaseLLMProvider):
return results return results
def create_follow_up_completion( def create_follow_up_completion(
self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None self,
messages: list[dict[str, Any]],
model: str,
temperature: float = 0.4,
max_tokens: int | None = None,
tool_results: dict[str, Any] = None,
original_response: Any = None,
stream: bool = True,
) -> Any: ) -> Any:
""" """
Create a follow-up completion with tool results. Create a follow-up completion with tool results.
@@ -233,14 +285,24 @@ class AnthropicProvider(BaseLLMProvider):
max_tokens: Maximum tokens to generate max_tokens: Maximum tokens to generate
tool_results: Results of tool executions tool_results: Results of tool executions
original_response: Original response with tool calls original_response: Original response with tool calls
stream: Whether to stream the response
Returns: Returns:
Anthropic response object Anthropic response object or generator if streaming
""" """
if not original_response or not tool_results: if not original_response or not tool_results:
return original_response return original_response
# Extract tool_use blocks from the original response # Extract tool call from the StreamingResponse or content blocks from Anthropic response
tool_use_blocks = []
if isinstance(original_response, StreamingResponse) and original_response.tool_call:
# For StreamingResponse, create a tool_use block from the tool_call
logger.info(f"Creating tool_use block from StreamingResponse.tool_call: {original_response.tool_call}")
tool_call = original_response.tool_call
tool_use_blocks.append({"type": "tool_use", "id": tool_call.get("id", ""), "name": tool_call.get("name", ""), "input": tool_call.get("input", {})})
elif hasattr(original_response, "content"):
# For regular Anthropic response, extract from content blocks
logger.info("Extracting tool_use blocks from response content")
tool_use_blocks = [block for block in original_response.content if isinstance(block, dict) and block.get("type") == "tool_use"] tool_use_blocks = [block for block in original_response.content if isinstance(block, dict) and block.get("type") == "tool_use"]
# Create tool result blocks # Create tool result blocks
@@ -258,13 +320,13 @@ class AnthropicProvider(BaseLLMProvider):
anthropic_messages.append({"role": "user", "content": tool_result_blocks}) anthropic_messages.append({"role": "user", "content": tool_result_blocks})
# Make a second request to get the final response # Make a second request to get the final response
logger.info("Making second request with tool results") logger.info(f"Making second request with tool results (stream={stream})")
return self.create_chat_completion( return self.create_chat_completion(
messages=anthropic_messages, messages=anthropic_messages,
model=model, model=model,
temperature=temperature, temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
stream=False, stream=stream,
tools=None, # No tools needed for follow-up tools=None, # No tools needed for follow-up
) )
@@ -291,7 +353,7 @@ class AnthropicProvider(BaseLLMProvider):
return "".join(content_parts) return "".join(content_parts)
def get_streaming_content(self, response: Any) -> Any: def get_streaming_content(self, response: Any) -> StreamingResponse:
""" """
Get a generator for streaming content from the response. Get a generator for streaming content from the response.
@@ -299,7 +361,8 @@ class AnthropicProvider(BaseLLMProvider):
response: Anthropic streaming response object response: Anthropic streaming response object
Returns: Returns:
Generator yielding content chunks with tool_call attribute if detected StreamingResponse object wrapping a generator that yields content chunks
and can also store tool call information detected during streaming
""" """
logger.info("Starting Anthropic streaming response processing") logger.info("Starting Anthropic streaming response processing")
@@ -307,20 +370,25 @@ class AnthropicProvider(BaseLLMProvider):
tool_call = None tool_call = None
tool_use_detected = False tool_use_detected = False
# Create the StreamingResponse object first
streaming_response = StreamingResponse(generator=None, tool_call=None)
def generate(): def generate():
nonlocal tool_call, tool_use_detected nonlocal tool_call, tool_use_detected
for chunk in response: for chunk in response:
logger.debug(f"Chunk type: {type(chunk)}") logger.debug(f"Chunk type: {type(chunk)}")
logger.debug(f"Chunk content: {json.dumps(chunk.json) if hasattr(chunk, 'json') else str(chunk)}") logger.debug(f"Chunk content: {json.dumps(chunk.model_dump_json()) if hasattr(chunk, 'json') else str(chunk)}")
# Check for content_block_start events with type "tool_use" # Check for content_block_start events with type "tool_use"
if not tool_use_detected and hasattr(chunk, "type") and chunk.type == "content_block_start": if not tool_use_detected and hasattr(chunk, "type") and chunk.type == "content_block_start":
if hasattr(chunk, "content_block") and hasattr(chunk.content_block, "type"): if hasattr(chunk, "content_block") and hasattr(chunk.content_block, "type"):
if chunk.content_block.type == "tool_use": if chunk.content_block.type == "tool_use":
logger.info(f"Tool use detected in streaming response: {json.dumps(chunk.json) if hasattr(chunk, 'json') else str(chunk)}") logger.info(f"Tool use detected in streaming response: {json.dumps(chunk.model_dump_json()) if hasattr(chunk, 'json') else str(chunk)}")
tool_use_detected = True tool_use_detected = True
tool_call = {"id": getattr(chunk.content_block, "id", ""), "name": getattr(chunk.content_block, "name", ""), "input": getattr(chunk.content_block, "input", {})} tool_call = {"id": getattr(chunk.content_block, "id", ""), "name": getattr(chunk.content_block, "name", ""), "input": getattr(chunk.content_block, "input", {})}
# Update the StreamingResponse object's tool_call attribute
streaming_response.tool_call = tool_call
# We don't signal to the frontend during streaming # We don't signal to the frontend during streaming
# The tool will only be executed after streaming ends # The tool will only be executed after streaming ends
continue continue
@@ -328,7 +396,7 @@ class AnthropicProvider(BaseLLMProvider):
# Handle content_block_delta events for tool_use (input updates) # Handle content_block_delta events for tool_use (input updates)
if tool_use_detected and hasattr(chunk, "type") and chunk.type == "content_block_delta": if tool_use_detected and hasattr(chunk, "type") and chunk.type == "content_block_delta":
if hasattr(chunk, "delta") and hasattr(chunk.delta, "type") and chunk.delta.type == "input_json_delta": if hasattr(chunk, "delta") and hasattr(chunk.delta, "type") and chunk.delta.type == "input_json_delta":
if hasattr(chunk.delta, "partial_json"): if hasattr(chunk.delta, "partial_json") and chunk.delta.partial_json:
logger.info(f"Tool use input update: {chunk.delta.partial_json}") logger.info(f"Tool use input update: {chunk.delta.partial_json}")
# Update the current tool call input # Update the current tool call input
if tool_call: if tool_call:
@@ -336,6 +404,8 @@ class AnthropicProvider(BaseLLMProvider):
# Try to parse the partial JSON and update the input # Try to parse the partial JSON and update the input
partial_input = json.loads(chunk.delta.partial_json) partial_input = json.loads(chunk.delta.partial_json)
tool_call["input"].update(partial_input) tool_call["input"].update(partial_input)
# Update the StreamingResponse object's tool_call attribute
streaming_response.tool_call = tool_call
except json.JSONDecodeError: except json.JSONDecodeError:
logger.warning(f"Failed to parse partial JSON: {chunk.delta.partial_json}") logger.warning(f"Failed to parse partial JSON: {chunk.delta.partial_json}")
continue continue
@@ -346,6 +416,8 @@ class AnthropicProvider(BaseLLMProvider):
# Log the complete tool call for debugging # Log the complete tool call for debugging
if tool_call: if tool_call:
logger.info(f"Completed tool call: {json.dumps(tool_call)}") logger.info(f"Completed tool call: {json.dumps(tool_call)}")
# Update the StreamingResponse object's tool_call attribute
streaming_response.tool_call = tool_call
continue continue
# Handle message_delta events with stop_reason "tool_use" # Handle message_delta events with stop_reason "tool_use"
@@ -353,6 +425,9 @@ class AnthropicProvider(BaseLLMProvider):
if hasattr(chunk, "delta") and hasattr(chunk.delta, "stop_reason"): if hasattr(chunk, "delta") and hasattr(chunk.delta, "stop_reason"):
if chunk.delta.stop_reason == "tool_use": if chunk.delta.stop_reason == "tool_use":
logger.info("Message stopped due to tool use") logger.info("Message stopped due to tool use")
# Update the StreamingResponse object's tool_call attribute one last time
if tool_call:
streaming_response.tool_call = tool_call
continue continue
# Handle regular content chunks # Handle regular content chunks
@@ -374,9 +449,8 @@ class AnthropicProvider(BaseLLMProvider):
# Create the generator # Create the generator
gen = generate() gen = generate()
# Attach the single tool_call to the generator object for later reference # Set the generator in the StreamingResponse object
# This will be used after streaming is complete streaming_response.generator = gen
gen.tool_call = tool_call
# Return the enhanced generator # Return the StreamingResponse object
return gen return streaming_response

View File

@@ -6,8 +6,45 @@ must adhere to. It defines the methods required for tool conversion, API request
and response processing. and response processing.
""" """
import json
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any from collections.abc import Generator, Iterator
from typing import Any, Generic, TypeVar
T = TypeVar("T")
class StreamingResponse(Generic[T]):
"""
Wrapper for streaming responses that can hold tool call information.
This class wraps a generator and provides an iterator interface while also
storing tool call information. This allows us to associate metadata with
a generator without modifying the generator itself.
"""
def __init__(self, generator: Generator[T, None, None], tool_call: dict = None):
"""
Initialize the streaming response.
Args:
generator: The underlying generator yielding content chunks
tool_call: Optional tool call information detected during streaming
"""
self.generator = generator
self.tool_call = tool_call
def __iter__(self) -> Iterator[T]:
"""
Return self as iterator.
"""
return self
def __next__(self) -> T:
"""
Get the next item from the generator.
"""
return next(self.generator)
class BaseLLMProvider(ABC): class BaseLLMProvider(ABC):
@@ -51,32 +88,85 @@ class BaseLLMProvider(ABC):
""" """
pass pass
@abstractmethod
def has_tool_calls(self, response: Any) -> bool: def has_tool_calls(self, response: Any) -> bool:
""" """
Check if the response contains tool calls. Check if the response contains tool calls.
Args: Args:
response: Provider-specific response object response: Provider-specific response object or StreamingResponse
Returns: Returns:
True if the response contains tool calls, False otherwise True if the response contains tool calls, False otherwise
""" """
pass # Check if response is a StreamingResponse with a tool_call attribute
if isinstance(response, StreamingResponse) and response.tool_call is not None:
return True
# Provider-specific implementation should handle other cases
return False
def get_tool_calls(self, response: Any) -> list:
"""
Extract tool calls from the response.
Args:
response: Provider-specific response object or StreamingResponse
Returns:
List of tool call objects in a standardized format
"""
tool_calls = []
# Check if response is a StreamingResponse with a tool_call attribute
if isinstance(response, StreamingResponse) and response.tool_call is not None:
tool_calls.append(response.tool_call)
# Provider-specific implementation should handle other cases
return tool_calls
@abstractmethod
def process_tool_calls(self, response: Any, cookie: str) -> dict[str, Any]: def process_tool_calls(self, response: Any, cookie: str) -> dict[str, Any]:
""" """
Process tool calls from the response. Process tool calls from the response.
Args: Args:
response: Provider-specific response object response: Provider-specific response object or StreamingResponse
cookie: Airflow cookie for authentication cookie: Airflow cookie for authentication
Returns: Returns:
Dictionary mapping tool call IDs to results Dictionary mapping tool call IDs to results
""" """
pass tool_calls = self.get_tool_calls(response)
results = {}
for tool_call in tool_calls:
tool_name = tool_call.get("name", "")
tool_input = tool_call.get("input", {})
tool_id = tool_call.get("id", "")
try:
import logging
logger = logging.getLogger(__name__)
logger.info(f"Executing tool: {tool_name} with input: {json.dumps(tool_input)}")
from airflow_wingman.tools import execute_airflow_tool
result = execute_airflow_tool(tool_name, tool_input, cookie)
logger.info(f"Tool result: {json.dumps(result)}")
results[tool_id] = {
"name": tool_name,
"input": tool_input,
"output": result,
}
except Exception as e:
import traceback
error_msg = f"Error executing tool: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
results[tool_id] = {"status": "error", "message": error_msg}
return results
@abstractmethod @abstractmethod
def create_follow_up_completion( def create_follow_up_completion(
@@ -112,14 +202,15 @@ class BaseLLMProvider(ABC):
pass pass
@abstractmethod @abstractmethod
def get_streaming_content(self, response: Any) -> Any: def get_streaming_content(self, response: Any) -> StreamingResponse:
""" """
Get a generator for streaming content from the response. Get a StreamingResponse for streaming content from the response.
Args: Args:
response: Provider-specific response object response: Provider-specific response object
Returns: Returns:
Generator yielding content chunks StreamingResponse object wrapping a generator that yields content chunks
and can also store tool call information detected during streaming
""" """
pass pass

View File

@@ -12,7 +12,7 @@ from typing import Any
from openai import OpenAI from openai import OpenAI
from airflow_wingman.providers.base import BaseLLMProvider from airflow_wingman.providers.base import BaseLLMProvider, StreamingResponse
from airflow_wingman.tools import execute_airflow_tool from airflow_wingman.tools import execute_airflow_tool
from airflow_wingman.tools.conversion import convert_to_openai_tools from airflow_wingman.tools.conversion import convert_to_openai_tools
@@ -116,35 +116,75 @@ class OpenAIProvider(BaseLLMProvider):
Check if the response contains tool calls. Check if the response contains tool calls.
Args: Args:
response: OpenAI response object response: OpenAI response object or StreamingResponse with tool_call attribute
Returns: Returns:
True if the response contains tool calls, False otherwise True if the response contains tool calls, False otherwise
""" """
# Check if response is a StreamingResponse with a tool_call attribute
if isinstance(response, StreamingResponse) and response.tool_call is not None:
return True
# For non-streaming responses
if hasattr(response, "choices") and len(response.choices) > 0:
message = response.choices[0].message message = response.choices[0].message
return hasattr(message, "tool_calls") and message.tool_calls return hasattr(message, "tool_calls") and message.tool_calls
return False
def get_tool_calls(self, response: Any) -> list:
"""
Extract tool calls from the response.
Args:
response: OpenAI response object or StreamingResponse with tool_call attribute
Returns:
List of tool call objects in a standardized format
"""
tool_calls = []
# Check if response is a StreamingResponse with a tool_call attribute
if isinstance(response, StreamingResponse) and response.tool_call is not None:
logger.info(f"Extracting tool call from StreamingResponse: {response.tool_call}")
tool_calls.append(response.tool_call)
return tool_calls
# For non-streaming responses
if hasattr(response, "choices") and len(response.choices) > 0:
message = response.choices[0].message
if hasattr(message, "tool_calls") and message.tool_calls:
for tool_call in message.tool_calls:
standardized_tool_call = {"id": tool_call.id, "name": tool_call.function.name, "input": json.loads(tool_call.function.arguments)}
tool_calls.append(standardized_tool_call)
logger.info(f"Extracted {len(tool_calls)} tool calls from OpenAI response")
return tool_calls
def process_tool_calls(self, response: Any, cookie: str) -> dict[str, Any]: def process_tool_calls(self, response: Any, cookie: str) -> dict[str, Any]:
""" """
Process tool calls from the response. Process tool calls from the response.
Args: Args:
response: OpenAI response object response: OpenAI response object or StreamingResponse with tool_call attribute
cookie: Airflow cookie for authentication cookie: Airflow cookie for authentication
Returns: Returns:
Dictionary mapping tool call IDs to results Dictionary mapping tool call IDs to results
""" """
results = {} results = {}
message = response.choices[0].message
if not self.has_tool_calls(response): if not self.has_tool_calls(response):
return results return results
for tool_call in message.tool_calls: # Get tool calls using the standardized method
tool_id = tool_call.id tool_calls = self.get_tool_calls(response)
function_name = tool_call.function.name logger.info(f"Processing {len(tool_calls)} tool calls")
arguments = json.loads(tool_call.function.arguments)
for tool_call in tool_calls:
tool_id = tool_call["id"]
function_name = tool_call["name"]
arguments = tool_call["input"]
try: try:
# Execute the Airflow tool with the provided arguments and cookie # Execute the Airflow tool with the provided arguments and cookie
@@ -220,7 +260,7 @@ class OpenAIProvider(BaseLLMProvider):
""" """
return response.choices[0].message.content return response.choices[0].message.content
def get_streaming_content(self, response: Any) -> Any: def get_streaming_content(self, response: Any) -> StreamingResponse:
""" """
Get a generator for streaming content from the response. Get a generator for streaming content from the response.
@@ -228,15 +268,87 @@ class OpenAIProvider(BaseLLMProvider):
response: OpenAI streaming response object response: OpenAI streaming response object
Returns: Returns:
Generator yielding content chunks StreamingResponse object wrapping a generator that yields content chunks
and can also store tool call information detected during streaming
""" """
logger.info("Starting OpenAI streaming response processing") logger.info("Starting OpenAI streaming response processing")
# Track only the first tool call detected during streaming
tool_call = None
tool_use_detected = False
current_tool_call = None
# Create the StreamingResponse object first
streaming_response = StreamingResponse(generator=None, tool_call=None)
def generate(): def generate():
nonlocal tool_call, tool_use_detected, current_tool_call
for chunk in response: for chunk in response:
if chunk.choices and chunk.choices[0].delta.content: # Check for tool call in the delta
# Don't do any newline replacement here if chunk.choices and hasattr(chunk.choices[0].delta, "tool_calls") and chunk.choices[0].delta.tool_calls:
# Tool call detected
if not tool_use_detected:
tool_use_detected = True
logger.info("Tool call detected in streaming response")
# Initialize the tool call
delta_tool_call = chunk.choices[0].delta.tool_calls[0]
current_tool_call = {
"id": getattr(delta_tool_call, "id", ""),
"name": getattr(delta_tool_call.function, "name", "") if hasattr(delta_tool_call, "function") else "",
"input": {},
}
# Update the StreamingResponse object's tool_call attribute
streaming_response.tool_call = current_tool_call
else:
# Update the existing tool call
delta_tool_call = chunk.choices[0].delta.tool_calls[0]
# Update the tool call ID if it's provided in this chunk
if hasattr(delta_tool_call, "id") and delta_tool_call.id and current_tool_call:
current_tool_call["id"] = delta_tool_call.id
# Update the function name if it's provided in this chunk
if hasattr(delta_tool_call, "function") and hasattr(delta_tool_call.function, "name") and delta_tool_call.function.name and current_tool_call:
current_tool_call["name"] = delta_tool_call.function.name
# Update the arguments if they're provided in this chunk
if hasattr(delta_tool_call, "function") and hasattr(delta_tool_call.function, "arguments") and delta_tool_call.function.arguments and current_tool_call:
try:
# Try to parse the arguments JSON
arguments = json.loads(delta_tool_call.function.arguments)
if isinstance(arguments, dict):
current_tool_call["input"].update(arguments)
# Update the StreamingResponse object's tool_call attribute
streaming_response.tool_call = current_tool_call
except json.JSONDecodeError:
# If the arguments are not valid JSON, just log a warning
logger.warning(f"Failed to parse arguments: {delta_tool_call.function.arguments}")
# Skip yielding content for tool call chunks
continue
# For the final chunk, set the tool_call attribute
if chunk.choices and hasattr(chunk.choices[0], "finish_reason") and chunk.choices[0].finish_reason == "tool_calls":
logger.info("Streaming response finished with tool_calls reason")
if current_tool_call:
tool_call = current_tool_call
logger.info(f"Final tool call: {json.dumps(tool_call)}")
# Update the StreamingResponse object's tool_call attribute
streaming_response.tool_call = tool_call
continue
# Handle regular content chunks
if chunk.choices and hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content content = chunk.choices[0].delta.content
yield content yield content
return generate() # Create the generator
gen = generate()
# Set the generator in the StreamingResponse object
streaming_response.generator = gen
# Return the StreamingResponse object
return streaming_response

View File

@@ -31,7 +31,14 @@ async def _list_airflow_tools_async(cookie: str) -> list:
# Set up configuration # Set up configuration
base_url = f"{configuration.conf.get('webserver', 'base_url')}/api/v1/" base_url = f"{configuration.conf.get('webserver', 'base_url')}/api/v1/"
logger.info(f"Setting up AirflowConfig with base_url: {base_url}") logger.info(f"Setting up AirflowConfig with base_url: {base_url}")
config = AirflowConfig(base_url=base_url, cookie=cookie, auth_token=None)
# Format the cookie properly if it doesn't already have the 'session=' prefix
formatted_cookie = cookie
if cookie and not cookie.startswith("session="):
formatted_cookie = f"session={cookie}"
logger.info(f"Formatted cookie with session prefix: {formatted_cookie[:10]}...")
config = AirflowConfig(base_url=base_url, cookie=formatted_cookie, auth_token=None)
# Get available tools # Get available tools
logger.info("Getting Airflow tools...") logger.info("Getting Airflow tools...")
@@ -73,19 +80,31 @@ async def _execute_airflow_tool_async(tool_name: str, arguments: dict, cookie: s
# Set up configuration # Set up configuration
base_url = f"{configuration.conf.get('webserver', 'base_url')}/api/v1/" base_url = f"{configuration.conf.get('webserver', 'base_url')}/api/v1/"
logger.info(f"Setting up AirflowConfig with base_url: {base_url}") logger.info(f"Setting up AirflowConfig with base_url: {base_url}")
config = AirflowConfig(base_url=base_url, cookie=cookie, auth_token=None)
# Format the cookie properly if it doesn't already have the 'session=' prefix
formatted_cookie = cookie
if cookie and not cookie.startswith("session="):
formatted_cookie = f"session={cookie}"
logger.info(f"Formatted cookie with session prefix: {formatted_cookie[:10]}...")
config = AirflowConfig(base_url=base_url, cookie=formatted_cookie, auth_token=None)
# Get the tool # Get the tool
logger.info(f"Getting tool: {tool_name}") logger.info(f"Getting tool: {tool_name}")
tool = await get_tool(config=config, tool_name=tool_name) tool = await get_tool(config=config, name=tool_name)
if not tool: if not tool:
error_msg = f"Tool not found: {tool_name}" error_msg = f"Tool not found: {tool_name}"
logger.error(error_msg) logger.error(error_msg)
return json.dumps({"error": error_msg}) return json.dumps({"error": error_msg})
# Execute the tool # Execute the tool - ensure the client is in an async context
logger.info(f"Executing tool: {tool_name} with arguments: {arguments}") logger.info(f"Executing tool: {tool_name} with arguments: {arguments}")
# The AirflowClient needs to be used as an async context manager
# to properly initialize its session
async with tool.client as client: # noqa F841
# Now the client has a _session attribute and is in an async context
result = await tool.run(arguments) result = await tool.run(arguments)
# Convert result to string # Convert result to string
@@ -114,4 +133,21 @@ def execute_airflow_tool(tool_name: str, arguments: dict, cookie: str) -> str:
Returns: Returns:
Result of the tool execution as a string Result of the tool execution as a string
""" """
return asyncio.run(_execute_airflow_tool_async(tool_name, arguments, cookie)) # Create a new event loop for this execution
# This ensures we're always in a clean async context
loop = asyncio.new_event_loop()
try:
# Set the event loop for this thread
asyncio.set_event_loop(loop)
# Run the async function in the new event loop
result = loop.run_until_complete(_execute_airflow_tool_async(tool_name, arguments, cookie))
return result
except Exception as e:
error_msg = f"Error in execute_airflow_tool: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
return json.dumps({"error": error_msg})
finally:
# Always close the loop to free resources
loop.close()

View File

@@ -114,16 +114,18 @@ class WingmanView(AppBuilderBaseView):
"""Handle streaming response.""" """Handle streaming response."""
try: try:
logger.info("Beginning streaming response") logger.info("Beginning streaming response")
# Use the enhanced chat_completion method with return_response_obj=True # Get the cookie at the beginning of the request handler
response_obj, generator = client.chat_completion( airflow_cookie = request.cookies.get("session")
messages=data["messages"], model=data["model"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=True, return_response_obj=True logger.info(f"Got airflow_cookie: {airflow_cookie is not None}")
)
def stream_response(): # Use the enhanced chat_completion method with return_response_obj=True
streaming_response = client.chat_completion(messages=data["messages"], model=data["model"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=True)
def stream_response(cookie=airflow_cookie):
complete_response = "" complete_response = ""
# Stream the initial response # Stream the initial response
for chunk in generator: for chunk in streaming_response:
if chunk: if chunk:
complete_response += chunk complete_response += chunk
yield f"data: {chunk}\n\n" yield f"data: {chunk}\n\n"
@@ -134,7 +136,7 @@ class WingmanView(AppBuilderBaseView):
logger.info("<<< COMPLETE RESPONSE END") logger.info("<<< COMPLETE RESPONSE END")
# Check for tool calls and make follow-up if needed # Check for tool calls and make follow-up if needed
if client.provider.has_tool_calls(response_obj): if client.provider.has_tool_calls(streaming_response):
# Signal tool processing start - frontend should disable send button # Signal tool processing start - frontend should disable send button
yield f"data: {json.dumps({'event': 'tool_processing_start'})}\n\n" yield f"data: {json.dumps({'event': 'tool_processing_start'})}\n\n"
@@ -142,9 +144,10 @@ class WingmanView(AppBuilderBaseView):
yield f"data: {json.dumps({'event': 'replace_content'})}\n\n" yield f"data: {json.dumps({'event': 'replace_content'})}\n\n"
logger.info("Response contains tool calls, making follow-up request") logger.info("Response contains tool calls, making follow-up request")
logger.info(f"Using cookie from closure: {cookie is not None}")
# Process tool calls and get follow-up response (handles recursive tool calls) # Process tool calls and get follow-up response (handles recursive tool calls)
follow_up_response = client.process_tool_calls_and_follow_up(response_obj, data["messages"], data["model"], data["temperature"], data["max_tokens"]) follow_up_response = client.process_tool_calls_and_follow_up(streaming_response, data["messages"], data["model"], data["temperature"], data["max_tokens"], cookie=cookie)
# Stream the follow-up response # Stream the follow-up response
for chunk in follow_up_response: for chunk in follow_up_response: