minerva.callback.embedding_logger_callback

Classes

EmbeddingLoggerCallback

Abstract base class used to build new callbacks.

Module Contents

class minerva.callback.embedding_logger_callback.EmbeddingLoggerCallback(data_X, logger, data_Y=None, feature_preffix='EMB-', backbone_names_list=['backbone', 'encoder'])[source]

Bases: lightning.Callback

Abstract base class used to build new callbacks.

Subclass this class and override any of the relevant hooks

Callback to extract and log embeddings from some data using the model’s backbone.

Parameters

data_Xtorch.Tensor

Tensor with the input data.

loggerCSVLogger, optional

The logger to use.

data_Ytorch.Tensor, optional

Tensor with the target data, by default None.

feature_preffixstr, optional

The preffix to use for the feature names, by default ‘EMB-‘.

backbone_names_listList[str], optional

List with the names of the backbones in the model, by default [‘backbone’, ‘encoder’].

backbone_names_list = ['backbone', 'encoder']
data_X
data_Y = None
feature_preffix = 'EMB-'
logger
on_train_epoch_end(trainer, pl_module)[source]

Called when the train epoch ends.

To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the lightning.pytorch.core.LightningModule and access them in this hook:

class MyLightningModule(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.training_step_outputs = []

    def training_step(self):
        loss = ...
        self.training_step_outputs.append(loss)
        return loss


class MyCallback(L.Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        # do something with all training_step outputs, for example:
        epoch_mean = torch.stack(pl_module.training_step_outputs).mean()
        pl_module.log("training_epoch_mean", epoch_mean)
        # free up the memory
        pl_module.training_step_outputs.clear()
Parameters:
  • trainer (lightning.Trainer)

  • pl_module (lightning.LightningModule)

Return type:

None

on_train_start(trainer, pl_module)[source]

Called when the train begins.

Parameters:
  • trainer (lightning.Trainer)

  • pl_module (lightning.LightningModule)

Return type:

None

Parameters:
  • data_X (torch.Tensor)

  • logger (lightning.pytorch.loggers.CSVLogger)

  • data_Y (torch.Tensor)

  • feature_preffix (str)

  • backbone_names_list (List[str])