peach.tl.archetypal#

Core archetypal analysis functions.

This module provides the primary interface for training Deep Archetypal Analysis models and extracting archetypal coordinates. All functions work directly with AnnData objects and follow scVerse conventions.

Main Functions: - train_archetypal(): Train Deep AA model to discover cellular archetypes - archetypal_coordinates(): Extract archetypal coordinates for all cells - assign_archetypes(): Assign cells to discovered archetypes based on distances

The module integrates PCHA initialization, inflation factors, and comprehensive training diagnostics for production-ready archetypal analysis workflows.

Functions

archetypal_coordinates(adata, *[, pca_key, ...])

Extract archetypal coordinates for all cells.

assign_archetypes(adata, *[, ...])

Assign cells to archetypes based on distances.

assign_to_centroids(adata, condition_column, *)

Assign cells to nearest centroid based on distance (top bin_prop% closest).

compute_conditional_centroids(adata, ...[, ...])

Compute centroid positions in PCA space for each level of a categorical condition.

extract_archetype_weights(adata[, model, ...])

Extract cell archetype weights from trained Deep_AA model.

train_archetypal(adata[, n_archetypes, ...])

Train Deep Archetypal Analysis model to discover cellular archetypes.

peach.tl.archetypal.train_archetypal(adata, n_archetypes=5, n_epochs=50, *, layer=None, pca_key='X_pca', hidden_dims=None, inflation_factor=1.5, model_config=None, optimizer_config=None, device='cpu', save_path=None, archetypal_weight=None, kld_weight=None, reconstruction_weight=0.0, vae_recon_weight=0.0, diversity_weight=0.0, activation_func='relu', track_stability=True, validate_constraints=True, lr_factor=0.1, lr_patience=10, seed=42, constraint_tolerance=0.001, stability_history_size=20, store_coords_key='archetype_coordinates', early_stopping=False, early_stopping_patience=10, early_stopping_metric='archetype_r2', min_improvement=0.0001, validation_check_interval=5, validation_data_loader=None, **kwargs)[source]#

Train Deep Archetypal Analysis model to discover cellular archetypes.

This function performs archetypal analysis using a variational autoencoder architecture to identify extreme cellular states (archetypes) that capture the main axes of biological variation. Each cell is represented as a convex combination of the learned archetypes.

The model uses PCHA initialization with inflation factors for optimal archetype positioning and achieves state-of-the-art performance (R² > 0.89) on real single-cell datasets.

