"""
Core archetypal analysis functions.
This module provides the primary interface for training Deep Archetypal Analysis
models and extracting archetypal coordinates. All functions work directly with
AnnData objects and follow scVerse conventions.
Main Functions:
- train_archetypal(): Train Deep AA model to discover cellular archetypes
- archetypal_coordinates(): Extract archetypal coordinates for all cells
- assign_archetypes(): Assign cells to discovered archetypes based on distances
The module integrates PCHA initialization, inflation factors, and comprehensive
training diagnostics for production-ready archetypal analysis workflows.
"""
from typing import Any
import numpy as np
import torch
from anndata import AnnData
from .._core.models.Deep_AA import Deep_AA
from .._core.utils.analysis import bin_cells_by_archetype as _bin_cells_by_archetype
from .._core.utils.analysis import compute_archetype_distances as _compute_archetype_distances
# Import existing battle-tested functions
from .._core.utils.training import train_vae as _train_vae
[docs]
def train_archetypal(
adata: AnnData,
n_archetypes: int = 5,
n_epochs: int = 50,
*, # Keyword-only arguments (scVerse convention)
layer: str | None = None,
pca_key: str = "X_pca",
# Model architecture parameters
hidden_dims: list[int] | None = None,
inflation_factor: float = 1.5,
# Config dicts (for advanced users)
model_config: dict[str, Any] | None = None,
optimizer_config: dict[str, Any] | None = None,
# Parameters from _core.train_vae
device: str = "cpu",
save_path: str = None,
archetypal_weight: float = None,
kld_weight: float = None,
reconstruction_weight: float = 0.0,
vae_recon_weight: float = 0.0,
diversity_weight: float = 0.0,
activation_func: str = "relu",
track_stability: bool = True,
validate_constraints: bool = True,
lr_factor: float = 0.1,
lr_patience: int = 10,
seed: int = 42,
constraint_tolerance: float = 1e-3,
stability_history_size: int = 20,
store_coords_key: str = "archetype_coordinates",
early_stopping: bool = False,
early_stopping_patience: int = 10,
early_stopping_metric: str = "archetype_r2",
min_improvement: float = 1e-4,
validation_check_interval: int = 5,
validation_data_loader=None,
**kwargs,
) -> dict[str, Any]:
"""Train Deep Archetypal Analysis model to discover cellular archetypes.
This function performs archetypal analysis using a variational autoencoder
architecture to identify extreme cellular states (archetypes) that capture
the main axes of biological variation. Each cell is represented as a convex
combination of the learned archetypes.
The model uses PCHA initialization with inflation factors for optimal
archetype positioning and achieves state-of-the-art performance (R² > 0.89)
on real single-cell datasets.
Parameters
----------
adata : AnnData
Annotated data object containing single-cell expression data.
Must have PCA coordinates in ``adata.obsm[pca_key]``.
Typically generated using ``scanpy.pp.pca(adata)``.
n_archetypes : int, default: 5
Number of archetypal patterns to learn. Should be chosen based on
biological knowledge or using hyperparameter optimization.
Common values: 3-10 for most datasets.
n_epochs : int, default: 50
Number of training epochs. Larger datasets may require more epochs.
For datasets >5K cells, consider 100-200 epochs.
layer : str | None, default: None
AnnData layer to use for training. If None, uses PCA coordinates
from ``adata.obsm[pca_key]`` (recommended).
pca_key : str, default: "X_pca"
Key in ``adata.obsm`` containing PCA coordinates. The model works
best with 5-50 PCA components. Auto-detects: 'X_pca', 'X_PCA', 'PCA'.
hidden_dims : list[int] | None, default: None
Encoder/decoder layer dimensions. If None, uses [256, 128, 64].
Smaller architectures like [128, 64] train faster but may underfit.
Larger architectures like [512, 256, 128] may overfit on small datasets.
inflation_factor : float, default: 1.5
PCHA inflation factor for archetype initialization. Values > 1.0 push
initial archetypes further from the data centroid, improving separation.
Recommended range: 1.2-2.0. Higher values for more distinct archetypes.
model_config : dict | None, default: None
Additional model configuration parameters (for advanced users):
- ``archetypal_weight`` : float - Archetypal loss weight, default 0.9
- ``kld_weight`` : float - KL divergence weight, default 0.1
- ``diversity_weight`` : float - Archetype diversity weight, default 0.05
- ``use_barycentric`` : bool - Use softmax constraints, default True
optimizer_config : dict | None, default: None
Optimizer configuration parameters:
- ``lr`` : float - Learning rate, default 1e-3
- ``weight_decay`` : float - L2 regularization, default 0.0
- ``betas`` : tuple[float, float] - Adam momentum parameters
device : str, default: "cpu"
Compute device for training. One of "cpu", "cuda", "mps".
save_path : str | None, default: None
Path to save model checkpoints during training.
archetypal_weight : float | None, default: None
Weight for archetypal loss component. Uses model's configured value if None.
kld_weight : float | None, default: None
Weight for KL divergence loss. Uses model's configured value if None.
reconstruction_weight : float, default: 0.0
Legacy reconstruction weight parameter.
vae_recon_weight : float, default: 0.0
VAE reconstruction weight.
diversity_weight : float, default: 0.0
Weight for archetype diversity loss.
activation_func : str, default: "relu"
Activation function for the model.
track_stability : bool, default: True
Whether to monitor archetype drift during training. Adds stability
metrics to history: archetype_drift_mean, archetype_variance_mean, etc.
validate_constraints : bool, default: True
Whether to validate archetypal constraints during training. Adds
constraint metrics to history: constraints_satisfied, A_sum_error, etc.
lr_factor : float, default: 0.1
Factor for learning rate reduction on plateau.
lr_patience : int, default: 10
Number of epochs with no improvement before reducing learning rate.
seed : int, default: 42
Random seed for reproducibility.
constraint_tolerance : float, default: 1e-3
Tolerance for constraint validation.
stability_history_size : int, default: 20
Number of epochs to track for stability analysis.
store_coords_key : str, default: "archetype_coordinates"
Key to store learned archetype positions in ``adata.uns``.
early_stopping : bool, default: False
Whether to use early stopping based on validation metrics.
early_stopping_patience : int, default: 10
Patience for early stopping (number of checks without improvement).
early_stopping_metric : str, default: "archetype_r2"
Metric to monitor for early stopping. One of: 'archetype_r2', 'loss', 'rmse'.
min_improvement : float, default: 1e-4
Minimum improvement required to reset patience counter.
validation_check_interval : int, default: 5
How often to check validation metrics (in epochs).
validation_data_loader : DataLoader | None, default: None
Validation data loader for early stopping. Uses training data if None.
**kwargs
Additional arguments passed to the core training function.
Returns
-------
dict
Training results dictionary with the following structure:
**Guaranteed keys (always present):**
- ``history`` : dict
Training metrics per epoch. Keys depend on tracking options:
- Core metrics (always): ``loss``, ``archetypal_loss``, ``archetype_r2``, ``rmse``
- KLD metrics: ``kld_loss``, ``KLD``
- Stability metrics (if track_stability=True): ``archetype_drift_mean``,
``archetype_drift_max``, ``archetype_variance_mean``, etc.
- Constraint metrics (if validate_constraints=True): ``constraints_satisfied``,
``A_sum_error``, ``B_sum_error``, ``constraint_violation_rate``
- Validation metrics (if early_stopping=True): ``val_loss``, ``val_archetype_r2``
- ``final_model`` : torch.nn.Module
Trained Deep_AA model instance.
- ``model`` : torch.nn.Module
Alias for ``final_model`` (same object, for compatibility).
- ``final_optimizer`` : torch.optim.Optimizer
Final optimizer state.
- ``final_analysis`` : dict
Final training analysis containing:
- ``final_constraint_validation`` : dict with constraint metrics
- ``archetypal_weights`` : dict with A_matrix and B_matrix analysis
- ``final_coordinates`` : dict with 'A', 'B', 'Y' tensors
- ``error`` : str (only if analysis failed)
- ``epoch_archetype_positions`` : list[torch.Tensor]
Archetype positions at each epoch, shape (n_archetypes, input_dim).
- ``training_config`` : dict
Training configuration with keys: ``n_epochs``, ``actual_epochs``,
``early_stop_triggered``, ``archetypal_weight``, ``kld_weight``,
``reconstruction_weight``, ``activation_func``, ``seed``,
``constraint_tolerance``, ``stability_history_size``, ``early_stopping``,
``early_stopping_patience``, ``early_stopping_metric``.
**Convenience keys (conditional - use .get() to access safely):**
- ``final_archetype_r2`` : float | None
Last value of ``history['archetype_r2']`` if tracked.
- ``final_rmse`` : float | None
Last value of ``history['rmse']`` if tracked.
- ``final_mae`` : float | None
Last value of ``history['mae']`` if tracked.
- ``final_loss`` : float | None
Last value of ``history['loss']`` if tracked.
- ``convergence_epoch`` : int | None
Equals ``training_config['actual_epochs']``.
Raises
------
ValueError
If ``adata.obsm[pca_key]`` is not found. Run ``scanpy.pp.pca()`` first.
If ``n_archetypes`` exceeds PCA dimensions.
RuntimeError
If CUDA device is requested but not available.
Stores
------
The function stores the following in AnnData:
- ``adata.uns[store_coords_key]`` : np.ndarray
Archetype positions in PCA space, shape (n_archetypes, n_pcs).
Default key: 'archetype_coordinates'.
Notes
-----
**Archetypal Analysis Theory**: Archetypal analysis represents each data point
as a convex combination of extreme points (archetypes). Unlike clustering,
which partitions data, archetypal analysis allows cells to have partial
membership in multiple archetypes, better reflecting biological continuity.
**Model Architecture**: Uses a variational autoencoder where the latent space
directly represents archetypal coordinates (A matrix). Archetypes are learned
as model parameters (Y matrix) rather than constructed from data points.
**Accessing Results Safely**:
- For guaranteed keys, direct access works: ``results['history']``
- For convenience keys, use ``.get()``: ``results.get('final_archetype_r2')``
- Or access from history: ``results['history']['archetype_r2'][-1]``
**Type Validation**: For IDE autocomplete and runtime validation::
from peach._core.types import TrainingResults, validate_training_results
validated = validate_training_results(results)
Examples
--------
Basic usage with default parameters:
>>> import scanpy as sc
>>> import peach as pc
>>> # Prepare data with PCA
>>> sc.pp.pca(adata, n_comps=30)
>>> # Train archetypal model
>>> results = pc.tl.train_archetypal(adata, n_archetypes=5, n_epochs=100)
>>> # Access final R² safely (may be None if not tracked)
>>> r2 = results.get("final_archetype_r2")
>>> if r2 is not None:
... print(f"Final R²: {r2:.3f}")
>>> # Or access from history (guaranteed if metric was tracked)
>>> if results["history"].get("archetype_r2"):
... r2 = results["history"]["archetype_r2"][-1]
... print(f"Final R²: {r2:.3f}")
Advanced usage with custom configuration:
>>> model_config = {"hidden_dims": [512, 256, 128], "inflation_factor": 2.0}
>>> results = pc.tl.train_archetypal(
... adata,
... n_archetypes=4,
... n_epochs=150,
... model_config=model_config,
... early_stopping=True,
... early_stopping_patience=15,
... device="cuda",
... )
>>> # Check if early stopping triggered
>>> config = results["training_config"]
>>> if config["early_stop_triggered"]:
... print(f"Converged at epoch {config['actual_epochs']}")
Accessing archetype coordinates stored in AnnData:
>>> # After training, coordinates are in adata.uns
>>> archetype_coords = adata.uns["archetype_coordinates"]
>>> print(f"Learned {archetype_coords.shape[0]} archetypes")
>>> print(f"Each archetype has {archetype_coords.shape[1]} PCA dimensions")
See Also
--------
peach.tl.archetypal_coordinates : Extract cell-archetype distances
peach.tl.assign_archetypes : Assign cells to discovered archetypes
peach.tl.extract_archetype_weights : Get barycentric coordinates for cells
peach.pl.training_metrics : Visualize training curves
peach._core.types.TrainingResults : Type definition for return structure
"""
# Comprehensive input validation (scVerse convention)
if pca_key not in adata.obsm:
# Try to auto-detect PCA coordinates
pca_candidates = ["X_pca", "X_PCA", "PCA", "x_pca"]
found_pca = None
for candidate in pca_candidates:
if candidate in adata.obsm:
found_pca = candidate
break
if found_pca:
print(f"[WARNING] Using {found_pca} instead of {pca_key}")
pca_key = found_pca
else:
raise ValueError(f"PCA coordinates not found. Tried: {pca_candidates}. Run scanpy.pp.pca(adata) first.")
# Validate PCA dimensions
pca_shape = adata.obsm[pca_key].shape
if pca_shape[1] < 3:
raise ValueError(f"PCA has only {pca_shape[1]} components. Need at least 3 for archetypal analysis.")
# Validate archetype count vs PCA dimensions
if n_archetypes > pca_shape[1]:
raise ValueError(
f"n_archetypes ({n_archetypes}) cannot exceed PCA dimensions ({pca_shape[1]}). "
f"Reduce n_archetypes or increase PCA components."
)
# Default configurations - use proven working setup
# Use all available PCA components (or user-specified input_dim)
input_dim = adata.obsm[pca_key].shape[1] # Use all available PCA components
default_model_config = {
"input_dim": input_dim,
"n_archetypes": n_archetypes,
"latent_dim": n_archetypes,
"hidden_dims": hidden_dims if hidden_dims is not None else [256, 128, 64],
"archetypal_weight": 0.9,
"kld_weight": 0.1,
"diversity_weight": 0.05,
"manifold_weight": 0.0,
"regularity_weight": 0.0,
"sparsity_weight": 0.0,
"inflation_factor": inflation_factor,
"use_barycentric": True,
"use_hidden_transform": True,
"device": device,
}
if model_config:
default_model_config.update(model_config)
default_optimizer_config = {"lr": 1e-3} # User's successful learning rate
if optimizer_config:
default_optimizer_config.update(optimizer_config)
# Create DataLoader with correct dimensionality
from ..pp.basic import prepare_training
# Slice PCA data to match input_dim for optimal performance
if input_dim < adata.obsm[pca_key].shape[1]:
# Temporarily modify adata to use only first input_dim components
original_pca = adata.obsm[pca_key].copy()
adata.obsm[pca_key] = adata.obsm[pca_key][:, :input_dim]
print(f" Using first {input_dim} PCA components for training (shape: {adata.obsm[pca_key].shape})")
dataloader = prepare_training(adata, pca_key=pca_key)
# Restore original PCA data
adata.obsm[pca_key] = original_pca
else:
dataloader = prepare_training(adata, pca_key=pca_key)
# Initialize model (existing code)
model = Deep_AA(**default_model_config)
optimizer = torch.optim.Adam(model.parameters(), **default_optimizer_config)
# CRITICAL: Initialize with PCHA + inflation (user's successful setup!)
print(f" Initializing with PCHA + inflation_factor={default_model_config['inflation_factor']}...")
# Get sample data for PCHA initialization
# Ensure contiguous array (fixes negative stride issue)
pca_sample = np.ascontiguousarray(adata.obsm[pca_key][:, :input_dim])
sample_data = torch.FloatTensor(pca_sample)
try:
if hasattr(model, "initialize_archetypes"):
success = model.initialize_archetypes(
sample_data,
use_pcha=True,
use_inflation=True,
inflation_factor=default_model_config["inflation_factor"],
)
else:
# Fallback to deprecated method
success = model.initialize_with_pcha_and_inflation(
sample_data, inflation_factor=default_model_config["inflation_factor"], n_subsample=1000
)
if success:
print("[OK] PCHA + inflation initialization successful!")
else:
print("[WARNING] PCHA initialization failed, using default initialization")
except Exception as e:
print(f"[WARNING] PCHA initialization error: {e}, using default initialization")
# Train model (delegate to existing battle-tested function)
results = _train_vae(
model=model,
data_loader=dataloader,
optimizer=optimizer,
n_epochs=n_epochs,
device=device,
save_path=save_path,
archetypal_weight=archetypal_weight,
kld_weight=kld_weight,
reconstruction_weight=reconstruction_weight,
vae_recon_weight=vae_recon_weight,
diversity_weight=diversity_weight,
activation_func=activation_func,
track_stability=track_stability,
validate_constraints=validate_constraints,
lr_factor=lr_factor,
lr_patience=lr_patience,
seed=seed,
constraint_tolerance=constraint_tolerance,
stability_history_size=stability_history_size,
adata=adata, # Auto-stores coordinates
store_coords_key=store_coords_key,
early_stopping=early_stopping,
early_stopping_patience=early_stopping_patience,
early_stopping_metric=early_stopping_metric,
min_improvement=min_improvement,
validation_check_interval=validation_check_interval,
validation_data_loader=validation_data_loader,
**kwargs,
)
# train_vae ALWAYS returns (results_dict, trained_model) tuple
# Extract just the results dictionary for user
results_dict, trained_model = results
# Store trained model in adata for downstream functions (extract_archetype_weights, etc.)
adata.uns["trained_model"] = trained_model
# Also add model to results dict for direct access
results_dict["final_model"] = trained_model
results_dict["model"] = trained_model # Alias for compatibility
# Add convenience keys for common access patterns (for user-friendly API)
# These provide direct access without navigating nested dicts
history = results_dict.get("history", {})
if "archetype_r2" in history and len(history["archetype_r2"]) > 0:
results_dict["final_archetype_r2"] = history["archetype_r2"][-1]
if "rmse" in history and len(history["rmse"]) > 0:
results_dict["final_rmse"] = history["rmse"][-1]
if "mae" in history and len(history["mae"]) > 0:
results_dict["final_mae"] = history["mae"][-1]
if "loss" in history and len(history["loss"]) > 0:
results_dict["final_loss"] = history["loss"][-1]
training_config = results_dict.get("training_config", {})
if "actual_epochs" in training_config:
results_dict["convergence_epoch"] = training_config["actual_epochs"]
# Optional: validate internally (catches bugs early)
# validate_training_results(results_dict)
return results_dict
[docs]
def archetypal_coordinates(
adata: AnnData,
*,
pca_key: str = "X_pca",
archetype_coords_key: str = "archetype_coordinates",
obsm_key: str = "archetype_distances",
uns_prefix: str = "archetype",
verbose: bool = True,
**kwargs,
) -> dict[str, Any]:
"""Extract archetypal coordinates for all cells.
Parameters
----------
adata : AnnData
Annotated data object with trained model coordinates
pca_key : str, default: "X_pca"
Key in adata.obsm containing PCA coordinates
archetype_coords_key : str, default: "archetype_coordinates"
Key in adata.uns containing archetype coordinates
obsm_key : str, default: "archetype_distances"
Key to store distance matrix in adata.obsm
uns_prefix : str, default: "archetype"
Prefix for keys stored in adata.uns
verbose : bool, default: True
Whether to print progress messages
**kwargs
Additional arguments passed to compute_archetype_distances
Returns
-------
dict
Dictionary with archetypal coordinates and distances
"""
# Input validation
if archetype_coords_key not in adata.uns:
raise ValueError(f"adata.uns['{archetype_coords_key}'] not found. Run pc.tl.train_archetypal() first.")
if pca_key not in adata.obsm:
# Try to auto-detect PCA coordinates
pca_candidates = ["X_pca", "X_PCA", "PCA", "x_pca"]
found_pca = None
for candidate in pca_candidates:
if candidate in adata.obsm:
found_pca = candidate
break
if found_pca:
pca_key = found_pca
else:
raise ValueError(f"PCA coordinates not found. Tried: {pca_candidates}. Run scanpy.pp.pca(adata) first.")
return _compute_archetype_distances(
adata=adata,
pca_key=pca_key,
archetype_coords_key=archetype_coords_key,
obsm_key=obsm_key,
uns_prefix=uns_prefix,
verbose=verbose,
**kwargs,
)
[docs]
def assign_archetypes(
adata: AnnData,
*,
percentage_per_archetype: float = 0.1,
obsm_key: str = "archetype_distances",
obs_key: str = "archetypes",
include_central_archetype: bool = True,
verbose: bool = True,
**kwargs,
) -> None:
"""Assign cells to archetypes based on distances.
Parameters
----------
adata : AnnData
Annotated data object with archetype distances
percentage_per_archetype : float, default: 0.1
Percentage of cells to assign to each archetype
obsm_key : str, default: "archetype_distances"
Key in adata.obsm containing distance matrix
obs_key : str, default: "archetypes"
Key to store assignments in adata.obs
include_central_archetype : bool, default: True
Whether to include a central archetype (cells far from all extreme archetypes)
verbose : bool, default: True
Whether to print progress messages
**kwargs
Additional arguments passed to bin_cells_by_archetype
"""
# Input validation
if obsm_key not in adata.obsm:
raise ValueError(f"adata.obsm['{obsm_key}'] not found. Run pc.tl.archetypal_coordinates() first.")
# Validate percentage
if not 0.0 < percentage_per_archetype <= 1.0:
raise ValueError(f"percentage_per_archetype must be between 0 and 1, got {percentage_per_archetype}")
_bin_cells_by_archetype(
adata=adata,
percentage_per_archetype=percentage_per_archetype,
obsm_key=obsm_key,
obs_key=obs_key,
include_central_archetype=include_central_archetype,
verbose=verbose,
**kwargs,
)
[docs]
def compute_conditional_centroids(
adata,
condition_column: str,
*,
pca_key: str = "X_pca",
store_key: str = "conditional_centroids",
exclude_archetypes: list = None,
groupby: str = None,
verbose: bool = True,
):
"""Compute centroid positions in PCA space for each level of a categorical condition.
This function calculates the mean position (centroid) in PCA space for cells
belonging to each level of a categorical variable. Useful for visualizing
how different conditions (e.g., treatment phases, timepoints) relate to
the archetypal structure.
Following R template patterns:
- Uses ALL PCs for centroid calculation (equivalent to R's colMeans)
- Stores full PC centroid but extracts first 3 for 3D visualization
- Excludes 'no_archetype' and 'archetype_0' cells by default
Parameters
----------
adata : AnnData
Annotated data object with PCA coordinates in adata.obsm[pca_key].
condition_column : str
Name of categorical column in adata.obs to group by.
Examples: 'treatment_phase', 'timepoint', 'batch'.
pca_key : str, default: "X_pca"
Key in adata.obsm containing PCA coordinates.
store_key : str, default: "conditional_centroids"
Key in adata.uns to store results.
exclude_archetypes : list, optional
Archetype labels to exclude from centroid calculation.
Default: ['no_archetype', 'archetype_0'] (following R template).
Set to empty list [] to include all cells.
groupby : str, optional
Second categorical column for multi-group trajectories.
If provided, centroids are computed for each (group, level) combination.
Example: groupby='response_group' to get separate trajectories per response.
verbose : bool, default: True
Whether to print progress messages.
Returns
-------
dict
Dictionary with keys:
- ``condition_column`` : str - name of the condition column
- ``n_levels`` : int - number of unique levels
- ``levels`` : List[str] - list of level names
- ``centroids`` : Dict[str, List[float]] - level → full PCA coordinates
- ``centroids_3d`` : Dict[str, List[float]] - level → [x, y, z] first 3 PCs
- ``cell_counts`` : Dict[str, int] - level → cell count
- ``pca_key`` : str - PCA key used
- ``exclude_archetypes`` : List[str] - archetypes excluded
- ``groupby`` : Optional[str] - groupby column if used
- ``group_centroids`` : Optional[Dict] - if groupby: {group: {level: coords}}
- ``group_centroids_3d`` : Optional[Dict] - if groupby: {group: {level: [x,y,z]}}
- ``group_cell_counts`` : Optional[Dict] - if groupby: {group: {level: count}}
Raises
------
ValueError
If condition_column not in adata.obs or PCA coordinates not found.
Stores
------
The function stores results in AnnData:
- ``adata.uns[store_key][condition_column]`` : dict
Full results dictionary as returned.
Examples
--------
>>> # Compute centroids for treatment phase
>>> result = pc.tl.compute_conditional_centroids(adata, "treatment_phase")
>>> print(result["centroids_3d"])
{'chemo-naive': [1.2, 0.5, -0.3], 'IDS': [0.8, 1.1, 0.2]}
>>> # Then visualize with trajectory
>>> fig = pc.pl.archetypal_space(
... adata, show_centroids=True, centroid_condition="treatment_phase", centroid_order=["chemo-naive", "IDS"]
... )
>>> # Multi-group centroids for trajectory comparison
>>> result = pc.tl.compute_conditional_centroids(adata, "treatment_phase", groupby="response_group")
>>> fig = pc.pl.archetypal_space(
... adata,
... show_centroids=True,
... centroid_condition="treatment_phase",
... centroid_groupby="response_group",
... centroid_order=["chemo-naive", "IDS"],
... centroid_colors={"long": "magenta", "short": "cyan"},
... )
See Also
--------
peach.pl.archetypal_space : Visualize with centroid trajectory overlay
"""
from .._core.utils.analysis import compute_conditional_centroids as _compute_centroids
return _compute_centroids(
adata=adata,
condition_column=condition_column,
pca_key=pca_key,
store_key=store_key,
exclude_archetypes=exclude_archetypes,
groupby=groupby,
verbose=verbose,
)
[docs]
def assign_to_centroids(
adata,
condition_column: str,
*,
pca_key: str = "X_pca",
centroid_key: str = "conditional_centroids",
bin_prop: float = 0.15,
obs_key: str = "centroid_assignments",
exclude_archetypes: list = None,
verbose: bool = True,
) -> None:
"""Assign cells to nearest centroid based on distance (top bin_prop% closest).
This function mirrors assign_archetypes but for condition-based centroids.
It enables using treatment phase centroids as trajectory endpoints in
single_trajectory_analysis by creating categorical assignments that CellRank
can use as terminal states.
Parameters
----------
adata : AnnData
Annotated data object. Must have:
- PCA coordinates in adata.obsm[pca_key]
- Centroids computed via compute_conditional_centroids in adata.uns[centroid_key]
condition_column : str
Name of the condition column used in compute_conditional_centroids.
This identifies which centroid set to use.
pca_key : str, default: "X_pca"
Key in adata.obsm containing PCA coordinates.
centroid_key : str, default: "conditional_centroids"
Key in adata.uns containing centroid results from compute_conditional_centroids.
bin_prop : float, default: 0.15
Proportion of cells to assign to each centroid (top 15% closest).
Similar to percentage_per_archetype in assign_archetypes.
obs_key : str, default: "centroid_assignments"
Key in adata.obs to store assignments.
exclude_archetypes : list, optional
Archetype labels to exclude from assignment.
Default: ['no_archetype'] - these cells get 'unassigned'.
verbose : bool, default: True
Whether to print progress messages.
Returns
-------
None
Modifies adata.obs[obs_key] with Categorical assignments.
Values are condition levels (e.g., 'chemo_naive', 'IDS') or 'unassigned'.
Examples
--------
>>> # First compute centroids for treatment phases
>>> pc.tl.compute_conditional_centroids(adata, "treatment_stage")
>>>
>>> # Then assign cells to nearest centroid (top 15% closest)
>>> pc.tl.assign_to_centroids(adata, "treatment_stage", bin_prop=0.15)
>>>
>>> # Check assignments
>>> print(adata.obs["centroid_assignments"].value_counts())
>>>
>>> # Now can use with CellRank for trajectory analysis
>>> # (setup_cellrank can use centroid_assignments as terminal states)
See Also
--------
compute_conditional_centroids : Compute centroids for condition levels
assign_archetypes : Similar function for archetype assignments
single_trajectory_analysis : Uses centroid assignments for trajectory analysis
"""
from .._core.utils.analysis import assign_to_centroids as _assign_to_centroids
return _assign_to_centroids(
adata=adata,
condition_column=condition_column,
pca_key=pca_key,
centroid_key=centroid_key,
bin_prop=bin_prop,
obs_key=obs_key,
exclude_archetypes=exclude_archetypes,
verbose=verbose,
)