WORKFLOW 02: Hyperparameter Search with Cross-Validation#

This workflow demonstrates how to find optimal hyperparameters for archetypal analysis:

  1. Load preprocessed data (with PCA from WORKFLOW_01)

  2. Configure search space for hyperparameters

  3. Run cross-validation grid search

  4. Analyze results and select best configuration

The hyperparameter search tests combinations of:

  • Number of archetypes (discrete values)

  • Hidden layer dimensions (architecture options)

  • Inflation factor (PCA inflation for Deep AA)

  • CV folds and training settings

Output structure (CVSummary):

  • cv_summary.ranked_configs: List of configs ranked by performance

  • cv_summary.summary_df: DataFrame with all results

  • cv_summary.config_results: Detailed per-config results

Example usage: python WORKFLOW_02.py

Requirements: - peach - scanpy - Data with PCA (from WORKFLOW_01 or equivalent)

[ ]:
import scanpy as sc
import peach as pc
from pathlib import Path

Configuration#

[ ]:
# Data path - should have PCA already computed
data_path = Path("data/HSC.h5ad")

# Hyperparameter search space
n_archetypes_range = [7, 9, 11]  # Number of archetypes to test
hidden_dims_options = [
    [128, 64],           # Simpler architecture
    [256, 128, 64],      # Deeper architecture
]
inflation_factor_range = [1.0, 1.5]  # PCA inflation factor

# Cross-validation settings
cv_folds = 3                    # Number of CV folds
max_epochs_cv = 50              # Max epochs per fold
early_stopping_patience = 5     # Stop if no improvement
speed_preset = 'fast'           # 'fast', 'balanced', or 'thorough'
subsample_fraction = 0.5        # Use 50% of data for CV
max_cells_cv = 5000             # Maximum cells per CV run
random_state = 42               # For reproducibility

Step 1: Load Data with PCA#

[ ]:
print("Loading data...")
adata = sc.read_h5ad(data_path)
print(f"  Shape: {adata.n_obs:,} cells Ɨ {adata.n_vars:,} genes")

# Ensure PCA exists (required for archetypal analysis)
if 'X_pca' not in adata.obsm:
    print("  Running PCA (required for archetypal analysis)...")
    sc.tl.pca(adata, n_comps=13)
    print(f"  PCA computed: {adata.obsm['X_pca'].shape}")
else:
    print(f"  PCA found: {adata.obsm['X_pca'].shape}")

NB: I’ve found that with archetype analysis you generally get best results when you use a PCA with the smallest n_components with >99% variance explained. The extraneous very low loading PCs just add noise to the archetypal training processes. For many datasets I have gotten best results with 5-11 PCs. Use Scanpy’s pl.pca_variance_ratio() to explore.

Step 3: Analyze Results#

[5]:
print("\nAnalyzing results...")

# Access ranked configurations (best to worst)
ranked_configs = cv_summary.ranked_configs

print(f"\nTop 3 configurations:")
for i, config in enumerate(ranked_configs[:3], 1):
    print(f"\n  {i}. Configuration:")
    print(f"     Performance (R²): {config['metric_value']:.4f} ± {config['std_error']:.4f}")
    print(f"     Settings: {config['config_summary']}")
    # Access hyperparameters dict
    hparams = config['hyperparameters']
    print(f"     Details:")
    print(f"       - n_archetypes: {hparams['n_archetypes']}")
    print(f"       - hidden_dims: {hparams['hidden_dims']}")
    print(f"       - inflation_factor: {hparams['inflation_factor']}")

# Get best configuration
best_config = ranked_configs[0]
print(f"\nRecommended configuration:")
print(f"  n_archetypes = {best_config['hyperparameters']['n_archetypes']}")
print(f"  hidden_dims = {best_config['hyperparameters']['hidden_dims']}")
print(f"  inflation_factor = {best_config['hyperparameters']['inflation_factor']}")
print(f"  Expected R² = {best_config['metric_value']:.4f}")

Analyzing results...

Top 3 configurations:

  1. Configuration:
     Performance (R²): 0.6089 ± 0.0057
     Settings: 11 archetypes, [256, 128, 64] hidden dims, Ī»=1.5
     Details:
       - n_archetypes: 11
       - hidden_dims: [256, 128, 64]
       - inflation_factor: 1.5

  2. Configuration:
     Performance (R²): 0.6027 ± 0.0095
     Settings: 11 archetypes, [128, 64] hidden dims, Ī»=1.5
     Details:
       - n_archetypes: 11
       - hidden_dims: [128, 64]
       - inflation_factor: 1.5

  3. Configuration:
     Performance (R²): 0.5921 ± 0.0013
     Settings: 9 archetypes, [128, 64] hidden dims, Ī»=1.5
     Details:
       - n_archetypes: 9
       - hidden_dims: [128, 64]
       - inflation_factor: 1.5

Recommended configuration:
  n_archetypes = 11
  hidden_dims = [256, 128, 64]
  inflation_factor = 1.5
  Expected R² = 0.6089

Using ā€˜best_config = ranked_configs[0]’ returns the config that delivers the highest \(R^2\), this is a useful way to programmatically access the best config if you’re running this analysis in a script (e.g., part of a Snakemake workflow), but you can also visually inspect it to select the best configuration for your dataset using the elbow_curve() method below.

Step 4: Visualize Results with Elbow Plot#