Parameters:
  • adata (AnnData) – Annotated data object containing single-cell expression data. Must have PCA coordinates in adata.obsm[pca_key]. Typically generated using scanpy.pp.pca(adata).

  • n_archetypes (int, default: 5) – Number of archetypal patterns to learn. Should be chosen based on biological knowledge or using hyperparameter optimization. Common values: 3-10 for most datasets.

  • n_epochs (int, default: 50) – Number of training epochs. Larger datasets may require more epochs. For datasets >5K cells, consider 100-200 epochs.

  • layer (str | None, default: None) – AnnData layer to use for training. If None, uses PCA coordinates from adata.obsm[pca_key] (recommended).

  • pca_key (str, default: "X_pca") – Key in adata.obsm containing PCA coordinates. The model works best with 5-50 PCA components. Auto-detects: ‘X_pca’, ‘X_PCA’, ‘PCA’.

  • hidden_dims (list[int] | None, default: None) – Encoder/decoder layer dimensions. If None, uses [256, 128, 64]. Smaller architectures like [128, 64] train faster but may underfit. Larger architectures like [512, 256, 128] may overfit on small datasets.

  • inflation_factor (float, default: 1.5) – PCHA inflation factor for archetype initialization. Values > 1.0 push initial archetypes further from the data centroid, improving separation. Recommended range: 1.2-2.0. Higher values for more distinct archetypes.

  • model_config (dict | None, default: None) –

    Additional model configuration parameters (for advanced users):

    • archetypal_weight : float - Archetypal loss weight, default 0.9

    • kld_weight : float - KL divergence weight, default 0.1

    • diversity_weight : float - Archetype diversity weight, default 0.05

    • use_barycentric : bool - Use softmax constraints, default True

  • optimizer_config (dict | None, default: None) –

    Optimizer configuration parameters:

    • lr : float - Learning rate, default 1e-3

    • weight_decay : float - L2 regularization, default 0.0

    • betas : tuple[float, float] - Adam momentum parameters

  • device (str, default: "cpu") – Compute device for training. One of “cpu”, “cuda”, “mps”.

  • save_path (str | None, default: None) – Path to save model checkpoints during training.

  • archetypal_weight (float | None, default: None) – Weight for archetypal loss component. Uses model’s configured value if None.

  • kld_weight (float | None, default: None) – Weight for KL divergence loss. Uses model’s configured value if None.

  • reconstruction_weight (float, default: 0.0) – Legacy reconstruction weight parameter.

  • vae_recon_weight (float, default: 0.0) – VAE reconstruction weight.

  • diversity_weight (float, default: 0.0) – Weight for archetype diversity loss.

  • activation_func (str, default: "relu") – Activation function for the model.

  • track_stability (bool, default: True) – Whether to monitor archetype drift during training. Adds stability metrics to history: archetype_drift_mean, archetype_variance_mean, etc.

  • validate_constraints (bool, default: True) – Whether to validate archetypal constraints during training. Adds constraint metrics to history: constraints_satisfied, A_sum_error, etc.

  • lr_factor (float, default: 0.1) – Factor for learning rate reduction on plateau.

  • lr_patience (int, default: 10) – Number of epochs with no improvement before reducing learning rate.

  • seed (int, default: 42) – Random seed for reproducibility.

  • constraint_tolerance (float, default: 1e-3) – Tolerance for constraint validation.

  • stability_history_size (int, default: 20) – Number of epochs to track for stability analysis.

  • store_coords_key (str, default: "archetype_coordinates") – Key to store learned archetype positions in adata.uns.

  • early_stopping (bool, default: False) – Whether to use early stopping based on validation metrics.

  • early_stopping_patience (int, default: 10) – Patience for early stopping (number of checks without improvement).

  • early_stopping_metric (str, default: "archetype_r2") – Metric to monitor for early stopping. One of: ‘archetype_r2’, ‘loss’, ‘rmse’.

  • min_improvement (float, default: 1e-4) – Minimum improvement required to reset patience counter.

  • validation_check_interval (int, default: 5) – How often to check validation metrics (in epochs).

  • validation_data_loader (DataLoader | None, default: None) – Validation data loader for early stopping. Uses training data if None.

  • **kwargs – Additional arguments passed to the core training function.

Returns:

Training results dictionary with the following structure:

Guaranteed keys (always present):

  • historydict

    Training metrics per epoch. Keys depend on tracking options:

    • Core metrics (always): loss, archetypal_loss, archetype_r2, rmse

    • KLD metrics: kld_loss, KLD

    • Stability metrics (if track_stability=True): archetype_drift_mean, archetype_drift_max, archetype_variance_mean, etc.

    • Constraint metrics (if validate_constraints=True): constraints_satisfied, A_sum_error, B_sum_error, constraint_violation_rate

    • Validation metrics (if early_stopping=True): val_loss, val_archetype_r2

  • final_modeltorch.nn.Module

    Trained Deep_AA model instance.

  • modeltorch.nn.Module

    Alias for final_model (same object, for compatibility).

  • final_optimizertorch.optim.Optimizer

    Final optimizer state.

  • final_analysisdict

    Final training analysis containing:

    • final_constraint_validation : dict with constraint metrics

    • archetypal_weights : dict with A_matrix and B_matrix analysis

    • final_coordinates : dict with ‘A’, ‘B’, ‘Y’ tensors

    • error : str (only if analysis failed)

  • epoch_archetype_positionslist[torch.Tensor]

    Archetype positions at each epoch, shape (n_archetypes, input_dim).

  • training_configdict

    Training configuration with keys: n_epochs, actual_epochs, early_stop_triggered, archetypal_weight, kld_weight, reconstruction_weight, activation_func, seed, constraint_tolerance, stability_history_size, early_stopping, early_stopping_patience, early_stopping_metric.

Convenience keys (conditional - use .get() to access safely):

  • final_archetype_r2float | None

    Last value of history['archetype_r2'] if tracked.

  • final_rmsefloat | None

    Last value of history['rmse'] if tracked.

  • final_maefloat | None

    Last value of history['mae'] if tracked.

  • final_lossfloat | None

    Last value of history['loss'] if tracked.

  • convergence_epochint | None

    Equals training_config['actual_epochs'].

