minerva.engines.patch_inferencer_engine

Classes

PatchInferencer

This class acts as a normal L.LightningModule that wraps a

PatchInferencerEngine

Main interface for Engine classes. Engines are used to alter the behavior of a model's prediction.

Module Contents

class minerva.engines.patch_inferencer_engine.PatchInferencer(model, input_shape, output_shape=None, weight_function=None, offsets=None, padding=None, return_tuple=None)[source]

Bases: lightning.LightningModule

This class acts as a normal L.LightningModule that wraps a SimpleSupervisedModel model allowing it to perform inference in patches. This is useful when the model’s default input size is smaller than the desired input size (sample size). In this case, the engine split the input tensor into patches, perform inference in each patch, and combine them into a single output of the desired size. The combination of patches can be parametrized by a weight_function allowing a customizable combination of patches (e.g, combining using weighted average). It is important to note that only model’s forward are wrapped, and, thus, any method that requires the forward method (e.g., training_step, predict_step) will be performed in patches, transparently to the user.

Wrap a SimpleSupervisedModel model’s forward method to perform inference in patches, transparently splitting the input tensor into patches, performing inference in each patch, and combining them into a single output of the desired size.

Parameters

modelSimpleSupervisedModel

Model to be wrapped.

input_shapeTuple[int, …]

Expected input shape of the wrapped model.

output_shapeTuple[int, …], optional

Expected output shape of the wrapped model. For models that return logits (e.g., classification models), the output_shape must include an additional dimension at the beginning to accommodate the number of output classes. For example, if the model processes an input tensor of shape (1, 128, 128) and outputs logits for 10 classes, the expected output_shape should be (10, 1, 128, 128). If the model does not return logits (e.g., return a tensor after applying an argmax operation, or a regression models that usually returns a tensor with the same shape as the input tensor), the output_shape should have the same number of dimensions as the input shape. Defaults to None, which assumes the output shape is the same as the input_shape parameter.

weight_function: Callable[[Tuple[int, …]], torch.Tensor], optional

Function that receives a tensor shape and returns the weights for each position of a tensor with the given shape. Useful when regions of the inference present diminishing performance when getting closer to borders, for instance.

offsetsList[Tuple[int, …]], optional

List of tuples with offsets that determine the shift of the initial position of the patch subdivision.

paddingDict[str, Any], optional
Dictionary describing padding strategy. Keys:
  • pad (mandatory): tuple with pad width (int) for each

    dimension, e.g.(0, 3, 3) when working with a tensor with 3 dimensions.

  • mode (optional): ‘constant’, ‘reflect’, ‘replicate’ or

    ‘circular’. Defaults to ‘constant’.

  • value (optional): fill value for ‘constant’. Defaults to 0.

If None, no padding is applied.

return_tuple: int, optional

Some models may return multiple outputs for a single sample (e.g., outputs from multiple auxiliary heads). This parameter is a integer that defines the number of outputs the model generates. By default, it is None, which indicates that the model produces a single output for a single input. When set, it indicates the number of outputs the model produces.

__call__(x)[source]
Parameters:

x (torch.Tensor)

Return type:

torch.Tensor

_single_step(batch, batch_idx, step_name)[source]

Perform a single step of the training/validation loop.

Parameters

batchtorch.Tensor

The input data.

batch_idxint

The index of the batch.

step_namestr

The name of the step, either “train” or “val”.

Returns

torch.Tensor

The loss value.

Parameters:
  • batch (torch.Tensor)

  • batch_idx (int)

  • step_name (str)

Return type:

torch.Tensor

forward(x)[source]

Perform inference in patches.

Parameters

xtorch.Tensor

Batch of input data.

Parameters:

x (torch.Tensor)

Return type:

torch.Tensor

model
patch_inferencer
test_step(batch, batch_idx)[source]

Operates on a single batch of data from the test set. In this step you’d normally generate examples or calculate anything of interest such as accuracy.

Args:

batch: The output of your data iterable, normally a DataLoader. batch_idx: The index of this batch. dataloader_idx: The index of the dataloader that produced this batch.

(only if multiple dataloaders used)

Return:
  • Tensor - The loss tensor

  • dict - A dictionary. Can include any keys, but must include the key 'loss'.

  • None - Skip to the next batch.

