Files
clm-system/PLANNING/streamlit_init_issue.md

296 lines
9.8 KiB
Markdown

# Streamlit RAG Pipeline Initialization Issue
## Problem Statement
When running the Streamlit app, sending "Hi" results in error: `AI models are not properly configured. Please check your API keys.`
### Root Cause Analysis
1. **Embeddings returning None**: The `_initialize_embeddings()` method in `RAGPipeline` returns `None` when initialization fails
2. **Silent failures**: Exceptions are caught but only logged as warnings, returning `None` instead of raising errors
3. **Streamlit rerun behavior**: Each interaction causes a full script rerun, potentially reinitializing models
4. **No persistence**: Models are not stored in `st.session_state`, causing repeated initialization attempts
### Current Behavior
```python
# Current problematic flow:
RAGPipeline.__init__()
_initialize_embeddings()
try/except returns None on failure
self.embeddings = None
Query fails with generic error
```
## Proposed Design Changes
### 1. Singleton Pattern with Session State
**Problem**: Multiple RAGPipeline instances created on reruns
**Solution**: Use Streamlit session state as singleton storage
```python
# In app.py or streamlit_app.py
def get_rag_pipeline():
"""Get or create RAG pipeline with proper session state management"""
if 'rag_pipeline' not in st.session_state:
with st.spinner("Initializing AI models..."):
pipeline = RAGPipeline()
# Validate initialization
if pipeline.embeddings is None:
st.error("Failed to initialize embedding model")
st.stop()
if pipeline.llm is None:
st.error("Failed to initialize language model")
st.stop()
st.session_state.rag_pipeline = pipeline
st.success("AI models initialized successfully")
return st.session_state.rag_pipeline
```
### 2. Lazy Initialization with Caching
**Problem**: Models initialized in `__init__` even if not needed
**Solution**: Use lazy properties with `@st.cache_resource`
```python
class RAGPipeline:
def __init__(self):
self._embeddings = None
self._llm = None
self.db_path = "data/lancedb"
self.db = lancedb.connect(self.db_path)
@property
def embeddings(self):
if self._embeddings is None:
self._embeddings = self._get_or_create_embeddings()
return self._embeddings
@property
def llm(self):
if self._llm is None:
self._llm = self._get_or_create_llm()
return self._llm
@st.cache_resource
def _get_or_create_embeddings(_self):
"""Cached embedding model creation"""
return _self._initialize_embeddings()
@st.cache_resource
def _get_or_create_llm(_self):
"""Cached LLM creation"""
return _self._initialize_llm()
```
### 3. Explicit Error Handling
**Problem**: Silent failures with `return None`
**Solution**: Raise exceptions with clear messages
```python
def _initialize_embeddings(self):
"""Initialize embeddings with explicit error handling"""
model_type = config.EMBEDDING_MODEL
try:
if model_type == "google":
if not config.GOOGLE_API_KEY:
raise ValueError("GOOGLE_API_KEY is not set in environment variables")
# Try to import and initialize
try:
from langchain_google_genai import GoogleGenerativeAIEmbeddings
except ImportError as e:
raise ImportError(f"Failed to import GoogleGenerativeAIEmbeddings: {e}")
embeddings = GoogleGenerativeAIEmbeddings(
model=config.GOOGLE_EMBEDDING_MODEL,
google_api_key=config.GOOGLE_API_KEY
)
# Test the embeddings work
test_embedding = embeddings.embed_query("test")
if not test_embedding:
raise ValueError("Embeddings returned empty result for test query")
return embeddings
elif model_type == "openai":
# Similar explicit handling for OpenAI
pass
else:
raise ValueError(f"Unsupported embedding model: {model_type}")
except Exception as e:
logger.error(f"Failed to initialize {model_type} embeddings: {str(e)}")
# Re-raise with context
raise RuntimeError(f"Embedding initialization failed: {str(e)}") from e
```
### 4. Configuration Validation
**Problem**: No upfront validation of configuration
**Solution**: Add configuration validator
```python
def validate_ai_config():
"""Validate AI configuration before initialization"""
errors = []
# Check embedding model configuration
if config.EMBEDDING_MODEL == "google":
if not config.GOOGLE_API_KEY:
errors.append("GOOGLE_API_KEY not set for Google embedding model")
if not config.GOOGLE_EMBEDDING_MODEL:
errors.append("GOOGLE_EMBEDDING_MODEL not configured")
# Check LLM configuration
if config.LLM_MODEL == "google":
if not config.GOOGLE_API_KEY:
errors.append("GOOGLE_API_KEY not set for Google LLM")
if not config.GOOGLE_MODEL:
errors.append("GOOGLE_MODEL not configured")
if errors:
return False, errors
return True, []
# Use in Streamlit app
def initialize_app():
valid, errors = validate_ai_config()
if not valid:
st.error("Configuration errors detected:")
for error in errors:
st.error(f"• {error}")
st.stop()
```
### 5. Model Factory Pattern
**Problem**: Model initialization logic mixed with pipeline logic
**Solution**: Separate model creation into factory
```python
class ModelFactory:
"""Factory for creating AI models with proper error handling"""
@staticmethod
def create_embeddings(model_type: str):
"""Create embedding model based on type"""
creators = {
"google": ModelFactory._create_google_embeddings,
"openai": ModelFactory._create_openai_embeddings,
"huggingface": ModelFactory._create_huggingface_embeddings
}
creator = creators.get(model_type)
if not creator:
raise ValueError(f"Unknown embedding model type: {model_type}")
return creator()
@staticmethod
def _create_google_embeddings():
"""Create Google embeddings with validation"""
if not config.GOOGLE_API_KEY:
raise ValueError("GOOGLE_API_KEY not configured")
from langchain_google_genai import GoogleGenerativeAIEmbeddings
embeddings = GoogleGenerativeAIEmbeddings(
model=config.GOOGLE_EMBEDDING_MODEL,
google_api_key=config.GOOGLE_API_KEY
)
# Validate it works
try:
test_result = embeddings.embed_query("test")
if not test_result or len(test_result) == 0:
raise ValueError("Embeddings test failed")
except Exception as e:
raise ValueError(f"Embeddings validation failed: {e}")
return embeddings
```
## Implementation Steps
1. **Update RAGPipeline class** (`src/clm_system/rag.py`)
- Remove direct initialization in `__init__`
- Add lazy property getters
- Implement explicit error handling
- Remove all `return None` patterns
2. **Create ModelFactory** (`src/clm_system/model_factory.py`)
- Implement factory methods for each model type
- Add validation for each model
- Include test queries to verify models work
3. **Update Streamlit app** (`debug_streamlit.py` or create new `app.py`)
- Use session state for pipeline storage
- Add configuration validation on startup
- Show clear error messages to users
- Add retry mechanism for transient failures
4. **Add configuration validator** (`src/clm_system/validators.py`)
- Check all required environment variables
- Validate API keys format (if applicable)
- Test connectivity to services
5. **Update config.py**
- Add helper methods for validation
- Include default fallbacks where appropriate
- Better error messages for missing values
## Testing Plan
1. **Unit Tests**
- Test model initialization with valid config
- Test model initialization with missing config
- Test error handling and messages
2. **Integration Tests**
- Test full pipeline initialization
- Test Streamlit session state persistence
- Test recovery from failures
3. **Manual Testing**
- Start app with valid config → should work
- Start app with missing API key → clear error
- Send query after initialization → should respond
- Refresh page → should maintain state
## Success Criteria
1. ✅ Clear error messages when configuration is invalid
2. ✅ Models initialize once and persist in session state
3. ✅ No silent failures (no `return None`)
4. ✅ App handles "Hi" message successfully
5. ✅ Page refreshes don't reinitialize models
6. ✅ Failed initialization stops app with helpful message
## Rollback Plan
If changes cause issues:
1. Revert to original code
2. Add temporary workaround in Streamlit app:
```python
if 'rag_pipeline' not in st.session_state:
# Force multiple initialization attempts
for attempt in range(3):
try:
pipeline = RAGPipeline()
if pipeline.embeddings and pipeline.llm:
st.session_state.rag_pipeline = pipeline
break
except Exception as e:
if attempt == 2:
st.error(f"Failed after 3 attempts: {e}")
st.stop()
```
## Notes
- Current code uses Google models (`EMBEDDING_MODEL=google`, `LLM_MODEL=google`)
- Google API key is set in environment
- Issue only occurs in Streamlit, CLI works fine
- Root cause: Streamlit's execution model + silent failures in initialization