minerva.losses.weighted_dice_loss¶
Classes¶
Base class for all neural network modules. |
|
Base class for all neural network modules. |
Module Contents¶
- class minerva.losses.weighted_dice_loss.BinaryWeightedDiceLoss(smooth=1.0)[source]¶
Bases:
torch.nn.ModuleBase class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes:
import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their parameters converted when you call
to(), etc.Note
As per the example above, an
__init__()call to the parent class must be made before assignment on the child.- Variables:
training (bool) – Boolean represents whether this module is in training or evaluation mode.
- Parameters:
smooth (float)
Implements Weighted Dice Loss for binary segmentation.
Applies sigmoid to predictions to obtain probabilities and calculates weights for foreground and background based on pixel frequency in the target, to balance the contribution of each class.
Dice Loss formula: dice_loss = 1 - (2 * Σ(w_i * p_i * t_i) + smooth) / (Σ(w_i * p_i) + Σ(w_i * t_i) + smooth)
Where w_i are the weights per pixel, p_i are the predictions after sigmoid, and t_i are the binary targets.
Parameters¶
- smoothfloat, optional
value to smooth calculation and avoid division by zero. Default is 1.0.
- forward(pred, target)[source]¶
Calculates the binary Weighted Dice Loss.
Parameters¶
- predtorch.Tensor
tensor of raw predictions (logits) with shape (B, 1, H, W).
- targettorch.Tensor
target binary tensor (0 or 1), shape (B, 1, H, W).
- Returns:
torch.Tensor: scalar tensor containing the loss value.
- Parameters:
pred (torch.Tensor)
target (torch.Tensor)
- get_weight(target)[source]¶
Calculates per-pixel weights based on the ratio of foreground and background in the target.
Parameters¶
- targettorch.Tensor
target binary tensor (0 or 1), shape (B, 1, H, W).
- Returns:
torch.Tensor: tensor of per-pixel weights, same shape as target.
- Parameters:
target (torch.Tensor)
- smooth = 1.0¶
- class minerva.losses.weighted_dice_loss.WeightedDiceLoss(num_classes, smooth=1.0)[source]¶
Bases:
torch.nn.ModuleBase class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes:
import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their parameters converted when you call
to(), etc.Note
As per the example above, an
__init__()call to the parent class must be made before assignment on the child.- Variables:
training (bool) – Boolean represents whether this module is in training or evaluation mode.
- Parameters:
num_classes (int)
smooth (float)
Implements Weighted Dice Loss for multiclass segmentation tasks.
The Dice Loss is calculated for each class individually and then weighted averaged by the class frequencies to compensate for imbalance.
Dice Loss formula for each class c: Dice = (2 * intersection + smooth) / (sum of both + smooth) DiceLoss = 1 - Dice
The weights are calculated inversely proportional to the frequency of the class in the target, giving more importance to classes with fewer pixels.
Note: The weights are applied within the intersection and union calculation, by multiplying the terms, and not directly as an overall weight on the class loss.
Parameters¶
- num_classesint
total number of classes in the segmentation.
- smoothfloat, optional
value to smooth the calculation and avoid division by zero. Default is 1.0.
- forward(pred, target)[source]¶
Calculates the Weighted Dice Loss.
Parameters¶
- predtorch.Tensor
tensor of raw predictions (logits) with shape (B, C, H, W).
- targettorch.Tensor
target tensor with integer class labels, shape (B, H, W).
- Returns:
torch.Tensor: scalar tensor containing the weighted average loss value.
- Parameters:
pred (torch.Tensor)
target (torch.Tensor)
- get_weight(target)[source]¶
Calculates weights for each class based on the inverse frequency in the target.
Parameters¶
- targettorch.Tensor
target tensor with integer class labels, shape(B, H, W).
- Returns:
torch.Tensor: 1D tensor with normalized weights for each class, shape(num_classes,).
- Parameters:
target (torch.Tensor)
- num_classes¶
- smooth = 1.0¶