Source code for minerva.models.nets.image.wisenet

import torch

from minerva.models.nets.base import SimpleSupervisedModel


[docs] class _WiseNet(torch.nn.Module): def __init__(self, in_channels=1, out_channels=1): super().__init__() self.in_channels = in_channels self.out_channels = out_channels # self.norm = nn.BatchNorm3d(1) self.conv1 = torch.nn.Conv3d( in_channels=in_channels, out_channels=32, kernel_size=3, stride=1, padding=1, ) self.relu = torch.nn.ReLU() self.pool1 = torch.nn.MaxPool3d(kernel_size=2, stride=2, ceil_mode=True) self.conv2 = torch.nn.Conv3d( in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1 ) self.pool2 = torch.nn.MaxPool3d(kernel_size=2, stride=2, ceil_mode=True) self.conv3 = torch.nn.Conv3d( in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1 ) self.pool3 = torch.nn.MaxPool3d(kernel_size=2, stride=2, ceil_mode=True) self.conv4 = torch.nn.Conv3d( in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1, ) self.pool4 = torch.nn.MaxPool3d(kernel_size=(2, 1, 1), stride=(2, 1, 1)) self.conv5 = torch.nn.ConvTranspose2d( in_channels=256, out_channels=128, kernel_size=3, stride=2, padding=1, output_padding=1, ) self.conv6 = torch.nn.ConvTranspose2d( in_channels=128, out_channels=64, kernel_size=3, stride=2, padding=1, output_padding=1, ) self.conv7 = torch.nn.ConvTranspose2d( in_channels=64, out_channels=32, kernel_size=3, stride=2, padding=1, output_padding=1, ) self.conv8 = torch.nn.Conv2d( in_channels=32, out_channels=out_channels, kernel_size=3, stride=1, padding=1, )
[docs] def forward(self, x): x = self.conv1(x) x = self.relu(x) x = self.pool1(x) x = self.conv2(x) x = self.relu(x) x = self.pool2(x) x = self.conv3(x) x = self.relu(x) x = self.pool3(x) x = self.conv4(x) x = self.relu(x) x = self.pool4(x) x = x.view( x.size(0), x.size(1), x.size(3), x.size(4) ) # (batch_size, channels, height, width) x = self.conv5(x) x = self.relu(x) x = self.conv6(x) x = self.relu(x) x = self.conv7(x) x = self.relu(x) x = self.conv8(x) return x
[docs] class WiseNet(SimpleSupervisedModel): def __init__( self, in_channels: int = 1, out_channels: int = 1, loss_fn: torch.nn.Module = None, learning_rate: float = 1e-3, **kwargs, ): super().__init__( backbone=_WiseNet(in_channels=in_channels, out_channels=out_channels), fc=torch.nn.Identity(), loss_fn=loss_fn or torch.nn.MSELoss(), learning_rate=learning_rate, flatten=False, **kwargs, )
[docs] def _single_step( self, batch: torch.Tensor, batch_idx: int, step_name: str ) -> torch.Tensor: x, y = batch y_hat = self.forward(x) y_hat = y_hat[:, :, : y.size(2), : y.size(3)] loss = self._loss_func(y_hat, y) self.log( f"{step_name}_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, ) return loss
[docs] def predict_step(self, batch, batch_idx, dataloader_idx=None): x, y = batch y_hat = self.forward(x) y_hat = y_hat[:, :, : y.size(2), : y.size(3)] return y_hat