from abc import ABC, abstractmethod
from pathlib import Path
import shutil
from typing import Any, Dict, List, Optional, Tuple, Union
import lightning as L
import torch
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.callbacks import (
    ModelCheckpoint,
    TQDMProgressBar as ProgressBar,
)
import numpy as np
import pandas as pd
import torchmetrics
import tqdm
from minerva.data.data_modules.base import MinervaDataModule
from minerva.pipelines.base import Pipeline
from minerva.utils.typing import PathLike
from dataclasses import asdict, dataclass
from minerva.utils.string_ops import tree_like_formating, indent_text
[docs]
class ModelInstantiator(ABC):
    """Abstract base class for lazy instantiation of PyTorch Lightning models.
    This interface defines a standardized way to construct models in three
    common training scenarios:
    1. Training from scratch: the entire model (backbone + head) is randomly
       initialized.
    2. Finetuning: a pretrained backbone is loaded from a checkpoint, while the
       head is randomly initialized.
    3. Inference/Evaluation: the full model is restored from a previously
       saved checkpoint. Usually, this checkpoint is generated using one of the
       two scenarios above.
    This abstraction allows for flexible and decoupled model construction across
    various stages of the machine learning lifecycle. Thus, is expected that
    model's architecture follows the same pattern as the one below:
        +-------------------------------+
        |   Model (LightningModule)     |
        |                               |
        |     +-----------------+       |
        |     |    Backbone     |       |   --> Feature extractor
        |     +-----------------+       |
        |             |                 |
        |             v                 |
        |        +----------+           |
        |        |   Head   |           |   --> Task-specific layers
        |        +----------+           |
        +-------------------------------+
    Definitions
    -----------
    - Backbone: Core feature extractor (e.g., ResNet, Transformer encoder).
    - Head: Task-specific layers (e.g., classification head, regression head).
    Implementations of this class should handle the appropriate model loading
    logic for each use case described above.
    """
[docs]
    @abstractmethod
    def create_model_randomly_initialized(self) -> L.LightningModule:
        """Create a model with both backbone and head randomly initialized.
        Typically used when training a model from scratch.
        Returns
        -------
        L.LightningModule
            A Lightning model fully initialized with random weights, ready for
            training.
        """
        raise NotImplementedError(
            "create_model_randomly_initialized must be implemented."
        ) 
[docs]
    @abstractmethod
    def create_model_and_load_backbone(
        self, backbone_checkpoint_path: PathLike
    ) -> L.LightningModule:
        """Create a model for finetuning with a pretrained backbone and a
        new head (randomly initialized). This method should load the backbone
        weights from the specified checkpoint and attach a freshly initialized
        head for the downstream task. User must handle the logic to load the
        backbone weights into the model's state dict.
        Parameters
        ----------
        backbone_checkpoint_path : PathLike
            Path to the checkpoint containing pretrained backbone weights. The
            checkpoint must be compatible with the model architecture.
        Returns
        -------
        L.LightningModule
            The model ready for finetuning (pretrained backbone, new head).
        """
        pass 
[docs]
    @abstractmethod
    def load_model_from_checkpoint(
        self, checkpoint_path: PathLike
    ) -> L.LightningModule:
        """Load the full model (backbone and head) from a saved checkpoint.
        Typically used for resuming training, evaluation, or inference when the
        model must be restored in its entirety. In practice, the checkpoint
        should be one created using `create_model_and_load_backbone` or
        `create_model_randomly_initialized`.
        The checkpoint must be compatible with the model architecture.
        Parameters
        ----------
        checkpoint_path : PathLike
            Path to the checkpoint file containing the full model state.
        Returns
        -------
        L.LightningModule
            A Lightning model fully restored from checkpoint, ready for
            evaluation or inference.
        """
        raise NotImplementedError(
            "load_model_from_checkpoint must be implemented in the subclass."
        ) 
 
[docs]
class ModelConfig:
    """Encapsulates the full configuration of a model for use in a training or
    inference pipeline.
    A `ModelConfig` brings together two key components:
    - `ModelInstantiator`: Responsible for creating the model in different
      modes (lazily instantiated):
        - From scratch (randomly initialized)
        - Finetuning (load pretrained backbone, new head)
        - From checkpoint (fully restored model)
    - `ModelInformation`: Contains descriptive metadata about the model such
      as input/output shapes, number of classes, backbone used, and task type.
    This class serves as the primary interface for managing and accessing
    model configuration throughout the lifecycle of training, evaluation, or
    deployment.
    """
    def __init__(
        self,
        instantiator: ModelInstantiator,
        information: ModelInformation,
    ):
        """Initialize a model configuration.
        Parameters
        ----------
        instantiator : ModelInstantiator
            An instance responsible for constructing the model in various
            training modes (random init, load backbone, load full checkpoint).
            This enables lazy instantiation depending on the training phase.
        information : ModelInformation
            Metadata describing the model's architecture and behavior.
            Includes input/output shapes, task type, number of classes, and
            other relevant info useful for logging, validation, and downstream
            processing.
        """
        self.instantiator = instantiator
        self.information = information
[docs]
    def __str__(self):
        return (
            f"ModelConfig\n"
            + f"├── Instantiator: {self.instantiator.__class__.__name__}\n"
            + indent_text(tree_like_formating(asdict(self.information)), spaces=0)
        ) 
 
