minerva.losses.batchwise_barlowtwins_loss

Classes

BatchWiseBarlowTwinLoss

Implementation of the Batch-wise Barlow Twins loss function (https://arxiv.org/abs/2310.07756).

Module Contents

class minerva.losses.batchwise_barlowtwins_loss.BatchWiseBarlowTwinLoss(diag_lambda=0.01, normalize=False)[source]

Bases: torch.nn.modules.loss._Loss

Implementation of the Batch-wise Barlow Twins loss function (https://arxiv.org/abs/2310.07756).

Initialize the BatchWiseBarlowtwinsLoss class.

Parameters

diag_lambda: float

The value of the diagonal lambda parameter. By default, 0.01.

normalize: bool

Whether to normalize the loss or not. By default, False.

bt_loss_bs(p, z, lambd=0.01, normalize=False)[source]
diag_lambda = 0.01
forward(prediction_data, projection_data)[source]

Calculate the loss between the prediction and projection data. This implementation uses a batch-wise version of the Barlow Twins loss function.

Parameters

prediction_datatorch.Tensor

The prediction data.

projection_datatorch.Tensor

The projection data.

normalize = False
off_diagonal(x)[source]
Parameters:
  • diag_lambda (float)

  • normalize (bool)