WORKFLOW 03: Archetypal Model Training#
This workflow demonstrates how to train a final archetypal analysis model:
Load data with PCA preprocessing
Configure training parameters based on hyperparameter search (WORKFLOW_02)
Train the archetypal model
Access and understand training results
The training function returns a TrainingResults dict with the following structure:
GUARANTEED KEYS (always present):
history: Training metrics per epoch (dict)
final_model: Trained model object
model: Same as final_model
final_optimizer: Optimizer state
final_analysis: Analysis results dict
epoch_archetype_positions: Archetype positions over training
training_config: Configuration used for training
OPTIONAL KEYS (use .get() to access):
final_archetype_r2: Final R² score (if computed)
final_rmse, final_mae, final_loss: Other metrics
convergence_epoch: Epoch where convergence occurred
Additionally, the function modifies adata.uns[‘archetype_coordinates’] in-place.
Example usage: python WORKFLOW_03.py
Requirements: - peach - scanpy - Data with PCA (from WORKFLOW_01)
[ ]:
import scanpy as sc
import peach as pc
from pathlib import Path
Configuration#
[ ]:
# Data path
data_path = Path("~/data/HSC.h5ad")
# Training parameters (from hyperparameter search)
n_archetypes = 5 # Number of archetypes
hidden_dims = [256, 128, 64] # Encoder architecture
inflation_factor = 1.5 # PCA inflation factor
n_epochs = 100 # Training epochs
early_stopping_patience = 10 # Stop if no improvement for N epochs
seed = 42 # For reproducibility
Step 1: Load Data with PCA#
[ ]:
print("Loading data...")
adata = sc.read_h5ad(data_path)
print(f" Shape: {adata.n_obs:,} cells × {adata.n_vars:,} genes")
# Ensure PCA exists
if 'X_pca' not in adata.obsm:
print(" Running PCA...")
sc.tl.pca(adata, n_comps=13)
print(f" PCA computed: {adata.obsm['X_pca'].shape}")
else:
print(f" PCA found: {adata.obsm['X_pca'].shape}")
Step 2: Train Archetypal Model#
[ ]:
print(f"\nTraining archetypal model...")
print(f" n_archetypes: {n_archetypes}")
print(f" hidden_dims: {hidden_dims}")
print(f" inflation_factor: {inflation_factor}")
print(f" n_epochs: {n_epochs}")
print(f" This may take several minutes...")
results = pc.tl.train_archetypal(
adata,
n_archetypes=n_archetypes,
n_epochs=n_epochs,
hidden_dims=hidden_dims,
inflation_factor=inflation_factor,
early_stopping_patience=early_stopping_patience,
seed=seed,
device='cpu', # Use 'cuda' if GPU available
)
print(" Training complete!")
Step 3: Examine Training Results#
[ ]:
print("\nExamining training results...")
# Access guaranteed keys
history = results['history']
final_model = results['final_model']
training_config = results['training_config']
print(f"\nTraining history:")
print(f" Epochs completed: {len(history.get('loss', []))}")
if 'loss' in history:
print(f" Final loss: {history['loss'][-1]:.6f}")
if 'archetype_r2' in history:
print(f" Final archetype R²: {history['archetype_r2'][-1]:.4f}")
# Access optional keys safely with .get()
final_r2 = results.get('final_archetype_r2')
convergence_epoch = results.get('convergence_epoch')
if final_r2 is not None:
print(f"\nModel performance:")
print(f" Final archetype R²: {final_r2:.4f}")
if convergence_epoch is not None:
print(f" Converged at epoch: {convergence_epoch}")
# Check that adata was modified
print(f"\nModifications to adata:")
if 'archetype_coordinates' in adata.uns:
print(f" adata.uns['archetype_coordinates'] created")
coords = adata.uns['archetype_coordinates']
print(f" Shape: {coords.shape if hasattr(coords, 'shape') else 'N/A'}")
else:
print(f" Warning: adata.uns['archetype_coordinates'] not found")
Step 4: Accessing the Trained Model#
[ ]:
print("\nTrained model access:")
print(f" Model type: {type(final_model).__name__}")
print(f" Model ready for inference and analysis")
# The model can be used for:
# - Computing archetype coordinates (WORKFLOW_04)
# - Computing archetype weights
# - Gene/pathway associations (WORKFLOW_05)
Summary#
[ ]:
print("\n" + "="*70)
print("WORKFLOW 03 COMPLETE")
print("="*70)
print(f"Trained model:")
print(f" • n_archetypes: {n_archetypes}")
print(f" • Epochs run: {len(history.get('loss', []))}")
if final_r2:
print(f" • Final R²: {final_r2:.4f}")
print(f"\nKey outputs:")
print(f" • results['final_model'] - Trained model for inference")
print(f" • results['history'] - Training metrics per epoch")
print(f" • adata.uns['archetype_coordinates'] - Cell coordinates")
print("\nNext workflow: WORKFLOW_04 (Archetype Coordinates & Assignment)")
print("="*70)