# -------- Functional interfaces --------
[docs]
def get_trainer(
    log_dir: Path,
    max_epochs: int = 100,
    limit_train_batches: Optional[Union[int, float]] = None,
    limit_val_batches: Optional[Union[int, float]] = None,
    limit_test_batches: Optional[Union[int, float]] = None,
    limit_predict_batches: Optional[Union[int, float]] = None,
    accelerator: str = "auto",
    strategy: str = "auto",
    devices: Optional[Union[int, list[int], str]] = "auto",
    num_nodes: int = 1,
    progress_bar_refresh_rate: int = 1,
    enable_logging: bool = True,
    checkpoint_metrics: Optional[List[Dict[str, str]]] = None,
    precision: str = "32-true",
    accumulate_grad_batches: int = 1,
    deterministic: bool = False,
    benchmark: bool = True,
    profiler: Optional[str] = None,
    overfit_batches: Union[int, float] = 0.0,
    sync_batchnorm: bool = False,
) -> L.Trainer:
    """Creates and configures a PyTorch Lightning Trainer instance.
    This function encapsulates all necessary options for flexible training,
    evaluation, or inference, including logging, checkpointing, device setup,
    precision, and more.
    Parameters
    ----------
    log_dir : Path
        Directory path where logs and checkpoints will be saved.
    max_epochs : int, default=100
        Maximum number of epochs for training.
    limit_train_batches : int or float, optional
        Limit on the number of training batches per epoch. Can be an integer
        (absolute number) or a float (fraction of total batches).
    limit_val_batches : int or float, optional
        Limit on the number of validation batches per epoch.
    limit_test_batches : int or float, optional
        Limit on the number of test batches per epoch.
    limit_predict_batches : int or float, optional
        Limit on the number of prediction batches.
    accelerator : str, default="auto"
        Hardware accelerator to use (e.g., "gpu", "cpu", "tpu", "auto").
    strategy : str, default="auto"
        Distributed training strategy (e.g., "ddp", "deepspeed", etc.).
    devices : int, list of int, or str, optional, default="auto"
        Devices to use for training (e.g., 1, [0,1], "auto").
    num_nodes : int, default=1
        Number of nodes to use for distributed training.
    progress_bar_refresh_rate : int, default=1
        Frequency (in steps) at which the progress bar is updated.
        Set to 0 to disable.
    enable_logging : bool, default=True
        Whether to enable CSV logging.
    checkpoint_metrics : list of dict, optional
        List of dictionaries containing checkpoint configurations. Each
        dictionary should specify "monitor", "mode", and "filename".
    precision : str, default="32-true"
        Numerical precision to use during training (e.g., 32-true, 16-mixed).
    accumulate_grad_batches : int, default=1
        Number of batches for which gradients should be accumulated before
        performing an optimizer step.
    deterministic : bool, default=False
        If True, sets deterministic behavior for reproducibility.
    benchmark : bool, default=True
        Enables the cudnn.benchmark flag for optimized performance on fixed
        input sizes.
    profiler : str, optional
        Enables performance profiling (e.g., "simple", "advanced").
    overfit_batches : int or float, default=0.0
        Uses a fraction or number of batches for both training and validation
        to quickly debug overfitting behavior.
    sync_batchnorm : bool, default=False
        Synchronizes batch norm layers across devices during distributed
        training.
    Returns
    -------
    L.Trainer
        A configured PyTorch Lightning Trainer instance.
    """
    if enable_logging:
        logger = CSVLogger(
            save_dir=log_dir.parents[1],
            name=log_dir.parents[0].name,
            version=log_dir.name,
        )
    else:
        logger = False
    callbacks = []
    enable_checkpointing = False
    if checkpoint_metrics:
        for ckpt_metric in checkpoint_metrics:
            ckpt_kwargs = {
                "monitor": ckpt_metric["monitor"],
                "mode": ckpt_metric["mode"],
                "filename": ckpt_metric["filename"],
                "save_last": False,
                "enable_version_counter": False,
            }
            callbacks.append(ModelCheckpoint(**ckpt_kwargs))
        enable_checkpointing = True
    else:
        enable_checkpointing = False
    enable_progress_bar = True
    log_every_n_steps = None
    if progress_bar_refresh_rate == 0:
        enable_progress_bar = False
    else:
        callbacks.append(ProgressBar(refresh_rate=progress_bar_refresh_rate))
    return L.Trainer(
        accelerator=accelerator,
        devices=devices,  # type: ignore
        logger=logger,
        callbacks=callbacks,
        enable_checkpointing=enable_checkpointing,
        enable_progress_bar=enable_progress_bar,
        enable_model_summary=True,
        log_every_n_steps=log_every_n_steps,
        max_epochs=max_epochs,
        strategy=strategy,
        num_nodes=num_nodes,
        limit_train_batches=limit_train_batches,
        limit_val_batches=limit_val_batches,
        limit_test_batches=limit_test_batches,
        limit_predict_batches=limit_predict_batches,
        precision=precision,  # type: ignore
        accumulate_grad_batches=accumulate_grad_batches,
        deterministic=deterministic,
        benchmark=benchmark,
        inference_mode=True,
        profiler=profiler,
        overfit_batches=overfit_batches,
        sync_batchnorm=sync_batchnorm,
    ) 
