diff --git a/src/airflow_wingman/llms_models.py b/src/airflow_wingman/llms_models.py index ff6f2b9..b70424b 100644 --- a/src/airflow_wingman/llms_models.py +++ b/src/airflow_wingman/llms_models.py @@ -1,7 +1,7 @@ MODELS = { "openai": { "name": "OpenAI", - "endpoint": "https://api.openai.com/v1/chat/completions", + "endpoint": "https://api.openai.com/v1", "models": [ { "id": "gpt-4o", diff --git a/src/airflow_wingman/providers/openai_provider.py b/src/airflow_wingman/providers/openai_provider.py index d1847da..310dde7 100644 --- a/src/airflow_wingman/providers/openai_provider.py +++ b/src/airflow_wingman/providers/openai_provider.py @@ -37,6 +37,15 @@ class OpenAIProvider(BaseLLMProvider): base_url: Optional base URL for the API (used for OpenRouter) """ self.api_key = api_key + + # Ensure the base_url doesn't end with /chat/completions to prevent URL duplication + if base_url and '/chat/completions' in base_url: + # Strip the /chat/completions part and ensure we have a proper base URL + base_url = base_url.split('/chat/completions')[0] + if not base_url.endswith('/v1'): + base_url = f"{base_url}/v1" if not base_url.endswith('/') else f"{base_url}v1" + logger.info(f"Modified base_url to prevent endpoint duplication: {base_url}") + self.client = OpenAI(api_key=api_key, base_url=base_url) def convert_tools(self, airflow_tools: list) -> list: @@ -200,7 +209,7 @@ class OpenAIProvider(BaseLLMProvider): return results def create_follow_up_completion( - self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None + self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None, stream: bool = False, tools: list[dict[str, Any]] | None = None ) -> Any: """ Create a follow-up completion with tool results. @@ -212,22 +221,51 @@ class OpenAIProvider(BaseLLMProvider): max_tokens: Maximum tokens to generate tool_results: Results of tool executions original_response: Original response with tool calls + stream: Whether to stream the response + tools: List of tool definitions in OpenAI format Returns: - OpenAI response object + OpenAI response object or StreamingResponse if streaming """ if not original_response or not tool_results: return original_response - # Get the original message with tool calls - original_message = original_response.choices[0].message + # Handle StreamingResponse objects + if isinstance(original_response, StreamingResponse): + logger.info("Processing StreamingResponse in create_follow_up_completion") + # Extract tool calls from StreamingResponse + tool_calls = [] + if original_response.tool_call is not None: + logger.info(f"Found tool call in StreamingResponse: {original_response.tool_call}") + tool_call = original_response.tool_call + # Create a simplified tool call structure for the assistant message + tool_calls.append({ + "id": tool_call.get("id", ""), + "type": "function", + "function": { + "name": tool_call.get("name", ""), + "arguments": json.dumps(tool_call.get("input", {})) + } + }) + + # Create a new message with the tool calls + assistant_message = { + "role": "assistant", + "content": None, + "tool_calls": tool_calls, + } + else: + # Handle regular OpenAI response objects + logger.info("Processing regular OpenAI response in create_follow_up_completion") + # Get the original message with tool calls + original_message = original_response.choices[0].message - # Create a new message with the tool calls - assistant_message = { - "role": "assistant", - "content": None, - "tool_calls": [{"id": tc.id, "type": "function", "function": {"name": tc.function.name, "arguments": tc.function.arguments}} for tc in original_message.tool_calls], - } + # Create a new message with the tool calls + assistant_message = { + "role": "assistant", + "content": None, + "tool_calls": [{"id": tc.id, "type": "function", "function": {"name": tc.function.name, "arguments": tc.function.arguments}} for tc in original_message.tool_calls], + } # Create tool result messages tool_messages = [] @@ -238,14 +276,14 @@ class OpenAIProvider(BaseLLMProvider): new_messages = messages + [assistant_message] + tool_messages # Make a second request to get the final response - logger.info("Making second request with tool results") + logger.info(f"Making second request with tool results (stream={stream})") return self.create_chat_completion( messages=new_messages, model=model, temperature=temperature, max_tokens=max_tokens, - stream=False, - tools=None, # No tools needed for follow-up + stream=stream, + tools=tools, # Pass tools parameter for follow-up ) def get_content(self, response: Any) -> str: @@ -283,6 +321,9 @@ class OpenAIProvider(BaseLLMProvider): def generate(): nonlocal tool_call, tool_use_detected, current_tool_call + + # Flag to track if we've yielded any content + has_yielded_content = False for chunk in response: # Check for tool call in the delta @@ -315,16 +356,25 @@ class OpenAIProvider(BaseLLMProvider): # Update the arguments if they're provided in this chunk if hasattr(delta_tool_call, "function") and hasattr(delta_tool_call.function, "arguments") and delta_tool_call.function.arguments and current_tool_call: + # Instead of trying to parse each chunk as JSON, accumulate the arguments + # and only parse the complete JSON at the end + if "_raw_arguments" not in current_tool_call: + current_tool_call["_raw_arguments"] = "" + + # Accumulate the raw arguments + current_tool_call["_raw_arguments"] += delta_tool_call.function.arguments + + # Try to parse the accumulated arguments try: - # Try to parse the arguments JSON - arguments = json.loads(delta_tool_call.function.arguments) + arguments = json.loads(current_tool_call["_raw_arguments"]) if isinstance(arguments, dict): - current_tool_call["input"].update(arguments) + # Successfully parsed the complete JSON + current_tool_call["input"] = arguments # Replace instead of update # Update the StreamingResponse object's tool_call attribute streaming_response.tool_call = current_tool_call except json.JSONDecodeError: - # If the arguments are not valid JSON, just log a warning - logger.warning(f"Failed to parse arguments: {delta_tool_call.function.arguments}") + # This is expected for partial JSON - we'll try again with the next chunk + logger.debug(f"Accumulated partial arguments: {current_tool_call['_raw_arguments']}") # Skip yielding content for tool call chunks continue @@ -332,7 +382,30 @@ class OpenAIProvider(BaseLLMProvider): # For the final chunk, set the tool_call attribute if chunk.choices and hasattr(chunk.choices[0], "finish_reason") and chunk.choices[0].finish_reason == "tool_calls": logger.info("Streaming response finished with tool_calls reason") + + # If we haven't yielded any content yet and we're finishing with tool_calls, + # yield a placeholder message so the frontend has something to display + if not has_yielded_content and tool_use_detected: + logger.info("Yielding placeholder content for tool call") + yield "I'll help you with that." # Simple placeholder message + has_yielded_content = True if current_tool_call: + # One final attempt to parse the arguments if we have accumulated raw arguments + if "_raw_arguments" in current_tool_call and current_tool_call["_raw_arguments"]: + try: + arguments = json.loads(current_tool_call["_raw_arguments"]) + if isinstance(arguments, dict): + current_tool_call["input"] = arguments + except json.JSONDecodeError: + logger.warning(f"Failed to parse final arguments: {current_tool_call['_raw_arguments']}") + # If we still can't parse it, use an empty dict as fallback + if not current_tool_call["input"]: + current_tool_call["input"] = {} + + # Remove the raw arguments from the final tool call + if "_raw_arguments" in current_tool_call: + del current_tool_call["_raw_arguments"] + tool_call = current_tool_call logger.info(f"Final tool call: {json.dumps(tool_call)}") # Update the StreamingResponse object's tool_call attribute @@ -343,6 +416,7 @@ class OpenAIProvider(BaseLLMProvider): if chunk.choices and hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content: content = chunk.choices[0].delta.content yield content + has_yielded_content = True # Create the generator gen = generate() diff --git a/src/airflow_wingman/tools/conversion.py b/src/airflow_wingman/tools/conversion.py index 09010e2..bad53c2 100644 --- a/src/airflow_wingman/tools/conversion.py +++ b/src/airflow_wingman/tools/conversion.py @@ -66,6 +66,15 @@ def convert_to_openai_tools(airflow_tools: list) -> list: # Add default value if available if "default" in param_info and param_info["default"] is not None: param_def["default"] = param_info["default"] + + # Add items property for array types + if param_def.get("type") == "array" and "items" not in param_def: + # If items is defined in the original schema, use it + if "items" in param_info: + param_def["items"] = param_info["items"] + else: + # Otherwise, default to string items + param_def["items"] = {"type": "string"} # Add to properties openai_tool["function"]["parameters"]["properties"][param_name] = param_def @@ -141,3 +150,7 @@ def _handle_schema_construct(param_def: dict[str, Any], param_info: dict[str, An # If no type was found, default to string if "type" not in param_def: param_def["type"] = "string" + + # Add items property for array types + if param_def.get("type") == "array" and "items" not in param_def: + param_def["items"] = {"type": "string"}