minerva.losses.batchwise_barlowtwins_loss¶
Classes¶
Implementation of the Barlow Twins loss function for self-supervised learning. |
|
Implementation of the Batch-wise Barlow Twins loss function (https://arxiv.org/abs/2310.07756). |
Functions¶
|
Normalizes each embedding tensor independently across the batch. |
Returns a flattened view of the off-diagonal elements of a square matrix. |
Module Contents¶
- class minerva.losses.batchwise_barlowtwins_loss.BarlowTwinsLoss(lambda_param=0.005, gather_distributed=False)[source]¶
Bases:
torch.nn.ModuleImplementation of the Barlow Twins loss function for self-supervised learning.
The loss encourages embeddings of two augmented views of the same input to be similar (invariance) while reducing redundancy between the components of their representations (decorrelation).
Initializes the BarlowTwinsLoss module.
Parameters¶
- lambda_paramfloat, optional
Coefficient for off-diagonal penalty in the loss. Defaults to 5e-3.
- gather_distributedbool, optional
If True, performs all-reduce on the cross-correlation matrix across GPUs. Defaults to False.
Raises¶
- ValueError
If gather_distributed is True but torch.distributed is not available.
- forward(z_a, z_b)[source]¶
Computes the Barlow Twins loss.
Parameters¶
- z_aTensor
Embedding tensor from the first view. Shape: [batch_size, dim].
- z_bTensor
Embedding tensor from the second view. Shape: [batch_size, dim].
Returns¶
- Tensor
Scalar loss value combining invariance and redundancy reduction terms.
- Parameters:
z_a (torch.Tensor)
z_b (torch.Tensor)
- Return type:
torch.Tensor
- gather_distributed = False¶
- lambda_param = 0.005¶
- Parameters:
lambda_param (float)
gather_distributed (bool)
- class minerva.losses.batchwise_barlowtwins_loss.BatchWiseBarlowTwinLoss(diag_lambda=0.01, normalize=False)[source]¶
Bases:
torch.nn.modules.loss._LossImplementation of the Batch-wise Barlow Twins loss function (https://arxiv.org/abs/2310.07756).
Initializes the BatchWiseBarlowTwinLoss class.
Parameters¶
- diag_lambdafloat, optional
The value of the diagonal lambda parameter. Default is 0.01.
- normalizebool, optional
Whether to normalize the loss. Default is False.
- diag_lambda = 0.01¶
- forward(prediction_data, projection_data)[source]¶
Calculates the loss between the prediction and projection data using a batch-wise version of the Barlow Twins loss function.
Parameters¶
- prediction_datatorch.Tensor
Prediction data tensor.
- projection_datatorch.Tensor
Projection data tensor.
Returns¶
- torch.Tensor
The computed batch-wise Barlow Twins loss.
- normalize = False¶
- Parameters:
diag_lambda (float)
normalize (bool)
- minerva.losses.batchwise_barlowtwins_loss._normalize(z_a, z_b)[source]¶
Normalizes each embedding tensor independently across the batch.
Parameters¶
- z_aTensor
Embeddings from the first view.
- z_bTensor
Embeddings from the second view.
Returns¶
- Tuple[Tensor, Tensor]
A tuple containing the normalized versions of z_a and z_b.
- Parameters:
z_a (torch.Tensor)
z_b (torch.Tensor)
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- minerva.losses.batchwise_barlowtwins_loss._off_diagonal(x)[source]¶
Returns a flattened view of the off-diagonal elements of a square matrix.
Parameters¶
- xTensor
A square 2D tensor (cross-correlation matrix).
Returns¶
- Tensor
A 1D tensor containing the flattened off-diagonal elements of the input matrix.