# if you have one test dataloader:
def test_step(self, batch, batch_idx): ...


# if you have multiple test dataloaders:
def test_step(self, batch, batch_idx, dataloader_idx=0): ...

Examples:

# CASE 1: A single test dataset
def test_step(self, batch, batch_idx):
    x, y = batch

    # implement your own
    out = self(x)
    loss = self.loss(out, y)

    # log 6 example images
    # or generated text... or whatever
    sample_imgs = x[:6]
    grid = torchvision.utils.make_grid(sample_imgs)
    self.logger.experiment.add_image('example_images', grid, 0)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs!
    self.log_dict({'test_loss': loss, 'test_acc': test_acc})

If you pass in multiple test dataloaders, test_step() will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.

# CASE 2: multiple test dataloaders
def test_step(self, batch, batch_idx, dataloader_idx=0):
    # dataloader_idx tells you which dataset this is.
    ...
Note:

If you don’t need to test you don’t need to implement this method.

Note:

When the test_step() is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of the test epoch, the model goes back to training mode and gradients are enabled.

Parameters:
  • batch (torch.Tensor)

  • batch_idx (int)

training_step(batch, batch_idx)[source]

Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.

Args:

batch: The output of your data iterable, normally a DataLoader. batch_idx: The index of this batch. dataloader_idx: The index of the dataloader that produced this batch.

(only if multiple dataloaders used)

