diff --git a/airflow-wingman/src/airflow_wingman/llm_client.py b/airflow-wingman/src/airflow_wingman/llm_client.py index dc2543c..0c51316 100644 --- a/airflow-wingman/src/airflow_wingman/llm_client.py +++ b/airflow-wingman/src/airflow_wingman/llm_client.py @@ -2,10 +2,10 @@ Client for making API calls to various LLM providers using their official SDKs. """ -from collections.abc import AsyncGenerator +from collections.abc import Generator -from anthropic import AsyncAnthropic -from openai import AsyncOpenAI +from anthropic import Anthropic +from openai import OpenAI class LLMClient: @@ -16,9 +16,9 @@ class LLMClient: api_key: API key for the provider """ self.api_key = api_key - self.openai_client = AsyncOpenAI(api_key=api_key) - self.anthropic_client = AsyncAnthropic(api_key=api_key) - self.openrouter_client = AsyncOpenAI( + self.openai_client = OpenAI(api_key=api_key) + self.anthropic_client = Anthropic(api_key=api_key) + self.openrouter_client = OpenAI( base_url="https://openrouter.ai/api/v1", api_key=api_key, default_headers={ @@ -27,9 +27,9 @@ class LLMClient: }, ) - async def chat_completion( + def chat_completion( self, messages: list[dict[str, str]], model: str, provider: str, temperature: float = 0.7, max_tokens: int | None = None, stream: bool = False - ) -> AsyncGenerator[str, None] | dict: + ) -> Generator[str, None, None] | dict: """Send a chat completion request to the specified provider. Args: @@ -41,29 +41,29 @@ class LLMClient: stream: Whether to stream the response Returns: - If stream=True, returns an async generator yielding response chunks + If stream=True, returns a generator yielding response chunks If stream=False, returns the complete response """ try: if provider == "openai": - return await self._openai_chat_completion(messages, model, temperature, max_tokens, stream) + return self._openai_chat_completion(messages, model, temperature, max_tokens, stream) elif provider == "anthropic": - return await self._anthropic_chat_completion(messages, model, temperature, max_tokens, stream) + return self._anthropic_chat_completion(messages, model, temperature, max_tokens, stream) elif provider == "openrouter": - return await self._openrouter_chat_completion(messages, model, temperature, max_tokens, stream) + return self._openrouter_chat_completion(messages, model, temperature, max_tokens, stream) else: return {"error": f"Unknown provider: {provider}"} except Exception as e: return {"error": f"API request failed: {str(e)}"} - async def _openai_chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float, max_tokens: int | None, stream: bool): + def _openai_chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float, max_tokens: int | None, stream: bool): """Handle OpenAI chat completion requests.""" - response = await self.openai_client.chat.completions.create(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream) + response = self.openai_client.chat.completions.create(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream) if stream: - async def response_generator(): - async for chunk in response: + def response_generator(): + for chunk in response: if chunk.choices[0].delta.content: yield chunk.choices[0].delta.content @@ -71,7 +71,7 @@ class LLMClient: else: return {"content": response.choices[0].message.content} - async def _anthropic_chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float, max_tokens: int | None, stream: bool): + def _anthropic_chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float, max_tokens: int | None, stream: bool): """Handle Anthropic chat completion requests.""" # Convert messages to Anthropic format system_message = next((m["content"] for m in messages if m["role"] == "system"), None) @@ -80,12 +80,12 @@ class LLMClient: if m["role"] != "system": conversation.append({"role": "assistant" if m["role"] == "assistant" else "user", "content": m["content"]}) - response = await self.anthropic_client.messages.create(model=model, messages=conversation, system=system_message, temperature=temperature, max_tokens=max_tokens, stream=stream) + response = self.anthropic_client.messages.create(model=model, messages=conversation, system=system_message, temperature=temperature, max_tokens=max_tokens, stream=stream) if stream: - async def response_generator(): - async for chunk in response: + def response_generator(): + for chunk in response: if chunk.delta.text: yield chunk.delta.text @@ -93,14 +93,14 @@ class LLMClient: else: return {"content": response.content[0].text} - async def _openrouter_chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float, max_tokens: int | None, stream: bool): + def _openrouter_chat_completion(self, messages: list[dict[str, str]], model: str, temperature: float, max_tokens: int | None, stream: bool): """Handle OpenRouter chat completion requests.""" - response = await self.openrouter_client.chat.completions.create(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream) + response = self.openrouter_client.chat.completions.create(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream) if stream: - async def response_generator(): - async for chunk in response: + def response_generator(): + for chunk in response: if chunk.choices[0].delta.content: yield chunk.choices[0].delta.content diff --git a/airflow-wingman/src/airflow_wingman/views.py b/airflow-wingman/src/airflow_wingman/views.py index 0dc6934..a56b7dc 100644 --- a/airflow-wingman/src/airflow_wingman/views.py +++ b/airflow-wingman/src/airflow_wingman/views.py @@ -21,7 +21,7 @@ class WingmanView(AppBuilderBaseView): return self.render_template("wingman_chat.html", title="Airflow Wingman", models=MODELS, providers=providers) @expose("/chat", methods=["POST"]) - async def chat_completion(self): + def chat_completion(self): """Handle chat completion requests.""" try: data = self._validate_chat_request(request.get_json()) @@ -32,7 +32,7 @@ class WingmanView(AppBuilderBaseView): if data["stream"]: return self._handle_streaming_response(client, data) else: - return await self._handle_regular_response(client, data) + return self._handle_regular_response(client, data) except ValueError as e: return jsonify({"error": str(e)}), 400 @@ -62,15 +62,17 @@ class WingmanView(AppBuilderBaseView): def _handle_streaming_response(self, client: LLMClient, data: dict) -> Response: """Handle streaming response.""" - async def generate(): - async for chunk in await client.chat_completion( - messages=data["messages"], model=data["model"], provider=data["provider"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=True - ): + def generate(): + for chunk in client.chat_completion(messages=data["messages"], model=data["model"], provider=data["provider"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=True): yield f"data: {chunk}\n\n" - return Response(stream_with_context(generate()), mimetype="text/event-stream") + response = Response(stream_with_context(generate()), mimetype="text/event-stream") + response.headers["Content-Type"] = "text/event-stream" + response.headers["Cache-Control"] = "no-cache" + response.headers["Connection"] = "keep-alive" + return response - async def _handle_regular_response(self, client: LLMClient, data: dict) -> Response: + def _handle_regular_response(self, client: LLMClient, data: dict) -> Response: """Handle regular response.""" - response = await client.chat_completion(messages=data["messages"], model=data["model"], provider=data["provider"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=False) + response = client.chat_completion(messages=data["messages"], model=data["model"], provider=data["provider"], temperature=data["temperature"], max_tokens=data["max_tokens"], stream=False) return jsonify(response)