296 lines
9.8 KiB
Markdown
296 lines
9.8 KiB
Markdown
# Streamlit RAG Pipeline Initialization Issue
|
|
|
|
## Problem Statement
|
|
When running the Streamlit app, sending "Hi" results in error: `AI models are not properly configured. Please check your API keys.`
|
|
|
|
### Root Cause Analysis
|
|
1. **Embeddings returning None**: The `_initialize_embeddings()` method in `RAGPipeline` returns `None` when initialization fails
|
|
2. **Silent failures**: Exceptions are caught but only logged as warnings, returning `None` instead of raising errors
|
|
3. **Streamlit rerun behavior**: Each interaction causes a full script rerun, potentially reinitializing models
|
|
4. **No persistence**: Models are not stored in `st.session_state`, causing repeated initialization attempts
|
|
|
|
### Current Behavior
|
|
```python
|
|
# Current problematic flow:
|
|
RAGPipeline.__init__()
|
|
→ _initialize_embeddings()
|
|
→ try/except returns None on failure
|
|
→ self.embeddings = None
|
|
→ Query fails with generic error
|
|
```
|
|
|
|
## Proposed Design Changes
|
|
|
|
### 1. Singleton Pattern with Session State
|
|
**Problem**: Multiple RAGPipeline instances created on reruns
|
|
**Solution**: Use Streamlit session state as singleton storage
|
|
|
|
```python
|
|
# In app.py or streamlit_app.py
|
|
def get_rag_pipeline():
|
|
"""Get or create RAG pipeline with proper session state management"""
|
|
if 'rag_pipeline' not in st.session_state:
|
|
with st.spinner("Initializing AI models..."):
|
|
pipeline = RAGPipeline()
|
|
|
|
# Validate initialization
|
|
if pipeline.embeddings is None:
|
|
st.error("Failed to initialize embedding model")
|
|
st.stop()
|
|
if pipeline.llm is None:
|
|
st.error("Failed to initialize language model")
|
|
st.stop()
|
|
|
|
st.session_state.rag_pipeline = pipeline
|
|
st.success("AI models initialized successfully")
|
|
|
|
return st.session_state.rag_pipeline
|
|
```
|
|
|
|
### 2. Lazy Initialization with Caching
|
|
**Problem**: Models initialized in `__init__` even if not needed
|
|
**Solution**: Use lazy properties with `@st.cache_resource`
|
|
|
|
```python
|
|
class RAGPipeline:
|
|
def __init__(self):
|
|
self._embeddings = None
|
|
self._llm = None
|
|
self.db_path = "data/lancedb"
|
|
self.db = lancedb.connect(self.db_path)
|
|
|
|
@property
|
|
def embeddings(self):
|
|
if self._embeddings is None:
|
|
self._embeddings = self._get_or_create_embeddings()
|
|
return self._embeddings
|
|
|
|
@property
|
|
def llm(self):
|
|
if self._llm is None:
|
|
self._llm = self._get_or_create_llm()
|
|
return self._llm
|
|
|
|
@st.cache_resource
|
|
def _get_or_create_embeddings(_self):
|
|
"""Cached embedding model creation"""
|
|
return _self._initialize_embeddings()
|
|
|
|
@st.cache_resource
|
|
def _get_or_create_llm(_self):
|
|
"""Cached LLM creation"""
|
|
return _self._initialize_llm()
|
|
```
|
|
|
|
### 3. Explicit Error Handling
|
|
**Problem**: Silent failures with `return None`
|
|
**Solution**: Raise exceptions with clear messages
|
|
|
|
```python
|
|
def _initialize_embeddings(self):
|
|
"""Initialize embeddings with explicit error handling"""
|
|
model_type = config.EMBEDDING_MODEL
|
|
|
|
try:
|
|
if model_type == "google":
|
|
if not config.GOOGLE_API_KEY:
|
|
raise ValueError("GOOGLE_API_KEY is not set in environment variables")
|
|
|
|
# Try to import and initialize
|
|
try:
|
|
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
|
except ImportError as e:
|
|
raise ImportError(f"Failed to import GoogleGenerativeAIEmbeddings: {e}")
|
|
|
|
embeddings = GoogleGenerativeAIEmbeddings(
|
|
model=config.GOOGLE_EMBEDDING_MODEL,
|
|
google_api_key=config.GOOGLE_API_KEY
|
|
)
|
|
|
|
# Test the embeddings work
|
|
test_embedding = embeddings.embed_query("test")
|
|
if not test_embedding:
|
|
raise ValueError("Embeddings returned empty result for test query")
|
|
|
|
return embeddings
|
|
|
|
elif model_type == "openai":
|
|
# Similar explicit handling for OpenAI
|
|
pass
|
|
|
|
else:
|
|
raise ValueError(f"Unsupported embedding model: {model_type}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize {model_type} embeddings: {str(e)}")
|
|
# Re-raise with context
|
|
raise RuntimeError(f"Embedding initialization failed: {str(e)}") from e
|
|
```
|
|
|
|
### 4. Configuration Validation
|
|
**Problem**: No upfront validation of configuration
|
|
**Solution**: Add configuration validator
|
|
|
|
```python
|
|
def validate_ai_config():
|
|
"""Validate AI configuration before initialization"""
|
|
errors = []
|
|
|
|
# Check embedding model configuration
|
|
if config.EMBEDDING_MODEL == "google":
|
|
if not config.GOOGLE_API_KEY:
|
|
errors.append("GOOGLE_API_KEY not set for Google embedding model")
|
|
if not config.GOOGLE_EMBEDDING_MODEL:
|
|
errors.append("GOOGLE_EMBEDDING_MODEL not configured")
|
|
|
|
# Check LLM configuration
|
|
if config.LLM_MODEL == "google":
|
|
if not config.GOOGLE_API_KEY:
|
|
errors.append("GOOGLE_API_KEY not set for Google LLM")
|
|
if not config.GOOGLE_MODEL:
|
|
errors.append("GOOGLE_MODEL not configured")
|
|
|
|
if errors:
|
|
return False, errors
|
|
return True, []
|
|
|
|
# Use in Streamlit app
|
|
def initialize_app():
|
|
valid, errors = validate_ai_config()
|
|
if not valid:
|
|
st.error("Configuration errors detected:")
|
|
for error in errors:
|
|
st.error(f"• {error}")
|
|
st.stop()
|
|
```
|
|
|
|
### 5. Model Factory Pattern
|
|
**Problem**: Model initialization logic mixed with pipeline logic
|
|
**Solution**: Separate model creation into factory
|
|
|
|
```python
|
|
class ModelFactory:
|
|
"""Factory for creating AI models with proper error handling"""
|
|
|
|
@staticmethod
|
|
def create_embeddings(model_type: str):
|
|
"""Create embedding model based on type"""
|
|
creators = {
|
|
"google": ModelFactory._create_google_embeddings,
|
|
"openai": ModelFactory._create_openai_embeddings,
|
|
"huggingface": ModelFactory._create_huggingface_embeddings
|
|
}
|
|
|
|
creator = creators.get(model_type)
|
|
if not creator:
|
|
raise ValueError(f"Unknown embedding model type: {model_type}")
|
|
|
|
return creator()
|
|
|
|
@staticmethod
|
|
def _create_google_embeddings():
|
|
"""Create Google embeddings with validation"""
|
|
if not config.GOOGLE_API_KEY:
|
|
raise ValueError("GOOGLE_API_KEY not configured")
|
|
|
|
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
|
|
|
embeddings = GoogleGenerativeAIEmbeddings(
|
|
model=config.GOOGLE_EMBEDDING_MODEL,
|
|
google_api_key=config.GOOGLE_API_KEY
|
|
)
|
|
|
|
# Validate it works
|
|
try:
|
|
test_result = embeddings.embed_query("test")
|
|
if not test_result or len(test_result) == 0:
|
|
raise ValueError("Embeddings test failed")
|
|
except Exception as e:
|
|
raise ValueError(f"Embeddings validation failed: {e}")
|
|
|
|
return embeddings
|
|
```
|
|
|
|
## Implementation Steps
|
|
|
|
1. **Update RAGPipeline class** (`src/clm_system/rag.py`)
|
|
- Remove direct initialization in `__init__`
|
|
- Add lazy property getters
|
|
- Implement explicit error handling
|
|
- Remove all `return None` patterns
|
|
|
|
2. **Create ModelFactory** (`src/clm_system/model_factory.py`)
|
|
- Implement factory methods for each model type
|
|
- Add validation for each model
|
|
- Include test queries to verify models work
|
|
|
|
3. **Update Streamlit app** (`debug_streamlit.py` or create new `app.py`)
|
|
- Use session state for pipeline storage
|
|
- Add configuration validation on startup
|
|
- Show clear error messages to users
|
|
- Add retry mechanism for transient failures
|
|
|
|
4. **Add configuration validator** (`src/clm_system/validators.py`)
|
|
- Check all required environment variables
|
|
- Validate API keys format (if applicable)
|
|
- Test connectivity to services
|
|
|
|
5. **Update config.py**
|
|
- Add helper methods for validation
|
|
- Include default fallbacks where appropriate
|
|
- Better error messages for missing values
|
|
|
|
## Testing Plan
|
|
|
|
1. **Unit Tests**
|
|
- Test model initialization with valid config
|
|
- Test model initialization with missing config
|
|
- Test error handling and messages
|
|
|
|
2. **Integration Tests**
|
|
- Test full pipeline initialization
|
|
- Test Streamlit session state persistence
|
|
- Test recovery from failures
|
|
|
|
3. **Manual Testing**
|
|
- Start app with valid config → should work
|
|
- Start app with missing API key → clear error
|
|
- Send query after initialization → should respond
|
|
- Refresh page → should maintain state
|
|
|
|
## Success Criteria
|
|
|
|
1. ✅ Clear error messages when configuration is invalid
|
|
2. ✅ Models initialize once and persist in session state
|
|
3. ✅ No silent failures (no `return None`)
|
|
4. ✅ App handles "Hi" message successfully
|
|
5. ✅ Page refreshes don't reinitialize models
|
|
6. ✅ Failed initialization stops app with helpful message
|
|
|
|
## Rollback Plan
|
|
|
|
If changes cause issues:
|
|
1. Revert to original code
|
|
2. Add temporary workaround in Streamlit app:
|
|
```python
|
|
if 'rag_pipeline' not in st.session_state:
|
|
# Force multiple initialization attempts
|
|
for attempt in range(3):
|
|
try:
|
|
pipeline = RAGPipeline()
|
|
if pipeline.embeddings and pipeline.llm:
|
|
st.session_state.rag_pipeline = pipeline
|
|
break
|
|
except Exception as e:
|
|
if attempt == 2:
|
|
st.error(f"Failed after 3 attempts: {e}")
|
|
st.stop()
|
|
```
|
|
|
|
## Notes
|
|
|
|
- Current code uses Google models (`EMBEDDING_MODEL=google`, `LLM_MODEL=google`)
|
|
- Google API key is set in environment
|
|
- Issue only occurs in Streamlit, CLI works fine
|
|
- Root cause: Streamlit's execution model + silent failures in initialization
|