import torch
from minerva.models.nets.base import SimpleSupervisedModel
[docs]
class _WiseNet(torch.nn.Module):
def __init__(self, in_channels=1, out_channels=1):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
# self.norm = nn.BatchNorm3d(1)
self.conv1 = torch.nn.Conv3d(
in_channels=in_channels,
out_channels=32,
kernel_size=3,
stride=1,
padding=1,
)
self.relu = torch.nn.ReLU()
self.pool1 = torch.nn.MaxPool3d(kernel_size=2, stride=2, ceil_mode=True)
self.conv2 = torch.nn.Conv3d(
in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1
)
self.pool2 = torch.nn.MaxPool3d(kernel_size=2, stride=2, ceil_mode=True)
self.conv3 = torch.nn.Conv3d(
in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1
)
self.pool3 = torch.nn.MaxPool3d(kernel_size=2, stride=2, ceil_mode=True)
self.conv4 = torch.nn.Conv3d(
in_channels=128,
out_channels=256,
kernel_size=3,
stride=1,
padding=1,
)
self.pool4 = torch.nn.MaxPool3d(kernel_size=(2, 1, 1), stride=(2, 1, 1))
self.conv5 = torch.nn.ConvTranspose2d(
in_channels=256,
out_channels=128,
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
)
self.conv6 = torch.nn.ConvTranspose2d(
in_channels=128,
out_channels=64,
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
)
self.conv7 = torch.nn.ConvTranspose2d(
in_channels=64,
out_channels=32,
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
)
self.conv8 = torch.nn.Conv2d(
in_channels=32,
out_channels=out_channels,
kernel_size=3,
stride=1,
padding=1,
)
[docs]
def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.pool1(x)
x = self.conv2(x)
x = self.relu(x)
x = self.pool2(x)
x = self.conv3(x)
x = self.relu(x)
x = self.pool3(x)
x = self.conv4(x)
x = self.relu(x)
x = self.pool4(x)
x = x.view(
x.size(0), x.size(1), x.size(3), x.size(4)
) # (batch_size, channels, height, width)
x = self.conv5(x)
x = self.relu(x)
x = self.conv6(x)
x = self.relu(x)
x = self.conv7(x)
x = self.relu(x)
x = self.conv8(x)
return x
[docs]
class WiseNet(SimpleSupervisedModel):
def __init__(
self,
in_channels: int = 1,
out_channels: int = 1,
loss_fn: torch.nn.Module = None,
learning_rate: float = 1e-3,
**kwargs,
):
super().__init__(
backbone=_WiseNet(in_channels=in_channels, out_channels=out_channels),
fc=torch.nn.Identity(),
loss_fn=loss_fn or torch.nn.MSELoss(),
learning_rate=learning_rate,
flatten=False,
**kwargs,
)
[docs]
def _single_step(
self, batch: torch.Tensor, batch_idx: int, step_name: str
) -> torch.Tensor:
x, y = batch
y_hat = self.forward(x)
y_hat = y_hat[:, :, : y.size(2), : y.size(3)]
loss = self._loss_func(y_hat, y)
self.log(
f"{step_name}_loss",
loss,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
)
return loss
[docs]
def predict_step(self, batch, batch_idx, dataloader_idx=None):
x, y = batch
y_hat = self.forward(x)
y_hat = y_hat[:, :, : y.size(2), : y.size(3)]
return y_hat