Initial plugin

This commit is contained in:
2025-02-24 16:44:55 +00:00
parent 0b2c16076e
commit 4e3355aa53
12 changed files with 806 additions and 2 deletions

15
.pre-commit-config.yaml Normal file
View File

@@ -0,0 +1,15 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.11
hooks:
- id: ruff
args: [--fix]
- id: ruff-format
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files

View File

@@ -1,6 +1,6 @@
MIT License
Copyright (c) 2025 Abhishek
Copyright (c) 2025 Abhishek Bhakat
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal

View File

@@ -1 +1,2 @@
# airflow-wingman
# Airflow Wingman
Airflow plugin to enable LLMs chat in Airflow Webserver.

78
pyproject.toml Normal file
View File

@@ -0,0 +1,78 @@
[project]
name = "airflow-wingman"
version = "0.2.0"
description = "Airflow plugin to enable LLMs chat"
readme = "README.md"
requires-python = ">=3.11"
authors = [
{name = "Abhishek Bhakat", email = "abhishek.bhakat@hotmail.com"}
]
dependencies = [
"apache-airflow>=2.10.0",
"openai>=1.64.0",
"anthropic>=0.46.0"
]
classifiers = [
"Development Status :: 3 - Alpha",
"Environment :: Plugins",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.10",
]
license = "MIT"
license-files = ["LICEN[CS]E*"]
[project.urls]
GitHub = "https://github.com/abhishekbhakat/airflow-wingman"
Issues = "https://github.com/abhishekbhakat/airflow-wingman/issues"
[project.entry-points."airflow.plugins"]
wingman = "airflow_wingman:WingmanPlugin"
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["src/airflow_wingman"]
[tool.ruff]
line-length = 200
indent-width = 4
fix = true
preview = true
lint.select = [
"E", # pycodestyle errors
"F", # pyflakes
"I", # isort
"W", # pycodestyle warnings
"C90", # Complexity
"C", # flake8-comprehensions
"ISC", # flake8-implicit-str-concat
"T10", # flake8-debugger
"A", # flake8-builtins
"UP", # pyupgrade
]
lint.ignore = [
"C416", # Unnecessary list comprehension - rewrite as a generator expression
"C408", # Unnecessary `dict` call - rewrite as a literal
"ISC001" # Single line implicit string concatenation
]
lint.fixable = ["ALL"]
lint.unfixable = []
[tool.ruff.format]
quote-style = "double"
indent-style = "space"
skip-magic-trailing-comma = false
[tool.ruff.lint.isort]
combine-as-imports = true
[tool.ruff.lint.mccabe]
max-complexity = 12

View File

@@ -0,0 +1,6 @@
from importlib.metadata import version
from airflow_wingman.plugin import WingmanPlugin
__version__ = version("airflow-wingman")
__all__ = ["WingmanPlugin"]

View File

@@ -0,0 +1,109 @@
"""
Client for making API calls to various LLM providers using their official SDKs.
"""
from collections.abc import Generator
from anthropic import Anthropic
from openai import OpenAI
class LLMClient:
def __init__(self, api_key: str):
"""Initialize the LLM client.
Args:
api_key: API key for the provider
"""
self.api_key = api_key
self.openai_client = OpenAI(api_key=api_key)
self.anthropic_client = Anthropic(api_key=api_key)
self.openrouter_client = OpenAI(
base_url="https://openrouter.ai/api/v1",
api_key=api_key,
default_headers={
"HTTP-Referer": "Airflow Wingman", # Required by OpenRouter
"X-Title": "Airflow Wingman", # Required by OpenRouter
},
)
def chat_completion(
self, messages: list[dict[str, str]], model: str, provider: str, temperature: float = 0.7, max_tokens: int | None = None, stream: bool = False
) -> Generator[str, None, None] | dict:
"""Send a chat completion request to the specified provider.
Args:
messages: List of message dictionaries with 'role' and 'content'
model: Model identifier
provider: Provider identifier (openai, anthropic, openrouter)
temperature: Sampling temperature (0-1)
max_tokens: Maximum tokens to generate
stream: Whether to stream the response
Returns:
If stream=True, returns a generator yielding response chunks
If stream=False, returns the complete response
"""
try:
if provider == "openai":
return self._openai_chat_completion(messages, model, temperature, max_tokens, stream)
elif provider == "anthropic":
return self._anthropic_chat_completion(messages, model, temperature, max_tokens, stream)
elif provider == "openrouter":
return self._openrouter_chat_completion(messages, model, temperature, max_tokens, stream)
else:
return {"error": f"Unknown provider: {provider}"}
except Exception as e:
return {"error": f"API request failed: {str(e)}"}
def _openai_chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float, max_tokens: int | None, stream: bool):
"""Handle OpenAI chat completion requests."""
response = self.openai_client.chat.completions.create(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream)
if stream:
def response_generator():
for chunk in response:
if chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
return response_generator()
else:
return {"content": response.choices[0].message.content}
def _anthropic_chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float, max_tokens: int | None, stream: bool):
"""Handle Anthropic chat completion requests."""
# Convert messages to Anthropic format
system_message = next((m["content"] for m in messages if m["role"] == "system"), None)
conversation = []
for m in messages:
if m["role"] != "system":
conversation.append({"role": "assistant" if m["role"] == "assistant" else "user", "content": m["content"]})
response = self.anthropic_client.messages.create(model=model, messages=conversation, system=system_message, temperature=temperature, max_tokens=max_tokens, stream=stream)
if stream:
def response_generator():
for chunk in response:
if chunk.delta.text:
yield chunk.delta.text
return response_generator()
else:
return {"content": response.content[0].text}
def _openrouter_chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float, max_tokens: int | None, stream: bool):
"""Handle OpenRouter chat completion requests."""
response = self.openrouter_client.chat.completions.create(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream)
if stream:
def response_generator():
for chunk in response:
if chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
return response_generator()
else:
return {"content": response.choices[0].message.content}

