"""
Cross-Validation Hyperparameter Search for Archetypal Analysis
==============================================================
Phase 2 of the PEACH pipeline: systematic hyperparameter evaluation.
This module provides grid search over hyperparameter combinations using
cross-validation to estimate model performance. Results support manual
selection in Phase 3 - NO automatic selection is performed.
Pipeline Position
-----------------
Phase 1: Data Loading → **Phase 2: CV Search** → Phase 3: Manual Selection
→ Phase 4: Final Training → Phase 5: Evaluation
Main Classes
------------
SearchConfig : Configuration for hyperparameter search space and CV settings
ArchetypalGridSearch : Main orchestrator for grid search with cross-validation
Type Definitions
----------------
See ``peach._core.types`` for Pydantic models of return structures.
Examples
--------
>>> from peach._core.utils.hyperparameter_search import ArchetypalGridSearch, SearchConfig
>>> config = SearchConfig(n_archetypes_range=[3, 4, 5, 6], cv_folds=5, max_epochs_cv=50)
>>> grid_search = ArchetypalGridSearch(config)
>>> cv_summary = grid_search.fit(dataloader, base_model_config)
>>> print(cv_summary.summary_report())
"""
import os
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import numpy as np
import torch
from sklearn.model_selection import KFold
from torch.utils.data import DataLoader, TensorDataset
from .cv_training import CVTrainingManager
from .grid_search_results import CVResults, CVSummary
[docs]
@dataclass
class SearchConfig:
"""Configuration for hyperparameter search space and CV settings.
Defines the hyperparameter grid to search and cross-validation parameters.
Speed presets automatically adjust epochs and early stopping for different
use cases.
Parameters
----------
n_archetypes_range : list[int] | None, default: None
Range of archetype numbers to test. If None, defaults to [2, 3, 4, 5, 6, 7].
hidden_dims_options : list[list[int]] | None, default: None
Network architectures to test. If None, defaults to standard options:
[[128, 64], [256, 128, 64], [128], [512, 256, 128]].
inflation_factor_range : list[float] | None, default: None
Inflation factors to test. If None, uses [1.5] (Helsinki optimal).
Set to multiple values (e.g., [1.0, 1.5, 2.0]) to search inflation.
cv_folds : int, default: 5
Number of cross-validation folds.
max_epochs_cv : int, default: 100
Maximum epochs per CV fold (overridden by speed_preset).
early_stopping_patience : int, default: 5
Patience for early stopping (overridden by speed_preset).
subsample_fraction : float, default: 0.5
Fraction of data to use for CV when dataset > max_cells_cv.
max_cells_cv : int, default: 15000
Maximum cells for CV. Larger datasets are subsampled.
speed_preset : str, default: "balanced"
Training speed preset. Options:
- ``"fast"`` : 25 epochs, patience=3 (quick exploration)
- ``"balanced"`` : 50 epochs, patience=5 (recommended)
- ``"thorough"`` : 100 epochs, patience=8 (comprehensive)
use_pcha_init : bool, default: True
Whether to use PCHA initialization for archetypes.
random_state : int, default: 42
Random seed for reproducibility.
Attributes
----------
_search_inflation : bool
Internal flag indicating whether inflation is being searched
(True if inflation_factor_range was explicitly provided).
Raises
------
ValueError
If n_archetypes_range contains non-positive integers.
If cv_folds <= 0.
If max_epochs_cv <= 0.
If subsample_fraction not in (0, 1].
If max_cells_cv <= 0.
If inflation_factor_range contains non-positive values.
Examples
--------
>>> # Basic configuration
>>> config = SearchConfig(n_archetypes_range=[3, 4, 5], cv_folds=3, speed_preset="fast")
>>> # Search inflation factors
>>> config = SearchConfig(n_archetypes_range=[4, 5, 6], inflation_factor_range=[1.0, 1.5, 2.0], cv_folds=5)
>>> print(f"Searching inflation: {config._search_inflation}") # True
See Also
--------
ArchetypalGridSearch : Uses this configuration
peach.tl.hyperparameter_search : User-facing wrapper
"""
n_archetypes_range: list[int] = None
hidden_dims_options: list[list[int]] = None
cv_folds: int = 5
max_epochs_cv: int = 100
early_stopping_patience: int = 5
subsample_fraction: float = 0.5
max_cells_cv: int = 15000
speed_preset: str = "balanced"
use_pcha_init: bool = True
inflation_factor_range: list[float] = None
random_state: int = 42
[docs]
def __post_init__(self):
"""Initialize defaults and validate inputs."""
# Set defaults first
if self.n_archetypes_range is None:
self.n_archetypes_range = [2, 3, 4, 5, 6, 7]
if self.hidden_dims_options is None:
self.hidden_dims_options = [
[128, 64], # Standard architecture
[256, 128, 64], # Deeper network
[128], # Simpler network
[512, 256, 128], # Larger capacity
]
# NEW: Inflation factor handling
# If None, use default optimal value (Helsinki breakthrough: 1.5)
# If provided as list, test multiple values
if self.inflation_factor_range is None:
self.inflation_factor_range = [1.5] # Default optimal
self._search_inflation = False # Flag: not searching inflation
else:
self._search_inflation = True # Flag: actively searching inflation
# Validate inputs
if not self.n_archetypes_range or not all(n > 0 for n in self.n_archetypes_range):
raise ValueError("n_archetypes_range must contain positive integers")
if self.cv_folds <= 0:
raise ValueError("cv_folds must be positive")
if self.max_epochs_cv <= 0:
raise ValueError("max_epochs_cv must be positive")
if not (0 < self.subsample_fraction <= 1.0):
raise ValueError("subsample_fraction must be between 0 and 1")
if self.max_cells_cv <= 0:
raise ValueError("max_cells_cv must be positive")
# Validate inflation factors
if not all(f > 0 for f in self.inflation_factor_range):
raise ValueError("inflation_factor_range must contain positive values")
class ArchetypalGridSearch:
"""Main orchestrator for hyperparameter grid search with cross-validation.
Performs systematic search over hyperparameter combinations, evaluating
each configuration using K-fold cross-validation. Designed for large-scale
single-cell datasets with intelligent subsampling and memory management.
Parameters
----------
search_config : SearchConfig | None, default: None
Search configuration. If None, uses default SearchConfig().
Attributes
----------
config : SearchConfig
Active search configuration.
cv_manager : CVTrainingManager | None
CV training manager (initialized during fit).
results : CVSummary | None
Search results (populated after fit).
Examples
--------
>>> from peach._core.utils.hyperparameter_search import ArchetypalGridSearch, SearchConfig
>>> # Configure search
>>> config = SearchConfig(
... n_archetypes_range=[3, 4, 5, 6],
... hidden_dims_options=[[256, 128], [128, 64]],
... cv_folds=5,
... speed_preset="balanced",
... )
>>> # Run search
>>> grid_search = ArchetypalGridSearch(config)
>>> cv_summary = grid_search.fit(dataloader, base_model_config)
>>> # Analyze results
>>> print(cv_summary.summary_report())
>>> top_configs = cv_summary.rank_by_metric("archetype_r2")[:3]
>>> fig = cv_summary.plot_elbow_r2()
See Also
--------
SearchConfig : Configuration class
CVSummary : Return type of fit()
peach.tl.hyperparameter_search : User-facing wrapper
"""
def __init__(self, search_config: SearchConfig = None):
"""Initialize grid search with configuration.
Parameters
----------
search_config : SearchConfig | None
Search configuration. Defaults to SearchConfig() if None.
"""
self.config = search_config or SearchConfig()
self.cv_manager = None
self.results = None
self._best_model = None
# Set random seeds for reproducibility
torch.manual_seed(self.config.random_state)
np.random.seed(self.config.random_state)
# Setup speed preset configurations
self.speed_presets = {
"fast": {"max_epochs_cv": 25, "early_stopping_patience": 3},
"balanced": {"max_epochs_cv": 50, "early_stopping_patience": 5},
"thorough": {"max_epochs_cv": 100, "early_stopping_patience": 8},
}
# Apply speed preset
preset_config = self.speed_presets.get(self.config.speed_preset, {})
for key, value in preset_config.items():
setattr(self.config, key, value)
def fit(
self, dataloader: DataLoader, base_model_config: dict[str, Any], compute_strategy: str = "sequential"
) -> CVSummary:
"""Execute hyperparameter grid search with cross-validation.
Searches all combinations of hyperparameters, evaluating each with
K-fold cross-validation. Results are organized for manual selection.
Parameters
----------
dataloader : DataLoader
DataLoader containing full dataset. Will be subsampled if
larger than ``config.max_cells_cv``.
base_model_config : dict
Base configuration for Deep_AA model:
- ``input_dim`` : int - Input feature dimensions
- ``device`` : str - Computing device ('cpu', 'cuda', 'mps')
- Additional model parameters
compute_strategy : str, default: "sequential"
Execution strategy. Currently only "sequential" is implemented.
"parallel" is reserved for future HPC support.
Returns
-------
CVSummary
Complete cross-validation results with:
- ``config_results`` : dict[str, CVResults] - Results per configuration
- ``summary_df`` : pd.DataFrame - Summary table for analysis
- ``ranked_configs`` : list[dict] - Configurations ranked by R²
- ``cv_info`` : dict - Search metadata
Key methods on CVSummary:
- ``summary_report()`` : Text summary for decision support
- ``rank_by_metric(metric)`` : Rank configs by any metric
- ``plot_elbow_r2()`` : Elbow curve visualization
- ``plot_metric(metric)`` : Generic metric visualization
- ``save(path)`` / ``load(path)`` : Persistence
Notes
-----
- Large datasets (> max_cells_cv) are automatically subsampled
- GPU memory is cleared between configurations
- Results are also stored in ``self.results`` for later access
Examples
--------
>>> base_config = {"input_dim": adata.obsm["X_pca"].shape[1], "device": "cuda"}
>>> cv_summary = grid_search.fit(dataloader, base_config)
>>> # Quick summary
>>> print(cv_summary.summary_report())
>>> # Get top 3 configurations
>>> top3 = grid_search.get_top_configurations(top_k=3)
>>> for config in top3:
... print(f"{config['config_summary']}: R²={config['metric_value']:.4f}")
See Also
--------
CVSummary : Return type with analysis methods
get_top_configurations : Convenience method for top configs
"""
print(" Starting Archetypal Hyperparameter Grid Search")
if self.config._search_inflation:
print(
f" Search space: {len(self.config.n_archetypes_range)} × {len(self.config.hidden_dims_options)} × {len(self.config.inflation_factor_range)} = {len(self._get_hyperparameter_combinations())} combinations"
)
print(f" Searching inflation factors: {self.config.inflation_factor_range}")
else:
print(
f" Search space: {len(self.config.n_archetypes_range)} × {len(self.config.hidden_dims_options)} = {len(self._get_hyperparameter_combinations())} combinations"
)
print(f" Using fixed inflation: {self.config.inflation_factor_range[0]}")
print(f" CV folds: {self.config.cv_folds}")
print(f" Total training runs: {len(self._get_hyperparameter_combinations()) * self.config.cv_folds}")
# Prepare data with intelligent subsampling
cv_splits = self._prepare_cv_data(dataloader)
# Initialize CV training manager
self.cv_manager = CVTrainingManager(base_model_config, self.config)
# Generate hyperparameter combinations
hyperparameter_combinations = self._get_hyperparameter_combinations()
# Execute grid search
cv_results = []
total_combinations = len(hyperparameter_combinations)
for i, hyperparams in enumerate(hyperparameter_combinations):
print(f"\n🧪 Configuration {i + 1}/{total_combinations}: {hyperparams}")
start_time = time.time()
cv_result = self.cv_manager.train_cv_configuration(hyperparams, cv_splits)
elapsed_time = time.time() - start_time
cv_result.training_time = elapsed_time
cv_results.append(cv_result)
# Memory cleanup
torch.cuda.empty_cache() if torch.cuda.is_available() else None
print(f" [OK] Completed in {elapsed_time:.1f}s")
print(f" [STATS] Mean archetype R²: {cv_result.mean_metrics.get('archetype_r2', 0):.4f}")
print(f" [STATS] Mean validation R²: {cv_result.mean_metrics.get('val_archetype_r2', 0):.4f}")
# Compile results using new simplified architecture
self.results = CVSummary.from_cv_results(
cv_results=cv_results, search_config=self.config, data_info=self._get_data_info(dataloader)
)
print("\n Grid search completed!")
print(f" Best configuration: {self.results.ranked_configs[0]['config_summary']}")
print(f" Archetype R²: {self.results.ranked_configs[0]['metric_value']:.4f}")
return self.results
def _prepare_cv_data(self, dataloader: DataLoader) -> list[tuple[DataLoader, DataLoader]]:
"""Prepare cross-validation data splits with intelligent subsampling.
Parameters
----------
dataloader : DataLoader
Original full dataset loader.
Returns
-------
list[tuple[DataLoader, DataLoader]]
List of (train_loader, val_loader) tuples, one per fold.
Notes
-----
- Datasets > max_cells_cv are subsampled to max_cells_cv
- KFold splitting with shuffle for randomization
- DataLoader workers optimized for HPC environments
"""
# Extract full dataset
full_data = []
for batch in dataloader:
if isinstance(batch, (list, tuple)):
full_data.append(batch[0])
else:
full_data.append(batch)
full_data = torch.cat(full_data, dim=0)
n_total_cells = len(full_data)
print(f"[STATS] Dataset info: {n_total_cells:,} cells, {full_data.shape[1]} features")
# Determine subsampling strategy
if n_total_cells > self.config.max_cells_cv:
# Subsample for CV
n_cv_cells = min(self.config.max_cells_cv, int(n_total_cells * self.config.subsample_fraction))
# Stratified subsampling (simple random for now, could add PCA-based stratification)
subsample_indices = torch.randperm(n_total_cells)[:n_cv_cells]
cv_data = full_data[subsample_indices]
print(f" Subsampled to {n_cv_cells:,} cells ({n_cv_cells / n_total_cells * 100:.1f}%) for CV")
else:
cv_data = full_data
n_cv_cells = n_total_cells
print(f" Using full dataset for CV ({n_cv_cells:,} cells)")
# Create KFold splits
kfold = KFold(n_splits=self.config.cv_folds, shuffle=True, random_state=self.config.random_state)
cv_splits = []
for fold_idx, (train_indices, val_indices) in enumerate(kfold.split(cv_data)):
# Create datasets
train_data = cv_data[train_indices]
val_data = cv_data[val_indices]
# Create dataloaders
train_dataset = TensorDataset(train_data)
val_dataset = TensorDataset(val_data)
# Get optimized DataLoader settings from original loader if available
num_workers = getattr(dataloader, "num_workers", 0)
pin_memory = getattr(dataloader, "pin_memory", False)
# Auto-detect optimal settings if not provided
if num_workers == 0 and not hasattr(torch.backends, "mps"):
# Check for HPC environment
if any([os.environ.get("SLURM_JOB_ID"), os.environ.get("PBS_JOBID"), (os.cpu_count() or 1) > 16]):
num_workers = min(6, max(4, (os.cpu_count() or 1) - 2))
if torch.cuda.is_available():
pin_memory = True
# Build optimized DataLoader kwargs
loader_kwargs = {
"batch_size": dataloader.batch_size,
"num_workers": num_workers,
"pin_memory": pin_memory and torch.cuda.is_available(),
}
if num_workers > 0:
loader_kwargs["persistent_workers"] = True
loader_kwargs["prefetch_factor"] = 2
train_loader = DataLoader(train_dataset, shuffle=True, **loader_kwargs)
val_loader = DataLoader(val_dataset, shuffle=False, **loader_kwargs)
cv_splits.append((train_loader, val_loader))
print(f" Fold {fold_idx + 1}: {len(train_data):,} train, {len(val_data):,} val")
# Report DataLoader optimization
if num_workers > 0:
env_type = "HPC" if num_workers >= 4 else "Local"
print(f"\n[STATS] DataLoader optimization: {env_type} mode with {num_workers} workers")
if pin_memory:
print(" GPU optimizations enabled (pin_memory=True)")
return cv_splits
def _get_hyperparameter_combinations(self) -> list[dict[str, Any]]:
"""Generate all hyperparameter combinations to test.
Returns
-------
list[dict]
Hyperparameter combinations, each dict containing:
- ``n_archetypes`` : int
- ``hidden_dims`` : list[int]
- ``inflation_factor`` : float
- ``use_pcha_init`` : bool
- ``use_inflation`` : bool (True if inflation_factor > 1.0)
Notes
-----
Generates Cartesian product of:
n_archetypes × hidden_dims × inflation_factor
"""
combinations = []
for n_archetypes in self.config.n_archetypes_range:
for hidden_dims in self.config.hidden_dims_options:
for inflation_factor in self.config.inflation_factor_range:
combinations.append(
{
"n_archetypes": n_archetypes,
"hidden_dims": hidden_dims,
"inflation_factor": inflation_factor,
"use_pcha_init": self.config.use_pcha_init,
"use_inflation": inflation_factor > 1.0, # Auto-enable if factor > 1.0
# latent_dim is automatically set to n_archetypes in Deep_AA
}
)
return combinations
def _get_data_info(self, dataloader: DataLoader) -> dict[str, Any]:
"""Extract dataset information for results metadata.
Parameters
----------
dataloader : DataLoader
Data loader to analyze.
Returns
-------
dict
Dataset metadata:
- ``n_total_samples`` : int - Total number of samples
- ``n_features`` : int - Number of input features
- ``batch_size`` : int - Batch size
- ``device`` : str - Device of data tensors
"""
# Get sample batch to determine data characteristics
sample_batch = next(iter(dataloader))
if isinstance(sample_batch, (list, tuple)):
sample_data = sample_batch[0]
else:
sample_data = sample_batch
return {
"n_total_samples": len(dataloader.dataset),
"n_features": sample_data.shape[1],
"batch_size": dataloader.batch_size,
"device": str(sample_data.device),
}
def get_top_configurations(self, metric: str = "archetype_r2", top_k: int = 5) -> list[dict[str, Any]]:
"""Get top-k configurations ranked by specified metric.
Convenience method wrapping ``CVSummary.rank_by_metric()``.
Parameters
----------
metric : str, default: "archetype_r2"
Metric to rank by. Common options:
- ``"archetype_r2"`` : Reconstruction R² (higher is better)
- ``"rmse"`` : Root mean squared error (lower is better)
- ``"val_rmse"`` : Validation RMSE
- ``"convergence_epoch"`` : Training convergence speed
top_k : int, default: 5
Number of top configurations to return.
Returns
-------
list[dict]
Top configurations, each dict containing:
- ``hyperparameters`` : dict - Configuration parameters
- ``metric_value`` : float - Value of ranking metric
- ``std_error`` : float - Standard error across folds
- ``config_summary`` : str - Human-readable summary
Raises
------
ValueError
If ``fit()`` has not been called yet.
Examples
--------
>>> top_configs = grid_search.get_top_configurations(metric="archetype_r2", top_k=3)
>>> for i, config in enumerate(top_configs, 1):
... print(f"{i}. {config['config_summary']}")
... print(f" R²: {config['metric_value']:.4f} ± {config['std_error']:.4f}")
"""
if self.results is None:
raise ValueError("Must run fit() before getting configurations")
return self.results.rank_by_metric(metric)[:top_k]
def save_results(self, path: str | Path) -> None:
"""Save CV summary to disk.
Parameters
----------
path : str | Path
File path for saving (pickle format).
Raises
------
ValueError
If ``fit()`` has not been called yet.
Examples
--------
>>> grid_search.fit(dataloader, base_config)
>>> grid_search.save_results("cv_results.pkl")
"""
if self.results is None:
raise ValueError("No results to save. Run fit() first.")
self.results.save(path)
def load_results(self, path: str | Path) -> CVSummary:
"""Load CV summary from disk.
Parameters
----------
path : str | Path
File path to load from.
Returns
-------
CVSummary
Loaded results (also stored in self.results).
Examples
--------
>>> cv_summary = grid_search.load_results("cv_results.pkl")
>>> print(cv_summary.summary_report())
"""
self.results = CVSummary.load(path)
return self.results
# Future extension points for parallelization
def _execute_parallel(self, combinations: list[dict], cv_splits: list) -> list[CVResults]:
"""Future: Parallel execution of grid search."""
# TODO: Implement when moving to HPC with GPU clusters
# Could use multiprocessing, Ray, or joblib for parallelization
raise NotImplementedError("Parallel execution not yet implemented")
def _setup_gpu_strategy(self) -> None:
"""Future: Setup multi-GPU strategy for large-scale training."""
# TODO: Implement GPU cluster support
# Could use torch.distributed or similar for multi-GPU training
raise NotImplementedError("GPU cluster support not yet implemented")