minerva.analysis.metrics.balanced_accuracy
Classes
Base class for all metrics present in the Metrics API. |
Module Contents
- class minerva.analysis.metrics.balanced_accuracy.BalancedAccuracy(num_classes, task, adjusted=False)[source]
Bases:
torchmetrics.Metric
Base class for all metrics present in the Metrics API.
This class is inherited by all metrics and implements the following functionality:
Handles the transfer of metric states to the correct device.
Handles the synchronization of metric states across processes.
Provides properties and methods to control the overall behavior of the metric and its states.
The three core methods of the base class are:
add_state()
,forward()
andreset()
which should almost never be overwritten by child classes. Instead, the following methods should be overwrittenupdate()
andcompute()
.- Args:
kwargs: additional keyword arguments, see Metric kwargs for more info.
- compute_on_cpu:
If metric state should be stored on CPU during computations. Only works for list states.
- dist_sync_on_step:
If metric state should synchronize on
forward()
. Default isFalse
.
- process_group:
The process group on which the synchronization is called. Default is the world.
- dist_sync_fn:
Function that performs the allgather option on the metric state. Default is a custom implementation that calls
torch.distributed.all_gather
internally.
- distributed_available_fn:
Function that checks if the distributed backend is available. Defaults to a check of
torch.distributed.is_available()
andtorch.distributed.is_initialized()
.
- sync_on_compute:
If metric state should synchronize when
compute
is called. Default isTrue
.
- compute_with_cache:
If results from
compute
should be cached. Default isTrue
.
Compute the balanced accuracy.
The balanced accuracy in binary, multiclass, and multilabel classification problems deals with imbalanced datasets. It is defined as the average of recall obtained on each class.
Parameters
- num_classesint
The number of classes in the target data.
- taskstr
The type of classification task, should be one of ‘binary’ or ‘multiclass’
- adjustedbool, optional (default=False)
When true, the result is adjusted for chance, so that random performance would score 0, while keeping perfect performance at a score of 1.
Attributes
- confmattorch.Tensor
Confusion matrix to keep track of true positives, false positives, true negatives, and false negatives.
Examples
>>> y_true = torch.tensor([0, 1, 0, 0, 1, 0]) >>> y_pred = torch.tensor([0, 1, 0, 0, 0, 1]) >>> metric = BalancedAccuracy(num_classes=2, task='binary') >>> metric(y_pred, y_true) 0.625
- adjusted = False
- num_classes
- task
- Parameters:
num_classes (int)
task (str)
adjusted (bool)