[docs]
def save_predictions(
    predictions: Union[np.ndarray, torch.Tensor], path: PathLike
) -> None:
    """Save predictions to a given path.
    Parameters
    ----------
    predictions : Union[np.ndarray, torch.Tensor]
        The prediction data to save.
    path : PathLike
        The path where the predictions will be saved.
    """
    if isinstance(predictions, torch.Tensor):
        predictions = predictions.cpu().numpy()
    if not isinstance(predictions, np.ndarray):
        raise ValueError("Predictions must be a numpy array.")
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)
    if path.suffix == "npz":
        np.savez(path, predictions=predictions)
    else:
        np.save(path, predictions)
    print(f"Predictions saved to {path}") 
[docs]
def load_predictions(path: PathLike) -> np.ndarray:
    """Load a prediction from a given path.
    Parameters
    ----------
    path : PathLike
        The path to the prediction file.
    Returns
    -------
    np.ndarray
        The loaded prediction data.
    """
    path = Path(path)
    if path.suffix == "npz":
        data = np.load(path, allow_pickle=True)
        if "predictions" in data:
            return data["predictions"]
        else:
            raise ValueError("No 'predictions' key found in the npz file.")
    else:
        return np.load(path, allow_pickle=True) 
[docs]
def save_results(results: pd.DataFrame, path: PathLike, index: bool = False) -> None:
    """Save results to a given path.
    Parameters
    ----------
    results : pd.DataFrame
        The results data to save.
    path : PathLike
        The path where the results will be saved.
    index : bool, optional
        Whether to save the index of the DataFrame, by default False
    """
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)
    results.to_csv(path, index=index)
    print(f"Results saved to {path}") 
[docs]
def load_results(path: PathLike) -> pd.DataFrame:
    """Load results from a given path.
    Parameters
    ----------
    path : PathLike
        The path to the results file.
    Returns
    -------
    pd.DataFrame
        The loaded results data.
    """
    path = Path(path)
    if not path.is_file():
        raise ValueError(f"File {path} does not exist.")
    return pd.read_csv(path) 