Return type:

dict

Raises:
  • ValueError – If adata.obsm[pca_key] is not found. Run scanpy.pp.pca() first. If n_archetypes exceeds PCA dimensions.

  • RuntimeError – If CUDA device is requested but not available.

  • Stores

  • ------

  • The function stores the following in AnnData:

:raises - adata.uns[store_coords_key] : np.ndarray: Archetype positions in PCA space, shape (n_archetypes, n_pcs). Default key: ‘archetype_coordinates’.

Notes

Archetypal Analysis Theory: Archetypal analysis represents each data point as a convex combination of extreme points (archetypes). Unlike clustering, which partitions data, archetypal analysis allows cells to have partial membership in multiple archetypes, better reflecting biological continuity.

Model Architecture: Uses a variational autoencoder where the latent space directly represents archetypal coordinates (A matrix). Archetypes are learned as model parameters (Y matrix) rather than constructed from data points.

Accessing Results Safely:

  • For guaranteed keys, direct access works: results['history']

  • For convenience keys, use .get(): results.get('final_archetype_r2')

  • Or access from history: results['history']['archetype_r2'][-1]

Type Validation: For IDE autocomplete and runtime validation:

from peach._core.types import TrainingResults, validate_training_results

validated = validate_training_results(results)

Examples

Basic usage with default parameters:

>>> import scanpy as sc
>>> import peach as pc
>>> # Prepare data with PCA
>>> sc.pp.pca(adata, n_comps=30)
>>> # Train archetypal model
>>> results = pc.tl.train_archetypal(adata, n_archetypes=5, n_epochs=100)
>>> # Access final R² safely (may be None if not tracked)
>>> r2 = results.get("final_archetype_r2")
>>> if r2 is not None:
...     print(f"Final R²: {r2:.3f}")
>>> # Or access from history (guaranteed if metric was tracked)
>>> if results["history"].get("archetype_r2"):
...     r2 = results["history"]["archetype_r2"][-1]
...     print(f"Final R²: {r2:.3f}")

Advanced usage with custom configuration:

>>> model_config = {"hidden_dims": [512, 256, 128], "inflation_factor": 2.0}
>>> results = pc.tl.train_archetypal(
...     adata,
...     n_archetypes=4,
...     n_epochs=150,
...     model_config=model_config,
...     early_stopping=True,
...     early_stopping_patience=15,
...     device="cuda",
... )
>>> # Check if early stopping triggered
>>> config = results["training_config"]
>>> if config["early_stop_triggered"]:
...     print(f"Converged at epoch {config['actual_epochs']}")

Accessing archetype coordinates stored in AnnData:

>>> # After training, coordinates are in adata.uns
>>> archetype_coords = adata.uns["archetype_coordinates"]
>>> print(f"Learned {archetype_coords.shape[0]} archetypes")
>>> print(f"Each archetype has {archetype_coords.shape[1]} PCA dimensions")

See also

peach.tl.archetypal_coordinates

Extract cell-archetype distances

peach.tl.assign_archetypes

Assign cells to discovered archetypes

peach.tl.extract_archetype_weights

Get barycentric coordinates for cells

peach.pl.training_metrics

Visualize training curves

peach._core.types.TrainingResults

Type definition for return structure

peach.tl.archetypal.archetypal_coordinates(adata, *, pca_key='X_pca', archetype_coords_key='archetype_coordinates', obsm_key='archetype_distances', uns_prefix='archetype', verbose=True, **kwargs)[source]#

Extract archetypal coordinates for all cells.

Parameters:
  • adata (AnnData) – Annotated data object with trained model coordinates

  • pca_key (str, default: "X_pca") – Key in adata.obsm containing PCA coordinates

  • archetype_coords_key (str, default: "archetype_coordinates") – Key in adata.uns containing archetype coordinates

  • obsm_key (str, default: "archetype_distances") – Key to store distance matrix in adata.obsm

  • uns_prefix (str, default: "archetype") – Prefix for keys stored in adata.uns

  • verbose (bool, default: True) – Whether to print progress messages

  • **kwargs – Additional arguments passed to compute_archetype_distances

Returns:

