From eea23b2097be39af360a54403d190184c48f2c97 Mon Sep 17 00:00:00 2001 From: abhishekbhakat Date: Mon, 24 Feb 2025 13:11:09 +0000 Subject: [PATCH] restructure code for modularity and use ai modules --- airflow-wingman/pyproject.toml | 4 +- .../src/airflow_wingman/llm_client.py | 109 ++++++++++++++ airflow-wingman/src/airflow_wingman/plugin.py | 24 +-- .../templates/wingman_chat.html | 138 ++++++++++++++---- airflow-wingman/src/airflow_wingman/views.py | 76 ++++++++++ 5 files changed, 304 insertions(+), 47 deletions(-) create mode 100644 airflow-wingman/src/airflow_wingman/llm_client.py create mode 100644 airflow-wingman/src/airflow_wingman/views.py diff --git a/airflow-wingman/pyproject.toml b/airflow-wingman/pyproject.toml index 37d22a1..5dfb861 100644 --- a/airflow-wingman/pyproject.toml +++ b/airflow-wingman/pyproject.toml @@ -10,7 +10,9 @@ authors = [ ] dependencies = [ "apache-airflow>=2.10.0", - "airflow-mcp-server>=0.2.0" + "airflow-mcp-server>=0.2.0", + "openai>=1.64.0", + "anthropic>=0.46.0" ] classifiers = [ "Development Status :: 3 - Alpha", diff --git a/airflow-wingman/src/airflow_wingman/llm_client.py b/airflow-wingman/src/airflow_wingman/llm_client.py new file mode 100644 index 0000000..dc2543c --- /dev/null +++ b/airflow-wingman/src/airflow_wingman/llm_client.py @@ -0,0 +1,109 @@ +""" +Client for making API calls to various LLM providers using their official SDKs. +""" + +from collections.abc import AsyncGenerator + +from anthropic import AsyncAnthropic +from openai import AsyncOpenAI + + +class LLMClient: + def __init__(self, api_key: str): + """Initialize the LLM client. + + Args: + api_key: API key for the provider + """ + self.api_key = api_key + self.openai_client = AsyncOpenAI(api_key=api_key) + self.anthropic_client = AsyncAnthropic(api_key=api_key) + self.openrouter_client = AsyncOpenAI( + base_url="https://openrouter.ai/api/v1", + api_key=api_key, + default_headers={ + "HTTP-Referer": "http://localhost:8080", # Required by OpenRouter + "X-Title": "Airflow Wingman", # Required by OpenRouter + }, + ) + + async def chat_completion( + self, messages: list[dict[str, str]], model: str, provider: str, temperature: float = 0.7, max_tokens: int | None = None, stream: bool = False + ) -> AsyncGenerator[str, None] | dict: + """Send a chat completion request to the specified provider. + + Args: + messages: List of message dictionaries with 'role' and 'content' + model: Model identifier + provider: Provider identifier (openai, anthropic, openrouter) + temperature: Sampling temperature (0-1) + max_tokens: Maximum tokens to generate + stream: Whether to stream the response + + Returns: + If stream=True, returns an async generator yielding response chunks + If stream=False, returns the complete response + """ + try: + if provider == "openai": + return await self._openai_chat_completion(messages, model, temperature, max_tokens, stream) + elif provider == "anthropic": + return await self._anthropic_chat_completion(messages, model, temperature, max_tokens, stream) + elif provider == "openrouter": + return await self._openrouter_chat_completion(messages, model, temperature, max_tokens, stream) + else: + return {"error": f"Unknown provider: {provider}"} + except Exception as e: + return {"error": f"API request failed: {str(e)}"} + + async def _openai_chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float, max_tokens: int | None, stream: bool): + """Handle OpenAI chat completion requests.""" + response = await self.openai_client.chat.completions.create(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream) + + if stream: + + async def response_generator(): + async for chunk in response: + if chunk.choices[0].delta.content: + yield chunk.choices[0].delta.content + + return response_generator() + else: + return {"content": response.choices[0].message.content} + + async def _anthropic_chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float, max_tokens: int | None, stream: bool): + """Handle Anthropic chat completion requests.""" + # Convert messages to Anthropic format + system_message = next((m["content"] for m in messages if m["role"] == "system"), None) + conversation = [] + for m in messages: + if m["role"] != "system": + conversation.append({"role": "assistant" if m["role"] == "assistant" else "user", "content": m["content"]}) + + response = await self.anthropic_client.messages.create(model=model, messages=conversation, system=system_message, temperature=temperature, max_tokens=max_tokens, stream=stream) + + if stream: + + async def response_generator(): + async for chunk in response: + if chunk.delta.text: + yield chunk.delta.text + + return response_generator() + else: + return {"content": response.content[0].text} + + async def _openrouter_chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float, max_tokens: int | None, stream: bool): + """Handle OpenRouter chat completion requests.""" + response = await self.openrouter_client.chat.completions.create(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream) + + if stream: + + async def response_generator(): + async for chunk in response: + if chunk.choices[0].delta.content: + yield chunk.choices[0].delta.content + + return response_generator() + else: + return {"content": response.choices[0].message.content} diff --git a/airflow-wingman/src/airflow_wingman/plugin.py b/airflow-wingman/src/airflow_wingman/plugin.py index 78b01fd..b1e37d6 100644 --- a/airflow-wingman/src/airflow_wingman/plugin.py +++ b/airflow-wingman/src/airflow_wingman/plugin.py @@ -1,10 +1,11 @@ +"""Plugin definition for Airflow Wingman.""" + from airflow.plugins_manager import AirflowPlugin -from flask_appbuilder import BaseView as AppBuilderBaseView, expose from flask import Blueprint -from airflow_wingman.llms_models import MODELS - +from airflow_wingman.views import WingmanView +# Create Blueprint bp = Blueprint( "wingman", __name__, @@ -13,21 +14,6 @@ bp = Blueprint( static_url_path="/static/wingman", ) - -class WingmanView(AppBuilderBaseView): - route_base = "/wingman" - default_view = "chat" - - @expose("/") - def chat(self): - """ - Chat interface for Airflow Wingman. - """ - return self.render_template( - "wingman_chat.html", title="Airflow Wingman", models=MODELS - ) - - # Create AppBuilder View v_appbuilder_view = WingmanView() v_appbuilder_package = { @@ -39,6 +25,8 @@ v_appbuilder_package = { # Create Plugin class WingmanPlugin(AirflowPlugin): + """Airflow plugin for Wingman chat interface.""" + name = "wingman" flask_blueprints = [bp] appbuilder_views = [v_appbuilder_package] diff --git a/airflow-wingman/src/airflow_wingman/templates/wingman_chat.html b/airflow-wingman/src/airflow_wingman/templates/wingman_chat.html index 6c6360b..3187fde 100644 --- a/airflow-wingman/src/airflow_wingman/templates/wingman_chat.html +++ b/airflow-wingman/src/airflow_wingman/templates/wingman_chat.html @@ -11,6 +11,8 @@

Note: For best results with function/tool calling capabilities, we recommend using models like Claude-3.5 Sonnet or GPT-4o. These models excel at understanding and using complex tools effectively.

+
+

Security: For your security, API keys are required for each session and are never stored. If you refresh the page or close the browser, you'll need to enter your API key again. This ensures your API keys remain secure in shared environments.

@@ -34,7 +36,6 @@ name="model" value="{{ provider_id }}:{{ model.id }}" {% if model.default %}checked{% endif %} - data-endpoint="{{ provider.endpoint }}" data-context-window="{{ model.context_window }}" data-provider="{{ provider_id }}" data-model-name="{{ model.name }}"> @@ -51,6 +52,21 @@
+ Only required for OpenRouter provider +
+ + + +
+
+ + + Your API key will be used for the selected provider
@@ -74,30 +90,11 @@ } - - -
-
-

API Key

-
-
-
- - - Your API key will be used for the selected provider - -
-
-
-
+
@@ -180,7 +177,7 @@ document.addEventListener('DOMContentLoaded', function() { const modelName = this.getAttribute('data-model-name'); console.log('Selected provider:', provider); console.log('Model name:', modelName); - + if (provider === 'openrouter') { console.log('Enabling model name input'); modelNameInput.disabled = false; @@ -201,7 +198,7 @@ document.addEventListener('DOMContentLoaded', function() { const modelName = defaultSelected.getAttribute('data-model-name'); console.log('Initial provider:', provider); console.log('Initial model name:', modelName); - + if (provider === 'openrouter') { console.log('Initially enabling model name input'); modelNameInput.disabled = false; @@ -218,20 +215,105 @@ document.addEventListener('DOMContentLoaded', function() { const sendButton = document.getElementById('send-button'); const chatMessages = document.getElementById('chat-messages'); + let currentMessageDiv = null; + function addMessage(content, isUser) { const messageDiv = document.createElement('div'); messageDiv.className = `message ${isUser ? 'message-user' : 'message-assistant'}`; messageDiv.textContent = content; chatMessages.appendChild(messageDiv); chatMessages.scrollTop = chatMessages.scrollHeight; + return messageDiv; } - function sendMessage() { + async function sendMessage() { const message = messageInput.value.trim(); - if (message) { - addMessage(message, true); - messageInput.value = ''; - // TODO: Add API call to send message and get response + if (!message) return; + + // Get selected model + const selectedModel = document.querySelector('input[name="model"]:checked'); + if (!selectedModel) { + alert('Please select a model'); + return; + } + + const provider = selectedModel.getAttribute('data-provider'); + const modelId = selectedModel.value.split(':')[1]; + const modelName = provider === 'openrouter' ? modelNameInput.value : modelId; + + // Clear input and add user message + messageInput.value = ''; + addMessage(message, true); + + try { + // Create messages array with system message + const messages = [ + { + role: 'system', + content: 'You are a helpful AI assistant integrated into Apache Airflow.' + }, + { + role: 'user', + content: message + } + ]; + + // Create assistant message div + currentMessageDiv = addMessage('', false); + + // Get API key + const apiKey = document.getElementById('api-key').value.trim(); + if (!apiKey) { + alert('Please enter an API key'); + return; + } + + // Send request + const response = await fetch('/wingman/chat', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + provider: provider, + model: modelName, + messages: messages, + api_key: apiKey, + stream: true, + temperature: 0.7 + }) + }); + + if (!response.ok) { + const error = await response.json(); + throw new Error(error.error || 'Failed to get response'); + } + + // Handle streaming response + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + + while (true) { + const { value, done } = await reader.read(); + if (done) break; + + const chunk = decoder.decode(value); + const lines = chunk.split('\n'); + + for (const line of lines) { + if (line.startsWith('data: ')) { + const content = line.slice(6); + if (content) { + currentMessageDiv.textContent += content; + chatMessages.scrollTop = chatMessages.scrollHeight; + } + } + } + } + } catch (error) { + console.error('Error:', error); + currentMessageDiv.textContent = `Error: ${error.message}`; + currentMessageDiv.style.color = 'red'; } } diff --git a/airflow-wingman/src/airflow_wingman/views.py b/airflow-wingman/src/airflow_wingman/views.py new file mode 100644 index 0000000..0dc6934 --- /dev/null +++ b/airflow-wingman/src/airflow_wingman/views.py @@ -0,0 +1,76 @@ +"""Views for Airflow Wingman plugin.""" + +from flask import Response, request, stream_with_context +from flask.json import jsonify +from flask_appbuilder import BaseView as AppBuilderBaseView, expose + +from airflow_wingman.llm_client import LLMClient +from airflow_wingman.llms_models import MODELS + + +class WingmanView(AppBuilderBaseView): + """View for Airflow Wingman plugin.""" + + route_base = "/wingman" + default_view = "chat" + + @expose("/") + def chat(self): + """Render chat interface.""" + providers = {provider: info["name"] for provider, info in MODELS.items()} + return self.render_template("wingman_chat.html", title="Airflow Wingman", models=MODELS, providers=providers) + + @expose("/chat", methods=["POST"]) + async def chat_completion(self): + """Handle chat completion requests.""" + try: + data = self._validate_chat_request(request.get_json()) + + # Create a new client for this request + client = LLMClient(data["api_key"]) + + if data["stream"]: + return self._handle_streaming_response(client, data) + else: + return await self._handle_regular_response(client, data) + + except ValueError as e: + return jsonify({"error": str(e)}), 400 + except Exception as e: + return jsonify({"error": str(e)}), 500 + + def _validate_chat_request(self, data: dict) -> dict: + """Validate chat request data.""" + if not data: + raise ValueError("No data provided") + + required_fields = ["provider", "model", "messages", "api_key"] + missing = [f for f in required_fields if not data.get(f)] + if missing: + raise ValueError(f"Missing required fields: {', '.join(missing)}") + + return { + "provider": data["provider"], + "model": data["model"], + "messages": data["messages"], + "api_key": data["api_key"], + "stream": data.get("stream", False), + "temperature": data.get("temperature", 0.7), + "max_tokens": data.get("max_tokens"), + } + + def _handle_streaming_response(self, client: LLMClient, data: dict) -> Response: + """Handle streaming response.""" + + async def generate(): + async for chunk in await client.chat_completion( + messages=data["messages"], model=data["model"], provider=data["provider"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=True + ): + yield f"data: {chunk}\n\n" + + return Response(stream_with_context(generate()), mimetype="text/event-stream") + + async def _handle_regular_response(self, client: LLMClient, data: dict) -> Response: + """Handle regular response.""" + response = await client.chat_completion(messages=data["messages"], model=data["model"], provider=data["provider"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=False) + return jsonify(response)