restructure code for modularity and use ai modules
This commit is contained in:
@@ -10,7 +10,9 @@ authors = [
|
|||||||
]
|
]
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"apache-airflow>=2.10.0",
|
"apache-airflow>=2.10.0",
|
||||||
"airflow-mcp-server>=0.2.0"
|
"airflow-mcp-server>=0.2.0",
|
||||||
|
"openai>=1.64.0",
|
||||||
|
"anthropic>=0.46.0"
|
||||||
]
|
]
|
||||||
classifiers = [
|
classifiers = [
|
||||||
"Development Status :: 3 - Alpha",
|
"Development Status :: 3 - Alpha",
|
||||||
|
|||||||
109
airflow-wingman/src/airflow_wingman/llm_client.py
Normal file
109
airflow-wingman/src/airflow_wingman/llm_client.py
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
"""
|
||||||
|
Client for making API calls to various LLM providers using their official SDKs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
|
||||||
|
from anthropic import AsyncAnthropic
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
|
||||||
|
class LLMClient:
|
||||||
|
def __init__(self, api_key: str):
|
||||||
|
"""Initialize the LLM client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: API key for the provider
|
||||||
|
"""
|
||||||
|
self.api_key = api_key
|
||||||
|
self.openai_client = AsyncOpenAI(api_key=api_key)
|
||||||
|
self.anthropic_client = AsyncAnthropic(api_key=api_key)
|
||||||
|
self.openrouter_client = AsyncOpenAI(
|
||||||
|
base_url="https://openrouter.ai/api/v1",
|
||||||
|
api_key=api_key,
|
||||||
|
default_headers={
|
||||||
|
"HTTP-Referer": "http://localhost:8080", # Required by OpenRouter
|
||||||
|
"X-Title": "Airflow Wingman", # Required by OpenRouter
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def chat_completion(
|
||||||
|
self, messages: list[dict[str, str]], model: str, provider: str, temperature: float = 0.7, max_tokens: int | None = None, stream: bool = False
|
||||||
|
) -> AsyncGenerator[str, None] | dict:
|
||||||
|
"""Send a chat completion request to the specified provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of message dictionaries with 'role' and 'content'
|
||||||
|
model: Model identifier
|
||||||
|
provider: Provider identifier (openai, anthropic, openrouter)
|
||||||
|
temperature: Sampling temperature (0-1)
|
||||||
|
max_tokens: Maximum tokens to generate
|
||||||
|
stream: Whether to stream the response
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
If stream=True, returns an async generator yielding response chunks
|
||||||
|
If stream=False, returns the complete response
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if provider == "openai":
|
||||||
|
return await self._openai_chat_completion(messages, model, temperature, max_tokens, stream)
|
||||||
|
elif provider == "anthropic":
|
||||||
|
return await self._anthropic_chat_completion(messages, model, temperature, max_tokens, stream)
|
||||||
|
elif provider == "openrouter":
|
||||||
|
return await self._openrouter_chat_completion(messages, model, temperature, max_tokens, stream)
|
||||||
|
else:
|
||||||
|
return {"error": f"Unknown provider: {provider}"}
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": f"API request failed: {str(e)}"}
|
||||||
|
|
||||||
|
async def _openai_chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float, max_tokens: int | None, stream: bool):
|
||||||
|
"""Handle OpenAI chat completion requests."""
|
||||||
|
response = await self.openai_client.chat.completions.create(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
|
||||||
|
async def response_generator():
|
||||||
|
async for chunk in response:
|
||||||
|
if chunk.choices[0].delta.content:
|
||||||
|
yield chunk.choices[0].delta.content
|
||||||
|
|
||||||
|
return response_generator()
|
||||||
|
else:
|
||||||
|
return {"content": response.choices[0].message.content}
|
||||||
|
|
||||||
|
async def _anthropic_chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float, max_tokens: int | None, stream: bool):
|
||||||
|
"""Handle Anthropic chat completion requests."""
|
||||||
|
# Convert messages to Anthropic format
|
||||||
|
system_message = next((m["content"] for m in messages if m["role"] == "system"), None)
|
||||||
|
conversation = []
|
||||||
|
for m in messages:
|
||||||
|
if m["role"] != "system":
|
||||||
|
conversation.append({"role": "assistant" if m["role"] == "assistant" else "user", "content": m["content"]})
|
||||||
|
|
||||||
|
response = await self.anthropic_client.messages.create(model=model, messages=conversation, system=system_message, temperature=temperature, max_tokens=max_tokens, stream=stream)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
|
||||||
|
async def response_generator():
|
||||||
|
async for chunk in response:
|
||||||
|
if chunk.delta.text:
|
||||||
|
yield chunk.delta.text
|
||||||
|
|
||||||
|
return response_generator()
|
||||||
|
else:
|
||||||
|
return {"content": response.content[0].text}
|
||||||
|
|
||||||
|
async def _openrouter_chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float, max_tokens: int | None, stream: bool):
|
||||||
|
"""Handle OpenRouter chat completion requests."""
|
||||||
|
response = await self.openrouter_client.chat.completions.create(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
|
||||||
|
async def response_generator():
|
||||||
|
async for chunk in response:
|
||||||
|
if chunk.choices[0].delta.content:
|
||||||
|
yield chunk.choices[0].delta.content
|
||||||
|
|
||||||
|
return response_generator()
|
||||||
|
else:
|
||||||
|
return {"content": response.choices[0].message.content}
|
||||||
@@ -1,10 +1,11 @@
|
|||||||
|
"""Plugin definition for Airflow Wingman."""
|
||||||
|
|
||||||
from airflow.plugins_manager import AirflowPlugin
|
from airflow.plugins_manager import AirflowPlugin
|
||||||
from flask_appbuilder import BaseView as AppBuilderBaseView, expose
|
|
||||||
from flask import Blueprint
|
from flask import Blueprint
|
||||||
|
|
||||||
from airflow_wingman.llms_models import MODELS
|
from airflow_wingman.views import WingmanView
|
||||||
|
|
||||||
|
|
||||||
|
# Create Blueprint
|
||||||
bp = Blueprint(
|
bp = Blueprint(
|
||||||
"wingman",
|
"wingman",
|
||||||
__name__,
|
__name__,
|
||||||
@@ -13,21 +14,6 @@ bp = Blueprint(
|
|||||||
static_url_path="/static/wingman",
|
static_url_path="/static/wingman",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class WingmanView(AppBuilderBaseView):
|
|
||||||
route_base = "/wingman"
|
|
||||||
default_view = "chat"
|
|
||||||
|
|
||||||
@expose("/")
|
|
||||||
def chat(self):
|
|
||||||
"""
|
|
||||||
Chat interface for Airflow Wingman.
|
|
||||||
"""
|
|
||||||
return self.render_template(
|
|
||||||
"wingman_chat.html", title="Airflow Wingman", models=MODELS
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Create AppBuilder View
|
# Create AppBuilder View
|
||||||
v_appbuilder_view = WingmanView()
|
v_appbuilder_view = WingmanView()
|
||||||
v_appbuilder_package = {
|
v_appbuilder_package = {
|
||||||
@@ -39,6 +25,8 @@ v_appbuilder_package = {
|
|||||||
|
|
||||||
# Create Plugin
|
# Create Plugin
|
||||||
class WingmanPlugin(AirflowPlugin):
|
class WingmanPlugin(AirflowPlugin):
|
||||||
|
"""Airflow plugin for Wingman chat interface."""
|
||||||
|
|
||||||
name = "wingman"
|
name = "wingman"
|
||||||
flask_blueprints = [bp]
|
flask_blueprints = [bp]
|
||||||
appbuilder_views = [v_appbuilder_package]
|
appbuilder_views = [v_appbuilder_package]
|
||||||
|
|||||||
@@ -11,6 +11,8 @@
|
|||||||
</div>
|
</div>
|
||||||
<div class="alert alert-info" style="margin: 15px;">
|
<div class="alert alert-info" style="margin: 15px;">
|
||||||
<p><strong>Note:</strong> For best results with function/tool calling capabilities, we recommend using models like Claude-3.5 Sonnet or GPT-4o. These models excel at understanding and using complex tools effectively.</p>
|
<p><strong>Note:</strong> For best results with function/tool calling capabilities, we recommend using models like Claude-3.5 Sonnet or GPT-4o. These models excel at understanding and using complex tools effectively.</p>
|
||||||
|
<hr style="margin: 10px 0;">
|
||||||
|
<p><strong>Security:</strong> For your security, API keys are required for each session and are never stored. If you refresh the page or close the browser, you'll need to enter your API key again. This ensures your API keys remain secure in shared environments.</p>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@@ -34,7 +36,6 @@
|
|||||||
name="model"
|
name="model"
|
||||||
value="{{ provider_id }}:{{ model.id }}"
|
value="{{ provider_id }}:{{ model.id }}"
|
||||||
{% if model.default %}checked{% endif %}
|
{% if model.default %}checked{% endif %}
|
||||||
data-endpoint="{{ provider.endpoint }}"
|
|
||||||
data-context-window="{{ model.context_window }}"
|
data-context-window="{{ model.context_window }}"
|
||||||
data-provider="{{ provider_id }}"
|
data-provider="{{ provider_id }}"
|
||||||
data-model-name="{{ model.name }}">
|
data-model-name="{{ model.name }}">
|
||||||
@@ -51,6 +52,21 @@
|
|||||||
<div class="form-group">
|
<div class="form-group">
|
||||||
<label for="modelName">Model Name</label>
|
<label for="modelName">Model Name</label>
|
||||||
<input type="text" class="form-control" id="modelName" placeholder="Enter model name for OpenRouter" disabled>
|
<input type="text" class="form-control" id="modelName" placeholder="Enter model name for OpenRouter" disabled>
|
||||||
|
<small class="form-text text-muted">Only required for OpenRouter provider</small>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- API Key Input -->
|
||||||
|
<div class="panel-body" style="border-top: 1px solid #ddd; padding-top: 15px;">
|
||||||
|
<div class="form-group">
|
||||||
|
<label for="api-key">API Key</label>
|
||||||
|
<input type="password"
|
||||||
|
class="form-control"
|
||||||
|
id="api-key"
|
||||||
|
placeholder="Enter API key for selected provider"
|
||||||
|
required
|
||||||
|
autocomplete="off">
|
||||||
|
<small class="text-muted">Your API key will be used for the selected provider</small>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@@ -74,30 +90,11 @@
|
|||||||
}
|
}
|
||||||
</style>
|
</style>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- API Key Input -->
|
|
||||||
<div class="panel panel-default mt-3">
|
|
||||||
<div class="panel-heading">
|
|
||||||
<h3 class="panel-title">API Key</h3>
|
|
||||||
</div>
|
|
||||||
<div class="panel-body">
|
|
||||||
<div class="form-group">
|
|
||||||
<input type="password"
|
|
||||||
class="form-control"
|
|
||||||
id="api-key"
|
|
||||||
placeholder="Enter your API key"
|
|
||||||
autocomplete="off">
|
|
||||||
<small class="text-muted">
|
|
||||||
Your API key will be used for the selected provider
|
|
||||||
</small>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- Main Chat Window -->
|
<!-- Main Chat Window -->
|
||||||
<div class="col-md-9">
|
<div class="col-md-9">
|
||||||
<div class="panel panel-default" style="height: calc(100vh - 250px); display: flex; flex-direction: column;">
|
<div class="panel panel-default" style="height: calc(80vh - 250px); display: flex; flex-direction: column;">
|
||||||
<div class="panel-body" style="flex-grow: 1; overflow-y: auto; padding: 15px;" id="chat-messages">
|
<div class="panel-body" style="flex-grow: 1; overflow-y: auto; padding: 15px;" id="chat-messages">
|
||||||
<!-- Messages will be dynamically added here -->
|
<!-- Messages will be dynamically added here -->
|
||||||
</div>
|
</div>
|
||||||
@@ -180,7 +177,7 @@ document.addEventListener('DOMContentLoaded', function() {
|
|||||||
const modelName = this.getAttribute('data-model-name');
|
const modelName = this.getAttribute('data-model-name');
|
||||||
console.log('Selected provider:', provider);
|
console.log('Selected provider:', provider);
|
||||||
console.log('Model name:', modelName);
|
console.log('Model name:', modelName);
|
||||||
|
|
||||||
if (provider === 'openrouter') {
|
if (provider === 'openrouter') {
|
||||||
console.log('Enabling model name input');
|
console.log('Enabling model name input');
|
||||||
modelNameInput.disabled = false;
|
modelNameInput.disabled = false;
|
||||||
@@ -201,7 +198,7 @@ document.addEventListener('DOMContentLoaded', function() {
|
|||||||
const modelName = defaultSelected.getAttribute('data-model-name');
|
const modelName = defaultSelected.getAttribute('data-model-name');
|
||||||
console.log('Initial provider:', provider);
|
console.log('Initial provider:', provider);
|
||||||
console.log('Initial model name:', modelName);
|
console.log('Initial model name:', modelName);
|
||||||
|
|
||||||
if (provider === 'openrouter') {
|
if (provider === 'openrouter') {
|
||||||
console.log('Initially enabling model name input');
|
console.log('Initially enabling model name input');
|
||||||
modelNameInput.disabled = false;
|
modelNameInput.disabled = false;
|
||||||
@@ -218,20 +215,105 @@ document.addEventListener('DOMContentLoaded', function() {
|
|||||||
const sendButton = document.getElementById('send-button');
|
const sendButton = document.getElementById('send-button');
|
||||||
const chatMessages = document.getElementById('chat-messages');
|
const chatMessages = document.getElementById('chat-messages');
|
||||||
|
|
||||||
|
let currentMessageDiv = null;
|
||||||
|
|
||||||
function addMessage(content, isUser) {
|
function addMessage(content, isUser) {
|
||||||
const messageDiv = document.createElement('div');
|
const messageDiv = document.createElement('div');
|
||||||
messageDiv.className = `message ${isUser ? 'message-user' : 'message-assistant'}`;
|
messageDiv.className = `message ${isUser ? 'message-user' : 'message-assistant'}`;
|
||||||
messageDiv.textContent = content;
|
messageDiv.textContent = content;
|
||||||
chatMessages.appendChild(messageDiv);
|
chatMessages.appendChild(messageDiv);
|
||||||
chatMessages.scrollTop = chatMessages.scrollHeight;
|
chatMessages.scrollTop = chatMessages.scrollHeight;
|
||||||
|
return messageDiv;
|
||||||
}
|
}
|
||||||
|
|
||||||
function sendMessage() {
|
async function sendMessage() {
|
||||||
const message = messageInput.value.trim();
|
const message = messageInput.value.trim();
|
||||||
if (message) {
|
if (!message) return;
|
||||||
addMessage(message, true);
|
|
||||||
messageInput.value = '';
|
// Get selected model
|
||||||
// TODO: Add API call to send message and get response
|
const selectedModel = document.querySelector('input[name="model"]:checked');
|
||||||
|
if (!selectedModel) {
|
||||||
|
alert('Please select a model');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const provider = selectedModel.getAttribute('data-provider');
|
||||||
|
const modelId = selectedModel.value.split(':')[1];
|
||||||
|
const modelName = provider === 'openrouter' ? modelNameInput.value : modelId;
|
||||||
|
|
||||||
|
// Clear input and add user message
|
||||||
|
messageInput.value = '';
|
||||||
|
addMessage(message, true);
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Create messages array with system message
|
||||||
|
const messages = [
|
||||||
|
{
|
||||||
|
role: 'system',
|
||||||
|
content: 'You are a helpful AI assistant integrated into Apache Airflow.'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: 'user',
|
||||||
|
content: message
|
||||||
|
}
|
||||||
|
];
|
||||||
|
|
||||||
|
// Create assistant message div
|
||||||
|
currentMessageDiv = addMessage('', false);
|
||||||
|
|
||||||
|
// Get API key
|
||||||
|
const apiKey = document.getElementById('api-key').value.trim();
|
||||||
|
if (!apiKey) {
|
||||||
|
alert('Please enter an API key');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send request
|
||||||
|
const response = await fetch('/wingman/chat', {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
provider: provider,
|
||||||
|
model: modelName,
|
||||||
|
messages: messages,
|
||||||
|
api_key: apiKey,
|
||||||
|
stream: true,
|
||||||
|
temperature: 0.7
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
const error = await response.json();
|
||||||
|
throw new Error(error.error || 'Failed to get response');
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle streaming response
|
||||||
|
const reader = response.body.getReader();
|
||||||
|
const decoder = new TextDecoder();
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
const { value, done } = await reader.read();
|
||||||
|
if (done) break;
|
||||||
|
|
||||||
|
const chunk = decoder.decode(value);
|
||||||
|
const lines = chunk.split('\n');
|
||||||
|
|
||||||
|
for (const line of lines) {
|
||||||
|
if (line.startsWith('data: ')) {
|
||||||
|
const content = line.slice(6);
|
||||||
|
if (content) {
|
||||||
|
currentMessageDiv.textContent += content;
|
||||||
|
chatMessages.scrollTop = chatMessages.scrollHeight;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error:', error);
|
||||||
|
currentMessageDiv.textContent = `Error: ${error.message}`;
|
||||||
|
currentMessageDiv.style.color = 'red';
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
76
airflow-wingman/src/airflow_wingman/views.py
Normal file
76
airflow-wingman/src/airflow_wingman/views.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
"""Views for Airflow Wingman plugin."""
|
||||||
|
|
||||||
|
from flask import Response, request, stream_with_context
|
||||||
|
from flask.json import jsonify
|
||||||
|
from flask_appbuilder import BaseView as AppBuilderBaseView, expose
|
||||||
|
|
||||||
|
from airflow_wingman.llm_client import LLMClient
|
||||||
|
from airflow_wingman.llms_models import MODELS
|
||||||
|
|
||||||
|
|
||||||
|
class WingmanView(AppBuilderBaseView):
|
||||||
|
"""View for Airflow Wingman plugin."""
|
||||||
|
|
||||||
|
route_base = "/wingman"
|
||||||
|
default_view = "chat"
|
||||||
|
|
||||||
|
@expose("/")
|
||||||
|
def chat(self):
|
||||||
|
"""Render chat interface."""
|
||||||
|
providers = {provider: info["name"] for provider, info in MODELS.items()}
|
||||||
|
return self.render_template("wingman_chat.html", title="Airflow Wingman", models=MODELS, providers=providers)
|
||||||
|
|
||||||
|
@expose("/chat", methods=["POST"])
|
||||||
|
async def chat_completion(self):
|
||||||
|
"""Handle chat completion requests."""
|
||||||
|
try:
|
||||||
|
data = self._validate_chat_request(request.get_json())
|
||||||
|
|
||||||
|
# Create a new client for this request
|
||||||
|
client = LLMClient(data["api_key"])
|
||||||
|
|
||||||
|
if data["stream"]:
|
||||||
|
return self._handle_streaming_response(client, data)
|
||||||
|
else:
|
||||||
|
return await self._handle_regular_response(client, data)
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
return jsonify({"error": str(e)}), 400
|
||||||
|
except Exception as e:
|
||||||
|
return jsonify({"error": str(e)}), 500
|
||||||
|
|
||||||
|
def _validate_chat_request(self, data: dict) -> dict:
|
||||||
|
"""Validate chat request data."""
|
||||||
|
if not data:
|
||||||
|
raise ValueError("No data provided")
|
||||||
|
|
||||||
|
required_fields = ["provider", "model", "messages", "api_key"]
|
||||||
|
missing = [f for f in required_fields if not data.get(f)]
|
||||||
|
if missing:
|
||||||
|
raise ValueError(f"Missing required fields: {', '.join(missing)}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"provider": data["provider"],
|
||||||
|
"model": data["model"],
|
||||||
|
"messages": data["messages"],
|
||||||
|
"api_key": data["api_key"],
|
||||||
|
"stream": data.get("stream", False),
|
||||||
|
"temperature": data.get("temperature", 0.7),
|
||||||
|
"max_tokens": data.get("max_tokens"),
|
||||||
|
}
|
||||||
|
|
||||||
|
def _handle_streaming_response(self, client: LLMClient, data: dict) -> Response:
|
||||||
|
"""Handle streaming response."""
|
||||||
|
|
||||||
|
async def generate():
|
||||||
|
async for chunk in await client.chat_completion(
|
||||||
|
messages=data["messages"], model=data["model"], provider=data["provider"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=True
|
||||||
|
):
|
||||||
|
yield f"data: {chunk}\n\n"
|
||||||
|
|
||||||
|
return Response(stream_with_context(generate()), mimetype="text/event-stream")
|
||||||
|
|
||||||
|
async def _handle_regular_response(self, client: LLMClient, data: dict) -> Response:
|
||||||
|
"""Handle regular response."""
|
||||||
|
response = await client.chat_completion(messages=data["messages"], model=data["model"], provider=data["provider"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=False)
|
||||||
|
return jsonify(response)
|
||||||
Reference in New Issue
Block a user