fix tool call until cookies are legit
This commit is contained in:
@@ -81,10 +81,8 @@ class LLMClient:
|
||||
# If streaming, handle based on return_response_obj flag
|
||||
if stream:
|
||||
logger.info(f"Using streaming response from {self.provider_name}")
|
||||
if return_response_obj:
|
||||
return response, self.provider.get_streaming_content(response)
|
||||
else:
|
||||
return self.provider.get_streaming_content(response)
|
||||
streaming_content = self.provider.get_streaming_content(response)
|
||||
return streaming_content
|
||||
|
||||
# For non-streaming responses, handle tool calls if present
|
||||
if self.provider.has_tool_calls(response):
|
||||
@@ -143,7 +141,7 @@ class LLMClient:
|
||||
|
||||
return cls(provider_name=provider_name, api_key=api_key, base_url=base_url)
|
||||
|
||||
def process_tool_calls_and_follow_up(self, response, messages, model, temperature, max_tokens, max_iterations=5):
|
||||
def process_tool_calls_and_follow_up(self, response, messages, model, temperature, max_tokens, max_iterations=5, cookie=None):
|
||||
"""
|
||||
Process tool calls recursively from a response and make follow-up requests until
|
||||
there are no more tool calls or max_iterations is reached.
|
||||
@@ -156,6 +154,7 @@ class LLMClient:
|
||||
temperature: Sampling temperature (0-1)
|
||||
max_tokens: Maximum tokens to generate
|
||||
max_iterations: Maximum number of tool call iterations to prevent infinite loops
|
||||
cookie: Airflow cookie for authentication (optional, will try to get from session if not provided)
|
||||
|
||||
Returns:
|
||||
Generator for streaming the final follow-up response
|
||||
@@ -163,8 +162,8 @@ class LLMClient:
|
||||
try:
|
||||
iteration = 0
|
||||
current_response = response
|
||||
cookie = session.get("airflow_cookie")
|
||||
|
||||
# Check if we have a cookie
|
||||
if not cookie:
|
||||
error_msg = "No Airflow cookie available"
|
||||
logger.error(error_msg)
|
||||
|
||||
@@ -12,7 +12,7 @@ from typing import Any
|
||||
|
||||
from anthropic import Anthropic
|
||||
|
||||
from airflow_wingman.providers.base import BaseLLMProvider
|
||||
from airflow_wingman.providers.base import BaseLLMProvider, StreamingResponse
|
||||
from airflow_wingman.tools import execute_airflow_tool
|
||||
from airflow_wingman.tools.conversion import convert_to_anthropic_tools
|
||||
|
||||
@@ -101,7 +101,18 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
|
||||
# Log as much information as possible
|
||||
if hasattr(response, "json"):
|
||||
if callable(response.json):
|
||||
# If json is a method, call it
|
||||
try:
|
||||
logger.info(f"Anthropic response json: {json.dumps(response.json())}")
|
||||
except Exception as json_err:
|
||||
logger.warning(f"Could not serialize response.json(): {str(json_err)}")
|
||||
else:
|
||||
# If json is a property, use it directly
|
||||
try:
|
||||
logger.info(f"Anthropic response json: {json.dumps(response.json)}")
|
||||
except Exception as json_err:
|
||||
logger.warning(f"Could not serialize response.json: {str(json_err)}")
|
||||
|
||||
# Log response attributes
|
||||
response_attrs = [attr for attr in dir(response) if not attr.startswith("_") and not callable(getattr(response, attr))]
|
||||
@@ -150,22 +161,63 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
Check if the response contains tool calls.
|
||||
|
||||
Args:
|
||||
response: Anthropic response object or generator with tool_call attribute
|
||||
response: Anthropic response object or StreamingResponse with tool_call attribute
|
||||
|
||||
Returns:
|
||||
True if the response contains tool calls, False otherwise
|
||||
"""
|
||||
# Check if response is a generator with a tool_call attribute
|
||||
if hasattr(response, "tool_call") and response.tool_call is not None:
|
||||
logger.info(f"Checking for tool calls in response of type: {type(response)}")
|
||||
|
||||
# Check if response is a StreamingResponse with a tool_call attribute
|
||||
if isinstance(response, StreamingResponse):
|
||||
logger.info(f"Response is a StreamingResponse, has tool_call attribute: {hasattr(response, 'tool_call')}")
|
||||
if response.tool_call is not None:
|
||||
logger.info(f"StreamingResponse has non-None tool_call: {response.tool_call}")
|
||||
return True
|
||||
else:
|
||||
logger.info("StreamingResponse has None tool_call")
|
||||
else:
|
||||
logger.info("Response is not a StreamingResponse")
|
||||
|
||||
# Check if any content block is a tool_use block (for non-streaming responses)
|
||||
if hasattr(response, "content"):
|
||||
logger.info(f"Response has content attribute with {len(response.content)} blocks")
|
||||
for i, block in enumerate(response.content):
|
||||
logger.info(f"Checking content block {i}: {type(block)}")
|
||||
if isinstance(block, dict) and block.get("type") == "tool_use":
|
||||
logger.info(f"Found tool_use block: {block}")
|
||||
return True
|
||||
else:
|
||||
logger.info("Response does not have content attribute")
|
||||
|
||||
logger.info("No tool calls found in response")
|
||||
return False
|
||||
|
||||
def get_tool_calls(self, response: Any) -> list:
|
||||
"""
|
||||
Extract tool calls from the response.
|
||||
|
||||
Args:
|
||||
response: Anthropic response object or StreamingResponse with tool_call attribute
|
||||
|
||||
Returns:
|
||||
List of tool call objects in a standardized format
|
||||
"""
|
||||
tool_calls = []
|
||||
|
||||
# Check if response is a StreamingResponse with a tool_call attribute
|
||||
if isinstance(response, StreamingResponse) and response.tool_call is not None:
|
||||
logger.info(f"Extracting tool call from StreamingResponse: {response.tool_call}")
|
||||
tool_calls.append(response.tool_call)
|
||||
# Otherwise, extract tool calls from response content (for non-streaming responses)
|
||||
elif hasattr(response, "content"):
|
||||
logger.info("Extracting tool calls from response content")
|
||||
for block in response.content:
|
||||
if isinstance(block, dict) and block.get("type") == "tool_use":
|
||||
return True
|
||||
tool_call = {"id": block.get("id", ""), "name": block.get("name", ""), "input": block.get("input", {})}
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
return False
|
||||
return tool_calls
|
||||
|
||||
def process_tool_calls(self, response: Any, cookie: str) -> dict[str, Any]:
|
||||
"""
|
||||
@@ -183,16 +235,9 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
if not self.has_tool_calls(response):
|
||||
return results
|
||||
|
||||
tool_calls = []
|
||||
|
||||
# Check if response is a generator with a tool_call attribute
|
||||
if hasattr(response, "tool_call") and response.tool_call is not None:
|
||||
logger.info(f"Processing tool call from generator: {response.tool_call}")
|
||||
tool_calls.append(response.tool_call)
|
||||
# Otherwise, extract tool calls from response content (for non-streaming responses)
|
||||
elif hasattr(response, "content"):
|
||||
logger.info("Processing tool calls from response content")
|
||||
tool_calls = [block for block in response.content if isinstance(block, dict) and block.get("type") == "tool_use"]
|
||||
# Get tool calls using the standardized method
|
||||
tool_calls = self.get_tool_calls(response)
|
||||
logger.info(f"Processing {len(tool_calls)} tool calls")
|
||||
|
||||
for tool_call in tool_calls:
|
||||
# Extract tool details - handle both formats (generator's tool_call and content block)
|
||||
@@ -221,7 +266,14 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
return results
|
||||
|
||||
def create_follow_up_completion(
|
||||
self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
model: str,
|
||||
temperature: float = 0.4,
|
||||
max_tokens: int | None = None,
|
||||
tool_results: dict[str, Any] = None,
|
||||
original_response: Any = None,
|
||||
stream: bool = True,
|
||||
) -> Any:
|
||||
"""
|
||||
Create a follow-up completion with tool results.
|
||||
@@ -233,14 +285,24 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
max_tokens: Maximum tokens to generate
|
||||
tool_results: Results of tool executions
|
||||
original_response: Original response with tool calls
|
||||
stream: Whether to stream the response
|
||||
|
||||
Returns:
|
||||
Anthropic response object
|
||||
Anthropic response object or generator if streaming
|
||||
"""
|
||||
if not original_response or not tool_results:
|
||||
return original_response
|
||||
|
||||
# Extract tool_use blocks from the original response
|
||||
# Extract tool call from the StreamingResponse or content blocks from Anthropic response
|
||||
tool_use_blocks = []
|
||||
if isinstance(original_response, StreamingResponse) and original_response.tool_call:
|
||||
# For StreamingResponse, create a tool_use block from the tool_call
|
||||
logger.info(f"Creating tool_use block from StreamingResponse.tool_call: {original_response.tool_call}")
|
||||
tool_call = original_response.tool_call
|
||||
tool_use_blocks.append({"type": "tool_use", "id": tool_call.get("id", ""), "name": tool_call.get("name", ""), "input": tool_call.get("input", {})})
|
||||
elif hasattr(original_response, "content"):
|
||||
# For regular Anthropic response, extract from content blocks
|
||||
logger.info("Extracting tool_use blocks from response content")
|
||||
tool_use_blocks = [block for block in original_response.content if isinstance(block, dict) and block.get("type") == "tool_use"]
|
||||
|
||||
# Create tool result blocks
|
||||
@@ -258,13 +320,13 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
anthropic_messages.append({"role": "user", "content": tool_result_blocks})
|
||||
|
||||
# Make a second request to get the final response
|
||||
logger.info("Making second request with tool results")
|
||||
logger.info(f"Making second request with tool results (stream={stream})")
|
||||
return self.create_chat_completion(
|
||||
messages=anthropic_messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stream=False,
|
||||
stream=stream,
|
||||
tools=None, # No tools needed for follow-up
|
||||
)
|
||||
|
||||
@@ -291,7 +353,7 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
|
||||
return "".join(content_parts)
|
||||
|
||||
def get_streaming_content(self, response: Any) -> Any:
|
||||
def get_streaming_content(self, response: Any) -> StreamingResponse:
|
||||
"""
|
||||
Get a generator for streaming content from the response.
|
||||
|
||||
@@ -299,7 +361,8 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
response: Anthropic streaming response object
|
||||
|
||||
Returns:
|
||||
Generator yielding content chunks with tool_call attribute if detected
|
||||
StreamingResponse object wrapping a generator that yields content chunks
|
||||
and can also store tool call information detected during streaming
|
||||
"""
|
||||
logger.info("Starting Anthropic streaming response processing")
|
||||
|
||||
@@ -307,20 +370,25 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
tool_call = None
|
||||
tool_use_detected = False
|
||||
|
||||
# Create the StreamingResponse object first
|
||||
streaming_response = StreamingResponse(generator=None, tool_call=None)
|
||||
|
||||
def generate():
|
||||
nonlocal tool_call, tool_use_detected
|
||||
|
||||
for chunk in response:
|
||||
logger.debug(f"Chunk type: {type(chunk)}")
|
||||
logger.debug(f"Chunk content: {json.dumps(chunk.json) if hasattr(chunk, 'json') else str(chunk)}")
|
||||
logger.debug(f"Chunk content: {json.dumps(chunk.model_dump_json()) if hasattr(chunk, 'json') else str(chunk)}")
|
||||
|
||||
# Check for content_block_start events with type "tool_use"
|
||||
if not tool_use_detected and hasattr(chunk, "type") and chunk.type == "content_block_start":
|
||||
if hasattr(chunk, "content_block") and hasattr(chunk.content_block, "type"):
|
||||
if chunk.content_block.type == "tool_use":
|
||||
logger.info(f"Tool use detected in streaming response: {json.dumps(chunk.json) if hasattr(chunk, 'json') else str(chunk)}")
|
||||
logger.info(f"Tool use detected in streaming response: {json.dumps(chunk.model_dump_json()) if hasattr(chunk, 'json') else str(chunk)}")
|
||||
tool_use_detected = True
|
||||
tool_call = {"id": getattr(chunk.content_block, "id", ""), "name": getattr(chunk.content_block, "name", ""), "input": getattr(chunk.content_block, "input", {})}
|
||||
# Update the StreamingResponse object's tool_call attribute
|
||||
streaming_response.tool_call = tool_call
|
||||
# We don't signal to the frontend during streaming
|
||||
# The tool will only be executed after streaming ends
|
||||
continue
|
||||
@@ -328,7 +396,7 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
# Handle content_block_delta events for tool_use (input updates)
|
||||
if tool_use_detected and hasattr(chunk, "type") and chunk.type == "content_block_delta":
|
||||
if hasattr(chunk, "delta") and hasattr(chunk.delta, "type") and chunk.delta.type == "input_json_delta":
|
||||
if hasattr(chunk.delta, "partial_json"):
|
||||
if hasattr(chunk.delta, "partial_json") and chunk.delta.partial_json:
|
||||
logger.info(f"Tool use input update: {chunk.delta.partial_json}")
|
||||
# Update the current tool call input
|
||||
if tool_call:
|
||||
@@ -336,6 +404,8 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
# Try to parse the partial JSON and update the input
|
||||
partial_input = json.loads(chunk.delta.partial_json)
|
||||
tool_call["input"].update(partial_input)
|
||||
# Update the StreamingResponse object's tool_call attribute
|
||||
streaming_response.tool_call = tool_call
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse partial JSON: {chunk.delta.partial_json}")
|
||||
continue
|
||||
@@ -346,6 +416,8 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
# Log the complete tool call for debugging
|
||||
if tool_call:
|
||||
logger.info(f"Completed tool call: {json.dumps(tool_call)}")
|
||||
# Update the StreamingResponse object's tool_call attribute
|
||||
streaming_response.tool_call = tool_call
|
||||
continue
|
||||
|
||||
# Handle message_delta events with stop_reason "tool_use"
|
||||
@@ -353,6 +425,9 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
if hasattr(chunk, "delta") and hasattr(chunk.delta, "stop_reason"):
|
||||
if chunk.delta.stop_reason == "tool_use":
|
||||
logger.info("Message stopped due to tool use")
|
||||
# Update the StreamingResponse object's tool_call attribute one last time
|
||||
if tool_call:
|
||||
streaming_response.tool_call = tool_call
|
||||
continue
|
||||
|
||||
# Handle regular content chunks
|
||||
@@ -374,9 +449,8 @@ class AnthropicProvider(BaseLLMProvider):
|
||||
# Create the generator
|
||||
gen = generate()
|
||||
|
||||
# Attach the single tool_call to the generator object for later reference
|
||||
# This will be used after streaming is complete
|
||||
gen.tool_call = tool_call
|
||||
# Set the generator in the StreamingResponse object
|
||||
streaming_response.generator = gen
|
||||
|
||||
# Return the enhanced generator
|
||||
return gen
|
||||
# Return the StreamingResponse object
|
||||
return streaming_response
|
||||
|
||||
@@ -6,8 +6,45 @@ must adhere to. It defines the methods required for tool conversion, API request
|
||||
and response processing.
|
||||
"""
|
||||
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
from collections.abc import Generator, Iterator
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class StreamingResponse(Generic[T]):
|
||||
"""
|
||||
Wrapper for streaming responses that can hold tool call information.
|
||||
|
||||
This class wraps a generator and provides an iterator interface while also
|
||||
storing tool call information. This allows us to associate metadata with
|
||||
a generator without modifying the generator itself.
|
||||
"""
|
||||
|
||||
def __init__(self, generator: Generator[T, None, None], tool_call: dict = None):
|
||||
"""
|
||||
Initialize the streaming response.
|
||||
|
||||
Args:
|
||||
generator: The underlying generator yielding content chunks
|
||||
tool_call: Optional tool call information detected during streaming
|
||||
"""
|
||||
self.generator = generator
|
||||
self.tool_call = tool_call
|
||||
|
||||
def __iter__(self) -> Iterator[T]:
|
||||
"""
|
||||
Return self as iterator.
|
||||
"""
|
||||
return self
|
||||
|
||||
def __next__(self) -> T:
|
||||
"""
|
||||
Get the next item from the generator.
|
||||
"""
|
||||
return next(self.generator)
|
||||
|
||||
|
||||
class BaseLLMProvider(ABC):
|
||||
@@ -51,32 +88,85 @@ class BaseLLMProvider(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def has_tool_calls(self, response: Any) -> bool:
|
||||
"""
|
||||
Check if the response contains tool calls.
|
||||
|
||||
Args:
|
||||
response: Provider-specific response object
|
||||
response: Provider-specific response object or StreamingResponse
|
||||
|
||||
Returns:
|
||||
True if the response contains tool calls, False otherwise
|
||||
"""
|
||||
pass
|
||||
# Check if response is a StreamingResponse with a tool_call attribute
|
||||
if isinstance(response, StreamingResponse) and response.tool_call is not None:
|
||||
return True
|
||||
|
||||
# Provider-specific implementation should handle other cases
|
||||
return False
|
||||
|
||||
def get_tool_calls(self, response: Any) -> list:
|
||||
"""
|
||||
Extract tool calls from the response.
|
||||
|
||||
Args:
|
||||
response: Provider-specific response object or StreamingResponse
|
||||
|
||||
Returns:
|
||||
List of tool call objects in a standardized format
|
||||
"""
|
||||
tool_calls = []
|
||||
|
||||
# Check if response is a StreamingResponse with a tool_call attribute
|
||||
if isinstance(response, StreamingResponse) and response.tool_call is not None:
|
||||
tool_calls.append(response.tool_call)
|
||||
|
||||
# Provider-specific implementation should handle other cases
|
||||
return tool_calls
|
||||
|
||||
@abstractmethod
|
||||
def process_tool_calls(self, response: Any, cookie: str) -> dict[str, Any]:
|
||||
"""
|
||||
Process tool calls from the response.
|
||||
|
||||
Args:
|
||||
response: Provider-specific response object
|
||||
response: Provider-specific response object or StreamingResponse
|
||||
cookie: Airflow cookie for authentication
|
||||
|
||||
Returns:
|
||||
Dictionary mapping tool call IDs to results
|
||||
"""
|
||||
pass
|
||||
tool_calls = self.get_tool_calls(response)
|
||||
results = {}
|
||||
|
||||
for tool_call in tool_calls:
|
||||
tool_name = tool_call.get("name", "")
|
||||
tool_input = tool_call.get("input", {})
|
||||
tool_id = tool_call.get("id", "")
|
||||
|
||||
try:
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(f"Executing tool: {tool_name} with input: {json.dumps(tool_input)}")
|
||||
|
||||
from airflow_wingman.tools import execute_airflow_tool
|
||||
|
||||
result = execute_airflow_tool(tool_name, tool_input, cookie)
|
||||
|
||||
logger.info(f"Tool result: {json.dumps(result)}")
|
||||
results[tool_id] = {
|
||||
"name": tool_name,
|
||||
"input": tool_input,
|
||||
"output": result,
|
||||
}
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
error_msg = f"Error executing tool: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
results[tool_id] = {"status": "error", "message": error_msg}
|
||||
|
||||
return results
|
||||
|
||||
@abstractmethod
|
||||
def create_follow_up_completion(
|
||||
@@ -112,14 +202,15 @@ class BaseLLMProvider(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_streaming_content(self, response: Any) -> Any:
|
||||
def get_streaming_content(self, response: Any) -> StreamingResponse:
|
||||
"""
|
||||
Get a generator for streaming content from the response.
|
||||
Get a StreamingResponse for streaming content from the response.
|
||||
|
||||
Args:
|
||||
response: Provider-specific response object
|
||||
|
||||
Returns:
|
||||
Generator yielding content chunks
|
||||
StreamingResponse object wrapping a generator that yields content chunks
|
||||
and can also store tool call information detected during streaming
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -12,7 +12,7 @@ from typing import Any
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from airflow_wingman.providers.base import BaseLLMProvider
|
||||
from airflow_wingman.providers.base import BaseLLMProvider, StreamingResponse
|
||||
from airflow_wingman.tools import execute_airflow_tool
|
||||
from airflow_wingman.tools.conversion import convert_to_openai_tools
|
||||
|
||||
@@ -116,35 +116,75 @@ class OpenAIProvider(BaseLLMProvider):
|
||||
Check if the response contains tool calls.
|
||||
|
||||
Args:
|
||||
response: OpenAI response object
|
||||
response: OpenAI response object or StreamingResponse with tool_call attribute
|
||||
|
||||
Returns:
|
||||
True if the response contains tool calls, False otherwise
|
||||
"""
|
||||
# Check if response is a StreamingResponse with a tool_call attribute
|
||||
if isinstance(response, StreamingResponse) and response.tool_call is not None:
|
||||
return True
|
||||
|
||||
# For non-streaming responses
|
||||
if hasattr(response, "choices") and len(response.choices) > 0:
|
||||
message = response.choices[0].message
|
||||
return hasattr(message, "tool_calls") and message.tool_calls
|
||||
|
||||
return False
|
||||
|
||||
def get_tool_calls(self, response: Any) -> list:
|
||||
"""
|
||||
Extract tool calls from the response.
|
||||
|
||||
Args:
|
||||
response: OpenAI response object or StreamingResponse with tool_call attribute
|
||||
|
||||
Returns:
|
||||
List of tool call objects in a standardized format
|
||||
"""
|
||||
tool_calls = []
|
||||
|
||||
# Check if response is a StreamingResponse with a tool_call attribute
|
||||
if isinstance(response, StreamingResponse) and response.tool_call is not None:
|
||||
logger.info(f"Extracting tool call from StreamingResponse: {response.tool_call}")
|
||||
tool_calls.append(response.tool_call)
|
||||
return tool_calls
|
||||
|
||||
# For non-streaming responses
|
||||
if hasattr(response, "choices") and len(response.choices) > 0:
|
||||
message = response.choices[0].message
|
||||
if hasattr(message, "tool_calls") and message.tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
standardized_tool_call = {"id": tool_call.id, "name": tool_call.function.name, "input": json.loads(tool_call.function.arguments)}
|
||||
tool_calls.append(standardized_tool_call)
|
||||
|
||||
logger.info(f"Extracted {len(tool_calls)} tool calls from OpenAI response")
|
||||
return tool_calls
|
||||
|
||||
def process_tool_calls(self, response: Any, cookie: str) -> dict[str, Any]:
|
||||
"""
|
||||
Process tool calls from the response.
|
||||
|
||||
Args:
|
||||
response: OpenAI response object
|
||||
response: OpenAI response object or StreamingResponse with tool_call attribute
|
||||
cookie: Airflow cookie for authentication
|
||||
|
||||
Returns:
|
||||
Dictionary mapping tool call IDs to results
|
||||
"""
|
||||
results = {}
|
||||
message = response.choices[0].message
|
||||
|
||||
if not self.has_tool_calls(response):
|
||||
return results
|
||||
|
||||
for tool_call in message.tool_calls:
|
||||
tool_id = tool_call.id
|
||||
function_name = tool_call.function.name
|
||||
arguments = json.loads(tool_call.function.arguments)
|
||||
# Get tool calls using the standardized method
|
||||
tool_calls = self.get_tool_calls(response)
|
||||
logger.info(f"Processing {len(tool_calls)} tool calls")
|
||||
|
||||
for tool_call in tool_calls:
|
||||
tool_id = tool_call["id"]
|
||||
function_name = tool_call["name"]
|
||||
arguments = tool_call["input"]
|
||||
|
||||
try:
|
||||
# Execute the Airflow tool with the provided arguments and cookie
|
||||
@@ -220,7 +260,7 @@ class OpenAIProvider(BaseLLMProvider):
|
||||
"""
|
||||
return response.choices[0].message.content
|
||||
|
||||
def get_streaming_content(self, response: Any) -> Any:
|
||||
def get_streaming_content(self, response: Any) -> StreamingResponse:
|
||||
"""
|
||||
Get a generator for streaming content from the response.
|
||||
|
||||
@@ -228,15 +268,87 @@ class OpenAIProvider(BaseLLMProvider):
|
||||
response: OpenAI streaming response object
|
||||
|
||||
Returns:
|
||||
Generator yielding content chunks
|
||||
StreamingResponse object wrapping a generator that yields content chunks
|
||||
and can also store tool call information detected during streaming
|
||||
"""
|
||||
logger.info("Starting OpenAI streaming response processing")
|
||||
|
||||
# Track only the first tool call detected during streaming
|
||||
tool_call = None
|
||||
tool_use_detected = False
|
||||
current_tool_call = None
|
||||
|
||||
# Create the StreamingResponse object first
|
||||
streaming_response = StreamingResponse(generator=None, tool_call=None)
|
||||
|
||||
def generate():
|
||||
nonlocal tool_call, tool_use_detected, current_tool_call
|
||||
|
||||
for chunk in response:
|
||||
if chunk.choices and chunk.choices[0].delta.content:
|
||||
# Don't do any newline replacement here
|
||||
# Check for tool call in the delta
|
||||
if chunk.choices and hasattr(chunk.choices[0].delta, "tool_calls") and chunk.choices[0].delta.tool_calls:
|
||||
# Tool call detected
|
||||
if not tool_use_detected:
|
||||
tool_use_detected = True
|
||||
logger.info("Tool call detected in streaming response")
|
||||
|
||||
# Initialize the tool call
|
||||
delta_tool_call = chunk.choices[0].delta.tool_calls[0]
|
||||
current_tool_call = {
|
||||
"id": getattr(delta_tool_call, "id", ""),
|
||||
"name": getattr(delta_tool_call.function, "name", "") if hasattr(delta_tool_call, "function") else "",
|
||||
"input": {},
|
||||
}
|
||||
# Update the StreamingResponse object's tool_call attribute
|
||||
streaming_response.tool_call = current_tool_call
|
||||
else:
|
||||
# Update the existing tool call
|
||||
delta_tool_call = chunk.choices[0].delta.tool_calls[0]
|
||||
|
||||
# Update the tool call ID if it's provided in this chunk
|
||||
if hasattr(delta_tool_call, "id") and delta_tool_call.id and current_tool_call:
|
||||
current_tool_call["id"] = delta_tool_call.id
|
||||
|
||||
# Update the function name if it's provided in this chunk
|
||||
if hasattr(delta_tool_call, "function") and hasattr(delta_tool_call.function, "name") and delta_tool_call.function.name and current_tool_call:
|
||||
current_tool_call["name"] = delta_tool_call.function.name
|
||||
|
||||
# Update the arguments if they're provided in this chunk
|
||||
if hasattr(delta_tool_call, "function") and hasattr(delta_tool_call.function, "arguments") and delta_tool_call.function.arguments and current_tool_call:
|
||||
try:
|
||||
# Try to parse the arguments JSON
|
||||
arguments = json.loads(delta_tool_call.function.arguments)
|
||||
if isinstance(arguments, dict):
|
||||
current_tool_call["input"].update(arguments)
|
||||
# Update the StreamingResponse object's tool_call attribute
|
||||
streaming_response.tool_call = current_tool_call
|
||||
except json.JSONDecodeError:
|
||||
# If the arguments are not valid JSON, just log a warning
|
||||
logger.warning(f"Failed to parse arguments: {delta_tool_call.function.arguments}")
|
||||
|
||||
# Skip yielding content for tool call chunks
|
||||
continue
|
||||
|
||||
# For the final chunk, set the tool_call attribute
|
||||
if chunk.choices and hasattr(chunk.choices[0], "finish_reason") and chunk.choices[0].finish_reason == "tool_calls":
|
||||
logger.info("Streaming response finished with tool_calls reason")
|
||||
if current_tool_call:
|
||||
tool_call = current_tool_call
|
||||
logger.info(f"Final tool call: {json.dumps(tool_call)}")
|
||||
# Update the StreamingResponse object's tool_call attribute
|
||||
streaming_response.tool_call = tool_call
|
||||
continue
|
||||
|
||||
# Handle regular content chunks
|
||||
if chunk.choices and hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content:
|
||||
content = chunk.choices[0].delta.content
|
||||
yield content
|
||||
|
||||
return generate()
|
||||
# Create the generator
|
||||
gen = generate()
|
||||
|
||||
# Set the generator in the StreamingResponse object
|
||||
streaming_response.generator = gen
|
||||
|
||||
# Return the StreamingResponse object
|
||||
return streaming_response
|
||||
|
||||
@@ -31,7 +31,14 @@ async def _list_airflow_tools_async(cookie: str) -> list:
|
||||
# Set up configuration
|
||||
base_url = f"{configuration.conf.get('webserver', 'base_url')}/api/v1/"
|
||||
logger.info(f"Setting up AirflowConfig with base_url: {base_url}")
|
||||
config = AirflowConfig(base_url=base_url, cookie=cookie, auth_token=None)
|
||||
|
||||
# Format the cookie properly if it doesn't already have the 'session=' prefix
|
||||
formatted_cookie = cookie
|
||||
if cookie and not cookie.startswith("session="):
|
||||
formatted_cookie = f"session={cookie}"
|
||||
logger.info(f"Formatted cookie with session prefix: {formatted_cookie[:10]}...")
|
||||
|
||||
config = AirflowConfig(base_url=base_url, cookie=formatted_cookie, auth_token=None)
|
||||
|
||||
# Get available tools
|
||||
logger.info("Getting Airflow tools...")
|
||||
@@ -73,19 +80,31 @@ async def _execute_airflow_tool_async(tool_name: str, arguments: dict, cookie: s
|
||||
# Set up configuration
|
||||
base_url = f"{configuration.conf.get('webserver', 'base_url')}/api/v1/"
|
||||
logger.info(f"Setting up AirflowConfig with base_url: {base_url}")
|
||||
config = AirflowConfig(base_url=base_url, cookie=cookie, auth_token=None)
|
||||
|
||||
# Format the cookie properly if it doesn't already have the 'session=' prefix
|
||||
formatted_cookie = cookie
|
||||
if cookie and not cookie.startswith("session="):
|
||||
formatted_cookie = f"session={cookie}"
|
||||
logger.info(f"Formatted cookie with session prefix: {formatted_cookie[:10]}...")
|
||||
|
||||
config = AirflowConfig(base_url=base_url, cookie=formatted_cookie, auth_token=None)
|
||||
|
||||
# Get the tool
|
||||
logger.info(f"Getting tool: {tool_name}")
|
||||
tool = await get_tool(config=config, tool_name=tool_name)
|
||||
tool = await get_tool(config=config, name=tool_name)
|
||||
|
||||
if not tool:
|
||||
error_msg = f"Tool not found: {tool_name}"
|
||||
logger.error(error_msg)
|
||||
return json.dumps({"error": error_msg})
|
||||
|
||||
# Execute the tool
|
||||
# Execute the tool - ensure the client is in an async context
|
||||
logger.info(f"Executing tool: {tool_name} with arguments: {arguments}")
|
||||
|
||||
# The AirflowClient needs to be used as an async context manager
|
||||
# to properly initialize its session
|
||||
async with tool.client as client: # noqa F841
|
||||
# Now the client has a _session attribute and is in an async context
|
||||
result = await tool.run(arguments)
|
||||
|
||||
# Convert result to string
|
||||
@@ -114,4 +133,21 @@ def execute_airflow_tool(tool_name: str, arguments: dict, cookie: str) -> str:
|
||||
Returns:
|
||||
Result of the tool execution as a string
|
||||
"""
|
||||
return asyncio.run(_execute_airflow_tool_async(tool_name, arguments, cookie))
|
||||
# Create a new event loop for this execution
|
||||
# This ensures we're always in a clean async context
|
||||
loop = asyncio.new_event_loop()
|
||||
|
||||
try:
|
||||
# Set the event loop for this thread
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# Run the async function in the new event loop
|
||||
result = loop.run_until_complete(_execute_airflow_tool_async(tool_name, arguments, cookie))
|
||||
return result
|
||||
except Exception as e:
|
||||
error_msg = f"Error in execute_airflow_tool: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
return json.dumps({"error": error_msg})
|
||||
finally:
|
||||
# Always close the loop to free resources
|
||||
loop.close()
|
||||
|
||||
@@ -114,16 +114,18 @@ class WingmanView(AppBuilderBaseView):
|
||||
"""Handle streaming response."""
|
||||
try:
|
||||
logger.info("Beginning streaming response")
|
||||
# Use the enhanced chat_completion method with return_response_obj=True
|
||||
response_obj, generator = client.chat_completion(
|
||||
messages=data["messages"], model=data["model"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=True, return_response_obj=True
|
||||
)
|
||||
# Get the cookie at the beginning of the request handler
|
||||
airflow_cookie = request.cookies.get("session")
|
||||
logger.info(f"Got airflow_cookie: {airflow_cookie is not None}")
|
||||
|
||||
def stream_response():
|
||||
# Use the enhanced chat_completion method with return_response_obj=True
|
||||
streaming_response = client.chat_completion(messages=data["messages"], model=data["model"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=True)
|
||||
|
||||
def stream_response(cookie=airflow_cookie):
|
||||
complete_response = ""
|
||||
|
||||
# Stream the initial response
|
||||
for chunk in generator:
|
||||
for chunk in streaming_response:
|
||||
if chunk:
|
||||
complete_response += chunk
|
||||
yield f"data: {chunk}\n\n"
|
||||
@@ -134,7 +136,7 @@ class WingmanView(AppBuilderBaseView):
|
||||
logger.info("<<< COMPLETE RESPONSE END")
|
||||
|
||||
# Check for tool calls and make follow-up if needed
|
||||
if client.provider.has_tool_calls(response_obj):
|
||||
if client.provider.has_tool_calls(streaming_response):
|
||||
# Signal tool processing start - frontend should disable send button
|
||||
yield f"data: {json.dumps({'event': 'tool_processing_start'})}\n\n"
|
||||
|
||||
@@ -142,9 +144,10 @@ class WingmanView(AppBuilderBaseView):
|
||||
yield f"data: {json.dumps({'event': 'replace_content'})}\n\n"
|
||||
|
||||
logger.info("Response contains tool calls, making follow-up request")
|
||||
logger.info(f"Using cookie from closure: {cookie is not None}")
|
||||
|
||||
# Process tool calls and get follow-up response (handles recursive tool calls)
|
||||
follow_up_response = client.process_tool_calls_and_follow_up(response_obj, data["messages"], data["model"], data["temperature"], data["max_tokens"])
|
||||
follow_up_response = client.process_tool_calls_and_follow_up(streaming_response, data["messages"], data["model"], data["temperature"], data["max_tokens"], cookie=cookie)
|
||||
|
||||
# Stream the follow-up response
|
||||
for chunk in follow_up_response:
|
||||
|
||||
Reference in New Issue
Block a user