minerva.losses.batchwise_barlowtwins_loss
Classes
Implementation of the Batch-wise Barlow Twins loss function (https://arxiv.org/abs/2310.07756). |
Module Contents
- class minerva.losses.batchwise_barlowtwins_loss.BatchWiseBarlowTwinLoss(diag_lambda=0.01, normalize=False)[source]
Bases:
torch.nn.modules.loss._Loss
Implementation of the Batch-wise Barlow Twins loss function (https://arxiv.org/abs/2310.07756).
Initialize the BatchWiseBarlowtwinsLoss class.
Parameters
- diag_lambda: float
The value of the diagonal lambda parameter. By default, 0.01.
- normalize: bool
Whether to normalize the loss or not. By default, False.
- diag_lambda = 0.01
- forward(prediction_data, projection_data)[source]
Calculate the loss between the prediction and projection data. This implementation uses a batch-wise version of the Barlow Twins loss function.
Parameters
- prediction_datatorch.Tensor
The prediction data.
- projection_datatorch.Tensor
The projection data.
- normalize = False
- Parameters:
diag_lambda (float)
normalize (bool)