View File

@@ -0,0 +1,48 @@
MODELS = {
"openai": {
"name": "OpenAI",
"endpoint": "https://api.openai.com/v1/chat/completions",
"models": [
{
"id": "gpt-4o",
"name": "GPT-4o",
"default": True,
"context_window": 128000,
"description": "Input $5/M tokens, Output $15/M tokens",
}
],
},
"anthropic": {
"name": "Anthropic",
"endpoint": "https://api.anthropic.com/v1/messages",
"models": [
{
"id": "claude-3.5-sonnet",
"name": "Claude 3.5 Sonnet",
"default": True,
"context_window": 200000,
"description": "Input $3/M tokens, Output $15/M tokens",
},
{
"id": "claude-3.5-haiku",
"name": "Claude 3.5 Haiku",
"default": False,
"context_window": 200000,
"description": "Input $0.80/M tokens, Output $4/M tokens",
},
],
},
"openrouter": {
"name": "OpenRouter",
"endpoint": "https://openrouter.ai/api/v1/chat/completions",
"models": [
{
"id": "custom",
"name": "Custom Model",
"default": False,
"context_window": 128000, # Default context window, will be updated based on model
"description": "Enter any model name supported by OpenRouter (e.g., 'anthropic/claude-3-opus', 'meta-llama/llama-2-70b')",
},
],
},
}

View File

@@ -0,0 +1,13 @@
INTERFACE_MESSAGES = {
"model_recommendation": {"title": "Note", "content": "For best results with function/tool calling capabilities, we recommend using models like Claude-3.5 Sonnet or GPT-4."},
"security_note": {
"title": "Security",
"content": "For your security, API keys are required for each session and are never stored. If you refresh the page or close the browser, you'll need to enter your API key again.",
},
"context_window": {
"title": "Context Window",
"content": "Each model has a maximum context window size that determines how much text it can process. "
"For long conversations or large code snippets, consider using models with larger context windows like Claude-3 Opus (200K tokens) or GPT-4 Turbo (128K tokens). "
"For better results try to keep the context size as low as possible. Try using new chats instead of reusing the same chat.",
},
}

View File

@@ -0,0 +1,32 @@
"""Plugin definition for Airflow Wingman."""
from airflow.plugins_manager import AirflowPlugin
from flask import Blueprint
from airflow_wingman.views import WingmanView
# Create Blueprint
bp = Blueprint(
"wingman",
__name__,
template_folder="templates",
static_folder="static",
static_url_path="/static/wingman",
)
# Create AppBuilder View
v_appbuilder_view = WingmanView()
v_appbuilder_package = {
"name": "Wingman",
"category": "AI",
"view": v_appbuilder_view,
}
# Create Plugin
class WingmanPlugin(AirflowPlugin):
"""Airflow plugin for Wingman chat interface."""
name = "wingman"
flask_blueprints = [bp]
appbuilder_views = [v_appbuilder_package]

View File

