minerva.losses.weighted_dice_loss

Classes

BinaryWeightedDiceLoss

Base class for all neural network modules.

WeightedDiceLoss

Base class for all neural network modules.

Module Contents

class minerva.losses.weighted_dice_loss.BinaryWeightedDiceLoss(smooth=1.0)[source]

Bases: torch.nn.Module

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes:

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

Note

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

Variables:

training (bool) – Boolean represents whether this module is in training or evaluation mode.

Parameters:

smooth (float)

Implements Weighted Dice Loss for binary segmentation.

Applies sigmoid to predictions to obtain probabilities and calculates weights for foreground and background based on pixel frequency in the target, to balance the contribution of each class.

Dice Loss formula: dice_loss = 1 - (2 * Σ(w_i * p_i * t_i) + smooth) / (Σ(w_i * p_i) + Σ(w_i * t_i) + smooth)

Where w_i are the weights per pixel, p_i are the predictions after sigmoid, and t_i are the binary targets.

Parameters

smoothfloat, optional

value to smooth calculation and avoid division by zero. Default is 1.0.

forward(pred, target)[source]

Calculates the binary Weighted Dice Loss.

Parameters

predtorch.Tensor

tensor of raw predictions (logits) with shape (B, 1, H, W).

targettorch.Tensor

target binary tensor (0 or 1), shape (B, 1, H, W).

Returns:

torch.Tensor: scalar tensor containing the loss value.

Parameters:
  • pred (torch.Tensor)

  • target (torch.Tensor)

get_weight(target)[source]

Calculates per-pixel weights based on the ratio of foreground and background in the target.

Parameters

targettorch.Tensor

target binary tensor (0 or 1), shape (B, 1, H, W).

Returns:

torch.Tensor: tensor of per-pixel weights, same shape as target.

Parameters:

target (torch.Tensor)

smooth = 1.0
class minerva.losses.weighted_dice_loss.WeightedDiceLoss(num_classes, smooth=1.0)[source]

Bases: torch.nn.Module

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes:

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

Note

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

Variables:

training (bool) – Boolean represents whether this module is in training or evaluation mode.

Parameters:
  • num_classes (int)

  • smooth (float)

Implements Weighted Dice Loss for multiclass segmentation tasks.

The Dice Loss is calculated for each class individually and then weighted averaged by the class frequencies to compensate for imbalance.

Dice Loss formula for each class c: Dice = (2 * intersection + smooth) / (sum of both + smooth) DiceLoss = 1 - Dice

The weights are calculated inversely proportional to the frequency of the class in the target, giving more importance to classes with fewer pixels.

Note: The weights are applied within the intersection and union calculation, by multiplying the terms, and not directly as an overall weight on the class loss.

Parameters

num_classesint

total number of classes in the segmentation.

smoothfloat, optional

value to smooth the calculation and avoid division by zero. Default is 1.0.

forward(pred, target)[source]

Calculates the Weighted Dice Loss.

Parameters

predtorch.Tensor

tensor of raw predictions (logits) with shape (B, C, H, W).

targettorch.Tensor

target tensor with integer class labels, shape (B, H, W).

Returns:

torch.Tensor: scalar tensor containing the weighted average loss value.

Parameters:
  • pred (torch.Tensor)

  • target (torch.Tensor)

get_weight(target)[source]

Calculates weights for each class based on the inverse frequency in the target.

Parameters

targettorch.Tensor

target tensor with integer class labels, shape(B, H, W).

Returns:

torch.Tensor: 1D tensor with normalized weights for each class, shape(num_classes,).

Parameters:

target (torch.Tensor)

num_classes
smooth = 1.0