9.8 KiB
Streamlit RAG Pipeline Initialization Issue
Problem Statement
When running the Streamlit app, sending "Hi" results in error: AI models are not properly configured. Please check your API keys.
Root Cause Analysis
- Embeddings returning None: The
_initialize_embeddings()method inRAGPipelinereturnsNonewhen initialization fails - Silent failures: Exceptions are caught but only logged as warnings, returning
Noneinstead of raising errors - Streamlit rerun behavior: Each interaction causes a full script rerun, potentially reinitializing models
- No persistence: Models are not stored in
st.session_state, causing repeated initialization attempts
Current Behavior
# Current problematic flow:
RAGPipeline.__init__()
→ _initialize_embeddings()
→ try/except returns None on failure
→ self.embeddings = None
→ Query fails with generic error
Proposed Design Changes
1. Singleton Pattern with Session State
Problem: Multiple RAGPipeline instances created on reruns Solution: Use Streamlit session state as singleton storage
# In app.py or streamlit_app.py
def get_rag_pipeline():
"""Get or create RAG pipeline with proper session state management"""
if 'rag_pipeline' not in st.session_state:
with st.spinner("Initializing AI models..."):
pipeline = RAGPipeline()
# Validate initialization
if pipeline.embeddings is None:
st.error("Failed to initialize embedding model")
st.stop()
if pipeline.llm is None:
st.error("Failed to initialize language model")
st.stop()
st.session_state.rag_pipeline = pipeline
st.success("AI models initialized successfully")
return st.session_state.rag_pipeline
2. Lazy Initialization with Caching
Problem: Models initialized in __init__ even if not needed
Solution: Use lazy properties with @st.cache_resource
class RAGPipeline:
def __init__(self):
self._embeddings = None
self._llm = None
self.db_path = "data/lancedb"
self.db = lancedb.connect(self.db_path)
@property
def embeddings(self):
if self._embeddings is None:
self._embeddings = self._get_or_create_embeddings()
return self._embeddings
@property
def llm(self):
if self._llm is None:
self._llm = self._get_or_create_llm()
return self._llm
@st.cache_resource
def _get_or_create_embeddings(_self):
"""Cached embedding model creation"""
return _self._initialize_embeddings()
@st.cache_resource
def _get_or_create_llm(_self):
"""Cached LLM creation"""
return _self._initialize_llm()
3. Explicit Error Handling
Problem: Silent failures with return None
Solution: Raise exceptions with clear messages
def _initialize_embeddings(self):
"""Initialize embeddings with explicit error handling"""
model_type = config.EMBEDDING_MODEL
try:
if model_type == "google":
if not config.GOOGLE_API_KEY:
raise ValueError("GOOGLE_API_KEY is not set in environment variables")
# Try to import and initialize
try:
from langchain_google_genai import GoogleGenerativeAIEmbeddings
except ImportError as e:
raise ImportError(f"Failed to import GoogleGenerativeAIEmbeddings: {e}")
embeddings = GoogleGenerativeAIEmbeddings(
model=config.GOOGLE_EMBEDDING_MODEL,
google_api_key=config.GOOGLE_API_KEY
)
# Test the embeddings work
test_embedding = embeddings.embed_query("test")
if not test_embedding:
raise ValueError("Embeddings returned empty result for test query")
return embeddings
elif model_type == "openai":
# Similar explicit handling for OpenAI
pass
else:
raise ValueError(f"Unsupported embedding model: {model_type}")
except Exception as e:
logger.error(f"Failed to initialize {model_type} embeddings: {str(e)}")
# Re-raise with context
raise RuntimeError(f"Embedding initialization failed: {str(e)}") from e
4. Configuration Validation
Problem: No upfront validation of configuration Solution: Add configuration validator
def validate_ai_config():
"""Validate AI configuration before initialization"""
errors = []
# Check embedding model configuration
if config.EMBEDDING_MODEL == "google":
if not config.GOOGLE_API_KEY:
errors.append("GOOGLE_API_KEY not set for Google embedding model")
if not config.GOOGLE_EMBEDDING_MODEL:
errors.append("GOOGLE_EMBEDDING_MODEL not configured")
# Check LLM configuration
if config.LLM_MODEL == "google":
if not config.GOOGLE_API_KEY:
errors.append("GOOGLE_API_KEY not set for Google LLM")
if not config.GOOGLE_MODEL:
errors.append("GOOGLE_MODEL not configured")
if errors:
return False, errors
return True, []
# Use in Streamlit app
def initialize_app():
valid, errors = validate_ai_config()
if not valid:
st.error("Configuration errors detected:")
for error in errors:
st.error(f"• {error}")
st.stop()
5. Model Factory Pattern
Problem: Model initialization logic mixed with pipeline logic Solution: Separate model creation into factory
class ModelFactory:
"""Factory for creating AI models with proper error handling"""
@staticmethod
def create_embeddings(model_type: str):
"""Create embedding model based on type"""
creators = {
"google": ModelFactory._create_google_embeddings,
"openai": ModelFactory._create_openai_embeddings,
"huggingface": ModelFactory._create_huggingface_embeddings
}
creator = creators.get(model_type)
if not creator:
raise ValueError(f"Unknown embedding model type: {model_type}")
return creator()
@staticmethod
def _create_google_embeddings():
"""Create Google embeddings with validation"""
if not config.GOOGLE_API_KEY:
raise ValueError("GOOGLE_API_KEY not configured")
from langchain_google_genai import GoogleGenerativeAIEmbeddings
embeddings = GoogleGenerativeAIEmbeddings(
model=config.GOOGLE_EMBEDDING_MODEL,
google_api_key=config.GOOGLE_API_KEY
)
# Validate it works
try:
test_result = embeddings.embed_query("test")
if not test_result or len(test_result) == 0:
raise ValueError("Embeddings test failed")
except Exception as e:
raise ValueError(f"Embeddings validation failed: {e}")
return embeddings
Implementation Steps
-
Update RAGPipeline class (
src/clm_system/rag.py)- Remove direct initialization in
__init__ - Add lazy property getters
- Implement explicit error handling
- Remove all
return Nonepatterns
- Remove direct initialization in
-
Create ModelFactory (
src/clm_system/model_factory.py)- Implement factory methods for each model type
- Add validation for each model
- Include test queries to verify models work
-
Update Streamlit app (
debug_streamlit.pyor create newapp.py)- Use session state for pipeline storage
- Add configuration validation on startup
- Show clear error messages to users
- Add retry mechanism for transient failures
-
Add configuration validator (
src/clm_system/validators.py)- Check all required environment variables
- Validate API keys format (if applicable)
- Test connectivity to services
-
Update config.py
- Add helper methods for validation
- Include default fallbacks where appropriate
- Better error messages for missing values
Testing Plan
-
Unit Tests
- Test model initialization with valid config
- Test model initialization with missing config
- Test error handling and messages
-
Integration Tests
- Test full pipeline initialization
- Test Streamlit session state persistence
- Test recovery from failures
-
Manual Testing
- Start app with valid config → should work
- Start app with missing API key → clear error
- Send query after initialization → should respond
- Refresh page → should maintain state
Success Criteria
- ✅ Clear error messages when configuration is invalid
- ✅ Models initialize once and persist in session state
- ✅ No silent failures (no
return None) - ✅ App handles "Hi" message successfully
- ✅ Page refreshes don't reinitialize models
- ✅ Failed initialization stops app with helpful message
Rollback Plan
If changes cause issues:
- Revert to original code
- Add temporary workaround in Streamlit app:
if 'rag_pipeline' not in st.session_state: # Force multiple initialization attempts for attempt in range(3): try: pipeline = RAGPipeline() if pipeline.embeddings and pipeline.llm: st.session_state.rag_pipeline = pipeline break except Exception as e: if attempt == 2: st.error(f"Failed after 3 attempts: {e}") st.stop()
Notes
- Current code uses Google models (
EMBEDDING_MODEL=google,LLM_MODEL=google) - Google API key is set in environment
- Issue only occurs in Streamlit, CLI works fine
- Root cause: Streamlit's execution model + silent failures in initialization