Source code for minerva.models.nets.image.deeplabv3

from typing import Any, Dict, Optional, Sequence, Tuple
import os
from collections import OrderedDict
from torch import Tensor, nn, optim
from torchmetrics import Metric
from torchvision.models import ResNet50_Weights
from torchvision.models.resnet import resnet50
from torchvision.models.segmentation.deeplabv3 import ASPP
import torch
from minerva.models.nets.base import SimpleSupervisedModel


[docs] class DeepLabV3(SimpleSupervisedModel): """A DeeplabV3 with a ResNet50 backbone References ---------- Liang-Chieh Chen, George Papandreou, Florian Schroff, Hartwig Adam. "Rethinking Atrous Convolution for Semantic Image Segmentation", 2017 """ def __init__( self, backbone: Optional[nn.Module] = None, pred_head: Optional[nn.Module] = None, loss_fn: Optional[nn.Module] = None, learning_rate: float = 0.001, num_classes: int = 6, pretrained: bool = False, weights_path: Optional[str] = None, train_metrics: Optional[Dict[str, Metric]] = None, val_metrics: Optional[Dict[str, Metric]] = None, test_metrics: Optional[Dict[str, Metric]] = None, optimizer: type = torch.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, lr_scheduler: Optional[type] = None, lr_scheduler_kwargs: Optional[Dict[str, Any]] = None, output_shape: Optional[Tuple[int, ...]] = None, freeze_backbone: bool = False, interpolate_mode: Optional[str] = "bilinear", flatten: bool = False, loss_squeeze: bool = True, loss_long: bool = True, ): """ Initializes a DeepLabV3 model. Parameters ---------- backbone: Optional[nn.Module] The backbone network. Defaults to None, which will use a ResNet50 backbone. pred_head: Optional[nn.Module] The prediction head network. Defaults to None, which will use a DeepLabV3PredictionHead with specified number of classes. loss_fn: Optional[nn.Module] The loss function. Defaults to None, which will use a CrossEntropyLoss. learning_rate: float The learning rate for the optimizer. Defaults to 0.001. num_classes: int The number of classes for prediction. Defaults to 6. pretrained: bool Whether to use pretrained weights. Defaults to False. weights_path: Optional[str] Path to local pretrained weights file. If provided with pretrained=True, loads weights from this path instead of downloading. Defaults to None. train_metrics: Optional[Dict[str, Metric]] The metrics to be computed during training. Defaults to None. val_metrics: Optional[Dict[str, Metric]] The metrics to be computed during validation. Defaults to None. test_metrics: Optional[Dict[str, Metric]] The metrics to be computed during testing. Defaults to None. optimizer: type Optimizer class to be instantiated. By default, it is set to `torch.optim.Adam`. Should be a subclass of `torch.optim.Optimizer` (e.g., `torch.optim.SGD`). optimizer_kwargs : dict, optional Additional kwargs passed to the optimizer constructor. lr_scheduler : type, optional Learning rate scheduler class to be instantiated. By default, it is set to None, which means no scheduler will be used. Should be a subclass of `torch.optim.lr_scheduler.LRScheduler` (e.g., `torch.optim.lr_scheduler.StepLR`). lr_scheduler_kwargs : dict, optional Additional kwargs passed to the scheduler constructor. output_shape: Optional[Tuple[int, ...]] The output shape of the model. If None, the output shape will be the same as the input shape. Defaults to None. This is useful for models that require a specific output shape, that is different from the input shape. freeze_backbone: bool Whether to freeze the backbone weights during training. Defaults to False. interpolate_mode: Optional[str] The interpolation mode to use when upscaling the output to the desired output shape. Defaults to "bilinear". Other options include "nearest", "bicubic", etc. See PyTorch documentation for `torch.nn.functional.interpolate` for all options. Use None to disable upscaling. flatten: bool Whether to flatten the output of the backbone before passing it to the prediction head. Defaults to False. Set to True for classification tasks where the prediction head is a fully connected layer. loss_squeeze: bool Whether to squeeze the target tensor in the loss function. Defaults to True. This is useful for segmentation tasks where the target tensor has a singleton channel dimension (e.g., shape (B, 1, H, W)) and the loss function expects shape (B, H, W). loss_long: bool Whether to convert the target tensor to long type in the loss function. Defaults to True. This is useful for classification tasks where the target tensor is of integer type. """ backbone = backbone or DeepLabV3Backbone( num_classes=num_classes, pretrained=pretrained, weights_path=weights_path ) pred_head = pred_head or DeepLabV3PredictionHead(num_classes=num_classes) loss_fn = loss_fn or nn.CrossEntropyLoss() self.output_shape = output_shape self.interpolate_mode = interpolate_mode self.squeeze_loss = loss_squeeze self.loss_long = loss_long super().__init__( backbone=backbone, fc=pred_head, loss_fn=loss_fn, learning_rate=learning_rate, train_metrics=train_metrics, val_metrics=val_metrics, test_metrics=test_metrics, optimizer=optimizer, optimizer_kwargs=optimizer_kwargs, lr_scheduler=lr_scheduler, lr_scheduler_kwargs=lr_scheduler_kwargs, freeze_backbone=freeze_backbone, flatten=flatten, )
[docs] def forward(self, x: Tensor) -> Tensor: """Performs the forward pass of the DeepLabV3 model. Parameters ---------- x : Tensor Input tensor of shape (batch_size, channels, height, width) Returns ------- Tensor Output tensor of shape (batch_size, num_classes, height, width) """ x = x.float() input_shape = self.output_shape or x.shape[-2:] h = self.backbone(x) if isinstance(h, OrderedDict): h = h["out"] if self.flatten: x = x.reshape(x.size(0), -1) if self.adapter is not None: x = self.adapter(x) z = self.fc(h) # Upscaling if self.interpolate_mode is None: return z return nn.functional.interpolate( z, size=input_shape, mode=self.interpolate_mode, align_corners=False )
[docs] def _loss_func(self, y_hat: Tensor, y: Tensor) -> Tensor: """Computes the loss between predictions and ground truth. Parameters ---------- y_hat : Tensor Predicted tensor of shape (batch_size, num_classes, height, width) y : Tensor Ground truth tensor of shape (batch_size, 1, height, width) """ if self.squeeze_loss: y = y.squeeze(1) # Remove channel dim if singleton if self.loss_long: y = y.long() return self.loss_fn(y_hat, y)
[docs] class DeepLabV3Backbone(nn.Module): """A ResNet50 backbone for DeepLabV3""" def __init__( self, num_classes: int = 6, pretrained: bool = False, weights_path: Optional[str] = None, ): """ Initializes the DeepLabV3 backbone model. Parameters ---------- num_classes: int The number of classes for classification. Default is 6. pretrained: bool Whether to use pretrained weights. If True and weights_path is None, will attempt to download ImageNet pretrained weights. Default is False. weights_path: Optional[str] Path to local pretrained weights file. If provided with pretrained=True, loads weights from this path instead of downloading. Default is None. """ super().__init__() if pretrained and weights_path is not None: # Validate file path exists if not os.path.exists(weights_path): raise FileNotFoundError(f"Weights file not found: {weights_path}") # Load from local weights file RN50model = resnet50(replace_stride_with_dilation=[False, True, True]) state_dict = torch.load(weights_path, map_location="cpu") # Handle different weight file formats if "state_dict" in state_dict: state_dict = state_dict["state_dict"] elif "model" in state_dict: state_dict = state_dict["model"] # Filter out classifier weights if they exist (fc layer) # since we don't use them in the backbone filtered_state_dict = { k: v for k, v in state_dict.items() if not k.startswith("fc.") } # Load the filtered state dict missing_keys, unexpected_keys = RN50model.load_state_dict( filtered_state_dict, strict=False ) if missing_keys: print( f"Warning: Missing keys when loading pretrained weights: {missing_keys}" ) if unexpected_keys: print( f"Warning: Unexpected keys when loading pretrained weights: {unexpected_keys}" ) print(f"Successfully loaded pretrained weights from {weights_path}") elif pretrained and weights_path is None: # Use torchvision's pretrained weights (requires internet) RN50model = resnet50( weights=ResNet50_Weights.IMAGENET1K_V1, replace_stride_with_dilation=[False, True, True], ) print("Successfully loaded ImageNet pretrained weights from torchvision") else: # No pretrained weights, random initialization RN50model = resnet50(replace_stride_with_dilation=[False, True, True]) self.RN50model = RN50model
[docs] def freeze_weights(self): """Freezes all parameters in the backbone, making them non-trainable.""" for param in self.RN50model.parameters(): param.requires_grad = False
[docs] def unfreeze_weights(self): """Unfreezes all parameters in the backbone, making them trainable.""" for param in self.RN50model.parameters(): param.requires_grad = True
[docs] def forward(self, x): """Performs the forward pass of the backbone. Parameters ---------- x : Tensor Input tensor of shape (batch_size, channels, height, width) Returns ------- Tensor Feature map tensor from the ResNet50 backbone """ x = self.RN50model.conv1(x) x = self.RN50model.bn1(x) x = self.RN50model.relu(x) x = self.RN50model.maxpool(x) x = self.RN50model.layer1(x) x = self.RN50model.layer2(x) x = self.RN50model.layer3(x) x = self.RN50model.layer4(x) # x = self.RN50model.avgpool(x) # These should be removed for deeplabv3 # x = torch.RN50model.flatten(x, 1) # These should be removed for deeplabv3 # x = self.RN50model.fc(x) # These should be removed for deeplabv3 return x
[docs] class DeepLabV3PredictionHead(nn.Sequential): """The prediction head for DeepLabV3""" def __init__( self, in_channels: int = 2048, num_classes: int = 6, atrous_rates: Sequence[int] = (12, 24, 36), ) -> None: """ Initializes the DeepLabV3 prediction head. Parameters ---------- in_channels: int Number of input channels. Defaults to 2048. num_classes: int Number of output classes. Defaults to 6. atrous_rates: Sequence[int] A sequence of atrous rates for the ASPP module. Defaults to (12, 24, 36). """ super().__init__( ASPP(in_channels, atrous_rates), nn.Conv2d(256, 256, 3, padding=1, bias=False), nn.BatchNorm2d(256), nn.ReLU(), nn.Conv2d(256, num_classes, 1), )
[docs] def forward(self, input) -> Tensor: """Performs the forward pass of the prediction head. Parameters ---------- input : Tensor Input tensor from the backbone Returns ------- Tensor Output tensor with class predictions """ assert input.shape[0] > 1, "Batch size must be greater than 1 due to BatchNorm" return super().forward(input)
[docs] class DeepLabV3RegressionHead(nn.Sequential): """Regression head for DeepLabV3 (continuous per-pixel/voxel prediction).""" def __init__( self, in_channels: int = 2048, out_channels: int = 1, atrous_rates: Sequence[int] = (12, 24, 36), ): """ Parameters ---------- in_channels : int Number of input channels from the backbone (typically 2048 for ResNet50). out_channels : int Number of output channels (1 for single regression target). atrous_rates : Sequence[int] Atrous (dilation) rates for ASPP. """ super().__init__( ASPP(in_channels, atrous_rates), nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, out_channels, kernel_size=1), )
[docs] def forward(self, x: Tensor) -> Tensor: # type: ignore """Forward pass for regression head.""" assert x.shape[0] > 1, "Batch size must be > 1 due to BatchNorm" return super().forward(x)