Use synchronous calls

This commit is contained in:
2025-02-24 13:35:36 +00:00
parent 093577dd96
commit f89f6e7d5e
2 changed files with 35 additions and 33 deletions

View File

@@ -2,10 +2,10 @@
Client for making API calls to various LLM providers using their official SDKs. Client for making API calls to various LLM providers using their official SDKs.
""" """
from collections.abc import AsyncGenerator from collections.abc import Generator
from anthropic import AsyncAnthropic from anthropic import Anthropic
from openai import AsyncOpenAI from openai import OpenAI
class LLMClient: class LLMClient:
@@ -16,9 +16,9 @@ class LLMClient:
api_key: API key for the provider api_key: API key for the provider
""" """
self.api_key = api_key self.api_key = api_key
self.openai_client = AsyncOpenAI(api_key=api_key) self.openai_client = OpenAI(api_key=api_key)
self.anthropic_client = AsyncAnthropic(api_key=api_key) self.anthropic_client = Anthropic(api_key=api_key)
self.openrouter_client = AsyncOpenAI( self.openrouter_client = OpenAI(
base_url="https://openrouter.ai/api/v1", base_url="https://openrouter.ai/api/v1",
api_key=api_key, api_key=api_key,
default_headers={ default_headers={
@@ -27,9 +27,9 @@ class LLMClient:
}, },
) )
async def chat_completion( def chat_completion(
self, messages: list[dict[str, str]], model: str, provider: str, temperature: float = 0.7, max_tokens: int | None = None, stream: bool = False self, messages: list[dict[str, str]], model: str, provider: str, temperature: float = 0.7, max_tokens: int | None = None, stream: bool = False
) -> AsyncGenerator[str, None] | dict: ) -> Generator[str, None, None] | dict:
"""Send a chat completion request to the specified provider. """Send a chat completion request to the specified provider.
Args: Args:
@@ -41,29 +41,29 @@ class LLMClient:
stream: Whether to stream the response stream: Whether to stream the response
Returns: Returns:
If stream=True, returns an async generator yielding response chunks If stream=True, returns a generator yielding response chunks
If stream=False, returns the complete response If stream=False, returns the complete response
""" """
try: try:
if provider == "openai": if provider == "openai":
return await self._openai_chat_completion(messages, model, temperature, max_tokens, stream) return self._openai_chat_completion(messages, model, temperature, max_tokens, stream)
elif provider == "anthropic": elif provider == "anthropic":
return await self._anthropic_chat_completion(messages, model, temperature, max_tokens, stream) return self._anthropic_chat_completion(messages, model, temperature, max_tokens, stream)
elif provider == "openrouter": elif provider == "openrouter":
return await self._openrouter_chat_completion(messages, model, temperature, max_tokens, stream) return self._openrouter_chat_completion(messages, model, temperature, max_tokens, stream)
else: else:
return {"error": f"Unknown provider: {provider}"} return {"error": f"Unknown provider: {provider}"}
except Exception as e: except Exception as e:
return {"error": f"API request failed: {str(e)}"} return {"error": f"API request failed: {str(e)}"}
async def _openai_chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float, max_tokens: int | None, stream: bool): def _openai_chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float, max_tokens: int | None, stream: bool):
"""Handle OpenAI chat completion requests.""" """Handle OpenAI chat completion requests."""
response = await self.openai_client.chat.completions.create(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream) response = self.openai_client.chat.completions.create(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream)
if stream: if stream:
async def response_generator(): def response_generator():
async for chunk in response: for chunk in response:
if chunk.choices[0].delta.content: if chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content yield chunk.choices[0].delta.content
@@ -71,7 +71,7 @@ class LLMClient:
else: else:
return {"content": response.choices[0].message.content} return {"content": response.choices[0].message.content}
async def _anthropic_chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float, max_tokens: int | None, stream: bool): def _anthropic_chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float, max_tokens: int | None, stream: bool):
"""Handle Anthropic chat completion requests.""" """Handle Anthropic chat completion requests."""
# Convert messages to Anthropic format # Convert messages to Anthropic format
system_message = next((m["content"] for m in messages if m["role"] == "system"), None) system_message = next((m["content"] for m in messages if m["role"] == "system"), None)
@@ -80,12 +80,12 @@ class LLMClient:
if m["role"] != "system": if m["role"] != "system":
conversation.append({"role": "assistant" if m["role"] == "assistant" else "user", "content": m["content"]}) conversation.append({"role": "assistant" if m["role"] == "assistant" else "user", "content": m["content"]})
response = await self.anthropic_client.messages.create(model=model, messages=conversation, system=system_message, temperature=temperature, max_tokens=max_tokens, stream=stream) response = self.anthropic_client.messages.create(model=model, messages=conversation, system=system_message, temperature=temperature, max_tokens=max_tokens, stream=stream)
if stream: if stream:
async def response_generator(): def response_generator():
async for chunk in response: for chunk in response:
if chunk.delta.text: if chunk.delta.text:
yield chunk.delta.text yield chunk.delta.text
@@ -93,14 +93,14 @@ class LLMClient:
else: else:
return {"content": response.content[0].text} return {"content": response.content[0].text}
async def _openrouter_chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float, max_tokens: int | None, stream: bool): def _openrouter_chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float, max_tokens: int | None, stream: bool):
"""Handle OpenRouter chat completion requests.""" """Handle OpenRouter chat completion requests."""
response = await self.openrouter_client.chat.completions.create(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream) response = self.openrouter_client.chat.completions.create(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream)
if stream: if stream:
async def response_generator(): def response_generator():
async for chunk in response: for chunk in response:
if chunk.choices[0].delta.content: if chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content yield chunk.choices[0].delta.content