Dictionary with archetypal coordinates and distances

Return type:

dict

peach.tl.archetypal.assign_archetypes(adata, *, percentage_per_archetype=0.1, obsm_key='archetype_distances', obs_key='archetypes', include_central_archetype=True, verbose=True, **kwargs)[source]#

Assign cells to archetypes based on distances.

Parameters:
  • adata (AnnData) – Annotated data object with archetype distances

  • percentage_per_archetype (float, default: 0.1) – Percentage of cells to assign to each archetype

  • obsm_key (str, default: "archetype_distances") – Key in adata.obsm containing distance matrix

  • obs_key (str, default: "archetypes") – Key to store assignments in adata.obs

  • include_central_archetype (bool, default: True) – Whether to include a central archetype (cells far from all extreme archetypes)

  • verbose (bool, default: True) – Whether to print progress messages

  • **kwargs – Additional arguments passed to bin_cells_by_archetype

Return type:

None

peach.tl.archetypal.extract_archetype_weights(adata, model=None, *, pca_key='X_pca', weights_key='cell_archetype_weights', batch_size=256, device='cpu', verbose=True)[source]#

Extract cell archetype weights from trained Deep_AA model.

This function computes the barycentric coordinates (weights) for each cell that describe how it’s composed of the learned archetypes.

Parameters:
  • adata (AnnData) – Annotated data object with PCA coordinates

  • model (Deep_AA model, optional) – Trained model. If None, will look for model in adata.uns[‘trained_model’]

  • pca_key (str, default: "X_pca") – Key in adata.obsm containing PCA coordinates

  • weights_key (str, default: "cell_archetype_weights") – Key to store weights in adata.obsm

  • batch_size (int, default: 256) – Batch size for processing

  • device (str, default: "cpu") – Device for computation (‘cpu’, ‘cuda’, or ‘mps’)

  • verbose (bool, default: True) – Whether to print progress

Returns:

Cell archetype weights of shape (n_cells, n_archetypes) Also stores weights in adata.obsm[weights_key]

Return type:

np.ndarray

Examples

>>> # After training
>>> results = pc.tl.train_archetypal(adata, n_archetypes=5)
>>>
>>> # Extract weights
>>> weights = pc.tl.extract_archetype_weights(adata, results["model"])
>>>
>>> # Weights are now in adata.obsm['cell_archetype_weights']
>>> print(adata.obsm["cell_archetype_weights"].shape)
peach.tl.archetypal.compute_conditional_centroids(adata, condition_column, *, pca_key='X_pca', store_key='conditional_centroids', exclude_archetypes=None, groupby=None, verbose=True)[source]#

Compute centroid positions in PCA space for each level of a categorical condition.

This function calculates the mean position (centroid) in PCA space for cells belonging to each level of a categorical variable. Useful for visualizing how different conditions (e.g., treatment phases, timepoints) relate to the archetypal structure.

Following R template patterns: - Uses ALL PCs for centroid calculation (equivalent to R’s colMeans) - Stores full PC centroid but extracts first 3 for 3D visualization - Excludes ‘no_archetype’ and ‘archetype_0’ cells by default

Parameters:
  • adata (AnnData) – Annotated data object with PCA coordinates in adata.obsm[pca_key].

  • condition_column (str) – Name of categorical column in adata.obs to group by. Examples: ‘treatment_phase’, ‘timepoint’, ‘batch’.

  • pca_key (str, default: "X_pca") – Key in adata.obsm containing PCA coordinates.

  • store_key (str, default: "conditional_centroids") – Key in adata.uns to store results.

  • exclude_archetypes (list, optional) – Archetype labels to exclude from centroid calculation. Default: [‘no_archetype’, ‘archetype_0’] (following R template). Set to empty list [] to include all cells.

  • groupby (str, optional) – Second categorical column for multi-group trajectories. If provided, centroids are computed for each (group, level) combination. Example: groupby=’response_group’ to get separate trajectories per response.

  • verbose (bool, default: True) – Whether to print progress messages.

Returns:

