Source code for minerva.callback.specific_checkpoint_callback
from lightning import Callback, LightningModule, Trainer
from typing import Optional, List, Union, Dict
from pathlib import Path
[docs]
class SpecificCheckpointCallback(Callback):
def __init__(
self,
specific_epochs: Optional[List[Union[Dict, int]]] = None,
specific_steps: Optional[List[Union[Dict, int]]] = None,
epoch_var_name: Optional[str] = None,
step_var_name: Optional[str] = None,
):
"""
Callback to save model checkpoints at specific epochs and/or steps.
Parameters
----------
specific_epochs : list of dict or int, optional
A list specifying the epoch indices at which to save the checkpoints.
Each item can be an integer epoch index (starting at 0) or a dictionary
defining a range of epoch indexes. If -1 is included, the model initial
random weights will be saved.
specific_steps : list of dict or int, optional
A list specifying the step indices at which to save the checkpoints.
Each item can be an integer step index (starting at 1) or a dictionary
defining a range of step indexes.
epoch_var_name : string, optional
The name of the trainer attribute that holds the current epoch, by default
None. If None, 'current_epoch' is used.
step_var_name : string, optional
The name of the trainer attribute that holds the current step, by default
None. If None, 'global_step' is used.
"""
super().__init__()
self.checkpoint_path = None
self.specific_epochs = specific_epochs or []
self.specific_steps = specific_steps or []
self.epoch_var_name = epoch_var_name or "current_epoch"
self.step_var_name = step_var_name or "global_step"
epochs_expanded = []
for value in self.specific_epochs:
if type(value) == int:
epochs_expanded += [value]
elif type(value) == dict:
epochs_expanded += list(
range(value["start"], value["stop"], value["step"])
)
self.specific_epochs = epochs_expanded
steps_expanded = []
for value in self.specific_steps:
if type(value) == int:
steps_expanded += [value]
elif type(value) == dict:
steps_expanded += list(
range(value["start"], value["stop"], value["step"])
)
self.specific_steps = steps_expanded
[docs]
def on_train_start(self, trainer: Trainer, pl_module: LightningModule):
"""
It creates the checkpoints folder at the start of the training. If required,
it also saves the model initial random weights.
"""
self.checkpoint_path = Path(trainer.log_dir) / "checkpoints"
self.checkpoint_path.mkdir(exist_ok=True)
if -1 in self.specific_epochs:
filename = f"epoch=-1.ckpt"
trainer.save_checkpoint(self.checkpoint_path / filename)
[docs]
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
"""
Checks the epoch index and saves the checkpoint if specified.
"""
curernt_epoch = None
if hasattr(pl_module, self.epoch_var_name):
current_epoch = getattr(pl_module, self.epoch_var_name)
elif hasattr(trainer, self.epoch_var_name):
current_epoch = getattr(trainer, self.epoch_var_name)
if current_epoch is not None and current_epoch in self.specific_epochs:
filename = f"epoch={current_epoch}.ckpt"
trainer.save_checkpoint(self.checkpoint_path / filename)
[docs]
def on_train_batch_end(
self, trainer: Trainer, pl_module: LightningModule, outputs, batch, batch_idx
):
"""
Checks the step index and saves the checkpoint if specified.
"""
global_step = None
if hasattr(pl_module, self.step_var_name):
global_step = getattr(pl_module, self.step_var_name)
elif hasattr(trainer, self.step_var_name):
global_step = getattr(trainer, self.step_var_name)
if global_step is not None and global_step in self.specific_steps:
filename = f"step={global_step}.ckpt"
trainer.save_checkpoint(self.checkpoint_path / filename)