import torch
import torchmetrics
from torchmetrics import Metric
from torchmetrics.functional import confusion_matrix
import warnings
[docs]
class BalancedAccuracy(Metric):
def __init__(
self,
num_classes: int,
task: str,
adjusted:
bool = False
):
"""
Compute the balanced accuracy.
The balanced accuracy in binary, multiclass, and multilabel classification problems
deals with imbalanced datasets. It is defined as the average of recall obtained on each class.
Parameters
----------
num_classes : int
The number of classes in the target data.
task : str
The type of classification task, should be one of 'binary' or 'multiclass'
adjusted : bool, optional (default=False)
When true, the result is adjusted for chance, so that random performance would score 0,
while keeping perfect performance at a score of 1.
Attributes
----------
confmat : torch.Tensor
Confusion matrix to keep track of true positives, false positives, true negatives, and false negatives.
Examples
--------
>>> y_true = torch.tensor([0, 1, 0, 0, 1, 0])
>>> y_pred = torch.tensor([0, 1, 0, 0, 0, 1])
>>> metric = BalancedAccuracy(num_classes=2, task='binary')
>>> metric(y_pred, y_true)
0.625
"""
super().__init__()
self.num_classes = num_classes
self.adjusted = adjusted
self.task = task
self.add_state("confmat", default=torch.zeros((num_classes, num_classes)), dist_reduce_fx="sum")
[docs]
def update(self, preds: torch.Tensor, target: torch.Tensor):
self.confmat += confusion_matrix(preds, target, num_classes=self.num_classes, task=self.task)
[docs]
def compute(self):
with torch.no_grad():
per_class = torch.diag(self.confmat) / self.confmat.sum(dim=1)
if torch.any(torch.isnan(per_class)):
warnings.warn(f"y_pred contains nan values and not all classes passed")
per_class = per_class[~torch.isnan(per_class)] # Filter out NaN values
if len(per_class) == 0:
return torch.tensor(0.0) # Return 0 if no valid classes remain
score = torch.mean(per_class)
if self.adjusted:
n_classes = len(per_class)
chance = 1 / n_classes
score -= chance
score /= 1 - chance
return score