""" OpenAI provider implementation for Airflow Wingman. This module contains the OpenAI provider implementation that handles API requests, tool conversion, and response processing for OpenAI. """ import json import logging import traceback from typing import Any from openai import OpenAI from airflow_wingman.providers.base import BaseLLMProvider, StreamingResponse from airflow_wingman.tools import execute_airflow_tool from airflow_wingman.tools.conversion import convert_to_openai_tools # Create a properly namespaced logger for the Airflow plugin logger = logging.getLogger("airflow.plugins.wingman") class OpenAIProvider(BaseLLMProvider): """ OpenAI provider implementation. This class handles API requests, tool conversion, and response processing for the OpenAI API. It can also be used for OpenRouter with a custom base URL. """ def __init__(self, api_key: str, base_url: str | None = None): """ Initialize the OpenAI provider. Args: api_key: API key for OpenAI base_url: Optional base URL for the API (used for OpenRouter) """ self.api_key = api_key # Ensure the base_url doesn't end with /chat/completions to prevent URL duplication if base_url and '/chat/completions' in base_url: # Strip the /chat/completions part and ensure we have a proper base URL base_url = base_url.split('/chat/completions')[0] if not base_url.endswith('/v1'): base_url = f"{base_url}/v1" if not base_url.endswith('/') else f"{base_url}v1" logger.info(f"Modified base_url to prevent endpoint duplication: {base_url}") self.client = OpenAI(api_key=api_key, base_url=base_url) def convert_tools(self, airflow_tools: list) -> list: """ Convert Airflow tools to OpenAI format. Args: airflow_tools: List of Airflow tools from MCP server Returns: List of OpenAI tool definitions """ return convert_to_openai_tools(airflow_tools) def create_chat_completion( self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, stream: bool = False, tools: list[dict[str, Any]] | None = None ) -> Any: """ Make API request to OpenAI. Args: messages: List of message dictionaries with 'role' and 'content' model: Model identifier temperature: Sampling temperature (0-1) max_tokens: Maximum tokens to generate stream: Whether to stream the response tools: List of tool definitions in OpenAI format Returns: OpenAI response object Raises: Exception: If the API request fails """ # Only include tools if we have any has_tools = tools is not None and len(tools) > 0 tool_choice = "auto" if has_tools else None try: logger.info(f"Sending chat completion request to OpenAI with model: {model}") # Log information about tools if not has_tools: logger.warning("No tools included in request") # Log request parameters request_params = { "model": model, "messages": messages, "temperature": temperature, "max_tokens": max_tokens, "stream": stream, "tools": tools if has_tools else None, "tool_choice": tool_choice, } logger.info(f"Request parameters: {json.dumps(request_params)}") response = self.client.chat.completions.create( model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream, tools=tools if has_tools else None, tool_choice=tool_choice ) logger.info("Received response from OpenAI") return response except Exception as e: # If the API call fails due to tools not being supported, retry without tools error_msg = str(e) logger.warning(f"Error in OpenAI API call: {error_msg}") if "tools" in error_msg.lower(): logger.info("Retrying without tools") response = self.client.chat.completions.create(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream) return response else: logger.error(f"Failed to get response from OpenAI: {error_msg}\n{traceback.format_exc()}") raise def has_tool_calls(self, response: Any) -> bool: """ Check if the response contains tool calls. Args: response: OpenAI response object or StreamingResponse with tool_call attribute Returns: True if the response contains tool calls, False otherwise """ # Check if response is a StreamingResponse with a tool_call attribute if isinstance(response, StreamingResponse) and response.tool_call is not None: return True # For non-streaming responses if hasattr(response, "choices") and len(response.choices) > 0: message = response.choices[0].message return hasattr(message, "tool_calls") and message.tool_calls return False def get_tool_calls(self, response: Any) -> list: """ Extract tool calls from the response. Args: response: OpenAI response object or StreamingResponse with tool_call attribute Returns: List of tool call objects in a standardized format """ tool_calls = [] # Check if response is a StreamingResponse with a tool_call attribute if isinstance(response, StreamingResponse) and response.tool_call is not None: logger.info(f"Extracting tool call from StreamingResponse: {response.tool_call}") tool_calls.append(response.tool_call) return tool_calls # For non-streaming responses if hasattr(response, "choices") and len(response.choices) > 0: message = response.choices[0].message if hasattr(message, "tool_calls") and message.tool_calls: for tool_call in message.tool_calls: standardized_tool_call = {"id": tool_call.id, "name": tool_call.function.name, "input": json.loads(tool_call.function.arguments)} tool_calls.append(standardized_tool_call) logger.info(f"Extracted {len(tool_calls)} tool calls from OpenAI response") return tool_calls def process_tool_calls(self, response: Any, cookie: str) -> dict[str, Any]: """ Process tool calls from the response. Args: response: OpenAI response object or StreamingResponse with tool_call attribute cookie: Airflow cookie for authentication Returns: Dictionary mapping tool call IDs to results """ results = {} if not self.has_tool_calls(response): return results # Get tool calls using the standardized method tool_calls = self.get_tool_calls(response) logger.info(f"Processing {len(tool_calls)} tool calls") for tool_call in tool_calls: tool_id = tool_call["id"] function_name = tool_call["name"] arguments = tool_call["input"] try: # Execute the Airflow tool with the provided arguments and cookie logger.info(f"Executing tool: {function_name} with arguments: {arguments}") result = execute_airflow_tool(function_name, arguments, cookie) logger.info(f"Tool execution result: {result}") results[tool_id] = {"status": "success", "result": result} except Exception as e: error_msg = f"Error executing tool: {str(e)}\n{traceback.format_exc()}" logger.error(error_msg) results[tool_id] = {"status": "error", "message": error_msg} return results def create_follow_up_completion( self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None, stream: bool = False, tools: list[dict[str, Any]] | None = None ) -> Any: """ Create a follow-up completion with tool results. Args: messages: Original messages model: Model identifier temperature: Sampling temperature (0-1) max_tokens: Maximum tokens to generate tool_results: Results of tool executions original_response: Original response with tool calls stream: Whether to stream the response tools: List of tool definitions in OpenAI format Returns: OpenAI response object or StreamingResponse if streaming """ if not original_response or not tool_results: return original_response # Handle StreamingResponse objects if isinstance(original_response, StreamingResponse): logger.info("Processing StreamingResponse in create_follow_up_completion") # Extract tool calls from StreamingResponse tool_calls = [] if original_response.tool_call is not None: logger.info(f"Found tool call in StreamingResponse: {original_response.tool_call}") tool_call = original_response.tool_call # Create a simplified tool call structure for the assistant message tool_calls.append({ "id": tool_call.get("id", ""), "type": "function", "function": { "name": tool_call.get("name", ""), "arguments": json.dumps(tool_call.get("input", {})) } }) # Create a new message with the tool calls assistant_message = { "role": "assistant", "content": None, "tool_calls": tool_calls, } else: # Handle regular OpenAI response objects logger.info("Processing regular OpenAI response in create_follow_up_completion") # Get the original message with tool calls original_message = original_response.choices[0].message # Create a new message with the tool calls assistant_message = { "role": "assistant", "content": None, "tool_calls": [{"id": tc.id, "type": "function", "function": {"name": tc.function.name, "arguments": tc.function.arguments}} for tc in original_message.tool_calls], } # Create tool result messages tool_messages = [] for tool_call_id, result in tool_results.items(): tool_messages.append({"role": "tool", "tool_call_id": tool_call_id, "content": result.get("result", str(result))}) # Add the original messages, assistant message, and tool results new_messages = messages + [assistant_message] + tool_messages # Make a second request to get the final response logger.info(f"Making second request with tool results (stream={stream})") return self.create_chat_completion( messages=new_messages, model=model, temperature=temperature, max_tokens=max_tokens, stream=stream, tools=tools, # Pass tools parameter for follow-up ) def get_content(self, response: Any) -> str: """ Extract content from the response. Args: response: OpenAI response object Returns: Content string from the response """ return response.choices[0].message.content def get_streaming_content(self, response: Any) -> StreamingResponse: """ Get a generator for streaming content from the response. Args: response: OpenAI streaming response object Returns: StreamingResponse object wrapping a generator that yields content chunks and can also store tool call information detected during streaming """ logger.info("Starting OpenAI streaming response processing") # Track only the first tool call detected during streaming tool_call = None tool_use_detected = False current_tool_call = None # Create the StreamingResponse object first streaming_response = StreamingResponse(generator=None, tool_call=None) def generate(): nonlocal tool_call, tool_use_detected, current_tool_call # Flag to track if we've yielded any content has_yielded_content = False for chunk in response: # Check for tool call in the delta if chunk.choices and hasattr(chunk.choices[0].delta, "tool_calls") and chunk.choices[0].delta.tool_calls: # Tool call detected if not tool_use_detected: tool_use_detected = True logger.info("Tool call detected in streaming response") # Initialize the tool call delta_tool_call = chunk.choices[0].delta.tool_calls[0] current_tool_call = { "id": getattr(delta_tool_call, "id", ""), "name": getattr(delta_tool_call.function, "name", "") if hasattr(delta_tool_call, "function") else "", "input": {}, } # Update the StreamingResponse object's tool_call attribute streaming_response.tool_call = current_tool_call else: # Update the existing tool call delta_tool_call = chunk.choices[0].delta.tool_calls[0] # Update the tool call ID if it's provided in this chunk if hasattr(delta_tool_call, "id") and delta_tool_call.id and current_tool_call: current_tool_call["id"] = delta_tool_call.id # Update the function name if it's provided in this chunk if hasattr(delta_tool_call, "function") and hasattr(delta_tool_call.function, "name") and delta_tool_call.function.name and current_tool_call: current_tool_call["name"] = delta_tool_call.function.name # Update the arguments if they're provided in this chunk if hasattr(delta_tool_call, "function") and hasattr(delta_tool_call.function, "arguments") and delta_tool_call.function.arguments and current_tool_call: # Instead of trying to parse each chunk as JSON, accumulate the arguments # and only parse the complete JSON at the end if "_raw_arguments" not in current_tool_call: current_tool_call["_raw_arguments"] = "" # Accumulate the raw arguments current_tool_call["_raw_arguments"] += delta_tool_call.function.arguments # Try to parse the accumulated arguments try: arguments = json.loads(current_tool_call["_raw_arguments"]) if isinstance(arguments, dict): # Successfully parsed the complete JSON current_tool_call["input"] = arguments # Replace instead of update # Update the StreamingResponse object's tool_call attribute streaming_response.tool_call = current_tool_call except json.JSONDecodeError: # This is expected for partial JSON - we'll try again with the next chunk logger.debug(f"Accumulated partial arguments: {current_tool_call['_raw_arguments']}") # Skip yielding content for tool call chunks continue # For the final chunk, set the tool_call attribute if chunk.choices and hasattr(chunk.choices[0], "finish_reason") and chunk.choices[0].finish_reason == "tool_calls": logger.info("Streaming response finished with tool_calls reason") # If we haven't yielded any content yet and we're finishing with tool_calls, # yield a placeholder message so the frontend has something to display if not has_yielded_content and tool_use_detected: logger.info("Yielding placeholder content for tool call") yield "I'll help you with that." # Simple placeholder message has_yielded_content = True if current_tool_call: # One final attempt to parse the arguments if we have accumulated raw arguments if "_raw_arguments" in current_tool_call and current_tool_call["_raw_arguments"]: try: arguments = json.loads(current_tool_call["_raw_arguments"]) if isinstance(arguments, dict): current_tool_call["input"] = arguments except json.JSONDecodeError: logger.warning(f"Failed to parse final arguments: {current_tool_call['_raw_arguments']}") # If we still can't parse it, use an empty dict as fallback if not current_tool_call["input"]: current_tool_call["input"] = {} # Remove the raw arguments from the final tool call if "_raw_arguments" in current_tool_call: del current_tool_call["_raw_arguments"] tool_call = current_tool_call logger.info(f"Final tool call: {json.dumps(tool_call)}") # Update the StreamingResponse object's tool_call attribute streaming_response.tool_call = tool_call continue # Handle regular content chunks if chunk.choices and hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content: content = chunk.choices[0].delta.content yield content has_yielded_content = True # Create the generator gen = generate() # Set the generator in the StreamingResponse object streaming_response.generator = gen # Return the StreamingResponse object return streaming_response