From 227e6321bf82e587b7b1b588add6120ae2b99ae2 Mon Sep 17 00:00:00 2001 From: abhishekbhakat Date: Tue, 25 Feb 2025 12:52:35 +0000 Subject: [PATCH 01/14] dependencies sorting for mcp server --- pyproject.toml | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 4431ddc..c084abf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,9 +9,10 @@ authors = [ {name = "Abhishek Bhakat", email = "abhishek.bhakat@hotmail.com"} ] dependencies = [ + "airflow-mcp-server>=0.4.0", + "anthropic>=0.46.0", "apache-airflow>=2.10.0", "openai>=1.64.0", - "anthropic>=0.46.0" ] classifiers = [ "Development Status :: 3 - Alpha", @@ -31,6 +32,13 @@ Issues = "https://github.com/abhishekbhakat/airflow-wingman/issues" [project.entry-points."airflow.plugins"] wingman = "airflow_wingman:WingmanPlugin" +[project.optional-dependencies] +dev = [ + "build>=1.2.2", + "pre-commit>=4.0.1", + "ruff>=0.9.2" +] + [build-system] requires = ["hatchling"] build-backend = "hatchling.build" From a8a3d6d1a12e19cdeca15a450470ea3e0e29f7a0 Mon Sep 17 00:00:00 2001 From: abhishekbhakat Date: Tue, 25 Feb 2025 12:56:47 +0000 Subject: [PATCH 02/14] update to 3.7 sonnet --- src/airflow_wingman/llms_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/airflow_wingman/llms_models.py b/src/airflow_wingman/llms_models.py index c73cd46..eaf6d16 100644 --- a/src/airflow_wingman/llms_models.py +++ b/src/airflow_wingman/llms_models.py @@ -17,8 +17,8 @@ MODELS = { "endpoint": "https://api.anthropic.com/v1/messages", "models": [ { - "id": "claude-3.5-sonnet", - "name": "Claude 3.5 Sonnet", + "id": "claude-3.7-sonnet", + "name": "Claude 3.7 Sonnet", "default": True, "context_window": 200000, "description": "Input $3/M tokens, Output $15/M tokens", From 3904bfc6443843b39cfb96ce3af78b66e39e3956 Mon Sep 17 00:00:00 2001 From: abhishekbhakat Date: Wed, 26 Feb 2025 19:51:49 +0000 Subject: [PATCH 03/14] Restructure foe 2 main providers and tools conversion --- pyproject.toml | 3 +- src/airflow_wingman/llm_client.py | 188 +++++++----- src/airflow_wingman/llms_models.py | 4 +- src/airflow_wingman/prompt_engineering.py | 6 +- src/airflow_wingman/providers/__init__.py | 41 +++ .../providers/anthropic_provider.py | 288 ++++++++++++++++++ src/airflow_wingman/providers/base.py | 125 ++++++++ .../providers/openai_provider.py | 224 ++++++++++++++ .../templates/wingman_chat.html | 3 + src/airflow_wingman/tools/__init__.py | 15 + src/airflow_wingman/tools/conversion.py | 139 +++++++++ src/airflow_wingman/tools/execution.py | 117 +++++++ src/airflow_wingman/views.py | 73 +++-- 13 files changed, 1126 insertions(+), 100 deletions(-) create mode 100644 src/airflow_wingman/providers/__init__.py create mode 100644 src/airflow_wingman/providers/anthropic_provider.py create mode 100644 src/airflow_wingman/providers/base.py create mode 100644 src/airflow_wingman/providers/openai_provider.py create mode 100644 src/airflow_wingman/tools/__init__.py create mode 100644 src/airflow_wingman/tools/conversion.py create mode 100644 src/airflow_wingman/tools/execution.py diff --git a/pyproject.toml b/pyproject.toml index c084abf..931b776 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,7 +68,8 @@ lint.select = [ lint.ignore = [ "C416", # Unnecessary list comprehension - rewrite as a generator expression "C408", # Unnecessary `dict` call - rewrite as a literal - "ISC001" # Single line implicit string concatenation + "ISC001", # Single line implicit string concatenation + "C901" ] lint.fixable = ["ALL"] diff --git a/src/airflow_wingman/llm_client.py b/src/airflow_wingman/llm_client.py index dd58370..4fd6768 100644 --- a/src/airflow_wingman/llm_client.py +++ b/src/airflow_wingman/llm_client.py @@ -1,109 +1,145 @@ """ -Client for making API calls to various LLM providers using their official SDKs. +Multi-provider LLM client for Airflow Wingman. + +This module contains the LLMClient class that supports multiple LLM providers +(OpenAI, Anthropic, OpenRouter) through a unified interface. """ -from collections.abc import Generator +import traceback +from typing import Any -from anthropic import Anthropic -from openai import OpenAI +from airflow.utils.log.logging_mixin import LoggingMixin +from flask import session + +from airflow_wingman.providers import create_llm_provider +from airflow_wingman.tools import list_airflow_tools + +# Create a logger instance +logger = LoggingMixin().log class LLMClient: - def __init__(self, api_key: str): - """Initialize the LLM client. + """ + Multi-provider LLM client for Airflow Wingman. + + This class handles chat completion requests to various LLM providers + (OpenAI, Anthropic, OpenRouter) through a unified interface. + """ + + def __init__(self, provider_name: str, api_key: str, base_url: str | None = None): + """ + Initialize the LLM client. Args: + provider_name: Name of the provider (openai, anthropic, openrouter) api_key: API key for the provider + base_url: Optional base URL for the provider API """ + self.provider_name = provider_name self.api_key = api_key - self.openai_client = OpenAI(api_key=api_key) - self.anthropic_client = Anthropic(api_key=api_key) - self.openrouter_client = OpenAI( - base_url="https://openrouter.ai/api/v1", - api_key=api_key, - default_headers={ - "HTTP-Referer": "Airflow Wingman", # Required by OpenRouter - "X-Title": "Airflow Wingman", # Required by OpenRouter - }, - ) + self.base_url = base_url + self.provider = create_llm_provider(provider_name, api_key, base_url) + self.airflow_tools = [] - def chat_completion( - self, messages: list[dict[str, str]], model: str, provider: str, temperature: float = 0.7, max_tokens: int | None = None, stream: bool = False - ) -> Generator[str, None, None] | dict: - """Send a chat completion request to the specified provider. + def set_airflow_tools(self, tools: list): + """ + Set the available Airflow tools. + + Args: + tools: List of Airflow Tool objects + """ + self.airflow_tools = tools + + def chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float = 0.7, max_tokens: int | None = None, stream: bool = False) -> dict[str, Any]: + """ + Send a chat completion request to the LLM provider. Args: messages: List of message dictionaries with 'role' and 'content' model: Model identifier - provider: Provider identifier (openai, anthropic, openrouter) temperature: Sampling temperature (0-1) max_tokens: Maximum tokens to generate - stream: Whether to stream the response + stream: Whether to stream the response (default is True) Returns: - If stream=True, returns a generator yielding response chunks - If stream=False, returns the complete response + Dictionary with the response content or a generator for streaming """ + # Get provider-specific tool definitions from Airflow tools + provider_tools = self.provider.convert_tools(self.airflow_tools) + try: - if provider == "openai": - return self._openai_chat_completion(messages, model, temperature, max_tokens, stream) - elif provider == "anthropic": - return self._anthropic_chat_completion(messages, model, temperature, max_tokens, stream) - elif provider == "openrouter": - return self._openrouter_chat_completion(messages, model, temperature, max_tokens, stream) + # Make the initial request with tools + logger.info(f"Sending chat completion request to {self.provider_name} with model: {model}") + response = self.provider.create_chat_completion(messages=messages, model=model, temperature=temperature, max_tokens=max_tokens, stream=stream, tools=provider_tools) + logger.info(f"Received response from {self.provider_name}") + + # If streaming, return the generator directly + if stream: + return self.provider.get_streaming_content(response) + + # For non-streaming responses, handle tool calls if present + if self.provider.has_tool_calls(response): + logger.info("Response contains tool calls") + + # Process tool calls and get results + cookie = session.get("airflow_cookie") + if not cookie: + error_msg = "No Airflow cookie available" + logger.error(error_msg) + return {"error": error_msg} + + tool_results = self.provider.process_tool_calls(response, cookie) + + # Create a follow-up completion with the tool results + logger.info("Making follow-up request with tool results") + follow_up_response = self.provider.create_follow_up_completion( + messages=messages, model=model, temperature=temperature, max_tokens=max_tokens, tool_results=tool_results, original_response=response + ) + + return {"content": self.provider.get_content(follow_up_response)} else: - return {"error": f"Unknown provider: {provider}"} + logger.info("Response does not contain tool calls") + return {"content": self.provider.get_content(response)} + except Exception as e: + error_msg = f"Error in {self.provider_name} API call: {str(e)}\\n{traceback.format_exc()}" + logger.error(error_msg) return {"error": f"API request failed: {str(e)}"} - def _openai_chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float, max_tokens: int | None, stream: bool): - """Handle OpenAI chat completion requests.""" - response = self.openai_client.chat.completions.create(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream) + @classmethod + def from_config(cls, config: dict[str, Any]) -> "LLMClient": + """ + Create an LLMClient instance from a configuration dictionary. - if stream: + Args: + config: Configuration dictionary with provider_name, api_key, and optional base_url - def response_generator(): - for chunk in response: - if chunk.choices[0].delta.content: - yield chunk.choices[0].delta.content + Returns: + LLMClient instance + """ + provider_name = config.get("provider_name", "openai") + api_key = config.get("api_key") + base_url = config.get("base_url") - return response_generator() - else: - return {"content": response.choices[0].message.content} + if not api_key: + raise ValueError("API key is required") - def _anthropic_chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float, max_tokens: int | None, stream: bool): - """Handle Anthropic chat completion requests.""" - # Convert messages to Anthropic format - system_message = next((m["content"] for m in messages if m["role"] == "system"), None) - conversation = [] - for m in messages: - if m["role"] != "system": - conversation.append({"role": "assistant" if m["role"] == "assistant" else "user", "content": m["content"]}) + return cls(provider_name=provider_name, api_key=api_key, base_url=base_url) - response = self.anthropic_client.messages.create(model=model, messages=conversation, system=system_message, temperature=temperature, max_tokens=max_tokens, stream=stream) + def refresh_tools(self, cookie: str) -> None: + """ + Refresh the available Airflow tools. - if stream: - - def response_generator(): - for chunk in response: - if chunk.delta.text: - yield chunk.delta.text - - return response_generator() - else: - return {"content": response.content[0].text} - - def _openrouter_chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float, max_tokens: int | None, stream: bool): - """Handle OpenRouter chat completion requests.""" - response = self.openrouter_client.chat.completions.create(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream) - - if stream: - - def response_generator(): - for chunk in response: - if chunk.choices[0].delta.content: - yield chunk.choices[0].delta.content - - return response_generator() - else: - return {"content": response.choices[0].message.content} + Args: + cookie: Airflow cookie for authentication + """ + try: + logger.info("Refreshing Airflow tools") + tools = list_airflow_tools(cookie) + self.set_airflow_tools(tools) + logger.info(f"Refreshed {len(tools)} Airflow tools") + except Exception as e: + error_msg = f"Error refreshing Airflow tools: {str(e)}\n{traceback.format_exc()}" + logger.error(error_msg) + # Don't raise the exception, just log it + # The client will continue to use the existing tools (if any) diff --git a/src/airflow_wingman/llms_models.py b/src/airflow_wingman/llms_models.py index eaf6d16..ff6f2b9 100644 --- a/src/airflow_wingman/llms_models.py +++ b/src/airflow_wingman/llms_models.py @@ -17,14 +17,14 @@ MODELS = { "endpoint": "https://api.anthropic.com/v1/messages", "models": [ { - "id": "claude-3.7-sonnet", + "id": "claude-3-7-sonnet-20250219", "name": "Claude 3.7 Sonnet", "default": True, "context_window": 200000, "description": "Input $3/M tokens, Output $15/M tokens", }, { - "id": "claude-3.5-haiku", + "id": "claude-3-5-haiku-20241022", "name": "Claude 3.5 Haiku", "default": False, "context_window": 200000, diff --git a/src/airflow_wingman/prompt_engineering.py b/src/airflow_wingman/prompt_engineering.py index 31b5d8e..74a67e5 100644 --- a/src/airflow_wingman/prompt_engineering.py +++ b/src/airflow_wingman/prompt_engineering.py @@ -8,9 +8,11 @@ INSTRUCTIONS = { You have deep knowledge of Apache Airflow's architecture, DAGs, operators, and best practices. The Airflow version being used is >=2.10. -You have access to the following Airflow API tools: +You have access to Airflow MCP tools that you can use to fetch information and help users understand +and manage their Airflow environment. -You can use these tools to fetch information and help users understand and manage their Airflow environment. +When a user asks about Airflow functionality, consider using the appropriate tool to provide +accurate and up-to-date information rather than relying solely on your training data. """ } diff --git a/src/airflow_wingman/providers/__init__.py b/src/airflow_wingman/providers/__init__.py new file mode 100644 index 0000000..828726c --- /dev/null +++ b/src/airflow_wingman/providers/__init__.py @@ -0,0 +1,41 @@ +""" +Provider factory for Airflow Wingman. + +This module contains the factory function to create provider instances +based on the provider name. +""" + +from airflow_wingman.providers.anthropic_provider import AnthropicProvider +from airflow_wingman.providers.base import BaseLLMProvider +from airflow_wingman.providers.openai_provider import OpenAIProvider + + +def create_llm_provider(provider_name: str, api_key: str, base_url: str | None = None) -> BaseLLMProvider: + """ + Create a provider instance based on the provider name. + + Args: + provider_name: Name of the provider (openai, anthropic, openrouter) + api_key: API key for the provider + base_url: Optional base URL for the provider API + + Returns: + Provider instance + + Raises: + ValueError: If the provider is not supported + """ + provider_name = provider_name.lower() + + if provider_name == "openai": + return OpenAIProvider(api_key=api_key, base_url=base_url) + elif provider_name == "openrouter": + # OpenRouter uses the OpenAI API format, so we can use the OpenAI provider + # with a custom base URL + if not base_url: + base_url = "https://openrouter.ai/api/v1" + return OpenAIProvider(api_key=api_key, base_url=base_url) + elif provider_name == "anthropic": + return AnthropicProvider(api_key=api_key) + else: + raise ValueError(f"Unsupported provider: {provider_name}. Supported providers: openai, anthropic, openrouter") diff --git a/src/airflow_wingman/providers/anthropic_provider.py b/src/airflow_wingman/providers/anthropic_provider.py new file mode 100644 index 0000000..f2ecb2e --- /dev/null +++ b/src/airflow_wingman/providers/anthropic_provider.py @@ -0,0 +1,288 @@ +""" +Anthropic provider implementation for Airflow Wingman. + +This module contains the Anthropic provider implementation that handles +API requests, tool conversion, and response processing for Anthropic's Claude models. +""" + +import traceback +from typing import Any + +from airflow.utils.log.logging_mixin import LoggingMixin +from anthropic import Anthropic + +from airflow_wingman.providers.base import BaseLLMProvider +from airflow_wingman.tools import execute_airflow_tool +from airflow_wingman.tools.conversion import convert_to_anthropic_tools + +logger = LoggingMixin().log + + +class AnthropicProvider(BaseLLMProvider): + """ + Anthropic provider implementation. + + This class handles API requests, tool conversion, and response processing + for the Anthropic API (Claude models). + """ + + def __init__(self, api_key: str): + """ + Initialize the Anthropic provider. + + Args: + api_key: API key for Anthropic + """ + self.api_key = api_key + self.client = Anthropic(api_key=api_key) + + def convert_tools(self, airflow_tools: list) -> list: + """ + Convert Airflow tools to Anthropic format. + + Args: + airflow_tools: List of Airflow tools from MCP server + + Returns: + List of Anthropic tool definitions + """ + return convert_to_anthropic_tools(airflow_tools) + + def create_chat_completion( + self, messages: list[dict[str, Any]], model: str, temperature: float = 0.7, max_tokens: int | None = None, stream: bool = False, tools: list[dict[str, Any]] | None = None + ) -> Any: + """ + Make API request to Anthropic. + + Args: + messages: List of message dictionaries with 'role' and 'content' + model: Model identifier + temperature: Sampling temperature (0-1) + max_tokens: Maximum tokens to generate + stream: Whether to stream the response + tools: List of tool definitions in Anthropic format + + Returns: + Anthropic response object + + Raises: + Exception: If the API request fails + """ + # Convert max_tokens to Anthropic's max_tokens parameter (if provided) + max_tokens_param = max_tokens if max_tokens is not None else 4096 + + # Convert messages from ChatML format to Anthropic's format + anthropic_messages = self._convert_to_anthropic_messages(messages) + + try: + logger.info(f"Sending chat completion request to Anthropic with model: {model}") + + # Create request parameters + params = {"model": model, "messages": anthropic_messages, "temperature": temperature, "max_tokens": max_tokens_param, "stream": stream} + + # Add tools if provided + if tools and len(tools) > 0: + params["tools"] = tools + + # Make the API request + response = self.client.messages.create(**params) + + logger.info("Received response from Anthropic") + return response + except Exception as e: + error_msg = str(e) + logger.error(f"Failed to get response from Anthropic: {error_msg}\n{traceback.format_exc()}") + raise + + def _convert_to_anthropic_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]: + """ + Convert messages from ChatML format to Anthropic's format. + + Args: + messages: List of message dictionaries in ChatML format + + Returns: + List of message dictionaries in Anthropic format + """ + anthropic_messages = [] + + for message in messages: + role = message["role"] + content = message["content"] + + # Map ChatML roles to Anthropic roles + if role == "system": + # System messages in Anthropic are handled differently + # We'll add them as a user message with a special prefix + anthropic_messages.append({"role": "user", "content": f"\n{content}\n"}) + elif role == "user": + anthropic_messages.append({"role": "user", "content": content}) + elif role == "assistant": + anthropic_messages.append({"role": "assistant", "content": content}) + elif role == "tool": + # Tool messages in ChatML become part of the user message in Anthropic + # We'll handle this in the follow-up completion + continue + + return anthropic_messages + + def has_tool_calls(self, response: Any) -> bool: + """ + Check if the response contains tool calls. + + Args: + response: Anthropic response object + + Returns: + True if the response contains tool calls, False otherwise + """ + # Check if any content block is a tool_use block + if not hasattr(response, "content"): + return False + + for block in response.content: + if isinstance(block, dict) and block.get("type") == "tool_use": + return True + + return False + + def process_tool_calls(self, response: Any, cookie: str) -> dict[str, Any]: + """ + Process tool calls from the response. + + Args: + response: Anthropic response object + cookie: Airflow cookie for authentication + + Returns: + Dictionary mapping tool call IDs to results + """ + results = {} + + if not self.has_tool_calls(response): + return results + + # Extract tool_use blocks + tool_use_blocks = [block for block in response.content if isinstance(block, dict) and block.get("type") == "tool_use"] + + for block in tool_use_blocks: + tool_id = block.get("id") + tool_name = block.get("name") + tool_input = block.get("input", {}) + + try: + # Execute the Airflow tool with the provided arguments and cookie + logger.info(f"Executing tool: {tool_name} with arguments: {tool_input}") + result = execute_airflow_tool(tool_name, tool_input, cookie) + logger.info(f"Tool execution result: {result}") + results[tool_id] = {"status": "success", "result": result} + except Exception as e: + error_msg = f"Error executing tool: {str(e)}\n{traceback.format_exc()}" + logger.error(error_msg) + results[tool_id] = {"status": "error", "message": error_msg} + + return results + + def create_follow_up_completion( + self, messages: list[dict[str, Any]], model: str, temperature: float = 0.7, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None + ) -> Any: + """ + Create a follow-up completion with tool results. + + Args: + messages: Original messages + model: Model identifier + temperature: Sampling temperature (0-1) + max_tokens: Maximum tokens to generate + tool_results: Results of tool executions + original_response: Original response with tool calls + + Returns: + Anthropic response object + """ + if not original_response or not tool_results: + return original_response + + # Extract tool_use blocks from the original response + tool_use_blocks = [block for block in original_response.content if isinstance(block, dict) and block.get("type") == "tool_use"] + + # Create tool result blocks + tool_result_blocks = [] + for tool_id, result in tool_results.items(): + tool_result_blocks.append({"type": "tool_result", "tool_use_id": tool_id, "content": result.get("result", str(result))}) + + # Convert original messages to Anthropic format + anthropic_messages = self._convert_to_anthropic_messages(messages) + + # Add the assistant response with tool use + anthropic_messages.append({"role": "assistant", "content": tool_use_blocks}) + + # Add the user message with tool results + anthropic_messages.append({"role": "user", "content": tool_result_blocks}) + + # Make a second request to get the final response + logger.info("Making second request with tool results") + return self.create_chat_completion( + messages=anthropic_messages, + model=model, + temperature=temperature, + max_tokens=max_tokens, + stream=False, + tools=None, # No tools needed for follow-up + ) + + def get_content(self, response: Any) -> str: + """ + Extract content from the response. + + Args: + response: Anthropic response object + + Returns: + Content string from the response + """ + if not hasattr(response, "content"): + return "" + + # Combine all text blocks into a single string + content_parts = [] + for block in response.content: + if isinstance(block, dict) and block.get("type") == "text": + content_parts.append(block.get("text", "")) + elif isinstance(block, str): + content_parts.append(block) + + return "".join(content_parts) + + def get_streaming_content(self, response: Any) -> Any: + """ + Get a generator for streaming content from the response. + + Args: + response: Anthropic streaming response object + + Returns: + Generator yielding content chunks + """ + + def generate(): + for chunk in response: + logger.debug(f"Chunk type: {type(chunk)}") + + # Handle different types of chunks from Anthropic API + content = None + if hasattr(chunk, "type") and chunk.type == "content_block_delta": + if hasattr(chunk, "delta") and hasattr(chunk.delta, "text"): + content = chunk.delta.text + elif hasattr(chunk, "delta") and hasattr(chunk.delta, "text"): + content = chunk.delta.text + elif hasattr(chunk, "content") and chunk.content: + for block in chunk.content: + if isinstance(block, dict) and block.get("type") == "text": + content = block.get("text", "") + + if content: + # Don't do any newline replacement here + yield content + + return generate() diff --git a/src/airflow_wingman/providers/base.py b/src/airflow_wingman/providers/base.py new file mode 100644 index 0000000..dcdbc44 --- /dev/null +++ b/src/airflow_wingman/providers/base.py @@ -0,0 +1,125 @@ +""" +Base provider interface for Airflow Wingman. + +This module contains the base provider interface that all provider implementations +must adhere to. It defines the methods required for tool conversion, API requests, +and response processing. +""" + +from abc import ABC, abstractmethod +from typing import Any + + +class BaseLLMProvider(ABC): + """ + Base provider interface for LLM providers. + + This abstract class defines the methods that all provider implementations + must implement to support tool integration. + """ + + @abstractmethod + def convert_tools(self, airflow_tools: list) -> list: + """ + Convert internal tool representation to provider format. + + Args: + airflow_tools: List of Airflow tools from MCP server + + Returns: + List of provider-specific tool definitions + """ + pass + + @abstractmethod + def create_chat_completion( + self, messages: list[dict[str, Any]], model: str, temperature: float = 0.7, max_tokens: int | None = None, stream: bool = False, tools: list[dict[str, Any]] | None = None + ) -> Any: + """ + Make API request to provider. + + Args: + messages: List of message dictionaries with 'role' and 'content' + model: Model identifier + temperature: Sampling temperature (0-1) + max_tokens: Maximum tokens to generate + stream: Whether to stream the response + tools: List of tool definitions in provider format + + Returns: + Provider-specific response object + """ + pass + + @abstractmethod + def has_tool_calls(self, response: Any) -> bool: + """ + Check if the response contains tool calls. + + Args: + response: Provider-specific response object + + Returns: + True if the response contains tool calls, False otherwise + """ + pass + + @abstractmethod + def process_tool_calls(self, response: Any, cookie: str) -> dict[str, Any]: + """ + Process tool calls from the response. + + Args: + response: Provider-specific response object + cookie: Airflow cookie for authentication + + Returns: + Dictionary mapping tool call IDs to results + """ + pass + + @abstractmethod + def create_follow_up_completion( + self, messages: list[dict[str, Any]], model: str, temperature: float = 0.7, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None + ) -> Any: + """ + Create a follow-up completion with tool results. + + Args: + messages: Original messages + model: Model identifier + temperature: Sampling temperature (0-1) + max_tokens: Maximum tokens to generate + tool_results: Results of tool executions + original_response: Original response with tool calls + + Returns: + Provider-specific response object + """ + pass + + @abstractmethod + def get_content(self, response: Any) -> str: + """ + Extract content from the response. + + Args: + response: Provider-specific response object + + Returns: + Content string from the response + """ + pass + + @abstractmethod + def get_streaming_content(self, response: Any) -> Any: + """ + Get a generator for streaming content from the response. + + Args: + response: Provider-specific response object + + Returns: + Generator yielding content chunks + """ + pass diff --git a/src/airflow_wingman/providers/openai_provider.py b/src/airflow_wingman/providers/openai_provider.py new file mode 100644 index 0000000..6d12cbf --- /dev/null +++ b/src/airflow_wingman/providers/openai_provider.py @@ -0,0 +1,224 @@ +""" +OpenAI provider implementation for Airflow Wingman. + +This module contains the OpenAI provider implementation that handles +API requests, tool conversion, and response processing for OpenAI. +""" + +import json +import traceback +from typing import Any + +from airflow.utils.log.logging_mixin import LoggingMixin +from openai import OpenAI + +from airflow_wingman.providers.base import BaseLLMProvider +from airflow_wingman.tools import execute_airflow_tool +from airflow_wingman.tools.conversion import convert_to_openai_tools + +# Create a logger instance +logger = LoggingMixin().log + + +class OpenAIProvider(BaseLLMProvider): + """ + OpenAI provider implementation. + + This class handles API requests, tool conversion, and response processing + for the OpenAI API. It can also be used for OpenRouter with a custom base URL. + """ + + def __init__(self, api_key: str, base_url: str | None = None): + """ + Initialize the OpenAI provider. + + Args: + api_key: API key for OpenAI + base_url: Optional base URL for the API (used for OpenRouter) + """ + self.api_key = api_key + self.client = OpenAI(api_key=api_key, base_url=base_url) + + def convert_tools(self, airflow_tools: list) -> list: + """ + Convert Airflow tools to OpenAI format. + + Args: + airflow_tools: List of Airflow tools from MCP server + + Returns: + List of OpenAI tool definitions + """ + return convert_to_openai_tools(airflow_tools) + + def create_chat_completion( + self, messages: list[dict[str, Any]], model: str, temperature: float = 0.7, max_tokens: int | None = None, stream: bool = False, tools: list[dict[str, Any]] | None = None + ) -> Any: + """ + Make API request to OpenAI. + + Args: + messages: List of message dictionaries with 'role' and 'content' + model: Model identifier + temperature: Sampling temperature (0-1) + max_tokens: Maximum tokens to generate + stream: Whether to stream the response + tools: List of tool definitions in OpenAI format + + Returns: + OpenAI response object + + Raises: + Exception: If the API request fails + """ + # Only include tools if we have any + has_tools = tools is not None and len(tools) > 0 + tool_choice = "auto" if has_tools else None + + try: + logger.info(f"Sending chat completion request to OpenAI with model: {model}") + response = self.client.chat.completions.create( + model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream, tools=tools if has_tools else None, tool_choice=tool_choice + ) + logger.info("Received response from OpenAI") + return response + except Exception as e: + # If the API call fails due to tools not being supported, retry without tools + error_msg = str(e) + logger.warning(f"Error in OpenAI API call: {error_msg}") + if "tools" in error_msg.lower(): + logger.info("Retrying without tools") + response = self.client.chat.completions.create(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream) + return response + else: + logger.error(f"Failed to get response from OpenAI: {error_msg}\n{traceback.format_exc()}") + raise + + def has_tool_calls(self, response: Any) -> bool: + """ + Check if the response contains tool calls. + + Args: + response: OpenAI response object + + Returns: + True if the response contains tool calls, False otherwise + """ + message = response.choices[0].message + return hasattr(message, "tool_calls") and message.tool_calls + + def process_tool_calls(self, response: Any, cookie: str) -> dict[str, Any]: + """ + Process tool calls from the response. + + Args: + response: OpenAI response object + cookie: Airflow cookie for authentication + + Returns: + Dictionary mapping tool call IDs to results + """ + results = {} + message = response.choices[0].message + + if not self.has_tool_calls(response): + return results + + for tool_call in message.tool_calls: + tool_id = tool_call.id + function_name = tool_call.function.name + arguments = json.loads(tool_call.function.arguments) + + try: + # Execute the Airflow tool with the provided arguments and cookie + logger.info(f"Executing tool: {function_name} with arguments: {arguments}") + result = execute_airflow_tool(function_name, arguments, cookie) + logger.info(f"Tool execution result: {result}") + results[tool_id] = {"status": "success", "result": result} + except Exception as e: + error_msg = f"Error executing tool: {str(e)}\n{traceback.format_exc()}" + logger.error(error_msg) + results[tool_id] = {"status": "error", "message": error_msg} + + return results + + def create_follow_up_completion( + self, messages: list[dict[str, Any]], model: str, temperature: float = 0.7, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None + ) -> Any: + """ + Create a follow-up completion with tool results. + + Args: + messages: Original messages + model: Model identifier + temperature: Sampling temperature (0-1) + max_tokens: Maximum tokens to generate + tool_results: Results of tool executions + original_response: Original response with tool calls + + Returns: + OpenAI response object + """ + if not original_response or not tool_results: + return original_response + + # Get the original message with tool calls + original_message = original_response.choices[0].message + + # Create a new message with the tool calls + assistant_message = { + "role": "assistant", + "content": None, + "tool_calls": [{"id": tc.id, "type": "function", "function": {"name": tc.function.name, "arguments": tc.function.arguments}} for tc in original_message.tool_calls], + } + + # Create tool result messages + tool_messages = [] + for tool_call_id, result in tool_results.items(): + tool_messages.append({"role": "tool", "tool_call_id": tool_call_id, "content": result.get("result", str(result))}) + + # Add the original messages, assistant message, and tool results + new_messages = messages + [assistant_message] + tool_messages + + # Make a second request to get the final response + logger.info("Making second request with tool results") + return self.create_chat_completion( + messages=new_messages, + model=model, + temperature=temperature, + max_tokens=max_tokens, + stream=False, + tools=None, # No tools needed for follow-up + ) + + def get_content(self, response: Any) -> str: + """ + Extract content from the response. + + Args: + response: OpenAI response object + + Returns: + Content string from the response + """ + return response.choices[0].message.content + + def get_streaming_content(self, response: Any) -> Any: + """ + Get a generator for streaming content from the response. + + Args: + response: OpenAI streaming response object + + Returns: + Generator yielding content chunks + """ + + def generate(): + for chunk in response: + if chunk.choices and chunk.choices[0].delta.content: + # Don't do any newline replacement here + content = chunk.choices[0].delta.content + yield content + + return generate() diff --git a/src/airflow_wingman/templates/wingman_chat.html b/src/airflow_wingman/templates/wingman_chat.html index 72657f0..56a38a6 100644 --- a/src/airflow_wingman/templates/wingman_chat.html +++ b/src/airflow_wingman/templates/wingman_chat.html @@ -150,6 +150,7 @@ border: 1px solid #e9ecef; border-radius: 15px 15px 15px 0; padding: 10px 15px; + white-space: pre-wrap; } #chat-messages::after { @@ -349,6 +350,8 @@ document.addEventListener('DOMContentLoaded', function() { if (line.startsWith('data: ')) { const content = line.slice(6); if (content) { + // Use textContent to properly handle newlines + console.log('Received chunk:', JSON.stringify(content)); // Debug currentMessageDiv.textContent += content; fullResponse += content; chatMessages.scrollTop = chatMessages.scrollHeight; diff --git a/src/airflow_wingman/tools/__init__.py b/src/airflow_wingman/tools/__init__.py new file mode 100644 index 0000000..44e9360 --- /dev/null +++ b/src/airflow_wingman/tools/__init__.py @@ -0,0 +1,15 @@ +""" +Tools module for Airflow Wingman. + +This module contains the tools used by Airflow Wingman to interact with Airflow. +""" + +from airflow_wingman.tools.conversion import convert_to_anthropic_tools, convert_to_openai_tools +from airflow_wingman.tools.execution import execute_airflow_tool, list_airflow_tools + +__all__ = [ + "convert_to_openai_tools", + "convert_to_anthropic_tools", + "list_airflow_tools", + "execute_airflow_tool", +] diff --git a/src/airflow_wingman/tools/conversion.py b/src/airflow_wingman/tools/conversion.py new file mode 100644 index 0000000..7a30c02 --- /dev/null +++ b/src/airflow_wingman/tools/conversion.py @@ -0,0 +1,139 @@ +""" +Conversion utilities for Airflow Wingman tools. + +This module contains functions to convert between different tool formats +for various LLM providers (OpenAI, Anthropic, etc.). +""" + +from typing import Any + + +def convert_to_openai_tools(airflow_tools: list) -> list: + """ + Convert Airflow tools to OpenAI tool definitions. + + Args: + airflow_tools: List of Airflow tools from MCP server + + Returns: + List of OpenAI tool definitions + """ + openai_tools = [] + + for tool in airflow_tools: + # Initialize the OpenAI tool structure + openai_tool = {"type": "function", "function": {"name": tool.name, "description": tool.description or tool.name, "parameters": {"type": "object", "properties": {}, "required": []}}} + + # Extract parameters directly from inputSchema if available + if hasattr(tool, "inputSchema") and tool.inputSchema: + # Set the type and required fields directly from the schema + if "type" in tool.inputSchema: + openai_tool["function"]["parameters"]["type"] = tool.inputSchema["type"] + + # Add required parameters if specified + if "required" in tool.inputSchema: + openai_tool["function"]["parameters"]["required"] = tool.inputSchema["required"] + + # Add properties from the input schema + if "properties" in tool.inputSchema: + for param_name, param_info in tool.inputSchema["properties"].items(): + # Create parameter definition + param_def = {} + + # Handle different schema constructs + if "anyOf" in param_info: + _handle_schema_construct(param_def, param_info, "anyOf") + elif "oneOf" in param_info: + _handle_schema_construct(param_def, param_info, "oneOf") + elif "allOf" in param_info: + _handle_schema_construct(param_def, param_info, "allOf") + elif "type" in param_info: + param_def["type"] = param_info["type"] + # Add format if available + if "format" in param_info: + param_def["format"] = param_info["format"] + else: + param_def["type"] = "string" # Default type + + # Add description from title or param name + param_def["description"] = param_info.get("description", param_info.get("title", param_name)) + + # Add enum values if available + if "enum" in param_info: + param_def["enum"] = param_info["enum"] + + # Add default value if available + if "default" in param_info and param_info["default"] is not None: + param_def["default"] = param_info["default"] + + # Add to properties + openai_tool["function"]["parameters"]["properties"][param_name] = param_def + + openai_tools.append(openai_tool) + + return openai_tools + + +def convert_to_anthropic_tools(airflow_tools: list) -> list: + """ + Convert Airflow tools to Anthropic tool definitions. + + Args: + airflow_tools: List of Airflow tools from MCP server + + Returns: + List of Anthropic tool definitions + """ + anthropic_tools = [] + + for tool in airflow_tools: + # Initialize the Anthropic tool structure + anthropic_tool = {"name": tool.name, "description": tool.description or tool.name, "input_schema": {}} + + # Extract parameters directly from inputSchema if available + if hasattr(tool, "inputSchema") and tool.inputSchema: + # Copy the input schema directly as Anthropic's format is similar to JSON Schema + anthropic_tool["input_schema"] = tool.inputSchema + else: + # Create a minimal schema if none exists + anthropic_tool["input_schema"] = {"type": "object", "properties": {}, "required": []} + + anthropic_tools.append(anthropic_tool) + + return anthropic_tools + + +def _handle_schema_construct(param_def: dict[str, Any], param_info: dict[str, Any], construct_type: str) -> None: + """ + Helper function to handle JSON Schema constructs like anyOf, oneOf, allOf. + + Args: + param_def: Parameter definition to update + param_info: Parameter info from the schema + construct_type: Type of construct (anyOf, oneOf, allOf) + """ + # Get the list of schemas from the construct + schemas = param_info[construct_type] + + # Find the first schema with a type + for schema in schemas: + if "type" in schema: + param_def["type"] = schema["type"] + + # Add format if available + if "format" in schema: + param_def["format"] = schema["format"] + + # Add enum values if available + if "enum" in schema: + param_def["enum"] = schema["enum"] + + # Add default value if available + if "default" in schema and schema["default"] is not None: + param_def["default"] = schema["default"] + + break + + # If no type was found, default to string + if "type" not in param_def: + param_def["type"] = "string" diff --git a/src/airflow_wingman/tools/execution.py b/src/airflow_wingman/tools/execution.py new file mode 100644 index 0000000..495cff0 --- /dev/null +++ b/src/airflow_wingman/tools/execution.py @@ -0,0 +1,117 @@ +""" +Tool execution module for Airflow Wingman. + +This module contains functions to list and execute Airflow tools. +""" + +import asyncio +import json +import traceback + +from airflow import configuration +from airflow.utils.log.logging_mixin import LoggingMixin +from airflow_mcp_server.config import AirflowConfig +from airflow_mcp_server.tools.tool_manager import get_airflow_tools, get_tool + +# Create a logger instance +logger = LoggingMixin().log + + +async def _list_airflow_tools_async(cookie: str) -> list: + """ + Async implementation to list available Airflow tools. + + Args: + cookie: Cookie for authentication + + Returns: + List of available Airflow tools + """ + try: + # Set up configuration + base_url = f"{configuration.conf.get('webserver', 'base_url')}/api/v1/" + logger.info(f"Setting up AirflowConfig with base_url: {base_url}") + config = AirflowConfig(base_url=base_url, cookie=cookie, auth_token=None) + + # Get available tools + logger.info("Getting Airflow tools...") + tools = await get_airflow_tools(config=config, mode="safe") + logger.info(f"Got {len(tools)} tools") + return tools + except Exception as e: + error_msg = f"Error listing Airflow tools: {str(e)}\n{traceback.format_exc()}" + logger.error(error_msg) + return [] + + +def list_airflow_tools(cookie: str) -> list: + """ + Synchronous wrapper to list available Airflow tools. + + Args: + cookie: Cookie for authentication + + Returns: + List of available Airflow tools + """ + return asyncio.run(_list_airflow_tools_async(cookie)) + + +async def _execute_airflow_tool_async(tool_name: str, arguments: dict, cookie: str) -> str: + """ + Async implementation to execute an Airflow tool. + + Args: + tool_name: Name of the tool to execute + arguments: Arguments to pass to the tool + cookie: Cookie for authentication + + Returns: + Result of the tool execution as a string + """ + try: + # Set up configuration + base_url = f"{configuration.conf.get('webserver', 'base_url')}/api/v1/" + logger.info(f"Setting up AirflowConfig with base_url: {base_url}") + config = AirflowConfig(base_url=base_url, cookie=cookie, auth_token=None) + + # Get the tool + logger.info(f"Getting tool: {tool_name}") + tool = await get_tool(config=config, tool_name=tool_name) + + if not tool: + error_msg = f"Tool not found: {tool_name}" + logger.error(error_msg) + return json.dumps({"error": error_msg}) + + # Execute the tool + logger.info(f"Executing tool: {tool_name} with arguments: {arguments}") + result = await tool.run(arguments) + + # Convert result to string + if isinstance(result, dict | list): + result_str = json.dumps(result, indent=2) + else: + result_str = str(result) + + logger.info(f"Tool execution result: {result_str[:100]}...") + return result_str + except Exception as e: + error_msg = f"Error executing tool: {str(e)}\n{traceback.format_exc()}" + logger.error(error_msg) + return json.dumps({"error": error_msg}) + + +def execute_airflow_tool(tool_name: str, arguments: dict, cookie: str) -> str: + """ + Synchronous wrapper to execute an Airflow tool. + + Args: + tool_name: Name of the tool to execute + arguments: Arguments to pass to the tool + cookie: Cookie for authentication + + Returns: + Result of the tool execution as a string + """ + return asyncio.run(_execute_airflow_tool_async(tool_name, arguments, cookie)) diff --git a/src/airflow_wingman/views.py b/src/airflow_wingman/views.py index 34158d7..fcf3bff 100644 --- a/src/airflow_wingman/views.py +++ b/src/airflow_wingman/views.py @@ -1,6 +1,6 @@ """Views for Airflow Wingman plugin.""" -from flask import Response, request, stream_with_context +from flask import Response, request, session from flask.json import jsonify from flask_appbuilder import BaseView as AppBuilderBaseView, expose @@ -8,6 +8,7 @@ from airflow_wingman.llm_client import LLMClient from airflow_wingman.llms_models import MODELS from airflow_wingman.notes import INTERFACE_MESSAGES from airflow_wingman.prompt_engineering import prepare_messages +from airflow_wingman.tools import list_airflow_tools class WingmanView(AppBuilderBaseView): @@ -28,8 +29,32 @@ class WingmanView(AppBuilderBaseView): try: data = self._validate_chat_request(request.get_json()) - # Create a new client for this request - client = LLMClient(data["api_key"]) + if data.get("cookie"): + session["airflow_cookie"] = data["cookie"] + + # Get available Airflow tools using the stored cookie + airflow_tools = [] + if session.get("airflow_cookie"): + try: + airflow_tools = list_airflow_tools(session["airflow_cookie"]) + except Exception as e: + # Log the error but continue without tools + print(f"Error fetching Airflow tools: {str(e)}") + + # Prepare messages with Airflow tools included in the prompt + data["messages"] = prepare_messages(data["messages"]) + + # Get provider name from request or use default + provider_name = data.get("provider", "openai") + + # Get base URL from models configuration based on provider + base_url = MODELS.get(provider_name, {}).get("endpoint") + + # Create a new client for this request with the appropriate provider + client = LLMClient(provider_name=provider_name, api_key=data["api_key"], base_url=base_url) + + # Set the Airflow tools for the client to use + client.set_airflow_tools(airflow_tools) if data["stream"]: return self._handle_streaming_response(client, data) @@ -46,39 +71,49 @@ class WingmanView(AppBuilderBaseView): if not data: raise ValueError("No data provided") - required_fields = ["provider", "model", "messages", "api_key"] + required_fields = ["model", "messages", "api_key"] missing = [f for f in required_fields if not data.get(f)] if missing: raise ValueError(f"Missing required fields: {', '.join(missing)}") - # Prepare messages with system instruction while maintaining history - messages = data["messages"] - messages = prepare_messages(messages) + # Validate provider if provided + provider = data.get("provider", "openai") + if provider not in MODELS: + raise ValueError(f"Unsupported provider: {provider}. Supported providers: {', '.join(MODELS.keys())}") return { - "provider": data["provider"], "model": data["model"], - "messages": messages, + "messages": data["messages"], "api_key": data["api_key"], - "stream": data.get("stream", False), + "stream": data.get("stream", True), "temperature": data.get("temperature", 0.7), "max_tokens": data.get("max_tokens"), + "cookie": data.get("cookie"), + "provider": provider, + "base_url": data.get("base_url"), } def _handle_streaming_response(self, client: LLMClient, data: dict) -> Response: """Handle streaming response.""" + try: + # Get the streaming generator from the client + generator = client.chat_completion(messages=data["messages"], model=data["model"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=True) - def generate(): - for chunk in client.chat_completion(messages=data["messages"], model=data["model"], provider=data["provider"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=True): - yield f"data: {chunk}\n\n" + def stream_response(): + # Send SSE format for each chunk + for chunk in generator: + if chunk: + yield f"data: {chunk}\n\n" - response = Response(stream_with_context(generate()), mimetype="text/event-stream") - response.headers["Content-Type"] = "text/event-stream" - response.headers["Cache-Control"] = "no-cache" - response.headers["Connection"] = "keep-alive" - return response + # Signal end of stream + yield "data: [DONE]\n\n" + + return Response(stream_response(), mimetype="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}) + except Exception as e: + # If streaming fails, return error + return jsonify({"error": str(e)}), 500 def _handle_regular_response(self, client: LLMClient, data: dict) -> Response: """Handle regular response.""" - response = client.chat_completion(messages=data["messages"], model=data["model"], provider=data["provider"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=False) + response = client.chat_completion(messages=data["messages"], model=data["model"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=False) return jsonify(response) From ee9c82f096e8fffa4ee7a4b02289f047fead2908 Mon Sep 17 00:00:00 2001 From: abhishekbhakat Date: Wed, 26 Feb 2025 19:52:04 +0000 Subject: [PATCH 04/14] ignore node modules --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index 85d6678..ac2d8da 100644 --- a/.gitignore +++ b/.gitignore @@ -174,3 +174,5 @@ cython_debug/ # Local Resources plugins_reference/ astro/ + +node_modules/ From db9c538d8afdebbd4312fd21c61770ca4ba30c2c Mon Sep 17 00:00:00 2001 From: abhishekbhakat Date: Wed, 26 Feb 2025 20:09:06 +0000 Subject: [PATCH 05/14] Logger fixes --- src/airflow_wingman/llm_client.py | 8 ++++---- src/airflow_wingman/providers/anthropic_provider.py | 5 +++-- src/airflow_wingman/providers/openai_provider.py | 6 +++--- src/airflow_wingman/tools/execution.py | 6 +++--- 4 files changed, 13 insertions(+), 12 deletions(-) diff --git a/src/airflow_wingman/llm_client.py b/src/airflow_wingman/llm_client.py index 4fd6768..40d1260 100644 --- a/src/airflow_wingman/llm_client.py +++ b/src/airflow_wingman/llm_client.py @@ -5,17 +5,17 @@ This module contains the LLMClient class that supports multiple LLM providers (OpenAI, Anthropic, OpenRouter) through a unified interface. """ +import logging import traceback from typing import Any -from airflow.utils.log.logging_mixin import LoggingMixin from flask import session from airflow_wingman.providers import create_llm_provider from airflow_wingman.tools import list_airflow_tools -# Create a logger instance -logger = LoggingMixin().log +# Create a properly namespaced logger for the Airflow plugin +logger = logging.getLogger("airflow.plugins.wingman") class LLMClient: @@ -102,7 +102,7 @@ class LLMClient: return {"content": self.provider.get_content(response)} except Exception as e: - error_msg = f"Error in {self.provider_name} API call: {str(e)}\\n{traceback.format_exc()}" + error_msg = f"Error in {self.provider_name} API call: {str(e)}\n{traceback.format_exc()}" logger.error(error_msg) return {"error": f"API request failed: {str(e)}"} diff --git a/src/airflow_wingman/providers/anthropic_provider.py b/src/airflow_wingman/providers/anthropic_provider.py index f2ecb2e..666c6d5 100644 --- a/src/airflow_wingman/providers/anthropic_provider.py +++ b/src/airflow_wingman/providers/anthropic_provider.py @@ -5,17 +5,18 @@ This module contains the Anthropic provider implementation that handles API requests, tool conversion, and response processing for Anthropic's Claude models. """ +import logging import traceback from typing import Any -from airflow.utils.log.logging_mixin import LoggingMixin from anthropic import Anthropic from airflow_wingman.providers.base import BaseLLMProvider from airflow_wingman.tools import execute_airflow_tool from airflow_wingman.tools.conversion import convert_to_anthropic_tools -logger = LoggingMixin().log +# Create a properly namespaced logger for the Airflow plugin +logger = logging.getLogger("airflow.plugins.wingman") class AnthropicProvider(BaseLLMProvider): diff --git a/src/airflow_wingman/providers/openai_provider.py b/src/airflow_wingman/providers/openai_provider.py index 6d12cbf..bfcc7dd 100644 --- a/src/airflow_wingman/providers/openai_provider.py +++ b/src/airflow_wingman/providers/openai_provider.py @@ -6,18 +6,18 @@ API requests, tool conversion, and response processing for OpenAI. """ import json +import logging import traceback from typing import Any -from airflow.utils.log.logging_mixin import LoggingMixin from openai import OpenAI from airflow_wingman.providers.base import BaseLLMProvider from airflow_wingman.tools import execute_airflow_tool from airflow_wingman.tools.conversion import convert_to_openai_tools -# Create a logger instance -logger = LoggingMixin().log +# Create a properly namespaced logger for the Airflow plugin +logger = logging.getLogger("airflow.plugins.wingman") class OpenAIProvider(BaseLLMProvider): diff --git a/src/airflow_wingman/tools/execution.py b/src/airflow_wingman/tools/execution.py index 495cff0..6242972 100644 --- a/src/airflow_wingman/tools/execution.py +++ b/src/airflow_wingman/tools/execution.py @@ -6,15 +6,15 @@ This module contains functions to list and execute Airflow tools. import asyncio import json +import logging import traceback from airflow import configuration -from airflow.utils.log.logging_mixin import LoggingMixin from airflow_mcp_server.config import AirflowConfig from airflow_mcp_server.tools.tool_manager import get_airflow_tools, get_tool -# Create a logger instance -logger = LoggingMixin().log +# Create a properly namespaced logger for the Airflow plugin +logger = logging.getLogger("airflow.plugins.wingman") async def _list_airflow_tools_async(cookie: str) -> list: From 86f70170469cb961fb085adc443c0f1305a1a3b8 Mon Sep 17 00:00:00 2001 From: abhishekbhakat Date: Wed, 26 Feb 2025 20:33:43 +0000 Subject: [PATCH 06/14] Log all responses --- src/airflow_wingman/llm_client.py | 13 +++++-- .../providers/anthropic_provider.py | 1 + .../providers/openai_provider.py | 1 + src/airflow_wingman/views.py | 36 ++++++++++++++++--- 4 files changed, 44 insertions(+), 7 deletions(-) diff --git a/src/airflow_wingman/llm_client.py b/src/airflow_wingman/llm_client.py index 40d1260..727991b 100644 --- a/src/airflow_wingman/llm_client.py +++ b/src/airflow_wingman/llm_client.py @@ -75,6 +75,7 @@ class LLMClient: # If streaming, return the generator directly if stream: + logger.info(f"Using streaming response from {self.provider_name}") return self.provider.get_streaming_content(response) # For non-streaming responses, handle tool calls if present @@ -96,10 +97,18 @@ class LLMClient: messages=messages, model=model, temperature=temperature, max_tokens=max_tokens, tool_results=tool_results, original_response=response ) - return {"content": self.provider.get_content(follow_up_response)} + content = self.provider.get_content(follow_up_response) + logger.info(f"Final content from {self.provider_name} with tool calls COMPLETE RESPONSE START >>>") + logger.info(content) + logger.info("<<< COMPLETE RESPONSE END") + return {"content": content} else: logger.info("Response does not contain tool calls") - return {"content": self.provider.get_content(response)} + content = self.provider.get_content(response) + logger.info(f"Final content from {self.provider_name} without tool calls COMPLETE RESPONSE START >>>") + logger.info(content) + logger.info("<<< COMPLETE RESPONSE END") + return {"content": content} except Exception as e: error_msg = f"Error in {self.provider_name} API call: {str(e)}\n{traceback.format_exc()}" diff --git a/src/airflow_wingman/providers/anthropic_provider.py b/src/airflow_wingman/providers/anthropic_provider.py index 666c6d5..d18dd1f 100644 --- a/src/airflow_wingman/providers/anthropic_provider.py +++ b/src/airflow_wingman/providers/anthropic_provider.py @@ -265,6 +265,7 @@ class AnthropicProvider(BaseLLMProvider): Returns: Generator yielding content chunks """ + logger.info("Starting Anthropic streaming response processing") def generate(): for chunk in response: diff --git a/src/airflow_wingman/providers/openai_provider.py b/src/airflow_wingman/providers/openai_provider.py index bfcc7dd..512970d 100644 --- a/src/airflow_wingman/providers/openai_provider.py +++ b/src/airflow_wingman/providers/openai_provider.py @@ -213,6 +213,7 @@ class OpenAIProvider(BaseLLMProvider): Returns: Generator yielding content chunks """ + logger.info("Starting OpenAI streaming response processing") def generate(): for chunk in response: diff --git a/src/airflow_wingman/views.py b/src/airflow_wingman/views.py index fcf3bff..599c4bd 100644 --- a/src/airflow_wingman/views.py +++ b/src/airflow_wingman/views.py @@ -1,5 +1,8 @@ """Views for Airflow Wingman plugin.""" +import json +import logging + from flask import Response, request, session from flask.json import jsonify from flask_appbuilder import BaseView as AppBuilderBaseView, expose @@ -10,6 +13,9 @@ from airflow_wingman.notes import INTERFACE_MESSAGES from airflow_wingman.prompt_engineering import prepare_messages from airflow_wingman.tools import list_airflow_tools +# Create a properly namespaced logger for the Airflow plugin +logger = logging.getLogger("airflow.plugins.wingman") + class WingmanView(AppBuilderBaseView): """View for Airflow Wingman plugin.""" @@ -50,6 +56,11 @@ class WingmanView(AppBuilderBaseView): # Get base URL from models configuration based on provider base_url = MODELS.get(provider_name, {}).get("endpoint") + # Log the request parameters (excluding API key for security) + safe_data = {k: v for k, v in data.items() if k != "api_key"} + logger.info(f"Chat request: provider={provider_name}, model={data.get('model')}, stream={data.get('stream')}") + logger.info(f"Request parameters: {json.dumps(safe_data)[:200]}...") + # Create a new client for this request with the appropriate provider client = LLMClient(provider_name=provider_name, api_key=data["api_key"], base_url=base_url) @@ -96,24 +107,39 @@ class WingmanView(AppBuilderBaseView): def _handle_streaming_response(self, client: LLMClient, data: dict) -> Response: """Handle streaming response.""" try: - # Get the streaming generator from the client + logger.info("Beginning streaming response") generator = client.chat_completion(messages=data["messages"], model=data["model"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=True) def stream_response(): + complete_response = "" + # Send SSE format for each chunk for chunk in generator: if chunk: yield f"data: {chunk}\n\n" - # Signal end of stream + # Log the complete assembled response at the end + logger.info("COMPLETE RESPONSE START >>>") + logger.info(complete_response) + logger.info("<<< COMPLETE RESPONSE END") + yield "data: [DONE]\n\n" return Response(stream_response(), mimetype="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}) except Exception as e: - # If streaming fails, return error + logger.error(f"Streaming error: {str(e)}") return jsonify({"error": str(e)}), 500 def _handle_regular_response(self, client: LLMClient, data: dict) -> Response: """Handle regular response.""" - response = client.chat_completion(messages=data["messages"], model=data["model"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=False) - return jsonify(response) + try: + logger.info("Beginning regular (non-streaming) response") + response = client.chat_completion(messages=data["messages"], model=data["model"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=False) + logger.info("COMPLETE RESPONSE START >>>") + logger.info(f"Response to frontend: {json.dumps(response)}") + logger.info("<<< COMPLETE RESPONSE END") + + return jsonify(response) + except Exception as e: + logger.error(f"Regular response error: {str(e)}") + return jsonify({"error": str(e)}), 500 From 63f6c824d4f3c9165c6bc138af3d4352abb23455 Mon Sep 17 00:00:00 2001 From: abhishekbhakat Date: Sat, 1 Mar 2025 13:02:56 +0000 Subject: [PATCH 07/14] Frontend capture complete response --- .../static/css/wingman_chat.css | 85 ++++++ src/airflow_wingman/static/js/wingman_chat.js | 285 ++++++++++++++++++ .../templates/wingman_chat.html | 277 +---------------- src/airflow_wingman/views.py | 6 + 4 files changed, 379 insertions(+), 274 deletions(-) create mode 100644 src/airflow_wingman/static/css/wingman_chat.css create mode 100644 src/airflow_wingman/static/js/wingman_chat.js diff --git a/src/airflow_wingman/static/css/wingman_chat.css b/src/airflow_wingman/static/css/wingman_chat.css new file mode 100644 index 0000000..10080dc --- /dev/null +++ b/src/airflow_wingman/static/css/wingman_chat.css @@ -0,0 +1,85 @@ +/* Provider and model selection styling */ +.provider-section { + margin-bottom: 20px; +} +.provider-name { + font-size: 16px; + font-weight: bold; + margin-bottom: 10px; + color: #666; +} +.model-option { + margin-left: 15px; + margin-bottom: 8px; +} +.model-option label { + display: block; + cursor: pointer; +} + +/* Message styling */ +.message { + margin-bottom: 15px; + max-width: 80%; + clear: both; +} + +.message-user { + float: right; + background-color: #f0f7ff; + border: 1px solid #d1e6ff; + border-radius: 15px 15px 0 15px; + padding: 10px 15px; +} + +.message-assistant { + float: left; + background-color: #f8f9fa; + border: 1px solid #e9ecef; + border-radius: 15px 15px 15px 0; + padding: 10px 15px; + white-space: pre-wrap; +} + +#chat-messages::after { + content: ""; + clear: both; + display: table; +} + +/* Scrollbar styling */ +.panel-body::-webkit-scrollbar { + width: 8px; +} + +.panel-body::-webkit-scrollbar-track { + background: #f1f1f1; +} + +.panel-body::-webkit-scrollbar-thumb { + background: #888; + border-radius: 4px; +} + +.panel-body::-webkit-scrollbar-thumb:hover { + background: #555; +} + +/* Processing indicator styling */ +.processing-indicator { + display: none; + background-color: #f0f8ff; + padding: 8px 12px; + border-radius: 4px; + margin: 8px 0; + font-style: italic; +} + +.processing-indicator.visible { + display: block; +} + +.pre-formatted { + white-space: pre-wrap; + font-family: monospace; +} diff --git a/src/airflow_wingman/static/js/wingman_chat.js b/src/airflow_wingman/static/js/wingman_chat.js new file mode 100644 index 0000000..821e5b5 --- /dev/null +++ b/src/airflow_wingman/static/js/wingman_chat.js @@ -0,0 +1,285 @@ +document.addEventListener('DOMContentLoaded', function() { + // Add title attributes for tooltips + document.querySelectorAll('[data-bs-toggle="tooltip"]').forEach(function(el) { + el.title = el.getAttribute('title') || el.getAttribute('data-bs-original-title'); + }); + + // Handle model selection and model name input + const modelNameInput = document.getElementById('modelName'); + const modelRadios = document.querySelectorAll('input[name="model"]'); + + modelRadios.forEach(function(radio) { + radio.addEventListener('change', function() { + const provider = this.value.split(':')[0]; + const modelName = this.getAttribute('data-model-name'); + console.log('Selected provider:', provider); + console.log('Model name:', modelName); + + if (provider === 'openrouter') { + console.log('Enabling model name input'); + modelNameInput.disabled = false; + modelNameInput.value = ''; + modelNameInput.placeholder = 'Enter model name for OpenRouter'; + } else { + console.log('Disabling model name input'); + modelNameInput.disabled = true; + modelNameInput.value = modelName; + } + }); + }); + + // Set initial state based on default selection + const defaultSelected = document.querySelector('input[name="model"]:checked'); + if (defaultSelected) { + const provider = defaultSelected.value.split(':')[0]; + const modelName = defaultSelected.getAttribute('data-model-name'); + console.log('Initial provider:', provider); + console.log('Initial model name:', modelName); + + if (provider === 'openrouter') { + console.log('Initially enabling model name input'); + modelNameInput.disabled = false; + modelNameInput.value = ''; + modelNameInput.placeholder = 'Enter model name for OpenRouter'; + } else { + console.log('Initially disabling model name input'); + modelNameInput.disabled = true; + modelNameInput.value = modelName; + } + } + + const messageInput = document.getElementById('message-input'); + const sendButton = document.getElementById('send-button'); + const refreshButton = document.getElementById('refresh-button'); + const chatMessages = document.getElementById('chat-messages'); + + let currentMessageDiv = null; + let messageHistory = []; + + // Create a processing indicator element + const processingIndicator = document.createElement('div'); + processingIndicator.className = 'processing-indicator'; + processingIndicator.textContent = 'Processing tool calls...'; + chatMessages.appendChild(processingIndicator); + + function clearChat() { + // Clear the chat messages + chatMessages.innerHTML = ''; + // Add back the processing indicator + chatMessages.appendChild(processingIndicator); + // Reset message history + messageHistory = []; + // Clear the input field + messageInput.value = ''; + // Enable input if it was disabled + messageInput.disabled = false; + sendButton.disabled = false; + } + + function addMessage(content, isUser) { + const messageDiv = document.createElement('div'); + messageDiv.className = `message ${isUser ? 'message-user' : 'message-assistant'}`; + + // Apply pre-formatted class to preserve whitespace and newlines + messageDiv.classList.add('pre-formatted'); + + // Use innerText instead of textContent to preserve newlines + messageDiv.innerText = content; + + chatMessages.appendChild(messageDiv); + chatMessages.scrollTop = chatMessages.scrollHeight; + return messageDiv; + } + + function showProcessingIndicator() { + processingIndicator.classList.add('visible'); + chatMessages.scrollTop = chatMessages.scrollHeight; + } + + function hideProcessingIndicator() { + processingIndicator.classList.remove('visible'); + } + + async function sendMessage() { + const message = messageInput.value.trim(); + if (!message) return; + + // Get selected model + const selectedModel = document.querySelector('input[name="model"]:checked'); + if (!selectedModel) { + alert('Please select a model'); + return; + } + + const [provider, modelId] = selectedModel.value.split(':'); + const modelName = provider === 'openrouter' ? modelNameInput.value : modelId; + + // Clear input and add user message + messageInput.value = ''; + addMessage(message, true); + + // Add user message to history + messageHistory.push({ + role: 'user', + content: message + }); + + // Use full message history for the request + const messages = [...messageHistory]; + + // Create assistant message div + currentMessageDiv = addMessage('', false); + + // Get API key + const apiKey = document.getElementById('api-key').value.trim(); + if (!apiKey) { + alert('Please enter an API key'); + return; + } + + // Disable input while processing + messageInput.disabled = true; + sendButton.disabled = true; + + // Get CSRF token + const csrfToken = document.querySelector('meta[name="csrf-token"]')?.getAttribute('content'); + if (!csrfToken) { + alert('CSRF token not found. Please refresh the page.'); + return; + } + + // Create request data + const requestData = { + provider: provider, + model: modelName, + messages: messages, + api_key: apiKey, + stream: true, + temperature: 0.7 + }; + console.log('Sending request:', {...requestData, api_key: '***'}); + + try { + // Send request + const response = await fetch('/wingman/chat', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'X-CSRFToken': csrfToken + }, + body: JSON.stringify(requestData) + }); + + if (!response.ok) { + const error = await response.json(); + throw new Error(error.error || 'Failed to get response'); + } + + // Process the streaming response + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let fullResponse = ''; + + while (true) { + const { value, done } = await reader.read(); + if (done) break; + + const chunk = decoder.decode(value); + const lines = chunk.split('\n'); + + for (const line of lines) { + if (line.trim() === '') continue; + + if (line.startsWith('data: ')) { + const content = line.slice(6); // Remove 'data: ' prefix + + // Check for special events or end marker + if (content === '[DONE]') { + console.log('Stream complete'); + + // Add assistant's response to history + if (fullResponse) { + messageHistory.push({ + role: 'assistant', + content: fullResponse + }); + } + continue; + } + + // Try to parse as JSON for special events + try { + const parsed = JSON.parse(content); + + if (parsed.event === 'tool_processing_start') { + console.log('Tool processing started'); + showProcessingIndicator(); + continue; + } + + if (parsed.event === 'tool_processing_complete') { + console.log('Tool processing completed'); + hideProcessingIndicator(); + continue; + } + + // Handle the complete response event + if (parsed.event === 'complete_response') { + console.log('Received complete response from backend'); + // Use the complete response from the backend + fullResponse = parsed.content; + + // Update the display with the complete response + if (!currentMessageDiv.classList.contains('pre-formatted')) { + currentMessageDiv.classList.add('pre-formatted'); + } + currentMessageDiv.innerText = fullResponse; + continue; + } + + // If we have JSON that's not a special event, it might be content + currentMessageDiv.textContent += JSON.stringify(parsed); + fullResponse += JSON.stringify(parsed); + } catch (e) { + // Not JSON, handle as normal content + // console.log('Received chunk:', JSON.stringify(content)); + + // Add to full response + fullResponse += content; + + // Create a properly formatted display + if (!currentMessageDiv.classList.contains('pre-formatted')) { + currentMessageDiv.classList.add('pre-formatted'); + } + + // Always rebuild the entire content from the full response + currentMessageDiv.innerText = fullResponse; + } + // Scroll to bottom + chatMessages.scrollTop = chatMessages.scrollHeight; + } + } + } + } catch (error) { + console.error('Error:', error); + if (currentMessageDiv) { + currentMessageDiv.textContent = `Error: ${error.message}`; + currentMessageDiv.style.color = 'red'; + } + } finally { + // Always re-enable input and hide indicators + messageInput.disabled = false; + sendButton.disabled = false; + hideProcessingIndicator(); + } + } + + sendButton.addEventListener('click', sendMessage); + messageInput.addEventListener('keypress', function(e) { + if (e.key === 'Enter') { + sendMessage(); + } + }); + + refreshButton.addEventListener('click', clearChat); +}); diff --git a/src/airflow_wingman/templates/wingman_chat.html b/src/airflow_wingman/templates/wingman_chat.html index 56a38a6..8424866 100644 --- a/src/airflow_wingman/templates/wingman_chat.html +++ b/src/airflow_wingman/templates/wingman_chat.html @@ -3,6 +3,7 @@ {% block head_meta %} {{ super() }} + {% endblock %} {% block content %} @@ -77,25 +78,7 @@ - + @@ -129,259 +112,5 @@ - - - + {% endblock %} diff --git a/src/airflow_wingman/views.py b/src/airflow_wingman/views.py index 599c4bd..c5d6233 100644 --- a/src/airflow_wingman/views.py +++ b/src/airflow_wingman/views.py @@ -116,6 +116,7 @@ class WingmanView(AppBuilderBaseView): # Send SSE format for each chunk for chunk in generator: if chunk: + complete_response += chunk yield f"data: {chunk}\n\n" # Log the complete assembled response at the end @@ -123,6 +124,11 @@ class WingmanView(AppBuilderBaseView): logger.info(complete_response) logger.info("<<< COMPLETE RESPONSE END") + # Send the complete response as a special event + complete_event = json.dumps({"event": "complete_response", "content": complete_response}) + yield f"data: {complete_event}\n\n" + + # Signal the end of the stream yield "data: [DONE]\n\n" return Response(stream_response(), mimetype="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}) From ab396318159df72a1a2bb22ad90cfa2208c34042 Mon Sep 17 00:00:00 2001 From: abhishekbhakat Date: Sat, 1 Mar 2025 16:18:49 +0000 Subject: [PATCH 08/14] fix formatting of texts --- .../static/css/wingman_chat.css | 18 ++++++- src/airflow_wingman/static/js/wingman_chat.js | 50 ++++++++++++++----- .../templates/wingman_chat.html | 1 + 3 files changed, 55 insertions(+), 14 deletions(-) diff --git a/src/airflow_wingman/static/css/wingman_chat.css b/src/airflow_wingman/static/css/wingman_chat.css index 10080dc..1d6cdce 100644 --- a/src/airflow_wingman/static/css/wingman_chat.css +++ b/src/airflow_wingman/static/css/wingman_chat.css @@ -24,6 +24,11 @@ clear: both; } +.message p { + margin-top: 0.5em; + margin-bottom: 0.5em; +} + .message-user { float: right; background-color: #f0f7ff; @@ -32,13 +37,22 @@ padding: 10px 15px; } +.message pre { + margin-top: 0.5em; + margin-bottom: 0.5em; + padding: 0.5em; +} + .message-assistant { float: left; background-color: #f8f9fa; border: 1px solid #e9ecef; border-radius: 15px 15px 15px 0; padding: 10px 15px; - white-space: pre-wrap; +} + +.message code { + padding: 0.1em 0.3em; } #chat-messages::after { @@ -80,6 +94,6 @@ } .pre-formatted { - white-space: pre-wrap; font-family: monospace; + line-height: 1.2; } diff --git a/src/airflow_wingman/static/js/wingman_chat.js b/src/airflow_wingman/static/js/wingman_chat.js index 821e5b5..5b83be3 100644 --- a/src/airflow_wingman/static/js/wingman_chat.js +++ b/src/airflow_wingman/static/js/wingman_chat.js @@ -1,5 +1,5 @@ document.addEventListener('DOMContentLoaded', function() { - // Add title attributes for tooltips + // Initialize tooltips document.querySelectorAll('[data-bs-toggle="tooltip"]').forEach(function(el) { el.title = el.getAttribute('title') || el.getAttribute('data-bs-original-title'); }); @@ -80,12 +80,26 @@ document.addEventListener('DOMContentLoaded', function() { const messageDiv = document.createElement('div'); messageDiv.className = `message ${isUser ? 'message-user' : 'message-assistant'}`; - // Apply pre-formatted class to preserve whitespace and newlines messageDiv.classList.add('pre-formatted'); - - // Use innerText instead of textContent to preserve newlines - messageDiv.innerText = content; - + + // Use marked.js to render markdown + try { + // Configure marked options + marked.use({ + breaks: true, // Add line breaks on single newlines + gfm: true, // Use GitHub Flavored Markdown + headerIds: false, // Don't add IDs to headers + mangle: false, // Don't mangle email addresses + }); + + // Render markdown to HTML + messageDiv.innerHTML = marked.parse(content); + } catch (e) { + console.error('Error rendering markdown:', e); + // Fallback to innerText if markdown parsing fails + messageDiv.innerText = content; + } + chatMessages.appendChild(messageDiv); chatMessages.scrollTop = chatMessages.scrollHeight; return messageDiv; @@ -228,12 +242,24 @@ document.addEventListener('DOMContentLoaded', function() { console.log('Received complete response from backend'); // Use the complete response from the backend fullResponse = parsed.content; - - // Update the display with the complete response - if (!currentMessageDiv.classList.contains('pre-formatted')) { - currentMessageDiv.classList.add('pre-formatted'); + + // Use marked.js to render markdown + try { + // Configure marked options + marked.use({ + breaks: true, // Add line breaks on single newlines + gfm: true, // Use GitHub Flavored Markdown + headerIds: false, // Don't add IDs to headers + mangle: false, // Don't mangle email addresses + }); + + // Render markdown to HTML + currentMessageDiv.innerHTML = marked.parse(fullResponse); + } catch (e) { + console.error('Error rendering markdown:', e); + // Fallback to innerText if markdown parsing fails + currentMessageDiv.innerText = fullResponse; } - currentMessageDiv.innerText = fullResponse; continue; } @@ -253,7 +279,7 @@ document.addEventListener('DOMContentLoaded', function() { } // Always rebuild the entire content from the full response - currentMessageDiv.innerText = fullResponse; + currentMessageDiv.innerHTML = marked.parse(fullResponse); } // Scroll to bottom chatMessages.scrollTop = chatMessages.scrollHeight; diff --git a/src/airflow_wingman/templates/wingman_chat.html b/src/airflow_wingman/templates/wingman_chat.html index 8424866..c1da377 100644 --- a/src/airflow_wingman/templates/wingman_chat.html +++ b/src/airflow_wingman/templates/wingman_chat.html @@ -112,5 +112,6 @@ + {% endblock %} From 7df5e3c55efc40f53d420c1169763f4bcdf6c5ce Mon Sep 17 00:00:00 2001 From: abhishekbhakat Date: Sat, 1 Mar 2025 17:40:58 +0000 Subject: [PATCH 09/14] fix tool listing and temperature --- src/airflow_wingman/llm_client.py | 95 ++++++++++++++++++- .../providers/anthropic_provider.py | 11 ++- src/airflow_wingman/providers/base.py | 4 +- .../providers/openai_provider.py | 21 +++- src/airflow_wingman/static/js/wingman_chat.js | 31 ++++-- src/airflow_wingman/tools/conversion.py | 4 + src/airflow_wingman/views.py | 48 ++++++++-- 7 files changed, 189 insertions(+), 25 deletions(-) diff --git a/src/airflow_wingman/llm_client.py b/src/airflow_wingman/llm_client.py index 727991b..b81a19b 100644 --- a/src/airflow_wingman/llm_client.py +++ b/src/airflow_wingman/llm_client.py @@ -50,7 +50,9 @@ class LLMClient: """ self.airflow_tools = tools - def chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float = 0.7, max_tokens: int | None = None, stream: bool = False) -> dict[str, Any]: + def chat_completion( + self, messages: list[dict[str, str]], model: str, temperature: float = 0.4, max_tokens: int | None = None, stream: bool = True, return_response_obj: bool = False + ) -> dict[str, Any] | tuple[Any, Any]: """ Send a chat completion request to the LLM provider. @@ -60,9 +62,12 @@ class LLMClient: temperature: Sampling temperature (0-1) max_tokens: Maximum tokens to generate stream: Whether to stream the response (default is True) + return_response_obj: If True and streaming, returns both the response object and generator Returns: - Dictionary with the response content or a generator for streaming + If stream=False: Dictionary with the response content + If stream=True and return_response_obj=False: Generator for streaming + If stream=True and return_response_obj=True: Tuple of (response_obj, generator) """ # Get provider-specific tool definitions from Airflow tools provider_tools = self.provider.convert_tools(self.airflow_tools) @@ -73,10 +78,13 @@ class LLMClient: response = self.provider.create_chat_completion(messages=messages, model=model, temperature=temperature, max_tokens=max_tokens, stream=stream, tools=provider_tools) logger.info(f"Received response from {self.provider_name}") - # If streaming, return the generator directly + # If streaming, handle based on return_response_obj flag if stream: logger.info(f"Using streaming response from {self.provider_name}") - return self.provider.get_streaming_content(response) + if return_response_obj: + return response, self.provider.get_streaming_content(response) + else: + return self.provider.get_streaming_content(response) # For non-streaming responses, handle tool calls if present if self.provider.has_tool_calls(response): @@ -135,6 +143,85 @@ class LLMClient: return cls(provider_name=provider_name, api_key=api_key, base_url=base_url) + def process_tool_calls_and_follow_up(self, response, messages, model, temperature, max_tokens, max_iterations=5): + """ + Process tool calls recursively from a response and make follow-up requests until + there are no more tool calls or max_iterations is reached. + Returns a generator for streaming the final follow-up response. + + Args: + response: The original response object containing tool calls + messages: List of message dictionaries with 'role' and 'content' + model: Model identifier + temperature: Sampling temperature (0-1) + max_tokens: Maximum tokens to generate + max_iterations: Maximum number of tool call iterations to prevent infinite loops + + Returns: + Generator for streaming the final follow-up response + """ + try: + iteration = 0 + current_response = response + cookie = session.get("airflow_cookie") + + if not cookie: + error_msg = "No Airflow cookie available" + logger.error(error_msg) + yield f"Error: {error_msg}" + return + + # Process tool calls recursively until there are no more or max_iterations is reached + while self.provider.has_tool_calls(current_response) and iteration < max_iterations: + iteration += 1 + logger.info(f"Processing tool calls iteration {iteration}/{max_iterations}") + + # Process tool calls and get results + tool_results = self.provider.process_tool_calls(current_response, cookie) + + # Make follow-up request with tool results + logger.info(f"Making follow-up request with tool results (iteration {iteration})") + + # Only stream on the final iteration + should_stream = (iteration == max_iterations) or not self.provider.has_tool_calls(current_response) + + follow_up_response = self.provider.create_follow_up_completion( + messages=messages, model=model, temperature=temperature, max_tokens=max_tokens, tool_results=tool_results, original_response=current_response, stream=should_stream + ) + + # Check if this follow-up response has more tool calls + if not self.provider.has_tool_calls(follow_up_response): + logger.info(f"No more tool calls after iteration {iteration}") + # Final response - return the streaming content + if not should_stream: + # If we didn't stream this response, we need to make a streaming version + content = self.provider.get_content(follow_up_response) + yield content + return + else: + # Return the streaming generator + return self.provider.get_streaming_content(follow_up_response) + + # Update current_response for the next iteration + current_response = follow_up_response + + # If we've reached max_iterations and still have tool calls, log a warning + if iteration == max_iterations and self.provider.has_tool_calls(current_response): + logger.warning(f"Reached maximum tool call iterations ({max_iterations})") + # Stream the final response even if it has tool calls + return self.provider.get_streaming_content(follow_up_response) + + # If we didn't process any tool calls (shouldn't happen), return an error + if iteration == 0: + error_msg = "No tool calls found in response" + logger.error(error_msg) + yield f"Error: {error_msg}" + + except Exception as e: + error_msg = f"Error processing tool calls: {str(e)}\n{traceback.format_exc()}" + logger.error(error_msg) + yield f"Error: {str(e)}" + def refresh_tools(self, cookie: str) -> None: """ Refresh the available Airflow tools. diff --git a/src/airflow_wingman/providers/anthropic_provider.py b/src/airflow_wingman/providers/anthropic_provider.py index d18dd1f..40e5787 100644 --- a/src/airflow_wingman/providers/anthropic_provider.py +++ b/src/airflow_wingman/providers/anthropic_provider.py @@ -5,6 +5,7 @@ This module contains the Anthropic provider implementation that handles API requests, tool conversion, and response processing for Anthropic's Claude models. """ +import json import logging import traceback from typing import Any @@ -50,7 +51,7 @@ class AnthropicProvider(BaseLLMProvider): return convert_to_anthropic_tools(airflow_tools) def create_chat_completion( - self, messages: list[dict[str, Any]], model: str, temperature: float = 0.7, max_tokens: int | None = None, stream: bool = False, tools: list[dict[str, Any]] | None = None + self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, stream: bool = False, tools: list[dict[str, Any]] | None = None ) -> Any: """ Make API request to Anthropic. @@ -84,6 +85,12 @@ class AnthropicProvider(BaseLLMProvider): # Add tools if provided if tools and len(tools) > 0: params["tools"] = tools + else: + logger.warning("No tools included in request") + + # Log the full request parameters (with sensitive information redacted) + log_params = params.copy() + logger.info(f"Request parameters: {json.dumps(log_params)}") # Make the API request response = self.client.messages.create(**params) @@ -185,7 +192,7 @@ class AnthropicProvider(BaseLLMProvider): return results def create_follow_up_completion( - self, messages: list[dict[str, Any]], model: str, temperature: float = 0.7, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None + self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None ) -> Any: """ Create a follow-up completion with tool results. diff --git a/src/airflow_wingman/providers/base.py b/src/airflow_wingman/providers/base.py index dcdbc44..87f4e7c 100644 --- a/src/airflow_wingman/providers/base.py +++ b/src/airflow_wingman/providers/base.py @@ -33,7 +33,7 @@ class BaseLLMProvider(ABC): @abstractmethod def create_chat_completion( - self, messages: list[dict[str, Any]], model: str, temperature: float = 0.7, max_tokens: int | None = None, stream: bool = False, tools: list[dict[str, Any]] | None = None + self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, stream: bool = False, tools: list[dict[str, Any]] | None = None ) -> Any: """ Make API request to provider. @@ -80,7 +80,7 @@ class BaseLLMProvider(ABC): @abstractmethod def create_follow_up_completion( - self, messages: list[dict[str, Any]], model: str, temperature: float = 0.7, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None + self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None ) -> Any: """ Create a follow-up completion with tool results. diff --git a/src/airflow_wingman/providers/openai_provider.py b/src/airflow_wingman/providers/openai_provider.py index 512970d..a6d5c0c 100644 --- a/src/airflow_wingman/providers/openai_provider.py +++ b/src/airflow_wingman/providers/openai_provider.py @@ -52,7 +52,7 @@ class OpenAIProvider(BaseLLMProvider): return convert_to_openai_tools(airflow_tools) def create_chat_completion( - self, messages: list[dict[str, Any]], model: str, temperature: float = 0.7, max_tokens: int | None = None, stream: bool = False, tools: list[dict[str, Any]] | None = None + self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, stream: bool = False, tools: list[dict[str, Any]] | None = None ) -> Any: """ Make API request to OpenAI. @@ -77,6 +77,23 @@ class OpenAIProvider(BaseLLMProvider): try: logger.info(f"Sending chat completion request to OpenAI with model: {model}") + + # Log information about tools + if not has_tools: + logger.warning("No tools included in request") + + # Log request parameters + request_params = { + "model": model, + "messages": messages, + "temperature": temperature, + "max_tokens": max_tokens, + "stream": stream, + "tools": tools if has_tools else None, + "tool_choice": tool_choice, + } + logger.info(f"Request parameters: {json.dumps(request_params)}") + response = self.client.chat.completions.create( model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream, tools=tools if has_tools else None, tool_choice=tool_choice ) @@ -143,7 +160,7 @@ class OpenAIProvider(BaseLLMProvider): return results def create_follow_up_completion( - self, messages: list[dict[str, Any]], model: str, temperature: float = 0.7, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None + self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None ) -> Any: """ Create a follow-up completion with tool results. diff --git a/src/airflow_wingman/static/js/wingman_chat.js b/src/airflow_wingman/static/js/wingman_chat.js index 5b83be3..553cbba 100644 --- a/src/airflow_wingman/static/js/wingman_chat.js +++ b/src/airflow_wingman/static/js/wingman_chat.js @@ -81,7 +81,7 @@ document.addEventListener('DOMContentLoaded', function() { messageDiv.className = `message ${isUser ? 'message-user' : 'message-assistant'}`; messageDiv.classList.add('pre-formatted'); - + // Use marked.js to render markdown try { // Configure marked options @@ -91,7 +91,7 @@ document.addEventListener('DOMContentLoaded', function() { headerIds: false, // Don't add IDs to headers mangle: false, // Don't mangle email addresses }); - + // Render markdown to HTML messageDiv.innerHTML = marked.parse(content); } catch (e) { @@ -99,7 +99,7 @@ document.addEventListener('DOMContentLoaded', function() { // Fallback to innerText if markdown parsing fails messageDiv.innerText = content; } - + chatMessages.appendChild(messageDiv); chatMessages.scrollTop = chatMessages.scrollHeight; return messageDiv; @@ -108,10 +108,18 @@ document.addEventListener('DOMContentLoaded', function() { function showProcessingIndicator() { processingIndicator.classList.add('visible'); chatMessages.scrollTop = chatMessages.scrollHeight; + + // Disable send button and input field during tool processing + sendButton.disabled = true; + messageInput.disabled = true; } function hideProcessingIndicator() { processingIndicator.classList.remove('visible'); + + // Re-enable send button and input field after tool processing + sendButton.disabled = false; + messageInput.disabled = false; } async function sendMessage() { @@ -169,7 +177,7 @@ document.addEventListener('DOMContentLoaded', function() { messages: messages, api_key: apiKey, stream: true, - temperature: 0.7 + temperature: 0.4, }; console.log('Sending request:', {...requestData, api_key: '***'}); @@ -231,6 +239,17 @@ document.addEventListener('DOMContentLoaded', function() { continue; } + if (parsed.event === 'replace_content') { + console.log('Replacing content due to tool call'); + // Clear the current message content + const currentMessageDiv = document.querySelector('.message.assistant:last-child .message-content'); + if (currentMessageDiv) { + currentMessageDiv.innerHTML = ''; + fullResponse = ''; // Reset the full response + } + continue; + } + if (parsed.event === 'tool_processing_complete') { console.log('Tool processing completed'); hideProcessingIndicator(); @@ -242,7 +261,7 @@ document.addEventListener('DOMContentLoaded', function() { console.log('Received complete response from backend'); // Use the complete response from the backend fullResponse = parsed.content; - + // Use marked.js to render markdown try { // Configure marked options @@ -252,7 +271,7 @@ document.addEventListener('DOMContentLoaded', function() { headerIds: false, // Don't add IDs to headers mangle: false, // Don't mangle email addresses }); - + // Render markdown to HTML currentMessageDiv.innerHTML = marked.parse(fullResponse); } catch (e) { diff --git a/src/airflow_wingman/tools/conversion.py b/src/airflow_wingman/tools/conversion.py index 7a30c02..09010e2 100644 --- a/src/airflow_wingman/tools/conversion.py +++ b/src/airflow_wingman/tools/conversion.py @@ -5,6 +5,7 @@ This module contains functions to convert between different tool formats for various LLM providers (OpenAI, Anthropic, etc.). """ +import logging from typing import Any @@ -84,6 +85,8 @@ def convert_to_anthropic_tools(airflow_tools: list) -> list: Returns: List of Anthropic tool definitions """ + logger = logging.getLogger("airflow.plugins.wingman") + logger.info(f"Converting {len(airflow_tools)} Airflow tools to Anthropic format") anthropic_tools = [] for tool in airflow_tools: @@ -100,6 +103,7 @@ def convert_to_anthropic_tools(airflow_tools: list) -> list: anthropic_tools.append(anthropic_tool) + logger.info(f"Converted {len(anthropic_tools)} tools to Anthropic format") return anthropic_tools diff --git a/src/airflow_wingman/views.py b/src/airflow_wingman/views.py index c5d6233..bc60804 100644 --- a/src/airflow_wingman/views.py +++ b/src/airflow_wingman/views.py @@ -40,12 +40,18 @@ class WingmanView(AppBuilderBaseView): # Get available Airflow tools using the stored cookie airflow_tools = [] - if session.get("airflow_cookie"): + airflow_cookie = request.cookies.get("session") + if airflow_cookie: try: - airflow_tools = list_airflow_tools(session["airflow_cookie"]) + airflow_tools = list_airflow_tools(airflow_cookie) + logger.info(f"Loaded {len(airflow_tools)} Airflow tools") + if len(airflow_tools) > 0: + logger.info(f"First tool: {airflow_tools[0].name if hasattr(airflow_tools[0], 'name') else 'Unknown'}") + else: + logger.warning("No Airflow tools were loaded") except Exception as e: # Log the error but continue without tools - print(f"Error fetching Airflow tools: {str(e)}") + logger.error(f"Error fetching Airflow tools: {str(e)}") # Prepare messages with Airflow tools included in the prompt data["messages"] = prepare_messages(data["messages"]) @@ -97,7 +103,7 @@ class WingmanView(AppBuilderBaseView): "messages": data["messages"], "api_key": data["api_key"], "stream": data.get("stream", True), - "temperature": data.get("temperature", 0.7), + "temperature": data.get("temperature", 0.4), "max_tokens": data.get("max_tokens"), "cookie": data.get("cookie"), "provider": provider, @@ -108,27 +114,51 @@ class WingmanView(AppBuilderBaseView): """Handle streaming response.""" try: logger.info("Beginning streaming response") - generator = client.chat_completion(messages=data["messages"], model=data["model"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=True) + # Use the enhanced chat_completion method with return_response_obj=True + response_obj, generator = client.chat_completion( + messages=data["messages"], model=data["model"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=True, return_response_obj=True + ) def stream_response(): complete_response = "" - # Send SSE format for each chunk + # Stream the initial response for chunk in generator: if chunk: complete_response += chunk yield f"data: {chunk}\n\n" - # Log the complete assembled response at the end + # Log the complete assembled response logger.info("COMPLETE RESPONSE START >>>") logger.info(complete_response) logger.info("<<< COMPLETE RESPONSE END") - # Send the complete response as a special event + # Check for tool calls and make follow-up if needed + if client.provider.has_tool_calls(response_obj): + # Signal tool processing start - frontend should disable send button + yield f"data: {json.dumps({'event': 'tool_processing_start'})}\n\n" + + # Signal to replace content - frontend should clear the current message + yield f"data: {json.dumps({'event': 'replace_content'})}\n\n" + + logger.info("Response contains tool calls, making follow-up request") + + # Process tool calls and get follow-up response (handles recursive tool calls) + follow_up_response = client.process_tool_calls_and_follow_up(response_obj, data["messages"], data["model"], data["temperature"], data["max_tokens"]) + + # Stream the follow-up response + for chunk in follow_up_response: + if chunk: + yield f"data: {chunk}\n\n" + + # Signal tool processing complete - frontend can re-enable send button + yield f"data: {json.dumps({'event': 'tool_processing_complete'})}\n\n" + + # Send the complete response as a special event (for compatibility with existing code) complete_event = json.dumps({"event": "complete_response", "content": complete_response}) yield f"data: {complete_event}\n\n" - # Signal the end of the stream + # Signal end of stream yield "data: [DONE]\n\n" return Response(stream_response(), mimetype="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}) From 6cb60f1bbd394d3fce64e89f30fe564e20a8898b Mon Sep 17 00:00:00 2001 From: abhishekbhakat Date: Sun, 2 Mar 2025 05:24:16 +0000 Subject: [PATCH 10/14] Implement provider-agnostic tool calling for Anthropic streaming responses. --- .../providers/anthropic_provider.py | 119 +++++++++++++++--- 1 file changed, 102 insertions(+), 17 deletions(-) diff --git a/src/airflow_wingman/providers/anthropic_provider.py b/src/airflow_wingman/providers/anthropic_provider.py index 40e5787..6996e30 100644 --- a/src/airflow_wingman/providers/anthropic_provider.py +++ b/src/airflow_wingman/providers/anthropic_provider.py @@ -96,6 +96,17 @@ class AnthropicProvider(BaseLLMProvider): response = self.client.messages.create(**params) logger.info("Received response from Anthropic") + # Log the response (with sensitive information redacted) + logger.info(f"Anthropic response type: {type(response).__name__}") + + # Log as much information as possible + if hasattr(response, "json"): + logger.info(f"Anthropic response json: {json.dumps(response.json)}") + + # Log response attributes + response_attrs = [attr for attr in dir(response) if not attr.startswith("_") and not callable(getattr(response, attr))] + logger.info(f"Anthropic response attributes: {response_attrs}") + return response except Exception as e: error_msg = str(e) @@ -139,18 +150,20 @@ class AnthropicProvider(BaseLLMProvider): Check if the response contains tool calls. Args: - response: Anthropic response object + response: Anthropic response object or generator with tool_call attribute Returns: True if the response contains tool calls, False otherwise """ - # Check if any content block is a tool_use block - if not hasattr(response, "content"): - return False + # Check if response is a generator with a tool_call attribute + if hasattr(response, "tool_call") and response.tool_call is not None: + return True - for block in response.content: - if isinstance(block, dict) and block.get("type") == "tool_use": - return True + # Check if any content block is a tool_use block (for non-streaming responses) + if hasattr(response, "content"): + for block in response.content: + if isinstance(block, dict) and block.get("type") == "tool_use": + return True return False @@ -159,7 +172,7 @@ class AnthropicProvider(BaseLLMProvider): Process tool calls from the response. Args: - response: Anthropic response object + response: Anthropic response object or generator with tool_call attribute cookie: Airflow cookie for authentication Returns: @@ -170,13 +183,29 @@ class AnthropicProvider(BaseLLMProvider): if not self.has_tool_calls(response): return results - # Extract tool_use blocks - tool_use_blocks = [block for block in response.content if isinstance(block, dict) and block.get("type") == "tool_use"] + tool_calls = [] - for block in tool_use_blocks: - tool_id = block.get("id") - tool_name = block.get("name") - tool_input = block.get("input", {}) + # Check if response is a generator with a tool_call attribute + if hasattr(response, "tool_call") and response.tool_call is not None: + logger.info(f"Processing tool call from generator: {response.tool_call}") + tool_calls.append(response.tool_call) + # Otherwise, extract tool calls from response content (for non-streaming responses) + elif hasattr(response, "content"): + logger.info("Processing tool calls from response content") + tool_calls = [block for block in response.content if isinstance(block, dict) and block.get("type") == "tool_use"] + + for tool_call in tool_calls: + # Extract tool details - handle both formats (generator's tool_call and content block) + if isinstance(tool_call, dict) and "id" in tool_call: + # This is from the generator's tool_call attribute + tool_id = tool_call.get("id") + tool_name = tool_call.get("name") + tool_input = tool_call.get("input", {}) + else: + # This is from the content blocks + tool_id = tool_call.get("id") + tool_name = tool_call.get("name") + tool_input = tool_call.get("input", {}) try: # Execute the Airflow tool with the provided arguments and cookie @@ -270,15 +299,63 @@ class AnthropicProvider(BaseLLMProvider): response: Anthropic streaming response object Returns: - Generator yielding content chunks + Generator yielding content chunks with tool_call attribute if detected """ logger.info("Starting Anthropic streaming response processing") + # Track only the first tool call detected during streaming + tool_call = None + tool_use_detected = False + def generate(): + nonlocal tool_call, tool_use_detected + for chunk in response: logger.debug(f"Chunk type: {type(chunk)}") + logger.debug(f"Chunk content: {json.dumps(chunk.json) if hasattr(chunk, 'json') else str(chunk)}") - # Handle different types of chunks from Anthropic API + # Check for content_block_start events with type "tool_use" + if not tool_use_detected and hasattr(chunk, "type") and chunk.type == "content_block_start": + if hasattr(chunk, "content_block") and hasattr(chunk.content_block, "type"): + if chunk.content_block.type == "tool_use": + logger.info(f"Tool use detected in streaming response: {json.dumps(chunk.json) if hasattr(chunk, 'json') else str(chunk)}") + tool_use_detected = True + tool_call = {"id": getattr(chunk.content_block, "id", ""), "name": getattr(chunk.content_block, "name", ""), "input": getattr(chunk.content_block, "input", {})} + # We don't signal to the frontend during streaming + # The tool will only be executed after streaming ends + continue + + # Handle content_block_delta events for tool_use (input updates) + if tool_use_detected and hasattr(chunk, "type") and chunk.type == "content_block_delta": + if hasattr(chunk, "delta") and hasattr(chunk.delta, "type") and chunk.delta.type == "input_json_delta": + if hasattr(chunk.delta, "partial_json"): + logger.info(f"Tool use input update: {chunk.delta.partial_json}") + # Update the current tool call input + if tool_call: + try: + # Try to parse the partial JSON and update the input + partial_input = json.loads(chunk.delta.partial_json) + tool_call["input"].update(partial_input) + except json.JSONDecodeError: + logger.warning(f"Failed to parse partial JSON: {chunk.delta.partial_json}") + continue + + # Handle content_block_stop events for tool_use + if tool_use_detected and hasattr(chunk, "type") and chunk.type == "content_block_stop": + logger.info("Tool use block completed") + # Log the complete tool call for debugging + if tool_call: + logger.info(f"Completed tool call: {json.dumps(tool_call)}") + continue + + # Handle message_delta events with stop_reason "tool_use" + if hasattr(chunk, "type") and chunk.type == "message_delta": + if hasattr(chunk, "delta") and hasattr(chunk.delta, "stop_reason"): + if chunk.delta.stop_reason == "tool_use": + logger.info("Message stopped due to tool use") + continue + + # Handle regular content chunks content = None if hasattr(chunk, "type") and chunk.type == "content_block_delta": if hasattr(chunk, "delta") and hasattr(chunk.delta, "text"): @@ -294,4 +371,12 @@ class AnthropicProvider(BaseLLMProvider): # Don't do any newline replacement here yield content - return generate() + # Create the generator + gen = generate() + + # Attach the single tool_call to the generator object for later reference + # This will be used after streaming is complete + gen.tool_call = tool_call + + # Return the enhanced generator + return gen From 5491ba71aa6d0a3396706e555aeabb36d23f11fa Mon Sep 17 00:00:00 2001 From: abhishekbhakat Date: Sun, 2 Mar 2025 07:44:24 +0000 Subject: [PATCH 11/14] fix tool call until cookies are legit --- src/airflow_wingman/llm_client.py | 11 +- .../providers/anthropic_provider.py | 142 +++++++++++++----- src/airflow_wingman/providers/base.py | 111 ++++++++++++-- .../providers/openai_provider.py | 142 ++++++++++++++++-- src/airflow_wingman/tools/execution.py | 48 +++++- src/airflow_wingman/views.py | 19 ++- 6 files changed, 394 insertions(+), 79 deletions(-) diff --git a/src/airflow_wingman/llm_client.py b/src/airflow_wingman/llm_client.py index b81a19b..c13b2d7 100644 --- a/src/airflow_wingman/llm_client.py +++ b/src/airflow_wingman/llm_client.py @@ -81,10 +81,8 @@ class LLMClient: # If streaming, handle based on return_response_obj flag if stream: logger.info(f"Using streaming response from {self.provider_name}") - if return_response_obj: - return response, self.provider.get_streaming_content(response) - else: - return self.provider.get_streaming_content(response) + streaming_content = self.provider.get_streaming_content(response) + return streaming_content # For non-streaming responses, handle tool calls if present if self.provider.has_tool_calls(response): @@ -143,7 +141,7 @@ class LLMClient: return cls(provider_name=provider_name, api_key=api_key, base_url=base_url) - def process_tool_calls_and_follow_up(self, response, messages, model, temperature, max_tokens, max_iterations=5): + def process_tool_calls_and_follow_up(self, response, messages, model, temperature, max_tokens, max_iterations=5, cookie=None): """ Process tool calls recursively from a response and make follow-up requests until there are no more tool calls or max_iterations is reached. @@ -156,6 +154,7 @@ class LLMClient: temperature: Sampling temperature (0-1) max_tokens: Maximum tokens to generate max_iterations: Maximum number of tool call iterations to prevent infinite loops + cookie: Airflow cookie for authentication (optional, will try to get from session if not provided) Returns: Generator for streaming the final follow-up response @@ -163,8 +162,8 @@ class LLMClient: try: iteration = 0 current_response = response - cookie = session.get("airflow_cookie") + # Check if we have a cookie if not cookie: error_msg = "No Airflow cookie available" logger.error(error_msg) diff --git a/src/airflow_wingman/providers/anthropic_provider.py b/src/airflow_wingman/providers/anthropic_provider.py index 6996e30..94a25c6 100644 --- a/src/airflow_wingman/providers/anthropic_provider.py +++ b/src/airflow_wingman/providers/anthropic_provider.py @@ -12,7 +12,7 @@ from typing import Any from anthropic import Anthropic -from airflow_wingman.providers.base import BaseLLMProvider +from airflow_wingman.providers.base import BaseLLMProvider, StreamingResponse from airflow_wingman.tools import execute_airflow_tool from airflow_wingman.tools.conversion import convert_to_anthropic_tools @@ -101,7 +101,18 @@ class AnthropicProvider(BaseLLMProvider): # Log as much information as possible if hasattr(response, "json"): - logger.info(f"Anthropic response json: {json.dumps(response.json)}") + if callable(response.json): + # If json is a method, call it + try: + logger.info(f"Anthropic response json: {json.dumps(response.json())}") + except Exception as json_err: + logger.warning(f"Could not serialize response.json(): {str(json_err)}") + else: + # If json is a property, use it directly + try: + logger.info(f"Anthropic response json: {json.dumps(response.json)}") + except Exception as json_err: + logger.warning(f"Could not serialize response.json: {str(json_err)}") # Log response attributes response_attrs = [attr for attr in dir(response) if not attr.startswith("_") and not callable(getattr(response, attr))] @@ -150,22 +161,63 @@ class AnthropicProvider(BaseLLMProvider): Check if the response contains tool calls. Args: - response: Anthropic response object or generator with tool_call attribute + response: Anthropic response object or StreamingResponse with tool_call attribute Returns: True if the response contains tool calls, False otherwise """ - # Check if response is a generator with a tool_call attribute - if hasattr(response, "tool_call") and response.tool_call is not None: - return True + logger.info(f"Checking for tool calls in response of type: {type(response)}") + + # Check if response is a StreamingResponse with a tool_call attribute + if isinstance(response, StreamingResponse): + logger.info(f"Response is a StreamingResponse, has tool_call attribute: {hasattr(response, 'tool_call')}") + if response.tool_call is not None: + logger.info(f"StreamingResponse has non-None tool_call: {response.tool_call}") + return True + else: + logger.info("StreamingResponse has None tool_call") + else: + logger.info("Response is not a StreamingResponse") # Check if any content block is a tool_use block (for non-streaming responses) if hasattr(response, "content"): + logger.info(f"Response has content attribute with {len(response.content)} blocks") + for i, block in enumerate(response.content): + logger.info(f"Checking content block {i}: {type(block)}") + if isinstance(block, dict) and block.get("type") == "tool_use": + logger.info(f"Found tool_use block: {block}") + return True + else: + logger.info("Response does not have content attribute") + + logger.info("No tool calls found in response") + return False + + def get_tool_calls(self, response: Any) -> list: + """ + Extract tool calls from the response. + + Args: + response: Anthropic response object or StreamingResponse with tool_call attribute + + Returns: + List of tool call objects in a standardized format + """ + tool_calls = [] + + # Check if response is a StreamingResponse with a tool_call attribute + if isinstance(response, StreamingResponse) and response.tool_call is not None: + logger.info(f"Extracting tool call from StreamingResponse: {response.tool_call}") + tool_calls.append(response.tool_call) + # Otherwise, extract tool calls from response content (for non-streaming responses) + elif hasattr(response, "content"): + logger.info("Extracting tool calls from response content") for block in response.content: if isinstance(block, dict) and block.get("type") == "tool_use": - return True + tool_call = {"id": block.get("id", ""), "name": block.get("name", ""), "input": block.get("input", {})} + tool_calls.append(tool_call) - return False + return tool_calls def process_tool_calls(self, response: Any, cookie: str) -> dict[str, Any]: """ @@ -183,16 +235,9 @@ class AnthropicProvider(BaseLLMProvider): if not self.has_tool_calls(response): return results - tool_calls = [] - - # Check if response is a generator with a tool_call attribute - if hasattr(response, "tool_call") and response.tool_call is not None: - logger.info(f"Processing tool call from generator: {response.tool_call}") - tool_calls.append(response.tool_call) - # Otherwise, extract tool calls from response content (for non-streaming responses) - elif hasattr(response, "content"): - logger.info("Processing tool calls from response content") - tool_calls = [block for block in response.content if isinstance(block, dict) and block.get("type") == "tool_use"] + # Get tool calls using the standardized method + tool_calls = self.get_tool_calls(response) + logger.info(f"Processing {len(tool_calls)} tool calls") for tool_call in tool_calls: # Extract tool details - handle both formats (generator's tool_call and content block) @@ -221,7 +266,14 @@ class AnthropicProvider(BaseLLMProvider): return results def create_follow_up_completion( - self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None + self, + messages: list[dict[str, Any]], + model: str, + temperature: float = 0.4, + max_tokens: int | None = None, + tool_results: dict[str, Any] = None, + original_response: Any = None, + stream: bool = True, ) -> Any: """ Create a follow-up completion with tool results. @@ -233,15 +285,25 @@ class AnthropicProvider(BaseLLMProvider): max_tokens: Maximum tokens to generate tool_results: Results of tool executions original_response: Original response with tool calls + stream: Whether to stream the response Returns: - Anthropic response object + Anthropic response object or generator if streaming """ if not original_response or not tool_results: return original_response - # Extract tool_use blocks from the original response - tool_use_blocks = [block for block in original_response.content if isinstance(block, dict) and block.get("type") == "tool_use"] + # Extract tool call from the StreamingResponse or content blocks from Anthropic response + tool_use_blocks = [] + if isinstance(original_response, StreamingResponse) and original_response.tool_call: + # For StreamingResponse, create a tool_use block from the tool_call + logger.info(f"Creating tool_use block from StreamingResponse.tool_call: {original_response.tool_call}") + tool_call = original_response.tool_call + tool_use_blocks.append({"type": "tool_use", "id": tool_call.get("id", ""), "name": tool_call.get("name", ""), "input": tool_call.get("input", {})}) + elif hasattr(original_response, "content"): + # For regular Anthropic response, extract from content blocks + logger.info("Extracting tool_use blocks from response content") + tool_use_blocks = [block for block in original_response.content if isinstance(block, dict) and block.get("type") == "tool_use"] # Create tool result blocks tool_result_blocks = [] @@ -258,13 +320,13 @@ class AnthropicProvider(BaseLLMProvider): anthropic_messages.append({"role": "user", "content": tool_result_blocks}) # Make a second request to get the final response - logger.info("Making second request with tool results") + logger.info(f"Making second request with tool results (stream={stream})") return self.create_chat_completion( messages=anthropic_messages, model=model, temperature=temperature, max_tokens=max_tokens, - stream=False, + stream=stream, tools=None, # No tools needed for follow-up ) @@ -291,7 +353,7 @@ class AnthropicProvider(BaseLLMProvider): return "".join(content_parts) - def get_streaming_content(self, response: Any) -> Any: + def get_streaming_content(self, response: Any) -> StreamingResponse: """ Get a generator for streaming content from the response. @@ -299,7 +361,8 @@ class AnthropicProvider(BaseLLMProvider): response: Anthropic streaming response object Returns: - Generator yielding content chunks with tool_call attribute if detected + StreamingResponse object wrapping a generator that yields content chunks + and can also store tool call information detected during streaming """ logger.info("Starting Anthropic streaming response processing") @@ -307,20 +370,25 @@ class AnthropicProvider(BaseLLMProvider): tool_call = None tool_use_detected = False + # Create the StreamingResponse object first + streaming_response = StreamingResponse(generator=None, tool_call=None) + def generate(): nonlocal tool_call, tool_use_detected for chunk in response: logger.debug(f"Chunk type: {type(chunk)}") - logger.debug(f"Chunk content: {json.dumps(chunk.json) if hasattr(chunk, 'json') else str(chunk)}") + logger.debug(f"Chunk content: {json.dumps(chunk.model_dump_json()) if hasattr(chunk, 'json') else str(chunk)}") # Check for content_block_start events with type "tool_use" if not tool_use_detected and hasattr(chunk, "type") and chunk.type == "content_block_start": if hasattr(chunk, "content_block") and hasattr(chunk.content_block, "type"): if chunk.content_block.type == "tool_use": - logger.info(f"Tool use detected in streaming response: {json.dumps(chunk.json) if hasattr(chunk, 'json') else str(chunk)}") + logger.info(f"Tool use detected in streaming response: {json.dumps(chunk.model_dump_json()) if hasattr(chunk, 'json') else str(chunk)}") tool_use_detected = True tool_call = {"id": getattr(chunk.content_block, "id", ""), "name": getattr(chunk.content_block, "name", ""), "input": getattr(chunk.content_block, "input", {})} + # Update the StreamingResponse object's tool_call attribute + streaming_response.tool_call = tool_call # We don't signal to the frontend during streaming # The tool will only be executed after streaming ends continue @@ -328,7 +396,7 @@ class AnthropicProvider(BaseLLMProvider): # Handle content_block_delta events for tool_use (input updates) if tool_use_detected and hasattr(chunk, "type") and chunk.type == "content_block_delta": if hasattr(chunk, "delta") and hasattr(chunk.delta, "type") and chunk.delta.type == "input_json_delta": - if hasattr(chunk.delta, "partial_json"): + if hasattr(chunk.delta, "partial_json") and chunk.delta.partial_json: logger.info(f"Tool use input update: {chunk.delta.partial_json}") # Update the current tool call input if tool_call: @@ -336,6 +404,8 @@ class AnthropicProvider(BaseLLMProvider): # Try to parse the partial JSON and update the input partial_input = json.loads(chunk.delta.partial_json) tool_call["input"].update(partial_input) + # Update the StreamingResponse object's tool_call attribute + streaming_response.tool_call = tool_call except json.JSONDecodeError: logger.warning(f"Failed to parse partial JSON: {chunk.delta.partial_json}") continue @@ -346,6 +416,8 @@ class AnthropicProvider(BaseLLMProvider): # Log the complete tool call for debugging if tool_call: logger.info(f"Completed tool call: {json.dumps(tool_call)}") + # Update the StreamingResponse object's tool_call attribute + streaming_response.tool_call = tool_call continue # Handle message_delta events with stop_reason "tool_use" @@ -353,6 +425,9 @@ class AnthropicProvider(BaseLLMProvider): if hasattr(chunk, "delta") and hasattr(chunk.delta, "stop_reason"): if chunk.delta.stop_reason == "tool_use": logger.info("Message stopped due to tool use") + # Update the StreamingResponse object's tool_call attribute one last time + if tool_call: + streaming_response.tool_call = tool_call continue # Handle regular content chunks @@ -374,9 +449,8 @@ class AnthropicProvider(BaseLLMProvider): # Create the generator gen = generate() - # Attach the single tool_call to the generator object for later reference - # This will be used after streaming is complete - gen.tool_call = tool_call + # Set the generator in the StreamingResponse object + streaming_response.generator = gen - # Return the enhanced generator - return gen + # Return the StreamingResponse object + return streaming_response diff --git a/src/airflow_wingman/providers/base.py b/src/airflow_wingman/providers/base.py index 87f4e7c..18195dc 100644 --- a/src/airflow_wingman/providers/base.py +++ b/src/airflow_wingman/providers/base.py @@ -6,8 +6,45 @@ must adhere to. It defines the methods required for tool conversion, API request and response processing. """ +import json from abc import ABC, abstractmethod -from typing import Any +from collections.abc import Generator, Iterator +from typing import Any, Generic, TypeVar + +T = TypeVar("T") + + +class StreamingResponse(Generic[T]): + """ + Wrapper for streaming responses that can hold tool call information. + + This class wraps a generator and provides an iterator interface while also + storing tool call information. This allows us to associate metadata with + a generator without modifying the generator itself. + """ + + def __init__(self, generator: Generator[T, None, None], tool_call: dict = None): + """ + Initialize the streaming response. + + Args: + generator: The underlying generator yielding content chunks + tool_call: Optional tool call information detected during streaming + """ + self.generator = generator + self.tool_call = tool_call + + def __iter__(self) -> Iterator[T]: + """ + Return self as iterator. + """ + return self + + def __next__(self) -> T: + """ + Get the next item from the generator. + """ + return next(self.generator) class BaseLLMProvider(ABC): @@ -51,32 +88,85 @@ class BaseLLMProvider(ABC): """ pass - @abstractmethod def has_tool_calls(self, response: Any) -> bool: """ Check if the response contains tool calls. Args: - response: Provider-specific response object + response: Provider-specific response object or StreamingResponse Returns: True if the response contains tool calls, False otherwise """ - pass + # Check if response is a StreamingResponse with a tool_call attribute + if isinstance(response, StreamingResponse) and response.tool_call is not None: + return True + + # Provider-specific implementation should handle other cases + return False + + def get_tool_calls(self, response: Any) -> list: + """ + Extract tool calls from the response. + + Args: + response: Provider-specific response object or StreamingResponse + + Returns: + List of tool call objects in a standardized format + """ + tool_calls = [] + + # Check if response is a StreamingResponse with a tool_call attribute + if isinstance(response, StreamingResponse) and response.tool_call is not None: + tool_calls.append(response.tool_call) + + # Provider-specific implementation should handle other cases + return tool_calls - @abstractmethod def process_tool_calls(self, response: Any, cookie: str) -> dict[str, Any]: """ Process tool calls from the response. Args: - response: Provider-specific response object + response: Provider-specific response object or StreamingResponse cookie: Airflow cookie for authentication Returns: Dictionary mapping tool call IDs to results """ - pass + tool_calls = self.get_tool_calls(response) + results = {} + + for tool_call in tool_calls: + tool_name = tool_call.get("name", "") + tool_input = tool_call.get("input", {}) + tool_id = tool_call.get("id", "") + + try: + import logging + + logger = logging.getLogger(__name__) + logger.info(f"Executing tool: {tool_name} with input: {json.dumps(tool_input)}") + + from airflow_wingman.tools import execute_airflow_tool + + result = execute_airflow_tool(tool_name, tool_input, cookie) + + logger.info(f"Tool result: {json.dumps(result)}") + results[tool_id] = { + "name": tool_name, + "input": tool_input, + "output": result, + } + except Exception as e: + import traceback + + error_msg = f"Error executing tool: {str(e)}\n{traceback.format_exc()}" + logger.error(error_msg) + results[tool_id] = {"status": "error", "message": error_msg} + + return results @abstractmethod def create_follow_up_completion( @@ -112,14 +202,15 @@ class BaseLLMProvider(ABC): pass @abstractmethod - def get_streaming_content(self, response: Any) -> Any: + def get_streaming_content(self, response: Any) -> StreamingResponse: """ - Get a generator for streaming content from the response. + Get a StreamingResponse for streaming content from the response. Args: response: Provider-specific response object Returns: - Generator yielding content chunks + StreamingResponse object wrapping a generator that yields content chunks + and can also store tool call information detected during streaming """ pass diff --git a/src/airflow_wingman/providers/openai_provider.py b/src/airflow_wingman/providers/openai_provider.py index a6d5c0c..d1847da 100644 --- a/src/airflow_wingman/providers/openai_provider.py +++ b/src/airflow_wingman/providers/openai_provider.py @@ -12,7 +12,7 @@ from typing import Any from openai import OpenAI -from airflow_wingman.providers.base import BaseLLMProvider +from airflow_wingman.providers.base import BaseLLMProvider, StreamingResponse from airflow_wingman.tools import execute_airflow_tool from airflow_wingman.tools.conversion import convert_to_openai_tools @@ -116,35 +116,75 @@ class OpenAIProvider(BaseLLMProvider): Check if the response contains tool calls. Args: - response: OpenAI response object + response: OpenAI response object or StreamingResponse with tool_call attribute Returns: True if the response contains tool calls, False otherwise """ - message = response.choices[0].message - return hasattr(message, "tool_calls") and message.tool_calls + # Check if response is a StreamingResponse with a tool_call attribute + if isinstance(response, StreamingResponse) and response.tool_call is not None: + return True + + # For non-streaming responses + if hasattr(response, "choices") and len(response.choices) > 0: + message = response.choices[0].message + return hasattr(message, "tool_calls") and message.tool_calls + + return False + + def get_tool_calls(self, response: Any) -> list: + """ + Extract tool calls from the response. + + Args: + response: OpenAI response object or StreamingResponse with tool_call attribute + + Returns: + List of tool call objects in a standardized format + """ + tool_calls = [] + + # Check if response is a StreamingResponse with a tool_call attribute + if isinstance(response, StreamingResponse) and response.tool_call is not None: + logger.info(f"Extracting tool call from StreamingResponse: {response.tool_call}") + tool_calls.append(response.tool_call) + return tool_calls + + # For non-streaming responses + if hasattr(response, "choices") and len(response.choices) > 0: + message = response.choices[0].message + if hasattr(message, "tool_calls") and message.tool_calls: + for tool_call in message.tool_calls: + standardized_tool_call = {"id": tool_call.id, "name": tool_call.function.name, "input": json.loads(tool_call.function.arguments)} + tool_calls.append(standardized_tool_call) + + logger.info(f"Extracted {len(tool_calls)} tool calls from OpenAI response") + return tool_calls def process_tool_calls(self, response: Any, cookie: str) -> dict[str, Any]: """ Process tool calls from the response. Args: - response: OpenAI response object + response: OpenAI response object or StreamingResponse with tool_call attribute cookie: Airflow cookie for authentication Returns: Dictionary mapping tool call IDs to results """ results = {} - message = response.choices[0].message if not self.has_tool_calls(response): return results - for tool_call in message.tool_calls: - tool_id = tool_call.id - function_name = tool_call.function.name - arguments = json.loads(tool_call.function.arguments) + # Get tool calls using the standardized method + tool_calls = self.get_tool_calls(response) + logger.info(f"Processing {len(tool_calls)} tool calls") + + for tool_call in tool_calls: + tool_id = tool_call["id"] + function_name = tool_call["name"] + arguments = tool_call["input"] try: # Execute the Airflow tool with the provided arguments and cookie @@ -220,7 +260,7 @@ class OpenAIProvider(BaseLLMProvider): """ return response.choices[0].message.content - def get_streaming_content(self, response: Any) -> Any: + def get_streaming_content(self, response: Any) -> StreamingResponse: """ Get a generator for streaming content from the response. @@ -228,15 +268,87 @@ class OpenAIProvider(BaseLLMProvider): response: OpenAI streaming response object Returns: - Generator yielding content chunks + StreamingResponse object wrapping a generator that yields content chunks + and can also store tool call information detected during streaming """ logger.info("Starting OpenAI streaming response processing") + # Track only the first tool call detected during streaming + tool_call = None + tool_use_detected = False + current_tool_call = None + + # Create the StreamingResponse object first + streaming_response = StreamingResponse(generator=None, tool_call=None) + def generate(): + nonlocal tool_call, tool_use_detected, current_tool_call + for chunk in response: - if chunk.choices and chunk.choices[0].delta.content: - # Don't do any newline replacement here + # Check for tool call in the delta + if chunk.choices and hasattr(chunk.choices[0].delta, "tool_calls") and chunk.choices[0].delta.tool_calls: + # Tool call detected + if not tool_use_detected: + tool_use_detected = True + logger.info("Tool call detected in streaming response") + + # Initialize the tool call + delta_tool_call = chunk.choices[0].delta.tool_calls[0] + current_tool_call = { + "id": getattr(delta_tool_call, "id", ""), + "name": getattr(delta_tool_call.function, "name", "") if hasattr(delta_tool_call, "function") else "", + "input": {}, + } + # Update the StreamingResponse object's tool_call attribute + streaming_response.tool_call = current_tool_call + else: + # Update the existing tool call + delta_tool_call = chunk.choices[0].delta.tool_calls[0] + + # Update the tool call ID if it's provided in this chunk + if hasattr(delta_tool_call, "id") and delta_tool_call.id and current_tool_call: + current_tool_call["id"] = delta_tool_call.id + + # Update the function name if it's provided in this chunk + if hasattr(delta_tool_call, "function") and hasattr(delta_tool_call.function, "name") and delta_tool_call.function.name and current_tool_call: + current_tool_call["name"] = delta_tool_call.function.name + + # Update the arguments if they're provided in this chunk + if hasattr(delta_tool_call, "function") and hasattr(delta_tool_call.function, "arguments") and delta_tool_call.function.arguments and current_tool_call: + try: + # Try to parse the arguments JSON + arguments = json.loads(delta_tool_call.function.arguments) + if isinstance(arguments, dict): + current_tool_call["input"].update(arguments) + # Update the StreamingResponse object's tool_call attribute + streaming_response.tool_call = current_tool_call + except json.JSONDecodeError: + # If the arguments are not valid JSON, just log a warning + logger.warning(f"Failed to parse arguments: {delta_tool_call.function.arguments}") + + # Skip yielding content for tool call chunks + continue + + # For the final chunk, set the tool_call attribute + if chunk.choices and hasattr(chunk.choices[0], "finish_reason") and chunk.choices[0].finish_reason == "tool_calls": + logger.info("Streaming response finished with tool_calls reason") + if current_tool_call: + tool_call = current_tool_call + logger.info(f"Final tool call: {json.dumps(tool_call)}") + # Update the StreamingResponse object's tool_call attribute + streaming_response.tool_call = tool_call + continue + + # Handle regular content chunks + if chunk.choices and hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content: content = chunk.choices[0].delta.content yield content - return generate() + # Create the generator + gen = generate() + + # Set the generator in the StreamingResponse object + streaming_response.generator = gen + + # Return the StreamingResponse object + return streaming_response diff --git a/src/airflow_wingman/tools/execution.py b/src/airflow_wingman/tools/execution.py index 6242972..5dc88b8 100644 --- a/src/airflow_wingman/tools/execution.py +++ b/src/airflow_wingman/tools/execution.py @@ -31,7 +31,14 @@ async def _list_airflow_tools_async(cookie: str) -> list: # Set up configuration base_url = f"{configuration.conf.get('webserver', 'base_url')}/api/v1/" logger.info(f"Setting up AirflowConfig with base_url: {base_url}") - config = AirflowConfig(base_url=base_url, cookie=cookie, auth_token=None) + + # Format the cookie properly if it doesn't already have the 'session=' prefix + formatted_cookie = cookie + if cookie and not cookie.startswith("session="): + formatted_cookie = f"session={cookie}" + logger.info(f"Formatted cookie with session prefix: {formatted_cookie[:10]}...") + + config = AirflowConfig(base_url=base_url, cookie=formatted_cookie, auth_token=None) # Get available tools logger.info("Getting Airflow tools...") @@ -73,20 +80,32 @@ async def _execute_airflow_tool_async(tool_name: str, arguments: dict, cookie: s # Set up configuration base_url = f"{configuration.conf.get('webserver', 'base_url')}/api/v1/" logger.info(f"Setting up AirflowConfig with base_url: {base_url}") - config = AirflowConfig(base_url=base_url, cookie=cookie, auth_token=None) + + # Format the cookie properly if it doesn't already have the 'session=' prefix + formatted_cookie = cookie + if cookie and not cookie.startswith("session="): + formatted_cookie = f"session={cookie}" + logger.info(f"Formatted cookie with session prefix: {formatted_cookie[:10]}...") + + config = AirflowConfig(base_url=base_url, cookie=formatted_cookie, auth_token=None) # Get the tool logger.info(f"Getting tool: {tool_name}") - tool = await get_tool(config=config, tool_name=tool_name) + tool = await get_tool(config=config, name=tool_name) if not tool: error_msg = f"Tool not found: {tool_name}" logger.error(error_msg) return json.dumps({"error": error_msg}) - # Execute the tool + # Execute the tool - ensure the client is in an async context logger.info(f"Executing tool: {tool_name} with arguments: {arguments}") - result = await tool.run(arguments) + + # The AirflowClient needs to be used as an async context manager + # to properly initialize its session + async with tool.client as client: # noqa F841 + # Now the client has a _session attribute and is in an async context + result = await tool.run(arguments) # Convert result to string if isinstance(result, dict | list): @@ -114,4 +133,21 @@ def execute_airflow_tool(tool_name: str, arguments: dict, cookie: str) -> str: Returns: Result of the tool execution as a string """ - return asyncio.run(_execute_airflow_tool_async(tool_name, arguments, cookie)) + # Create a new event loop for this execution + # This ensures we're always in a clean async context + loop = asyncio.new_event_loop() + + try: + # Set the event loop for this thread + asyncio.set_event_loop(loop) + + # Run the async function in the new event loop + result = loop.run_until_complete(_execute_airflow_tool_async(tool_name, arguments, cookie)) + return result + except Exception as e: + error_msg = f"Error in execute_airflow_tool: {str(e)}\n{traceback.format_exc()}" + logger.error(error_msg) + return json.dumps({"error": error_msg}) + finally: + # Always close the loop to free resources + loop.close() diff --git a/src/airflow_wingman/views.py b/src/airflow_wingman/views.py index bc60804..14417f5 100644 --- a/src/airflow_wingman/views.py +++ b/src/airflow_wingman/views.py @@ -114,16 +114,18 @@ class WingmanView(AppBuilderBaseView): """Handle streaming response.""" try: logger.info("Beginning streaming response") - # Use the enhanced chat_completion method with return_response_obj=True - response_obj, generator = client.chat_completion( - messages=data["messages"], model=data["model"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=True, return_response_obj=True - ) + # Get the cookie at the beginning of the request handler + airflow_cookie = request.cookies.get("session") + logger.info(f"Got airflow_cookie: {airflow_cookie is not None}") - def stream_response(): + # Use the enhanced chat_completion method with return_response_obj=True + streaming_response = client.chat_completion(messages=data["messages"], model=data["model"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=True) + + def stream_response(cookie=airflow_cookie): complete_response = "" # Stream the initial response - for chunk in generator: + for chunk in streaming_response: if chunk: complete_response += chunk yield f"data: {chunk}\n\n" @@ -134,7 +136,7 @@ class WingmanView(AppBuilderBaseView): logger.info("<<< COMPLETE RESPONSE END") # Check for tool calls and make follow-up if needed - if client.provider.has_tool_calls(response_obj): + if client.provider.has_tool_calls(streaming_response): # Signal tool processing start - frontend should disable send button yield f"data: {json.dumps({'event': 'tool_processing_start'})}\n\n" @@ -142,9 +144,10 @@ class WingmanView(AppBuilderBaseView): yield f"data: {json.dumps({'event': 'replace_content'})}\n\n" logger.info("Response contains tool calls, making follow-up request") + logger.info(f"Using cookie from closure: {cookie is not None}") # Process tool calls and get follow-up response (handles recursive tool calls) - follow_up_response = client.process_tool_calls_and_follow_up(response_obj, data["messages"], data["model"], data["temperature"], data["max_tokens"]) + follow_up_response = client.process_tool_calls_and_follow_up(streaming_response, data["messages"], data["model"], data["temperature"], data["max_tokens"], cookie=cookie) # Stream the follow-up response for chunk in follow_up_response: From cd284e8de4f7ba21827022119e0a7c5aa95a07be Mon Sep 17 00:00:00 2001 From: abhishekbhakat Date: Sun, 2 Mar 2025 14:27:52 +0000 Subject: [PATCH 12/14] intermediate fix chain tool response to AI --- src/airflow_wingman/llm_client.py | 14 ++++++++++++-- .../providers/anthropic_provider.py | 6 ++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/src/airflow_wingman/llm_client.py b/src/airflow_wingman/llm_client.py index c13b2d7..f9d818d 100644 --- a/src/airflow_wingman/llm_client.py +++ b/src/airflow_wingman/llm_client.py @@ -100,7 +100,7 @@ class LLMClient: # Create a follow-up completion with the tool results logger.info("Making follow-up request with tool results") follow_up_response = self.provider.create_follow_up_completion( - messages=messages, model=model, temperature=temperature, max_tokens=max_tokens, tool_results=tool_results, original_response=response + messages=messages, model=model, temperature=temperature, max_tokens=max_tokens, tool_results=tool_results, original_response=response, tools=provider_tools ) content = self.provider.get_content(follow_up_response) @@ -184,8 +184,18 @@ class LLMClient: # Only stream on the final iteration should_stream = (iteration == max_iterations) or not self.provider.has_tool_calls(current_response) + # Get provider-specific tool definitions from Airflow tools + provider_tools = self.provider.convert_tools(self.airflow_tools) + follow_up_response = self.provider.create_follow_up_completion( - messages=messages, model=model, temperature=temperature, max_tokens=max_tokens, tool_results=tool_results, original_response=current_response, stream=should_stream + messages=messages, + model=model, + temperature=temperature, + max_tokens=max_tokens, + tool_results=tool_results, + original_response=current_response, + stream=should_stream, + tools=provider_tools, ) # Check if this follow-up response has more tool calls diff --git a/src/airflow_wingman/providers/anthropic_provider.py b/src/airflow_wingman/providers/anthropic_provider.py index 94a25c6..c02441f 100644 --- a/src/airflow_wingman/providers/anthropic_provider.py +++ b/src/airflow_wingman/providers/anthropic_provider.py @@ -104,7 +104,7 @@ class AnthropicProvider(BaseLLMProvider): if callable(response.json): # If json is a method, call it try: - logger.info(f"Anthropic response json: {json.dumps(response.json())}") + logger.info(f"Anthropic response json: {json.dumps(response.model_dump_json())}") except Exception as json_err: logger.warning(f"Could not serialize response.json(): {str(json_err)}") else: @@ -274,6 +274,7 @@ class AnthropicProvider(BaseLLMProvider): tool_results: dict[str, Any] = None, original_response: Any = None, stream: bool = True, + tools: list[dict[str, Any]] | None = None, ) -> Any: """ Create a follow-up completion with tool results. @@ -286,6 +287,7 @@ class AnthropicProvider(BaseLLMProvider): tool_results: Results of tool executions original_response: Original response with tool calls stream: Whether to stream the response + tools: List of tool definitions in Anthropic format Returns: Anthropic response object or generator if streaming @@ -327,7 +329,7 @@ class AnthropicProvider(BaseLLMProvider): temperature=temperature, max_tokens=max_tokens, stream=stream, - tools=None, # No tools needed for follow-up + tools=tools, ) def get_content(self, response: Any) -> str: From dc5e2ef7c2142fdade7d41f40894cb5ddf24bf66 Mon Sep 17 00:00:00 2001 From: abhishekbhakat Date: Sun, 2 Mar 2025 16:51:49 +0000 Subject: [PATCH 13/14] stream follow up responses --- src/airflow_wingman/llm_client.py | 42 +++++++++++++++++++++---------- src/airflow_wingman/views.py | 13 +++++++++- 2 files changed, 41 insertions(+), 14 deletions(-) diff --git a/src/airflow_wingman/llm_client.py b/src/airflow_wingman/llm_client.py index f9d818d..5cd1071 100644 --- a/src/airflow_wingman/llm_client.py +++ b/src/airflow_wingman/llm_client.py @@ -141,7 +141,7 @@ class LLMClient: return cls(provider_name=provider_name, api_key=api_key, base_url=base_url) - def process_tool_calls_and_follow_up(self, response, messages, model, temperature, max_tokens, max_iterations=5, cookie=None): + def process_tool_calls_and_follow_up(self, response, messages, model, temperature, max_tokens, max_iterations=5, cookie=None, stream=True): """ Process tool calls recursively from a response and make follow-up requests until there are no more tool calls or max_iterations is reached. @@ -155,6 +155,7 @@ class LLMClient: max_tokens: Maximum tokens to generate max_iterations: Maximum number of tool call iterations to prevent infinite loops cookie: Airflow cookie for authentication (optional, will try to get from session if not provided) + stream: Whether to stream the response Returns: Generator for streaming the final follow-up response @@ -181,8 +182,10 @@ class LLMClient: # Make follow-up request with tool results logger.info(f"Making follow-up request with tool results (iteration {iteration})") - # Only stream on the final iteration - should_stream = (iteration == max_iterations) or not self.provider.has_tool_calls(current_response) + # Always stream follow-up requests to ensure consistent behavior + # This ensures we get streaming responses from the provider + should_stream = True + logger.info(f"Setting should_stream=True for follow-up request (iteration {iteration})") # Get provider-specific tool definitions from Airflow tools provider_tools = self.provider.convert_tools(self.airflow_tools) @@ -201,15 +204,14 @@ class LLMClient: # Check if this follow-up response has more tool calls if not self.provider.has_tool_calls(follow_up_response): logger.info(f"No more tool calls after iteration {iteration}") - # Final response - return the streaming content - if not should_stream: - # If we didn't stream this response, we need to make a streaming version - content = self.provider.get_content(follow_up_response) - yield content - return - else: - # Return the streaming generator - return self.provider.get_streaming_content(follow_up_response) + # Final response - always yield content in a streaming fashion + # Since we're always streaming now, we can directly yield chunks from the streaming generator + chunk_count = 0 + for chunk in self.provider.get_streaming_content(follow_up_response): + chunk_count += 1 + # logger.info(f"Yielding chunk {chunk_count} from streaming generator: {chunk[:50] if chunk else 'Empty chunk'}...") + yield chunk + logger.info(f"Finished yielding {chunk_count} chunks from streaming generator") # Update current_response for the next iteration current_response = follow_up_response @@ -218,7 +220,21 @@ class LLMClient: if iteration == max_iterations and self.provider.has_tool_calls(current_response): logger.warning(f"Reached maximum tool call iterations ({max_iterations})") # Stream the final response even if it has tool calls - return self.provider.get_streaming_content(follow_up_response) + if not should_stream: + # If we didn't stream this response, convert it to a single chunk + content = self.provider.get_content(follow_up_response) + logger.info(f"Yielding complete content as a single chunk (max iterations): {content[:100]}...") + yield content + logger.info("Finished yielding complete content (max iterations)") + else: + # Yield chunks from the streaming generator + logger.info("Starting to yield chunks from streaming generator (max iterations reached)") + chunk_count = 0 + for chunk in self.provider.get_streaming_content(follow_up_response): + chunk_count += 1 + logger.info(f"Yielding chunk {chunk_count} from streaming generator (max iterations)") + yield chunk + logger.info(f"Finished yielding {chunk_count} chunks from streaming generator (max iterations)") # If we didn't process any tool calls (shouldn't happen), return an error if iteration == 0: diff --git a/src/airflow_wingman/views.py b/src/airflow_wingman/views.py index 14417f5..f02d508 100644 --- a/src/airflow_wingman/views.py +++ b/src/airflow_wingman/views.py @@ -147,13 +147,24 @@ class WingmanView(AppBuilderBaseView): logger.info(f"Using cookie from closure: {cookie is not None}") # Process tool calls and get follow-up response (handles recursive tool calls) - follow_up_response = client.process_tool_calls_and_follow_up(streaming_response, data["messages"], data["model"], data["temperature"], data["max_tokens"], cookie=cookie) + # Always stream the follow-up response for consistent handling + follow_up_response = client.process_tool_calls_and_follow_up( + streaming_response, data["messages"], data["model"], data["temperature"], data["max_tokens"], cookie=cookie, stream=True + ) # Stream the follow-up response + follow_up_complete_response = "" for chunk in follow_up_response: if chunk: + follow_up_complete_response += chunk + # logger.info(f"Yielding chunk to frontend: {chunk[:50]}...") yield f"data: {chunk}\n\n" + # Log the complete follow-up response + logger.info("FOLLOW-UP RESPONSE START >>>") + logger.info(follow_up_complete_response) + logger.info("<<< FOLLOW-UP RESPONSE END") + # Signal tool processing complete - frontend can re-enable send button yield f"data: {json.dumps({'event': 'tool_processing_complete'})}\n\n" From 7c20220c2f1b97beaa998599f62a17390d8515af Mon Sep 17 00:00:00 2001 From: abhishekbhakat Date: Sun, 2 Mar 2025 18:28:49 +0000 Subject: [PATCH 14/14] fix show follow up responses --- src/airflow_wingman/static/js/wingman_chat.js | 16 ++++++++++++ src/airflow_wingman/views.py | 26 ++++++++++++------- 2 files changed, 33 insertions(+), 9 deletions(-) diff --git a/src/airflow_wingman/static/js/wingman_chat.js b/src/airflow_wingman/static/js/wingman_chat.js index 553cbba..e12f53d 100644 --- a/src/airflow_wingman/static/js/wingman_chat.js +++ b/src/airflow_wingman/static/js/wingman_chat.js @@ -256,6 +256,22 @@ document.addEventListener('DOMContentLoaded', function() { continue; } + // Handle follow-up response event + if (parsed.event === 'follow_up_response' && parsed.content) { + console.log('Received follow-up response'); + + // Add this follow-up response to message history + messageHistory.push({ + role: 'assistant', + content: parsed.content + }); + + // Create a new message div for the follow-up response + // The addMessage function already handles markdown rendering + addMessage(parsed.content, false); + continue; + } + // Handle the complete response event if (parsed.event === 'complete_response') { console.log('Received complete response from backend'); diff --git a/src/airflow_wingman/views.py b/src/airflow_wingman/views.py index f02d508..12825fc 100644 --- a/src/airflow_wingman/views.py +++ b/src/airflow_wingman/views.py @@ -136,7 +136,9 @@ class WingmanView(AppBuilderBaseView): logger.info("<<< COMPLETE RESPONSE END") # Check for tool calls and make follow-up if needed - if client.provider.has_tool_calls(streaming_response): + has_tool_calls = client.provider.has_tool_calls(streaming_response) + logger.info(f"Has tool calls: {has_tool_calls}") + if has_tool_calls: # Signal tool processing start - frontend should disable send button yield f"data: {json.dumps({'event': 'tool_processing_start'})}\n\n" @@ -152,18 +154,24 @@ class WingmanView(AppBuilderBaseView): streaming_response, data["messages"], data["model"], data["temperature"], data["max_tokens"], cookie=cookie, stream=True ) - # Stream the follow-up response + # Collect the follow-up response follow_up_complete_response = "" for chunk in follow_up_response: if chunk: follow_up_complete_response += chunk - # logger.info(f"Yielding chunk to frontend: {chunk[:50]}...") - yield f"data: {chunk}\n\n" - - # Log the complete follow-up response - logger.info("FOLLOW-UP RESPONSE START >>>") - logger.info(follow_up_complete_response) - logger.info("<<< FOLLOW-UP RESPONSE END") + + # Send the follow-up response as a single event + if follow_up_complete_response: + follow_up_event = json.dumps({'event': 'follow_up_response', 'content': follow_up_complete_response}) + logger.info(f"Follow-up event created with length: {len(follow_up_event)}") + data_line = f"data: {follow_up_event}\n\n" + logger.info(f"Yielding data line with length: {len(data_line)}") + yield data_line + + # Log the complete follow-up response + logger.info("FOLLOW-UP RESPONSE START >>>") + logger.info(follow_up_complete_response) + logger.info("<<< FOLLOW-UP RESPONSE END") # Signal tool processing complete - frontend can re-enable send button yield f"data: {json.dumps({'event': 'tool_processing_complete'})}\n\n"