Source code for minerva.models.nets.image.unetplusplus_resnet50

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import lightning as L
from torchmetrics import Accuracy, F1Score, JaccardIndex, Precision, Recall
from torchvision.models import resnet50, ResNet50_Weights
from minerva.losses.dice import MultiClassDiceCELoss
from typing import Any, Callable, Optional, Tuple, Union, List


[docs] class DeepLabV3ResNet50Backbone(nn.Module): def __init__(self, pretrained: bool = True) -> None: """Notes ----- The dilation rates are set as follows: - layer3: dilation=2, stride=1 - layer4: dilation=4, stride=1 This preserves the spatial resolution at H/8, W/8 for c3, c4, and c5. References ---------- .. [1] He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 770-778). """ super(DeepLabV3ResNet50Backbone, self).__init__() weights = ResNet50_Weights.IMAGENET1K_V1 if pretrained else None resnet = resnet50( weights=weights, replace_stride_with_dilation=[False, True, True] ) self.conv1 = resnet.conv1 self.bn1 = resnet.bn1 self.relu = resnet.relu self.maxpool = resnet.maxpool self.layer1 = resnet.layer1 self.layer2 = resnet.layer2 self.layer3 = resnet.layer3 self.layer4 = resnet.layer4
[docs] def forward( self, x: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Forward pass through the ResNet50 backbone. Parameters ---------- x : torch.Tensor Input tensor of shape (batch_size, 3, H, W). """ c1 = self.relu(self.bn1(self.conv1(x))) c1 = self.maxpool(c1) c2 = self.layer1(c1) c3 = self.layer2(c2) c4 = self.layer3(c3) c5 = self.layer4(c4) return c1, c2, c3, c4, c5
[docs] class ConvBlock(nn.Module): def __init__(self, in_channels: int, out_channels: int) -> None: """Double convolution block for UNet++ decoder. Applies two consecutive 3x3 convolutions, each followed by batch normalization and ReLU activation, commonly used in U-Net architectures for feature refinement. Parameters ---------- in_channels : int Number of input channels. out_channels : int Number of output channels. """ super(ConvBlock, self).__init__() self.conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass through the double convolution block. Parameters ---------- x : torch.Tensor Input feature tensor of shape (batch_size, in_channels, H, W). """ return self.conv(x)
[docs] class UNetPlusPlusDeepLabV3(nn.Module): """UNet++ with DeepLabV3 ResNet50 backbone for semantic segmentation. Combines DeepLabV3's multi-scale feature extraction with UNet++'s nested skip connections for robust semantic segmentation, particularly suited for seismic image segmentation tasks. Parameters ---------- in_channels : int, optional Number of input channels (default is 3 for RGB images). num_classes : int, optional Number of segmentation classes (default is 6). deep_supervision : bool, optional If True, enables deep supervision with auxiliary losses (default is True). pretrained : bool, optional If True, uses ImageNet pre-trained ResNet50 backbone (default is True). """ def __init__( self, in_channels: int = 3, num_classes: int = 6, deep_supervision: bool = True, pretrained: bool = True, ) -> None: """Notes ----- The architecture includes: - DeepLabV3 ResNet50 backbone with dilated convolutions. - UNet++ decoder with nested skip connections for feature refinement. - Bilinear upsampling to restore original input resolution. - Optional deep supervision for improved training stability. References ---------- .. [1] Zhou, Z., Rahman Siddiquee, M. M., Tajbakhsh, N., & Liang, J. (2018). Unet++: A nested u-net architecture for medical image segmentation. In Deep learning in medical image analysis and multimodal learning for clinical decision support (pp. 3-11). Springer. .. [2] He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 770-778). """ super(UNetPlusPlusDeepLabV3, self).__init__() self.deep_supervision = deep_supervision filters = [128, 256, 512, 1024] self.backbone = DeepLabV3ResNet50Backbone(pretrained=pretrained) self.proj0 = nn.Conv2d(256, filters[0], kernel_size=1) self.proj1 = nn.Conv2d(512, filters[1], kernel_size=1) self.proj2 = nn.Conv2d(1024, filters[2], kernel_size=1) self.proj3 = nn.Conv2d(2048, filters[3], kernel_size=1) self.up = nn.ModuleList( [ nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) for _ in range(3) ] ) self.up_conv = nn.ModuleList( [nn.Conv2d(filters[i + 1], filters[i], kernel_size=1) for i in range(3)] ) self.conv0_1 = ConvBlock(2 * filters[0], filters[0]) self.conv1_1 = ConvBlock(2 * filters[1], filters[1]) self.conv2_1 = ConvBlock(2 * filters[2], filters[2]) self.conv0_2 = ConvBlock(3 * filters[0], filters[0]) self.conv1_2 = ConvBlock(3 * filters[1], filters[1]) self.conv0_3 = ConvBlock(4 * filters[0], filters[0]) if self.deep_supervision: self.final1 = nn.Conv2d(filters[0], num_classes, kernel_size=1) self.final2 = nn.Conv2d(filters[0], num_classes, kernel_size=1) self.final3 = nn.Conv2d(filters[0], num_classes, kernel_size=1) else: self.final = nn.Conv2d(filters[0], num_classes, kernel_size=1)
[docs] def forward(self, x: torch.Tensor) -> Union[torch.Tensor, List[torch.Tensor]]: """Forward pass through the UNet++ DeepLabV3 network. Parameters ---------- x : torch.Tensor Input tensor of shape (batch_size, in_channels, H, W). """ input_size = x.shape[2:] c1, c2, c3, c4, c5 = self.backbone(x) x0_0 = self.proj0(c2) x1_0 = self.proj1(c3) x2_0 = self.proj2(c4) x3_0 = self.proj3(c5) x1_1 = self.conv1_1(torch.cat([x1_0, self.up_conv[1](x2_0)], dim=1)) x2_1 = self.conv2_1(torch.cat([x2_0, self.up_conv[2](x3_0)], dim=1)) x1_0_up = self._safe_upsample(self.up_conv[0](self.up[0](x1_0)), x0_0) x0_1 = self.conv0_1(torch.cat([x0_0, x1_0_up], dim=1)) x2_1_up = self._safe_upsample(self.up_conv[1](x2_1), x1_0) x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, x2_1_up], dim=1)) x1_1_up = self._safe_upsample(self.up_conv[0](self.up[0](x1_1)), x0_0) x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, x1_1_up], dim=1)) x1_2_up = self._safe_upsample(self.up_conv[0](self.up[0](x1_2)), x0_0) x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, x1_2_up], dim=1)) if self.deep_supervision: output1 = self.final1(x0_1) output2 = self.final2(x0_2) output3 = self.final3(x0_3) output1 = F.interpolate( output1, size=input_size, mode="bilinear", align_corners=True ) output2 = F.interpolate( output2, size=input_size, mode="bilinear", align_corners=True ) output3 = F.interpolate( output3, size=input_size, mode="bilinear", align_corners=True ) return [output1, output2, output3] else: output = self.final(x0_3) output = F.interpolate( output, size=input_size, mode="bilinear", align_corners=True ) return output
[docs] def _safe_upsample(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Upsample tensor to match target spatial dimensions. Handles cases where spatial dimensions may not align due to rounding errors in pooling or upsampling operations. Parameters ---------- x : torch.Tensor Tensor to be upsampled. target : torch.Tensor Reference tensor with target spatial dimensions. """ if x.size()[2:] != target.size()[2:]: x = F.interpolate( x, size=target.size()[2:], mode="bilinear", align_corners=True ) return x
[docs] class LitUNetPlusPlusDeepLabV3(L.LightningModule): def __init__( self, in_channels: int = 3, num_classes: int = 6, deep_supervision: bool = True, lr: float = 3e-4, pretrained: bool = True, ) -> None: """PyTorch Lightning module for UNet++ with DeepLabV3 backbone. Wraps the UNet++ DeepLabV3 model with training, validation, and testing loops, optimizer configuration, and metrics for multi-class segmentation. Parameters ---------- in_channels : int, optional Number of input image channels (default is 3). num_classes : int, optional Number of segmentation classes (default is 6). deep_supervision : bool, optional If True, enables deep supervision training (default is True). lr : float, optional Learning rate for the optimizer (default is 3e-4). pretrained : bool, optional If True, uses ImageNet pre-trained backbone (default is True). Notes ----- Metrics include accuracy, F1-score, mean IoU, precision, and recall, all computed using torchmetrics. """ super().__init__() self.save_hyperparameters() self.model = UNetPlusPlusDeepLabV3( in_channels=in_channels, num_classes=num_classes, deep_supervision=deep_supervision, pretrained=pretrained, ) self.loss_fn = MultiClassDiceCELoss() self.lr = lr self.num_classes = num_classes self.test_accuracy = Accuracy(task="multiclass", num_classes=num_classes) self.test_f1 = F1Score( task="multiclass", num_classes=num_classes, average="macro" ) self.test_miou = JaccardIndex(task="multiclass", num_classes=num_classes) self.test_precision = Precision(task="multiclass", num_classes=num_classes) self.test_recall = Recall(task="multiclass", num_classes=num_classes)
[docs] def forward(self, x: torch.Tensor) -> Union[torch.Tensor, List[torch.Tensor]]: """Forward pass through the model. Parameters ---------- x : torch.Tensor Input batch of images of shape (batch_size, in_channels, H, W). """ return self.model(x)
[docs] def training_step( self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int ) -> torch.Tensor: """Training step for one batch. Parameters ---------- batch : tuple of torch.Tensor Batch containing (images, masks). batch_idx : int Index of the current batch. """ imgs, masks = batch preds = self(imgs) loss = self.loss_fn(preds, masks) self.log("train_loss", loss, prog_bar=True) return loss
[docs] def validation_step( self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int ) -> torch.Tensor: """Validation step for one batch. Parameters ---------- batch : tuple of torch.Tensor Batch containing (images, masks). batch_idx : int Index of the current batch. """ imgs, masks = batch preds = self(imgs) loss = self.loss_fn(preds, masks) self.log("val_loss", loss, prog_bar=True) return loss
[docs] def test_step( self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int ) -> torch.Tensor: """Test step for one batch with metrics evaluation. Parameters ---------- batch : tuple of torch.Tensor Batch containing (images, masks). batch_idx : int Index of the current batch. """ imgs, masks = batch preds = self(imgs) if isinstance(preds, list): final_preds = preds[-1] else: final_preds = preds loss = self.loss_fn(preds, masks) self.log("test_loss", loss) self.test_accuracy(final_preds, masks) self.test_f1(final_preds, masks) self.test_miou(final_preds, masks) self.test_precision(final_preds, masks) self.test_recall(final_preds, masks) self.log("test_accuracy", self.test_accuracy, on_step=False, on_epoch=True) self.log("test_f1", self.test_f1, on_step=False, on_epoch=True) self.log("test_miou", self.test_miou, on_step=False, on_epoch=True) self.log("test_precision", self.test_precision, on_step=False, on_epoch=True) self.log("test_recall", self.test_recall, on_step=False, on_epoch=True) return loss
[docs] def configure_optimizers(self) -> torch.optim.Optimizer: """Configure optimizer for training.""" optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) return optimizer