minerva.losses.weighted_dice_loss ================================= .. py:module:: minerva.losses.weighted_dice_loss Classes ------- .. autoapisummary:: minerva.losses.weighted_dice_loss.BinaryWeightedDiceLoss minerva.losses.weighted_dice_loss.WeightedDiceLoss Module Contents --------------- .. py:class:: BinaryWeightedDiceLoss(smooth = 1.0) Bases: :py:obj:`torch.nn.Module` Base class for all neural network modules. Your models should also subclass this class. Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes:: import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x)) Submodules assigned in this way will be registered, and will also have their parameters converted when you call :meth:`to`, etc. .. note:: As per the example above, an ``__init__()`` call to the parent class must be made before assignment on the child. :ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool Implements Weighted Dice Loss for binary segmentation. Applies sigmoid to predictions to obtain probabilities and calculates weights for foreground and background based on pixel frequency in the target, to balance the contribution of each class. Dice Loss formula: dice_loss = 1 - (2 * Σ(w_i * p_i * t_i) + smooth) / (Σ(w_i * p_i) + Σ(w_i * t_i) + smooth) Where w_i are the weights per pixel, p_i are the predictions after sigmoid, and t_i are the binary targets. Parameters ---------- smooth : float, optional value to smooth calculation and avoid division by zero. Default is 1.0. .. py:method:: forward(pred, target) Calculates the binary Weighted Dice Loss. Parameters ---------- pred : torch.Tensor tensor of raw predictions (logits) with shape (B, 1, H, W). target : torch.Tensor target binary tensor (0 or 1), shape (B, 1, H, W). Returns: torch.Tensor: scalar tensor containing the loss value. .. py:method:: get_weight(target) Calculates per-pixel weights based on the ratio of foreground and background in the target. Parameters ---------- target : torch.Tensor target binary tensor (0 or 1), shape (B, 1, H, W). Returns: torch.Tensor: tensor of per-pixel weights, same shape as target. .. py:attribute:: smooth :value: 1.0 .. py:class:: WeightedDiceLoss(num_classes, smooth = 1.0) Bases: :py:obj:`torch.nn.Module` Base class for all neural network modules. Your models should also subclass this class. Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes:: import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x)) Submodules assigned in this way will be registered, and will also have their parameters converted when you call :meth:`to`, etc. .. note:: As per the example above, an ``__init__()`` call to the parent class must be made before assignment on the child. :ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool Implements Weighted Dice Loss for multiclass segmentation tasks. The Dice Loss is calculated for each class individually and then weighted averaged by the class frequencies to compensate for imbalance. Dice Loss formula for each class c: Dice = (2 * intersection + smooth) / (sum of both + smooth) DiceLoss = 1 - Dice The weights are calculated inversely proportional to the frequency of the class in the target, giving more importance to classes with fewer pixels. Note: The weights are applied within the intersection and union calculation, by multiplying the terms, and not directly as an overall weight on the class loss. Parameters ---------- num_classes : int total number of classes in the segmentation. smooth : float, optional value to smooth the calculation and avoid division by zero. Default is 1.0. .. py:method:: forward(pred, target) Calculates the Weighted Dice Loss. Parameters ---------- pred : torch.Tensor tensor of raw predictions (logits) with shape (B, C, H, W). target : torch.Tensor target tensor with integer class labels, shape (B, H, W). Returns: torch.Tensor: scalar tensor containing the weighted average loss value. .. py:method:: get_weight(target) Calculates weights for each class based on the inverse frequency in the target. Parameters ---------- target : torch.Tensor target tensor with integer class labels, shape(B, H, W). Returns: torch.Tensor: 1D tensor with normalized weights for each class, shape(num_classes,). .. py:attribute:: num_classes .. py:attribute:: smooth :value: 1.0