minerva.losses.batchwise_barlowtwins_loss

Classes

BarlowTwinsLoss

Implementation of the Barlow Twins loss function for self-supervised learning.

BatchWiseBarlowTwinLoss

Implementation of the Batch-wise Barlow Twins loss function (https://arxiv.org/abs/2310.07756).

Functions

_normalize(z_a, z_b)

Normalizes each embedding tensor independently across the batch.

_off_diagonal(x)

Returns a flattened view of the off-diagonal elements of a square matrix.

Module Contents

class minerva.losses.batchwise_barlowtwins_loss.BarlowTwinsLoss(lambda_param=0.005, gather_distributed=False)[source]

Bases: torch.nn.Module

Implementation of the Barlow Twins loss function for self-supervised learning.

The loss encourages embeddings of two augmented views of the same input to be similar (invariance) while reducing redundancy between the components of their representations (decorrelation).

Initializes the BarlowTwinsLoss module.

Parameters

lambda_paramfloat, optional

Coefficient for off-diagonal penalty in the loss. Defaults to 5e-3.

gather_distributedbool, optional

If True, performs all-reduce on the cross-correlation matrix across GPUs. Defaults to False.

Raises

ValueError

If gather_distributed is True but torch.distributed is not available.

forward(z_a, z_b)[source]

Computes the Barlow Twins loss.

Parameters

z_aTensor

Embedding tensor from the first view. Shape: [batch_size, dim].

z_bTensor

Embedding tensor from the second view. Shape: [batch_size, dim].

Returns

Tensor

Scalar loss value combining invariance and redundancy reduction terms.

Parameters:
  • z_a (torch.Tensor)

  • z_b (torch.Tensor)

Return type:

torch.Tensor

gather_distributed = False
lambda_param = 0.005
Parameters:
  • lambda_param (float)

  • gather_distributed (bool)

class minerva.losses.batchwise_barlowtwins_loss.BatchWiseBarlowTwinLoss(diag_lambda=0.01, normalize=False)[source]

Bases: torch.nn.modules.loss._Loss

Implementation of the Batch-wise Barlow Twins loss function (https://arxiv.org/abs/2310.07756).

Initializes the BatchWiseBarlowTwinLoss class.

Parameters

diag_lambdafloat, optional

The value of the diagonal lambda parameter. Default is 0.01.

normalizebool, optional

Whether to normalize the loss. Default is False.

bt_loss_bs(p, z, lambd=0.01, normalize=False)[source]
diag_lambda = 0.01
forward(prediction_data, projection_data)[source]

Calculates the loss between the prediction and projection data using a batch-wise version of the Barlow Twins loss function.

Parameters

prediction_datatorch.Tensor

Prediction data tensor.

projection_datatorch.Tensor

Projection data tensor.

Returns

torch.Tensor

The computed batch-wise Barlow Twins loss.

normalize = False
Parameters:
  • diag_lambda (float)

  • normalize (bool)

minerva.losses.batchwise_barlowtwins_loss._normalize(z_a, z_b)[source]

Normalizes each embedding tensor independently across the batch.

Parameters

z_aTensor

Embeddings from the first view.

z_bTensor

Embeddings from the second view.

Returns

Tuple[Tensor, Tensor]

A tuple containing the normalized versions of z_a and z_b.

Parameters:
  • z_a (torch.Tensor)

  • z_b (torch.Tensor)

Return type:

Tuple[torch.Tensor, torch.Tensor]

minerva.losses.batchwise_barlowtwins_loss._off_diagonal(x)[source]

Returns a flattened view of the off-diagonal elements of a square matrix.

Parameters

xTensor

A square 2D tensor (cross-correlation matrix).

Returns

Tensor

A 1D tensor containing the flattened off-diagonal elements of the input matrix.