fix tool call until cookies are legit

This commit is contained in:
2025-03-02 07:44:24 +00:00
parent 6cb60f1bbd
commit 5491ba71aa
6 changed files with 394 additions and 79 deletions

View File

@@ -12,7 +12,7 @@ from typing import Any
from openai import OpenAI
from airflow_wingman.providers.base import BaseLLMProvider
from airflow_wingman.providers.base import BaseLLMProvider, StreamingResponse
from airflow_wingman.tools import execute_airflow_tool
from airflow_wingman.tools.conversion import convert_to_openai_tools
@@ -116,35 +116,75 @@ class OpenAIProvider(BaseLLMProvider):
Check if the response contains tool calls.
Args:
response: OpenAI response object
response: OpenAI response object or StreamingResponse with tool_call attribute
Returns:
True if the response contains tool calls, False otherwise
"""
message = response.choices[0].message
return hasattr(message, "tool_calls") and message.tool_calls
# Check if response is a StreamingResponse with a tool_call attribute
if isinstance(response, StreamingResponse) and response.tool_call is not None:
return True
# For non-streaming responses
if hasattr(response, "choices") and len(response.choices) > 0:
message = response.choices[0].message
return hasattr(message, "tool_calls") and message.tool_calls
return False
def get_tool_calls(self, response: Any) -> list:
"""
Extract tool calls from the response.
Args:
response: OpenAI response object or StreamingResponse with tool_call attribute
Returns:
List of tool call objects in a standardized format
"""
tool_calls = []
# Check if response is a StreamingResponse with a tool_call attribute
if isinstance(response, StreamingResponse) and response.tool_call is not None:
logger.info(f"Extracting tool call from StreamingResponse: {response.tool_call}")
tool_calls.append(response.tool_call)
return tool_calls
# For non-streaming responses
if hasattr(response, "choices") and len(response.choices) > 0:
message = response.choices[0].message
if hasattr(message, "tool_calls") and message.tool_calls:
for tool_call in message.tool_calls:
standardized_tool_call = {"id": tool_call.id, "name": tool_call.function.name, "input": json.loads(tool_call.function.arguments)}
tool_calls.append(standardized_tool_call)
logger.info(f"Extracted {len(tool_calls)} tool calls from OpenAI response")
return tool_calls
def process_tool_calls(self, response: Any, cookie: str) -> dict[str, Any]:
"""
Process tool calls from the response.
Args:
response: OpenAI response object
response: OpenAI response object or StreamingResponse with tool_call attribute
cookie: Airflow cookie for authentication
Returns:
Dictionary mapping tool call IDs to results
"""
results = {}
message = response.choices[0].message
if not self.has_tool_calls(response):
return results
for tool_call in message.tool_calls:
tool_id = tool_call.id
function_name = tool_call.function.name
arguments = json.loads(tool_call.function.arguments)
# Get tool calls using the standardized method
tool_calls = self.get_tool_calls(response)
logger.info(f"Processing {len(tool_calls)} tool calls")
for tool_call in tool_calls:
tool_id = tool_call["id"]
function_name = tool_call["name"]
arguments = tool_call["input"]
try:
# Execute the Airflow tool with the provided arguments and cookie
@@ -220,7 +260,7 @@ class OpenAIProvider(BaseLLMProvider):
"""
return response.choices[0].message.content
def get_streaming_content(self, response: Any) -> Any:
def get_streaming_content(self, response: Any) -> StreamingResponse:
"""
Get a generator for streaming content from the response.
@@ -228,15 +268,87 @@ class OpenAIProvider(BaseLLMProvider):
response: OpenAI streaming response object
Returns:
Generator yielding content chunks
StreamingResponse object wrapping a generator that yields content chunks
and can also store tool call information detected during streaming
"""
logger.info("Starting OpenAI streaming response processing")
# Track only the first tool call detected during streaming
tool_call = None
tool_use_detected = False
current_tool_call = None
# Create the StreamingResponse object first
streaming_response = StreamingResponse(generator=None, tool_call=None)
def generate():
nonlocal tool_call, tool_use_detected, current_tool_call
for chunk in response:
if chunk.choices and chunk.choices[0].delta.content:
# Don't do any newline replacement here
# Check for tool call in the delta
if chunk.choices and hasattr(chunk.choices[0].delta, "tool_calls") and chunk.choices[0].delta.tool_calls:
# Tool call detected
if not tool_use_detected:
tool_use_detected = True
logger.info("Tool call detected in streaming response")
# Initialize the tool call
delta_tool_call = chunk.choices[0].delta.tool_calls[0]
current_tool_call = {
"id": getattr(delta_tool_call, "id", ""),
"name": getattr(delta_tool_call.function, "name", "") if hasattr(delta_tool_call, "function") else "",
"input": {},
}
# Update the StreamingResponse object's tool_call attribute
streaming_response.tool_call = current_tool_call
else:
# Update the existing tool call
delta_tool_call = chunk.choices[0].delta.tool_calls[0]
# Update the tool call ID if it's provided in this chunk
if hasattr(delta_tool_call, "id") and delta_tool_call.id and current_tool_call:
current_tool_call["id"] = delta_tool_call.id
# Update the function name if it's provided in this chunk
if hasattr(delta_tool_call, "function") and hasattr(delta_tool_call.function, "name") and delta_tool_call.function.name and current_tool_call:
current_tool_call["name"] = delta_tool_call.function.name
# Update the arguments if they're provided in this chunk
if hasattr(delta_tool_call, "function") and hasattr(delta_tool_call.function, "arguments") and delta_tool_call.function.arguments and current_tool_call:
try:
# Try to parse the arguments JSON
arguments = json.loads(delta_tool_call.function.arguments)
if isinstance(arguments, dict):
current_tool_call["input"].update(arguments)
# Update the StreamingResponse object's tool_call attribute
streaming_response.tool_call = current_tool_call
except json.JSONDecodeError:
# If the arguments are not valid JSON, just log a warning
logger.warning(f"Failed to parse arguments: {delta_tool_call.function.arguments}")
# Skip yielding content for tool call chunks
continue
# For the final chunk, set the tool_call attribute
if chunk.choices and hasattr(chunk.choices[0], "finish_reason") and chunk.choices[0].finish_reason == "tool_calls":
logger.info("Streaming response finished with tool_calls reason")
if current_tool_call:
tool_call = current_tool_call
logger.info(f"Final tool call: {json.dumps(tool_call)}")
# Update the StreamingResponse object's tool_call attribute
streaming_response.tool_call = tool_call
continue
# Handle regular content chunks
if chunk.choices and hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
yield content
return generate()
# Create the generator
gen = generate()
# Set the generator in the StreamingResponse object
streaming_response.generator = gen
# Return the StreamingResponse object
return streaming_response