from abc import ABC, abstractmethod
from contextlib import contextmanager
from pathlib import Path
import shutil
import sys
import time
import traceback
from typing import Any, Dict, List, Optional, Tuple, Union
import lightning as L
import torch
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.callbacks import (
ModelCheckpoint,
TQDMProgressBar as ProgressBar,
)
import numpy as np
import pandas as pd
import torchmetrics
import tqdm
from minerva.data.data_modules.base import MinervaDataModule
from minerva.pipelines.base import Pipeline
from minerva.utils.typing import PathLike
from dataclasses import asdict, dataclass
from minerva.utils.string_ops import tree_like_formating, indent_text
from minerva.utils.output import Tee
from minerva.models.loaders import FromPretrained
import os
[docs]
class ModelInstantiator(ABC):
"""Abstract base class for lazy instantiation of PyTorch Lightning models.
This interface defines a standardized way to construct models in three
common training scenarios:
1. Training from scratch: the entire model (backbone + head) is randomly
initialized.
2. Finetuning: a pretrained backbone is loaded from a checkpoint, while the
head is randomly initialized.
3. Inference/Evaluation: the full model is restored from a previously
saved checkpoint. Usually, this checkpoint is generated using one of the
two scenarios above.
This abstraction allows for flexible and decoupled model construction across
various stages of the machine learning lifecycle. Thus, is expected that
model's architecture follows the same pattern as the one below:
+-------------------------------+
| Model (LightningModule) |
| |
| +-----------------+ |
| | Backbone | | --> Feature extractor
| +-----------------+ |
| | |
| v |
| +----------+ |
| | Head | | --> Task-specific layers
| +----------+ |
+-------------------------------+
Definitions
-----------
- Backbone: Core feature extractor (e.g., ResNet, Transformer encoder).
- Head: Task-specific layers (e.g., classification head, regression head).
Implementations of this class should handle the appropriate model loading
logic for each use case described above.
"""
[docs]
@abstractmethod
def create_model_randomly_initialized(self) -> L.LightningModule:
"""Create a model with both backbone and head randomly initialized.
Typically used when training a model from scratch.
Returns
-------
L.LightningModule
A Lightning model fully initialized with random weights, ready for
training.
"""
raise NotImplementedError(
"create_model_randomly_initialized must be implemented."
)
[docs]
@abstractmethod
def create_model_and_load_backbone(
self, backbone_checkpoint_path: PathLike
) -> L.LightningModule:
"""Create a model for finetuning with a pretrained backbone and a
new head (randomly initialized). This method should load the backbone
weights from the specified checkpoint and attach a freshly initialized
head for the downstream task. User must handle the logic to load the
backbone weights into the model's state dict.
Parameters
----------
backbone_checkpoint_path : PathLike
Path to the checkpoint containing pretrained backbone weights. The
checkpoint must be compatible with the model architecture.
Returns
-------
L.LightningModule
The model ready for finetuning (pretrained backbone, new head).
"""
pass
[docs]
@abstractmethod
def load_model_from_checkpoint(
self, checkpoint_path: PathLike
) -> L.LightningModule:
"""Load the full model (backbone and head) from a saved checkpoint.
Typically used for resuming training, evaluation, or inference when the
model must be restored in its entirety. In practice, the checkpoint
should be one created using `create_model_and_load_backbone` or
`create_model_randomly_initialized`.
The checkpoint must be compatible with the model architecture.
Parameters
----------
checkpoint_path : PathLike
Path to the checkpoint file containing the full model state.
Returns
-------
L.LightningModule
A Lightning model fully restored from checkpoint, ready for
evaluation or inference.
"""
raise NotImplementedError(
"load_model_from_checkpoint must be implemented in the subclass."
)
[docs]
class ModelConfig:
"""Encapsulates the full configuration of a model for use in a training or
inference pipeline.
A `ModelConfig` brings together two key components:
- `ModelInstantiator`: Responsible for creating the model in different
modes (lazily instantiated):
- From scratch (randomly initialized)
- Finetuning (load pretrained backbone, new head)
- From checkpoint (fully restored model)
- `ModelInformation`: Contains descriptive metadata about the model such
as input/output shapes, number of classes, backbone used, and task type.
This class serves as the primary interface for managing and accessing
model configuration throughout the lifecycle of training, evaluation, or
deployment.
"""
def __init__(
self,
instantiator: ModelInstantiator,
information: ModelInformation,
):
"""Initialize a model configuration.
Parameters
----------
instantiator : ModelInstantiator
An instance responsible for constructing the model in various
training modes (random init, load backbone, load full checkpoint).
This enables lazy instantiation depending on the training phase.
information : ModelInformation
Metadata describing the model's architecture and behavior.
Includes input/output shapes, task type, number of classes, and
other relevant info useful for logging, validation, and downstream
processing.
"""
self.instantiator = instantiator
self.information = information
[docs]
def __str__(self):
return (
f"ModelConfig\n"
+ f"├── Instantiator: {self.instantiator.__class__.__name__}\n"
+ indent_text(tree_like_formating(asdict(self.information)), spaces=0)
)
# -------- Functional interfaces --------
[docs]
def get_trainer(
log_dir: Path,
max_epochs: int = 100,
limit_train_batches: Optional[Union[int, float]] = None,
limit_val_batches: Optional[Union[int, float]] = None,
limit_test_batches: Optional[Union[int, float]] = None,
limit_predict_batches: Optional[Union[int, float]] = None,
accelerator: str = "auto",
strategy: str = "auto",
devices: Optional[Union[int, list[int], str]] = "auto",
num_nodes: int = 1,
progress_bar_refresh_rate: int = 1,
enable_logging: bool = True,
checkpoint_metrics: Optional[List[Dict[str, str]]] = None,
precision: str = "32-true",
accumulate_grad_batches: int = 1,
deterministic: bool = False,
benchmark: bool = True,
profiler: Optional[str] = None,
overfit_batches: Union[int, float] = 0.0,
sync_batchnorm: bool = False,
callbacks: Optional[List[L.Callback]] = None,
) -> L.Trainer:
"""Creates and configures a PyTorch Lightning Trainer instance.
This function encapsulates all necessary options for flexible training,
evaluation, or inference, including logging, checkpointing, device setup,
precision, and more.
Parameters
----------
log_dir : Path
Directory path where logs and checkpoints will be saved.
max_epochs : int, default=100
Maximum number of epochs for training.
limit_train_batches : int or float, optional
Limit on the number of training batches per epoch. Can be an integer
(absolute number) or a float (fraction of total batches).
limit_val_batches : int or float, optional
Limit on the number of validation batches per epoch.
limit_test_batches : int or float, optional
Limit on the number of test batches per epoch.
limit_predict_batches : int or float, optional
Limit on the number of prediction batches.
accelerator : str, default="auto"
Hardware accelerator to use (e.g., "gpu", "cpu", "tpu", "auto").
strategy : str, default="auto"
Distributed training strategy (e.g., "ddp", "deepspeed", etc.).
devices : int, list of int, or str, optional, default="auto"
Devices to use for training (e.g., 1, [0,1], "auto").
num_nodes : int, default=1
Number of nodes to use for distributed training.
progress_bar_refresh_rate : int, default=1
Frequency (in steps) at which the progress bar is updated.
Set to 0 to disable.
enable_logging : bool, default=True
Whether to enable CSV logging.
checkpoint_metrics : list of dict, optional
List of dictionaries containing checkpoint configurations. Each
dictionary should specify "monitor", "mode", and "filename".
precision : str, default="32-true"
Numerical precision to use during training (e.g., 32-true, 16-mixed).
accumulate_grad_batches : int, default=1
Number of batches for which gradients should be accumulated before
performing an optimizer step.
deterministic : bool, default=False
If True, sets deterministic behavior for reproducibility.
benchmark : bool, default=True
Enables the cudnn.benchmark flag for optimized performance on fixed
input sizes.
profiler : str, optional
Enables performance profiling (e.g., "simple", "advanced").
overfit_batches : int or float, default=0.0
Uses a fraction or number of batches for both training and validation
to quickly debug overfitting behavior.
sync_batchnorm : bool, default=False
Synchronizes batch norm layers across devices during distributed
training.
callbacks : list of L.Callback, optional
List of PyTorch Lightning callbacks to be used during training. Checkpointing
is enabled by default if `checkpoint_metrics` is provided.
Returns
-------
L.Trainer
A configured PyTorch Lightning Trainer instance.
"""
if enable_logging:
logger = CSVLogger(
save_dir=log_dir.parents[1],
name=log_dir.parents[0].name,
version=log_dir.name,
)
else:
logger = False
if callbacks is None:
callbacks = []
enable_checkpointing = False
if checkpoint_metrics:
for ckpt_metric in checkpoint_metrics:
ckpt_kwargs = {
"monitor": ckpt_metric["monitor"],
"mode": ckpt_metric["mode"],
"filename": ckpt_metric["filename"],
"save_last": False,
"enable_version_counter": False,
}
callbacks.append(ModelCheckpoint(**ckpt_kwargs))
enable_checkpointing = True
else:
enable_checkpointing = False
enable_progress_bar = True
log_every_n_steps = None
if progress_bar_refresh_rate == 0:
enable_progress_bar = False
else:
callbacks.append(ProgressBar(refresh_rate=progress_bar_refresh_rate))
return L.Trainer(
accelerator=accelerator,
devices=devices, # type: ignore
logger=logger,
callbacks=callbacks,
enable_checkpointing=enable_checkpointing,
enable_progress_bar=enable_progress_bar,
enable_model_summary=True,
log_every_n_steps=log_every_n_steps,
max_epochs=max_epochs,
strategy=strategy,
num_nodes=num_nodes,
limit_train_batches=limit_train_batches,
limit_val_batches=limit_val_batches,
limit_test_batches=limit_test_batches,
limit_predict_batches=limit_predict_batches,
precision=precision, # type: ignore
accumulate_grad_batches=accumulate_grad_batches,
deterministic=deterministic,
benchmark=benchmark,
inference_mode=True,
profiler=profiler,
overfit_batches=overfit_batches,
sync_batchnorm=sync_batchnorm,
)
[docs]
def save_predictions(
predictions: Union[np.ndarray, torch.Tensor], path: PathLike
) -> None:
"""Save predictions to a given path.
Parameters
----------
predictions : Union[np.ndarray, torch.Tensor]
The prediction data to save.
path : PathLike
The path where the predictions will be saved.
"""
if isinstance(predictions, torch.Tensor):
predictions = predictions.cpu().numpy()
if not isinstance(predictions, np.ndarray):
raise ValueError("Predictions must be a numpy array.")
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
if path.suffix == "npz":
np.savez(path, predictions=predictions)
else:
np.save(path, predictions)
print(f"Predictions saved to {path}")
[docs]
def load_predictions(path: PathLike) -> np.ndarray:
"""Load a prediction from a given path.
Parameters
----------
path : PathLike
The path to the prediction file.
Returns
-------
np.ndarray
The loaded prediction data.
"""
path = Path(path)
if path.suffix == "npz":
data = np.load(path, allow_pickle=True)
if "predictions" in data:
return data["predictions"]
else:
raise ValueError("No 'predictions' key found in the npz file.")
else:
return np.load(path, allow_pickle=True)
[docs]
def save_results(results: pd.DataFrame, path: PathLike, index: bool = False) -> None:
"""Save results to a given path.
Parameters
----------
results : pd.DataFrame
The results data to save.
path : PathLike
The path where the results will be saved.
index : bool, optional
Whether to save the index of the DataFrame, by default False
"""
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
results.to_csv(path, index=index)
print(f"Results saved to {path}")
[docs]
def load_results(path: PathLike) -> pd.DataFrame:
"""Load results from a given path.
Parameters
----------
path : PathLike
The path to the results file.
Returns
-------
pd.DataFrame
The loaded results data.
"""
path = Path(path)
if not path.is_file():
raise ValueError(f"File {path} does not exist.")
return pd.read_csv(path)
[docs]
class InstantiatedModel(ModelInstantiator):
"""Encapsulates a PyTorch Lightning model into a ModelInstantiator.
This class is used to wrap an existing instance of a PyTorch Lightning
model (L.LightningModule) into a ModelInstantiator interface. It allows
to use the model randomly initialized, but does not support
finetuning with a pretrained backbone.
"""
def __init__(self, model: L.LightningModule):
"""Initialize the CachedModelInstantiator with a PyTorch Lightning model.
Parameters
----------
model : L.LightningModule
The PyTorch Lightning model to be wrapped. This model should
already be defined and instantiated.
"""
self.model = model
[docs]
def create_model_randomly_initialized(self) -> L.LightningModule:
"""Create a model with both backbone and head randomly initialized."""
return self.model
[docs]
def create_model_and_load_backbone(
self, backbone_checkpoint_path: PathLike
) -> L.LightningModule:
raise Exception(
"You passed an instance of L.LightningModule as model_config, and does not support "
+ "finetuning with a pretrained backbone. Please use a ModelInstantiator "
)
[docs]
def load_model_from_checkpoint(self, checkpoint_path):
return FromPretrained(
model=self.model,
ckpt_path=checkpoint_path,
)
[docs]
@contextmanager
def using_redirected_outputs(
output_path: PathLike, error_path: PathLike, mode: str = "a", ignore: bool = False
):
if ignore:
try:
yield
finally:
pass
else:
output_path = Path(output_path)
error_path = Path(error_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
error_path.parent.mkdir(parents=True, exist_ok=True)
original_stdout = sys.stdout
original_stderr = sys.stderr
tee_out = Tee(output_path, mode)
tee_err = Tee(error_path, mode)
sys.stdout = tee_out
sys.stderr = tee_err
try:
yield
except Exception as e:
traceback.print_exc(file=sys.stderr)
raise e
finally:
sys.stdout.flush()
sys.stderr.flush()
sys.stdout = original_stdout
sys.stderr = original_stderr
tee_out.close()
tee_err.close()
[docs]
def timestamp_now():
return time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())
[docs]
def get_run_id():
job_id = os.getenv("SLURM_JOB_ID", None)
if job_id is not None:
return f"_slurm_{job_id}"
else:
return f"_local_{timestamp_now()}"
[docs]
class Experiment(Pipeline):
NUM_DEBUG_EPOCHS = 3
NUM_DEBUG_BATCHES = 10
def __init__(
self,
# Base parameters
experiment_name: str,
model_config: Union[ModelConfig, L.LightningModule],
data_module: MinervaDataModule,
# Logging and checkpointing parameters
pretrained_backbone_ckpt_path: Optional[PathLike] = None,
root_log_dir: PathLike = "./logs",
execution_id: Union[str, int] = 0,
checkpoint_metrics: Optional[List[Dict[str, str]]] = None,
# Trainer-related parameters
max_epochs: int = 100,
accelerator: str = "gpu",
devices: Optional[Union[int, list[int], str]] = 1,
strategy: str = "auto",
num_nodes: int = 1,
limit_train_batches: Optional[Union[int, float]] = None,
limit_val_batches: Optional[Union[int, float]] = None,
limit_test_batches: Optional[Union[int, float]] = None,
limit_predict_batches: Optional[Union[int, float]] = None,
callbacks: Optional[List[L.Callback]] = None,
# Prediction parameters
evaluation_metrics: Optional[Dict[str, torchmetrics.Metric]] = None,
per_sample_evaluation_metrics: Optional[Dict[str, torchmetrics.Metric]] = None,
# Other parameters
seed: Optional[int] = None,
progress_bar_refresh_rate: int = 1,
profiler: Optional[str] = None,
save_predictions: bool = False,
save_results: bool = True,
add_last_checkpoint: bool = True,
log_outputs: bool = True,
_run_id: Optional[str] = None,
):
"""An experiment is a pipeline that contains all the parameters needed
to train and evaluate a model, as well as to manage the logging,
checkpointing, prediction, and results processes in a coherent way.
Parameters
----------
experiment_name : str
The name of the experiment. This name will be used to create a
directory for the experiment in the log directory.
model_config : ModelConfig
The model configuration. This object contains the model instantiator
and the model information.
data_module : MinervaDataModule
The data module. This object contains the training, validation, and
test datasets, as well as the data loaders. For now, datasets must
return a 2 element tuple (input, label) for each sample.
pretrained_backbone_ckpt_path : Optional[PathLike], optional
The path to the pretrained backbone checkpoint. This is used to
finetune the model. If None, the model will be trained from
scratch. This parameter handles the lazy instantiation of the model
and calls `create_model_and_load_backbone` method of the model
instantiator if `pretrained_backbone_ckpt_path` is not None or
`create_model_randomly_initialized` method if it is None. By
default None
root_log_dir : PathLike, optional
Root directory for logging and checkpoints. This directory will be
used to create a subdirectory for the experiment. By default ./logs
execution_id : Union[str, int], optional
The execution ID for the experiment. This ID will be used to create
a subdirectory for the experiment in the log directory. This is
useful when running the experiment multiple times with the same
parameters. By default 0
checkpoint_metrics : Optional[List[Dict[str, str]]], optional
The checkpoint metrics. This is a list of dictionaries that contain
the checkpoint metrics. Each dictionary must contain the keys
"monitor", "mode", and "filename". The "monitor" key is the name of
the metric to monitor, the "mode" key is the mode of the metric
("min" or "max"), and the "filename" key is the name of the
checkpoint file. The "monitor" key can be None if the checkpoint is
the last one. By default None
max_epochs : int, optional
Number of epochs to train the model. This parameter is passed to the
`get_trainer` function. By default 100.
accelerator : str, optional
The accelerator to use for training. This parameter is passed to the
`get_trainer` function. By default "gpu". Possible values are
"cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto". If "auto" is
selected, the accelerator will be automatically selected based on
the available hardware. By default "gpu"
devices : Optional[Union[int, list[int], str]], optional
Number of accelerators to use for training. This parameter is
passed to the `get_trainer` function. By default 1.
strategy : str, optional
Strategy to use for distributed training. This parameter is passed
to the `get_trainer` function. By default "auto".
num_nodes : int, optional
Number of nodes to use for distributed training. This parameter is
passed to the `get_trainer` function. By default 1.
limit_train_batches : Optional[Union[int, float]], optional
Limit the number of training batches to use. This parameter is
passed to the `get_trainer` function. By default None. If None, all
batches will be used. If an integer is provided, it will be the
absolute number of batches. If a float is provided, it will be the
fraction of the total number of batches. For example, 0.1 means 10%
of the training batches will be used.
limit_val_batches : Optional[Union[int, float]], optional
Limit the number of validation batches to use. This parameter is
passed to the `get_trainer` function. By default None. If None, all
batches will be used. If an integer is provided, it will be the
absolute number of batches. If a float is provided, it will be the
fraction of the total number of batches. For example, 0.1 means 10%
of the validation batches will be used.
limit_test_batches : Optional[Union[int, float]], optional
Limit the number of test batches to use. This parameter is
passed to the `get_trainer` function. By default None. If None, all
batches will be used. If an integer is provided, it will be the
absolute number of batches. If a float is provided, it will be the
fraction of the total number of batches. For example, 0.1 means 10%
of the test batches will be used.
limit_predict_batches : Optional[Union[int, float]], optional
Limit the number of prediction batches to use. This parameter is
passed to the `get_trainer` function. By default None. If None, all
batches will be used. If an integer is provided, it will be the
absolute number of batches. If a float is provided, it will be the
fraction of the total number of batches. For example, 0.1 means 10%
of the prediction batches will be used.
evaluation_metrics : Optional[Dict[str, torchmetrics.Metric]], optional
A dictionary of evaluation metrics to use for the predictions. The
keys are the names of the metrics and the values are the
`torchmetrics.Metric` objects. These metrics are calculated using
all the predictions. By default None.
per_sample_evaluation_metrics : Optional[ Dict[str, torchmetrics.Metric] ], optional
A dictionary of evaluation metrics to use for the predictions. The
keys are the names of the metrics and the values are the
`torchmetrics.Metric` objects. These metrics are calculated using
each prediction separately, that is, applyied per sample. By
default None.
seed : Optional[int], optional
The seed to use for the experiment, by default None
progress_bar_refresh_rate : int, optional
The refresh rate of the progress bar (in batches). If 0, the
progress bar is disabled. If 1, the progress bar is updated every
batch. By default 1
profiler : Optional[str], optional
A profiler to use for the experiment. This parameter is passed to
the `get_trainer` function. By default None.
save_predictions : bool, optional
If True, the predictions will be saved to the log directory. By
default False.
save_results : bool, optional
If True, the results will be saved to the log directory. By
default True
add_last_checkpoint : bool, optional
If True, the last checkpoint will be added to the list of checkpoint
metrics. By default True.
log_outputs : bool, optional
If True, the standard output and error will be logged to files in the
log directory. By default True.
_run_id : Optional[str], optional
An internal run ID for the experiment. This is used to
differentiate between different runs of the same experiment. If
None, a timestamp will be used. By default None.
Raises
------
ValueError
If the checkpoint metrics are not valid or do not contain the
required keys.
Notes
------
- This class assumes that the `MinervaDataModule` class returns a
(input, label) tuple for each sample in the dataset. The input is
the data and the label is the ground truth/target.
"""
# ------- Base parameters -------
self.experiment_name = experiment_name
self.model_config = model_config
self.data_module = data_module
if isinstance(model_config, L.LightningModule):
instantiator = InstantiatedModel(model_config)
self.model_config = ModelConfig(
instantiator=instantiator,
information=ModelInformation(
name=instantiator.model.__class__.__name__,
),
)
print(
f"Using an already instantiated model: {self.model_config.information.name}."
)
if pretrained_backbone_ckpt_path is not None:
raise ValueError(
"You passed an instance of L.LightningModule as model_config, "
+ "as well as a pretrained_backbone_ckpt_path. "
+ "This is not supported. Please use a ModelInstantiator instead."
)
elif isinstance(model_config, ModelConfig):
self.model_config = model_config
else:
raise ValueError(
"model_config must be an instance of ModelConfig or L.LightningModule."
)
# ------- Logging and checkpointing parameters -------
self.pretrained_backbone_ckpt_path = pretrained_backbone_ckpt_path
self.root_log_dir = Path(root_log_dir)
self.execution_id = str(execution_id)
self.checkpoint_metrics = checkpoint_metrics or []
# Check if checkpoint metrics are valid
for ckpt_metric in self.checkpoint_metrics:
if not isinstance(ckpt_metric, dict):
raise ValueError("Checkpoint metric must be a dictionary.")
for key in ["monitor", "mode", "filename"]:
if key not in ckpt_metric:
raise ValueError(f"Checkpoint metric must contain a '{key}' key.")
# Add the "last" checkpoint metric if not already present
if add_last_checkpoint:
if not any(
ckpt_metric.get("filename") == "last"
for ckpt_metric in self.checkpoint_metrics
):
self.checkpoint_metrics.append(
{"monitor": None, "mode": "min", "filename": "last"} # type: ignore
)
# ------- Trainer-related parameters -------
self.max_epochs = max_epochs
self.accelerator = accelerator
self.devices = devices
self.strategy = strategy
self.num_nodes = num_nodes
self.limit_train_batches = limit_train_batches
self.limit_val_batches = limit_val_batches
self.limit_test_batches = limit_test_batches
self.limit_predict_batches = limit_predict_batches
self.callbacks = callbacks
# ------- Prediction parameters -------
self.evaluation_metrics = evaluation_metrics or {}
self.per_sample_evaluation_metrics = per_sample_evaluation_metrics or {}
# ------- Other parameters -------
self.seed = seed
self.progress_bar_refresh_rate = progress_bar_refresh_rate
self.profiler = profiler
self.save_predictions = save_predictions
self.save_results = save_results
self.log_outputs = log_outputs
# ------- Initialize the pipeline -------
log_dir = (
self.root_log_dir
/ self.experiment_name
/ self.data_module.dataset_name
/ self.model_config.information.name
/ self.execution_id
)
super().__init__(
log_dir=log_dir,
cache_result=False,
save_run_status=False,
seed=seed,
ignore=["model_config", "data_module"],
)
self._checkpoint_dir = log_dir / "checkpoints"
self._predictions_dir = log_dir / "predictions"
self._results_dir = log_dir / "results"
self._training_metrics_path = log_dir / "metrics.csv"
self._run_id = _run_id or timestamp_now()
self._stdout_path = log_dir / "runs" / f"stdout_{self._run_id}.log"
self._stderr_path = log_dir / "runs" / f"stderr_{self._run_id}.log"
self._stdout_path.parent.mkdir(parents=True, exist_ok=True)
self._stderr_path.parent.mkdir(parents=True, exist_ok=True)
@property
def run_id(self) -> str:
"""Returns the run ID of the experiment.
Returns
-------
str
The run ID of the experiment.
"""
return self._run_id
# ------------ Acessors ------------
# Here we have acess to:
# - checkpoint paths
# - metrics and metrics path
# - prediction paths
# - results and results path
@property
def checkpoint_paths(self) -> Dict[str, Path]:
"""Returns a dictionary of checkpoint paths for the experiment.
The keys are the checkpoint names, and the values are the corresponding
paths to the checkpoints.
Returns
-------
Dict[str, Path]
A dictionary mapping checkpoint names to their respective paths.
"""
return {p.stem: p for p in self._checkpoint_dir.glob("*.ckpt") if p.is_file()}
@property
def training_metrics_path(self) -> Optional[Path]:
"""The path to the training metrics file.
Returns
-------
Optional[Path]
The path to the metrics file if it exists, otherwise None.
"""
if self._training_metrics_path.is_file():
return self._training_metrics_path
return None
@property
def training_metrics(self) -> Optional[pd.DataFrame]:
"""Returns the training metrics as a pandas DataFrame.
If the metrics file does not exist, returns None.
Returns
-------
Optional[pd.DataFrame]
A DataFrame containing the training metrics.
"""
# Check if the metrics file exists and is a file
path = self.training_metrics_path
if path:
return pd.read_csv(self._training_metrics_path)
else:
return None
@property
def prediction_paths(self) -> Dict[str, Path]:
"""Returns a dictionary of prediction paths for the experiment.
The keys are the prediction names, and the values are the corresponding
paths to the predictions.
Returns
-------
Dict[str, Path]
A dictionary mapping prediction names to their respective paths.
"""
return {p.stem: p for p in self._predictions_dir.glob("*.npy") if p.is_file()}
[docs]
def load_predictions_of_ckpt(self, name: str) -> np.ndarray:
"""Load predictions from a file.
Parameters
----------
name : str
The name of the prediction file (without extension).
Returns
-------
np.ndarray
The loaded predictions as a numpy array.
"""
try:
path = self.prediction_paths[name]
return load_predictions(path)
except KeyError:
raise Exception(
f"Prediction file '{name}' not found in {self._predictions_dir}"
)
@property
def results_paths(self) -> Dict[str, Path]:
"""Returns a dictionary of results paths for the experiment.
The keys are the result names, and the values are the corresponding
paths to the results.
Returns
-------
Dict[str, Path]
A dictionary mapping result names to their respective paths.
"""
return {p.stem: p for p in self._results_dir.glob("*.csv") if p.is_file()}
[docs]
def load_results_of_ckpt(self, name: str) -> pd.DataFrame:
"""Load results from a file.
Parameters
----------
name : str
The name of the result file (without extension).
Returns
-------
pd.DataFrame
The loaded results as a pandas DataFrame.
"""
try:
path = self.results_paths[name]
return load_results(path)
except KeyError:
raise Exception(f"Results file '{name}' not found in {self._results_dir}")
# ---------- Trainer ---------
[docs]
def _trainer_parameters(
self, enable_logging: bool = True, debug: bool = False
) -> Dict[str, Any]:
"""Return the parameters for the trainer based on the current on debug
and logging settings.
Parameters
----------
enable_logging : bool, optional
If True, logging will be enabled, by default True
debug : bool, optional
If True, model will be trained with a few batches and for a few
epochs only. Logging will always be disabled, by default False
Returns
-------
Dict[str, Any]
All the parameters for the `get_trainer` function.
"""
return {
"log_dir": self.log_dir,
"max_epochs": self.NUM_DEBUG_EPOCHS if debug else self.max_epochs,
"limit_train_batches": (
self.NUM_DEBUG_BATCHES if debug else self.limit_train_batches
),
"limit_val_batches": (
self.NUM_DEBUG_BATCHES if debug else self.limit_val_batches
),
"limit_test_batches": (
self.NUM_DEBUG_BATCHES if debug else self.limit_test_batches
),
"limit_predict_batches": (
self.NUM_DEBUG_BATCHES if debug else self.limit_predict_batches
),
"accelerator": self.accelerator,
"strategy": self.strategy,
"devices": self.devices,
"num_nodes": self.num_nodes,
"progress_bar_refresh_rate": self.progress_bar_refresh_rate,
"enable_logging": False if debug else enable_logging,
"checkpoint_metrics": None if debug else self.checkpoint_metrics,
"precision": "32-true",
"deterministic": False,
"benchmark": True,
"profiler": self.profiler,
"callbacks": self.callbacks,
}
# ---------- FIT Experiment methods ---------
@staticmethod
def __typing_string(value):
if isinstance(value, torch.Tensor) or isinstance(value, np.ndarray):
return f"shape={tuple(value.shape)}"
else:
return f"scalar with type={type(value).__name__}"
[docs]
def _print_train_summary(
self,
model: L.LightningModule,
trainer_params: Dict[str, Any],
debug: bool = False,
resume_from_ckpt: Optional[str] = None,
) -> None:
print("\n" + "=" * 80)
print(
f"Experiment: {self.experiment_name} {'(DEBUG)' if debug else ''}".center(
80
)
)
print("=" * 80)
# ------------ Model info ------------
finetune_backbone = self.pretrained_backbone_ckpt_path is not None
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("🧠 Model")
print(f" ├── Name: {self.model_config.information.name}")
print(f" ├── Finetune: {'Yes' if finetune_backbone else 'No'}")
if finetune_backbone:
print(
f" | └── Pretrained Backbone Checkpoint: {self.pretrained_backbone_ckpt_path}"
)
print(f" ├── Resumed From: {resume_from_ckpt or 'Beginning'}")
print(
f" ├── Expected Input Shape: {self.model_config.information.input_shape}"
)
print(
f" ├── Expected Output Shape: {self.model_config.information.output_shape}"
)
print(f" ├── Total Params: {total_params:,}")
try:
print(
f" └── Trainable Params: {trainable_params:,} ({trainable_params / total_params:.2%})"
)
except ZeroDivisionError:
print(" └── Trainable Params: 0 (0.00%)")
# ------------ Dataset info ------------
print("\n📊 Dataset")
train_data = self.data_module.train_dataset
val_data = self.data_module.val_dataset
if train_data:
x, y = train_data[0]
print(f" ├── Train Samples: {len(train_data)}")
print(f" | ├── Input Shape: {self.__typing_string(x)}")
print(f" | └── Label Shape: {self.__typing_string(y)}")
else:
print(" ├── Train Dataset: None")
if val_data:
x, y = val_data[0]
print(f" └── Validation Samples: {len(val_data)}")
print(f" ├── Input Shape: {self.__typing_string(x)}")
print(f" └── Label Shape: {self.__typing_string(y)}")
else:
print(" └── Validation Dataset: None")
# ------------ Logging & checkpoints ------------
ckpt_filenames = ", ".join(
[
f"{m['filename']}.ckpt"
for m in self.checkpoint_metrics
if m.get("filename")
]
)
checkpoints_exist = self.checkpoint_paths
print("\n💾 Logging & Checkpoints")
print(f" ├── Log Dir: {self.log_dir}")
print(f" ├── Metrics Path: {self._training_metrics_path}")
print(f" └── Checkpoints Dir: {self._checkpoint_dir}")
if checkpoints_exist:
print(f" ├── Files: {ckpt_filenames or 'None'}")
print(f" └── ⚠️ Existing checkpoints found! It will be overwritten!")
else:
print(f" └── Files: {ckpt_filenames or 'None'}")
# ------------ Trainer configuration ------------
print("\n⚙️ Trainer Config")
print(f" ├── Max Epochs: {trainer_params['max_epochs']}")
print(f" ├── Train Batches: {trainer_params['limit_train_batches']}")
print(f" ├── Accelerator: {trainer_params['accelerator']}")
print(f" ├── Strategy: {trainer_params['strategy']}")
print(f" ├── Devices: {trainer_params['devices']}")
print(f" ├── Num Nodes: {trainer_params['num_nodes']}")
print(f" └── Seed: {self.seed}")
[docs]
def _train_model(
self,
resume_from_ckpt: Optional[str] = None,
debug: bool = False,
print_summary: bool = True,
) -> Dict[str, Any]:
data_module = self.data_module
# If pre-trained backbone is provided, load the model with the
# pre-trained backbone (usually for finetuning)
if self.pretrained_backbone_ckpt_path is not None:
model = self.model_config.instantiator.create_model_and_load_backbone(
self.pretrained_backbone_ckpt_path
)
# If no pre-trained backbone is provided, create a model with
# randomly initialized backbone and head (full supervised training)
else:
model = self.model_config.instantiator.create_model_randomly_initialized()
# Get the trainer
trainer_params = self._trainer_parameters(
enable_logging=True,
debug=debug,
)
trainer = get_trainer(**trainer_params)
# Check if need to resume from a checkpoint
checkpoints = self.checkpoint_paths
if resume_from_ckpt:
if resume_from_ckpt not in checkpoints:
raise ValueError(
f"Checkpoint '{resume_from_ckpt}' not found in {list(checkpoints.keys())}"
)
resume_from_ckpt = checkpoints[resume_from_ckpt] # type: ignore
# Print the training summary
if print_summary:
self._print_train_summary(
model=model,
trainer_params=trainer_params,
debug=debug,
resume_from_ckpt=resume_from_ckpt,
)
# Train the model
perform_train(
data_module=data_module,
model=model,
trainer=trainer,
resume_from_ckpt=resume_from_ckpt,
)
return {
"data_module": data_module,
"model": model,
"trainer": trainer,
"log_dir": self.log_dir,
"metrics_path": self._training_metrics_path,
"checkpoints": self.checkpoint_paths,
}
# ---------- EVALUATE Experiment methods ---------
[docs]
def _print_evaluation_summary(
self,
trainer_params: Dict[str, Any],
debug: bool = False,
ckpt_path: Optional[PathLike] = None,
predictions_path: Optional[PathLike] = None,
results_path: Optional[PathLike] = None,
) -> None:
if ckpt_path is not None:
ckpt_path_used = ckpt_path.name # type: ignore
else:
ckpt_path_used = "cached model (no checkpoint)"
print("\n" + "=" * 80)
print(
f"Evaluation: {self.experiment_name} ({ckpt_path_used}) {'(DEBUG)' if debug else ''}".center(
80
)
)
print("=" * 80)
# ------------ Checkpoint Info ------------
print("💾 Checkpoint")
print(f" ├── Checkpoint Path: {ckpt_path}")
print(f" └── Predictions Path: {predictions_path or 'Not saved'}")
# ------------ Dataset Info ------------
print("\n📊 Dataset")
predict_data = self.data_module.predict_dataset
if predict_data:
x, y = predict_data[0]
print(f" ├── Predict Samples: {len(predict_data)}")
print(f" ├── Input: {self.__typing_string(x)}")
print(f" └── Label: {self.__typing_string(y)}")
else:
print(" └── Predict Dataset: None")
# ------------ Evaluation Metrics ------------
print("\n📈 Evaluation Metrics")
if self.evaluation_metrics or self.per_sample_evaluation_metrics:
for name, metric in self.evaluation_metrics.items():
print(f" ├── {name}: {metric.__class__.__name__}")
for name, metric in self.per_sample_evaluation_metrics.items():
print(f" ├── {name}: {metric.__class__.__name__} (PER_SAMPLE)")
else:
print(" └── No evaluation metrics provided.")
# ------------ Trainer Configuration ------------
print("\n⚙️ Trainer Config")
print(f" ├── Max Epochs: {trainer_params['max_epochs']}")
print(f" ├── Predict Batches: {trainer_params['limit_predict_batches']}")
print(f" ├── Accelerator: {trainer_params['accelerator']}")
print(f" ├── Strategy: {trainer_params['strategy']}")
print(f" ├── Devices: {trainer_params['devices']}")
print(f" ├── Num Nodes: {trainer_params['num_nodes']}")
print(f" └── Seed: {self.seed}")
[docs]
def _evaluate_model(
self,
ckpts_to_evaluate: Optional[Union[str, List[str]]] = None,
print_summary: bool = True,
debug: bool = False,
cached_execution: Optional[Dict[str, Any]] = None,
):
# --------- Checkpoints or cached model -------
cached = False
if cached_execution is not None:
checkpoints_to_use = {"last_cached": cached_execution["model"]}
cached = True
else:
checkpoints_to_use = self.checkpoint_paths
# Check which checkpoints to evaluate (else, evaluate all)
if ckpts_to_evaluate is not None:
if isinstance(ckpts_to_evaluate, str):
ckpts_to_evaluate = [ckpts_to_evaluate]
try:
checkpoints_to_use = {
ckpt: checkpoints_to_use[ckpt] for ckpt in ckpts_to_evaluate
}
except KeyError as e:
raise ValueError(
f"Checkpoint {e} not found in {checkpoints_to_use}"
)
# Check if any checkpoint is found
if len(checkpoints_to_use) == 0:
raise ValueError(f"No checkpoints found in {self._checkpoint_dir}")
# --------- Dataset -------
checkpoint_results = {}
data_module = self.data_module
if data_module.predict_dataset is None:
raise ValueError(
"No predict dataset found in the data module. Please provide a predict dataset to perform evaluation."
)
# Note: if cached_execution is provided, the model is already loaded,
# thus cpk_path is a L.LightningModule instance, not a path.
for ckpt_name, ckpt_path_or_model in checkpoints_to_use.items():
predictions_file = None
results_filename = None
results_filename_per_sample = None
results = None
per_sample_results = None
if self.save_predictions and not debug:
predictions_file = self._predictions_dir / f"{ckpt_name}.npy"
if self.save_results and not debug:
results_filename = self._results_dir / f"{ckpt_name}.csv"
results_filename_per_sample = (
self._results_dir / f"{ckpt_name}_per_sample.csv"
)
# Load the model from the checkpoint
if cached:
model = ckpt_path_or_model
else:
model = self.model_config.instantiator.load_model_from_checkpoint(
ckpt_path_or_model
)
# Trainer
trainer_params = self._trainer_parameters(
enable_logging=False,
debug=debug,
)
trainer = get_trainer(**trainer_params)
if print_summary:
self._print_evaluation_summary(
trainer_params=trainer_params,
debug=debug,
ckpt_path=ckpt_path_or_model if not cached else None,
predictions_path=predictions_file,
results_path=results_filename,
)
# Perform prediction
predictions = perform_predict(
data_module=data_module,
model=model, # type: ignore
trainer=trainer,
)
if predictions_file is not None:
save_predictions(predictions, predictions_file)
else:
print("Predictions not saved...")
# Perform evaluation
if self.evaluation_metrics:
results = perform_evaluation(
evaluation_metrics=self.evaluation_metrics,
data_module=data_module,
predictions=predictions,
argmax_axis=(
1 if self.model_config.information.return_logits else None
),
per_sample=False,
batch_size=self.data_module._predict_dataloader_kwargs[
"batch_size"
],
device="cuda" if self.accelerator == "gpu" else "cpu",
)
if results_filename is not None:
save_results(results, results_filename, index=False)
else:
print(f"Results not saved...")
else:
print("No evaluation metrics provided. Skipping evaluation.")
# Perform per-sample evaluation
if self.per_sample_evaluation_metrics:
per_sample_results = perform_evaluation(
evaluation_metrics=self.per_sample_evaluation_metrics,
data_module=data_module,
predictions=predictions,
argmax_axis=(
1 if self.model_config.information.return_logits else None
),
per_sample=True,
batch_size=self.data_module._predict_dataloader_kwargs[
"batch_size"
],
device="cuda" if self.accelerator == "gpu" else "cpu",
)
if results_filename_per_sample is not None:
save_results(
per_sample_results,
results_filename_per_sample,
index=False,
)
else:
print(f"Results not saved...")
else:
print(
"No per-sample evaluation metrics provided. Skipping per-sample evaluation."
)
# Store the results
checkpoint_results[ckpt_name] = {
"predictions_path": predictions_file,
"results_path": results_filename,
"results_path_per_sample": results_filename_per_sample,
"results": results,
"results_per_sample": per_sample_results,
}
print(f"Checkpoint {ckpt_name} evaluated!")
return checkpoint_results
# ---------- Default pipeline entrypoints and other methods ---------
[docs]
def _run(
self,
task: str,
debug: bool = False,
resume_from_ckpt: Optional[str] = None,
print_summary: bool = True,
ckpts_to_evaluate: Optional[Union[str, List[str]]] = None,
) -> Dict[str, Any]:
stdout_path = Path(
str(self._stdout_path.parent / self._stdout_path.stem)
+ f"_run_{self._run_count}"
+ self._stdout_path.suffix
)
stderr_path = Path(
str(self._stderr_path.parent / self._stderr_path.stem)
+ f"_run_{self._run_count}"
+ self._stderr_path.suffix
)
with using_redirected_outputs(
output_path=stdout_path,
error_path=stderr_path,
mode="w",
ignore=not self.log_outputs or debug,
):
if task == "fit":
return self._train_model(
resume_from_ckpt=resume_from_ckpt,
print_summary=print_summary,
debug=debug,
)
elif task == "evaluate":
return self._evaluate_model(
ckpts_to_evaluate=ckpts_to_evaluate,
print_summary=print_summary,
debug=debug,
)
elif task == "fit-evaluate":
# Train the model
cached_res = self._train_model(
resume_from_ckpt=resume_from_ckpt,
print_summary=print_summary,
debug=debug,
)
# If checkpoint paths exists, do not cache the results
if len(self.checkpoint_paths) > 0:
cached_res = None
# Evaluate the model
eval_results = self._evaluate_model(
ckpts_to_evaluate=ckpts_to_evaluate,
print_summary=print_summary,
debug=debug,
cached_execution=cached_res,
)
return eval_results
else:
raise ValueError(
f"Unknown task '{task}'. Supported tasks are: 'fit', 'evaluate', or 'fit-evaluate'"
)
[docs]
def cleanup(self):
"""Clean up the experiment by removing the log directory."""
if self.log_dir.exists():
shutil.rmtree(self.log_dir)
print(f"Experiment at '{self.log_dir}' cleaned up.")
else:
print(f"Experiment at '{self.log_dir}' not found.")
@property
def status(self) -> Dict[str, Any]:
d = {}
d["experiment_name"] = self.experiment_name
d["log_dir"] = self.log_dir
d["checkpoints"] = self.checkpoint_paths
d["training_metrics"] = self.training_metrics_path
d["prediction_paths"] = self.prediction_paths
d["results_paths"] = self.results_paths
state = "not executed"
if len(d["checkpoints"]) > 0:
state = "executed"
if len(d["prediction_paths"]) > 0:
state = "predicted"
if len(d["results_paths"]) > 0:
state = "evaluated"
d["state"] = state
return d
# ---------- Python methods ---------
[docs]
def __str__(self) -> str:
def indent_text(text, spaces=6):
"""Indent each line of a string by a given number of spaces."""
if not text:
return "No data."
return "\n".join(
" " * spaces + line if line.strip() else line
for line in text.split("\n")
)
exp_name = f"🚀 Experiment: {self.experiment_name} 🚀"
pretrained_backbone = (
self.pretrained_backbone_ckpt_path
if self.pretrained_backbone_ckpt_path
else "FROM SCRATCH"
)
limit_batches = "Limit batches: "
return (
f"{'=' * 80}\n"
f"{' ' * ((80 - len(exp_name)) // 2)}{exp_name}\n"
f"{'=' * 80}\n"
f"\n🛠 Execution Details\n"
f" ├── Execution ID: {self.execution_id}\n"
f" ├── Log Dir: {self.log_dir}\n"
f" ├── Seed: {self.seed}\n"
f" ├── Accelerator: {self.accelerator}\n"
f" ├── Devices: {self.devices}\n"
f" ├── Max Epochs: {self.max_epochs}\n"
f" ├── Train Batches Limit: {self.limit_train_batches or 'all'}\n"
f" ├── Val Batches Limit: {self.limit_val_batches or 'all'}\n"
f" └── Test Batches Limit: {self.limit_test_batches or 'all'}\n"
f"\n"
# f"{'=' * 50}\n"
f"🧠 Model Information\n"
f" ├── Model Name: {self.model_config.information.name}\n"
f" ├── Pretrained Backbone: {pretrained_backbone}\n"
f" ├── Input Shape: {self.model_config.information.input_shape}\n"
f" ├── Output Shape: {self.model_config.information.output_shape}\n"
f" └── Num Classes: {self.model_config.information.num_classes}\n"
f"\n📂 Dataset Information\n"
f"{indent_text(str(self.data_module), spaces=6)}\n"
)