minerva.callback.embedding_logger_callback
Classes
Abstract base class used to build new callbacks. |
Module Contents
- class minerva.callback.embedding_logger_callback.EmbeddingLoggerCallback(data_X, logger, data_Y=None, feature_preffix='EMB-', backbone_names_list=['backbone', 'encoder'])[source]
Bases:
lightning.Callback
Abstract base class used to build new callbacks.
Subclass this class and override any of the relevant hooks
Callback to extract and log embeddings from some data using the model’s backbone.
Parameters
- data_Xtorch.Tensor
Tensor with the input data.
- loggerCSVLogger, optional
The logger to use.
- data_Ytorch.Tensor, optional
Tensor with the target data, by default None.
- feature_preffixstr, optional
The preffix to use for the feature names, by default ‘EMB-‘.
- backbone_names_listList[str], optional
List with the names of the backbones in the model, by default [‘backbone’, ‘encoder’].
- backbone_names_list = ['backbone', 'encoder']
- data_X
- data_Y = None
- feature_preffix = 'EMB-'
- logger
- on_train_epoch_end(trainer, pl_module)[source]
Called when the train epoch ends.
To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the
lightning.pytorch.core.LightningModule
and access them in this hook:class MyLightningModule(L.LightningModule): def __init__(self): super().__init__() self.training_step_outputs = [] def training_step(self): loss = ... self.training_step_outputs.append(loss) return loss class MyCallback(L.Callback): def on_train_epoch_end(self, trainer, pl_module): # do something with all training_step outputs, for example: epoch_mean = torch.stack(pl_module.training_step_outputs).mean() pl_module.log("training_epoch_mean", epoch_mean) # free up the memory pl_module.training_step_outputs.clear()
- Parameters:
trainer (lightning.Trainer)
pl_module (lightning.LightningModule)
- Return type:
None
- Parameters:
data_X (torch.Tensor)
logger (lightning.pytorch.loggers.CSVLogger)
data_Y (torch.Tensor)
feature_preffix (str)
backbone_names_list (List[str])