Source code for minerva.losses.weighted_dice_loss

import torch
import torch.nn.functional as F


[docs] class WeightedDiceLoss(torch.nn.Module): def __init__(self, num_classes: int, smooth: float = 1.0): """ Implements Weighted Dice Loss for multiclass segmentation tasks. The Dice Loss is calculated for each class individually and then weighted averaged by the class frequencies to compensate for imbalance. Dice Loss formula for each class c: Dice = (2 * intersection + smooth) / (sum of both + smooth) DiceLoss = 1 - Dice The weights are calculated inversely proportional to the frequency of the class in the target, giving more importance to classes with fewer pixels. Note: The weights are applied within the intersection and union calculation, by multiplying the terms, and not directly as an overall weight on the class loss. Parameters ---------- num_classes : int total number of classes in the segmentation. smooth : float, optional value to smooth the calculation and avoid division by zero. Default is 1.0. """ super().__init__() self.num_classes = num_classes self.smooth = smooth
[docs] def forward(self, pred: torch.Tensor, target: torch.Tensor): """ Calculates the Weighted Dice Loss. Parameters ---------- pred : torch.Tensor tensor of raw predictions (logits) with shape (B, C, H, W). target : torch.Tensor target tensor with integer class labels, shape (B, H, W). Returns: torch.Tensor: scalar tensor containing the weighted average loss value. """ B, C, H, W = pred.shape weights = self.get_weight(target) pred = F.softmax(pred, dim=1) target = F.one_hot(target, self.num_classes).permute(0, 3, 1, 2) target = target.contiguous().view(B, C, H, W) total_loss = 0 for class_idx in range(self.num_classes): class_true = target[:, class_idx, ...] class_pred = pred[:, class_idx, ...] intersection = torch.sum(weights[class_idx] * (class_true * class_pred)) union = torch.sum(weights[class_idx] * class_true) + torch.sum( weights[class_idx] * class_pred ) dice_loss = 1 - (2.0 * intersection + self.smooth) / (union + self.smooth) total_loss += dice_loss return total_loss / self.num_classes
[docs] def get_weight(self, target: torch.Tensor): """ Calculates weights for each class based on the inverse frequency in the target. Parameters ---------- target : torch.Tensor target tensor with integer class labels, shape(B, H, W). Returns: torch.Tensor: 1D tensor with normalized weights for each class, shape(num_classes,). """ cls = torch.arange(self.num_classes).reshape(-1, 1).to(target.device) counts = torch.bincount(torch.flatten(target), minlength=self.num_classes) cls_num = counts[cls] denominator = torch.where(cls_num != 0, cls_num.float(), torch.tensor(1e10)) alpha = 1 / denominator alpha_norm = alpha / alpha.sum() return alpha_norm
[docs] class BinaryWeightedDiceLoss(torch.nn.Module): def __init__(self, smooth: float = 1.0): """ Implements Weighted Dice Loss for binary segmentation. Applies sigmoid to predictions to obtain probabilities and calculates weights for foreground and background based on pixel frequency in the target, to balance the contribution of each class. Dice Loss formula: dice_loss = 1 - (2 * Σ(w_i * p_i * t_i) + smooth) / (Σ(w_i * p_i) + Σ(w_i * t_i) + smooth) Where w_i are the weights per pixel, p_i are the predictions after sigmoid, and t_i are the binary targets. Parameters ---------- smooth : float, optional value to smooth calculation and avoid division by zero. Default is 1.0. """ super().__init__() self.smooth = smooth
[docs] def forward(self, pred: torch.Tensor, target: torch.Tensor): """ Calculates the binary Weighted Dice Loss. Parameters ---------- pred : torch.Tensor tensor of raw predictions (logits) with shape (B, 1, H, W). target : torch.Tensor target binary tensor (0 or 1), shape (B, 1, H, W). Returns: torch.Tensor: scalar tensor containing the loss value. """ pred = torch.sigmoid(pred) weights = self.get_weight(target) intersection = torch.sum(weights * (target * pred)) union = torch.sum(weights * target) + torch.sum(weights * pred) dice_loss = 1 - (2.0 * intersection + self.smooth) / (union + self.smooth) return dice_loss
[docs] def get_weight(self, target: torch.Tensor): """ Calculates per-pixel weights based on the ratio of foreground and background in the target. Parameters ---------- target : torch.Tensor target binary tensor (0 or 1), shape (B, 1, H, W). Returns: torch.Tensor: tensor of per-pixel weights, same shape as target. """ num_foreground = torch.sum(target) num_background = torch.sum(1 - target) total = num_foreground + num_background alpha_foreground = num_background / total alpha_background = num_foreground / total weights = target * alpha_foreground + (1 - target) * alpha_background return weights