Return:
  • Tensor - The loss tensor

  • dict - A dictionary which can include any keys, but must include the key 'loss' in the case of automatic optimization.

  • None - In automatic optimization, this will skip to the next batch (but is not supported for multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning the loss is not required.

In this step you’d normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.

Example:

def training_step(self, batch, batch_idx):
    x, y, z = batch
    out = self.encoder(x)
    loss = self.loss(out, x)
    return loss

To use multiple optimizers, you can switch to ‘manual optimization’ and control their stepping:

def __init__(self):
    super().__init__()
    self.automatic_optimization = False


# Multiple optimizers (e.g.: GANs)
def training_step(self, batch, batch_idx):
    opt1, opt2 = self.optimizers()

    # do training_step with encoder
    ...
    opt1.step()
    # do training_step with decoder
    ...
    opt2.step()
Note:

When accumulate_grad_batches > 1, the loss returned here will be automatically normalized by accumulate_grad_batches internally.

Parameters:
  • batch (torch.Tensor)

  • batch_idx (int)

validation_step(batch, batch_idx)[source]

Operates on a single batch of data from the validation set. In this step you’d might generate examples or calculate anything of interest like accuracy.

Args:

batch: The output of your data iterable, normally a DataLoader. batch_idx: The index of this batch. dataloader_idx: The index of the dataloader that produced this batch.

(only if multiple dataloaders used)

Return:
  • Tensor - The loss tensor

  • dict - A dictionary. Can include any keys, but must include the key 'loss'.

  • None - Skip to the next batch.

# if you have one val dataloader:
def validation_step(self, batch, batch_idx): ...


# if you have multiple val dataloaders:
def validation_step(self, batch, batch_idx, dataloader_idx=0): ...

Examples:

# CASE 1: A single validation dataset
def validation_step(self, batch, batch_idx):
    x, y = batch

    # implement your own
    out = self(x)
    loss = self.loss(out, y)

    # log 6 example images
    # or generated text... or whatever
    sample_imgs = x[:6]
    grid = torchvision.utils.make_grid(sample_imgs)
    self.logger.experiment.add_image('example_images', grid, 0)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs!
    self.log_dict({'val_loss': loss, 'val_acc': val_acc})

If you pass in multiple val dataloaders, validation_step() will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.

# CASE 2: multiple validation dataloaders
def validation_step(self, batch, batch_idx, dataloader_idx=0):
    # dataloader_idx tells you which dataset this is.
    ...
Note:

If you don’t need to validate you don’t need to implement this method.

Note:

When the validation_step() is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.

Parameters:
  • batch (torch.Tensor)

  • batch_idx (int)

Parameters:
  • model (minerva.models.nets.base.SimpleSupervisedModel)

  • input_shape (Tuple[int, Ellipsis])

  • output_shape (Optional[Tuple[int, Ellipsis]])

  • weight_function (Optional[Callable[[Tuple[int, Ellipsis]], torch.Tensor]])

  • offsets (Optional[List[Tuple[int, Ellipsis]]])

  • padding (Optional[Dict[str, Any]])

  • return_tuple (Optional[int])

class minerva.engines.patch_inferencer_engine.PatchInferencerEngine(input_shape, output_shape=None, offsets=None, padding=None, weight_function=None, return_tuple=None)[source]

Bases: minerva.engines.engine._Engine

Main interface for Engine classes. Engines are used to alter the behavior of a model’s prediction. An engine should be able to take a model and input data x and return a prediction. An use case for Engines is patched inference, where the model’s default input size is smaller them the desired input size. The engine can be used to make predictions in patches and combine this predictions in to a single output.

Parameters

input_shapeTuple[int, …]

Shape of each patch to process.

output_shapeTuple[int, …], optional

Expected output shape of the model. For models that return logits, the output_shape must include an additional dimension at the beginning to accommodate the number of output classes. Else, the output_shape should have the same number of dimensions as the input_shape (i.e., no logits are returned). Defaults to input_shape.

paddingDict[str, Any], optional
Padding configuration with keys:
  • ‘pad’: Tuple of padding for each expected final dimension,

    e.g., (0, 512, 512) - (c, h, w).

  • ‘mode’: Padding mode, e.g., ‘constant’, ‘reflect’.

  • ‘value’: Padding value if mode is ‘constant’.

Defaults to None, which means no padding is applyied.

weight_functionCallable, optional

Function to calculate the weight of each patch. Defaults to None.

return_tupleint, optional

Number of outputs to return. This is useful when the model returns multiple outputs for a single input (e.g., from multiple auxiliary heads). Defaults to None.

__call__(model, x)[source]

Perform inference in patches, from the input tensor x using the model model.

Parameters

model: Union[L.LightningModule, torch.nn.Module]

Model to perform inference.

xtorch.Tensor

Input tensor of the sample. It can be a single sample or a batch of samples.

Parameters:
  • model (Union[lightning.LightningModule, torch.nn.Module])

  • x (torch.Tensor)

_adjust_patches(arrays, ref_shape, offset, pad_value=0)[source]

Pads reconstructed patches with pad_value to have same shape as the reference shape from the base patch set.

Parameters:
  • arrays (List[torch.Tensor])

  • ref_shape (Tuple[int])

  • offset (Tuple[int])

  • pad_value (int)

Return type:

List[torch.Tensor]

_combine_patches(results, offsets, indexes)[source]

Performs the combination of patches based on the weight function.

Parameters:
  • results (List[torch.Tensor])

  • offsets (List[Tuple[int]])

  • indexes (List[Tuple[int]])

Return type:

torch.Tensor

_compute_base_padding(tensor)[source]

Computes the padding for the base patch set based on the input tensor shape and the model’s input shape.

Parameters:

tensor (torch.Tensor)

_compute_output_shape(tensor)[source]

Computes PatchInferencer output shape based on input tensor shape, and model’s input and output shapes.

Parameters:

tensor (torch.Tensor)

Return type:

Tuple[int]

_extract_patches(data, patch_shape)[source]

Patch extraction method. It will be called once for the base patch set and also for the requested offsets (overlapping patch sets).

Parameters:
  • data (torch.Tensor)

  • patch_shape (Tuple[int])

Return type:

Tuple[torch.Tensor, Tuple[int]]

_reconstruct_patches(patches, index)[source]

Rearranges patches to reconstruct area of interest from patches and weights.

Parameters:
  • patches (torch.Tensor)

  • index (Tuple[int])

Return type:

Tuple[torch.Tensor, torch.Tensor]

input_shape
logits_dim
output_shape
output_simplified_shape = ()
return_tuple = None
weight_function = None
Parameters:
  • input_shape (Tuple[int, Ellipsis])

  • output_shape (Optional[Tuple[int, Ellipsis]])

  • offsets (Optional[List[Tuple[int, Ellipsis]]])

  • padding (Optional[Dict[str, Any]])

  • weight_function (Optional[Callable])

  • return_tuple (Optional[int])