Files
airflow-wingman/src/airflow_wingman/providers/openai_provider.py

429 lines
19 KiB
Python

"""
OpenAI provider implementation for Airflow Wingman.
This module contains the OpenAI provider implementation that handles
API requests, tool conversion, and response processing for OpenAI.
"""
import json
import logging
import traceback
from typing import Any
from openai import OpenAI
from airflow_wingman.providers.base import BaseLLMProvider, StreamingResponse
from airflow_wingman.tools import execute_airflow_tool
from airflow_wingman.tools.conversion import convert_to_openai_tools
# Create a properly namespaced logger for the Airflow plugin
logger = logging.getLogger("airflow.plugins.wingman")
class OpenAIProvider(BaseLLMProvider):
"""
OpenAI provider implementation.
This class handles API requests, tool conversion, and response processing
for the OpenAI API. It can also be used for OpenRouter with a custom base URL.
"""
def __init__(self, api_key: str, base_url: str | None = None):
"""
Initialize the OpenAI provider.
Args:
api_key: API key for OpenAI
base_url: Optional base URL for the API (used for OpenRouter)
"""
self.api_key = api_key
# Ensure the base_url doesn't end with /chat/completions to prevent URL duplication
if base_url and '/chat/completions' in base_url:
# Strip the /chat/completions part and ensure we have a proper base URL
base_url = base_url.split('/chat/completions')[0]
if not base_url.endswith('/v1'):
base_url = f"{base_url}/v1" if not base_url.endswith('/') else f"{base_url}v1"
logger.info(f"Modified base_url to prevent endpoint duplication: {base_url}")
self.client = OpenAI(api_key=api_key, base_url=base_url)
def convert_tools(self, airflow_tools: list) -> list:
"""
Convert Airflow tools to OpenAI format.
Args:
airflow_tools: List of Airflow tools from MCP server
Returns:
List of OpenAI tool definitions
"""
return convert_to_openai_tools(airflow_tools)
def create_chat_completion(
self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, stream: bool = False, tools: list[dict[str, Any]] | None = None
) -> Any:
"""
Make API request to OpenAI.
Args:
messages: List of message dictionaries with 'role' and 'content'
model: Model identifier
temperature: Sampling temperature (0-1)
max_tokens: Maximum tokens to generate
stream: Whether to stream the response
tools: List of tool definitions in OpenAI format
Returns:
OpenAI response object
Raises:
Exception: If the API request fails
"""
# Only include tools if we have any
has_tools = tools is not None and len(tools) > 0
tool_choice = "auto" if has_tools else None
try:
logger.info(f"Sending chat completion request to OpenAI with model: {model}")
# Log information about tools
if not has_tools:
logger.warning("No tools included in request")
# Log request parameters
request_params = {
"model": model,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
"stream": stream,
"tools": tools if has_tools else None,
"tool_choice": tool_choice,
}
logger.info(f"Request parameters: {json.dumps(request_params)}")
response = self.client.chat.completions.create(
model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream, tools=tools if has_tools else None, tool_choice=tool_choice
)
logger.info("Received response from OpenAI")
return response
except Exception as e:
# If the API call fails due to tools not being supported, retry without tools
error_msg = str(e)
logger.warning(f"Error in OpenAI API call: {error_msg}")
if "tools" in error_msg.lower():
logger.info("Retrying without tools")
response = self.client.chat.completions.create(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream)
return response
else:
logger.error(f"Failed to get response from OpenAI: {error_msg}\n{traceback.format_exc()}")
raise
def has_tool_calls(self, response: Any) -> bool:
"""
Check if the response contains tool calls.
Args:
response: OpenAI response object or StreamingResponse with tool_call attribute
Returns:
True if the response contains tool calls, False otherwise
"""
# Check if response is a StreamingResponse with a tool_call attribute
if isinstance(response, StreamingResponse) and response.tool_call is not None:
return True
# For non-streaming responses
if hasattr(response, "choices") and len(response.choices) > 0:
message = response.choices[0].message
return hasattr(message, "tool_calls") and message.tool_calls
return False
def get_tool_calls(self, response: Any) -> list:
"""
Extract tool calls from the response.
Args:
response: OpenAI response object or StreamingResponse with tool_call attribute
Returns:
List of tool call objects in a standardized format
"""
tool_calls = []
# Check if response is a StreamingResponse with a tool_call attribute
if isinstance(response, StreamingResponse) and response.tool_call is not None:
logger.info(f"Extracting tool call from StreamingResponse: {response.tool_call}")
tool_calls.append(response.tool_call)
return tool_calls
# For non-streaming responses
if hasattr(response, "choices") and len(response.choices) > 0:
message = response.choices[0].message
if hasattr(message, "tool_calls") and message.tool_calls:
for tool_call in message.tool_calls:
standardized_tool_call = {"id": tool_call.id, "name": tool_call.function.name, "input": json.loads(tool_call.function.arguments)}
tool_calls.append(standardized_tool_call)
logger.info(f"Extracted {len(tool_calls)} tool calls from OpenAI response")
return tool_calls
def process_tool_calls(self, response: Any, cookie: str) -> dict[str, Any]:
"""
Process tool calls from the response.
Args:
response: OpenAI response object or StreamingResponse with tool_call attribute
cookie: Airflow cookie for authentication
Returns:
Dictionary mapping tool call IDs to results
"""
results = {}
if not self.has_tool_calls(response):
return results
# Get tool calls using the standardized method
tool_calls = self.get_tool_calls(response)
logger.info(f"Processing {len(tool_calls)} tool calls")
for tool_call in tool_calls:
tool_id = tool_call["id"]
function_name = tool_call["name"]
arguments = tool_call["input"]
try:
# Execute the Airflow tool with the provided arguments and cookie
logger.info(f"Executing tool: {function_name} with arguments: {arguments}")
result = execute_airflow_tool(function_name, arguments, cookie)
logger.info(f"Tool execution result: {result}")
results[tool_id] = {"status": "success", "result": result}
except Exception as e:
error_msg = f"Error executing tool: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
results[tool_id] = {"status": "error", "message": error_msg}
return results
def create_follow_up_completion(
self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None, stream: bool = False, tools: list[dict[str, Any]] | None = None
) -> Any:
"""
Create a follow-up completion with tool results.
Args:
messages: Original messages
model: Model identifier
temperature: Sampling temperature (0-1)
max_tokens: Maximum tokens to generate
tool_results: Results of tool executions
original_response: Original response with tool calls
stream: Whether to stream the response
tools: List of tool definitions in OpenAI format
Returns:
OpenAI response object or StreamingResponse if streaming
"""
if not original_response or not tool_results:
return original_response
# Handle StreamingResponse objects
if isinstance(original_response, StreamingResponse):
logger.info("Processing StreamingResponse in create_follow_up_completion")
# Extract tool calls from StreamingResponse
tool_calls = []
if original_response.tool_call is not None:
logger.info(f"Found tool call in StreamingResponse: {original_response.tool_call}")
tool_call = original_response.tool_call
# Create a simplified tool call structure for the assistant message
tool_calls.append({
"id": tool_call.get("id", ""),
"type": "function",
"function": {
"name": tool_call.get("name", ""),
"arguments": json.dumps(tool_call.get("input", {}))
}
})
# Create a new message with the tool calls
assistant_message = {
"role": "assistant",
"content": None,
"tool_calls": tool_calls,
}
else:
# Handle regular OpenAI response objects
logger.info("Processing regular OpenAI response in create_follow_up_completion")
# Get the original message with tool calls
original_message = original_response.choices[0].message
# Create a new message with the tool calls
assistant_message = {
"role": "assistant",
"content": None,
"tool_calls": [{"id": tc.id, "type": "function", "function": {"name": tc.function.name, "arguments": tc.function.arguments}} for tc in original_message.tool_calls],
}
# Create tool result messages
tool_messages = []
for tool_call_id, result in tool_results.items():
tool_messages.append({"role": "tool", "tool_call_id": tool_call_id, "content": result.get("result", str(result))})
# Add the original messages, assistant message, and tool results
new_messages = messages + [assistant_message] + tool_messages
# Make a second request to get the final response
logger.info(f"Making second request with tool results (stream={stream})")
return self.create_chat_completion(
messages=new_messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
stream=stream,
tools=tools, # Pass tools parameter for follow-up
)
def get_content(self, response: Any) -> str:
"""
Extract content from the response.
Args:
response: OpenAI response object
Returns:
Content string from the response
"""
return response.choices[0].message.content
def get_streaming_content(self, response: Any) -> StreamingResponse:
"""
Get a generator for streaming content from the response.
Args:
response: OpenAI streaming response object
Returns:
StreamingResponse object wrapping a generator that yields content chunks
and can also store tool call information detected during streaming
"""
logger.info("Starting OpenAI streaming response processing")
# Track only the first tool call detected during streaming
tool_call = None
tool_use_detected = False
current_tool_call = None
# Create the StreamingResponse object first
streaming_response = StreamingResponse(generator=None, tool_call=None)
def generate():
nonlocal tool_call, tool_use_detected, current_tool_call
# Flag to track if we've yielded any content
has_yielded_content = False
for chunk in response:
# Check for tool call in the delta
if chunk.choices and hasattr(chunk.choices[0].delta, "tool_calls") and chunk.choices[0].delta.tool_calls:
# Tool call detected
if not tool_use_detected:
tool_use_detected = True
logger.info("Tool call detected in streaming response")
# Initialize the tool call
delta_tool_call = chunk.choices[0].delta.tool_calls[0]
current_tool_call = {
"id": getattr(delta_tool_call, "id", ""),
"name": getattr(delta_tool_call.function, "name", "") if hasattr(delta_tool_call, "function") else "",
"input": {},
}
# Update the StreamingResponse object's tool_call attribute
streaming_response.tool_call = current_tool_call
else:
# Update the existing tool call
delta_tool_call = chunk.choices[0].delta.tool_calls[0]
# Update the tool call ID if it's provided in this chunk
if hasattr(delta_tool_call, "id") and delta_tool_call.id and current_tool_call:
current_tool_call["id"] = delta_tool_call.id
# Update the function name if it's provided in this chunk
if hasattr(delta_tool_call, "function") and hasattr(delta_tool_call.function, "name") and delta_tool_call.function.name and current_tool_call:
current_tool_call["name"] = delta_tool_call.function.name
# Update the arguments if they're provided in this chunk
if hasattr(delta_tool_call, "function") and hasattr(delta_tool_call.function, "arguments") and delta_tool_call.function.arguments and current_tool_call:
# Instead of trying to parse each chunk as JSON, accumulate the arguments
# and only parse the complete JSON at the end
if "_raw_arguments" not in current_tool_call:
current_tool_call["_raw_arguments"] = ""
# Accumulate the raw arguments
current_tool_call["_raw_arguments"] += delta_tool_call.function.arguments
# Try to parse the accumulated arguments
try:
arguments = json.loads(current_tool_call["_raw_arguments"])
if isinstance(arguments, dict):
# Successfully parsed the complete JSON
current_tool_call["input"] = arguments # Replace instead of update
# Update the StreamingResponse object's tool_call attribute
streaming_response.tool_call = current_tool_call
except json.JSONDecodeError:
# This is expected for partial JSON - we'll try again with the next chunk
logger.debug(f"Accumulated partial arguments: {current_tool_call['_raw_arguments']}")
# Skip yielding content for tool call chunks
continue
# For the final chunk, set the tool_call attribute
if chunk.choices and hasattr(chunk.choices[0], "finish_reason") and chunk.choices[0].finish_reason == "tool_calls":
logger.info("Streaming response finished with tool_calls reason")
# If we haven't yielded any content yet and we're finishing with tool_calls,
# yield a placeholder message so the frontend has something to display
if not has_yielded_content and tool_use_detected:
logger.info("Yielding placeholder content for tool call")
yield "I'll help you with that." # Simple placeholder message
has_yielded_content = True
if current_tool_call:
# One final attempt to parse the arguments if we have accumulated raw arguments
if "_raw_arguments" in current_tool_call and current_tool_call["_raw_arguments"]:
try:
arguments = json.loads(current_tool_call["_raw_arguments"])
if isinstance(arguments, dict):
current_tool_call["input"] = arguments
except json.JSONDecodeError:
logger.warning(f"Failed to parse final arguments: {current_tool_call['_raw_arguments']}")
# If we still can't parse it, use an empty dict as fallback
if not current_tool_call["input"]:
current_tool_call["input"] = {}
# Remove the raw arguments from the final tool call
if "_raw_arguments" in current_tool_call:
del current_tool_call["_raw_arguments"]
tool_call = current_tool_call
logger.info(f"Final tool call: {json.dumps(tool_call)}")
# Update the StreamingResponse object's tool_call attribute
streaming_response.tool_call = tool_call
continue
# Handle regular content chunks
if chunk.choices and hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
yield content
has_yielded_content = True
# Create the generator
gen = generate()
# Set the generator in the StreamingResponse object
streaming_response.generator = gen
# Return the StreamingResponse object
return streaming_response