Dictionary with keys:

  • condition_column : str - name of the condition column

  • n_levels : int - number of unique levels

  • levels : List[str] - list of level names

  • centroids : Dict[str, List[float]] - level → full PCA coordinates

  • centroids_3d : Dict[str, List[float]] - level → [x, y, z] first 3 PCs

  • cell_counts : Dict[str, int] - level → cell count

  • pca_key : str - PCA key used

  • exclude_archetypes : List[str] - archetypes excluded

  • groupby : Optional[str] - groupby column if used

  • group_centroids : Optional[Dict] - if groupby: {group: {level: coords}}

  • group_centroids_3d : Optional[Dict] - if groupby: {group: {level: [x,y,z]}}

  • group_cell_counts : Optional[Dict] - if groupby: {group: {level: count}}

Return type:

dict

Raises:
  • ValueError – If condition_column not in adata.obs or PCA coordinates not found.

  • Stores

  • ------

  • The function stores results in AnnData:

:raises - adata.uns[store_key][condition_column] : dict: Full results dictionary as returned.

Examples

>>> # Compute centroids for treatment phase
>>> result = pc.tl.compute_conditional_centroids(adata, "treatment_phase")
>>> print(result["centroids_3d"])
{'chemo-naive': [1.2, 0.5, -0.3], 'IDS': [0.8, 1.1, 0.2]}
>>> # Then visualize with trajectory
>>> fig = pc.pl.archetypal_space(
...     adata, show_centroids=True, centroid_condition="treatment_phase", centroid_order=["chemo-naive", "IDS"]
... )
>>> # Multi-group centroids for trajectory comparison
>>> result = pc.tl.compute_conditional_centroids(adata, "treatment_phase", groupby="response_group")
>>> fig = pc.pl.archetypal_space(
...     adata,
...     show_centroids=True,
...     centroid_condition="treatment_phase",
...     centroid_groupby="response_group",
...     centroid_order=["chemo-naive", "IDS"],
...     centroid_colors={"long": "magenta", "short": "cyan"},
... )

See also

peach.pl.archetypal_space

Visualize with centroid trajectory overlay

peach.tl.archetypal.assign_to_centroids(adata, condition_column, *, pca_key='X_pca', centroid_key='conditional_centroids', bin_prop=0.15, obs_key='centroid_assignments', exclude_archetypes=None, verbose=True)[source]#

Assign cells to nearest centroid based on distance (top bin_prop% closest).

This function mirrors assign_archetypes but for condition-based centroids. It enables using treatment phase centroids as trajectory endpoints in single_trajectory_analysis by creating categorical assignments that CellRank can use as terminal states.

Parameters:
  • adata (AnnData) – Annotated data object. Must have: - PCA coordinates in adata.obsm[pca_key] - Centroids computed via compute_conditional_centroids in adata.uns[centroid_key]

  • condition_column (str) – Name of the condition column used in compute_conditional_centroids. This identifies which centroid set to use.

  • pca_key (str, default: "X_pca") – Key in adata.obsm containing PCA coordinates.

  • centroid_key (str, default: "conditional_centroids") – Key in adata.uns containing centroid results from compute_conditional_centroids.

  • bin_prop (float, default: 0.15) – Proportion of cells to assign to each centroid (top 15% closest). Similar to percentage_per_archetype in assign_archetypes.

  • obs_key (str, default: "centroid_assignments") – Key in adata.obs to store assignments.

  • exclude_archetypes (list, optional) – Archetype labels to exclude from assignment. Default: [‘no_archetype’] - these cells get ‘unassigned’.

  • verbose (bool, default: True) – Whether to print progress messages.

Returns:

Modifies adata.obs[obs_key] with Categorical assignments. Values are condition levels (e.g., ‘chemo_naive’, ‘IDS’) or ‘unassigned’.

Return type:

None

Examples

>>> # First compute centroids for treatment phases
>>> pc.tl.compute_conditional_centroids(adata, "treatment_stage")
>>>
>>> # Then assign cells to nearest centroid (top 15% closest)
>>> pc.tl.assign_to_centroids(adata, "treatment_stage", bin_prop=0.15)
>>>
>>> # Check assignments
>>> print(adata.obs["centroid_assignments"].value_counts())
>>>
>>> # Now can use with CellRank for trajectory analysis
>>> # (setup_cellrank can use centroid_assignments as terminal states)

See also

compute_conditional_centroids

Compute centroids for condition levels

assign_archetypes

Similar function for archetype assignments

single_trajectory_analysis

Uses centroid assignments for trajectory analysis