From dc5e2ef7c2142fdade7d41f40894cb5ddf24bf66 Mon Sep 17 00:00:00 2001 From: abhishekbhakat Date: Sun, 2 Mar 2025 16:51:49 +0000 Subject: [PATCH] stream follow up responses --- src/airflow_wingman/llm_client.py | 42 +++++++++++++++++++++---------- src/airflow_wingman/views.py | 13 +++++++++- 2 files changed, 41 insertions(+), 14 deletions(-) diff --git a/src/airflow_wingman/llm_client.py b/src/airflow_wingman/llm_client.py index f9d818d..5cd1071 100644 --- a/src/airflow_wingman/llm_client.py +++ b/src/airflow_wingman/llm_client.py @@ -141,7 +141,7 @@ class LLMClient: return cls(provider_name=provider_name, api_key=api_key, base_url=base_url) - def process_tool_calls_and_follow_up(self, response, messages, model, temperature, max_tokens, max_iterations=5, cookie=None): + def process_tool_calls_and_follow_up(self, response, messages, model, temperature, max_tokens, max_iterations=5, cookie=None, stream=True): """ Process tool calls recursively from a response and make follow-up requests until there are no more tool calls or max_iterations is reached. @@ -155,6 +155,7 @@ class LLMClient: max_tokens: Maximum tokens to generate max_iterations: Maximum number of tool call iterations to prevent infinite loops cookie: Airflow cookie for authentication (optional, will try to get from session if not provided) + stream: Whether to stream the response Returns: Generator for streaming the final follow-up response @@ -181,8 +182,10 @@ class LLMClient: # Make follow-up request with tool results logger.info(f"Making follow-up request with tool results (iteration {iteration})") - # Only stream on the final iteration - should_stream = (iteration == max_iterations) or not self.provider.has_tool_calls(current_response) + # Always stream follow-up requests to ensure consistent behavior + # This ensures we get streaming responses from the provider + should_stream = True + logger.info(f"Setting should_stream=True for follow-up request (iteration {iteration})") # Get provider-specific tool definitions from Airflow tools provider_tools = self.provider.convert_tools(self.airflow_tools) @@ -201,15 +204,14 @@ class LLMClient: # Check if this follow-up response has more tool calls if not self.provider.has_tool_calls(follow_up_response): logger.info(f"No more tool calls after iteration {iteration}") - # Final response - return the streaming content - if not should_stream: - # If we didn't stream this response, we need to make a streaming version - content = self.provider.get_content(follow_up_response) - yield content - return - else: - # Return the streaming generator - return self.provider.get_streaming_content(follow_up_response) + # Final response - always yield content in a streaming fashion + # Since we're always streaming now, we can directly yield chunks from the streaming generator + chunk_count = 0 + for chunk in self.provider.get_streaming_content(follow_up_response): + chunk_count += 1 + # logger.info(f"Yielding chunk {chunk_count} from streaming generator: {chunk[:50] if chunk else 'Empty chunk'}...") + yield chunk + logger.info(f"Finished yielding {chunk_count} chunks from streaming generator") # Update current_response for the next iteration current_response = follow_up_response @@ -218,7 +220,21 @@ class LLMClient: if iteration == max_iterations and self.provider.has_tool_calls(current_response): logger.warning(f"Reached maximum tool call iterations ({max_iterations})") # Stream the final response even if it has tool calls - return self.provider.get_streaming_content(follow_up_response) + if not should_stream: + # If we didn't stream this response, convert it to a single chunk + content = self.provider.get_content(follow_up_response) + logger.info(f"Yielding complete content as a single chunk (max iterations): {content[:100]}...") + yield content + logger.info("Finished yielding complete content (max iterations)") + else: + # Yield chunks from the streaming generator + logger.info("Starting to yield chunks from streaming generator (max iterations reached)") + chunk_count = 0 + for chunk in self.provider.get_streaming_content(follow_up_response): + chunk_count += 1 + logger.info(f"Yielding chunk {chunk_count} from streaming generator (max iterations)") + yield chunk + logger.info(f"Finished yielding {chunk_count} chunks from streaming generator (max iterations)") # If we didn't process any tool calls (shouldn't happen), return an error if iteration == 0: diff --git a/src/airflow_wingman/views.py b/src/airflow_wingman/views.py index 14417f5..f02d508 100644 --- a/src/airflow_wingman/views.py +++ b/src/airflow_wingman/views.py @@ -147,13 +147,24 @@ class WingmanView(AppBuilderBaseView): logger.info(f"Using cookie from closure: {cookie is not None}") # Process tool calls and get follow-up response (handles recursive tool calls) - follow_up_response = client.process_tool_calls_and_follow_up(streaming_response, data["messages"], data["model"], data["temperature"], data["max_tokens"], cookie=cookie) + # Always stream the follow-up response for consistent handling + follow_up_response = client.process_tool_calls_and_follow_up( + streaming_response, data["messages"], data["model"], data["temperature"], data["max_tokens"], cookie=cookie, stream=True + ) # Stream the follow-up response + follow_up_complete_response = "" for chunk in follow_up_response: if chunk: + follow_up_complete_response += chunk + # logger.info(f"Yielding chunk to frontend: {chunk[:50]}...") yield f"data: {chunk}\n\n" + # Log the complete follow-up response + logger.info("FOLLOW-UP RESPONSE START >>>") + logger.info(follow_up_complete_response) + logger.info("<<< FOLLOW-UP RESPONSE END") + # Signal tool processing complete - frontend can re-enable send button yield f"data: {json.dumps({'event': 'tool_processing_complete'})}\n\n"