"""
Archetypal visualization functions.
This module provides publication-ready interactive visualization tools for
archetypal analysis results. All plots are built with Plotly for interactivity
and can be exported to various formats for publication.
Main Functions:
- archetypal_space(): Interactive 3D visualization of archetypal coordinate space
- archetypal_space_multi(): Compare multiple archetypal fits side-by-side
- training_metrics(): Training diagnostics and convergence analysis
- elbow_curve(): Hyperparameter selection support with cross-validation
Features:
- Interactive Plotly-based plots with zoom, pan, and hover
- Gene expression coloring with smart layer selection
- Publication-ready aesthetics and customization options
- Automatic legend and colorbar positioning
"""
from pathlib import Path
from typing import Any
import plotly.graph_objects as go
from anndata import AnnData
from .._core.viz.results_viz import visualize_archetypal_space_3d_multi as _viz_3d_multi
# Import existing battle-tested functions
from .._core.viz.results_viz import visualize_archetypal_space_3d_single as _viz_3d
from .._core.viz.training_viz import plot_training_metrics as _plot_training
[docs]
def archetypal_space(
adata: AnnData,
*,
archetype_coords_key: str = "archetype_coordinates",
pca_key: str = "X_pca",
color_by: str | None = None,
use_layer: str = "logcounts",
cell_size: float = 2.0,
cell_opacity: float = 0.6,
archetype_size: float = 8.0,
archetype_color: str = "red",
show_archetype_labels: bool = True,
show_connections: bool = True,
color_scale: str = "viridis",
categorical_colors: dict | None = None,
title: str = "Archetypal Space Visualization",
auto_scale: bool = True,
save_path: str | None = None,
fixed_ranges: dict | None = None,
legend_marker_scale: float = 1.0,
legend_font_size: int = 12,
# Conditional centroid parameters
show_centroids: bool = False,
centroid_condition: str | None = None,
centroid_order: list | None = None,
centroid_groupby: str | None = None,
centroid_size: float = 20.0,
centroid_start_symbol: str = "circle",
centroid_end_symbol: str = "diamond",
centroid_line_width: float = 6.0,
centroid_colors: dict | None = None,
**kwargs,
) -> go.Figure:
"""Visualize cells in 3D archetypal coordinate space.
Creates an interactive 3D scatter plot showing cells positioned
in PCA space with archetype positions and optional coloring by
gene expression or metadata.
Parameters
----------
adata : AnnData
Annotated data object with archetypal coordinates.
archetype_coords_key : str, default: "archetype_coordinates"
Key in adata.uns containing archetype coordinates [n_archetypes, n_pcs].
pca_key : str, default: "X_pca"
Key in adata.obsm containing PCA coordinates [n_cells, n_pcs].
color_by : str | None, default: None
Column in adata.obs (categorical/continuous) or gene name in
adata.var.index for expression coloring.
use_layer : str, default: "logcounts"
Layer for gene expression. Falls back to adata.X if not found.
cell_size : float, default: 2.0
Size of cell points.
cell_opacity : float, default: 0.6
Opacity of cell points (0-1).
archetype_size : float, default: 8.0
Size of archetype diamond markers.
archetype_color : str, default: "red"
Color for archetype markers.
show_archetype_labels : bool, default: True
Whether to show 'Arch1', 'Arch2', etc. labels.
show_connections : bool, default: True
Whether to draw lines connecting all archetype pairs.
color_scale : str, default: "viridis"
Plotly color scale for continuous variables.
categorical_colors : dict | None, default: None
Custom colors for categorical variables {category: color}.
title : str, default: "Archetypal Space Visualization"
Plot title.
auto_scale : bool, default: True
Whether to auto-scale axes using 1st-99th percentiles.
save_path : str | None, default: None
Path to save HTML file.
fixed_ranges : dict | None, default: None
Fixed axis ranges {'x': (min, max), 'y': (min, max), 'z': (min, max)}.
legend_marker_scale : float, default: 1.0
Scale factor for legend marker sizes.
legend_font_size : int, default: 12
Font size for legend text.
show_centroids : bool, default: False
Whether to display condition centroids on the plot.
Requires centroids computed via pc.tl.compute_conditional_centroids().
centroid_condition : str | None, default: None
Column name in adata.obs for condition centroids.
Must have centroids pre-computed via pc.tl.compute_conditional_centroids().
centroid_order : list | None, default: None
Order of condition levels for trajectory line.
If provided, draws a line connecting centroids in this order.
Example: ['chemo-naive', 'IDS'] for treatment timeline.
centroid_groupby : str | None, default: None
Column name for multi-group trajectories.
If provided, draws separate trajectory per group with different colors.
centroid_size : float, default: 20.0
Size of centroid markers.
centroid_start_symbol : str, default: "circle"
Plotly symbol for first centroid in trajectory.
centroid_end_symbol : str, default: "diamond"
Plotly symbol for last centroid in trajectory.
centroid_line_width : float, default: 6.0
Width of trajectory line connecting centroids.
centroid_colors : dict | None, default: None
Custom colors for centroid markers/lines.
If centroid_groupby used: {group: color} (e.g., {'long': 'magenta', 'short': 'cyan'}).
Otherwise: {'default': color}.
**kwargs
Additional arguments passed to underlying visualization.
Returns
-------
plotly.graph_objects.Figure
Interactive 3D scatter plot containing:
- Cell points colored by color_by (with colorbar if continuous)
- Archetype positions as diamond markers
- Archetype labels (if show_archetype_labels=True)
- Hull edges connecting archetypes (if show_connections=True)
- Condition centroids with trajectory lines (if show_centroids=True)
Raises
------
ValueError
If adata.obsm['archetype_distances'] not found (run
pc.tl.archetypal_coordinates() first).
Examples
--------
>>> # Color by cell type metadata
>>> fig = pc.pl.archetypal_space(adata, color_by="cell_type")
>>> fig.show()
>>> # Color by gene expression
>>> fig = pc.pl.archetypal_space(adata, color_by="CD3D")
>>> fig.show()
>>> # Custom styling
>>> fig = pc.pl.archetypal_space(
... adata, color_by="pseudotime", color_scale="plasma", cell_opacity=0.4, archetype_size=12.0
... )
>>> # With condition trajectory centroids
>>> pc.tl.compute_conditional_centroids(adata, "treatment_phase")
>>> fig = pc.pl.archetypal_space(
... adata,
... show_centroids=True,
... centroid_condition="treatment_phase",
... centroid_order=["chemo-naive", "IDS"],
... centroid_colors={"default": "magenta"},
... )
>>> # Multi-group trajectories (treatment × response)
>>> pc.tl.compute_conditional_centroids(adata, "treatment_phase", groupby="response")
>>> fig = pc.pl.archetypal_space(
... adata,
... show_centroids=True,
... centroid_condition="treatment_phase",
... centroid_groupby="response",
... centroid_order=["chemo-naive", "IDS"],
... centroid_colors={"long": "magenta", "short": "cyan"},
... )
"""
# Input validation
if "archetype_distances" not in adata.obsm:
raise ValueError("Archetypal distances not found. Run pc.tl.archetypal_coordinates() first.")
# Delegate to existing visualization function
fig = _viz_3d(
adata=adata,
archetype_coords_key=archetype_coords_key,
pca_key=pca_key,
color_by=color_by,
use_layer=use_layer,
cell_size=cell_size,
cell_opacity=cell_opacity,
archetype_size=archetype_size,
archetype_color=archetype_color,
show_archetype_labels=show_archetype_labels,
show_connections=show_connections,
color_scale=color_scale,
categorical_colors=categorical_colors,
title=title,
auto_scale=auto_scale,
save_path=save_path,
fixed_ranges=fixed_ranges,
legend_marker_scale=legend_marker_scale,
legend_font_size=legend_font_size,
# Conditional centroid parameters
show_centroids=show_centroids,
centroid_condition=centroid_condition,
centroid_order=centroid_order,
centroid_groupby=centroid_groupby,
centroid_size=centroid_size,
centroid_start_symbol=centroid_start_symbol,
centroid_end_symbol=centroid_end_symbol,
centroid_line_width=centroid_line_width,
centroid_colors=centroid_colors,
**kwargs,
)
return fig
[docs]
def archetypal_space_multi(
adata_list: list[AnnData],
*,
archetype_coords_key: str = "archetype_coordinates",
pca_key: str = "X_pca",
labels_list: list[str] | None = None,
color_by: str | list[str] | None = None,
color_values: Any | list[Any] | None = None,
cell_size: float = 2.0,
cell_opacity: float = 0.6,
archetype_size: float = 8.0,
archetype_colors: list[str] | None = None,
show_labels: bool | list[int] = True,
auto_scale: bool = True,
range_reference: int | Any | None = None,
fixed_ranges: dict[str, tuple[float, float]] | None = None,
color_scale: str = "viridis",
categorical_colors: dict[str, str] | None = None,
title: str = "Multi-Archetypal Space Comparison",
save_path: str | Path | None = None,
) -> go.Figure:
"""Compare multiple archetypal analysis fits in 3D PCA space.
Creates an interactive 3D scatter plot comparing multiple archetypal fits,
useful for comparing different conditions, treatments, or parameter settings.
Parameters
----------
adata_list : list of AnnData
List of AnnData objects with PCA coordinates and archetype results
archetype_coords_key : str, default: "archetype_coordinates"
Key in adata.uns containing archetype coordinates
pca_key : str, default: "X_pca"
Key in adata.obsm containing PCA coordinates
labels_list : list of str | None, default: None
Labels for each dataset (defaults to 'Set 1', 'Set 2', etc.)
color_by : str | list of str | None, default: None
Column(s) to color cells by - single string or list per dataset
color_values : array | list of arrays | None, default: None
Direct color values - single array or list per dataset
cell_size : float, default: 2.0
Size of cell points
cell_opacity : float, default: 0.6
Opacity of cell points (0-1)
archetype_size : float, default: 8.0
Size of archetype markers
archetype_colors : list of str | None, default: None
Colors for archetype markers per dataset
show_labels : bool | list of int, default: True
Which datasets to show archetype labels for (bool, list of indices)
auto_scale : bool, default: True
Whether to auto-scale axes based on all data
range_reference : int | AnnData | None, default: None
Reference dataset index or AnnData for axis scaling
fixed_ranges : dict | None, default: None
Fixed axis ranges {'x': (min, max), 'y': (min, max), 'z': (min, max)}
color_scale : str, default: 'viridis'
Plotly color scale for continuous variables
categorical_colors : dict | None, default: None
Custom colors for categorical variables
title : str, default: 'Multi-Archetypal Space Comparison'
Plot title
save_path : str | Path | None, default: None
Optional path to save HTML file
Returns
-------
plotly.graph_objects.Figure
Interactive 3D comparison plot
Examples
--------
>>> # Compare treatment conditions
>>> fig = pc.pl.archetypal_space_multi(
... adata_list=[adata_control, adata_treated],
... labels_list=["Control", "Treated"],
... color_by=["cell_type", "cell_type"],
... title="Treatment Effect on Archetypal Space",
... )
>>> fig.show()
>>> # Compare different archetype numbers
>>> fig = pc.pl.archetypal_space_multi(
... adata_list=[adata_k3, adata_k5, adata_k7],
... labels_list=["K=3", "K=5", "K=7"],
... show_labels=[2], # Only show labels for K=7
... title="Archetype Number Comparison",
... )
>>> fig.show()
"""
# Delegate to existing visualization function
fig = _viz_3d_multi(
adata_list=adata_list,
archetype_coords_key=archetype_coords_key,
pca_key=pca_key,
labels_list=labels_list,
color_by=color_by,
color_values=color_values,
cell_size=cell_size,
cell_opacity=cell_opacity,
archetype_size=archetype_size,
archetype_colors=archetype_colors,
show_labels=show_labels,
auto_scale=auto_scale,
range_reference=range_reference,
fixed_ranges=fixed_ranges,
color_scale=color_scale,
categorical_colors=categorical_colors,
title=title,
save_path=save_path,
)
return fig
[docs]
def training_metrics(
history: dict, *, height: int = 400, width: int = 800, display: bool = True, **kwargs
) -> go.Figure:
"""Visualize training metrics over epochs.
Creates interactive Plotly visualization with loss components,
stability metrics, and convergence analysis.
Parameters
----------
history : dict
Training history dictionary from pc.tl.train_archetypal().
Expected keys: 'loss', 'archetypal_loss', 'KLD', 'rmse',
'vertex_stability_latent', 'vertex_stability_pca', 'loss_delta'.
height : int, default: 400
Base plot height in pixels (actual height is 2x for 3 rows).
width : int, default: 800
Plot width in pixels.
display : bool, default: True
Whether to display the plot immediately via fig.show().
**kwargs
Additional arguments passed to plot_training_metrics.
Returns
-------
plotly.graph_objects.Figure or None
Interactive training metrics plot with 3-row layout:
- Row 1 (40%): Loss metrics (loss, archetypal_loss, KLD, rmse)
- Row 2 (30%): Stability metrics (vertex_stability_latent/pca)
- Row 3 (30%): Convergence (loss_delta with rolling mean)
Returns None only if history is empty.
Examples
--------
>>> results = pc.tl.train_archetypal(adata, n_archetypes=5)
>>> fig = pc.pl.training_metrics(results["history"], display=False)
>>> fig.write_html("training.html")
"""
return _plot_training(history=history, height=height, width=width, display=display, **kwargs)
[docs]
def elbow_curve(cv_summary, *, metrics: list[str] = ["archetype_r2", "rmse"], **kwargs) -> go.Figure:
"""Plot elbow curves for hyperparameter selection.
Parameters
----------
cv_summary : CVSummary
Cross-validation results from pc.tl.hyperparameter_search()
metrics : list[str], default: ["archetype_r2", "rmse"]
Metrics to plot
**kwargs
Additional arguments passed to plot_elbow_curve
Returns
-------
plotly.graph_objects.Figure
Interactive elbow curve plot
"""
return cv_summary.plot_elbow_curve(metrics, **kwargs)
[docs]
def archetype_positions(
adata: AnnData,
*,
coords_key: str = "archetype_coordinates",
title: str = "Archetype Positions in PCA Space",
figsize: tuple = (15, 6),
cmap: str = "tab10",
show_distances: bool = True,
save_path: str | None = None,
**kwargs,
) -> Any:
"""Visualize archetype positions in PCA space with distance matrix.
Creates a two-panel visualization showing archetype positions in the
first two principal components and a pairwise distance matrix heatmap.
Parameters
----------
adata : AnnData
Annotated data object with archetype coordinates
coords_key : str, default: "archetype_coordinates"
Key in adata.uns containing archetype coordinates
title : str, default: "Archetype Positions in PCA Space"
Main figure title
figsize : tuple, default: (15, 6)
Figure size as (width, height)
cmap : str, default: 'tab10'
Colormap for archetype points
show_distances : bool, default: True
Whether to show distance matrix panel
save_path : str | None, default: None
Path to save the figure
**kwargs
Additional arguments passed to plot_archetype_positions
Returns
-------
matplotlib.figure.Figure
Figure with archetype position visualizations
Examples
--------
>>> fig = pc.pl.archetype_positions(adata)
>>> plt.show()
>>> # Save high-resolution figure
>>> fig = pc.pl.archetype_positions(adata, title="Helsinki EOC Archetype Positions", save_path="archetypes.png")
Notes
-----
The visualization includes:
- Left panel: Archetype positions in PC1-PC2 space with convex hull
- Right panel: Pairwise distance matrix with values
Requires at least 2 dimensions in archetype coordinates.
For 3D visualization, use `archetype_positions_3d()`.
"""
from .._core.viz.training_viz import plot_archetype_positions as _plot_positions
# Get archetype coordinates from AnnData
if coords_key not in adata.uns:
raise ValueError(f"adata.uns['{coords_key}'] not found. Run pc.tl.train_archetypal() first.")
coords = adata.uns[coords_key]
return _plot_positions(
archetype_coordinates=coords,
title=title,
figsize=figsize,
cmap=cmap,
show_distances=show_distances,
save_path=save_path,
**kwargs,
)
[docs]
def archetype_positions_3d(
adata: AnnData,
*,
coords_key: str = "archetype_coordinates",
title: str = "Archetype Positions in 3D PCA Space",
figsize: tuple = (12, 10),
cmap: str = "tab10",
save_path: str | None = None,
**kwargs,
) -> Any:
"""Visualize archetype positions in 3D PCA space.
Creates an interactive 3D visualization of archetype positions with
convex hull edges connecting the archetypes.
Parameters
----------
adata : AnnData
Annotated data object with archetype coordinates
coords_key : str, default: "archetype_coordinates"
Key in adata.uns containing archetype coordinates
title : str, default: "Archetype Positions in 3D PCA Space"
Figure title
figsize : tuple, default: (12, 10)
Figure size as (width, height)
cmap : str, default: 'tab10'
Colormap for archetype points
save_path : str | None, default: None
Path to save the figure
**kwargs
Additional arguments passed to plot_archetype_distances_3d
Returns
-------
matplotlib.figure.Figure
3D visualization of archetypes
Examples
--------
>>> # Basic 3D visualization
>>> fig = pc.pl.archetype_positions_3d(adata)
>>> plt.show()
>>> # Custom visualization
>>> fig = pc.pl.archetype_positions_3d(adata, cmap="Set1", title="3D Archetype Hull")
Notes
-----
Requires at least 3 dimensions in archetype coordinates.
The visualization includes convex hull edges connecting archetypes.
"""
from .._core.viz.training_viz import plot_archetype_distances_3d as _plot_3d
# Get archetype coordinates from AnnData
if coords_key not in adata.uns:
raise ValueError(f"adata.uns['{coords_key}'] not found. Run pc.tl.train_archetypal() first.")
coords = adata.uns[coords_key]
return _plot_3d(
archetype_coordinates=coords, title=title, figsize=figsize, cmap=cmap, save_path=save_path, **kwargs
)
[docs]
def archetype_statistics(adata: AnnData, *, coords_key: str = "archetype_coordinates", verbose: bool = True) -> dict:
"""Compute and display statistics about archetype positions.
Calculates pairwise distances, identifies nearest/farthest archetype
pairs, and computes convex hull metrics when possible.
Parameters
----------
adata : AnnData
Annotated data object with archetype coordinates.
coords_key : str, default: "archetype_coordinates"
Key in adata.uns containing archetype coordinates.
verbose : bool, default: True
Whether to print statistics to console.
Returns
-------
dict
Statistics dictionary with keys:
- n_archetypes : int - Number of archetypes
- n_dimensions : int - Embedding dimensions
- mean_distance : float - Mean pairwise Euclidean distance
- std_distance : float - Std of pairwise distances
- min_distance : float - Minimum pairwise distance
- max_distance : float - Maximum pairwise distance
- distance_range : float - max - min distance
- nearest_pair : tuple[int, int] - Indices of nearest pair (0-based)
- farthest_pair : tuple[int, int] - Indices of farthest pair (0-based)
- distance_matrix : np.ndarray - Full pairwise distance matrix
- hull_volume : float | None - Convex hull volume (3D+ only)
- hull_area : float | None - Convex hull surface area (3D+ only)
Raises
------
ValueError
If adata.uns[coords_key] not found.
Examples
--------
>>> stats = pc.pl.archetype_statistics(adata)
[STATS] Archetype Statistics
==================================================
Number of archetypes: 5
...
>>> # Quiet mode
>>> stats = pc.pl.archetype_statistics(adata, verbose=False)
>>> print(f"Nearest pair: A{stats['nearest_pair'][0] + 1}-A{stats['nearest_pair'][1] + 1}")
"""
from .._core.viz.training_viz import compute_archetype_statistics as _compute_stats
# Get archetype coordinates from AnnData
if coords_key not in adata.uns:
raise ValueError(f"adata.uns['{coords_key}'] not found. Run pc.tl.train_archetypal() first.")
coords = adata.uns[coords_key]
return _compute_stats(archetype_coordinates=coords, verbose=verbose)