minerva.models.nets.image.wisenet
Classes
Simple pipeline for supervised models. |
|
Base class for all neural network modules. |
Module Contents
- class minerva.models.nets.image.wisenet.WiseNet(in_channels=1, out_channels=1, loss_fn=None, learning_rate=0.001, **kwargs)[source]
Bases:
minerva.models.nets.base.SimpleSupervisedModel
Simple pipeline for supervised models.
This class implements a very common deep learning pipeline, which is composed by the following steps:
Make a forward pass with the input data on the backbone model;
Make a forward pass with the input data on the fc model;
Compute the loss between the output and the label data;
Optimize the model (backbone and FC) parameters with respect to the loss.
This reduces the code duplication for autoencoder models, and makes it easier to implement new models by only changing the backbone model. More complex models, that does not follow this pipeline, should not inherit from this class. Note that, for this class the input data is a tuple of tensors, where the first tensor is the input data and the second tensor is the mask or label.
Initialize the model with the backbone, fc, loss function and metrics. Metrics are used to evaluate the model during training, validation, testing or prediction. It will be logged using lightning logger at the end of each epoch. Metrics should implement the torchmetrics.Metric interface.
Parameters
- backbonetorch.nn.Module
The backbone model. Usually the encoder/decoder part of the model.
- fctorch.nn.Module
The fully connected model, usually used to classification tasks. Use torch.nn.Identity() if no FC model is needed.
- loss_fntorch.nn.Module
The function used to compute the loss.
- learning_ratefloat, optional
The learning rate to Adam optimizer, by default 1e-3
- flattenbool, optional
If True the input data will be flattened before passing through the fc model, by default True
- train_metricsDict[str, Metric], optional
The metrics to be used during training, by default None
- val_metricsDict[str, Metric], optional
The metrics to be used during validation, by default None
- test_metricsDict[str, Metric], optional
The metrics to be used during testing, by default None
- predict_metricsDict[str, Metric], optional
The metrics to be used during prediction, by default None
- _single_step(batch, batch_idx, step_name)[source]
Perform a single train/validation/test step. It consists in making a forward pass with the input data on the backbone model, computing the loss between the output and the input data, and logging the loss.
Parameters
- batchtorch.Tensor
The input data. It must be a 2-element tuple of tensors, where the first tensor is the input data and the second tensor is the mask.
- batch_idxint
The index of the batch.
- step_namestr
The name of the step. It will be used to log the loss. The possible values are: “train”, “val” and “test”. The loss will be logged as “{step_name}_loss”.
Returns
- torch.Tensor
A tensor with the loss value.
- Parameters:
batch (torch.Tensor)
batch_idx (int)
step_name (str)
- Return type:
torch.Tensor
- predict_step(batch, batch_idx, dataloader_idx=None)[source]
Step function called during
predict()
. By default, it callsforward()
. Override to add any processing logic.The
predict_step()
is used to scale inference on multi-devices.To prevent an OOM error, it is possible to use
BasePredictionWriter
callback to write the predictions to disk or database after each batch or on epoch end.The
BasePredictionWriter
should be used while using a spawn based accelerator. This happens forTrainer(strategy="ddp_spawn")
or training on 8 TPU cores withTrainer(accelerator="tpu", devices=8)
as predictions won’t be returned.- Args:
batch: The output of your data iterable, normally a
DataLoader
. batch_idx: The index of this batch. dataloader_idx: The index of the dataloader that produced this batch.(only if multiple dataloaders used)
- Return:
Predicted output (optional).
Example
class MyModel(LightningModule): def predict_step(self, batch, batch_idx, dataloader_idx=0): return self(batch) dm = ... model = MyModel() trainer = Trainer(accelerator="gpu", devices=2) predictions = trainer.predict(model, dm)
- Parameters:
in_channels (int)
out_channels (int)
loss_fn (torch.nn.Module)
learning_rate (float)
- class minerva.models.nets.image.wisenet._WiseNet(in_channels=1, out_channels=1)[source]
Bases:
torch.nn.Module
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes:
import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their parameters converted when you call
to()
, etc.Note
As per the example above, an
__init__()
call to the parent class must be made before assignment on the child.- Variables:
training (bool) – Boolean represents whether this module is in training or evaluation mode.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- conv1
- conv2
- conv3
- conv4
- conv5
- conv6
- conv7
- conv8
- in_channels = 1
- out_channels = 1
- pool1
- pool2
- pool3
- pool4
- relu