From 5491ba71aa6d0a3396706e555aeabb36d23f11fa Mon Sep 17 00:00:00 2001 From: abhishekbhakat Date: Sun, 2 Mar 2025 07:44:24 +0000 Subject: [PATCH] fix tool call until cookies are legit --- src/airflow_wingman/llm_client.py | 11 +- .../providers/anthropic_provider.py | 142 +++++++++++++----- src/airflow_wingman/providers/base.py | 111 ++++++++++++-- .../providers/openai_provider.py | 142 ++++++++++++++++-- src/airflow_wingman/tools/execution.py | 48 +++++- src/airflow_wingman/views.py | 19 ++- 6 files changed, 394 insertions(+), 79 deletions(-) diff --git a/src/airflow_wingman/llm_client.py b/src/airflow_wingman/llm_client.py index b81a19b..c13b2d7 100644 --- a/src/airflow_wingman/llm_client.py +++ b/src/airflow_wingman/llm_client.py @@ -81,10 +81,8 @@ class LLMClient: # If streaming, handle based on return_response_obj flag if stream: logger.info(f"Using streaming response from {self.provider_name}") - if return_response_obj: - return response, self.provider.get_streaming_content(response) - else: - return self.provider.get_streaming_content(response) + streaming_content = self.provider.get_streaming_content(response) + return streaming_content # For non-streaming responses, handle tool calls if present if self.provider.has_tool_calls(response): @@ -143,7 +141,7 @@ class LLMClient: return cls(provider_name=provider_name, api_key=api_key, base_url=base_url) - def process_tool_calls_and_follow_up(self, response, messages, model, temperature, max_tokens, max_iterations=5): + def process_tool_calls_and_follow_up(self, response, messages, model, temperature, max_tokens, max_iterations=5, cookie=None): """ Process tool calls recursively from a response and make follow-up requests until there are no more tool calls or max_iterations is reached. @@ -156,6 +154,7 @@ class LLMClient: temperature: Sampling temperature (0-1) max_tokens: Maximum tokens to generate max_iterations: Maximum number of tool call iterations to prevent infinite loops + cookie: Airflow cookie for authentication (optional, will try to get from session if not provided) Returns: Generator for streaming the final follow-up response @@ -163,8 +162,8 @@ class LLMClient: try: iteration = 0 current_response = response - cookie = session.get("airflow_cookie") + # Check if we have a cookie if not cookie: error_msg = "No Airflow cookie available" logger.error(error_msg) diff --git a/src/airflow_wingman/providers/anthropic_provider.py b/src/airflow_wingman/providers/anthropic_provider.py index 6996e30..94a25c6 100644 --- a/src/airflow_wingman/providers/anthropic_provider.py +++ b/src/airflow_wingman/providers/anthropic_provider.py @@ -12,7 +12,7 @@ from typing import Any from anthropic import Anthropic -from airflow_wingman.providers.base import BaseLLMProvider +from airflow_wingman.providers.base import BaseLLMProvider, StreamingResponse from airflow_wingman.tools import execute_airflow_tool from airflow_wingman.tools.conversion import convert_to_anthropic_tools @@ -101,7 +101,18 @@ class AnthropicProvider(BaseLLMProvider): # Log as much information as possible if hasattr(response, "json"): - logger.info(f"Anthropic response json: {json.dumps(response.json)}") + if callable(response.json): + # If json is a method, call it + try: + logger.info(f"Anthropic response json: {json.dumps(response.json())}") + except Exception as json_err: + logger.warning(f"Could not serialize response.json(): {str(json_err)}") + else: + # If json is a property, use it directly + try: + logger.info(f"Anthropic response json: {json.dumps(response.json)}") + except Exception as json_err: + logger.warning(f"Could not serialize response.json: {str(json_err)}") # Log response attributes response_attrs = [attr for attr in dir(response) if not attr.startswith("_") and not callable(getattr(response, attr))] @@ -150,22 +161,63 @@ class AnthropicProvider(BaseLLMProvider): Check if the response contains tool calls. Args: - response: Anthropic response object or generator with tool_call attribute + response: Anthropic response object or StreamingResponse with tool_call attribute Returns: True if the response contains tool calls, False otherwise """ - # Check if response is a generator with a tool_call attribute - if hasattr(response, "tool_call") and response.tool_call is not None: - return True + logger.info(f"Checking for tool calls in response of type: {type(response)}") + + # Check if response is a StreamingResponse with a tool_call attribute + if isinstance(response, StreamingResponse): + logger.info(f"Response is a StreamingResponse, has tool_call attribute: {hasattr(response, 'tool_call')}") + if response.tool_call is not None: + logger.info(f"StreamingResponse has non-None tool_call: {response.tool_call}") + return True + else: + logger.info("StreamingResponse has None tool_call") + else: + logger.info("Response is not a StreamingResponse") # Check if any content block is a tool_use block (for non-streaming responses) if hasattr(response, "content"): + logger.info(f"Response has content attribute with {len(response.content)} blocks") + for i, block in enumerate(response.content): + logger.info(f"Checking content block {i}: {type(block)}") + if isinstance(block, dict) and block.get("type") == "tool_use": + logger.info(f"Found tool_use block: {block}") + return True + else: + logger.info("Response does not have content attribute") + + logger.info("No tool calls found in response") + return False + + def get_tool_calls(self, response: Any) -> list: + """ + Extract tool calls from the response. + + Args: + response: Anthropic response object or StreamingResponse with tool_call attribute + + Returns: + List of tool call objects in a standardized format + """ + tool_calls = [] + + # Check if response is a StreamingResponse with a tool_call attribute + if isinstance(response, StreamingResponse) and response.tool_call is not None: + logger.info(f"Extracting tool call from StreamingResponse: {response.tool_call}") + tool_calls.append(response.tool_call) + # Otherwise, extract tool calls from response content (for non-streaming responses) + elif hasattr(response, "content"): + logger.info("Extracting tool calls from response content") for block in response.content: if isinstance(block, dict) and block.get("type") == "tool_use": - return True + tool_call = {"id": block.get("id", ""), "name": block.get("name", ""), "input": block.get("input", {})} + tool_calls.append(tool_call) - return False + return tool_calls def process_tool_calls(self, response: Any, cookie: str) -> dict[str, Any]: """ @@ -183,16 +235,9 @@ class AnthropicProvider(BaseLLMProvider): if not self.has_tool_calls(response): return results - tool_calls = [] - - # Check if response is a generator with a tool_call attribute - if hasattr(response, "tool_call") and response.tool_call is not None: - logger.info(f"Processing tool call from generator: {response.tool_call}") - tool_calls.append(response.tool_call) - # Otherwise, extract tool calls from response content (for non-streaming responses) - elif hasattr(response, "content"): - logger.info("Processing tool calls from response content") - tool_calls = [block for block in response.content if isinstance(block, dict) and block.get("type") == "tool_use"] + # Get tool calls using the standardized method + tool_calls = self.get_tool_calls(response) + logger.info(f"Processing {len(tool_calls)} tool calls") for tool_call in tool_calls: # Extract tool details - handle both formats (generator's tool_call and content block) @@ -221,7 +266,14 @@ class AnthropicProvider(BaseLLMProvider): return results def create_follow_up_completion( - self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None + self, + messages: list[dict[str, Any]], + model: str, + temperature: float = 0.4, + max_tokens: int | None = None, + tool_results: dict[str, Any] = None, + original_response: Any = None, + stream: bool = True, ) -> Any: """ Create a follow-up completion with tool results. @@ -233,15 +285,25 @@ class AnthropicProvider(BaseLLMProvider): max_tokens: Maximum tokens to generate tool_results: Results of tool executions original_response: Original response with tool calls + stream: Whether to stream the response Returns: - Anthropic response object + Anthropic response object or generator if streaming """ if not original_response or not tool_results: return original_response - # Extract tool_use blocks from the original response - tool_use_blocks = [block for block in original_response.content if isinstance(block, dict) and block.get("type") == "tool_use"] + # Extract tool call from the StreamingResponse or content blocks from Anthropic response + tool_use_blocks = [] + if isinstance(original_response, StreamingResponse) and original_response.tool_call: + # For StreamingResponse, create a tool_use block from the tool_call + logger.info(f"Creating tool_use block from StreamingResponse.tool_call: {original_response.tool_call}") + tool_call = original_response.tool_call + tool_use_blocks.append({"type": "tool_use", "id": tool_call.get("id", ""), "name": tool_call.get("name", ""), "input": tool_call.get("input", {})}) + elif hasattr(original_response, "content"): + # For regular Anthropic response, extract from content blocks + logger.info("Extracting tool_use blocks from response content") + tool_use_blocks = [block for block in original_response.content if isinstance(block, dict) and block.get("type") == "tool_use"] # Create tool result blocks tool_result_blocks = [] @@ -258,13 +320,13 @@ class AnthropicProvider(BaseLLMProvider): anthropic_messages.append({"role": "user", "content": tool_result_blocks}) # Make a second request to get the final response - logger.info("Making second request with tool results") + logger.info(f"Making second request with tool results (stream={stream})") return self.create_chat_completion( messages=anthropic_messages, model=model, temperature=temperature, max_tokens=max_tokens, - stream=False, + stream=stream, tools=None, # No tools needed for follow-up ) @@ -291,7 +353,7 @@ class AnthropicProvider(BaseLLMProvider): return "".join(content_parts) - def get_streaming_content(self, response: Any) -> Any: + def get_streaming_content(self, response: Any) -> StreamingResponse: """ Get a generator for streaming content from the response. @@ -299,7 +361,8 @@ class AnthropicProvider(BaseLLMProvider): response: Anthropic streaming response object Returns: - Generator yielding content chunks with tool_call attribute if detected + StreamingResponse object wrapping a generator that yields content chunks + and can also store tool call information detected during streaming """ logger.info("Starting Anthropic streaming response processing") @@ -307,20 +370,25 @@ class AnthropicProvider(BaseLLMProvider): tool_call = None tool_use_detected = False + # Create the StreamingResponse object first + streaming_response = StreamingResponse(generator=None, tool_call=None) + def generate(): nonlocal tool_call, tool_use_detected for chunk in response: logger.debug(f"Chunk type: {type(chunk)}") - logger.debug(f"Chunk content: {json.dumps(chunk.json) if hasattr(chunk, 'json') else str(chunk)}") + logger.debug(f"Chunk content: {json.dumps(chunk.model_dump_json()) if hasattr(chunk, 'json') else str(chunk)}") # Check for content_block_start events with type "tool_use" if not tool_use_detected and hasattr(chunk, "type") and chunk.type == "content_block_start": if hasattr(chunk, "content_block") and hasattr(chunk.content_block, "type"): if chunk.content_block.type == "tool_use": - logger.info(f"Tool use detected in streaming response: {json.dumps(chunk.json) if hasattr(chunk, 'json') else str(chunk)}") + logger.info(f"Tool use detected in streaming response: {json.dumps(chunk.model_dump_json()) if hasattr(chunk, 'json') else str(chunk)}") tool_use_detected = True tool_call = {"id": getattr(chunk.content_block, "id", ""), "name": getattr(chunk.content_block, "name", ""), "input": getattr(chunk.content_block, "input", {})} + # Update the StreamingResponse object's tool_call attribute + streaming_response.tool_call = tool_call # We don't signal to the frontend during streaming # The tool will only be executed after streaming ends continue @@ -328,7 +396,7 @@ class AnthropicProvider(BaseLLMProvider): # Handle content_block_delta events for tool_use (input updates) if tool_use_detected and hasattr(chunk, "type") and chunk.type == "content_block_delta": if hasattr(chunk, "delta") and hasattr(chunk.delta, "type") and chunk.delta.type == "input_json_delta": - if hasattr(chunk.delta, "partial_json"): + if hasattr(chunk.delta, "partial_json") and chunk.delta.partial_json: logger.info(f"Tool use input update: {chunk.delta.partial_json}") # Update the current tool call input if tool_call: @@ -336,6 +404,8 @@ class AnthropicProvider(BaseLLMProvider): # Try to parse the partial JSON and update the input partial_input = json.loads(chunk.delta.partial_json) tool_call["input"].update(partial_input) + # Update the StreamingResponse object's tool_call attribute + streaming_response.tool_call = tool_call except json.JSONDecodeError: logger.warning(f"Failed to parse partial JSON: {chunk.delta.partial_json}") continue @@ -346,6 +416,8 @@ class AnthropicProvider(BaseLLMProvider): # Log the complete tool call for debugging if tool_call: logger.info(f"Completed tool call: {json.dumps(tool_call)}") + # Update the StreamingResponse object's tool_call attribute + streaming_response.tool_call = tool_call continue # Handle message_delta events with stop_reason "tool_use" @@ -353,6 +425,9 @@ class AnthropicProvider(BaseLLMProvider): if hasattr(chunk, "delta") and hasattr(chunk.delta, "stop_reason"): if chunk.delta.stop_reason == "tool_use": logger.info("Message stopped due to tool use") + # Update the StreamingResponse object's tool_call attribute one last time + if tool_call: + streaming_response.tool_call = tool_call continue # Handle regular content chunks @@ -374,9 +449,8 @@ class AnthropicProvider(BaseLLMProvider): # Create the generator gen = generate() - # Attach the single tool_call to the generator object for later reference - # This will be used after streaming is complete - gen.tool_call = tool_call + # Set the generator in the StreamingResponse object + streaming_response.generator = gen - # Return the enhanced generator - return gen + # Return the StreamingResponse object + return streaming_response diff --git a/src/airflow_wingman/providers/base.py b/src/airflow_wingman/providers/base.py index 87f4e7c..18195dc 100644 --- a/src/airflow_wingman/providers/base.py +++ b/src/airflow_wingman/providers/base.py @@ -6,8 +6,45 @@ must adhere to. It defines the methods required for tool conversion, API request and response processing. """ +import json from abc import ABC, abstractmethod -from typing import Any +from collections.abc import Generator, Iterator +from typing import Any, Generic, TypeVar + +T = TypeVar("T") + + +class StreamingResponse(Generic[T]): + """ + Wrapper for streaming responses that can hold tool call information. + + This class wraps a generator and provides an iterator interface while also + storing tool call information. This allows us to associate metadata with + a generator without modifying the generator itself. + """ + + def __init__(self, generator: Generator[T, None, None], tool_call: dict = None): + """ + Initialize the streaming response. + + Args: + generator: The underlying generator yielding content chunks + tool_call: Optional tool call information detected during streaming + """ + self.generator = generator + self.tool_call = tool_call + + def __iter__(self) -> Iterator[T]: + """ + Return self as iterator. + """ + return self + + def __next__(self) -> T: + """ + Get the next item from the generator. + """ + return next(self.generator) class BaseLLMProvider(ABC): @@ -51,32 +88,85 @@ class BaseLLMProvider(ABC): """ pass - @abstractmethod def has_tool_calls(self, response: Any) -> bool: """ Check if the response contains tool calls. Args: - response: Provider-specific response object + response: Provider-specific response object or StreamingResponse Returns: True if the response contains tool calls, False otherwise """ - pass + # Check if response is a StreamingResponse with a tool_call attribute + if isinstance(response, StreamingResponse) and response.tool_call is not None: + return True + + # Provider-specific implementation should handle other cases + return False + + def get_tool_calls(self, response: Any) -> list: + """ + Extract tool calls from the response. + + Args: + response: Provider-specific response object or StreamingResponse + + Returns: + List of tool call objects in a standardized format + """ + tool_calls = [] + + # Check if response is a StreamingResponse with a tool_call attribute + if isinstance(response, StreamingResponse) and response.tool_call is not None: + tool_calls.append(response.tool_call) + + # Provider-specific implementation should handle other cases + return tool_calls - @abstractmethod def process_tool_calls(self, response: Any, cookie: str) -> dict[str, Any]: """ Process tool calls from the response. Args: - response: Provider-specific response object + response: Provider-specific response object or StreamingResponse cookie: Airflow cookie for authentication Returns: Dictionary mapping tool call IDs to results """ - pass + tool_calls = self.get_tool_calls(response) + results = {} + + for tool_call in tool_calls: + tool_name = tool_call.get("name", "") + tool_input = tool_call.get("input", {}) + tool_id = tool_call.get("id", "") + + try: + import logging + + logger = logging.getLogger(__name__) + logger.info(f"Executing tool: {tool_name} with input: {json.dumps(tool_input)}") + + from airflow_wingman.tools import execute_airflow_tool + + result = execute_airflow_tool(tool_name, tool_input, cookie) + + logger.info(f"Tool result: {json.dumps(result)}") + results[tool_id] = { + "name": tool_name, + "input": tool_input, + "output": result, + } + except Exception as e: + import traceback + + error_msg = f"Error executing tool: {str(e)}\n{traceback.format_exc()}" + logger.error(error_msg) + results[tool_id] = {"status": "error", "message": error_msg} + + return results @abstractmethod def create_follow_up_completion( @@ -112,14 +202,15 @@ class BaseLLMProvider(ABC): pass @abstractmethod - def get_streaming_content(self, response: Any) -> Any: + def get_streaming_content(self, response: Any) -> StreamingResponse: """ - Get a generator for streaming content from the response. + Get a StreamingResponse for streaming content from the response. Args: response: Provider-specific response object Returns: - Generator yielding content chunks + StreamingResponse object wrapping a generator that yields content chunks + and can also store tool call information detected during streaming """ pass diff --git a/src/airflow_wingman/providers/openai_provider.py b/src/airflow_wingman/providers/openai_provider.py index a6d5c0c..d1847da 100644 --- a/src/airflow_wingman/providers/openai_provider.py +++ b/src/airflow_wingman/providers/openai_provider.py @@ -12,7 +12,7 @@ from typing import Any from openai import OpenAI -from airflow_wingman.providers.base import BaseLLMProvider +from airflow_wingman.providers.base import BaseLLMProvider, StreamingResponse from airflow_wingman.tools import execute_airflow_tool from airflow_wingman.tools.conversion import convert_to_openai_tools @@ -116,35 +116,75 @@ class OpenAIProvider(BaseLLMProvider): Check if the response contains tool calls. Args: - response: OpenAI response object + response: OpenAI response object or StreamingResponse with tool_call attribute Returns: True if the response contains tool calls, False otherwise """ - message = response.choices[0].message - return hasattr(message, "tool_calls") and message.tool_calls + # Check if response is a StreamingResponse with a tool_call attribute + if isinstance(response, StreamingResponse) and response.tool_call is not None: + return True + + # For non-streaming responses + if hasattr(response, "choices") and len(response.choices) > 0: + message = response.choices[0].message + return hasattr(message, "tool_calls") and message.tool_calls + + return False + + def get_tool_calls(self, response: Any) -> list: + """ + Extract tool calls from the response. + + Args: + response: OpenAI response object or StreamingResponse with tool_call attribute + + Returns: + List of tool call objects in a standardized format + """ + tool_calls = [] + + # Check if response is a StreamingResponse with a tool_call attribute + if isinstance(response, StreamingResponse) and response.tool_call is not None: + logger.info(f"Extracting tool call from StreamingResponse: {response.tool_call}") + tool_calls.append(response.tool_call) + return tool_calls + + # For non-streaming responses + if hasattr(response, "choices") and len(response.choices) > 0: + message = response.choices[0].message + if hasattr(message, "tool_calls") and message.tool_calls: + for tool_call in message.tool_calls: + standardized_tool_call = {"id": tool_call.id, "name": tool_call.function.name, "input": json.loads(tool_call.function.arguments)} + tool_calls.append(standardized_tool_call) + + logger.info(f"Extracted {len(tool_calls)} tool calls from OpenAI response") + return tool_calls def process_tool_calls(self, response: Any, cookie: str) -> dict[str, Any]: """ Process tool calls from the response. Args: - response: OpenAI response object + response: OpenAI response object or StreamingResponse with tool_call attribute cookie: Airflow cookie for authentication Returns: Dictionary mapping tool call IDs to results """ results = {} - message = response.choices[0].message if not self.has_tool_calls(response): return results - for tool_call in message.tool_calls: - tool_id = tool_call.id - function_name = tool_call.function.name - arguments = json.loads(tool_call.function.arguments) + # Get tool calls using the standardized method + tool_calls = self.get_tool_calls(response) + logger.info(f"Processing {len(tool_calls)} tool calls") + + for tool_call in tool_calls: + tool_id = tool_call["id"] + function_name = tool_call["name"] + arguments = tool_call["input"] try: # Execute the Airflow tool with the provided arguments and cookie @@ -220,7 +260,7 @@ class OpenAIProvider(BaseLLMProvider): """ return response.choices[0].message.content - def get_streaming_content(self, response: Any) -> Any: + def get_streaming_content(self, response: Any) -> StreamingResponse: """ Get a generator for streaming content from the response. @@ -228,15 +268,87 @@ class OpenAIProvider(BaseLLMProvider): response: OpenAI streaming response object Returns: - Generator yielding content chunks + StreamingResponse object wrapping a generator that yields content chunks + and can also store tool call information detected during streaming """ logger.info("Starting OpenAI streaming response processing") + # Track only the first tool call detected during streaming + tool_call = None + tool_use_detected = False + current_tool_call = None + + # Create the StreamingResponse object first + streaming_response = StreamingResponse(generator=None, tool_call=None) + def generate(): + nonlocal tool_call, tool_use_detected, current_tool_call + for chunk in response: - if chunk.choices and chunk.choices[0].delta.content: - # Don't do any newline replacement here + # Check for tool call in the delta + if chunk.choices and hasattr(chunk.choices[0].delta, "tool_calls") and chunk.choices[0].delta.tool_calls: + # Tool call detected + if not tool_use_detected: + tool_use_detected = True + logger.info("Tool call detected in streaming response") + + # Initialize the tool call + delta_tool_call = chunk.choices[0].delta.tool_calls[0] + current_tool_call = { + "id": getattr(delta_tool_call, "id", ""), + "name": getattr(delta_tool_call.function, "name", "") if hasattr(delta_tool_call, "function") else "", + "input": {}, + } + # Update the StreamingResponse object's tool_call attribute + streaming_response.tool_call = current_tool_call + else: + # Update the existing tool call + delta_tool_call = chunk.choices[0].delta.tool_calls[0] + + # Update the tool call ID if it's provided in this chunk + if hasattr(delta_tool_call, "id") and delta_tool_call.id and current_tool_call: + current_tool_call["id"] = delta_tool_call.id + + # Update the function name if it's provided in this chunk + if hasattr(delta_tool_call, "function") and hasattr(delta_tool_call.function, "name") and delta_tool_call.function.name and current_tool_call: + current_tool_call["name"] = delta_tool_call.function.name + + # Update the arguments if they're provided in this chunk + if hasattr(delta_tool_call, "function") and hasattr(delta_tool_call.function, "arguments") and delta_tool_call.function.arguments and current_tool_call: + try: + # Try to parse the arguments JSON + arguments = json.loads(delta_tool_call.function.arguments) + if isinstance(arguments, dict): + current_tool_call["input"].update(arguments) + # Update the StreamingResponse object's tool_call attribute + streaming_response.tool_call = current_tool_call + except json.JSONDecodeError: + # If the arguments are not valid JSON, just log a warning + logger.warning(f"Failed to parse arguments: {delta_tool_call.function.arguments}") + + # Skip yielding content for tool call chunks + continue + + # For the final chunk, set the tool_call attribute + if chunk.choices and hasattr(chunk.choices[0], "finish_reason") and chunk.choices[0].finish_reason == "tool_calls": + logger.info("Streaming response finished with tool_calls reason") + if current_tool_call: + tool_call = current_tool_call + logger.info(f"Final tool call: {json.dumps(tool_call)}") + # Update the StreamingResponse object's tool_call attribute + streaming_response.tool_call = tool_call + continue + + # Handle regular content chunks + if chunk.choices and hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content: content = chunk.choices[0].delta.content yield content - return generate() + # Create the generator + gen = generate() + + # Set the generator in the StreamingResponse object + streaming_response.generator = gen + + # Return the StreamingResponse object + return streaming_response diff --git a/src/airflow_wingman/tools/execution.py b/src/airflow_wingman/tools/execution.py index 6242972..5dc88b8 100644 --- a/src/airflow_wingman/tools/execution.py +++ b/src/airflow_wingman/tools/execution.py @@ -31,7 +31,14 @@ async def _list_airflow_tools_async(cookie: str) -> list: # Set up configuration base_url = f"{configuration.conf.get('webserver', 'base_url')}/api/v1/" logger.info(f"Setting up AirflowConfig with base_url: {base_url}") - config = AirflowConfig(base_url=base_url, cookie=cookie, auth_token=None) + + # Format the cookie properly if it doesn't already have the 'session=' prefix + formatted_cookie = cookie + if cookie and not cookie.startswith("session="): + formatted_cookie = f"session={cookie}" + logger.info(f"Formatted cookie with session prefix: {formatted_cookie[:10]}...") + + config = AirflowConfig(base_url=base_url, cookie=formatted_cookie, auth_token=None) # Get available tools logger.info("Getting Airflow tools...") @@ -73,20 +80,32 @@ async def _execute_airflow_tool_async(tool_name: str, arguments: dict, cookie: s # Set up configuration base_url = f"{configuration.conf.get('webserver', 'base_url')}/api/v1/" logger.info(f"Setting up AirflowConfig with base_url: {base_url}") - config = AirflowConfig(base_url=base_url, cookie=cookie, auth_token=None) + + # Format the cookie properly if it doesn't already have the 'session=' prefix + formatted_cookie = cookie + if cookie and not cookie.startswith("session="): + formatted_cookie = f"session={cookie}" + logger.info(f"Formatted cookie with session prefix: {formatted_cookie[:10]}...") + + config = AirflowConfig(base_url=base_url, cookie=formatted_cookie, auth_token=None) # Get the tool logger.info(f"Getting tool: {tool_name}") - tool = await get_tool(config=config, tool_name=tool_name) + tool = await get_tool(config=config, name=tool_name) if not tool: error_msg = f"Tool not found: {tool_name}" logger.error(error_msg) return json.dumps({"error": error_msg}) - # Execute the tool + # Execute the tool - ensure the client is in an async context logger.info(f"Executing tool: {tool_name} with arguments: {arguments}") - result = await tool.run(arguments) + + # The AirflowClient needs to be used as an async context manager + # to properly initialize its session + async with tool.client as client: # noqa F841 + # Now the client has a _session attribute and is in an async context + result = await tool.run(arguments) # Convert result to string if isinstance(result, dict | list): @@ -114,4 +133,21 @@ def execute_airflow_tool(tool_name: str, arguments: dict, cookie: str) -> str: Returns: Result of the tool execution as a string """ - return asyncio.run(_execute_airflow_tool_async(tool_name, arguments, cookie)) + # Create a new event loop for this execution + # This ensures we're always in a clean async context + loop = asyncio.new_event_loop() + + try: + # Set the event loop for this thread + asyncio.set_event_loop(loop) + + # Run the async function in the new event loop + result = loop.run_until_complete(_execute_airflow_tool_async(tool_name, arguments, cookie)) + return result + except Exception as e: + error_msg = f"Error in execute_airflow_tool: {str(e)}\n{traceback.format_exc()}" + logger.error(error_msg) + return json.dumps({"error": error_msg}) + finally: + # Always close the loop to free resources + loop.close() diff --git a/src/airflow_wingman/views.py b/src/airflow_wingman/views.py index bc60804..14417f5 100644 --- a/src/airflow_wingman/views.py +++ b/src/airflow_wingman/views.py @@ -114,16 +114,18 @@ class WingmanView(AppBuilderBaseView): """Handle streaming response.""" try: logger.info("Beginning streaming response") - # Use the enhanced chat_completion method with return_response_obj=True - response_obj, generator = client.chat_completion( - messages=data["messages"], model=data["model"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=True, return_response_obj=True - ) + # Get the cookie at the beginning of the request handler + airflow_cookie = request.cookies.get("session") + logger.info(f"Got airflow_cookie: {airflow_cookie is not None}") - def stream_response(): + # Use the enhanced chat_completion method with return_response_obj=True + streaming_response = client.chat_completion(messages=data["messages"], model=data["model"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=True) + + def stream_response(cookie=airflow_cookie): complete_response = "" # Stream the initial response - for chunk in generator: + for chunk in streaming_response: if chunk: complete_response += chunk yield f"data: {chunk}\n\n" @@ -134,7 +136,7 @@ class WingmanView(AppBuilderBaseView): logger.info("<<< COMPLETE RESPONSE END") # Check for tool calls and make follow-up if needed - if client.provider.has_tool_calls(response_obj): + if client.provider.has_tool_calls(streaming_response): # Signal tool processing start - frontend should disable send button yield f"data: {json.dumps({'event': 'tool_processing_start'})}\n\n" @@ -142,9 +144,10 @@ class WingmanView(AppBuilderBaseView): yield f"data: {json.dumps({'event': 'replace_content'})}\n\n" logger.info("Response contains tool calls, making follow-up request") + logger.info(f"Using cookie from closure: {cookie is not None}") # Process tool calls and get follow-up response (handles recursive tool calls) - follow_up_response = client.process_tool_calls_and_follow_up(response_obj, data["messages"], data["model"], data["temperature"], data["max_tokens"]) + follow_up_response = client.process_tool_calls_and_follow_up(streaming_response, data["messages"], data["model"], data["temperature"], data["max_tokens"], cookie=cookie) # Stream the follow-up response for chunk in follow_up_response: