From 3904bfc6443843b39cfb96ce3af78b66e39e3956 Mon Sep 17 00:00:00 2001 From: abhishekbhakat Date: Wed, 26 Feb 2025 19:51:49 +0000 Subject: [PATCH] Restructure foe 2 main providers and tools conversion --- pyproject.toml | 3 +- src/airflow_wingman/llm_client.py | 188 +++++++----- src/airflow_wingman/llms_models.py | 4 +- src/airflow_wingman/prompt_engineering.py | 6 +- src/airflow_wingman/providers/__init__.py | 41 +++ .../providers/anthropic_provider.py | 288 ++++++++++++++++++ src/airflow_wingman/providers/base.py | 125 ++++++++ .../providers/openai_provider.py | 224 ++++++++++++++ .../templates/wingman_chat.html | 3 + src/airflow_wingman/tools/__init__.py | 15 + src/airflow_wingman/tools/conversion.py | 139 +++++++++ src/airflow_wingman/tools/execution.py | 117 +++++++ src/airflow_wingman/views.py | 73 +++-- 13 files changed, 1126 insertions(+), 100 deletions(-) create mode 100644 src/airflow_wingman/providers/__init__.py create mode 100644 src/airflow_wingman/providers/anthropic_provider.py create mode 100644 src/airflow_wingman/providers/base.py create mode 100644 src/airflow_wingman/providers/openai_provider.py create mode 100644 src/airflow_wingman/tools/__init__.py create mode 100644 src/airflow_wingman/tools/conversion.py create mode 100644 src/airflow_wingman/tools/execution.py diff --git a/pyproject.toml b/pyproject.toml index c084abf..931b776 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,7 +68,8 @@ lint.select = [ lint.ignore = [ "C416", # Unnecessary list comprehension - rewrite as a generator expression "C408", # Unnecessary `dict` call - rewrite as a literal - "ISC001" # Single line implicit string concatenation + "ISC001", # Single line implicit string concatenation + "C901" ] lint.fixable = ["ALL"] diff --git a/src/airflow_wingman/llm_client.py b/src/airflow_wingman/llm_client.py index dd58370..4fd6768 100644 --- a/src/airflow_wingman/llm_client.py +++ b/src/airflow_wingman/llm_client.py @@ -1,109 +1,145 @@ """ -Client for making API calls to various LLM providers using their official SDKs. +Multi-provider LLM client for Airflow Wingman. + +This module contains the LLMClient class that supports multiple LLM providers +(OpenAI, Anthropic, OpenRouter) through a unified interface. """ -from collections.abc import Generator +import traceback +from typing import Any -from anthropic import Anthropic -from openai import OpenAI +from airflow.utils.log.logging_mixin import LoggingMixin +from flask import session + +from airflow_wingman.providers import create_llm_provider +from airflow_wingman.tools import list_airflow_tools + +# Create a logger instance +logger = LoggingMixin().log class LLMClient: - def __init__(self, api_key: str): - """Initialize the LLM client. + """ + Multi-provider LLM client for Airflow Wingman. + + This class handles chat completion requests to various LLM providers + (OpenAI, Anthropic, OpenRouter) through a unified interface. + """ + + def __init__(self, provider_name: str, api_key: str, base_url: str | None = None): + """ + Initialize the LLM client. Args: + provider_name: Name of the provider (openai, anthropic, openrouter) api_key: API key for the provider + base_url: Optional base URL for the provider API """ + self.provider_name = provider_name self.api_key = api_key - self.openai_client = OpenAI(api_key=api_key) - self.anthropic_client = Anthropic(api_key=api_key) - self.openrouter_client = OpenAI( - base_url="https://openrouter.ai/api/v1", - api_key=api_key, - default_headers={ - "HTTP-Referer": "Airflow Wingman", # Required by OpenRouter - "X-Title": "Airflow Wingman", # Required by OpenRouter - }, - ) + self.base_url = base_url + self.provider = create_llm_provider(provider_name, api_key, base_url) + self.airflow_tools = [] - def chat_completion( - self, messages: list[dict[str, str]], model: str, provider: str, temperature: float = 0.7, max_tokens: int | None = None, stream: bool = False - ) -> Generator[str, None, None] | dict: - """Send a chat completion request to the specified provider. + def set_airflow_tools(self, tools: list): + """ + Set the available Airflow tools. + + Args: + tools: List of Airflow Tool objects + """ + self.airflow_tools = tools + + def chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float = 0.7, max_tokens: int | None = None, stream: bool = False) -> dict[str, Any]: + """ + Send a chat completion request to the LLM provider. Args: messages: List of message dictionaries with 'role' and 'content' model: Model identifier - provider: Provider identifier (openai, anthropic, openrouter) temperature: Sampling temperature (0-1) max_tokens: Maximum tokens to generate - stream: Whether to stream the response + stream: Whether to stream the response (default is True) Returns: - If stream=True, returns a generator yielding response chunks - If stream=False, returns the complete response + Dictionary with the response content or a generator for streaming """ + # Get provider-specific tool definitions from Airflow tools + provider_tools = self.provider.convert_tools(self.airflow_tools) + try: - if provider == "openai": - return self._openai_chat_completion(messages, model, temperature, max_tokens, stream) - elif provider == "anthropic": - return self._anthropic_chat_completion(messages, model, temperature, max_tokens, stream) - elif provider == "openrouter": - return self._openrouter_chat_completion(messages, model, temperature, max_tokens, stream) + # Make the initial request with tools + logger.info(f"Sending chat completion request to {self.provider_name} with model: {model}") + response = self.provider.create_chat_completion(messages=messages, model=model, temperature=temperature, max_tokens=max_tokens, stream=stream, tools=provider_tools) + logger.info(f"Received response from {self.provider_name}") + + # If streaming, return the generator directly + if stream: + return self.provider.get_streaming_content(response) + + # For non-streaming responses, handle tool calls if present + if self.provider.has_tool_calls(response): + logger.info("Response contains tool calls") + + # Process tool calls and get results + cookie = session.get("airflow_cookie") + if not cookie: + error_msg = "No Airflow cookie available" + logger.error(error_msg) + return {"error": error_msg} + + tool_results = self.provider.process_tool_calls(response, cookie) + + # Create a follow-up completion with the tool results + logger.info("Making follow-up request with tool results") + follow_up_response = self.provider.create_follow_up_completion( + messages=messages, model=model, temperature=temperature, max_tokens=max_tokens, tool_results=tool_results, original_response=response + ) + + return {"content": self.provider.get_content(follow_up_response)} else: - return {"error": f"Unknown provider: {provider}"} + logger.info("Response does not contain tool calls") + return {"content": self.provider.get_content(response)} + except Exception as e: + error_msg = f"Error in {self.provider_name} API call: {str(e)}\\n{traceback.format_exc()}" + logger.error(error_msg) return {"error": f"API request failed: {str(e)}"} - def _openai_chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float, max_tokens: int | None, stream: bool): - """Handle OpenAI chat completion requests.""" - response = self.openai_client.chat.completions.create(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream) + @classmethod + def from_config(cls, config: dict[str, Any]) -> "LLMClient": + """ + Create an LLMClient instance from a configuration dictionary. - if stream: + Args: + config: Configuration dictionary with provider_name, api_key, and optional base_url - def response_generator(): - for chunk in response: - if chunk.choices[0].delta.content: - yield chunk.choices[0].delta.content + Returns: + LLMClient instance + """ + provider_name = config.get("provider_name", "openai") + api_key = config.get("api_key") + base_url = config.get("base_url") - return response_generator() - else: - return {"content": response.choices[0].message.content} + if not api_key: + raise ValueError("API key is required") - def _anthropic_chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float, max_tokens: int | None, stream: bool): - """Handle Anthropic chat completion requests.""" - # Convert messages to Anthropic format - system_message = next((m["content"] for m in messages if m["role"] == "system"), None) - conversation = [] - for m in messages: - if m["role"] != "system": - conversation.append({"role": "assistant" if m["role"] == "assistant" else "user", "content": m["content"]}) + return cls(provider_name=provider_name, api_key=api_key, base_url=base_url) - response = self.anthropic_client.messages.create(model=model, messages=conversation, system=system_message, temperature=temperature, max_tokens=max_tokens, stream=stream) + def refresh_tools(self, cookie: str) -> None: + """ + Refresh the available Airflow tools. - if stream: - - def response_generator(): - for chunk in response: - if chunk.delta.text: - yield chunk.delta.text - - return response_generator() - else: - return {"content": response.content[0].text} - - def _openrouter_chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float, max_tokens: int | None, stream: bool): - """Handle OpenRouter chat completion requests.""" - response = self.openrouter_client.chat.completions.create(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream) - - if stream: - - def response_generator(): - for chunk in response: - if chunk.choices[0].delta.content: - yield chunk.choices[0].delta.content - - return response_generator() - else: - return {"content": response.choices[0].message.content} + Args: + cookie: Airflow cookie for authentication + """ + try: + logger.info("Refreshing Airflow tools") + tools = list_airflow_tools(cookie) + self.set_airflow_tools(tools) + logger.info(f"Refreshed {len(tools)} Airflow tools") + except Exception as e: + error_msg = f"Error refreshing Airflow tools: {str(e)}\n{traceback.format_exc()}" + logger.error(error_msg) + # Don't raise the exception, just log it + # The client will continue to use the existing tools (if any) diff --git a/src/airflow_wingman/llms_models.py b/src/airflow_wingman/llms_models.py index eaf6d16..ff6f2b9 100644 --- a/src/airflow_wingman/llms_models.py +++ b/src/airflow_wingman/llms_models.py @@ -17,14 +17,14 @@ MODELS = { "endpoint": "https://api.anthropic.com/v1/messages", "models": [ { - "id": "claude-3.7-sonnet", + "id": "claude-3-7-sonnet-20250219", "name": "Claude 3.7 Sonnet", "default": True, "context_window": 200000, "description": "Input $3/M tokens, Output $15/M tokens", }, { - "id": "claude-3.5-haiku", + "id": "claude-3-5-haiku-20241022", "name": "Claude 3.5 Haiku", "default": False, "context_window": 200000, diff --git a/src/airflow_wingman/prompt_engineering.py b/src/airflow_wingman/prompt_engineering.py index 31b5d8e..74a67e5 100644 --- a/src/airflow_wingman/prompt_engineering.py +++ b/src/airflow_wingman/prompt_engineering.py @@ -8,9 +8,11 @@ INSTRUCTIONS = { You have deep knowledge of Apache Airflow's architecture, DAGs, operators, and best practices. The Airflow version being used is >=2.10. -You have access to the following Airflow API tools: +You have access to Airflow MCP tools that you can use to fetch information and help users understand +and manage their Airflow environment. -You can use these tools to fetch information and help users understand and manage their Airflow environment. +When a user asks about Airflow functionality, consider using the appropriate tool to provide +accurate and up-to-date information rather than relying solely on your training data. """ } diff --git a/src/airflow_wingman/providers/__init__.py b/src/airflow_wingman/providers/__init__.py new file mode 100644 index 0000000..828726c --- /dev/null +++ b/src/airflow_wingman/providers/__init__.py @@ -0,0 +1,41 @@ +""" +Provider factory for Airflow Wingman. + +This module contains the factory function to create provider instances +based on the provider name. +""" + +from airflow_wingman.providers.anthropic_provider import AnthropicProvider +from airflow_wingman.providers.base import BaseLLMProvider +from airflow_wingman.providers.openai_provider import OpenAIProvider + + +def create_llm_provider(provider_name: str, api_key: str, base_url: str | None = None) -> BaseLLMProvider: + """ + Create a provider instance based on the provider name. + + Args: + provider_name: Name of the provider (openai, anthropic, openrouter) + api_key: API key for the provider + base_url: Optional base URL for the provider API + + Returns: + Provider instance + + Raises: + ValueError: If the provider is not supported + """ + provider_name = provider_name.lower() + + if provider_name == "openai": + return OpenAIProvider(api_key=api_key, base_url=base_url) + elif provider_name == "openrouter": + # OpenRouter uses the OpenAI API format, so we can use the OpenAI provider + # with a custom base URL + if not base_url: + base_url = "https://openrouter.ai/api/v1" + return OpenAIProvider(api_key=api_key, base_url=base_url) + elif provider_name == "anthropic": + return AnthropicProvider(api_key=api_key) + else: + raise ValueError(f"Unsupported provider: {provider_name}. Supported providers: openai, anthropic, openrouter") diff --git a/src/airflow_wingman/providers/anthropic_provider.py b/src/airflow_wingman/providers/anthropic_provider.py new file mode 100644 index 0000000..f2ecb2e --- /dev/null +++ b/src/airflow_wingman/providers/anthropic_provider.py @@ -0,0 +1,288 @@ +""" +Anthropic provider implementation for Airflow Wingman. + +This module contains the Anthropic provider implementation that handles +API requests, tool conversion, and response processing for Anthropic's Claude models. +""" + +import traceback +from typing import Any + +from airflow.utils.log.logging_mixin import LoggingMixin +from anthropic import Anthropic + +from airflow_wingman.providers.base import BaseLLMProvider +from airflow_wingman.tools import execute_airflow_tool +from airflow_wingman.tools.conversion import convert_to_anthropic_tools + +logger = LoggingMixin().log + + +class AnthropicProvider(BaseLLMProvider): + """ + Anthropic provider implementation. + + This class handles API requests, tool conversion, and response processing + for the Anthropic API (Claude models). + """ + + def __init__(self, api_key: str): + """ + Initialize the Anthropic provider. + + Args: + api_key: API key for Anthropic + """ + self.api_key = api_key + self.client = Anthropic(api_key=api_key) + + def convert_tools(self, airflow_tools: list) -> list: + """ + Convert Airflow tools to Anthropic format. + + Args: + airflow_tools: List of Airflow tools from MCP server + + Returns: + List of Anthropic tool definitions + """ + return convert_to_anthropic_tools(airflow_tools) + + def create_chat_completion( + self, messages: list[dict[str, Any]], model: str, temperature: float = 0.7, max_tokens: int | None = None, stream: bool = False, tools: list[dict[str, Any]] | None = None + ) -> Any: + """ + Make API request to Anthropic. + + Args: + messages: List of message dictionaries with 'role' and 'content' + model: Model identifier + temperature: Sampling temperature (0-1) + max_tokens: Maximum tokens to generate + stream: Whether to stream the response + tools: List of tool definitions in Anthropic format + + Returns: + Anthropic response object + + Raises: + Exception: If the API request fails + """ + # Convert max_tokens to Anthropic's max_tokens parameter (if provided) + max_tokens_param = max_tokens if max_tokens is not None else 4096 + + # Convert messages from ChatML format to Anthropic's format + anthropic_messages = self._convert_to_anthropic_messages(messages) + + try: + logger.info(f"Sending chat completion request to Anthropic with model: {model}") + + # Create request parameters + params = {"model": model, "messages": anthropic_messages, "temperature": temperature, "max_tokens": max_tokens_param, "stream": stream} + + # Add tools if provided + if tools and len(tools) > 0: + params["tools"] = tools + + # Make the API request + response = self.client.messages.create(**params) + + logger.info("Received response from Anthropic") + return response + except Exception as e: + error_msg = str(e) + logger.error(f"Failed to get response from Anthropic: {error_msg}\n{traceback.format_exc()}") + raise + + def _convert_to_anthropic_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]: + """ + Convert messages from ChatML format to Anthropic's format. + + Args: + messages: List of message dictionaries in ChatML format + + Returns: + List of message dictionaries in Anthropic format + """ + anthropic_messages = [] + + for message in messages: + role = message["role"] + content = message["content"] + + # Map ChatML roles to Anthropic roles + if role == "system": + # System messages in Anthropic are handled differently + # We'll add them as a user message with a special prefix + anthropic_messages.append({"role": "user", "content": f"\n{content}\n"}) + elif role == "user": + anthropic_messages.append({"role": "user", "content": content}) + elif role == "assistant": + anthropic_messages.append({"role": "assistant", "content": content}) + elif role == "tool": + # Tool messages in ChatML become part of the user message in Anthropic + # We'll handle this in the follow-up completion + continue + + return anthropic_messages + + def has_tool_calls(self, response: Any) -> bool: + """ + Check if the response contains tool calls. + + Args: + response: Anthropic response object + + Returns: + True if the response contains tool calls, False otherwise + """ + # Check if any content block is a tool_use block + if not hasattr(response, "content"): + return False + + for block in response.content: + if isinstance(block, dict) and block.get("type") == "tool_use": + return True + + return False + + def process_tool_calls(self, response: Any, cookie: str) -> dict[str, Any]: + """ + Process tool calls from the response. + + Args: + response: Anthropic response object + cookie: Airflow cookie for authentication + + Returns: + Dictionary mapping tool call IDs to results + """ + results = {} + + if not self.has_tool_calls(response): + return results + + # Extract tool_use blocks + tool_use_blocks = [block for block in response.content if isinstance(block, dict) and block.get("type") == "tool_use"] + + for block in tool_use_blocks: + tool_id = block.get("id") + tool_name = block.get("name") + tool_input = block.get("input", {}) + + try: + # Execute the Airflow tool with the provided arguments and cookie + logger.info(f"Executing tool: {tool_name} with arguments: {tool_input}") + result = execute_airflow_tool(tool_name, tool_input, cookie) + logger.info(f"Tool execution result: {result}") + results[tool_id] = {"status": "success", "result": result} + except Exception as e: + error_msg = f"Error executing tool: {str(e)}\n{traceback.format_exc()}" + logger.error(error_msg) + results[tool_id] = {"status": "error", "message": error_msg} + + return results + + def create_follow_up_completion( + self, messages: list[dict[str, Any]], model: str, temperature: float = 0.7, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None + ) -> Any: + """ + Create a follow-up completion with tool results. + + Args: + messages: Original messages + model: Model identifier + temperature: Sampling temperature (0-1) + max_tokens: Maximum tokens to generate + tool_results: Results of tool executions + original_response: Original response with tool calls + + Returns: + Anthropic response object + """ + if not original_response or not tool_results: + return original_response + + # Extract tool_use blocks from the original response + tool_use_blocks = [block for block in original_response.content if isinstance(block, dict) and block.get("type") == "tool_use"] + + # Create tool result blocks + tool_result_blocks = [] + for tool_id, result in tool_results.items(): + tool_result_blocks.append({"type": "tool_result", "tool_use_id": tool_id, "content": result.get("result", str(result))}) + + # Convert original messages to Anthropic format + anthropic_messages = self._convert_to_anthropic_messages(messages) + + # Add the assistant response with tool use + anthropic_messages.append({"role": "assistant", "content": tool_use_blocks}) + + # Add the user message with tool results + anthropic_messages.append({"role": "user", "content": tool_result_blocks}) + + # Make a second request to get the final response + logger.info("Making second request with tool results") + return self.create_chat_completion( + messages=anthropic_messages, + model=model, + temperature=temperature, + max_tokens=max_tokens, + stream=False, + tools=None, # No tools needed for follow-up + ) + + def get_content(self, response: Any) -> str: + """ + Extract content from the response. + + Args: + response: Anthropic response object + + Returns: + Content string from the response + """ + if not hasattr(response, "content"): + return "" + + # Combine all text blocks into a single string + content_parts = [] + for block in response.content: + if isinstance(block, dict) and block.get("type") == "text": + content_parts.append(block.get("text", "")) + elif isinstance(block, str): + content_parts.append(block) + + return "".join(content_parts) + + def get_streaming_content(self, response: Any) -> Any: + """ + Get a generator for streaming content from the response. + + Args: + response: Anthropic streaming response object + + Returns: + Generator yielding content chunks + """ + + def generate(): + for chunk in response: + logger.debug(f"Chunk type: {type(chunk)}") + + # Handle different types of chunks from Anthropic API + content = None + if hasattr(chunk, "type") and chunk.type == "content_block_delta": + if hasattr(chunk, "delta") and hasattr(chunk.delta, "text"): + content = chunk.delta.text + elif hasattr(chunk, "delta") and hasattr(chunk.delta, "text"): + content = chunk.delta.text + elif hasattr(chunk, "content") and chunk.content: + for block in chunk.content: + if isinstance(block, dict) and block.get("type") == "text": + content = block.get("text", "") + + if content: + # Don't do any newline replacement here + yield content + + return generate() diff --git a/src/airflow_wingman/providers/base.py b/src/airflow_wingman/providers/base.py new file mode 100644 index 0000000..dcdbc44 --- /dev/null +++ b/src/airflow_wingman/providers/base.py @@ -0,0 +1,125 @@ +""" +Base provider interface for Airflow Wingman. + +This module contains the base provider interface that all provider implementations +must adhere to. It defines the methods required for tool conversion, API requests, +and response processing. +""" + +from abc import ABC, abstractmethod +from typing import Any + + +class BaseLLMProvider(ABC): + """ + Base provider interface for LLM providers. + + This abstract class defines the methods that all provider implementations + must implement to support tool integration. + """ + + @abstractmethod + def convert_tools(self, airflow_tools: list) -> list: + """ + Convert internal tool representation to provider format. + + Args: + airflow_tools: List of Airflow tools from MCP server + + Returns: + List of provider-specific tool definitions + """ + pass + + @abstractmethod + def create_chat_completion( + self, messages: list[dict[str, Any]], model: str, temperature: float = 0.7, max_tokens: int | None = None, stream: bool = False, tools: list[dict[str, Any]] | None = None + ) -> Any: + """ + Make API request to provider. + + Args: + messages: List of message dictionaries with 'role' and 'content' + model: Model identifier + temperature: Sampling temperature (0-1) + max_tokens: Maximum tokens to generate + stream: Whether to stream the response + tools: List of tool definitions in provider format + + Returns: + Provider-specific response object + """ + pass + + @abstractmethod + def has_tool_calls(self, response: Any) -> bool: + """ + Check if the response contains tool calls. + + Args: + response: Provider-specific response object + + Returns: + True if the response contains tool calls, False otherwise + """ + pass + + @abstractmethod + def process_tool_calls(self, response: Any, cookie: str) -> dict[str, Any]: + """ + Process tool calls from the response. + + Args: + response: Provider-specific response object + cookie: Airflow cookie for authentication + + Returns: + Dictionary mapping tool call IDs to results + """ + pass + + @abstractmethod + def create_follow_up_completion( + self, messages: list[dict[str, Any]], model: str, temperature: float = 0.7, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None + ) -> Any: + """ + Create a follow-up completion with tool results. + + Args: + messages: Original messages + model: Model identifier + temperature: Sampling temperature (0-1) + max_tokens: Maximum tokens to generate + tool_results: Results of tool executions + original_response: Original response with tool calls + + Returns: + Provider-specific response object + """ + pass + + @abstractmethod + def get_content(self, response: Any) -> str: + """ + Extract content from the response. + + Args: + response: Provider-specific response object + + Returns: + Content string from the response + """ + pass + + @abstractmethod + def get_streaming_content(self, response: Any) -> Any: + """ + Get a generator for streaming content from the response. + + Args: + response: Provider-specific response object + + Returns: + Generator yielding content chunks + """ + pass diff --git a/src/airflow_wingman/providers/openai_provider.py b/src/airflow_wingman/providers/openai_provider.py new file mode 100644 index 0000000..6d12cbf --- /dev/null +++ b/src/airflow_wingman/providers/openai_provider.py @@ -0,0 +1,224 @@ +""" +OpenAI provider implementation for Airflow Wingman. + +This module contains the OpenAI provider implementation that handles +API requests, tool conversion, and response processing for OpenAI. +""" + +import json +import traceback +from typing import Any + +from airflow.utils.log.logging_mixin import LoggingMixin +from openai import OpenAI + +from airflow_wingman.providers.base import BaseLLMProvider +from airflow_wingman.tools import execute_airflow_tool +from airflow_wingman.tools.conversion import convert_to_openai_tools + +# Create a logger instance +logger = LoggingMixin().log + + +class OpenAIProvider(BaseLLMProvider): + """ + OpenAI provider implementation. + + This class handles API requests, tool conversion, and response processing + for the OpenAI API. It can also be used for OpenRouter with a custom base URL. + """ + + def __init__(self, api_key: str, base_url: str | None = None): + """ + Initialize the OpenAI provider. + + Args: + api_key: API key for OpenAI + base_url: Optional base URL for the API (used for OpenRouter) + """ + self.api_key = api_key + self.client = OpenAI(api_key=api_key, base_url=base_url) + + def convert_tools(self, airflow_tools: list) -> list: + """ + Convert Airflow tools to OpenAI format. + + Args: + airflow_tools: List of Airflow tools from MCP server + + Returns: + List of OpenAI tool definitions + """ + return convert_to_openai_tools(airflow_tools) + + def create_chat_completion( + self, messages: list[dict[str, Any]], model: str, temperature: float = 0.7, max_tokens: int | None = None, stream: bool = False, tools: list[dict[str, Any]] | None = None + ) -> Any: + """ + Make API request to OpenAI. + + Args: + messages: List of message dictionaries with 'role' and 'content' + model: Model identifier + temperature: Sampling temperature (0-1) + max_tokens: Maximum tokens to generate + stream: Whether to stream the response + tools: List of tool definitions in OpenAI format + + Returns: + OpenAI response object + + Raises: + Exception: If the API request fails + """ + # Only include tools if we have any + has_tools = tools is not None and len(tools) > 0 + tool_choice = "auto" if has_tools else None + + try: + logger.info(f"Sending chat completion request to OpenAI with model: {model}") + response = self.client.chat.completions.create( + model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream, tools=tools if has_tools else None, tool_choice=tool_choice + ) + logger.info("Received response from OpenAI") + return response + except Exception as e: + # If the API call fails due to tools not being supported, retry without tools + error_msg = str(e) + logger.warning(f"Error in OpenAI API call: {error_msg}") + if "tools" in error_msg.lower(): + logger.info("Retrying without tools") + response = self.client.chat.completions.create(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream) + return response + else: + logger.error(f"Failed to get response from OpenAI: {error_msg}\n{traceback.format_exc()}") + raise + + def has_tool_calls(self, response: Any) -> bool: + """ + Check if the response contains tool calls. + + Args: + response: OpenAI response object + + Returns: + True if the response contains tool calls, False otherwise + """ + message = response.choices[0].message + return hasattr(message, "tool_calls") and message.tool_calls + + def process_tool_calls(self, response: Any, cookie: str) -> dict[str, Any]: + """ + Process tool calls from the response. + + Args: + response: OpenAI response object + cookie: Airflow cookie for authentication + + Returns: + Dictionary mapping tool call IDs to results + """ + results = {} + message = response.choices[0].message + + if not self.has_tool_calls(response): + return results + + for tool_call in message.tool_calls: + tool_id = tool_call.id + function_name = tool_call.function.name + arguments = json.loads(tool_call.function.arguments) + + try: + # Execute the Airflow tool with the provided arguments and cookie + logger.info(f"Executing tool: {function_name} with arguments: {arguments}") + result = execute_airflow_tool(function_name, arguments, cookie) + logger.info(f"Tool execution result: {result}") + results[tool_id] = {"status": "success", "result": result} + except Exception as e: + error_msg = f"Error executing tool: {str(e)}\n{traceback.format_exc()}" + logger.error(error_msg) + results[tool_id] = {"status": "error", "message": error_msg} + + return results + + def create_follow_up_completion( + self, messages: list[dict[str, Any]], model: str, temperature: float = 0.7, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None + ) -> Any: + """ + Create a follow-up completion with tool results. + + Args: + messages: Original messages + model: Model identifier + temperature: Sampling temperature (0-1) + max_tokens: Maximum tokens to generate + tool_results: Results of tool executions + original_response: Original response with tool calls + + Returns: + OpenAI response object + """ + if not original_response or not tool_results: + return original_response + + # Get the original message with tool calls + original_message = original_response.choices[0].message + + # Create a new message with the tool calls + assistant_message = { + "role": "assistant", + "content": None, + "tool_calls": [{"id": tc.id, "type": "function", "function": {"name": tc.function.name, "arguments": tc.function.arguments}} for tc in original_message.tool_calls], + } + + # Create tool result messages + tool_messages = [] + for tool_call_id, result in tool_results.items(): + tool_messages.append({"role": "tool", "tool_call_id": tool_call_id, "content": result.get("result", str(result))}) + + # Add the original messages, assistant message, and tool results + new_messages = messages + [assistant_message] + tool_messages + + # Make a second request to get the final response + logger.info("Making second request with tool results") + return self.create_chat_completion( + messages=new_messages, + model=model, + temperature=temperature, + max_tokens=max_tokens, + stream=False, + tools=None, # No tools needed for follow-up + ) + + def get_content(self, response: Any) -> str: + """ + Extract content from the response. + + Args: + response: OpenAI response object + + Returns: + Content string from the response + """ + return response.choices[0].message.content + + def get_streaming_content(self, response: Any) -> Any: + """ + Get a generator for streaming content from the response. + + Args: + response: OpenAI streaming response object + + Returns: + Generator yielding content chunks + """ + + def generate(): + for chunk in response: + if chunk.choices and chunk.choices[0].delta.content: + # Don't do any newline replacement here + content = chunk.choices[0].delta.content + yield content + + return generate() diff --git a/src/airflow_wingman/templates/wingman_chat.html b/src/airflow_wingman/templates/wingman_chat.html index 72657f0..56a38a6 100644 --- a/src/airflow_wingman/templates/wingman_chat.html +++ b/src/airflow_wingman/templates/wingman_chat.html @@ -150,6 +150,7 @@ border: 1px solid #e9ecef; border-radius: 15px 15px 15px 0; padding: 10px 15px; + white-space: pre-wrap; } #chat-messages::after { @@ -349,6 +350,8 @@ document.addEventListener('DOMContentLoaded', function() { if (line.startsWith('data: ')) { const content = line.slice(6); if (content) { + // Use textContent to properly handle newlines + console.log('Received chunk:', JSON.stringify(content)); // Debug currentMessageDiv.textContent += content; fullResponse += content; chatMessages.scrollTop = chatMessages.scrollHeight; diff --git a/src/airflow_wingman/tools/__init__.py b/src/airflow_wingman/tools/__init__.py new file mode 100644 index 0000000..44e9360 --- /dev/null +++ b/src/airflow_wingman/tools/__init__.py @@ -0,0 +1,15 @@ +""" +Tools module for Airflow Wingman. + +This module contains the tools used by Airflow Wingman to interact with Airflow. +""" + +from airflow_wingman.tools.conversion import convert_to_anthropic_tools, convert_to_openai_tools +from airflow_wingman.tools.execution import execute_airflow_tool, list_airflow_tools + +__all__ = [ + "convert_to_openai_tools", + "convert_to_anthropic_tools", + "list_airflow_tools", + "execute_airflow_tool", +] diff --git a/src/airflow_wingman/tools/conversion.py b/src/airflow_wingman/tools/conversion.py new file mode 100644 index 0000000..7a30c02 --- /dev/null +++ b/src/airflow_wingman/tools/conversion.py @@ -0,0 +1,139 @@ +""" +Conversion utilities for Airflow Wingman tools. + +This module contains functions to convert between different tool formats +for various LLM providers (OpenAI, Anthropic, etc.). +""" + +from typing import Any + + +def convert_to_openai_tools(airflow_tools: list) -> list: + """ + Convert Airflow tools to OpenAI tool definitions. + + Args: + airflow_tools: List of Airflow tools from MCP server + + Returns: + List of OpenAI tool definitions + """ + openai_tools = [] + + for tool in airflow_tools: + # Initialize the OpenAI tool structure + openai_tool = {"type": "function", "function": {"name": tool.name, "description": tool.description or tool.name, "parameters": {"type": "object", "properties": {}, "required": []}}} + + # Extract parameters directly from inputSchema if available + if hasattr(tool, "inputSchema") and tool.inputSchema: + # Set the type and required fields directly from the schema + if "type" in tool.inputSchema: + openai_tool["function"]["parameters"]["type"] = tool.inputSchema["type"] + + # Add required parameters if specified + if "required" in tool.inputSchema: + openai_tool["function"]["parameters"]["required"] = tool.inputSchema["required"] + + # Add properties from the input schema + if "properties" in tool.inputSchema: + for param_name, param_info in tool.inputSchema["properties"].items(): + # Create parameter definition + param_def = {} + + # Handle different schema constructs + if "anyOf" in param_info: + _handle_schema_construct(param_def, param_info, "anyOf") + elif "oneOf" in param_info: + _handle_schema_construct(param_def, param_info, "oneOf") + elif "allOf" in param_info: + _handle_schema_construct(param_def, param_info, "allOf") + elif "type" in param_info: + param_def["type"] = param_info["type"] + # Add format if available + if "format" in param_info: + param_def["format"] = param_info["format"] + else: + param_def["type"] = "string" # Default type + + # Add description from title or param name + param_def["description"] = param_info.get("description", param_info.get("title", param_name)) + + # Add enum values if available + if "enum" in param_info: + param_def["enum"] = param_info["enum"] + + # Add default value if available + if "default" in param_info and param_info["default"] is not None: + param_def["default"] = param_info["default"] + + # Add to properties + openai_tool["function"]["parameters"]["properties"][param_name] = param_def + + openai_tools.append(openai_tool) + + return openai_tools + + +def convert_to_anthropic_tools(airflow_tools: list) -> list: + """ + Convert Airflow tools to Anthropic tool definitions. + + Args: + airflow_tools: List of Airflow tools from MCP server + + Returns: + List of Anthropic tool definitions + """ + anthropic_tools = [] + + for tool in airflow_tools: + # Initialize the Anthropic tool structure + anthropic_tool = {"name": tool.name, "description": tool.description or tool.name, "input_schema": {}} + + # Extract parameters directly from inputSchema if available + if hasattr(tool, "inputSchema") and tool.inputSchema: + # Copy the input schema directly as Anthropic's format is similar to JSON Schema + anthropic_tool["input_schema"] = tool.inputSchema + else: + # Create a minimal schema if none exists + anthropic_tool["input_schema"] = {"type": "object", "properties": {}, "required": []} + + anthropic_tools.append(anthropic_tool) + + return anthropic_tools + + +def _handle_schema_construct(param_def: dict[str, Any], param_info: dict[str, Any], construct_type: str) -> None: + """ + Helper function to handle JSON Schema constructs like anyOf, oneOf, allOf. + + Args: + param_def: Parameter definition to update + param_info: Parameter info from the schema + construct_type: Type of construct (anyOf, oneOf, allOf) + """ + # Get the list of schemas from the construct + schemas = param_info[construct_type] + + # Find the first schema with a type + for schema in schemas: + if "type" in schema: + param_def["type"] = schema["type"] + + # Add format if available + if "format" in schema: + param_def["format"] = schema["format"] + + # Add enum values if available + if "enum" in schema: + param_def["enum"] = schema["enum"] + + # Add default value if available + if "default" in schema and schema["default"] is not None: + param_def["default"] = schema["default"] + + break + + # If no type was found, default to string + if "type" not in param_def: + param_def["type"] = "string" diff --git a/src/airflow_wingman/tools/execution.py b/src/airflow_wingman/tools/execution.py new file mode 100644 index 0000000..495cff0 --- /dev/null +++ b/src/airflow_wingman/tools/execution.py @@ -0,0 +1,117 @@ +""" +Tool execution module for Airflow Wingman. + +This module contains functions to list and execute Airflow tools. +""" + +import asyncio +import json +import traceback + +from airflow import configuration +from airflow.utils.log.logging_mixin import LoggingMixin +from airflow_mcp_server.config import AirflowConfig +from airflow_mcp_server.tools.tool_manager import get_airflow_tools, get_tool + +# Create a logger instance +logger = LoggingMixin().log + + +async def _list_airflow_tools_async(cookie: str) -> list: + """ + Async implementation to list available Airflow tools. + + Args: + cookie: Cookie for authentication + + Returns: + List of available Airflow tools + """ + try: + # Set up configuration + base_url = f"{configuration.conf.get('webserver', 'base_url')}/api/v1/" + logger.info(f"Setting up AirflowConfig with base_url: {base_url}") + config = AirflowConfig(base_url=base_url, cookie=cookie, auth_token=None) + + # Get available tools + logger.info("Getting Airflow tools...") + tools = await get_airflow_tools(config=config, mode="safe") + logger.info(f"Got {len(tools)} tools") + return tools + except Exception as e: + error_msg = f"Error listing Airflow tools: {str(e)}\n{traceback.format_exc()}" + logger.error(error_msg) + return [] + + +def list_airflow_tools(cookie: str) -> list: + """ + Synchronous wrapper to list available Airflow tools. + + Args: + cookie: Cookie for authentication + + Returns: + List of available Airflow tools + """ + return asyncio.run(_list_airflow_tools_async(cookie)) + + +async def _execute_airflow_tool_async(tool_name: str, arguments: dict, cookie: str) -> str: + """ + Async implementation to execute an Airflow tool. + + Args: + tool_name: Name of the tool to execute + arguments: Arguments to pass to the tool + cookie: Cookie for authentication + + Returns: + Result of the tool execution as a string + """ + try: + # Set up configuration + base_url = f"{configuration.conf.get('webserver', 'base_url')}/api/v1/" + logger.info(f"Setting up AirflowConfig with base_url: {base_url}") + config = AirflowConfig(base_url=base_url, cookie=cookie, auth_token=None) + + # Get the tool + logger.info(f"Getting tool: {tool_name}") + tool = await get_tool(config=config, tool_name=tool_name) + + if not tool: + error_msg = f"Tool not found: {tool_name}" + logger.error(error_msg) + return json.dumps({"error": error_msg}) + + # Execute the tool + logger.info(f"Executing tool: {tool_name} with arguments: {arguments}") + result = await tool.run(arguments) + + # Convert result to string + if isinstance(result, dict | list): + result_str = json.dumps(result, indent=2) + else: + result_str = str(result) + + logger.info(f"Tool execution result: {result_str[:100]}...") + return result_str + except Exception as e: + error_msg = f"Error executing tool: {str(e)}\n{traceback.format_exc()}" + logger.error(error_msg) + return json.dumps({"error": error_msg}) + + +def execute_airflow_tool(tool_name: str, arguments: dict, cookie: str) -> str: + """ + Synchronous wrapper to execute an Airflow tool. + + Args: + tool_name: Name of the tool to execute + arguments: Arguments to pass to the tool + cookie: Cookie for authentication + + Returns: + Result of the tool execution as a string + """ + return asyncio.run(_execute_airflow_tool_async(tool_name, arguments, cookie)) diff --git a/src/airflow_wingman/views.py b/src/airflow_wingman/views.py index 34158d7..fcf3bff 100644 --- a/src/airflow_wingman/views.py +++ b/src/airflow_wingman/views.py @@ -1,6 +1,6 @@ """Views for Airflow Wingman plugin.""" -from flask import Response, request, stream_with_context +from flask import Response, request, session from flask.json import jsonify from flask_appbuilder import BaseView as AppBuilderBaseView, expose @@ -8,6 +8,7 @@ from airflow_wingman.llm_client import LLMClient from airflow_wingman.llms_models import MODELS from airflow_wingman.notes import INTERFACE_MESSAGES from airflow_wingman.prompt_engineering import prepare_messages +from airflow_wingman.tools import list_airflow_tools class WingmanView(AppBuilderBaseView): @@ -28,8 +29,32 @@ class WingmanView(AppBuilderBaseView): try: data = self._validate_chat_request(request.get_json()) - # Create a new client for this request - client = LLMClient(data["api_key"]) + if data.get("cookie"): + session["airflow_cookie"] = data["cookie"] + + # Get available Airflow tools using the stored cookie + airflow_tools = [] + if session.get("airflow_cookie"): + try: + airflow_tools = list_airflow_tools(session["airflow_cookie"]) + except Exception as e: + # Log the error but continue without tools + print(f"Error fetching Airflow tools: {str(e)}") + + # Prepare messages with Airflow tools included in the prompt + data["messages"] = prepare_messages(data["messages"]) + + # Get provider name from request or use default + provider_name = data.get("provider", "openai") + + # Get base URL from models configuration based on provider + base_url = MODELS.get(provider_name, {}).get("endpoint") + + # Create a new client for this request with the appropriate provider + client = LLMClient(provider_name=provider_name, api_key=data["api_key"], base_url=base_url) + + # Set the Airflow tools for the client to use + client.set_airflow_tools(airflow_tools) if data["stream"]: return self._handle_streaming_response(client, data) @@ -46,39 +71,49 @@ class WingmanView(AppBuilderBaseView): if not data: raise ValueError("No data provided") - required_fields = ["provider", "model", "messages", "api_key"] + required_fields = ["model", "messages", "api_key"] missing = [f for f in required_fields if not data.get(f)] if missing: raise ValueError(f"Missing required fields: {', '.join(missing)}") - # Prepare messages with system instruction while maintaining history - messages = data["messages"] - messages = prepare_messages(messages) + # Validate provider if provided + provider = data.get("provider", "openai") + if provider not in MODELS: + raise ValueError(f"Unsupported provider: {provider}. Supported providers: {', '.join(MODELS.keys())}") return { - "provider": data["provider"], "model": data["model"], - "messages": messages, + "messages": data["messages"], "api_key": data["api_key"], - "stream": data.get("stream", False), + "stream": data.get("stream", True), "temperature": data.get("temperature", 0.7), "max_tokens": data.get("max_tokens"), + "cookie": data.get("cookie"), + "provider": provider, + "base_url": data.get("base_url"), } def _handle_streaming_response(self, client: LLMClient, data: dict) -> Response: """Handle streaming response.""" + try: + # Get the streaming generator from the client + generator = client.chat_completion(messages=data["messages"], model=data["model"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=True) - def generate(): - for chunk in client.chat_completion(messages=data["messages"], model=data["model"], provider=data["provider"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=True): - yield f"data: {chunk}\n\n" + def stream_response(): + # Send SSE format for each chunk + for chunk in generator: + if chunk: + yield f"data: {chunk}\n\n" - response = Response(stream_with_context(generate()), mimetype="text/event-stream") - response.headers["Content-Type"] = "text/event-stream" - response.headers["Cache-Control"] = "no-cache" - response.headers["Connection"] = "keep-alive" - return response + # Signal end of stream + yield "data: [DONE]\n\n" + + return Response(stream_response(), mimetype="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}) + except Exception as e: + # If streaming fails, return error + return jsonify({"error": str(e)}), 500 def _handle_regular_response(self, client: LLMClient, data: dict) -> Response: """Handle regular response.""" - response = client.chat_completion(messages=data["messages"], model=data["model"], provider=data["provider"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=False) + response = client.chat_completion(messages=data["messages"], model=data["model"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=False) return jsonify(response)