peach.tl.train_archetypal

peach.tl.train_archetypal#

peach.tl.train_archetypal(adata, n_archetypes=5, n_epochs=50, *, layer=None, pca_key='X_pca', hidden_dims=None, inflation_factor=1.5, model_config=None, optimizer_config=None, device='cpu', save_path=None, archetypal_weight=None, kld_weight=None, reconstruction_weight=0.0, vae_recon_weight=0.0, diversity_weight=0.0, activation_func='relu', track_stability=True, validate_constraints=True, lr_factor=0.1, lr_patience=10, seed=42, constraint_tolerance=0.001, stability_history_size=20, store_coords_key='archetype_coordinates', early_stopping=False, early_stopping_patience=10, early_stopping_metric='archetype_r2', min_improvement=0.0001, validation_check_interval=5, validation_data_loader=None, **kwargs)[source]#

Train Deep Archetypal Analysis model to discover cellular archetypes.

This function performs archetypal analysis using a variational autoencoder architecture to identify extreme cellular states (archetypes) that capture the main axes of biological variation. Each cell is represented as a convex combination of the learned archetypes.

The model uses PCHA initialization with inflation factors for optimal archetype positioning and achieves state-of-the-art performance (R² > 0.89) on real single-cell datasets.

Parameters:
  • adata (AnnData) – Annotated data object containing single-cell expression data. Must have PCA coordinates in adata.obsm[pca_key]. Typically generated using scanpy.pp.pca(adata).

  • n_archetypes (int, default: 5) – Number of archetypal patterns to learn. Should be chosen based on biological knowledge or using hyperparameter optimization. Common values: 3-10 for most datasets.

  • n_epochs (int, default: 50) – Number of training epochs. Larger datasets may require more epochs. For datasets >5K cells, consider 100-200 epochs.

  • layer (str | None, default: None) – AnnData layer to use for training. If None, uses PCA coordinates from adata.obsm[pca_key] (recommended).

  • pca_key (str, default: "X_pca") – Key in adata.obsm containing PCA coordinates. The model works best with 5-50 PCA components. Auto-detects: ‘X_pca’, ‘X_PCA’, ‘PCA’.

  • hidden_dims (list[int] | None, default: None) – Encoder/decoder layer dimensions. If None, uses [256, 128, 64]. Smaller architectures like [128, 64] train faster but may underfit. Larger architectures like [512, 256, 128] may overfit on small datasets.

  • inflation_factor (float, default: 1.5) – PCHA inflation factor for archetype initialization. Values > 1.0 push initial archetypes further from the data centroid, improving separation. Recommended range: 1.2-2.0. Higher values for more distinct archetypes.

  • model_config (dict | None, default: None) –

    Additional model configuration parameters (for advanced users):

    • archetypal_weight : float - Archetypal loss weight, default 0.9

    • kld_weight : float - KL divergence weight, default 0.1

    • diversity_weight : float - Archetype diversity weight, default 0.05

    • use_barycentric : bool - Use softmax constraints, default True

  • optimizer_config (dict | None, default: None) –

    Optimizer configuration parameters:

    • lr : float - Learning rate, default 1e-3

    • weight_decay : float - L2 regularization, default 0.0

    • betas : tuple[float, float] - Adam momentum parameters

  • device (str, default: "cpu") – Compute device for training. One of “cpu”, “cuda”, “mps”.

  • save_path (str | None, default: None) – Path to save model checkpoints during training.

  • archetypal_weight (float | None, default: None) – Weight for archetypal loss component. Uses model’s configured value if None.

  • kld_weight (float | None, default: None) – Weight for KL divergence loss. Uses model’s configured value if None.

  • reconstruction_weight (float, default: 0.0) – Legacy reconstruction weight parameter.

  • vae_recon_weight (float, default: 0.0) – VAE reconstruction weight.

  • diversity_weight (float, default: 0.0) – Weight for archetype diversity loss.

  • activation_func (str, default: "relu") – Activation function for the model.

  • track_stability (bool, default: True) – Whether to monitor archetype drift during training. Adds stability metrics to history: archetype_drift_mean, archetype_variance_mean, etc.

  • validate_constraints (bool, default: True) – Whether to validate archetypal constraints during training. Adds constraint metrics to history: constraints_satisfied, A_sum_error, etc.

  • lr_factor (float, default: 0.1) – Factor for learning rate reduction on plateau.

  • lr_patience (int, default: 10) – Number of epochs with no improvement before reducing learning rate.

  • seed (int, default: 42) – Random seed for reproducibility.

  • constraint_tolerance (float, default: 1e-3) – Tolerance for constraint validation.

  • stability_history_size (int, default: 20) – Number of epochs to track for stability analysis.

  • store_coords_key (str, default: "archetype_coordinates") – Key to store learned archetype positions in adata.uns.

  • early_stopping (bool, default: False) – Whether to use early stopping based on validation metrics.

  • early_stopping_patience (int, default: 10) – Patience for early stopping (number of checks without improvement).

  • early_stopping_metric (str, default: "archetype_r2") – Metric to monitor for early stopping. One of: ‘archetype_r2’, ‘loss’, ‘rmse’.

  • min_improvement (float, default: 1e-4) – Minimum improvement required to reset patience counter.

  • validation_check_interval (int, default: 5) – How often to check validation metrics (in epochs).

  • validation_data_loader (DataLoader | None, default: None) – Validation data loader for early stopping. Uses training data if None.

  • **kwargs – Additional arguments passed to the core training function.

