restructure code for modularity and use ai modules

This commit is contained in:
2025-02-24 13:11:09 +00:00
parent f3cc238130
commit eea23b2097
5 changed files with 304 additions and 47 deletions

View File

@@ -10,7 +10,9 @@ authors = [
] ]
dependencies = [ dependencies = [
"apache-airflow>=2.10.0", "apache-airflow>=2.10.0",
"airflow-mcp-server>=0.2.0" "airflow-mcp-server>=0.2.0",
"openai>=1.64.0",
"anthropic>=0.46.0"
] ]
classifiers = [ classifiers = [
"Development Status :: 3 - Alpha", "Development Status :: 3 - Alpha",

View File

@@ -0,0 +1,109 @@
"""
Client for making API calls to various LLM providers using their official SDKs.
"""
from collections.abc import AsyncGenerator
from anthropic import AsyncAnthropic
from openai import AsyncOpenAI
class LLMClient:
def __init__(self, api_key: str):
"""Initialize the LLM client.
Args:
api_key: API key for the provider
"""
self.api_key = api_key
self.openai_client = AsyncOpenAI(api_key=api_key)
self.anthropic_client = AsyncAnthropic(api_key=api_key)
self.openrouter_client = AsyncOpenAI(
base_url="https://openrouter.ai/api/v1",
api_key=api_key,
default_headers={
"HTTP-Referer": "http://localhost:8080", # Required by OpenRouter
"X-Title": "Airflow Wingman", # Required by OpenRouter
},
)
async def chat_completion(
self, messages: list[dict[str, str]], model: str, provider: str, temperature: float = 0.7, max_tokens: int | None = None, stream: bool = False
) -> AsyncGenerator[str, None] | dict:
"""Send a chat completion request to the specified provider.
Args:
messages: List of message dictionaries with 'role' and 'content'
model: Model identifier
provider: Provider identifier (openai, anthropic, openrouter)
temperature: Sampling temperature (0-1)
max_tokens: Maximum tokens to generate
stream: Whether to stream the response
Returns:
If stream=True, returns an async generator yielding response chunks
If stream=False, returns the complete response
"""
try:
if provider == "openai":
return await self._openai_chat_completion(messages, model, temperature, max_tokens, stream)
elif provider == "anthropic":
return await self._anthropic_chat_completion(messages, model, temperature, max_tokens, stream)
elif provider == "openrouter":
return await self._openrouter_chat_completion(messages, model, temperature, max_tokens, stream)
else:
return {"error": f"Unknown provider: {provider}"}
except Exception as e:
return {"error": f"API request failed: {str(e)}"}
async def _openai_chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float, max_tokens: int | None, stream: bool):
"""Handle OpenAI chat completion requests."""
response = await self.openai_client.chat.completions.create(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream)
if stream:
async def response_generator():
async for chunk in response:
if chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
return response_generator()
else:
return {"content": response.choices[0].message.content}
async def _anthropic_chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float, max_tokens: int | None, stream: bool):
"""Handle Anthropic chat completion requests."""
# Convert messages to Anthropic format
system_message = next((m["content"] for m in messages if m["role"] == "system"), None)
conversation = []
for m in messages:
if m["role"] != "system":
conversation.append({"role": "assistant" if m["role"] == "assistant" else "user", "content": m["content"]})
response = await self.anthropic_client.messages.create(model=model, messages=conversation, system=system_message, temperature=temperature, max_tokens=max_tokens, stream=stream)
if stream:
async def response_generator():
async for chunk in response:
if chunk.delta.text:
yield chunk.delta.text
return response_generator()
else:
return {"content": response.content[0].text}
async def _openrouter_chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float, max_tokens: int | None, stream: bool):
"""Handle OpenRouter chat completion requests."""
response = await self.openrouter_client.chat.completions.create(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream)
if stream:
async def response_generator():
async for chunk in response:
if chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
return response_generator()
else:
return {"content": response.choices[0].message.content}

View File

