minerva.callback.specific_checkpoint_callback

Classes

SpecificCheckpointCallback

Abstract base class used to build new callbacks.

Module Contents

class minerva.callback.specific_checkpoint_callback.SpecificCheckpointCallback(specific_epochs=None, specific_steps=None, epoch_var_name=None, step_var_name=None)[source]

Bases: lightning.Callback

Abstract base class used to build new callbacks.

Subclass this class and override any of the relevant hooks

Callback to save model checkpoints at specific epochs and/or steps.

Parameters

specific_epochslist of dict or int, optional

A list specifying the epoch indices at which to save the checkpoints. Each item can be an integer epoch index (starting at 0) or a dictionary defining a range of epoch indexes. If -1 is included, the model initial random weights will be saved.

specific_stepslist of dict or int, optional

A list specifying the step indices at which to save the checkpoints. Each item can be an integer step index (starting at 1) or a dictionary defining a range of step indexes.

epoch_var_namestring, optional

The name of the trainer attribute that holds the current epoch, by default None. If None, ‘current_epoch’ is used.

step_var_namestring, optional

The name of the trainer attribute that holds the current step, by default None. If None, ‘global_step’ is used.

checkpoint_path = None
epoch_var_name = 'current_epoch'
on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)[source]

Checks the step index and saves the checkpoint if specified.

Parameters:
  • trainer (lightning.Trainer)

  • pl_module (lightning.LightningModule)

on_train_epoch_end(trainer, pl_module)[source]

Checks the epoch index and saves the checkpoint if specified.

Parameters:
  • trainer (lightning.Trainer)

  • pl_module (lightning.LightningModule)

on_train_start(trainer, pl_module)[source]

It creates the checkpoints folder at the start of the training. If required, it also saves the model initial random weights.

Parameters:
  • trainer (lightning.Trainer)

  • pl_module (lightning.LightningModule)

specific_epochs = []
specific_steps = []
step_var_name = 'global_step'
Parameters:
  • specific_epochs (Optional[List[Union[Dict, int]]])

  • specific_steps (Optional[List[Union[Dict, int]]])

  • epoch_var_name (Optional[str])

  • step_var_name (Optional[str])