Source code for peach.tl.hyperparameters

# """Hyperparameter optimization for archetypal analysis."""

# from typing import List, Dict, Any
# from anndata import AnnData
# from torch.utils.data import DataLoader
# import torch

# # Import existing battle-tested functions
# from .._core.utils.hyperparameter_search import ArchetypalGridSearch, SearchConfig
# from .._core.utils.grid_search_results import CVSummary


# def hyperparameter_search(
#     adata: AnnData,
#     *,
#     n_archetypes_range: List[int] = [3, 4, 5, 6],
#     cv_folds: int = 3,
#     max_epochs_cv: int = 15,
#     pca_key: str = "X_pca",
#     device: str = "cpu",
#     base_model_config: Dict[str, Any] | None = None,
#     **kwargs
# ) -> CVSummary:
#     """Perform cross-validation hyperparameter search.

#     Systematically searches hyperparameter space using cross-validation
#     to find optimal model configurations for archetypal analysis.

#     Parameters
#     ----------
#     adata : AnnData
#         Annotated data object with PCA coordinates
#     n_archetypes_range : List[int], default: [3, 4, 5, 6]
#         Range of archetype numbers to test
#     cv_folds : int, default: 3
#         Number of cross-validation folds
#     max_epochs_cv : int, default: 15
#         Maximum epochs per CV fold (early stopping recommended)
#     pca_key : str, default: "X_pca"
#         Key in adata.obsm containing PCA coordinates
#     device : str, default: "cpu"
#         Device to use for training ('cpu', 'cuda', or 'mps')
#         Default is 'cpu' for stability on Apple Silicon
#     base_model_config : dict | None, default: None
#         Base model configuration to extend

#     Returns
#     -------
#     CVSummary
#         Complete cross-validation results with ranking and analysis methods

#     Examples
#     --------
#     >>> cv_summary = pc.tl.hyperparameter_search(
#     ...     adata,
#     ...     n_archetypes_range=[3, 4, 5],
#     ...     device='cpu'
#     ... )
#     >>> print(cv_summary.summary_report())
#     >>> top_configs = cv_summary.rank_by_metric('archetype_r2')
#     >>> fig = cv_summary.plot_elbow_curve(['archetype_r2', 'rmse'])
#     """
#     # Input validation
#     if pca_key not in adata.obsm:
#         raise ValueError(f"adata.obsm['{pca_key}'] not found. Run sc.pp.pca() first.")

#     # Create DataLoader
#     from ..pp.basic import prepare_training
#     dataloader = prepare_training(adata, pca_key=pca_key)

#     # Configure search
#     # Note: SearchConfig doesn't have latent_dim_offset_range parameter
#     # Filter out the latent_dim_offset_range if passed
#     search_kwargs = dict(kwargs)
#     if 'latent_dim_offset_range' in search_kwargs:
#         search_kwargs.pop('latent_dim_offset_range')

#     search_config = SearchConfig(
#         n_archetypes_range=n_archetypes_range,
#         cv_folds=cv_folds,
#         max_epochs_cv=max_epochs_cv,
#         **search_kwargs
#     )

#     # Default base model configuration
#     if base_model_config is None:
#         base_model_config = {
#             'input_dim': adata.obsm[pca_key].shape[1],
#             'barycentric_mode': True,
#             'device': device
#         }

#     # Run search
#     grid_search = ArchetypalGridSearch(search_config)
#     cv_summary = grid_search.fit(dataloader, base_model_config)

#     return cv_summary

"""Hyperparameter optimization for archetypal analysis.

User-facing wrapper for cross-validation hyperparameter search.
Provides a simple interface to the core grid search functionality.

Examples
--------
>>> import peach as pc
>>> # Basic search
>>> cv_summary = pc.tl.hyperparameter_search(adata, n_archetypes_range=[3, 4, 5, 6], cv_folds=5)
>>> # View results
>>> print(cv_summary.summary_report())
>>> top_configs = cv_summary.rank_by_metric("archetype_r2")[:3]
>>> fig = cv_summary.plot_elbow_r2()
"""

from typing import Any

from anndata import AnnData

from .._core.utils.grid_search_results import CVSummary
from .._core.utils.hyperparameter_search import ArchetypalGridSearch, SearchConfig