Source code for minerva.analysis.metrics.pixel_accuracy

import torch
from torchmetrics import Metric


[docs] class PixelAccuracy(Metric): def __init__(self, dist_sync_on_step: bool = False): """ Initializes a PixelAccuracy metric object. Parameters ---------- dist_sync_on_step: bool, optional Whether to synchronize metric state across processes at each step. Defaults to False. """ super().__init__(dist_sync_on_step=dist_sync_on_step) self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
[docs] def update(self, preds: torch.Tensor, target: torch.Tensor): """ Updates the metric state with the predictions and targets. Parameters ---------- preds: torch.Tensor The predicted tensor. target: torch.Tensor The target tensor. """ correct = torch.sum(preds == target) total = target.numel() self.correct += correct self.total += total
[docs] def compute(self) -> float: """ Computes the pixel accuracy. Returns: float: The pixel accuracy. """ return self.correct.float() / self.total