@@ -0,0 +1,34 @@
"""
Prompt engineering for the Airflow Wingman plugin.
Contains prompts and instructions for the AI assistant.
"""
INSTRUCTIONS = {
"default": """You are Airflow Wingman, a helpful AI assistant integrated into Apache Airflow.
You have deep knowledge of Apache Airflow's architecture, DAGs, operators, and best practices.
The Airflow version being used is >=2.10.
You have access to the following Airflow API tools:
You can use these tools to fetch information and help users understand and manage their Airflow environment.
"""
}
def prepare_messages(messages: list[dict[str, str]], instruction_key: str = "default") -> list[dict[str, str]]:
"""Prepare messages for the chat completion request.
Args:
messages: List of messages in the conversation
instruction_key: Key for the instruction template to use
Returns:
List of message dictionaries ready for the chat completion API
"""
instruction = INSTRUCTIONS.get(instruction_key, INSTRUCTIONS["default"])
# Add instruction as first system message if not present
if not messages or messages[0].get("role") != "system":
messages.insert(0, {"role": "system", "content": instruction})
return messages

View File

@@ -0,0 +1,384 @@
{% extends "appbuilder/base.html" %}
{% block head_meta %}
{{ super() }}
<meta name="csrf-token" content="{{ csrf_token() }}">
{% endblock %}
{% block content %}
<div class="container-fluid">
<!-- Banner -->
<div class="row">
<div class="col-md-12">
<div class="panel panel-primary">
<div class="panel-heading">
<h3 class="panel-title">Airflow Wingman</h3>
</div>
<div class="alert alert-info" style="margin: 15px;">
<p><strong>{{ interface_messages.model_recommendation.title }}:</strong> {{ interface_messages.model_recommendation.content }}</p>
<hr style="margin: 10px 0;">
<p><strong>{{ interface_messages.security_note.title }}:</strong> {{ interface_messages.security_note.content }}</p>
<hr style="margin: 10px 0;">
<p><strong>{{ interface_messages.context_window.title }}:</strong> {{ interface_messages.context_window.content }}</p>
</div>
</div>
</div>
</div>
<div class="row">
<!-- Sidebar -->
<div class="col-md-3">
<div class="panel panel-default">
<div class="panel-heading">
<h3 class="panel-title">Provider Selection</h3>
</div>
<div class="panel-body">
{% for provider_id, provider in models.items() %}
<div class="provider-section mb-3">
<h4 class="provider-name">{{ provider.name }}</h4>
{% for model in provider.models %}
<div class="radio model-option">
<label class="model-label" title="{{ model.description }}">
<input type="radio"
name="model"
value="{{ provider_id }}:{{ model.id }}"
{% if model.default %}checked{% endif %}
data-context-window="{{ model.context_window }}"
data-provider="{{ provider_id }}"
data-model-name="{{ model.name }}">
{{ model.name }}
</label>
</div>
{% endfor %}
</div>
{% endfor %}
</div>
<!-- Model Name Input -->
<div class="panel-body" style="border-top: 1px solid #ddd; padding-top: 15px;">
<div class="form-group">
<label for="modelName">Model Name</label>
<input type="text" class="form-control" id="modelName" placeholder="Enter model name for OpenRouter" disabled>
<small class="form-text text-muted">Only required for OpenRouter provider</small>
</div>
</div>
<!-- API Key Input -->
<div class="panel-body" style="border-top: 1px solid #ddd; padding-top: 15px;">
<div class="form-group">
<label for="api-key">API Key</label>
<input type="password"
class="form-control"
id="api-key"
placeholder="Enter API key for selected provider"
required
autocomplete="off">
<small class="text-muted">Your API key will be used for the selected provider</small>
</div>
</div>
<style>
.provider-section {
margin-bottom: 20px;
}
.provider-name {
font-size: 16px;
font-weight: bold;
margin-bottom: 10px;
color: #666;
}
.model-option {
margin-left: 15px;
margin-bottom: 8px;
}
.model-option label {
display: block;
cursor: pointer;
}
</style>
</div>
</div>
<!-- Main Chat Window -->
<div class="col-md-9">
<div class="panel panel-default" style="height: calc(80vh - 250px); display: flex; flex-direction: column;">
<div class="panel-body" style="flex-grow: 1; overflow-y: auto; padding: 15px;" id="chat-messages">
<!-- Messages will be dynamically added here -->
</div>
<div class="panel-footer" style="padding: 15px; background-color: white;">
<div class="row">
<div class="col-md-2">
<button class="btn btn-default btn-block" type="button" id="refresh-button" title="Start a new chat">
<i class="fa fa-refresh"></i> New Chat
</button>
</div>
<div class="col-md-10">
<div class="input-group">
<input type="text" class="form-control" id="message-input" placeholder="Type your message...">
<span class="input-group-btn">
<button class="btn btn-primary" type="button" id="send-button">
<i class="fa fa-paper-plane"></i> Send
</button>
</span>
</div>
</div>
</div>
</div>
</div>
</div>
</div>
</div>
<style>
.message {
margin-bottom: 15px;
max-width: 80%;
clear: both;
}
.message-user {
float: right;
background-color: #f0f7ff;
border: 1px solid #d1e6ff;
border-radius: 15px 15px 0 15px;
padding: 10px 15px;
}
.message-assistant {
float: left;
background-color: #f8f9fa;
border: 1px solid #e9ecef;
border-radius: 15px 15px 15px 0;
padding: 10px 15px;
}
#chat-messages::after {
content: "";
clear: both;
display: table;
}
.panel-body::-webkit-scrollbar {
width: 8px;
}
.panel-body::-webkit-scrollbar-track {
background: #f1f1f1;
}
.panel-body::-webkit-scrollbar-thumb {
background: #888;
border-radius: 4px;
}
.panel-body::-webkit-scrollbar-thumb:hover {
background: #555;
}
</style>
<script>
document.addEventListener('DOMContentLoaded', function() {
// Add title attributes for tooltips
document.querySelectorAll('[data-bs-toggle="tooltip"]').forEach(function(el) {
el.title = el.getAttribute('title') || el.getAttribute('data-bs-original-title');
});
// Handle model selection and model name input
const modelNameInput = document.getElementById('modelName');
const modelRadios = document.querySelectorAll('input[name="model"]');
modelRadios.forEach(function(radio) {
radio.addEventListener('change', function() {
const provider = this.value.split(':')[0]; // Get provider from value instead of data attribute
const modelName = this.getAttribute('data-model-name');
console.log('Selected provider:', provider);
console.log('Model name:', modelName);
if (provider === 'openrouter') {
console.log('Enabling model name input');
modelNameInput.disabled = false;
modelNameInput.value = '';
modelNameInput.placeholder = 'Enter model name for OpenRouter';
} else {
console.log('Disabling model name input');
modelNameInput.disabled = true;
modelNameInput.value = modelName;
}
});
});
// Set initial state based on default selection
const defaultSelected = document.querySelector('input[name="model"]:checked');
if (defaultSelected) {
const provider = defaultSelected.value.split(':')[0]; // Get provider from value instead of data attribute
const modelName = defaultSelected.getAttribute('data-model-name');
console.log('Initial provider:', provider);
console.log('Initial model name:', modelName);
if (provider === 'openrouter') {
console.log('Initially enabling model name input');
modelNameInput.disabled = false;
modelNameInput.value = '';
modelNameInput.placeholder = 'Enter model name for OpenRouter';
} else {
console.log('Initially disabling model name input');
modelNameInput.disabled = true;
modelNameInput.value = modelName;
}
}
const messageInput = document.getElementById('message-input');
const sendButton = document.getElementById('send-button');
const refreshButton = document.getElementById('refresh-button');
const chatMessages = document.getElementById('chat-messages');
let currentMessageDiv = null;
let messageHistory = [];
function clearChat() {
// Clear the chat messages
chatMessages.innerHTML = '';
// Reset message history
messageHistory = [];
// Clear the input field
messageInput.value = '';
// Enable input if it was disabled
messageInput.disabled = false;
sendButton.disabled = false;
}
function addMessage(content, isUser) {
const messageDiv = document.createElement('div');
messageDiv.className = `message ${isUser ? 'message-user' : 'message-assistant'}`;
messageDiv.textContent = content;
chatMessages.appendChild(messageDiv);
chatMessages.scrollTop = chatMessages.scrollHeight;
return messageDiv;
}
async function sendMessage() {
const message = messageInput.value.trim();
if (!message) return;
// Get selected model
const selectedModel = document.querySelector('input[name="model"]:checked');
if (!selectedModel) {
alert('Please select a model');
return;
}
const [provider, modelId] = selectedModel.value.split(':');
const modelName = provider === 'openrouter' ? modelNameInput.value : modelId;
// Clear input and add user message
messageInput.value = '';
addMessage(message, true);
try {
// Add user message to history
messageHistory.push({
role: 'user',
content: message
});
// Use full message history for the request
const messages = [...messageHistory];
// Create assistant message div
currentMessageDiv = addMessage('', false);
// Get API key
const apiKey = document.getElementById('api-key').value.trim();
if (!apiKey) {
alert('Please enter an API key');
return;
}
// Debug log the request
const requestData = {
provider: provider,
model: modelName,
messages: messages,
api_key: apiKey,
stream: true,
temperature: 0.7
};
console.log('Sending request:', {...requestData, api_key: '***'});
// Get CSRF token
const csrfToken = document.querySelector('meta[name="csrf-token"]')?.getAttribute('content');
if (!csrfToken) {
throw new Error('CSRF token not found. Please refresh the page.');
}
// Send request
const response = await fetch('/wingman/chat', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'X-CSRFToken': csrfToken
},
body: JSON.stringify({
provider: provider,
model: modelName,
messages: messages,
api_key: apiKey,
stream: true,
temperature: 0.7
})
});
if (!response.ok) {
const error = await response.json();
throw new Error(error.error || 'Failed to get response');
}
// Handle streaming response
const reader = response.body.getReader();
const decoder = new TextDecoder();
let fullResponse = '';
while (true) {
const { value, done } = await reader.read();
if (done) break;
const chunk = decoder.decode(value);
const lines = chunk.split('\n');
for (const line of lines) {
if (line.startsWith('data: ')) {
const content = line.slice(6);
if (content) {
currentMessageDiv.textContent += content;
fullResponse += content;
chatMessages.scrollTop = chatMessages.scrollHeight;
}
}
}
}
// Add assistant's response to history
if (fullResponse) {
messageHistory.push({
role: 'assistant',
content: fullResponse
});
}
} catch (error) {
console.error('Error:', error);
currentMessageDiv.textContent = `Error: ${error.message}`;
currentMessageDiv.style.color = 'red';
}
}
sendButton.addEventListener('click', sendMessage);
messageInput.addEventListener('keypress', function(e) {
if (e.key === 'Enter') {
sendMessage();
}
});
refreshButton.addEventListener('click', clearChat);
});
</script>
{% endblock %}