[docs]
class Experiment(Pipeline):
    NUM_DEBUG_EPOCHS = 3
    NUM_DEBUG_BATCHES = 10
    def __init__(
        self,
        # Base parameters
        experiment_name: str,
        model_config: ModelConfig,
        data_module: MinervaDataModule,
        # Logging and checkpointing parameters
        pretrained_backbone_ckpt_path: Optional[PathLike] = None,
        root_log_dir: PathLike = "./logs",
        execution_id: Union[str, int] = 0,
        checkpoint_metrics: Optional[List[Dict[str, str]]] = None,
        # Trainer-related parameters
        max_epochs: int = 100,
        accelerator: str = "gpu",
        devices: Optional[Union[int, list[int], str]] = 1,
        strategy: str = "auto",
        num_nodes: int = 1,
        limit_train_batches: Optional[Union[int, float]] = None,
        limit_val_batches: Optional[Union[int, float]] = None,
        limit_test_batches: Optional[Union[int, float]] = None,
        limit_predict_batches: Optional[Union[int, float]] = None,
        # Prediction parameters
        evaluation_metrics: Optional[Dict[str, torchmetrics.Metric]] = None,
        per_sample_evaluation_metrics: Optional[Dict[str, torchmetrics.Metric]] = None,
        # Other parameters
        seed: Optional[int] = None,
        progress_bar_refresh_rate: int = 1,
        profiler: Optional[str] = None,
        save_predictions: bool = True,
        save_results: bool = True,
        add_last_checkpoint: bool = True,
    ):
        """An experiment is a pipeline that contains all the parameters needed
        to train and evaluate a model, as well as to manage the logging,
        checkpointing, prediction, and results processes in a coherent way.
        Parameters
        ----------
        experiment_name : str
            The name of the experiment. This name will be used to create a
            directory for the experiment in the log directory.
        model_config : ModelConfig
            The model configuration. This object contains the model instantiator
            and the model information.
        data_module : MinervaDataModule
            The data module. This object contains the training, validation, and
            test datasets, as well as the data loaders. For now, datasets must
            return a 2 element tuple (input, label) for each sample.
        pretrained_backbone_ckpt_path : Optional[PathLike], optional
            The path to the pretrained backbone checkpoint. This is used to
            finetune the model. If None, the model will be trained from
            scratch. This parameter handles the lazy instantiation of the model
            and calls `create_model_and_load_backbone` method of the model
            instantiator if `pretrained_backbone_ckpt_path` is not None or
            `create_model_randomly_initialized` method if it is None. By
            default None
        root_log_dir : PathLike, optional
            Root directory for logging and checkpoints. This directory will be
            used to create a subdirectory for the experiment. By default ./logs
        execution_id : Union[str, int], optional
            The execution ID for the experiment. This ID will be used to create
            a subdirectory for the experiment in the log directory. This is
            useful when running the experiment multiple times with the same
            parameters. By default 0
        checkpoint_metrics : Optional[List[Dict[str, str]]], optional
            The checkpoint metrics. This is a list of dictionaries that contain
            the checkpoint metrics. Each dictionary must contain the keys
            "monitor", "mode", and "filename". The "monitor" key is the name of
            the metric to monitor, the "mode" key is the mode of the metric
            ("min" or "max"), and the "filename" key is the name of the
            checkpoint file. The "monitor" key can be None if the checkpoint is
            the last one. By default None
        max_epochs : int, optional
            Number of epochs to train the model. This parameter is passed to the
            `get_trainer` function. By default 100.
        accelerator : str, optional
            The accelerator to use for training. This parameter is passed to the
            `get_trainer` function. By default "gpu". Possible values are
            "cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto". If "auto" is
            selected, the accelerator will be automatically selected based on
            the available hardware. By default "gpu"
        devices : Optional[Union[int, list[int], str]], optional
            Number of accelerators to use for training. This parameter is
            passed to the `get_trainer` function. By default 1.
        strategy : str, optional
            Strategy to use for distributed training. This parameter is passed
            to the `get_trainer` function. By default "auto".
        num_nodes : int, optional
            Number of nodes to use for distributed training. This parameter is
            passed to the `get_trainer` function. By default 1.
        limit_train_batches : Optional[Union[int, float]], optional
            Limit the number of training batches to use. This parameter is
            passed to the `get_trainer` function. By default None. If None, all
            batches will be used. If an integer is provided, it will be the
            absolute number of batches. If a float is provided, it will be the
            fraction of the total number of batches. For example, 0.1 means 10%
            of the training batches will be used.
        limit_val_batches : Optional[Union[int, float]], optional
            Limit the number of validation batches to use. This parameter is
            passed to the `get_trainer` function. By default None. If None, all
            batches will be used. If an integer is provided, it will be the
            absolute number of batches. If a float is provided, it will be the
            fraction of the total number of batches. For example, 0.1 means 10%
            of the validation batches will be used.
        limit_test_batches : Optional[Union[int, float]], optional
            Limit the number of test batches to use. This parameter is
            passed to the `get_trainer` function. By default None. If None, all
            batches will be used. If an integer is provided, it will be the
            absolute number of batches. If a float is provided, it will be the
            fraction of the total number of batches. For example, 0.1 means 10%
            of the test batches will be used.
        limit_predict_batches : Optional[Union[int, float]], optional
            Limit the number of prediction batches to use. This parameter is
            passed to the `get_trainer` function. By default None. If None, all
            batches will be used. If an integer is provided, it will be the
            absolute number of batches. If a float is provided, it will be the
            fraction of the total number of batches. For example, 0.1 means 10%
            of the prediction batches will be used.
        evaluation_metrics : Optional[Dict[str, torchmetrics.Metric]], optional
            A dictionary of evaluation metrics to use for the predictions. The
            keys are the names of the metrics and the values are the
            `torchmetrics.Metric` objects. These metrics are calculated using
            all the predictions. By default None.
        per_sample_evaluation_metrics : Optional[ Dict[str, torchmetrics.Metric] ], optional
            A dictionary of evaluation metrics to use for the predictions. The
            keys are the names of the metrics and the values are the
            `torchmetrics.Metric` objects. These metrics are calculated using
            each prediction separately, that is, applyied per sample. By
            default None.
        seed : Optional[int], optional
            The seed to use for the experiment, by default None
        progress_bar_refresh_rate : int, optional
            The refresh rate of the progress bar (in batches). If 0, the
            progress bar is disabled. If 1, the progress bar is updated every
            batch. By default 1
        profiler : Optional[str], optional
            A profiler to use for the experiment. This parameter is passed to
            the `get_trainer` function. By default None.
        save_predictions : bool, optional
            If True, the predictions will be saved to the log directory. By
            default True
        save_results : bool, optional
            If True, the results will be saved to the log directory. By
            default True
        add_last_checkpoint : bool, optional
            If True, the last checkpoint will be added to the list of checkpoint
            metrics. By default True.
        Raises
        ------
        ValueError
            If the checkpoint metrics are not valid or do not contain the
            required keys.
        Notes
        ------
        - This class assumes that the `MinervaDataModule` class returns a
            (input, label) tuple for each sample in the dataset. The input is
            the data and the label is the ground truth/target.
        """
        # ------- Base parameters -------
        self.experiment_name = experiment_name
        self.model_config = model_config
        self.data_module = data_module
        # ------- Logging and checkpointing parameters -------
        self.pretrained_backbone_ckpt_path = pretrained_backbone_ckpt_path
        self.root_log_dir = Path(root_log_dir)
        self.execution_id = str(execution_id)
        self.checkpoint_metrics = checkpoint_metrics or []
        # Check if checkpoint metrics are valid
        for ckpt_metric in self.checkpoint_metrics:
            if not isinstance(ckpt_metric, dict):
                raise ValueError("Checkpoint metric must be a dictionary.")
            for key in ["monitor", "mode", "filename"]:
                if key not in ckpt_metric:
                    raise ValueError(f"Checkpoint metric must contain a '{key}' key.")
        # Add the "last" checkpoint metric if not already present
        if add_last_checkpoint:
            if not any(
                ckpt_metric.get("filename") == "last"
                for ckpt_metric in self.checkpoint_metrics
            ):
                self.checkpoint_metrics.append(
                    {"monitor": None, "mode": "min", "filename": "last"}  # type: ignore
                )
        # -------  Trainer-related parameters -------
        self.max_epochs = max_epochs
        self.accelerator = accelerator
        self.devices = devices
        self.strategy = strategy
        self.num_nodes = num_nodes
        self.limit_train_batches = limit_train_batches
        self.limit_val_batches = limit_val_batches
        self.limit_test_batches = limit_test_batches
        self.limit_predict_batches = limit_predict_batches
        # ------- Prediction parameters -------
        self.evaluation_metrics = evaluation_metrics or {}
        self.per_sample_evaluation_metrics = per_sample_evaluation_metrics or {}
        # ------- Other parameters -------
        self.seed = seed
        self.progress_bar_refresh_rate = progress_bar_refresh_rate
        self.profiler = profiler
        self.save_predictions = save_predictions
        self.save_results = save_results
        # ------- Initialize the pipeline -------
        log_dir = (
            self.root_log_dir
            / self.experiment_name
            / self.data_module.dataset_name
            / self.model_config.information.name
            / self.execution_id
        )
        super().__init__(
            log_dir=log_dir,
            cache_result=False,
            save_run_status=False,
            seed=seed,
            ignore=["model_config", "data_module"],
        )
        self._checkpoint_dir = log_dir / "checkpoints"
        self._predictions_dir = log_dir / "predictions"
        self._results_dir = log_dir / "results"
        self._training_metrics_path = log_dir / "metrics.csv"
    # ------------ Acessors ------------
    # Here we have acess to:
    # - checkpoint paths
    # - metrics and metrics path
    # - prediction paths
    # - results and results path
    @property
    def checkpoint_paths(self) -> Dict[str, Path]:
        """Returns a dictionary of checkpoint paths for the experiment.
        The keys are the checkpoint names, and the values are the corresponding
        paths to the checkpoints.
        Returns
        -------
        Dict[str, Path]
            A dictionary mapping checkpoint names to their respective paths.
        """
        return {p.stem: p for p in self._checkpoint_dir.glob("*.ckpt") if p.is_file()}
    @property
    def training_metrics_path(self) -> Optional[Path]:
        """The path to the training metrics file.
        Returns
        -------
        Optional[Path]
            The path to the metrics file if it exists, otherwise None.
        """
        if self._training_metrics_path.is_file():
            return self._training_metrics_path
        return None
    @property
    def training_metrics(self) -> Optional[pd.DataFrame]:
        """Returns the training metrics as a pandas DataFrame.
        If the metrics file does not exist, returns None.
        Returns
        -------
        Optional[pd.DataFrame]
            A DataFrame containing the training metrics.
        """
        # Check if the metrics file exists and is a file
        path = self.training_metrics_path
        if path:
            return pd.read_csv(self._training_metrics_path)
        else:
            return None
    @property
    def prediction_paths(self) -> Dict[str, Path]:
        """Returns a dictionary of prediction paths for the experiment.
        The keys are the prediction names, and the values are the corresponding
        paths to the predictions.
        Returns
        -------
        Dict[str, Path]
            A dictionary mapping prediction names to their respective paths.
        """
        return {p.stem: p for p in self._predictions_dir.glob("*.npy") if p.is_file()}
