fix tool call until cookies are legit
This commit is contained in:
@@ -12,7 +12,7 @@ from typing import Any
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from airflow_wingman.providers.base import BaseLLMProvider
|
||||
from airflow_wingman.providers.base import BaseLLMProvider, StreamingResponse
|
||||
from airflow_wingman.tools import execute_airflow_tool
|
||||
from airflow_wingman.tools.conversion import convert_to_openai_tools
|
||||
|
||||
@@ -116,35 +116,75 @@ class OpenAIProvider(BaseLLMProvider):
|
||||
Check if the response contains tool calls.
|
||||
|
||||
Args:
|
||||
response: OpenAI response object
|
||||
response: OpenAI response object or StreamingResponse with tool_call attribute
|
||||
|
||||
Returns:
|
||||
True if the response contains tool calls, False otherwise
|
||||
"""
|
||||
message = response.choices[0].message
|
||||
return hasattr(message, "tool_calls") and message.tool_calls
|
||||
# Check if response is a StreamingResponse with a tool_call attribute
|
||||
if isinstance(response, StreamingResponse) and response.tool_call is not None:
|
||||
return True
|
||||
|
||||
# For non-streaming responses
|
||||
if hasattr(response, "choices") and len(response.choices) > 0:
|
||||
message = response.choices[0].message
|
||||
return hasattr(message, "tool_calls") and message.tool_calls
|
||||
|
||||
return False
|
||||
|
||||
def get_tool_calls(self, response: Any) -> list:
|
||||
"""
|
||||
Extract tool calls from the response.
|
||||
|
||||
Args:
|
||||
response: OpenAI response object or StreamingResponse with tool_call attribute
|
||||
|
||||
Returns:
|
||||
List of tool call objects in a standardized format
|
||||
"""
|
||||
tool_calls = []
|
||||
|
||||
# Check if response is a StreamingResponse with a tool_call attribute
|
||||
if isinstance(response, StreamingResponse) and response.tool_call is not None:
|
||||
logger.info(f"Extracting tool call from StreamingResponse: {response.tool_call}")
|
||||
tool_calls.append(response.tool_call)
|
||||
return tool_calls
|
||||
|
||||
# For non-streaming responses
|
||||
if hasattr(response, "choices") and len(response.choices) > 0:
|
||||
message = response.choices[0].message
|
||||
if hasattr(message, "tool_calls") and message.tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
standardized_tool_call = {"id": tool_call.id, "name": tool_call.function.name, "input": json.loads(tool_call.function.arguments)}
|
||||
tool_calls.append(standardized_tool_call)
|
||||
|
||||
logger.info(f"Extracted {len(tool_calls)} tool calls from OpenAI response")
|
||||
return tool_calls
|
||||
|
||||
def process_tool_calls(self, response: Any, cookie: str) -> dict[str, Any]:
|
||||
"""
|
||||
Process tool calls from the response.
|
||||
|
||||
Args:
|
||||
response: OpenAI response object
|
||||
response: OpenAI response object or StreamingResponse with tool_call attribute
|
||||
cookie: Airflow cookie for authentication
|
||||
|
||||
Returns:
|
||||
Dictionary mapping tool call IDs to results
|
||||
"""
|
||||
results = {}
|
||||
message = response.choices[0].message
|
||||
|
||||
if not self.has_tool_calls(response):
|
||||
return results
|
||||
|
||||
for tool_call in message.tool_calls:
|
||||
tool_id = tool_call.id
|
||||
function_name = tool_call.function.name
|
||||
arguments = json.loads(tool_call.function.arguments)
|
||||
# Get tool calls using the standardized method
|
||||
tool_calls = self.get_tool_calls(response)
|
||||
logger.info(f"Processing {len(tool_calls)} tool calls")
|
||||
|
||||
for tool_call in tool_calls:
|
||||
tool_id = tool_call["id"]
|
||||
function_name = tool_call["name"]
|
||||
arguments = tool_call["input"]
|
||||
|
||||
try:
|
||||
# Execute the Airflow tool with the provided arguments and cookie
|
||||
@@ -220,7 +260,7 @@ class OpenAIProvider(BaseLLMProvider):
|
||||
"""
|
||||
return response.choices[0].message.content
|
||||
|
||||
def get_streaming_content(self, response: Any) -> Any:
|
||||
def get_streaming_content(self, response: Any) -> StreamingResponse:
|
||||
"""
|
||||
Get a generator for streaming content from the response.
|
||||
|
||||
@@ -228,15 +268,87 @@ class OpenAIProvider(BaseLLMProvider):
|
||||
response: OpenAI streaming response object
|
||||
|
||||
Returns:
|
||||
Generator yielding content chunks
|
||||
StreamingResponse object wrapping a generator that yields content chunks
|
||||
and can also store tool call information detected during streaming
|
||||
"""
|
||||
logger.info("Starting OpenAI streaming response processing")
|
||||
|
||||
# Track only the first tool call detected during streaming
|
||||
tool_call = None
|
||||
tool_use_detected = False
|
||||
current_tool_call = None
|
||||
|
||||
# Create the StreamingResponse object first
|
||||
streaming_response = StreamingResponse(generator=None, tool_call=None)
|
||||
|
||||
def generate():
|
||||
nonlocal tool_call, tool_use_detected, current_tool_call
|
||||
|
||||
for chunk in response:
|
||||
if chunk.choices and chunk.choices[0].delta.content:
|
||||
# Don't do any newline replacement here
|
||||
# Check for tool call in the delta
|
||||
if chunk.choices and hasattr(chunk.choices[0].delta, "tool_calls") and chunk.choices[0].delta.tool_calls:
|
||||
# Tool call detected
|
||||
if not tool_use_detected:
|
||||
tool_use_detected = True
|
||||
logger.info("Tool call detected in streaming response")
|
||||
|
||||
# Initialize the tool call
|
||||
delta_tool_call = chunk.choices[0].delta.tool_calls[0]
|
||||
current_tool_call = {
|
||||
"id": getattr(delta_tool_call, "id", ""),
|
||||
"name": getattr(delta_tool_call.function, "name", "") if hasattr(delta_tool_call, "function") else "",
|
||||
"input": {},
|
||||
}
|
||||
# Update the StreamingResponse object's tool_call attribute
|
||||
streaming_response.tool_call = current_tool_call
|
||||
else:
|
||||
# Update the existing tool call
|
||||
delta_tool_call = chunk.choices[0].delta.tool_calls[0]
|
||||
|
||||
# Update the tool call ID if it's provided in this chunk
|
||||
if hasattr(delta_tool_call, "id") and delta_tool_call.id and current_tool_call:
|
||||
current_tool_call["id"] = delta_tool_call.id
|
||||
|
||||
# Update the function name if it's provided in this chunk
|
||||
if hasattr(delta_tool_call, "function") and hasattr(delta_tool_call.function, "name") and delta_tool_call.function.name and current_tool_call:
|
||||
current_tool_call["name"] = delta_tool_call.function.name
|
||||
|
||||
# Update the arguments if they're provided in this chunk
|
||||
if hasattr(delta_tool_call, "function") and hasattr(delta_tool_call.function, "arguments") and delta_tool_call.function.arguments and current_tool_call:
|
||||
try:
|
||||
# Try to parse the arguments JSON
|
||||
arguments = json.loads(delta_tool_call.function.arguments)
|
||||
if isinstance(arguments, dict):
|
||||
current_tool_call["input"].update(arguments)
|
||||
# Update the StreamingResponse object's tool_call attribute
|
||||
streaming_response.tool_call = current_tool_call
|
||||
except json.JSONDecodeError:
|
||||
# If the arguments are not valid JSON, just log a warning
|
||||
logger.warning(f"Failed to parse arguments: {delta_tool_call.function.arguments}")
|
||||
|
||||
# Skip yielding content for tool call chunks
|
||||
continue
|
||||
|
||||
# For the final chunk, set the tool_call attribute
|
||||
if chunk.choices and hasattr(chunk.choices[0], "finish_reason") and chunk.choices[0].finish_reason == "tool_calls":
|
||||
logger.info("Streaming response finished with tool_calls reason")
|
||||
if current_tool_call:
|
||||
tool_call = current_tool_call
|
||||
logger.info(f"Final tool call: {json.dumps(tool_call)}")
|
||||
# Update the StreamingResponse object's tool_call attribute
|
||||
streaming_response.tool_call = tool_call
|
||||
continue
|
||||
|
||||
# Handle regular content chunks
|
||||
if chunk.choices and hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content:
|
||||
content = chunk.choices[0].delta.content
|
||||
yield content
|
||||
|
||||
return generate()
|
||||
# Create the generator
|
||||
gen = generate()
|
||||
|
||||
# Set the generator in the StreamingResponse object
|
||||
streaming_response.generator = gen
|
||||
|
||||
# Return the StreamingResponse object
|
||||
return streaming_response
|
||||
|
||||
Reference in New Issue
Block a user