The pc.pl.elbow_curve() function creates an interactive visualization showing:

  • Performance metrics (R², RMSE) across different numbers of archetypes

  • Results for all tested hyperparameter combinations (hidden_dims, inflation_factor)

  • Error bars from cross-validation folds

[6]:
print("\nGenerating elbow plot using pc.pl.elbow_curve()...")

# Use PEACH's built-in elbow curve visualization
# Shows multiple metrics across all hyperparameter configurations
fig = pc.pl.elbow_curve(
    cv_summary,
    metrics=["archetype_r2", "mean_val_rmse"],  # Show both R² and RMSE
)

# Display interactive plot
fig.show()

# Display the summary DataFrame to see all hyperparameters explored
print("\nFull hyperparameter search results:")
print(cv_summary.summary_df.to_string())

Generating elbow plot using pc.pl.elbow_curve()...

Data type cannot be displayed: application/vnd.plotly.v1+json


Full hyperparameter search results:
    n_archetypes     hidden_dims  inflation_factor  use_pcha_init  use_inflation  mean_convergence_epoch  mean_val_archetype_r2  mean_val_mae  mean_archetype_r2  mean_val_rmse  mean_early_stopped  std_convergence_epoch  std_val_archetype_r2  std_val_mae  std_archetype_r2  std_val_rmse  std_early_stopped  training_time  early_stopping_rate
0              7       [128, 64]               1.0           True          False                    25.0               0.523364      1.223734           0.523364       1.800805                 0.0                    0.0              0.006811     0.018873          0.006811      0.026436                0.0      14.967988                  0.0
1              7       [128, 64]               1.5           True           True                    25.0               0.525295      1.216038           0.525295       1.797145                 0.0                    0.0              0.002924     0.003426          0.002924      0.019772                0.0      14.652772                  0.0
2              7  [256, 128, 64]               1.0           True          False                    25.0               0.513390      1.233822           0.513390       1.819534                 0.0                    0.0              0.016819     0.027618          0.016819      0.047850                0.0      17.667443                  0.0
3              7  [256, 128, 64]               1.5           True           True                    25.0               0.526204      1.221582           0.526204       1.795450                 0.0                    0.0              0.004735     0.005992          0.004735      0.024167                0.0      17.644125                  0.0
4              9       [128, 64]               1.0           True          False                    25.0               0.545984      1.198935           0.545984       1.757194                 0.0                    0.0              0.012670     0.013127          0.012670      0.016610                0.0      17.817975                  0.0
5              9       [128, 64]               1.5           True           True                    25.0               0.592101      1.156319           0.592101       1.665881                 0.0                    0.0              0.001274     0.000567          0.001274      0.015886                0.0      17.814216                  0.0
6              9  [256, 128, 64]               1.0           True          False                    25.0               0.552910      1.202490           0.552910       1.743383                 0.0                    0.0              0.020284     0.017922          0.020284      0.030704                0.0      20.795003                  0.0
7              9  [256, 128, 64]               1.5           True           True                    25.0               0.577809      1.172580           0.577809       1.694558                 0.0                    0.0              0.011997     0.011833          0.011997      0.022728                0.0      20.965029                  0.0
8             11       [128, 64]               1.0           True          False                    25.0               0.575601      1.178364           0.575601       1.699096                 0.0                    0.0              0.006369     0.012037          0.006369      0.009075                0.0      21.915368                  0.0
9             11       [128, 64]               1.5           True           True                    25.0               0.602715      1.156228           0.602715       1.644016                 0.0                    0.0              0.009476     0.014831          0.009476      0.029191                0.0      22.049284                  0.0
10            11  [256, 128, 64]               1.0           True          False                    25.0               0.575154      1.175513           0.575154       1.698644                 0.0                    0.0              0.034274     0.037607          0.034274      0.065987                0.0      25.127929                  0.0
11            11  [256, 128, 64]               1.5           True           True                    25.0               0.608874      1.130446           0.608874       1.631127                 0.0                    0.0              0.005681     0.004980          0.005681      0.006469                0.0      24.669151                  0.0

Summary#

[7]:
print("\n" + "="*70)
print("WORKFLOW 02 COMPLETE")
print("="*70)
print(f"Configurations tested: {len(ranked_configs)}")
print(f"Best performance: R² = {best_config['metric_value']:.4f}")
print(f"\nBest hyperparameters:")
print(f"  • n_archetypes: {best_config['hyperparameters']['n_archetypes']}")
print(f"  • hidden_dims: {best_config['hyperparameters']['hidden_dims']}")
print(f"  • inflation_factor: {best_config['hyperparameters']['inflation_factor']}")
print(f"\nKey outputs:")
print(f"  • cv_summary.ranked_configs - Ranked configurations")
print(f"  • cv_summary.summary_df - Full results DataFrame")
print(f"  • pc.pl.elbow_curve() - Interactive visualization")
print("\nNext workflow: WORKFLOW_03 (Model Training with best config)")
print("="*70)

======================================================================
WORKFLOW 02 COMPLETE
======================================================================
Configurations tested: 12
Best performance: R² = 0.6089

Best hyperparameters:
  • n_archetypes: 11
  • hidden_dims: [256, 128, 64]
  • inflation_factor: 1.5

Key outputs:
  • cv_summary.ranked_configs - Ranked configurations
  • cv_summary.summary_df - Full results DataFrame
  • pc.pl.elbow_curve() - Interactive visualization

Next workflow: WORKFLOW_03 (Model Training with best config)
======================================================================