# Standard library imports
import math
from collections import OrderedDict
from functools import partial
from typing import Callable, List, Optional, Tuple, Union
# Third-party imports
import lightning as L
import numpy as np
import timm.models.vision_transformer
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.vision_transformer import Block, PatchEmbed
from minerva.models.nets.base import SimpleSupervisedModel
# Local imports
from minerva.utils.position_embedding import get_2d_sincos_pos_embed
###################################
############## SETR ###############
###################################
[docs]
class MMAdaptivePadding(nn.Module):
def __init__(
self,
kernel_size: Tuple[int, int],
stride: Tuple[int, int],
dilation: Tuple[int, int],
padding: str = "corner",
):
"""
Applies adaptive padding to the input tensor to ensure its dimensions are compatible
with a convolutional layer using a given kernel size, stride, and dilation.
Parameters
----------
kernel_size : Tuple[int, int]
Size of the convolution kernel.
stride : Tuple[int, int]
Stride of the convolution.
dilation : Tuple[int, int]
Dilation rate of the convolution.
padding : str, default="corner"
Padding mode. Options are "same" or "corner".
"""
super().__init__()
assert padding in ("same", "corner")
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
self.padding = padding
[docs]
def get_pad_shape(self, input_shape):
H, W = input_shape
kh, kw = self.kernel_size
sh, sw = self.stride
dh, dw = self.dilation
oh = math.ceil(H / sh)
ow = math.ceil(W / sw)
pad_h = max((oh - 1) * sh + (kh - 1) * dh + 1 - H, 0)
pad_w = max((ow - 1) * sw + (kw - 1) * dw + 1 - W, 0)
return pad_h, pad_w
[docs]
def forward(self, x):
pad_h, pad_w = self.get_pad_shape(x.shape[-2:])
if pad_h > 0 or pad_w > 0:
if self.padding == "corner":
x = F.pad(x, (0, pad_w, 0, pad_h))
else:
x = F.pad(
x, (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
)
return x
[docs]
class MMPatchEmbed(nn.Module):
def __init__(
self,
in_channels: int,
embed_dims: int,
patch_size: int,
stride: Optional[int],
dilation: int,
bias: bool,
norm_type: Optional[type],
norm_params: Optional[dict],
patch_norm: bool,
padding_type: str = "corner",
):
"""
Converts an image into patch embeddings using a convolutional projection layer.
Parameters
----------
in_channels : int
Number of input image channels.
embed_dims : int
Dimensionality of the output patch embeddings.
patch_size : int
Size of the square patches.
stride : Optional[int]
Stride for the convolution. If None, defaults to patch size.
dilation : int
Dilation applied to the convolution.
bias : bool
Whether to include a bias term in the projection.
norm_type : Optional[type]
Normalization layer class (e.g., nn.LayerNorm).
norm_params : Optional[dict]
Parameters to initialize the normalization layer.
patch_norm : bool
Whether to apply normalization after patch embedding.
padding_type : str, default="corner"
Padding strategy for adaptive padding.
"""
super().__init__()
self.adapt_padding = MMAdaptivePadding(
kernel_size=(patch_size, patch_size),
stride=(stride, stride) if stride is not None else (patch_size, patch_size),
dilation=(dilation, dilation),
padding=padding_type,
)
self.projection = nn.Conv2d(
in_channels,
embed_dims,
kernel_size=(patch_size, patch_size),
stride=(stride, stride) if stride is not None else (patch_size, patch_size),
dilation=(dilation, dilation),
padding=0,
bias=bias,
)
if patch_norm and norm_type is not None:
if norm_params is None:
self.norm = norm_type(embed_dims)
else:
self.norm = norm_type(embed_dims, *norm_params)
else:
self.norm = None
[docs]
def forward(self, x):
x = self.adapt_padding(x)
x = self.projection(x) # (B, C, H', W')
out_size = tuple(
x.shape[2:]
) # force to be a tuple (H', W'), instead torch.tensor([H', W']) (mmseg return a tuple instead a tensor)
x = x.flatten(2).transpose(1, 2) # (B, N, C)
if self.norm:
x = self.norm(x)
return x, out_size
[docs]
class MMMultiheadAttention(nn.Module):
def __init__(
self,
embed_dims: int,
num_heads: int,
attn_drop: float,
proj_drop: float,
batch_first: bool,
bias: bool,
):
"""
Wrapper around `nn.MultiheadAttention` with support for dropout and residual connections.
Parameters
----------
embed_dims : int
Dimensionality of each token embedding.
num_heads : int
Number of attention heads.
attn_drop : float
Dropout rate for attention weights.
proj_drop : float
Dropout rate for output projection.
batch_first : bool
Whether the input is in (B, L, C) format.
bias : bool
If True, add bias terms to the query, key, and value projections.
"""
super().__init__()
self.batch_first = batch_first
self.attn = nn.MultiheadAttention(
embed_dims, num_heads, dropout=attn_drop, bias=bias
)
self.proj_drop = nn.Dropout(proj_drop)
self.dropout_layer = nn.Dropout(proj_drop)
[docs]
def forward(self, x, identity=None):
if identity is None:
identity = x
# MMSeg do this (transpose to (seq_len, batch, dim)):
if self.batch_first:
x = x.transpose(0, 1)
x = self.attn(x, x, x)[0]
# rollback to (batch, seq_len, dim)
if self.batch_first:
x = x.transpose(0, 1)
x = self.proj_drop(x)
x = identity + self.dropout_layer(x)
return x
[docs]
class MMFFN(nn.Module):
def __init__(
self,
embed_dims: int,
feedforward_channels: int,
dropout_type: type,
dropout_params: Optional[dict],
act_type: type,
act_params: Optional[dict],
num_fcs: int,
ffn_drop: float,
):
"""
Feed-forward network used within the Transformer encoder layer.
Parameters
----------
embed_dims : int
Dimensionality of the token embeddings.
feedforward_channels : int
Number of hidden units in the feed-forward layer.
dropout_type : type
Dropout module class (e.g., nn.Dropout, DropPath).
dropout_params : Optional[dict]
Parameters for the dropout layer.
act_type : type
Activation function class (e.g., nn.GELU).
act_params : Optional[dict]
Parameters for the activation function.
num_fcs : int
Number of fully-connected layers. Only supports 2.
ffn_drop : float
Dropout rate applied after each FC layer.
"""
super().__init__()
if num_fcs != 2:
raise ValueError(
"A implementação atual do FFN suporta apenas num_fcs=2 como no MMSeg."
)
self.activate = act_type(**act_params) if act_params else act_type()
layers = []
# first block: Linear -> GELU -> Dropout
first_block = nn.Sequential(
nn.Linear(embed_dims, feedforward_channels),
self.activate,
nn.Dropout(ffn_drop),
)
layers.append(first_block)
# second block: Linear -> Dropout
layers.append(nn.Linear(feedforward_channels, embed_dims))
layers.append(nn.Dropout(ffn_drop))
self.layers = nn.Sequential(*layers)
self.dropout_layer = (
dropout_type(**dropout_params) if dropout_params else dropout_type()
)
[docs]
def forward(self, x, identity=None):
if identity is None:
identity = x
return identity + self.dropout_layer(self.layers(x))
[docs]
class SetrVitBackbone(nn.Module):
def __init__(
self,
original_resolution: Optional[tuple],
img_size: tuple,
patch_size: int,
embed_dims: int,
interpolate_mode: str,
in_channels: int,
patch_norm: bool,
stride: Optional[int],
dilatation: int,
bias: bool,
norm_type: type,
norm_params: Optional[dict],
padding_type: str,
num_layers: int,
num_heads: int,
out_indices: Union[int, List[int], Tuple[int, ...]],
drop_rate: float,
with_cls_token: bool,
mlp_ratio: int,
attn_drop_rate: float,
drop_path_rate: float,
num_fcs: int,
qkv_bias: bool,
output_cls_token: bool,
act_type: type,
act_params: dict,
with_cp: bool,
dropout_type: type,
dropout_params: Optional[dict],
batch_first: bool = True,
):
"""
Vision Transformer (ViT) backbone for semantic segmentation, following the SETR architecture.
Parameters
----------
original_resolution : Optional[tuple]
Original training image resolution (used for interpolating positional embeddings).
img_size : tuple
Target image size (H, W).
patch_size : int
Size of square patches.
embed_dims : int
Dimensionality of patch embeddings.
interpolate_mode : str
Interpolation method for resizing positional embeddings.
in_channels : int
Number of input channels.
patch_norm : bool
Whether to apply normalization after patch embedding.
stride : Optional[int]
Convolution stride for patch embedding.
dilatation : int
Dilation factor for convolution.
bias : bool
Whether to use bias in convolution.
norm_type : type
Normalization layer class.
norm_params : Optional[dict]
Parameters for normalization layers.
padding_type : str
Padding type for adaptive padding ("same" or "corner").
num_layers : int
Number of transformer encoder layers.
num_heads : int
Number of attention heads.
out_indices : Union[int, List[int], Tuple[int, ...]]
Indices of layers whose outputs are returned.
drop_rate : float
Dropout rate after positional encoding.
with_cls_token : bool
Whether to use a class token in the encoder.
mlp_ratio : int
Expansion ratio for the hidden layer in FFN.
attn_drop_rate : float
Dropout rate in attention.
drop_path_rate : float
Stochastic depth drop rate.
num_fcs : int
Number of FCs in FFN. Must be 2.
qkv_bias : bool
Whether to use bias in QKV projections.
output_cls_token : bool
Whether to return the class token in outputs.
act_type : type
Activation function class.
act_params : dict
Parameters for the activation function.
with_cp : bool
Whether to use checkpointing for memory savings.
dropout_type : type
Dropout class used in FFN.
dropout_params : Optional[dict]
Parameters for dropout.
batch_first : bool, default=True
If True, inputs/outputs are in shape (B, L, C).
"""
super().__init__()
self.original_resolution = original_resolution
self.img_size = img_size
self.patch_size = patch_size
self.embed_dims = embed_dims
self.interpolate_mode = interpolate_mode
assert (
len(img_size) == 2
), f"The size of image should have length 1 or 2, but got {len(img_size)}"
self.patch_embed = MMPatchEmbed(
in_channels=in_channels,
embed_dims=embed_dims,
patch_size=patch_size,
patch_norm=patch_norm,
stride=stride,
dilation=dilatation,
bias=bias,
norm_type=norm_type,
norm_params=norm_params,
padding_type=padding_type,
)
num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size)
self.with_cls_token = with_cls_token
self.output_cls_token = output_cls_token
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dims))
self.drop_after_pos = nn.Dropout(p=drop_rate)
# manipulate output indices
if isinstance(out_indices, int):
if out_indices == -1:
out_indices = num_layers - 1
self.out_indices = [out_indices]
elif isinstance(out_indices, list) or isinstance(out_indices, tuple):
self.out_indices = out_indices
else:
raise TypeError("out_indices must be type of int, list or tuple")
"""
Generates a list of floats with num_layers elements, ranging from 0 to drop_path_rate, in a linearly spaced fashion.
This dpr vector is used to apply the Stochastic Depth technique.
Instead of applying the same drop_path rate to all layers, this technique
linearly decays the value across layers: deeper layers tend to be more likely
to be "turned off" during training. This helps in regularizing the training of very deep networks.
"""
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, num_layers)
] # stochastic depth decay rule
self.layers = nn.ModuleList()
for i in range(num_layers):
self.layers.append(
MMTransformerEncoderLayer(
embed_dims=embed_dims,
num_heads=num_heads,
feedforward_channels=mlp_ratio * embed_dims,
attn_drop_rate=attn_drop_rate,
drop_rate=drop_rate,
drop_path_rate=dpr[i],
num_fcs=num_fcs,
qkv_bias=qkv_bias,
act_params=act_params,
norm_params=norm_params,
with_cp=with_cp,
batch_first=batch_first,
dropout_type=dropout_type,
dropout_params=dropout_params,
norm_type=norm_type,
act_type=act_type,
)
)
# self.init_weights()
[docs]
def _pos_embeding(self, patched_img, hw_shape: tuple, pos_embed):
"""
Positioning embeding method. Resize the pos_embed, if the input image size doesn't match the training size.
Args:
patched_img (torch.Tensor):
The patched image, it should be shape of [B, L1, C].
hw_shape (tuple):
The downsampled image resolution. pos_embed (torch.Tensor): The pos_embed weighs, it should be shape of [B, L2, c].
Return:
torch.Tensor:
The pos encoded image feature.
"""
assert (
patched_img.ndim == 3 and pos_embed.ndim == 3
), "the shapes of patched_img and pos_embed must be [B, L, C]"
x_len, pos_len = patched_img.shape[1], pos_embed.shape[1]
if x_len != pos_len:
if (
pos_len
== (self.img_size[0] // self.patch_size)
* (self.img_size[1] // self.patch_size)
+ 1
):
pos_h = self.img_size[0] // self.patch_size
pos_w = self.img_size[1] // self.patch_size
else:
raise ValueError(
"Unexpected shape of pos_embed, got {}.".format(pos_embed.shape)
)
pos_embed = self.resize_pos_embed(
pos_embed, hw_shape, (pos_h, pos_w), self.interpolate_mode
)
return self.drop_after_pos(
patched_img + pos_embed
) # sum of patched_embed and positional embed (PS: don't do this in MultiheadAttention() again!)
[docs]
@staticmethod
def resize_pos_embed(pos_embed, input_shape, pos_shape, mode: str = "bicubic"):
"""
Resize pos_embed weights. Resize pos_embed using bicubic interpolate method.
Args:
pos_embed (torch.Tensor):
Position embedding weights.
input_shape (tuple):
Tuple for (downsampled input image height, downsampled input image width).
pos_shape (tuple):
The resolution of downsampled origin training image.
mode (str):
Algorithm used for upsampling:
``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | ``'trilinear'``. Default: ``'bicubic'``
Return:
torch.Tensor:
The resized pos_embed of shape [B, L_new, C]
"""
assert pos_embed.ndim == 3, "shape of pos_embed must be [B, L, C]"
pos_h, pos_w = pos_shape
# keep dim for easy deployment
cls_token_weight = pos_embed[:, 0:1]
pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w) :]
pos_embed_weight = pos_embed_weight.reshape(
1, pos_h, pos_w, pos_embed.shape[2]
).permute(0, 3, 1, 2)
pos_embed_weight = F.interpolate(
input=pos_embed_weight,
size=input_shape,
scale_factor=None,
mode=mode,
align_corners=False,
)
pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1)
return pos_embed
# def init_weights(self):
# print("iniciando os pesos")
# self.patch_embed.init_weights()
[docs]
def interpolate_pos_embeddings(
self, pretrained_pos_embed, new_img_size, patch_size=16
):
h, w = new_img_size[0] // patch_size, new_img_size[1] // patch_size
if self.original_resolution is None:
raise ValueError(
"original_resolution must be set to interpolate pos_embed."
)
original_h, original_w = self.original_resolution
pos_embed_reshaped = pretrained_pos_embed[:, 1:].reshape(
1, original_h // patch_size, original_w // patch_size, -1
)
pos_embed_interpolated = (
F.interpolate(
pos_embed_reshaped.permute(0, 3, 1, 2),
size=(h, w),
mode=self.interpolate_mode,
align_corners=False,
)
.permute(0, 2, 3, 1)
.reshape(1, -1, pos_embed_reshaped.shape[-1])
)
cls_token = pretrained_pos_embed[:, :1]
return torch.cat((cls_token, pos_embed_interpolated), dim=1)
[docs]
def load_backbone(self, path: str):
"""Loads pretrained weights and handles positional embedding resizing
if necessary."""
state_dict = torch.load(path)
# Caso os pesos venham de um checkpoint do Lightning ou similar
if "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
# Interpola pos_embed, se necessário
if "pos_embed" in state_dict:
image_size = (
(self.img_size, self.img_size)
if isinstance(self.img_size, int)
else self.img_size
)
expected_shape = (
1,
(image_size[0] // self.patch_size) * (image_size[1] // self.patch_size)
+ 1,
self.embed_dims,
)
if state_dict["pos_embed"].shape != expected_shape:
print("🔄 Interpolando pos_embed para nova resolução...")
with torch.no_grad():
state_dict["pos_embed"] = self.interpolate_pos_embeddings(
state_dict["pos_embed"],
new_img_size=image_size,
patch_size=self.patch_size,
)
else:
print("⚠️ Arquivo .pth não tem pos_embed; pulando interpolação.")
# Filtra apenas os pesos que estão presentes no modelo
model_keys = self.state_dict().keys()
filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_keys}
# Carrega com warning de chaves faltantes/inesperadas
missing, unexpected = self.load_state_dict(filtered_state_dict, strict=False)
print(f"🔍 Missing keys: {missing}")
print(f"🚫 Unexpected keys: {unexpected}")
[docs]
def forward(self, x):
# apply patch embed
x, hw_shape = self.patch_embed(x)
# print(x.shape, hw_shape)
# apply cls token to embed
B = x.shape[0]
cls_tokens = self.cls_token.expand(B, -1, -1) # (B, 1, 1024)
x = torch.cat((cls_tokens, x), dim=1) # (B, 1025, 1024)
# apply droptout
x = self._pos_embeding(x, hw_shape, self.pos_embed)
if not self.with_cls_token:
x = x[:, 1:] # remove class token for transformer encoder input
outs = []
for i, layer in enumerate(self.layers):
x = layer(x)
# process output indices (default: return all indices)
if i in self.out_indices:
if self.with_cls_token:
out = x[
:, 1:
] # remove class token and reshape token for decoder head
else:
out = x
B, _, C = out.shape
out = (
out.reshape(B, hw_shape[0], hw_shape[1], C)
.permute(0, 3, 1, 2)
.contiguous()
)
if self.output_cls_token:
out = [out, x[:, 0]]
outs.append(out)
# final return: the output of last layer, converted like intermediate layers
if self.with_cls_token:
x_final = x[:, 1:]
else:
x_final = x
B, _, C = x_final.shape
x_final = (
x_final.reshape(B, hw_shape[0], hw_shape[1], C)
.permute(0, 3, 1, 2)
.contiguous()
)
return x_final, tuple(outs)
###################################
############### SFM ###############
###################################
[docs]
class MaskedAutoencoderViT(L.LightningModule):
"""
Masked Autoencoder with VisionTransformer backbone.
Args:
img_size (int): Size of input image.
patch_size (int): Size of image patch.
in_chans (int): Number of input channels.
embed_dim (int): Dimension of token embeddings.
depth (int): Number of transformer blocks.
num_heads (int): Number of attention heads.
decoder_embed_dim (int): Dimension of decoder embeddings.
decoder_depth (int): Number of decoder transformer blocks.
decoder_num_heads (int): Number of decoder attention heads.
mlp_ratio (float): Ratio of MLP hidden layer size to embedding size.
norm_layer (torch.nn.LayerNorm): Normalization layer.
norm_pix_loss (bool): Whether to normalize pixel loss.
References:
- timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
- DeiT: https://github.com/facebookresearch/deit
"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=1,
embed_dim=1024,
depth=24,
num_heads=16,
decoder_embed_dim=512,
decoder_depth=8,
decoder_num_heads=16,
mlp_ratio=4.0,
norm_layer=nn.LayerNorm,
norm_pix_loss=False,
):
super().__init__()
# MAE encoder specifics
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False
) # fixed sin-cos embedding
self.in_chans = in_chans
self.blocks = nn.ModuleList(
[
Block(
embed_dim,
num_heads,
mlp_ratio,
qkv_bias=True,
norm_layer=norm_layer,
)
for _ in range(depth)
]
)
self.norm = norm_layer(embed_dim)
# MAE decoder specifics
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
self.decoder_pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, decoder_embed_dim),
requires_grad=False,
) # fixed sin-cos embedding
self.decoder_blocks = nn.ModuleList(
[
Block(
decoder_embed_dim,
decoder_num_heads,
mlp_ratio,
qkv_bias=True,
norm_layer=norm_layer,
)
for _ in range(decoder_depth)
]
)
self.decoder_norm = norm_layer(decoder_embed_dim)
self.decoder_pred = nn.Linear(
decoder_embed_dim, patch_size**2 * in_chans, bias=True
) # decoder to patch
self.norm_pix_loss = norm_pix_loss
self.initialize_weights()
[docs]
def initialize_weights(self):
# Initialization
pos_embed = get_2d_sincos_pos_embed(
self.pos_embed.shape[-1],
int(self.patch_embed.num_patches**0.5),
cls_token=True,
)
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
decoder_pos_embed = get_2d_sincos_pos_embed(
self.decoder_pos_embed.shape[-1],
int(self.patch_embed.num_patches**0.5),
cls_token=True,
)
self.decoder_pos_embed.data.copy_(
torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)
)
w = self.patch_embed.proj.weight.data
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
torch.nn.init.normal_(self.cls_token, std=0.02)
torch.nn.init.normal_(self.mask_token, std=0.02)
self.apply(self._init_weights)
[docs]
def _init_weights(self, m):
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
[docs]
def patchify(self, imgs): # input: (32, 1, 224, 224)
"""
Extract patches from input images.
Args:
imgs (torch.Tensor): Input images of shape (N, C, H, W).
Returns:
torch.Tensor: Patches of shape (N, num_patches, patch_size^2 * in_chans).
"""
p = self.patch_embed.patch_size[0]
assert (
imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
) # only square images are supported, and the size must be divisible by the patch size
h = w = imgs.shape[2] // p
x = imgs.reshape(
(imgs.shape[0], self.in_chans, h, p, w, p)
) # Transform images into (32, 1, 14, 16, 14, 16)
x = torch.einsum("nchpwq->nhwpqc", x) # reshape into (32, 14, 14, 16, 16, 1)
x = x.reshape(
(imgs.shape[0], h * w, p**2 * self.in_chans)
) # Transform into (32, 196, 256)
return x
[docs]
def unpatchify(self, x):
"""
Reconstruct images from patches.
Args:
x (torch.Tensor): Patches of shape (N, L, patch_size^2 * in_chans).
Returns:
torch.Tensor: Reconstructed images of shape (N, C, H, W).
"""
p = self.patch_embed.patch_size[0]
h = w = int(x.shape[1] ** 0.5)
assert h * w == x.shape[1]
x = x.reshape((x.shape[0], h, w, p, p, 3))
x = torch.einsum("nhwpqc->nchpwq", x)
imgs = x.reshape((x.shape[0], 3, h * p, h * p))
return imgs
[docs]
def random_masking(self, x, mask_ratio):
"""
Perform per-sample random masking by per-sample shuffling.
Args:
x (torch.Tensor): Input tensor of shape (N, L, D).
mask_ratio (float): Ratio of values to mask.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Masked input,
binary mask, shuffled indices.
"""
N, L, D = x.shape
len_keep = int(L * (1 - mask_ratio))
noise = torch.rand(N, L, device=x.device)
ids_shuffle = torch.argsort(noise, dim=1)
ids_restore = torch.argsort(ids_shuffle, dim=1)
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
mask = torch.ones(N, L, device=x.device)
mask[:, :len_keep] = 0
mask = torch.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore
[docs]
def forward_encoder(self, x, mask_ratio):
"""
Forward pass through the encoder.
Args:
x (torch.Tensor): Input tensor of shape (N, C, H, W).
mask_ratio (float): Ratio of values to mask.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Encoded
representation, binary mask, shuffled indices.
"""
x = self.patch_embed(x)
x = x + self.pos_embed[:, 1:, :]
x, mask, ids_restore = self.random_masking(x, mask_ratio)
cls_token = self.cls_token + self.pos_embed[:, :1, :]
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x, mask, ids_restore
[docs]
def forward_decoder(self, x, ids_restore):
"""
Forward pass through the decoder.
Args:
x (torch.Tensor): Input tensor of shape (N, L, D).
ids_restore (torch.Tensor): Indices to restore the original order
of patches.
Returns:
torch.Tensor: Decoded output tensor of shape (N, L, patch_size^2 * in_chans).
"""
x = self.decoder_embed(x)
mask_tokens = self.mask_token.repeat(
x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1
)
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)
x_ = torch.gather(
x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])
)
x = torch.cat([x[:, :1, :], x_], dim=1)
x = x + self.decoder_pos_embed
for blk in self.decoder_blocks:
x = blk(x)
x = self.decoder_norm(x)
x = self.decoder_pred(x)
x = x[:, 1:, :]
return x
[docs]
def forward_loss(self, imgs, pred, mask):
"""
Calculate the loss.
Args:
imgs (torch.Tensor): Input images of shape (N, C, H, W).
pred (torch.Tensor): Predicted output of shape (N, L, patch_size^2 * in_chans).
mask (torch.Tensor): Binary mask of shape (N, L).
Returns:
torch.Tensor: Computed loss value.
"""
target = self.patchify(imgs)
if self.norm_pix_loss:
mean = target.mean(dim=-1, keepdim=True)
var = target.var(dim=-1, keepdim=True)
target = (target - mean) / (var + 1.0e-6) ** 0.5
loss = (pred - target) ** 2
loss = loss.mean(dim=-1)
loss = (loss * mask).sum() / mask.sum()
return loss
[docs]
def forward(self, imgs, mask_ratio=0.75):
"""
Forward pass.
Args:
imgs (torch.Tensor): Input images of shape (N, C, H, W).
mask_ratio (float): Ratio of values to mask.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Loss value,
predicted output, binary mask.
"""
latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
pred = self.forward_decoder(latent, ids_restore)
loss = self.forward_loss(imgs, pred, mask)
return loss, pred, mask
[docs]
def training_step(self, batch, batch_idx):
"""
Training step.
Args:
batch (Tuple[torch.Tensor]): Input batch of images and corresponding labels.
batch_idx (int): Index of the current batch.
Returns:
Dict[str, torch.Tensor]: Dictionary containing the loss value for
the current step.
"""
imgs, _ = batch
loss, _, _ = self(imgs)
self.log("train_loss", loss)
return {"loss": loss}
[docs]
def validation_step(self, batch, batch_idx):
"""
Validation step.
Args:
batch (Tuple[torch.Tensor]): Input batch of images and
corresponding labels.
batch_idx (int): Index of the current batch.
Returns:
Dict[str, torch.Tensor]: Dictionary containing the loss value for
the current step.
"""
imgs, _ = batch
loss, _, _ = self(imgs)
self.log("val_loss", loss)
return {"val_loss": loss}
# Define model architectures
# mae_vit_small_patch16_dec512d8b
# decoder: 512 dim, 8 blocks, depth: 6
mae_vit_small_patch16 = partial(
MaskedAutoencoderViT,
patch_size=16,
embed_dim=768,
depth=6,
num_heads=12,
decoder_embed_dim=512,
decoder_depth=4,
decoder_num_heads=16,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
)
# mae_vit_base_patch16_dec512d8b
# decoder: 512 dim, 8 blocks,
mae_vit_base_patch16 = partial(
MaskedAutoencoderViT,
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
decoder_embed_dim=512,
decoder_depth=8,
decoder_num_heads=16,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
)
# mae_vit_large_patch16_dec512d8b
# decoder: 512 dim, 8 blocks
mae_vit_large_patch16 = partial(
MaskedAutoencoderViT,
patch_size=16,
embed_dim=1024,
depth=24,
num_heads=16,
decoder_embed_dim=512,
decoder_depth=8,
decoder_num_heads=16,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
)
# mae_vit_huge_patch14_dec512d8b
# decoder: 512 dim, 8 blocks
mae_vit_huge_patch14 = partial(
MaskedAutoencoderViT,
patch_size=14,
embed_dim=1280,
depth=32,
num_heads=16,
decoder_embed_dim=512,
decoder_depth=8,
decoder_num_heads=16,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
)
# mae_vit_large_patch16_dec256d4b
# decoder: 256 dim, 8 blocks
mae_vit_large_patch16D4d256 = partial(
MaskedAutoencoderViT,
patch_size=16,
embed_dim=1024,
depth=24,
num_heads=16,
decoder_embed_dim=256,
decoder_depth=4,
decoder_num_heads=16,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
)
# mae_vit_base_patch16_dec256d4b
mae_vit_base_patch16D4d256 = partial(
MaskedAutoencoderViT,
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
decoder_embed_dim=256,
decoder_depth=4,
decoder_num_heads=16,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
)
################################################################################
# SFM DOWNSTREAM TASKS
################################################################################
[docs]
class Conv2dReLU(nn.Sequential):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
padding=0,
stride=1,
use_batchnorm=True,
):
conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
bias=not (use_batchnorm),
)
relu = nn.ReLU(inplace=True)
bn = nn.BatchNorm2d(out_channels)
super(Conv2dReLU, self).__init__(conv, bn, relu)
[docs]
class DecoderBlock(nn.Module):
def __init__(
self,
in_channels,
out_channels,
skip_channels=0,
use_batchnorm=True,
):
super().__init__()
self.conv1 = Conv2dReLU(
in_channels + skip_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
self.conv2 = Conv2dReLU(
out_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
self.up = nn.UpsamplingBilinear2d(scale_factor=2)
[docs]
def forward(self, x, skip=None):
if skip is not None:
x = torch.cat([x, skip], dim=1)
x = self.up(x)
x = self.conv1(x)
x = self.conv2(x)
return x
[docs]
class SegmentationHead(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1):
conv2d = nn.Conv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
padding=kernel_size // 2,
)
upsampling = (
nn.UpsamplingBilinear2d(scale_factor=upsampling)
if upsampling > 1
else nn.Identity()
)
super().__init__(conv2d, upsampling)
[docs]
class DecoderCup(nn.Module):
def __init__(self):
super().__init__()
# self.config = config
head_channels = 512
self.conv_more = Conv2dReLU(
1024,
head_channels,
kernel_size=3,
padding=1,
use_batchnorm=True,
)
decoder_channels = (256, 128, 64, 16)
in_channels = [head_channels] + list(decoder_channels[:-1])
out_channels = decoder_channels
# if self.config.n_skip != 0:
# skip_channels = self.config.skip_channels
# for i in range(4-self.config.n_skip): # re-select the skip channels according to n_skip
# skip_channels[3-i]=0
# else:
# skip_channels=[0,0,0,0]
skip_channels = [512, 256, 128, 64]
self.conv_feature1 = Conv2dReLU(
1024, skip_channels[0], kernel_size=3, padding=1, use_batchnorm=True
)
self.conv_feature2 = Conv2dReLU(
1024, skip_channels[1], kernel_size=3, padding=1, use_batchnorm=True
)
self.up2 = nn.UpsamplingBilinear2d(scale_factor=2)
self.conv_feature3 = Conv2dReLU(
1024, skip_channels[2], kernel_size=3, padding=1, use_batchnorm=True
)
self.up3 = nn.UpsamplingBilinear2d(scale_factor=4)
self.conv_feature4 = Conv2dReLU(
1024, skip_channels[3], kernel_size=3, padding=1, use_batchnorm=True
)
self.up4 = nn.UpsamplingBilinear2d(scale_factor=8)
# skip_channels=[128,64,32,8]
blocks = [
DecoderBlock(in_ch, out_ch, sk_ch)
for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels)
]
self.blocks = nn.ModuleList(blocks)
[docs]
def TransShape(self, x, head_channels=512, up=0):
B, n_patch, hidden = (
x.size()
) # reshape from (B, n_patch, hidden) to (B, h, w, hidden)
h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))
x = x.permute(0, 2, 1)
x = x.contiguous().view(B, hidden, h, w)
if up == 0:
x = self.conv_feature1(x)
elif up == 1:
x = self.conv_feature2(x)
x = self.up2(x)
elif up == 2:
x = self.conv_feature3(x)
x = self.up3(x)
elif up == 3:
x = self.conv_feature4(x)
x = self.up4(x)
return x
[docs]
def forward(self, hidden_states, features=None):
B, n_patch, hidden = (
hidden_states.size()
) # reshape from (B, n_patch, hidden) to (B, h, w, hidden)
h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))
x = hidden_states.permute(0, 2, 1)
x = x.contiguous().view(B, hidden, h, w)
x = self.conv_more(x)
skip_channels = [512, 256, 128, 64]
for i, decoder_block in enumerate(self.blocks):
if features is not None:
skip = self.TransShape(
features[i], head_channels=skip_channels[i], up=i
)
else:
skip = None
x = decoder_block(x, skip=skip)
return x
[docs]
class MLAHead(nn.Module):
def __init__(self, mla_channels=256, mlahead_channels=128, norm_cfg=None):
super(MLAHead, self).__init__()
self.head2 = nn.Sequential(
nn.Conv2d(mla_channels, mlahead_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(mlahead_channels),
nn.ReLU(),
nn.Conv2d(mlahead_channels, mlahead_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(mlahead_channels),
nn.ReLU(),
)
self.head3 = nn.Sequential(
nn.Conv2d(mla_channels, mlahead_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(mlahead_channels),
nn.ReLU(),
nn.Conv2d(mlahead_channels, mlahead_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(mlahead_channels),
nn.ReLU(),
)
self.head4 = nn.Sequential(
nn.Conv2d(mla_channels, mlahead_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(mlahead_channels),
nn.ReLU(),
nn.Conv2d(mlahead_channels, mlahead_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(mlahead_channels),
nn.ReLU(),
)
self.head5 = nn.Sequential(
nn.Conv2d(mla_channels, mlahead_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(mlahead_channels),
nn.ReLU(),
nn.Conv2d(mlahead_channels, mlahead_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(mlahead_channels),
nn.ReLU(),
)
[docs]
def forward(self, mla_p2, mla_p3, mla_p4, mla_p5):
head2 = F.interpolate(
self.head2(mla_p2),
(4 * mla_p2.shape[-2], 4 * mla_p2.shape[-1]),
mode="bilinear",
align_corners=True,
)
head3 = F.interpolate(
self.head3(mla_p3),
(4 * mla_p3.shape[-2], 4 * mla_p3.shape[-1]),
mode="bilinear",
align_corners=True,
)
head4 = F.interpolate(
self.head4(mla_p4),
(4 * mla_p4.shape[-2], 4 * mla_p4.shape[-1]),
mode="bilinear",
align_corners=True,
)
head5 = F.interpolate(
self.head5(mla_p5),
(4 * mla_p5.shape[-2], 4 * mla_p5.shape[-1]),
mode="bilinear",
align_corners=True,
)
return torch.cat([head2, head3, head4, head5], dim=1)
[docs]
class VIT_MLAHead(nn.Module):
"""Vision Transformer with support for patch or hybrid CNN input stage"""
def __init__(
self,
img_size=768,
mla_channels=256,
mlahead_channels=128,
num_classes=6,
norm_layer=nn.BatchNorm2d,
norm_cfg=None,
**kwargs,
):
super(VIT_MLAHead, self).__init__(**kwargs)
self.img_size = img_size
self.norm_cfg = norm_cfg
self.mla_channels = mla_channels
self.BatchNorm = norm_layer
self.mlahead_channels = mlahead_channels
self.num_classes = num_classes
self.mlahead = MLAHead(
mla_channels=self.mla_channels,
mlahead_channels=self.mlahead_channels,
norm_cfg=self.norm_cfg,
)
self.cls = nn.Conv2d(4 * self.mlahead_channels, self.num_classes, 3, padding=1)
[docs]
def forward(self, x1, x2, x3, x4, h=14, w=14):
B, n_patch, hidden = x1.size()
if h == w:
h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))
x1 = x1.permute(0, 2, 1)
x1 = x1.contiguous().view(B, hidden, h, w)
x2 = x2.permute(0, 2, 1)
x2 = x2.contiguous().view(B, hidden, h, w)
x3 = x3.permute(0, 2, 1)
x3 = x3.contiguous().view(B, hidden, h, w)
x4 = x4.permute(0, 2, 1)
x4 = x4.contiguous().view(B, hidden, h, w)
x = self.mlahead(x1, x2, x3, x4)
x = self.cls(x)
x = F.interpolate(x, size=(h * 16, w * 16), mode="bilinear", align_corners=True)
return x
[docs]
def vit_base_patch16_downstream_regression(**kwargs):
model = VisionTransformer(
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs,
)
return model
[docs]
def vit_large_patch16_downstream_regression(**kwargs):
model = VisionTransformer(
patch_size=16,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs,
)
return model
[docs]
def vit_huge_patch14_downstream_regression(**kwargs):
model = VisionTransformer(
patch_size=14,
embed_dim=1280,
depth=32,
num_heads=16,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs,
)
return model
[docs]
def interpolate_pos_embed(model, checkpoint_model, newsize1=None, newsize2=None):
if "pos_embed" in checkpoint_model:
pos_embed_checkpoint = checkpoint_model["pos_embed"]
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.patch_embed.num_patches
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
# height (== width) for the new position embedding
new_size = int(num_patches**0.5)
# class_token and dist_token are kept unchanged
if orig_size != new_size:
if newsize1 == None:
newsize1, newsize2 = new_size, new_size
print(
"Position interpolate from %dx%d to %dx%d"
% (orig_size, orig_size, newsize1, newsize2)
)
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(
-1, orig_size, orig_size, embedding_size
).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens,
size=(newsize1, newsize2),
mode="bicubic",
align_corners=False,
)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
checkpoint_model["pos_embed"] = new_pos_embed
[docs]
class SFM_BasePatch16_Downstream(SimpleSupervisedModel):
def __init__(
self,
img_size: Union[int, Tuple[int, ...]] = (512, 512),
num_classes: int = 6,
in_chans: int = 1,
loss_fn: Optional[torch.nn.Module] = None,
learning_rate: float = 1e-3,
**kwargs,
):
"""Create a SFM model with a ViT base backbone. The ViT-Base-16 backbone
has the following configuration:
- Patch size: 16
- Embedding dimension: 768
- Depth: 12
- Number of heads: 12
Parameters
----------
img_size : Union[int, Tuple[int, ...]]
Size of the input image. Note that, to use default pre-trained SFM
model, the size should be (512, 512).
num_classes : int
Number of classes for segmentation head. Default is 6.
in_chans : int
Number of input channels. Default is 1.
loss_fn : Optional[torch.nn.Module], optional
Loss function, by default None
learning_rate : float, optional
Learning rate value, by default 1e-3
"""
super().__init__(
backbone=vit_base_patch16_downstream_regression(
img_size=img_size,
num_classes=num_classes,
in_chans=in_chans,
),
fc=torch.nn.Identity(),
loss_fn=loss_fn or torch.nn.CrossEntropyLoss(),
learning_rate=learning_rate,
flatten=False,
**kwargs,
)
[docs]
def _single_step(
self, batch: torch.Tensor, batch_idx: int, step_name: str
) -> torch.Tensor:
x, y = batch
x = x.float()
if x.shape[1] > 1:
x = x[:, 0:1, :, :]
if y.ndim == 4:
y = y[:, 0, :, :].long()
return super()._single_step((x, y), batch_idx, step_name)
[docs]
def predict_step(
self, batch: torch.Tensor, batch_idx: int, dataloader_idx: int = 0
) -> torch.Tensor:
x, _ = batch
x = x.float()
if x.shape[1] > 1:
x = x[:, 0:1, :, :]
logits = self.backbone.model(x)
return logits