from torch.nn.modules.loss import _Loss
import torch
import numpy as np
import torch.nn.functional as F
[docs]
class BatchWiseBarlowTwinLoss(_Loss):
    """
    Implementation of the Batch-wise Barlow Twins loss function (https://arxiv.org/abs/2310.07756).
    """
    def __init__(self, diag_lambda: float = 0.01, normalize: bool = False):
        """
        Initialize the BatchWiseBarlowtwinsLoss class.
        Parameters
        ----------
        diag_lambda: float
            The value of the diagonal lambda parameter. By default, 0.01.
        normalize: bool
            Whether to normalize the loss or not. By default, False.
        """
        super().__init__()
        self.diag_lambda = diag_lambda
        self.normalize = normalize
[docs]
    def forward(self, prediction_data, projection_data):
        """
        Calculate the loss between the prediction and projection data. This implementation uses a batch-wise
        version of the Barlow Twins loss function.
        Parameters
        ----------
        prediction_data : torch.Tensor
            The prediction data.
        projection_data : torch.Tensor
            The projection data.
        """
        return self.bt_loss_bs(
            prediction_data, projection_data, self.diag_lambda, self.normalize
        ) 
[docs]
    def bt_loss_bs(self, p, z, lambd=0.01, normalize=False):
        # barlow twins loss but in batch dims
        c = torch.matmul(F.normalize(p), F.normalize(z).T)
        assert c.min() > -1 and c.max() < 1
        on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
        off_diag = self.off_diagonal(c).pow_(2).sum()
        loss = on_diag + lambd * off_diag
        if normalize:
            loss = loss / p.shape[0]
        return loss 
[docs]
    def off_diagonal(self, x):
        # return a flattened view of the off-diagonal elements of a square matrix
        n, m = x.shape
        assert n == m
        return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()