Source code for minerva.losses.batchwise_barlowtwins_loss
import torch
from torch.nn.modules.loss import _Loss
import torch.nn.functional as F
from typing import Tuple
from torch import Tensor
import torch.distributed as dist
[docs]
def _off_diagonal(x):
"""
Returns a flattened view of the off-diagonal elements of a square matrix.
Parameters
----------
x : Tensor
A square 2D tensor (cross-correlation matrix).
Returns
-------
Tensor
A 1D tensor containing the flattened off-diagonal elements of the input matrix.
"""
n, m = x.shape
assert n == m
return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()
[docs]
def _normalize(z_a: Tensor, z_b: Tensor) -> Tuple[Tensor, Tensor]:
"""
Normalizes each embedding tensor independently across the batch.
Parameters
----------
z_a : Tensor
Embeddings from the first view.
z_b : Tensor
Embeddings from the second view.
Returns
-------
Tuple[Tensor, Tensor]
A tuple containing the normalized versions of `z_a` and `z_b`.
"""
combined = torch.stack([z_a, z_b], dim=0) # Shape: 2 x N x D
normalized = F.batch_norm(
combined.flatten(0, 1),
running_mean=None,
running_var=None,
weight=None,
bias=None,
training=True,
).view_as(combined)
return normalized[0], normalized[1]
[docs]
class BatchWiseBarlowTwinLoss(_Loss):
"""
Implementation of the Batch-wise Barlow Twins loss function (https://arxiv.org/abs/2310.07756).
"""
def __init__(self, diag_lambda: float = 0.01, normalize: bool = False):
"""
Initializes the BatchWiseBarlowTwinLoss class.
Parameters
----------
diag_lambda : float, optional
The value of the diagonal lambda parameter. Default is 0.01.
normalize : bool, optional
Whether to normalize the loss. Default is False.
"""
super().__init__()
self.diag_lambda = diag_lambda
self.normalize = normalize
[docs]
def forward(self, prediction_data, projection_data):
"""
Calculates the loss between the prediction and projection data using a batch-wise
version of the Barlow Twins loss function.
Parameters
----------
prediction_data : torch.Tensor
Prediction data tensor.
projection_data : torch.Tensor
Projection data tensor.
Returns
-------
torch.Tensor
The computed batch-wise Barlow Twins loss.
"""
return self.bt_loss_bs(
prediction_data, projection_data, self.diag_lambda, self.normalize
)
[docs]
def bt_loss_bs(self, p, z, lambd=0.01, normalize=False):
# barlow twins loss but in batch dims
c = torch.matmul(F.normalize(p), F.normalize(z).T)
assert c.min() > -1 and c.max() < 1
on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
off_diag = _off_diagonal(c).pow_(2).sum()
loss = on_diag + lambd * off_diag
if normalize:
loss = loss / p.shape[0]
return loss
[docs]
class BarlowTwinsLoss(torch.nn.Module):
"""
Implementation of the Barlow Twins loss function for self-supervised learning.
The loss encourages embeddings of two augmented views of the same input to be
similar (invariance) while reducing redundancy between the components of
their representations (decorrelation).
"""
def __init__(self, lambda_param: float = 5e-3, gather_distributed: bool = False):
"""
Initializes the BarlowTwinsLoss module.
Parameters
----------
lambda_param : float, optional
Coefficient for off-diagonal penalty in the loss. Defaults to 5e-3.
gather_distributed : bool, optional
If True, performs all-reduce on the cross-correlation matrix across GPUs. Defaults to False.
Raises
------
ValueError
If gather_distributed is True but torch.distributed is not available.
"""
super(BarlowTwinsLoss, self).__init__()
self.lambda_param = lambda_param
self.gather_distributed = gather_distributed
if gather_distributed and not dist.is_available():
raise ValueError(
"gather_distributed is True but torch.distributed is not available. "
"Please set gather_distributed=False or install a torch version with "
"distributed support."
)
[docs]
def forward(self, z_a: Tensor, z_b: Tensor) -> Tensor:
"""
Computes the Barlow Twins loss.
Parameters
----------
z_a : Tensor
Embedding tensor from the first view. Shape: [batch_size, dim].
z_b : Tensor
Embedding tensor from the second view. Shape: [batch_size, dim].
Returns
-------
Tensor
Scalar loss value combining invariance and redundancy reduction terms.
"""
# normalize repr. along the batch dimension
z_a_norm, z_b_norm = _normalize(z_a, z_b)
N = z_a.size(0)
# cross-correlation matrix
c = z_a_norm.T @ z_b_norm
c.div_(N)
# sum cross-correlation matrix between multiple gpus
if self.gather_distributed and dist.is_initialized():
world_size = dist.get_world_size()
if world_size > 1:
c = c / world_size
dist.all_reduce(c)
invariance_loss = torch.diagonal(c).add_(-1).pow_(2).sum()
redundancy_reduction_loss = _off_diagonal(c).pow_(2).sum()
loss = invariance_loss + self.lambda_param * redundancy_reduction_loss
return loss