[docs]
    def load_predictions_of_ckpt(self, name: str) -> np.ndarray:
        """Load predictions from a file.
        Parameters
        ----------
        name : str
            The name of the prediction file (without extension).
        Returns
        -------
        np.ndarray
            The loaded predictions as a numpy array.
        """
        try:
            path = self.prediction_paths[name]
            return load_predictions(path)
        except KeyError:
            raise Exception(
                f"Prediction file '{name}' not found in {self._predictions_dir}"
            ) 
    @property
    def results_paths(self) -> Dict[str, Path]:
        """Returns a dictionary of results paths for the experiment.
        The keys are the result names, and the values are the corresponding
        paths to the results.
        Returns
        -------
        Dict[str, Path]
            A dictionary mapping result names to their respective paths.
        """
        return {p.stem: p for p in self._results_dir.glob("*.csv") if p.is_file()}
[docs]
    def load_results_of_ckpt(self, name: str) -> pd.DataFrame:
        """Load results from a file.
        Parameters
        ----------
        name : str
            The name of the result file (without extension).
        Returns
        -------
        pd.DataFrame
            The loaded results as a pandas DataFrame.
        """
        try:
            path = self.results_paths[name]
            return load_results(path)
        except KeyError:
            raise Exception(f"Results file '{name}' not found in {self._results_dir}") 
    # ---------- Trainer ---------
[docs]
    def _trainer_parameters(
        self, enable_logging: bool = True, debug: bool = False
    ) -> Dict[str, Any]:
        """Return the parameters for the trainer based on the current on debug
        and logging settings.
        Parameters
        ----------
        enable_logging : bool, optional
            If True, logging will be enabled, by default True
        debug : bool, optional
            If True,  model will be trained with a few batches and for a few
            epochs only. Logging will always be disabled, by default False
        Returns
        -------
        Dict[str, Any]
            All the parameters for the `get_trainer` function.
        """
        return {
            "log_dir": self.log_dir,
            "max_epochs": self.NUM_DEBUG_EPOCHS if debug else self.max_epochs,
            "limit_train_batches": (
                self.NUM_DEBUG_BATCHES if debug else self.limit_train_batches
            ),
            "limit_val_batches": (
                self.NUM_DEBUG_BATCHES if debug else self.limit_val_batches
            ),
            "limit_test_batches": (
                self.NUM_DEBUG_BATCHES if debug else self.limit_test_batches
            ),
            "limit_predict_batches": (
                self.NUM_DEBUG_BATCHES if debug else self.limit_predict_batches
            ),
            "accelerator": self.accelerator,
            "strategy": self.strategy,
            "devices": self.devices,
            "num_nodes": self.num_nodes,
            "progress_bar_refresh_rate": self.progress_bar_refresh_rate,
            "enable_logging": False if debug else enable_logging,
            "checkpoint_metrics": None if debug else self.checkpoint_metrics,
            "precision": "32-true",
            "deterministic": False,
            "benchmark": True,
            "profiler": self.profiler,
        } 
    # ---------- FIT Experiment methods ---------
    @staticmethod
    def __typing_string(value):
        if isinstance(value, torch.Tensor) or isinstance(value, np.ndarray):
            return f"shape={tuple(value.shape)}"
        else:
            return f"scalar with type={type(value).__name__}"