View File

@@ -0,0 +1,84 @@
"""Views for Airflow Wingman plugin."""
from flask import Response, request, stream_with_context
from flask.json import jsonify
from flask_appbuilder import BaseView as AppBuilderBaseView, expose
from airflow_wingman.llm_client import LLMClient
from airflow_wingman.llms_models import MODELS
from airflow_wingman.notes import INTERFACE_MESSAGES
from airflow_wingman.prompt_engineering import prepare_messages
class WingmanView(AppBuilderBaseView):
"""View for Airflow Wingman plugin."""
route_base = "/wingman"
default_view = "chat"
@expose("/")
def chat(self):
"""Render chat interface."""
providers = {provider: info["name"] for provider, info in MODELS.items()}
return self.render_template("wingman_chat.html", title="Airflow Wingman", models=MODELS, providers=providers, interface_messages=INTERFACE_MESSAGES)
@expose("/chat", methods=["POST"])
def chat_completion(self):
"""Handle chat completion requests."""
try:
data = self._validate_chat_request(request.get_json())
# Create a new client for this request
client = LLMClient(data["api_key"])
if data["stream"]:
return self._handle_streaming_response(client, data)
else:
return self._handle_regular_response(client, data)
except ValueError as e:
return jsonify({"error": str(e)}), 400
except Exception as e:
return jsonify({"error": str(e)}), 500
def _validate_chat_request(self, data: dict) -> dict:
"""Validate chat request data."""
if not data:
raise ValueError("No data provided")
required_fields = ["provider", "model", "messages", "api_key"]
missing = [f for f in required_fields if not data.get(f)]
if missing:
raise ValueError(f"Missing required fields: {', '.join(missing)}")
# Prepare messages with system instruction while maintaining history
messages = data["messages"]
messages = prepare_messages(messages)
return {
"provider": data["provider"],
"model": data["model"],
"messages": messages,
"api_key": data["api_key"],
"stream": data.get("stream", False),
"temperature": data.get("temperature", 0.7),
"max_tokens": data.get("max_tokens"),
}
def _handle_streaming_response(self, client: LLMClient, data: dict) -> Response:
"""Handle streaming response."""
def generate():
for chunk in client.chat_completion(messages=data["messages"], model=data["model"], provider=data["provider"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=True):
yield f"data: {chunk}\n\n"
response = Response(stream_with_context(generate()), mimetype="text/event-stream")
response.headers["Content-Type"] = "text/event-stream"
response.headers["Cache-Control"] = "no-cache"
response.headers["Connection"] = "keep-alive"
return response
def _handle_regular_response(self, client: LLMClient, data: dict) -> Response:
"""Handle regular response."""
response = client.chat_completion(messages=data["messages"], model=data["model"], provider=data["provider"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=False)
return jsonify(response)