Source code for minerva.losses.batchwise_barlowtwins_loss

from torch.nn.modules.loss import _Loss
import torch
import numpy as np
import torch.nn.functional as F


[docs] class BatchWiseBarlowTwinLoss(_Loss): """ Implementation of the Batch-wise Barlow Twins loss function (https://arxiv.org/abs/2310.07756). """ def __init__(self, diag_lambda: float=0.01, normalize: bool=False): """ Initialize the BatchWiseBarlowtwinsLoss class. Parameters ---------- diag_lambda: float The value of the diagonal lambda parameter. By default, 0.01. normalize: bool Whether to normalize the loss or not. By default, False. """ super().__init__() self.diag_lambda = diag_lambda self.normalize = normalize
[docs] def forward(self, prediction_data, projection_data): """ Calculate the loss between the prediction and projection data. This implementation uses a batch-wise version of the Barlow Twins loss function. Parameters ---------- prediction_data : torch.Tensor The prediction data. projection_data : torch.Tensor The projection data. """ return self.bt_loss_bs(prediction_data, projection_data, self.diag_lambda, self.normalize)
[docs] def bt_loss_bs(self, p, z, lambd=0.01, normalize=False): #barlow twins loss but in batch dims c = torch.matmul(F.normalize(p), F.normalize(z).T) assert c.min()>-1 and c.max()<1 on_diag = torch.diagonal(c).add_(-1).pow_(2).sum() off_diag = self.off_diagonal(c).pow_(2).sum() loss = on_diag + lambd * off_diag if normalize: loss = loss/p.shape[0] return loss
[docs] def off_diagonal(self, x): # return a flattened view of the off-diagonal elements of a square matrix n, m = x.shape assert n == m return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()