[docs]
    def _print_train_summary(
        self,
        model: L.LightningModule,
        trainer_params: Dict[str, Any],
        debug: bool = False,
        resume_from_ckpt: Optional[str] = None,
    ) -> None:
        print("\n" + "=" * 80)
        print(
            f"Experiment: {self.experiment_name} {'(DEBUG)' if debug else ''}".center(
                80
            )
        )
        print("=" * 80)
        #  ------------ Model info ------------
        finetune_backbone = self.pretrained_backbone_ckpt_path is not None
        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print("🧠 Model")
        print(f"   ├── Name: {self.model_config.information.name}")
        print(f"   ├── Finetune: {'Yes' if finetune_backbone else 'No'}")
        if finetune_backbone:
            print(
                f"   |    └── Pretrained Backbone Checkpoint: {self.pretrained_backbone_ckpt_path}"
            )
        print(f"   ├── Resumed From: {resume_from_ckpt or 'Beginning'}")
        print(
            f"   ├── Expected Input Shape: {self.model_config.information.input_shape}"
        )
        print(
            f"   ├── Expected Output Shape: {self.model_config.information.output_shape}"
        )
        print(f"   ├── Total Params: {total_params:,}")
        try:
            print(
                f"   └── Trainable Params: {trainable_params:,} ({trainable_params / total_params:.2%})"
            )
        except ZeroDivisionError:
            print("   └── Trainable Params: 0 (0.00%)")
        # ------------ Dataset info ------------
        print("\n📊 Dataset")
        train_data = self.data_module.train_dataset
        val_data = self.data_module.val_dataset
        if train_data:
            x, y = train_data[0]
            print(f"   ├── Train Samples: {len(train_data)}")
            print(f"   |   ├── Input Shape: {self.__typing_string(x)}")
            print(f"   |   └── Label Shape: {self.__typing_string(y)}")
        else:
            print("   ├── Train Dataset: None")
        if val_data:
            x, y = val_data[0]
            print(f"   └── Validation Samples: {len(val_data)}")
            print(f"       ├── Input Shape: {self.__typing_string(x)}")
            print(f"       └── Label Shape: {self.__typing_string(y)}")
        else:
            print("   └── Validation Dataset: None")
        # ------------ Logging & checkpoints ------------
        ckpt_filenames = ", ".join(
            [
                f"{m['filename']}.ckpt"
                for m in self.checkpoint_metrics
                if m.get("filename")
            ]
        )
        checkpoints_exist = self.checkpoint_paths
        print("\n💾 Logging & Checkpoints")
        print(f"   ├── Log Dir: {self.log_dir}")
        print(f"   ├── Metrics Path: {self._training_metrics_path}")
        print(f"   └── Checkpoints Dir: {self._checkpoint_dir}")
        if checkpoints_exist:
            print(f"       ├── Files: {ckpt_filenames or 'None'}")
            print(f"       └── ⚠️ Existing checkpoints found! It will be overwritten!")
        else:
            print(f"       └── Files: {ckpt_filenames or 'None'}")
        # ------------ Trainer configuration ------------
        print("\n⚙️ Trainer Config")
        print(f"   ├── Max Epochs: {trainer_params['max_epochs']}")
        print(f"   ├── Train Batches: {trainer_params['limit_train_batches']}")
        print(f"   ├── Accelerator: {trainer_params['accelerator']}")
        print(f"   ├── Strategy: {trainer_params['strategy']}")
        print(f"   ├── Devices: {trainer_params['devices']}")
        print(f"   ├── Num Nodes: {trainer_params['num_nodes']}")
        print(f"   └── Seed: {self.seed}") 
[docs]
    def _train_model(
        self,
        resume_from_ckpt: Optional[str] = None,
        debug: bool = False,
        print_summary: bool = True,
    ) -> Dict[str, Any]:
        data_module = self.data_module
        # If pre-trained backbone is provided, load the model with the
        # pre-trained backbone (usually for finetuning)
        if self.pretrained_backbone_ckpt_path is not None:
            model = self.model_config.instantiator.create_model_and_load_backbone(
                self.pretrained_backbone_ckpt_path
            )
        # If no pre-trained backbone is provided, create a model with
        # randomly initialized backbone and head (full supervised training)
        else:
            model = self.model_config.instantiator.create_model_randomly_initialized()
        # Get the trainer
        trainer_params = self._trainer_parameters(
            enable_logging=True,
            debug=debug,
        )
        trainer = get_trainer(**trainer_params)
        # Check if need to resume from a checkpoint
        checkpoints = self.checkpoint_paths
        if resume_from_ckpt:
            if resume_from_ckpt not in checkpoints:
                raise ValueError(
                    f"Checkpoint '{resume_from_ckpt}' not found in {list(checkpoints.keys())}"
                )
            resume_from_ckpt = checkpoints[resume_from_ckpt]  # type: ignore
        # Print the training summary
        if print_summary:
            self._print_train_summary(
                model=model,
                trainer_params=trainer_params,
                debug=debug,
                resume_from_ckpt=resume_from_ckpt,
            )
        # Train the model
        perform_train(
            data_module=data_module,
            model=model,
            trainer=trainer,
            resume_from_ckpt=resume_from_ckpt,
        )
        return {
            "data_module": data_module,
            "model": model,
            "trainer": trainer,
            "log_dir": self.log_dir,
            "metrics_path": self._training_metrics_path,
            "checkpoints": self.checkpoint_paths,
        } 
    # ---------- EVALUATE Experiment methods ---------
