minerva.analysis.metrics

Submodules

Classes

BalancedAccuracy

Base class for all metrics present in the Metrics API.

PixelAccuracy

Base class for all metrics present in the Metrics API.

Package Contents

class minerva.analysis.metrics.BalancedAccuracy(num_classes, task, adjusted=False)[source]

Bases: torchmetrics.Metric

Base class for all metrics present in the Metrics API.

This class is inherited by all metrics and implements the following functionality:

  1. Handles the transfer of metric states to the correct device.

  2. Handles the synchronization of metric states across processes.

  3. Provides properties and methods to control the overall behavior of the metric and its states.

The three core methods of the base class are: add_state(), forward() and reset() which should almost never be overwritten by child classes. Instead, the following methods should be overwritten update() and compute().

Args:

kwargs: additional keyword arguments, see Metric kwargs for more info.

  • compute_on_cpu:

    If metric state should be stored on CPU during computations. Only works for list states.

  • dist_sync_on_step:

    If metric state should synchronize on forward(). Default is False.

  • process_group:

    The process group on which the synchronization is called. Default is the world.

  • dist_sync_fn:

    Function that performs the allgather option on the metric state. Default is a custom implementation that calls torch.distributed.all_gather internally.

  • distributed_available_fn:

    Function that checks if the distributed backend is available. Defaults to a check of torch.distributed.is_available() and torch.distributed.is_initialized().

  • sync_on_compute:

    If metric state should synchronize when compute is called. Default is True.

  • compute_with_cache:

    If results from compute should be cached. Default is True.

Compute the balanced accuracy.

The balanced accuracy in binary, multiclass, and multilabel classification problems deals with imbalanced datasets. It is defined as the average of recall obtained on each class.

Parameters

num_classesint

The number of classes in the target data.

taskstr

The type of classification task, should be one of ‘binary’ or ‘multiclass’

adjustedbool, optional (default=False)

When true, the result is adjusted for chance, so that random performance would score 0, while keeping perfect performance at a score of 1.

Attributes

confmattorch.Tensor

Confusion matrix to keep track of true positives, false positives, true negatives, and false negatives.

Examples

>>> y_true = torch.tensor([0, 1, 0, 0, 1, 0])
>>> y_pred = torch.tensor([0, 1, 0, 0, 0, 1])
>>> metric = BalancedAccuracy(num_classes=2, task='binary')
>>> metric(y_pred, y_true)
0.625
adjusted = False
compute()[source]
num_classes
task
update(preds, target)[source]
Parameters:
  • preds (torch.Tensor)

  • target (torch.Tensor)

Parameters:
  • num_classes (int)

  • task (str)

  • adjusted (bool)

class minerva.analysis.metrics.PixelAccuracy(dist_sync_on_step=False)[source]

Bases: torchmetrics.Metric

Base class for all metrics present in the Metrics API.

This class is inherited by all metrics and implements the following functionality:

  1. Handles the transfer of metric states to the correct device.

  2. Handles the synchronization of metric states across processes.

  3. Provides properties and methods to control the overall behavior of the metric and its states.

The three core methods of the base class are: add_state(), forward() and reset() which should almost never be overwritten by child classes. Instead, the following methods should be overwritten update() and compute().

Args:

kwargs: additional keyword arguments, see Metric kwargs for more info.

  • compute_on_cpu:

    If metric state should be stored on CPU during computations. Only works for list states.

  • dist_sync_on_step:

    If metric state should synchronize on forward(). Default is False.

  • process_group:

    The process group on which the synchronization is called. Default is the world.

  • dist_sync_fn:

    Function that performs the allgather option on the metric state. Default is a custom implementation that calls torch.distributed.all_gather internally.

  • distributed_available_fn:

    Function that checks if the distributed backend is available. Defaults to a check of torch.distributed.is_available() and torch.distributed.is_initialized().

  • sync_on_compute:

    If metric state should synchronize when compute is called. Default is True.

  • compute_with_cache:

    If results from compute should be cached. Default is True.

Initializes a PixelAccuracy metric object.

Parameters

dist_sync_on_step: bool, optional

Whether to synchronize metric state across processes at each step. Defaults to False.

compute()[source]

Computes the pixel accuracy.

Returns:

float: The pixel accuracy.

Return type:

float

update(preds, target)[source]

Updates the metric state with the predictions and targets.

Parameters

preds: torch.Tensor

The predicted tensor.

target:

torch.Tensor The target tensor.

Parameters:
  • preds (torch.Tensor)

  • target (torch.Tensor)

Parameters:

dist_sync_on_step (bool)