peach.tl.train_archetypal#
- peach.tl.train_archetypal(adata, n_archetypes=5, n_epochs=50, *, layer=None, pca_key='X_pca', hidden_dims=None, inflation_factor=1.5, model_config=None, optimizer_config=None, device='cpu', save_path=None, archetypal_weight=None, kld_weight=None, reconstruction_weight=0.0, vae_recon_weight=0.0, diversity_weight=0.0, activation_func='relu', track_stability=True, validate_constraints=True, lr_factor=0.1, lr_patience=10, seed=42, constraint_tolerance=0.001, stability_history_size=20, store_coords_key='archetype_coordinates', early_stopping=False, early_stopping_patience=10, early_stopping_metric='archetype_r2', min_improvement=0.0001, validation_check_interval=5, validation_data_loader=None, **kwargs)[source]#
Train Deep Archetypal Analysis model to discover cellular archetypes.
This function performs archetypal analysis using a variational autoencoder architecture to identify extreme cellular states (archetypes) that capture the main axes of biological variation. Each cell is represented as a convex combination of the learned archetypes.
The model uses PCHA initialization with inflation factors for optimal archetype positioning and achieves state-of-the-art performance (R² > 0.89) on real single-cell datasets.
- Parameters:
adata (AnnData) – Annotated data object containing single-cell expression data. Must have PCA coordinates in
adata.obsm[pca_key]. Typically generated usingscanpy.pp.pca(adata).n_archetypes (int, default: 5) – Number of archetypal patterns to learn. Should be chosen based on biological knowledge or using hyperparameter optimization. Common values: 3-10 for most datasets.
n_epochs (int, default: 50) – Number of training epochs. Larger datasets may require more epochs. For datasets >5K cells, consider 100-200 epochs.
layer (str | None, default: None) – AnnData layer to use for training. If None, uses PCA coordinates from
adata.obsm[pca_key](recommended).pca_key (str, default: "X_pca") – Key in
adata.obsmcontaining PCA coordinates. The model works best with 5-50 PCA components. Auto-detects: ‘X_pca’, ‘X_PCA’, ‘PCA’.hidden_dims (list[int] | None, default: None) – Encoder/decoder layer dimensions. If None, uses [256, 128, 64]. Smaller architectures like [128, 64] train faster but may underfit. Larger architectures like [512, 256, 128] may overfit on small datasets.
inflation_factor (float, default: 1.5) – PCHA inflation factor for archetype initialization. Values > 1.0 push initial archetypes further from the data centroid, improving separation. Recommended range: 1.2-2.0. Higher values for more distinct archetypes.
model_config (dict | None, default: None) –
Additional model configuration parameters (for advanced users):
archetypal_weight: float - Archetypal loss weight, default 0.9kld_weight: float - KL divergence weight, default 0.1diversity_weight: float - Archetype diversity weight, default 0.05use_barycentric: bool - Use softmax constraints, default True
optimizer_config (dict | None, default: None) –
Optimizer configuration parameters:
lr: float - Learning rate, default 1e-3weight_decay: float - L2 regularization, default 0.0betas: tuple[float, float] - Adam momentum parameters
device (str, default: "cpu") – Compute device for training. One of “cpu”, “cuda”, “mps”.
save_path (str | None, default: None) – Path to save model checkpoints during training.
archetypal_weight (float | None, default: None) – Weight for archetypal loss component. Uses model’s configured value if None.
kld_weight (float | None, default: None) – Weight for KL divergence loss. Uses model’s configured value if None.
reconstruction_weight (float, default: 0.0) – Legacy reconstruction weight parameter.
vae_recon_weight (float, default: 0.0) – VAE reconstruction weight.
diversity_weight (float, default: 0.0) – Weight for archetype diversity loss.
activation_func (str, default: "relu") – Activation function for the model.
track_stability (bool, default: True) – Whether to monitor archetype drift during training. Adds stability metrics to history: archetype_drift_mean, archetype_variance_mean, etc.
validate_constraints (bool, default: True) – Whether to validate archetypal constraints during training. Adds constraint metrics to history: constraints_satisfied, A_sum_error, etc.
lr_factor (float, default: 0.1) – Factor for learning rate reduction on plateau.
lr_patience (int, default: 10) – Number of epochs with no improvement before reducing learning rate.
seed (int, default: 42) – Random seed for reproducibility.
constraint_tolerance (float, default: 1e-3) – Tolerance for constraint validation.
stability_history_size (int, default: 20) – Number of epochs to track for stability analysis.
store_coords_key (str, default: "archetype_coordinates") – Key to store learned archetype positions in
adata.uns.early_stopping (bool, default: False) – Whether to use early stopping based on validation metrics.
early_stopping_patience (int, default: 10) – Patience for early stopping (number of checks without improvement).
early_stopping_metric (str, default: "archetype_r2") – Metric to monitor for early stopping. One of: ‘archetype_r2’, ‘loss’, ‘rmse’.
min_improvement (float, default: 1e-4) – Minimum improvement required to reset patience counter.
validation_check_interval (int, default: 5) – How often to check validation metrics (in epochs).
validation_data_loader (DataLoader | None, default: None) – Validation data loader for early stopping. Uses training data if None.
**kwargs – Additional arguments passed to the core training function.
- Returns:
Training results dictionary with the following structure:
Guaranteed keys (always present):
historydictTraining metrics per epoch. Keys depend on tracking options:
Core metrics (always):
loss,archetypal_loss,archetype_r2,rmseKLD metrics:
kld_loss,KLDStability metrics (if track_stability=True):
archetype_drift_mean,archetype_drift_max,archetype_variance_mean, etc.Constraint metrics (if validate_constraints=True):
constraints_satisfied,A_sum_error,B_sum_error,constraint_violation_rateValidation metrics (if early_stopping=True):
val_loss,val_archetype_r2
final_modeltorch.nn.ModuleTrained Deep_AA model instance.
modeltorch.nn.ModuleAlias for
final_model(same object, for compatibility).
final_optimizertorch.optim.OptimizerFinal optimizer state.
final_analysisdictFinal training analysis containing:
final_constraint_validation: dict with constraint metricsarchetypal_weights: dict with A_matrix and B_matrix analysisfinal_coordinates: dict with ‘A’, ‘B’, ‘Y’ tensorserror: str (only if analysis failed)
epoch_archetype_positionslist[torch.Tensor]Archetype positions at each epoch, shape (n_archetypes, input_dim).
training_configdictTraining configuration with keys:
n_epochs,actual_epochs,early_stop_triggered,archetypal_weight,kld_weight,reconstruction_weight,activation_func,seed,constraint_tolerance,stability_history_size,early_stopping,early_stopping_patience,early_stopping_metric.
Convenience keys (conditional - use .get() to access safely):
final_archetype_r2float | NoneLast value of
history['archetype_r2']if tracked.
final_rmsefloat | NoneLast value of
history['rmse']if tracked.
final_maefloat | NoneLast value of
history['mae']if tracked.
final_lossfloat | NoneLast value of
history['loss']if tracked.
convergence_epochint | NoneEquals
training_config['actual_epochs'].
- Return type:
- Raises:
ValueError – If
adata.obsm[pca_key]is not found. Runscanpy.pp.pca()first. Ifn_archetypesexceeds PCA dimensions.RuntimeError – If CUDA device is requested but not available.
Stores –
------ –
The function stores the following in AnnData: –
:raises -
adata.uns[store_coords_key]: np.ndarray: Archetype positions in PCA space, shape (n_archetypes, n_pcs). Default key: ‘archetype_coordinates’.Notes
Archetypal Analysis Theory: Archetypal analysis represents each data point as a convex combination of extreme points (archetypes). Unlike clustering, which partitions data, archetypal analysis allows cells to have partial membership in multiple archetypes, better reflecting biological continuity.
Model Architecture: Uses a variational autoencoder where the latent space directly represents archetypal coordinates (A matrix). Archetypes are learned as model parameters (Y matrix) rather than constructed from data points.
Accessing Results Safely:
For guaranteed keys, direct access works:
results['history']For convenience keys, use
.get():results.get('final_archetype_r2')Or access from history:
results['history']['archetype_r2'][-1]
Type Validation: For IDE autocomplete and runtime validation:
from peach._core.types import TrainingResults, validate_training_results validated = validate_training_results(results)
Examples
Basic usage with default parameters:
>>> import scanpy as sc >>> import peach as pc >>> # Prepare data with PCA >>> sc.pp.pca(adata, n_comps=30) >>> # Train archetypal model >>> results = pc.tl.train_archetypal(adata, n_archetypes=5, n_epochs=100) >>> # Access final R² safely (may be None if not tracked) >>> r2 = results.get("final_archetype_r2") >>> if r2 is not None: ... print(f"Final R²: {r2:.3f}") >>> # Or access from history (guaranteed if metric was tracked) >>> if results["history"].get("archetype_r2"): ... r2 = results["history"]["archetype_r2"][-1] ... print(f"Final R²: {r2:.3f}")
Advanced usage with custom configuration:
>>> model_config = {"hidden_dims": [512, 256, 128], "inflation_factor": 2.0} >>> results = pc.tl.train_archetypal( ... adata, ... n_archetypes=4, ... n_epochs=150, ... model_config=model_config, ... early_stopping=True, ... early_stopping_patience=15, ... device="cuda", ... ) >>> # Check if early stopping triggered >>> config = results["training_config"] >>> if config["early_stop_triggered"]: ... print(f"Converged at epoch {config['actual_epochs']}")
Accessing archetype coordinates stored in AnnData:
>>> # After training, coordinates are in adata.uns >>> archetype_coords = adata.uns["archetype_coordinates"] >>> print(f"Learned {archetype_coords.shape[0]} archetypes") >>> print(f"Each archetype has {archetype_coords.shape[1]} PCA dimensions")
See also
peach.tl.archetypal_coordinatesExtract cell-archetype distances
peach.tl.assign_archetypesAssign cells to discovered archetypes
peach.tl.extract_archetype_weightsGet barycentric coordinates for cells
peach.pl.training_metricsVisualize training curves
peach._core.types.TrainingResultsType definition for return structure