View File

@@ -21,7 +21,7 @@ class WingmanView(AppBuilderBaseView):
return self.render_template("wingman_chat.html", title="Airflow Wingman", models=MODELS, providers=providers) return self.render_template("wingman_chat.html", title="Airflow Wingman", models=MODELS, providers=providers)
@expose("/chat", methods=["POST"]) @expose("/chat", methods=["POST"])
async def chat_completion(self): def chat_completion(self):
"""Handle chat completion requests.""" """Handle chat completion requests."""
try: try:
data = self._validate_chat_request(request.get_json()) data = self._validate_chat_request(request.get_json())
@@ -32,7 +32,7 @@ class WingmanView(AppBuilderBaseView):
if data["stream"]: if data["stream"]:
return self._handle_streaming_response(client, data) return self._handle_streaming_response(client, data)
else: else:
return await self._handle_regular_response(client, data) return self._handle_regular_response(client, data)
except ValueError as e: except ValueError as e:
return jsonify({"error": str(e)}), 400 return jsonify({"error": str(e)}), 400
@@ -62,15 +62,17 @@ class WingmanView(AppBuilderBaseView):
def _handle_streaming_response(self, client: LLMClient, data: dict) -> Response: def _handle_streaming_response(self, client: LLMClient, data: dict) -> Response:
"""Handle streaming response.""" """Handle streaming response."""
async def generate(): def generate():
async for chunk in await client.chat_completion( for chunk in client.chat_completion(messages=data["messages"], model=data["model"], provider=data["provider"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=True):
messages=data["messages"], model=data["model"], provider=data["provider"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=True
):
yield f"data: {chunk}\n\n" yield f"data: {chunk}\n\n"
return Response(stream_with_context(generate()), mimetype="text/event-stream") response = Response(stream_with_context(generate()), mimetype="text/event-stream")
response.headers["Content-Type"] = "text/event-stream"
response.headers["Cache-Control"] = "no-cache"
response.headers["Connection"] = "keep-alive"
return response
async def _handle_regular_response(self, client: LLMClient, data: dict) -> Response: def _handle_regular_response(self, client: LLMClient, data: dict) -> Response:
"""Handle regular response.""" """Handle regular response."""
response = await client.chat_completion(messages=data["messages"], model=data["model"], provider=data["provider"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=False) response = client.chat_completion(messages=data["messages"], model=data["model"], provider=data["provider"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=False)
return jsonify(response) return jsonify(response)