peach.tl.archetypal#
Core archetypal analysis functions.
This module provides the primary interface for training Deep Archetypal Analysis models and extracting archetypal coordinates. All functions work directly with AnnData objects and follow scVerse conventions.
Main Functions: - train_archetypal(): Train Deep AA model to discover cellular archetypes - archetypal_coordinates(): Extract archetypal coordinates for all cells - assign_archetypes(): Assign cells to discovered archetypes based on distances
The module integrates PCHA initialization, inflation factors, and comprehensive training diagnostics for production-ready archetypal analysis workflows.
Functions
|
Extract archetypal coordinates for all cells. |
|
Assign cells to archetypes based on distances. |
|
Assign cells to nearest centroid based on distance (top bin_prop% closest). |
|
Compute centroid positions in PCA space for each level of a categorical condition. |
|
Extract cell archetype weights from trained Deep_AA model. |
|
Train Deep Archetypal Analysis model to discover cellular archetypes. |
- peach.tl.archetypal.train_archetypal(adata, n_archetypes=5, n_epochs=50, *, layer=None, pca_key='X_pca', hidden_dims=None, inflation_factor=1.5, model_config=None, optimizer_config=None, device='cpu', save_path=None, archetypal_weight=None, kld_weight=None, reconstruction_weight=0.0, vae_recon_weight=0.0, diversity_weight=0.0, activation_func='relu', track_stability=True, validate_constraints=True, lr_factor=0.1, lr_patience=10, seed=42, constraint_tolerance=0.001, stability_history_size=20, store_coords_key='archetype_coordinates', early_stopping=False, early_stopping_patience=10, early_stopping_metric='archetype_r2', min_improvement=0.0001, validation_check_interval=5, validation_data_loader=None, **kwargs)[source]#
Train Deep Archetypal Analysis model to discover cellular archetypes.
This function performs archetypal analysis using a variational autoencoder architecture to identify extreme cellular states (archetypes) that capture the main axes of biological variation. Each cell is represented as a convex combination of the learned archetypes.
The model uses PCHA initialization with inflation factors for optimal archetype positioning and achieves state-of-the-art performance (R² > 0.89) on real single-cell datasets.
- Parameters:
adata (AnnData) – Annotated data object containing single-cell expression data. Must have PCA coordinates in
adata.obsm[pca_key]. Typically generated usingscanpy.pp.pca(adata).n_archetypes (int, default: 5) – Number of archetypal patterns to learn. Should be chosen based on biological knowledge or using hyperparameter optimization. Common values: 3-10 for most datasets.
n_epochs (int, default: 50) – Number of training epochs. Larger datasets may require more epochs. For datasets >5K cells, consider 100-200 epochs.
layer (str | None, default: None) – AnnData layer to use for training. If None, uses PCA coordinates from
adata.obsm[pca_key](recommended).pca_key (str, default: "X_pca") – Key in
adata.obsmcontaining PCA coordinates. The model works best with 5-50 PCA components. Auto-detects: ‘X_pca’, ‘X_PCA’, ‘PCA’.hidden_dims (list[int] | None, default: None) – Encoder/decoder layer dimensions. If None, uses [256, 128, 64]. Smaller architectures like [128, 64] train faster but may underfit. Larger architectures like [512, 256, 128] may overfit on small datasets.
inflation_factor (float, default: 1.5) – PCHA inflation factor for archetype initialization. Values > 1.0 push initial archetypes further from the data centroid, improving separation. Recommended range: 1.2-2.0. Higher values for more distinct archetypes.
model_config (dict | None, default: None) –
Additional model configuration parameters (for advanced users):
archetypal_weight: float - Archetypal loss weight, default 0.9kld_weight: float - KL divergence weight, default 0.1diversity_weight: float - Archetype diversity weight, default 0.05use_barycentric: bool - Use softmax constraints, default True
optimizer_config (dict | None, default: None) –
Optimizer configuration parameters:
lr: float - Learning rate, default 1e-3weight_decay: float - L2 regularization, default 0.0betas: tuple[float, float] - Adam momentum parameters
device (str, default: "cpu") – Compute device for training. One of “cpu”, “cuda”, “mps”.
save_path (str | None, default: None) – Path to save model checkpoints during training.
archetypal_weight (float | None, default: None) – Weight for archetypal loss component. Uses model’s configured value if None.
kld_weight (float | None, default: None) – Weight for KL divergence loss. Uses model’s configured value if None.
reconstruction_weight (float, default: 0.0) – Legacy reconstruction weight parameter.
vae_recon_weight (float, default: 0.0) – VAE reconstruction weight.
diversity_weight (float, default: 0.0) – Weight for archetype diversity loss.
activation_func (str, default: "relu") – Activation function for the model.
track_stability (bool, default: True) – Whether to monitor archetype drift during training. Adds stability metrics to history: archetype_drift_mean, archetype_variance_mean, etc.
validate_constraints (bool, default: True) – Whether to validate archetypal constraints during training. Adds constraint metrics to history: constraints_satisfied, A_sum_error, etc.
lr_factor (float, default: 0.1) – Factor for learning rate reduction on plateau.
lr_patience (int, default: 10) – Number of epochs with no improvement before reducing learning rate.
seed (int, default: 42) – Random seed for reproducibility.
constraint_tolerance (float, default: 1e-3) – Tolerance for constraint validation.
stability_history_size (int, default: 20) – Number of epochs to track for stability analysis.
store_coords_key (str, default: "archetype_coordinates") – Key to store learned archetype positions in
adata.uns.early_stopping (bool, default: False) – Whether to use early stopping based on validation metrics.
early_stopping_patience (int, default: 10) – Patience for early stopping (number of checks without improvement).
early_stopping_metric (str, default: "archetype_r2") – Metric to monitor for early stopping. One of: ‘archetype_r2’, ‘loss’, ‘rmse’.
min_improvement (float, default: 1e-4) – Minimum improvement required to reset patience counter.
validation_check_interval (int, default: 5) – How often to check validation metrics (in epochs).
validation_data_loader (DataLoader | None, default: None) – Validation data loader for early stopping. Uses training data if None.
**kwargs – Additional arguments passed to the core training function.
- Returns:
Training results dictionary with the following structure:
Guaranteed keys (always present):
historydictTraining metrics per epoch. Keys depend on tracking options:
Core metrics (always):
loss,archetypal_loss,archetype_r2,rmseKLD metrics:
kld_loss,KLDStability metrics (if track_stability=True):
archetype_drift_mean,archetype_drift_max,archetype_variance_mean, etc.Constraint metrics (if validate_constraints=True):
constraints_satisfied,A_sum_error,B_sum_error,constraint_violation_rateValidation metrics (if early_stopping=True):
val_loss,val_archetype_r2
final_modeltorch.nn.ModuleTrained Deep_AA model instance.
modeltorch.nn.ModuleAlias for
final_model(same object, for compatibility).
final_optimizertorch.optim.OptimizerFinal optimizer state.
final_analysisdictFinal training analysis containing:
final_constraint_validation: dict with constraint metricsarchetypal_weights: dict with A_matrix and B_matrix analysisfinal_coordinates: dict with ‘A’, ‘B’, ‘Y’ tensorserror: str (only if analysis failed)
epoch_archetype_positionslist[torch.Tensor]Archetype positions at each epoch, shape (n_archetypes, input_dim).
training_configdictTraining configuration with keys:
n_epochs,actual_epochs,early_stop_triggered,archetypal_weight,kld_weight,reconstruction_weight,activation_func,seed,constraint_tolerance,stability_history_size,early_stopping,early_stopping_patience,early_stopping_metric.
Convenience keys (conditional - use .get() to access safely):
final_archetype_r2float | NoneLast value of
history['archetype_r2']if tracked.
final_rmsefloat | NoneLast value of
history['rmse']if tracked.
final_maefloat | NoneLast value of
history['mae']if tracked.
final_lossfloat | NoneLast value of
history['loss']if tracked.
convergence_epochint | NoneEquals
training_config['actual_epochs'].
- Return type:
- Raises:
ValueError – If
adata.obsm[pca_key]is not found. Runscanpy.pp.pca()first. Ifn_archetypesexceeds PCA dimensions.RuntimeError – If CUDA device is requested but not available.
Stores –
------ –
The function stores the following in AnnData: –
:raises -
adata.uns[store_coords_key]: np.ndarray: Archetype positions in PCA space, shape (n_archetypes, n_pcs). Default key: ‘archetype_coordinates’.Notes
Archetypal Analysis Theory: Archetypal analysis represents each data point as a convex combination of extreme points (archetypes). Unlike clustering, which partitions data, archetypal analysis allows cells to have partial membership in multiple archetypes, better reflecting biological continuity.
Model Architecture: Uses a variational autoencoder where the latent space directly represents archetypal coordinates (A matrix). Archetypes are learned as model parameters (Y matrix) rather than constructed from data points.
Accessing Results Safely:
For guaranteed keys, direct access works:
results['history']For convenience keys, use
.get():results.get('final_archetype_r2')Or access from history:
results['history']['archetype_r2'][-1]
Type Validation: For IDE autocomplete and runtime validation:
from peach._core.types import TrainingResults, validate_training_results validated = validate_training_results(results)
Examples
Basic usage with default parameters:
>>> import scanpy as sc >>> import peach as pc >>> # Prepare data with PCA >>> sc.pp.pca(adata, n_comps=30) >>> # Train archetypal model >>> results = pc.tl.train_archetypal(adata, n_archetypes=5, n_epochs=100) >>> # Access final R² safely (may be None if not tracked) >>> r2 = results.get("final_archetype_r2") >>> if r2 is not None: ... print(f"Final R²: {r2:.3f}") >>> # Or access from history (guaranteed if metric was tracked) >>> if results["history"].get("archetype_r2"): ... r2 = results["history"]["archetype_r2"][-1] ... print(f"Final R²: {r2:.3f}")
Advanced usage with custom configuration:
>>> model_config = {"hidden_dims": [512, 256, 128], "inflation_factor": 2.0} >>> results = pc.tl.train_archetypal( ... adata, ... n_archetypes=4, ... n_epochs=150, ... model_config=model_config, ... early_stopping=True, ... early_stopping_patience=15, ... device="cuda", ... ) >>> # Check if early stopping triggered >>> config = results["training_config"] >>> if config["early_stop_triggered"]: ... print(f"Converged at epoch {config['actual_epochs']}")
Accessing archetype coordinates stored in AnnData:
>>> # After training, coordinates are in adata.uns >>> archetype_coords = adata.uns["archetype_coordinates"] >>> print(f"Learned {archetype_coords.shape[0]} archetypes") >>> print(f"Each archetype has {archetype_coords.shape[1]} PCA dimensions")
See also
peach.tl.archetypal_coordinatesExtract cell-archetype distances
peach.tl.assign_archetypesAssign cells to discovered archetypes
peach.tl.extract_archetype_weightsGet barycentric coordinates for cells
peach.pl.training_metricsVisualize training curves
peach._core.types.TrainingResultsType definition for return structure
- peach.tl.archetypal.archetypal_coordinates(adata, *, pca_key='X_pca', archetype_coords_key='archetype_coordinates', obsm_key='archetype_distances', uns_prefix='archetype', verbose=True, **kwargs)[source]#
Extract archetypal coordinates for all cells.
- Parameters:
adata (AnnData) – Annotated data object with trained model coordinates
pca_key (str, default: "X_pca") – Key in adata.obsm containing PCA coordinates
archetype_coords_key (str, default: "archetype_coordinates") – Key in adata.uns containing archetype coordinates
obsm_key (str, default: "archetype_distances") – Key to store distance matrix in adata.obsm
uns_prefix (str, default: "archetype") – Prefix for keys stored in adata.uns
verbose (bool, default: True) – Whether to print progress messages
**kwargs – Additional arguments passed to compute_archetype_distances
- Returns:
Dictionary with archetypal coordinates and distances
- Return type:
- peach.tl.archetypal.assign_archetypes(adata, *, percentage_per_archetype=0.1, obsm_key='archetype_distances', obs_key='archetypes', include_central_archetype=True, verbose=True, **kwargs)[source]#
Assign cells to archetypes based on distances.
- Parameters:
adata (AnnData) – Annotated data object with archetype distances
percentage_per_archetype (float, default: 0.1) – Percentage of cells to assign to each archetype
obsm_key (str, default: "archetype_distances") – Key in adata.obsm containing distance matrix
obs_key (str, default: "archetypes") – Key to store assignments in adata.obs
include_central_archetype (bool, default: True) – Whether to include a central archetype (cells far from all extreme archetypes)
verbose (bool, default: True) – Whether to print progress messages
**kwargs – Additional arguments passed to bin_cells_by_archetype
- Return type:
- peach.tl.archetypal.extract_archetype_weights(adata, model=None, *, pca_key='X_pca', weights_key='cell_archetype_weights', batch_size=256, device='cpu', verbose=True)[source]#
Extract cell archetype weights from trained Deep_AA model.
This function computes the barycentric coordinates (weights) for each cell that describe how it’s composed of the learned archetypes.
- Parameters:
adata (AnnData) – Annotated data object with PCA coordinates
model (Deep_AA model, optional) – Trained model. If None, will look for model in adata.uns[‘trained_model’]
pca_key (str, default: "X_pca") – Key in adata.obsm containing PCA coordinates
weights_key (str, default: "cell_archetype_weights") – Key to store weights in adata.obsm
batch_size (int, default: 256) – Batch size for processing
device (str, default: "cpu") – Device for computation (‘cpu’, ‘cuda’, or ‘mps’)
verbose (bool, default: True) – Whether to print progress
- Returns:
Cell archetype weights of shape (n_cells, n_archetypes) Also stores weights in adata.obsm[weights_key]
- Return type:
np.ndarray
Examples
>>> # After training >>> results = pc.tl.train_archetypal(adata, n_archetypes=5) >>> >>> # Extract weights >>> weights = pc.tl.extract_archetype_weights(adata, results["model"]) >>> >>> # Weights are now in adata.obsm['cell_archetype_weights'] >>> print(adata.obsm["cell_archetype_weights"].shape)
- peach.tl.archetypal.compute_conditional_centroids(adata, condition_column, *, pca_key='X_pca', store_key='conditional_centroids', exclude_archetypes=None, groupby=None, verbose=True)[source]#
Compute centroid positions in PCA space for each level of a categorical condition.
This function calculates the mean position (centroid) in PCA space for cells belonging to each level of a categorical variable. Useful for visualizing how different conditions (e.g., treatment phases, timepoints) relate to the archetypal structure.
Following R template patterns: - Uses ALL PCs for centroid calculation (equivalent to R’s colMeans) - Stores full PC centroid but extracts first 3 for 3D visualization - Excludes ‘no_archetype’ and ‘archetype_0’ cells by default
- Parameters:
adata (AnnData) – Annotated data object with PCA coordinates in adata.obsm[pca_key].
condition_column (str) – Name of categorical column in adata.obs to group by. Examples: ‘treatment_phase’, ‘timepoint’, ‘batch’.
pca_key (str, default: "X_pca") – Key in adata.obsm containing PCA coordinates.
store_key (str, default: "conditional_centroids") – Key in adata.uns to store results.
exclude_archetypes (list, optional) – Archetype labels to exclude from centroid calculation. Default: [‘no_archetype’, ‘archetype_0’] (following R template). Set to empty list [] to include all cells.
groupby (str, optional) – Second categorical column for multi-group trajectories. If provided, centroids are computed for each (group, level) combination. Example: groupby=’response_group’ to get separate trajectories per response.
verbose (bool, default: True) – Whether to print progress messages.
- Returns:
Dictionary with keys:
condition_column: str - name of the condition columnn_levels: int - number of unique levelslevels: List[str] - list of level namescentroids: Dict[str, List[float]] - level → full PCA coordinatescentroids_3d: Dict[str, List[float]] - level → [x, y, z] first 3 PCscell_counts: Dict[str, int] - level → cell countpca_key: str - PCA key usedexclude_archetypes: List[str] - archetypes excludedgroupby: Optional[str] - groupby column if usedgroup_centroids: Optional[Dict] - if groupby: {group: {level: coords}}group_centroids_3d: Optional[Dict] - if groupby: {group: {level: [x,y,z]}}group_cell_counts: Optional[Dict] - if groupby: {group: {level: count}}
- Return type:
- Raises:
ValueError – If condition_column not in adata.obs or PCA coordinates not found.
Stores –
------ –
The function stores results in AnnData: –
:raises -
adata.uns[store_key][condition_column]: dict: Full results dictionary as returned.Examples
>>> # Compute centroids for treatment phase >>> result = pc.tl.compute_conditional_centroids(adata, "treatment_phase") >>> print(result["centroids_3d"]) {'chemo-naive': [1.2, 0.5, -0.3], 'IDS': [0.8, 1.1, 0.2]}
>>> # Then visualize with trajectory >>> fig = pc.pl.archetypal_space( ... adata, show_centroids=True, centroid_condition="treatment_phase", centroid_order=["chemo-naive", "IDS"] ... )
>>> # Multi-group centroids for trajectory comparison >>> result = pc.tl.compute_conditional_centroids(adata, "treatment_phase", groupby="response_group") >>> fig = pc.pl.archetypal_space( ... adata, ... show_centroids=True, ... centroid_condition="treatment_phase", ... centroid_groupby="response_group", ... centroid_order=["chemo-naive", "IDS"], ... centroid_colors={"long": "magenta", "short": "cyan"}, ... )
See also
peach.pl.archetypal_spaceVisualize with centroid trajectory overlay
- peach.tl.archetypal.assign_to_centroids(adata, condition_column, *, pca_key='X_pca', centroid_key='conditional_centroids', bin_prop=0.15, obs_key='centroid_assignments', exclude_archetypes=None, verbose=True)[source]#
Assign cells to nearest centroid based on distance (top bin_prop% closest).
This function mirrors assign_archetypes but for condition-based centroids. It enables using treatment phase centroids as trajectory endpoints in single_trajectory_analysis by creating categorical assignments that CellRank can use as terminal states.
- Parameters:
adata (AnnData) – Annotated data object. Must have: - PCA coordinates in adata.obsm[pca_key] - Centroids computed via compute_conditional_centroids in adata.uns[centroid_key]
condition_column (str) – Name of the condition column used in compute_conditional_centroids. This identifies which centroid set to use.
pca_key (str, default: "X_pca") – Key in adata.obsm containing PCA coordinates.
centroid_key (str, default: "conditional_centroids") – Key in adata.uns containing centroid results from compute_conditional_centroids.
bin_prop (float, default: 0.15) – Proportion of cells to assign to each centroid (top 15% closest). Similar to percentage_per_archetype in assign_archetypes.
obs_key (str, default: "centroid_assignments") – Key in adata.obs to store assignments.
exclude_archetypes (list, optional) – Archetype labels to exclude from assignment. Default: [‘no_archetype’] - these cells get ‘unassigned’.
verbose (bool, default: True) – Whether to print progress messages.
- Returns:
Modifies adata.obs[obs_key] with Categorical assignments. Values are condition levels (e.g., ‘chemo_naive’, ‘IDS’) or ‘unassigned’.
- Return type:
None
Examples
>>> # First compute centroids for treatment phases >>> pc.tl.compute_conditional_centroids(adata, "treatment_stage") >>> >>> # Then assign cells to nearest centroid (top 15% closest) >>> pc.tl.assign_to_centroids(adata, "treatment_stage", bin_prop=0.15) >>> >>> # Check assignments >>> print(adata.obs["centroid_assignments"].value_counts()) >>> >>> # Now can use with CellRank for trajectory analysis >>> # (setup_cellrank can use centroid_assignments as terminal states)
See also
compute_conditional_centroidsCompute centroids for condition levels
assign_archetypesSimilar function for archetype assignments
single_trajectory_analysisUses centroid assignments for trajectory analysis