minerva.callback.embedding_logger_callback ========================================== .. py:module:: minerva.callback.embedding_logger_callback Classes ------- .. autoapisummary:: minerva.callback.embedding_logger_callback.EmbeddingLoggerCallback Module Contents --------------- .. py:class:: EmbeddingLoggerCallback(data_X, logger, data_Y = None, feature_preffix = 'EMB-', backbone_names_list = ['backbone', 'encoder']) Bases: :py:obj:`lightning.Callback` Abstract base class used to build new callbacks. Subclass this class and override any of the relevant hooks Callback to extract and log embeddings from some data using the model's backbone. Parameters ---------- data_X : torch.Tensor Tensor with the input data. logger : CSVLogger, optional The logger to use. data_Y : torch.Tensor, optional Tensor with the target data, by default None. feature_preffix : str, optional The preffix to use for the feature names, by default 'EMB-'. backbone_names_list : List[str], optional List with the names of the backbones in the model, by default ['backbone', 'encoder']. .. py:attribute:: backbone_names_list :value: ['backbone', 'encoder'] .. py:attribute:: data_X .. py:attribute:: data_Y :value: None .. py:attribute:: feature_preffix :value: 'EMB-' .. py:attribute:: logger .. py:method:: on_train_epoch_end(trainer, pl_module) Called when the train epoch ends. To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the :class:`lightning.pytorch.core.LightningModule` and access them in this hook: .. code-block:: python class MyLightningModule(L.LightningModule): def __init__(self): super().__init__() self.training_step_outputs = [] def training_step(self): loss = ... self.training_step_outputs.append(loss) return loss class MyCallback(L.Callback): def on_train_epoch_end(self, trainer, pl_module): # do something with all training_step outputs, for example: epoch_mean = torch.stack(pl_module.training_step_outputs).mean() pl_module.log("training_epoch_mean", epoch_mean) # free up the memory pl_module.training_step_outputs.clear() .. py:method:: on_train_start(trainer, pl_module) Called when the train begins.