[docs]
    def _print_evaluation_summary(
        self,
        trainer_params: Dict[str, Any],
        debug: bool = False,
        ckpt_path: Optional[PathLike] = None,
        predictions_path: Optional[PathLike] = None,
        results_path: Optional[PathLike] = None,
    ) -> None:
        print("\n" + "=" * 80)
        print(
            f"Evaluation: {self.experiment_name} ({ckpt_path.name}) {'(DEBUG)' if debug else ''}".center(
                80
            )
        )
        print("=" * 80)
        # ------------ Checkpoint Info ------------
        print("💾 Checkpoint")
        print(f"   ├── Checkpoint Path: {ckpt_path}")
        print(f"   └── Predictions Path: {predictions_path or 'Not saved'}")
        # ------------ Dataset Info ------------
        print("\n📊 Dataset")
        predict_data = self.data_module.predict_dataset
        if predict_data:
            x, y = predict_data[0]
            print(f"   ├── Predict Samples: {len(predict_data)}")
            print(f"   ├── Input: {self.__typing_string(x)}")
            print(f"   └── Label: {self.__typing_string(y)}")
        else:
            print("   └── Predict Dataset: None")
        # ------------ Evaluation Metrics ------------
        print("\n📈 Evaluation Metrics")
        if self.evaluation_metrics or self.per_sample_evaluation_metrics:
            for name, metric in self.evaluation_metrics.items():
                print(f"   ├── {name}: {metric.__class__.__name__}")
            for name, metric in self.per_sample_evaluation_metrics.items():
                print(f"   ├── {name}: {metric.__class__.__name__} (PER_SAMPLE)")
        else:
            print("   └── No evaluation metrics provided.")
        # ------------ Trainer Configuration ------------
        print("\n⚙️ Trainer Config")
        print(f"   ├── Max Epochs: {trainer_params['max_epochs']}")
        print(f"   ├── Predict Batches: {trainer_params['limit_predict_batches']}")
        print(f"   ├── Accelerator: {trainer_params['accelerator']}")
        print(f"   ├── Strategy: {trainer_params['strategy']}")
        print(f"   ├── Devices: {trainer_params['devices']}")
        print(f"   ├── Num Nodes: {trainer_params['num_nodes']}")
        print(f"   └── Seed: {self.seed}") 
[docs]
    def _evaluate_model(
        self,
        ckpts_to_evaluate: Optional[Union[str, List[str]]] = None,
        print_summary: bool = True,
        debug: bool = False,
    ):
        # --------- Checkpoints -------
        checkpoints_to_use = self.checkpoint_paths
        # Check which checkpoints to evaluate (else, evaluate all)
        if ckpts_to_evaluate is not None:
            if isinstance(ckpts_to_evaluate, str):
                ckpts_to_evaluate = [ckpts_to_evaluate]
            try:
                checkpoints_to_use = {
                    ckpt: checkpoints_to_use[ckpt] for ckpt in ckpts_to_evaluate
                }
            except KeyError as e:
                raise ValueError(f"Checkpoint {e} not found in {checkpoints_to_use}")
        # Check if any checkpoint is found
        if len(checkpoints_to_use) == 0:
            raise ValueError(f"No checkpoints found in {self._checkpoint_dir}")
        # --------- Dataset -------
        checkpoint_results = {}
        data_module = self.data_module
        if data_module.predict_dataset is None:
            raise ValueError(
                "No predict dataset found in the data module. Please provide a predict dataset to perform evaluation."
            )
        for ckpt_name, ckpt_path in checkpoints_to_use.items():
            predictions_file = None
            results_filename = None
            results_filename_per_sample = None
            results = None
            per_sample_results = None
            if self.save_predictions and not debug:
                predictions_file = self._predictions_dir / f"{ckpt_name}.npy"
            if self.save_results and not debug:
                results_filename = self._results_dir / f"{ckpt_name}.csv"
                results_filename_per_sample = (
                    self._results_dir / f"{ckpt_name}_per_sample.csv"
                )
            # Load the model from the checkpoint
            model = self.model_config.instantiator.load_model_from_checkpoint(ckpt_path)
            # Trainer
            trainer_params = self._trainer_parameters(
                enable_logging=False,
                debug=debug,
            )
            trainer = get_trainer(**trainer_params)
            if print_summary:
                self._print_evaluation_summary(
                    trainer_params=trainer_params,
                    debug=debug,
                    ckpt_path=ckpt_path,
                    predictions_path=predictions_file,
                    results_path=results_filename,
                )
            # Perform prediction
            predictions = perform_predict(
                data_module=data_module,
                model=model,
                trainer=trainer,
            )
            if predictions_file is not None:
                save_predictions(predictions, predictions_file)
            else:
                print("Predictions not saved...")
            # Perform evaluation
            if self.evaluation_metrics:
                results = perform_evaluation(
                    evaluation_metrics=self.evaluation_metrics,
                    data_module=data_module,
                    predictions=predictions,
                    argmax_axis=(
                        1 if self.model_config.information.return_logits else None
                    ),
                    per_sample=False,
                    batch_size=self.data_module._predict_dataloader_kwargs[
                        "batch_size"
                    ],
                    device="cuda" if self.accelerator == "gpu" else "cpu",
                )
                if results_filename is not None:
                    save_results(results, results_filename, index=False)
                else:
                    print(f"Results not saved...")
            else:
                print("No evaluation metrics provided. Skipping evaluation.")
            # Perform per-sample evaluation
            if self.per_sample_evaluation_metrics:
                per_sample_results = perform_evaluation(
                    evaluation_metrics=self.per_sample_evaluation_metrics,
                    data_module=data_module,
                    predictions=predictions,
                    argmax_axis=(
                        1 if self.model_config.information.return_logits else None
                    ),
                    per_sample=True,
                    batch_size=self.data_module._predict_dataloader_kwargs[
                        "batch_size"
                    ],
                    device="cuda" if self.accelerator == "gpu" else "cpu",
                )
                if results_filename_per_sample is not None:
                    save_results(
                        per_sample_results,
                        results_filename_per_sample,
                        index=False,
                    )
                else:
                    print(f"Results not saved...")
            else:
                print(
                    "No per-sample evaluation metrics provided. Skipping per-sample evaluation."
                )
            # Store the results
            checkpoint_results[ckpt_name] = {
                "predictions_path": predictions_file,
                "results_path": results_filename,
                "results_path_per_sample": results_filename_per_sample,
                "results": results,
                "results_per_sample": per_sample_results,
            }
            print(f"Checkpoint {ckpt_name} evaluated!")
        return checkpoint_results 
    # ---------- Default pipeline entrypoints and other methods ---------