Returns:

Training results dictionary with the following structure:

Guaranteed keys (always present):

  • historydict

    Training metrics per epoch. Keys depend on tracking options:

    • Core metrics (always): loss, archetypal_loss, archetype_r2, rmse

    • KLD metrics: kld_loss, KLD

    • Stability metrics (if track_stability=True): archetype_drift_mean, archetype_drift_max, archetype_variance_mean, etc.

    • Constraint metrics (if validate_constraints=True): constraints_satisfied, A_sum_error, B_sum_error, constraint_violation_rate

    • Validation metrics (if early_stopping=True): val_loss, val_archetype_r2

  • final_modeltorch.nn.Module

    Trained Deep_AA model instance.

  • modeltorch.nn.Module

    Alias for final_model (same object, for compatibility).

  • final_optimizertorch.optim.Optimizer

    Final optimizer state.

  • final_analysisdict

    Final training analysis containing:

    • final_constraint_validation : dict with constraint metrics

    • archetypal_weights : dict with A_matrix and B_matrix analysis

    • final_coordinates : dict with ‘A’, ‘B’, ‘Y’ tensors

    • error : str (only if analysis failed)

  • epoch_archetype_positionslist[torch.Tensor]

    Archetype positions at each epoch, shape (n_archetypes, input_dim).

  • training_configdict

    Training configuration with keys: n_epochs, actual_epochs, early_stop_triggered, archetypal_weight, kld_weight, reconstruction_weight, activation_func, seed, constraint_tolerance, stability_history_size, early_stopping, early_stopping_patience, early_stopping_metric.

Convenience keys (conditional - use .get() to access safely):

  • final_archetype_r2float | None

    Last value of history['archetype_r2'] if tracked.

  • final_rmsefloat | None

    Last value of history['rmse'] if tracked.

  • final_maefloat | None

    Last value of history['mae'] if tracked.

  • final_lossfloat | None

    Last value of history['loss'] if tracked.

  • convergence_epochint | None

    Equals training_config['actual_epochs'].

Return type:

dict

Raises:
  • ValueError – If adata.obsm[pca_key] is not found. Run scanpy.pp.pca() first. If n_archetypes exceeds PCA dimensions.

  • RuntimeError – If CUDA device is requested but not available.

  • Stores

  • ------

  • The function stores the following in AnnData:

:raises - adata.uns[store_coords_key] : np.ndarray: Archetype positions in PCA space, shape (n_archetypes, n_pcs). Default key: ‘archetype_coordinates’.

Notes

Archetypal Analysis Theory: Archetypal analysis represents each data point as a convex combination of extreme points (archetypes). Unlike clustering, which partitions data, archetypal analysis allows cells to have partial membership in multiple archetypes, better reflecting biological continuity.

Model Architecture: Uses a variational autoencoder where the latent space directly represents archetypal coordinates (A matrix). Archetypes are learned as model parameters (Y matrix) rather than constructed from data points.

Accessing Results Safely:

  • For guaranteed keys, direct access works: results['history']

  • For convenience keys, use .get(): results.get('final_archetype_r2')

  • Or access from history: results['history']['archetype_r2'][-1]

Type Validation: For IDE autocomplete and runtime validation:

from peach._core.types import TrainingResults, validate_training_results

validated = validate_training_results(results)

Examples

Basic usage with default parameters:

>>> import scanpy as sc
>>> import peach as pc
>>> # Prepare data with PCA
>>> sc.pp.pca(adata, n_comps=30)
>>> # Train archetypal model
>>> results = pc.tl.train_archetypal(adata, n_archetypes=5, n_epochs=100)
>>> # Access final R² safely (may be None if not tracked)
>>> r2 = results.get("final_archetype_r2")
>>> if r2 is not None:
...     print(f"Final R²: {r2:.3f}")
>>> # Or access from history (guaranteed if metric was tracked)
>>> if results["history"].get("archetype_r2"):
...     r2 = results["history"]["archetype_r2"][-1]
...     print(f"Final R²: {r2:.3f}")

Advanced usage with custom configuration:

>>> model_config = {"hidden_dims": [512, 256, 128], "inflation_factor": 2.0}
>>> results = pc.tl.train_archetypal(
...     adata,
...     n_archetypes=4,
...     n_epochs=150,
...     model_config=model_config,
...     early_stopping=True,
...     early_stopping_patience=15,
...     device="cuda",
... )
>>> # Check if early stopping triggered
>>> config = results["training_config"]
>>> if config["early_stop_triggered"]:
...     print(f"Converged at epoch {config['actual_epochs']}")

Accessing archetype coordinates stored in AnnData:

>>> # After training, coordinates are in adata.uns
>>> archetype_coords = adata.uns["archetype_coordinates"]
>>> print(f"Learned {archetype_coords.shape[0]} archetypes")
>>> print(f"Each archetype has {archetype_coords.shape[1]} PCA dimensions")

See also

peach.tl.archetypal_coordinates

Extract cell-archetype distances

peach.tl.assign_archetypes

Assign cells to discovered archetypes

peach.tl.extract_archetype_weights

Get barycentric coordinates for cells

peach.pl.training_metrics

Visualize training curves

peach._core.types.TrainingResults

Type definition for return structure