"""Functional API for losses.
"""
import torch
# Borrowed from https://github.com/qubvel/segmentation_models.pytorch/blob/master/segmentation_models_pytorch/losses/_functional.py
[docs]
def dice_score(
y_hat: torch.Tensor,
y: torch.Tensor,
smooth: float = 0.0,
eps: float = 1e-7,
dims=None,
) -> torch.Tensor:
assert y_hat.size() == y.size()
if dims is not None:
intersection = torch.sum(y_hat * y, dim=dims)
cardinality = torch.sum(y_hat + y, dim=dims)
else:
intersection = torch.sum(y_hat * y)
cardinality = torch.sum(y_hat + y)
dice_score = (2.0 * intersection + smooth) / (cardinality + smooth).clamp_min(eps)
return dice_score