@@ -1,10 +1,11 @@
"""Plugin definition for Airflow Wingman."""
from airflow.plugins_manager import AirflowPlugin from airflow.plugins_manager import AirflowPlugin
from flask_appbuilder import BaseView as AppBuilderBaseView, expose
from flask import Blueprint from flask import Blueprint
from airflow_wingman.llms_models import MODELS from airflow_wingman.views import WingmanView
# Create Blueprint
bp = Blueprint( bp = Blueprint(
"wingman", "wingman",
__name__, __name__,
@@ -13,21 +14,6 @@ bp = Blueprint(
static_url_path="/static/wingman", static_url_path="/static/wingman",
) )
class WingmanView(AppBuilderBaseView):
route_base = "/wingman"
default_view = "chat"
@expose("/")
def chat(self):
"""
Chat interface for Airflow Wingman.
"""
return self.render_template(
"wingman_chat.html", title="Airflow Wingman", models=MODELS
)
# Create AppBuilder View # Create AppBuilder View
v_appbuilder_view = WingmanView() v_appbuilder_view = WingmanView()
v_appbuilder_package = { v_appbuilder_package = {
@@ -39,6 +25,8 @@ v_appbuilder_package = {
# Create Plugin # Create Plugin
class WingmanPlugin(AirflowPlugin): class WingmanPlugin(AirflowPlugin):
"""Airflow plugin for Wingman chat interface."""
name = "wingman" name = "wingman"
flask_blueprints = [bp] flask_blueprints = [bp]
appbuilder_views = [v_appbuilder_package] appbuilder_views = [v_appbuilder_package]

View File

@@ -11,6 +11,8 @@
</div> </div>
<div class="alert alert-info" style="margin: 15px;"> <div class="alert alert-info" style="margin: 15px;">
<p><strong>Note:</strong> For best results with function/tool calling capabilities, we recommend using models like Claude-3.5 Sonnet or GPT-4o. These models excel at understanding and using complex tools effectively.</p> <p><strong>Note:</strong> For best results with function/tool calling capabilities, we recommend using models like Claude-3.5 Sonnet or GPT-4o. These models excel at understanding and using complex tools effectively.</p>
<hr style="margin: 10px 0;">
<p><strong>Security:</strong> For your security, API keys are required for each session and are never stored. If you refresh the page or close the browser, you'll need to enter your API key again. This ensures your API keys remain secure in shared environments.</p>
</div> </div>
</div> </div>
</div> </div>
@@ -34,7 +36,6 @@
name="model" name="model"
value="{{ provider_id }}:{{ model.id }}" value="{{ provider_id }}:{{ model.id }}"
{% if model.default %}checked{% endif %} {% if model.default %}checked{% endif %}
data-endpoint="{{ provider.endpoint }}"
data-context-window="{{ model.context_window }}" data-context-window="{{ model.context_window }}"
data-provider="{{ provider_id }}" data-provider="{{ provider_id }}"
data-model-name="{{ model.name }}"> data-model-name="{{ model.name }}">
@@ -51,6 +52,21 @@
<div class="form-group"> <div class="form-group">
<label for="modelName">Model Name</label> <label for="modelName">Model Name</label>
<input type="text" class="form-control" id="modelName" placeholder="Enter model name for OpenRouter" disabled> <input type="text" class="form-control" id="modelName" placeholder="Enter model name for OpenRouter" disabled>
<small class="form-text text-muted">Only required for OpenRouter provider</small>
</div>
</div>
<!-- API Key Input -->
<div class="panel-body" style="border-top: 1px solid #ddd; padding-top: 15px;">
<div class="form-group">
<label for="api-key">API Key</label>
<input type="password"
class="form-control"
id="api-key"
placeholder="Enter API key for selected provider"
required
autocomplete="off">
<small class="text-muted">Your API key will be used for the selected provider</small>
</div> </div>
</div> </div>
@@ -74,30 +90,11 @@
} }
</style> </style>
</div> </div>
<!-- API Key Input -->
<div class="panel panel-default mt-3">
<div class="panel-heading">
<h3 class="panel-title">API Key</h3>
</div>
<div class="panel-body">
<div class="form-group">
<input type="password"
class="form-control"
id="api-key"
placeholder="Enter your API key"
autocomplete="off">
<small class="text-muted">
Your API key will be used for the selected provider
</small>
</div>
</div>
</div>
</div> </div>
<!-- Main Chat Window --> <!-- Main Chat Window -->
<div class="col-md-9"> <div class="col-md-9">
<div class="panel panel-default" style="height: calc(100vh - 250px); display: flex; flex-direction: column;"> <div class="panel panel-default" style="height: calc(80vh - 250px); display: flex; flex-direction: column;">
<div class="panel-body" style="flex-grow: 1; overflow-y: auto; padding: 15px;" id="chat-messages"> <div class="panel-body" style="flex-grow: 1; overflow-y: auto; padding: 15px;" id="chat-messages">
<!-- Messages will be dynamically added here --> <!-- Messages will be dynamically added here -->
</div> </div>
@@ -218,20 +215,105 @@ document.addEventListener('DOMContentLoaded', function() {
const sendButton = document.getElementById('send-button'); const sendButton = document.getElementById('send-button');
const chatMessages = document.getElementById('chat-messages'); const chatMessages = document.getElementById('chat-messages');
let currentMessageDiv = null;
function addMessage(content, isUser) { function addMessage(content, isUser) {
const messageDiv = document.createElement('div'); const messageDiv = document.createElement('div');
messageDiv.className = `message ${isUser ? 'message-user' : 'message-assistant'}`; messageDiv.className = `message ${isUser ? 'message-user' : 'message-assistant'}`;
messageDiv.textContent = content; messageDiv.textContent = content;
chatMessages.appendChild(messageDiv); chatMessages.appendChild(messageDiv);
chatMessages.scrollTop = chatMessages.scrollHeight; chatMessages.scrollTop = chatMessages.scrollHeight;
return messageDiv;
} }
function sendMessage() { async function sendMessage() {
const message = messageInput.value.trim(); const message = messageInput.value.trim();
if (message) { if (!message) return;
addMessage(message, true);
messageInput.value = ''; // Get selected model
// TODO: Add API call to send message and get response const selectedModel = document.querySelector('input[name="model"]:checked');
if (!selectedModel) {
alert('Please select a model');
return;
}
const provider = selectedModel.getAttribute('data-provider');
const modelId = selectedModel.value.split(':')[1];
const modelName = provider === 'openrouter' ? modelNameInput.value : modelId;
// Clear input and add user message
messageInput.value = '';
addMessage(message, true);
try {
// Create messages array with system message
const messages = [
{
role: 'system',
content: 'You are a helpful AI assistant integrated into Apache Airflow.'
},
{
role: 'user',
content: message
}
];
// Create assistant message div
currentMessageDiv = addMessage('', false);
// Get API key
const apiKey = document.getElementById('api-key').value.trim();
if (!apiKey) {
alert('Please enter an API key');
return;
}
// Send request
const response = await fetch('/wingman/chat', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
provider: provider,
model: modelName,
messages: messages,
api_key: apiKey,
stream: true,
temperature: 0.7
})
});
if (!response.ok) {
const error = await response.json();
throw new Error(error.error || 'Failed to get response');
}
// Handle streaming response
const reader = response.body.getReader();
const decoder = new TextDecoder();
while (true) {
const { value, done } = await reader.read();
if (done) break;
const chunk = decoder.decode(value);
const lines = chunk.split('\n');
for (const line of lines) {
if (line.startsWith('data: ')) {
const content = line.slice(6);
if (content) {
currentMessageDiv.textContent += content;
chatMessages.scrollTop = chatMessages.scrollHeight;
}
}
}
}
} catch (error) {
console.error('Error:', error);
currentMessageDiv.textContent = `Error: ${error.message}`;
currentMessageDiv.style.color = 'red';
} }
} }

View File

@@ -0,0 +1,76 @@
"""Views for Airflow Wingman plugin."""
from flask import Response, request, stream_with_context
from flask.json import jsonify
from flask_appbuilder import BaseView as AppBuilderBaseView, expose
from airflow_wingman.llm_client import LLMClient
from airflow_wingman.llms_models import MODELS
class WingmanView(AppBuilderBaseView):
"""View for Airflow Wingman plugin."""
route_base = "/wingman"
default_view = "chat"
@expose("/")
def chat(self):
"""Render chat interface."""
providers = {provider: info["name"] for provider, info in MODELS.items()}
return self.render_template("wingman_chat.html", title="Airflow Wingman", models=MODELS, providers=providers)
@expose("/chat", methods=["POST"])
async def chat_completion(self):
"""Handle chat completion requests."""
try:
data = self._validate_chat_request(request.get_json())
# Create a new client for this request
client = LLMClient(data["api_key"])
if data["stream"]:
return self._handle_streaming_response(client, data)
else:
return await self._handle_regular_response(client, data)
except ValueError as e:
return jsonify({"error": str(e)}), 400
except Exception as e:
return jsonify({"error": str(e)}), 500
def _validate_chat_request(self, data: dict) -> dict:
"""Validate chat request data."""
if not data:
raise ValueError("No data provided")
required_fields = ["provider", "model", "messages", "api_key"]
missing = [f for f in required_fields if not data.get(f)]
if missing:
raise ValueError(f"Missing required fields: {', '.join(missing)}")
return {
"provider": data["provider"],
"model": data["model"],
"messages": data["messages"],
"api_key": data["api_key"],
"stream": data.get("stream", False),
"temperature": data.get("temperature", 0.7),
"max_tokens": data.get("max_tokens"),
}
def _handle_streaming_response(self, client: LLMClient, data: dict) -> Response:
"""Handle streaming response."""
async def generate():
async for chunk in await client.chat_completion(
messages=data["messages"], model=data["model"], provider=data["provider"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=True
):
yield f"data: {chunk}\n\n"
return Response(stream_with_context(generate()), mimetype="text/event-stream")
async def _handle_regular_response(self, client: LLMClient, data: dict) -> Response:
"""Handle regular response."""
response = await client.chat_completion(messages=data["messages"], model=data["model"], provider=data["provider"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=False)
return jsonify(response)