Source code for peach.pl.cellrank_viz

"""Visualization functions for CellRank archetypal trajectories."""

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns


[docs] def fate_probabilities(adata, lineages=None, basis="X_umap", same_plot=False, ncols=3, figsize=None, **kwargs): """ Plot fate probabilities on embedding. Visualizes the probability of cells committing to each lineage/archetype using CellRank's plotting functionality. Parameters ---------- adata : AnnData Annotated data matrix with fate probabilities computed. lineages : list of str, optional Specific lineages to plot. If None, plots all lineages. basis : str, default: 'X_umap' Embedding to use ('X_umap', 'X_pca', etc.). same_plot : bool, default: False If True, plots all lineages on same axes (pie chart style). If False, creates separate panel per lineage. ncols : int, default: 3 Number of columns for multi-panel plot (if same_plot=False). figsize : tuple, optional Figure size. If None, automatically determined. **kwargs Additional arguments passed to CellRank's plot function. Returns ------- None Displays matplotlib figure. CellRank's plotting functions display directly rather than returning figure objects. Raises ------ ImportError If CellRank is not installed. NotImplementedError If GPCCA object not stored in adata.uns['cellrank_gpcca']. Examples -------- >>> pc.pl.fate_probabilities(adata, same_plot=False, ncols=3) >>> pc.pl.fate_probabilities(adata, lineages=["archetype_3", "archetype_5"], same_plot=True) """ try: import cellrank as cr except ImportError: raise ImportError("CellRank not installed. Install with: pip install cellrank") if "cellrank_gpcca" not in adata.uns: # Reconstruct GPCCA object from stored fate probabilities raise NotImplementedError( "Direct plotting from stored fate probabilities not yet implemented. " "Please store GPCCA object: adata.uns['cellrank_gpcca'] = g" ) g = adata.uns["cellrank_gpcca"] # Use CellRank's plotting function # Note: CellRank uses 'states' parameter, not 'lineages' g.plot_fate_probabilities( states=lineages, same_plot=same_plot, basis=basis, ncols=ncols, figsize=figsize, **kwargs )
[docs] def lineage_drivers(adata, lineage, n_genes=20, driver_key=None, figsize=(10, 8), **kwargs): """ Plot heatmap of top driver genes for a lineage. Visualizes expression of genes most correlated with commitment to a specific lineage/archetype. Parameters ---------- adata : AnnData Annotated data matrix lineage : str Target lineage name (e.g., 'archetype_5') n_genes : int, optional (default: 20) Number of top genes to plot driver_key : str, optional Key in adata.varm containing driver gene scores. If None, computes drivers on-the-fly using correlation method figsize : tuple, optional (default: (10, 8)) Figure size **kwargs Additional arguments passed to seaborn.heatmap Returns ------- fig : matplotlib.figure.Figure Figure object Examples -------- Plot top 20 driver genes: >>> import peach as pc >>> pc.pl.lineage_drivers(adata, lineage="archetype_5", n_genes=20) Custom number of genes: >>> pc.pl.lineage_drivers(adata, lineage="archetype_3", n_genes=30, figsize=(12, 10)) Using pre-computed drivers: >>> drivers = pc.tl.compute_lineage_drivers(adata, lineage="archetype_5") >>> adata.var["driver_scores"] = drivers[f"archetype_5_corr"] >>> pc.pl.lineage_drivers(adata, lineage="archetype_5", driver_key="driver_scores") Notes ----- - If driver_key is None, uses simple correlation method - Heatmap rows are cells ordered by fate probability - Heatmap columns are top driver genes See Also -------- gene_trends : Plot expression trends along pseudotime fate_probabilities : Visualize fate probability distributions """ import pandas as pd from scipy.stats import spearmanr # Get or compute driver genes if driver_key is not None: if driver_key not in adata.var.columns: raise ValueError(f"Driver key '{driver_key}' not found in adata.var") driver_scores = adata.var[driver_key] top_genes = driver_scores.nlargest(n_genes).index.tolist() else: # Compute on-the-fly using correlation if "lineage_names" not in adata.uns: raise ValueError("Run setup_cellrank() first") if lineage not in adata.uns["lineage_names"]: raise ValueError(f"Lineage '{lineage}' not found") lineage_idx = adata.uns["lineage_names"].index(lineage) fate_prob = adata.obsm["fate_probabilities"][:, lineage_idx] # Compute correlations results = [] for gene in adata.var_names: expr = adata[:, gene].X.toarray().flatten() if hasattr(adata.X, "toarray") else adata[:, gene].X.flatten() corr, _ = spearmanr(fate_prob, expr) results.append({"gene": gene, "corr": corr}) df = pd.DataFrame(results) top_genes = df.nlargest(n_genes, "corr")["gene"].tolist() # Get expression matrix for top genes expr_matrix = adata[:, top_genes].X if hasattr(expr_matrix, "toarray"): expr_matrix = expr_matrix.toarray() # Order cells by fate probability lineage_idx = adata.uns["lineage_names"].index(lineage) fate_prob = adata.obsm["fate_probabilities"][:, lineage_idx] order = np.argsort(fate_prob) # Create heatmap fig, ax = plt.subplots(figsize=figsize) sns.heatmap( expr_matrix[order, :].T, xticklabels=False, yticklabels=top_genes, cmap="viridis", cbar_kws={"label": "Expression"}, ax=ax, **kwargs, ) ax.set_xlabel(f"Cells (ordered by {lineage} fate probability)", fontsize=12) ax.set_ylabel("Genes", fontsize=12) ax.set_title(f"Top {n_genes} Driver Genes for {lineage}", fontsize=14, fontweight="bold") plt.tight_layout() return fig