[docs]
    def _run(
        self,
        task: str,
        debug: bool = False,
        resume_from_ckpt: Optional[str] = None,
        print_summary: bool = True,
        ckpts_to_evaluate: Optional[Union[str, List[str]]] = None,
    ) -> Dict[str, Any]:
        if task == "fit":
            return self._train_model(
                resume_from_ckpt=resume_from_ckpt,
                print_summary=print_summary,
                debug=debug,
            )
        elif task == "evaluate":
            return self._evaluate_model(
                ckpts_to_evaluate=ckpts_to_evaluate,
                print_summary=print_summary,
                debug=debug,
            )
        elif task == "fit-evaluate":
            # Train the model
            self._train_model(
                resume_from_ckpt=resume_from_ckpt,
                print_summary=print_summary,
                debug=debug,
            )
            # Evaluate the model
            eval_results = self._evaluate_model(
                ckpts_to_evaluate=ckpts_to_evaluate,
                print_summary=print_summary,
                debug=debug,
            )
            return eval_results
        else:
            raise ValueError(
                f"Unknown task '{task}'. Supported tasks are: 'fit', 'evaluate', or 'fit-evaluate'"
            ) 
[docs]
    def cleanup(self):
        """Clean up the experiment by removing the log directory."""
        if self.log_dir.exists():
            shutil.rmtree(self.log_dir)
            print(f"Experiment at '{self.log_dir}' cleaned up.")
        else:
            print(f"Experiment at '{self.log_dir}' not found.") 
    @property
    def status(self) -> Dict[str, Any]:
        d = {}
        d["experiment_name"] = self.experiment_name
        d["log_dir"] = self.log_dir
        d["checkpoints"] = self.checkpoint_paths
        d["training_metrics"] = self.training_metrics_path
        d["prediction_paths"] = self.prediction_paths
        d["results_paths"] = self.results_paths
        state = "not executed"
        if len(d["checkpoints"]) > 0:
            state = "executed"
        if len(d["prediction_paths"]) > 0:
            state = "predicted"
        if len(d["results_paths"]) > 0:
            state = "evaluated"
        d["state"] = state
        return d
    # ---------- Python methods ---------
[docs]
    def __str__(self) -> str:
        def indent_text(text, spaces=6):
            """Indent each line of a string by a given number of spaces."""
            if not text:
                return "No data."
            return "\n".join(
                " " * spaces + line if line.strip() else line
                for line in text.split("\n")
            )
        exp_name = f"🚀 Experiment: {self.experiment_name} 🚀"
        pretrained_backbone = (
            self.pretrained_backbone_ckpt_path
            if self.pretrained_backbone_ckpt_path
            else "FROM SCRATCH"
        )
        limit_batches = "Limit batches: "
        return (
            f"{'=' * 80}\n"
            f"{' ' * ((80 - len(exp_name)) // 2)}{exp_name}\n"
            f"{'=' * 80}\n"
            f"\n🛠 Execution Details\n"
            f"   ├── Execution ID: {self.execution_id}\n"
            f"   ├── Log Dir: {self.log_dir}\n"
            f"   ├── Seed: {self.seed}\n"
            f"   ├── Accelerator: {self.accelerator}\n"
            f"   ├── Devices: {self.devices}\n"
            f"   ├── Max Epochs: {self.max_epochs}\n"
            f"   ├── Train Batches Limit: {self.limit_train_batches or 'all'}\n"
            f"   ├── Val Batches Limit: {self.limit_val_batches or 'all'}\n"
            f"   └── Test Batches Limit: {self.limit_test_batches or 'all'}\n"
            f"\n"
            # f"{'=' * 50}\n"
            f"🧠 Model Information\n"
            f"   ├── Model Name: {self.model_config.information.name}\n"
            f"   ├── Pretrained Backbone: {pretrained_backbone}\n"
            f"   ├── Input Shape: {self.model_config.information.input_shape}\n"
            f"   ├── Output Shape: {self.model_config.information.output_shape}\n"
            f"   └── Num Classes: {self.model_config.information.num_classes}\n"
            f"\n📂 Dataset Information\n"
            f"{indent_text(str(self.data_module), spaces=6)}\n"
        )