Source code for minerva.models.nets.image.vit
# Standard library imports
import math
from collections import OrderedDict
from functools import partial
from typing import Callable, List, Optional, Tuple, Union
# Third-party imports
import lightning as L
import numpy as np
import timm.models.vision_transformer
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.vision_transformer import Block, PatchEmbed
from torchvision.models.vision_transformer import (
Conv2dNormActivation,
ConvStemConfig,
EncoderBlock,
_log_api_usage_once,
)
from minerva.models.nets.base import SimpleSupervisedModel
# Local imports
from minerva.utils.position_embedding import get_2d_sincos_pos_embed
[docs]
class _Encoder(nn.Module):
"""Transformer Model Encoder for sequence to sequence translation."""
def __init__(
self,
seq_length: int,
num_layers: int,
num_heads: int,
hidden_dim: int,
mlp_dim: int,
dropout: float,
attention_dropout: float,
aux_output: bool = False,
aux_output_layers: Optional[List[int]] = None,
norm_layer: Callable[..., torch.nn.Module] = partial(
nn.LayerNorm, eps=1e-6
),
):
super().__init__()
# Note that batch_size is on the first dim because
# we have batch_first=True in nn.MultiAttention() by default
self.aux_output = aux_output
self.aux_output_layers = aux_output_layers
self.pos_embedding = nn.Parameter(
torch.empty(1, seq_length, hidden_dim).normal_(std=0.02)
) # from BERT
self.dropout = nn.Dropout(dropout)
layers: OrderedDict[str, nn.Module] = OrderedDict()
for i in range(num_layers):
layers[f"encoder_layer_{i}"] = EncoderBlock(
num_heads,
hidden_dim,
mlp_dim,
dropout,
attention_dropout,
norm_layer,
)
self.layers = nn.Sequential(layers)
self.ln = norm_layer(hidden_dim)
[docs]
def forward(self, input: torch.Tensor):
torch._assert(
input.dim() == 3,
f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}",
)
input = input + self.pos_embedding
if self.aux_output:
aux_outputs = []
for i, layer in enumerate(self.layers):
input = layer(input)
if i in self.aux_output_layers: # type: ignore
aux_outputs.append(self.ln(self.dropout(input)))
return self.ln(self.dropout(input)), aux_outputs
return self.ln(self.layers(self.dropout(input)))
[docs]
class _VisionTransformerBackbone(nn.Module):
"""Vision Transformer as per https://arxiv.org/abs/2010.11929."""
def __init__(
self,
image_size: Union[int, Tuple[int, int]],
patch_size: int,
num_layers: int,
num_heads: int,
hidden_dim: int,
mlp_dim: int,
original_resolution: Optional[Tuple[int, int]] = None,
dropout: float = 0.0,
attention_dropout: float = 0.0,
num_classes: int = 1000,
aux_output: bool = False,
aux_output_layers: Optional[List[int]] = None,
norm_layer: Callable[..., torch.nn.Module] = partial(
nn.LayerNorm, eps=1e-6
),
conv_stem_configs: Optional[List[ConvStemConfig]] = None,
):
"""
Initializes a Vision Transformer (ViT) model.
Parameters
----------
image_size : int or Tuple[int, int]
The size of the input image. If an int is provided, it is assumed
to be a square image. If a tuple of ints is provided, it represents
the height and width of the image.
patch_size : int
The size of each patch in the image.
num_layers : int
The number of transformer layers in the model.
num_heads : int
The number of attention heads in the transformer layers.
hidden_dim : int
The dimensionality of the hidden layers in the transformer.
mlp_dim : int
The dimensionality of the feed-forward MLP layers in the transformer
original_resolution : Tuple[int, int], optional
The original resolution of the input image in the pre-training
weights. When None, positional embeddings will not be interpolated.
Defaults to None.
dropout : float, optional
The dropout rate to apply. Defaults to 0.0.
attention_dropout : float, optional
The dropout rate to apply to the attention weights. Defaults to 0.0
num_classes : int, optional
The number of output classes. Defaults to 1000.
norm_layer : Callable[..., torch.nn.Module], optional
The normalization layer to use. Defaults to nn.LayerNorm with
epsilon=1e-6.
conv_stem_configs : List[ConvStemConfig], optional
The configuration for the convolutional stem layers.
If provided, the input image will be processed by these
convolutional layers before being passed to the transformer.
Defaults to None.
"""
super().__init__()
_log_api_usage_once(self)
if aux_output:
assert aux_output_layers is not None
assert all(
0 <= i < num_layers for i in aux_output_layers
), "Invalid layer index in aux_output_layers"
if isinstance(image_size, int):
torch._assert(
image_size % patch_size == 0,
"Input shape indivisible by patch size!",
)
elif isinstance(image_size, tuple):
torch._assert(
image_size[0] % patch_size == 0
and image_size[1] % patch_size == 0,
"Input shape indivisible by patch size!",
)
self.image_size = image_size
self.patch_size = patch_size
self.hidden_dim = hidden_dim
self.mlp_dim = mlp_dim
self.attention_dropout = attention_dropout
self.dropout = dropout
self.num_classes = num_classes
self.norm_layer = norm_layer
self.aux_output = aux_output
self.aux_output_layers = aux_output_layers
self.original_resolution = (
original_resolution if original_resolution else image_size
)
if conv_stem_configs is not None:
# As per https://arxiv.org/abs/2106.14881
seq_proj = nn.Sequential()
prev_channels = 3
for i, conv_stem_layer_config in enumerate(conv_stem_configs):
seq_proj.add_module(
f"conv_bn_relu_{i}",
Conv2dNormActivation(
in_channels=prev_channels,
out_channels=conv_stem_layer_config.out_channels,
kernel_size=conv_stem_layer_config.kernel_size,
stride=conv_stem_layer_config.stride,
norm_layer=conv_stem_layer_config.norm_layer,
activation_layer=conv_stem_layer_config.activation_layer,
),
)
prev_channels = conv_stem_layer_config.out_channels
seq_proj.add_module(
"conv_last",
nn.Conv2d(
in_channels=prev_channels,
out_channels=hidden_dim,
kernel_size=1,
),
)
self.conv_proj: nn.Module = seq_proj
else:
self.conv_proj = nn.Conv2d(
in_channels=3,
out_channels=hidden_dim,
kernel_size=patch_size,
stride=patch_size,
)
if isinstance(image_size, int):
seq_length = (image_size // patch_size) ** 2
elif isinstance(image_size, tuple):
seq_length = (image_size[0] // patch_size) * (
image_size[1] // patch_size
)
# Add a class token
self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))
seq_length += 1
self.encoder = _Encoder(
seq_length=seq_length,
num_layers=num_layers,
num_heads=num_heads,
hidden_dim=hidden_dim,
mlp_dim=mlp_dim,
dropout=dropout,
attention_dropout=attention_dropout,
norm_layer=norm_layer,
aux_output=aux_output,
aux_output_layers=aux_output_layers,
)
self.seq_length = seq_length
if isinstance(self.conv_proj, nn.Conv2d):
# Init the patchify stem
fan_in = (
self.conv_proj.in_channels
* self.conv_proj.kernel_size[0]
* self.conv_proj.kernel_size[1]
)
nn.init.trunc_normal_(
self.conv_proj.weight, std=math.sqrt(1 / fan_in)
)
if self.conv_proj.bias is not None:
nn.init.zeros_(self.conv_proj.bias)
elif self.conv_proj.conv_last is not None and isinstance(
self.conv_proj.conv_last, nn.Conv2d
):
# Init the last 1x1 conv of the conv stem
nn.init.normal_(
self.conv_proj.conv_last.weight,
mean=0.0,
std=math.sqrt(2.0 / self.conv_proj.conv_last.out_channels),
)
if self.conv_proj.conv_last.bias is not None:
nn.init.zeros_(self.conv_proj.conv_last.bias)
[docs]
def _process_input(self, x: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
"""Process the input tensor and return the reshaped tensor and dimensions.
Args:
x (torch.Tensor): The input tensor.
Returns:
Tuple[torch.Tensor, int, int]: The reshaped tensor, number of rows,
and number of columns.
"""
n, c, h, w = x.shape
p = self.patch_size
if isinstance(self.image_size, int):
torch._assert(
h == self.image_size,
f"Wrong image height! Expected {self.image_size} but got {h}!",
)
torch._assert(
w == self.image_size,
f"Wrong image width! Expected {self.image_size} but got {w}!",
)
elif isinstance(self.image_size, tuple):
torch._assert(
h == self.image_size[0],
f"Wrong image height! Expected {self.image_size[0]} but got {h}!",
)
torch._assert(
w == self.image_size[1],
f"Wrong image width! Expected {self.image_size[1]} but got {w}!",
)
else:
raise ValueError("Invalid image size type!")
n_h = h // p
n_w = w // p
# (n, c, h, w) -> (n, hidden_dim, n_h, n_w)
x = x.to(torch.float32)
x = self.conv_proj(x)
# (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w))
x = x.reshape(n, self.hidden_dim, n_h * n_w)
# (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim)
# The self attention layer expects inputs in the format (N, S, E)
# where S is the source sequence length, N is the batch size, E is the
# embedding dimension
x = x.permute(0, 2, 1)
return x, n_h, n_w
[docs]
def interpolate_pos_embeddings(self, pretrained_pos_embed, new_img_size):
"""Interpolate encoder's positional embeddings to fit a new input size.
Args:
pretrained_pos_embed (torch.Tensor): Pretrained positional embeddings.
new_img_size (Tuple[int, int]): New height and width of the input image.
"""
h, w = (
new_img_size[0] // self.patch_size,
new_img_size[1] // self.patch_size,
)
new_grid_size = (h, w)
# Reshape pretrained positional embeddings to match the original grid size
original_resolution = (
self.original_resolution
if isinstance(self.original_resolution, Tuple)
else (self.original_resolution, self.original_resolution)
)
pos_embed_reshaped = pretrained_pos_embed[:, 1:].reshape(
1,
original_resolution[0] // self.patch_size,
original_resolution[1] // self.patch_size,
-1,
)
# Interpolate positional embeddings to the new grid size
pos_embed_interpolated = (
F.interpolate(
pos_embed_reshaped.permute(
0, 3, 1, 2
), # (1, C, H, W) for interpolation
size=new_grid_size,
mode="bilinear",
align_corners=False,
)
.permute(0, 2, 3, 1)
.reshape(1, -1, pos_embed_reshaped.shape[-1])
)
# Concatenate the CLS token and the interpolated positional embeddings
cls_token = pretrained_pos_embed[:, :1]
pos_embed_interpolated = torch.cat(
(cls_token, pos_embed_interpolated), dim=1
)
return pos_embed_interpolated
return pos_embed_interpolated
[docs]
def load_backbone(self, path: str, freeze: bool = False):
"""Loads pretrained weights and handles positional embedding resizing
if necessary."""
# Load the pretrained state dict
state_dict = torch.load(path)
# Expected shape for positional embeddings based on current model image size
image_size = (
self.image_size
if isinstance(self.image_size, Tuple)
else (self.image_size, self.image_size)
)
expected_pos_embed_shape = (
1,
(image_size[0] // self.patch_size)
* (image_size[1] // self.patch_size)
+ 1,
self.hidden_dim,
)
# Check if positional embeddings need interpolation
if (
state_dict["encoder.pos_embedding"].shape
!= expected_pos_embed_shape
):
# Extract the positional embeddings from the state dict
pretrained_pos_embed = state_dict["encoder.pos_embedding"]
# Interpolate to match the current image size
print(
"Interpolating positional embeddings to match the new image size."
)
with torch.no_grad():
pos_embed_interpolated = self.interpolate_pos_embeddings(
pretrained_pos_embed, (image_size[0], image_size[1])
)
state_dict["encoder.pos_embedding"] = pos_embed_interpolated
# Load the (potentially modified) state dict into the encoder
self.encoder.load_state_dict(state_dict, strict=False)
# Optionally freeze parameters
if freeze:
for param in self.encoder.parameters():
param.requires_grad = False
[docs]
def forward(self, x: torch.Tensor):
"""Forward pass of the Vision Transformer Backbone.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor.
"""
# Reshape and permute the input tensor
x, n_h, n_w = self._process_input(x)
n = x.shape[0]
# Expand the class token to the full batch
batch_class_token = self.class_token.expand(n, -1, -1)
x = torch.cat([batch_class_token, x], dim=1)
if self.aux_output:
x, aux_outputs = self.encoder(x)
x = x[:, 1:]
B, _, C = x.shape
x = x.reshape(B, n_h, n_w, C).permute(0, 3, 1, 2).contiguous()
for i, aux_output in enumerate(aux_outputs):
aux_outputs[i] = aux_output[:, 1:]
B, _, C = aux_outputs[i].shape
aux_outputs[i] = (
aux_outputs[i]
.reshape(B, n_h, n_w, C)
.permute(0, 3, 1, 2)
.contiguous()
)
return x, aux_outputs
x = self.encoder(x)
# Classifier "token" as used by standard language architectures
x = x[:, 1:]
B, _, C = x.shape
x = x.reshape(B, n_h, n_w, C).permute(0, 3, 1, 2).contiguous()
return x
[docs]
def load_weights(self, weights_path: str, freeze: bool = False):
state_dict = torch.load(weights_path)
# Get expected positional embedding shape based on current image size
image_size = (
self.image_size
if isinstance(self.image_size, Tuple)
else (self.image_size, self.image_size)
)
expected_pos_embed_shape = (
1,
(image_size[0] // self.patch_size)
* (image_size[1] // self.patch_size)
+ 1,
self.hidden_dim,
)
# Check if positional embeddings need interpolation
if (
state_dict["encoder.pos_embedding"].shape
!= expected_pos_embed_shape
):
# Extract the positional embeddings from the state dict
pretrained_pos_embed = state_dict["encoder.pos_embedding"]
# Interpolate to match the current image size
print(
"Interpolating positional embeddings to match the new image size."
)
with torch.no_grad():
pos_embed_interpolated = self.interpolate_pos_embeddings(
pretrained_pos_embed, (image_size[0], image_size[1])
)
state_dict["encoder.pos_embedding"] = pos_embed_interpolated
# Load the (potentially modified) state dict
self.load_state_dict(state_dict, strict=False)
# Optionally freeze parameters
if freeze:
for param in self.parameters():
param.requires_grad = False
###################################
############### SFM ###############
###################################
[docs]
class MaskedAutoencoderViT(L.LightningModule):
"""
Masked Autoencoder with VisionTransformer backbone.
Args:
img_size (int): Size of input image.
patch_size (int): Size of image patch.
in_chans (int): Number of input channels.
embed_dim (int): Dimension of token embeddings.
depth (int): Number of transformer blocks.
num_heads (int): Number of attention heads.
decoder_embed_dim (int): Dimension of decoder embeddings.
decoder_depth (int): Number of decoder transformer blocks.
decoder_num_heads (int): Number of decoder attention heads.
mlp_ratio (float): Ratio of MLP hidden layer size to embedding size.
norm_layer (torch.nn.LayerNorm): Normalization layer.
norm_pix_loss (bool): Whether to normalize pixel loss.
References:
- timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
- DeiT: https://github.com/facebookresearch/deit
"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=1,
embed_dim=1024,
depth=24,
num_heads=16,
decoder_embed_dim=512,
decoder_depth=8,
decoder_num_heads=16,
mlp_ratio=4.0,
norm_layer=nn.LayerNorm,
norm_pix_loss=False,
):
super().__init__()
# MAE encoder specifics
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False
) # fixed sin-cos embedding
self.in_chans = in_chans
self.blocks = nn.ModuleList(
[
Block(
embed_dim,
num_heads,
mlp_ratio,
qkv_bias=True,
norm_layer=norm_layer,
)
for _ in range(depth)
]
)
self.norm = norm_layer(embed_dim)
# MAE decoder specifics
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
self.decoder_pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, decoder_embed_dim),
requires_grad=False,
) # fixed sin-cos embedding
self.decoder_blocks = nn.ModuleList(
[
Block(
decoder_embed_dim,
decoder_num_heads,
mlp_ratio,
qkv_bias=True,
norm_layer=norm_layer,
)
for _ in range(decoder_depth)
]
)
self.decoder_norm = norm_layer(decoder_embed_dim)
self.decoder_pred = nn.Linear(
decoder_embed_dim, patch_size**2 * in_chans, bias=True
) # decoder to patch
self.norm_pix_loss = norm_pix_loss
self.initialize_weights()
[docs]
def initialize_weights(self):
# Initialization
pos_embed = get_2d_sincos_pos_embed(
self.pos_embed.shape[-1],
int(self.patch_embed.num_patches**0.5),
cls_token=True,
)
self.pos_embed.data.copy_(
torch.from_numpy(pos_embed).float().unsqueeze(0)
)
decoder_pos_embed = get_2d_sincos_pos_embed(
self.decoder_pos_embed.shape[-1],
int(self.patch_embed.num_patches**0.5),
cls_token=True,
)
self.decoder_pos_embed.data.copy_(
torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)
)
w = self.patch_embed.proj.weight.data
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
torch.nn.init.normal_(self.cls_token, std=0.02)
torch.nn.init.normal_(self.mask_token, std=0.02)
self.apply(self._init_weights)
[docs]
def _init_weights(self, m):
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
[docs]
def patchify(self, imgs): # input: (32, 1, 224, 224)
"""
Extract patches from input images.
Args:
imgs (torch.Tensor): Input images of shape (N, C, H, W).
Returns:
torch.Tensor: Patches of shape (N, num_patches, patch_size^2 * in_chans).
"""
p = self.patch_embed.patch_size[0]
assert (
imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
) # only square images are supported, and the size must be divisible by the patch size
h = w = imgs.shape[2] // p
x = imgs.reshape(
(imgs.shape[0], self.in_chans, h, p, w, p)
) # Transform images into (32, 1, 14, 16, 14, 16)
x = torch.einsum(
"nchpwq->nhwpqc", x
) # reshape into (32, 14, 14, 16, 16, 1)
x = x.reshape(
(imgs.shape[0], h * w, p**2 * self.in_chans)
) # Transform into (32, 196, 256)
return x
[docs]
def unpatchify(self, x):
"""
Reconstruct images from patches.
Args:
x (torch.Tensor): Patches of shape (N, L, patch_size^2 * in_chans).
Returns:
torch.Tensor: Reconstructed images of shape (N, C, H, W).
"""
p = self.patch_embed.patch_size[0]
h = w = int(x.shape[1] ** 0.5)
assert h * w == x.shape[1]
x = x.reshape((x.shape[0], h, w, p, p, 3))
x = torch.einsum("nhwpqc->nchpwq", x)
imgs = x.reshape((x.shape[0], 3, h * p, h * p))
return imgs
[docs]
def random_masking(self, x, mask_ratio):
"""
Perform per-sample random masking by per-sample shuffling.
Args:
x (torch.Tensor): Input tensor of shape (N, L, D).
mask_ratio (float): Ratio of values to mask.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Masked input,
binary mask, shuffled indices.
"""
N, L, D = x.shape
len_keep = int(L * (1 - mask_ratio))
noise = torch.rand(N, L, device=x.device)
ids_shuffle = torch.argsort(noise, dim=1)
ids_restore = torch.argsort(ids_shuffle, dim=1)
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(
x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)
)
mask = torch.ones(N, L, device=x.device)
mask[:, :len_keep] = 0
mask = torch.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore
[docs]
def forward_encoder(self, x, mask_ratio):
"""
Forward pass through the encoder.
Args:
x (torch.Tensor): Input tensor of shape (N, C, H, W).
mask_ratio (float): Ratio of values to mask.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Encoded
representation, binary mask, shuffled indices.
"""
x = self.patch_embed(x)
x = x + self.pos_embed[:, 1:, :]
x, mask, ids_restore = self.random_masking(x, mask_ratio)
cls_token = self.cls_token + self.pos_embed[:, :1, :]
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x, mask, ids_restore
[docs]
def forward_decoder(self, x, ids_restore):
"""
Forward pass through the decoder.
Args:
x (torch.Tensor): Input tensor of shape (N, L, D).
ids_restore (torch.Tensor): Indices to restore the original order
of patches.
Returns:
torch.Tensor: Decoded output tensor of shape (N, L, patch_size^2 * in_chans).
"""
x = self.decoder_embed(x)
mask_tokens = self.mask_token.repeat(
x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1
)
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)
x_ = torch.gather(
x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])
)
x = torch.cat([x[:, :1, :], x_], dim=1)
x = x + self.decoder_pos_embed
for blk in self.decoder_blocks:
x = blk(x)
x = self.decoder_norm(x)
x = self.decoder_pred(x)
x = x[:, 1:, :]
return x
[docs]
def forward_loss(self, imgs, pred, mask):
"""
Calculate the loss.
Args:
imgs (torch.Tensor): Input images of shape (N, C, H, W).
pred (torch.Tensor): Predicted output of shape (N, L, patch_size^2 * in_chans).
mask (torch.Tensor): Binary mask of shape (N, L).
Returns:
torch.Tensor: Computed loss value.
"""
target = self.patchify(imgs)
if self.norm_pix_loss:
mean = target.mean(dim=-1, keepdim=True)
var = target.var(dim=-1, keepdim=True)
target = (target - mean) / (var + 1.0e-6) ** 0.5
loss = (pred - target) ** 2
loss = loss.mean(dim=-1)
loss = (loss * mask).sum() / mask.sum()
return loss
[docs]
def forward(self, imgs, mask_ratio=0.75):
"""
Forward pass.
Args:
imgs (torch.Tensor): Input images of shape (N, C, H, W).
mask_ratio (float): Ratio of values to mask.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Loss value,
predicted output, binary mask.
"""
latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
pred = self.forward_decoder(latent, ids_restore)
loss = self.forward_loss(imgs, pred, mask)
return loss, pred, mask
[docs]
def training_step(self, batch, batch_idx):
"""
Training step.
Args:
batch (Tuple[torch.Tensor]): Input batch of images and corresponding labels.
batch_idx (int): Index of the current batch.
Returns:
Dict[str, torch.Tensor]: Dictionary containing the loss value for
the current step.
"""
imgs, _ = batch
loss, _, _ = self(imgs)
self.log("train_loss", loss)
return {"loss": loss}
[docs]
def validation_step(self, batch, batch_idx):
"""
Validation step.
Args:
batch (Tuple[torch.Tensor]): Input batch of images and
corresponding labels.
batch_idx (int): Index of the current batch.
Returns:
Dict[str, torch.Tensor]: Dictionary containing the loss value for
the current step.
"""
imgs, _ = batch
loss, _, _ = self(imgs)
self.log("val_loss", loss)
return {"val_loss": loss}
[docs]
def configure_optimizers(self):
"""
Configure optimizer.
Returns:
torch.optim.Optimizer: Optimizer.
"""
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
# Define model architectures
# mae_vit_small_patch16_dec512d8b
# decoder: 512 dim, 8 blocks, depth: 6
mae_vit_small_patch16 = partial(
MaskedAutoencoderViT,
patch_size=16,
embed_dim=768,
depth=6,
num_heads=12,
decoder_embed_dim=512,
decoder_depth=4,
decoder_num_heads=16,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
)
# mae_vit_base_patch16_dec512d8b
# decoder: 512 dim, 8 blocks,
mae_vit_base_patch16 = partial(
MaskedAutoencoderViT,
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
decoder_embed_dim=512,
decoder_depth=8,
decoder_num_heads=16,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
)
# mae_vit_large_patch16_dec512d8b
# decoder: 512 dim, 8 blocks
mae_vit_large_patch16 = partial(
MaskedAutoencoderViT,
patch_size=16,
embed_dim=1024,
depth=24,
num_heads=16,
decoder_embed_dim=512,
decoder_depth=8,
decoder_num_heads=16,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
)
# mae_vit_huge_patch14_dec512d8b
# decoder: 512 dim, 8 blocks
mae_vit_huge_patch14 = partial(
MaskedAutoencoderViT,
patch_size=14,
embed_dim=1280,
depth=32,
num_heads=16,
decoder_embed_dim=512,
decoder_depth=8,
decoder_num_heads=16,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
)
# mae_vit_large_patch16_dec256d4b
# decoder: 256 dim, 8 blocks
mae_vit_large_patch16D4d256 = partial(
MaskedAutoencoderViT,
patch_size=16,
embed_dim=1024,
depth=24,
num_heads=16,
decoder_embed_dim=256,
decoder_depth=4,
decoder_num_heads=16,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
)
# mae_vit_base_patch16_dec256d4b
mae_vit_base_patch16D4d256 = partial(
MaskedAutoencoderViT,
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
decoder_embed_dim=256,
decoder_depth=4,
decoder_num_heads=16,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
)
################################################################################
# SFM DOWNSTREAM TASKS
################################################################################
[docs]
class VisionTransformer(
timm.models.vision_transformer.VisionTransformer, L.LightningModule
):
"""Vision Transformer with support for global average pooling"""
def __init__(self, global_pool=False, **kwargs):
super(VisionTransformer, self).__init__(**kwargs)
self.global_pool = global_pool
self.decoder = VIT_MLAHead(
mla_channels=self.embed_dim, num_classes=self.num_classes
)
self.segmentation_head = SegmentationHead(
in_channels=16,
out_channels=self.num_classes,
kernel_size=3,
)
if self.global_pool:
norm_layer = kwargs["norm_layer"]
embed_dim = kwargs["embed_dim"]
self.fc_norm = norm_layer(embed_dim)
del self.norm # remove the original norm
self.loss_fn = nn.CrossEntropyLoss()
[docs]
def forward_features(self, x):
B, C, H, W = x.shape
x = self.patch_embed(x)
_H, _W = (
H // self.patch_embed.patch_size[0],
W // self.patch_embed.patch_size[0],
)
cls_tokens = self.cls_token.expand(
B, -1, -1
) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed
x = self.pos_drop(x)
featureskip = []
featureskipnum = 1
for blk in self.blocks:
x = blk(x)
if featureskipnum % (len(self.blocks) // 4) == 0:
featureskip.append(x[:, 1:, :])
# print(featureskipnum)
featureskipnum += 1
x = self.decoder(
featureskip[0],
featureskip[1],
featureskip[2],
featureskip[3],
h=_H,
w=_W,
)
return x
[docs]
class Conv2dReLU(nn.Sequential):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
padding=0,
stride=1,
use_batchnorm=True,
):
conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
bias=not (use_batchnorm),
)
relu = nn.ReLU(inplace=True)
bn = nn.BatchNorm2d(out_channels)
super(Conv2dReLU, self).__init__(conv, bn, relu)
[docs]
class DecoderBlock(nn.Module):
def __init__(
self,
in_channels,
out_channels,
skip_channels=0,
use_batchnorm=True,
):
super().__init__()
self.conv1 = Conv2dReLU(
in_channels + skip_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
self.conv2 = Conv2dReLU(
out_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
self.up = nn.UpsamplingBilinear2d(scale_factor=2)
[docs]
def forward(self, x, skip=None):
if skip is not None:
x = torch.cat([x, skip], dim=1)
x = self.up(x)
x = self.conv1(x)
x = self.conv2(x)
return x
[docs]
class SegmentationHead(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1):
conv2d = nn.Conv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
padding=kernel_size // 2,
)
upsampling = (
nn.UpsamplingBilinear2d(scale_factor=upsampling)
if upsampling > 1
else nn.Identity()
)
super().__init__(conv2d, upsampling)
[docs]
class DecoderCup(nn.Module):
def __init__(self):
super().__init__()
# self.config = config
head_channels = 512
self.conv_more = Conv2dReLU(
1024,
head_channels,
kernel_size=3,
padding=1,
use_batchnorm=True,
)
decoder_channels = (256, 128, 64, 16)
in_channels = [head_channels] + list(decoder_channels[:-1])
out_channels = decoder_channels
# if self.config.n_skip != 0:
# skip_channels = self.config.skip_channels
# for i in range(4-self.config.n_skip): # re-select the skip channels according to n_skip
# skip_channels[3-i]=0
# else:
# skip_channels=[0,0,0,0]
skip_channels = [512, 256, 128, 64]
self.conv_feature1 = Conv2dReLU(
1024, skip_channels[0], kernel_size=3, padding=1, use_batchnorm=True
)
self.conv_feature2 = Conv2dReLU(
1024, skip_channels[1], kernel_size=3, padding=1, use_batchnorm=True
)
self.up2 = nn.UpsamplingBilinear2d(scale_factor=2)
self.conv_feature3 = Conv2dReLU(
1024, skip_channels[2], kernel_size=3, padding=1, use_batchnorm=True
)
self.up3 = nn.UpsamplingBilinear2d(scale_factor=4)
self.conv_feature4 = Conv2dReLU(
1024, skip_channels[3], kernel_size=3, padding=1, use_batchnorm=True
)
self.up4 = nn.UpsamplingBilinear2d(scale_factor=8)
# skip_channels=[128,64,32,8]
blocks = [
DecoderBlock(in_ch, out_ch, sk_ch)
for in_ch, out_ch, sk_ch in zip(
in_channels, out_channels, skip_channels
)
]
self.blocks = nn.ModuleList(blocks)
[docs]
def TransShape(self, x, head_channels=512, up=0):
B, n_patch, hidden = (
x.size()
) # reshape from (B, n_patch, hidden) to (B, h, w, hidden)
h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))
x = x.permute(0, 2, 1)
x = x.contiguous().view(B, hidden, h, w)
if up == 0:
x = self.conv_feature1(x)
elif up == 1:
x = self.conv_feature2(x)
x = self.up2(x)
elif up == 2:
x = self.conv_feature3(x)
x = self.up3(x)
elif up == 3:
x = self.conv_feature4(x)
x = self.up4(x)
return x
[docs]
def forward(self, hidden_states, features=None):
B, n_patch, hidden = (
hidden_states.size()
) # reshape from (B, n_patch, hidden) to (B, h, w, hidden)
h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))
x = hidden_states.permute(0, 2, 1)
x = x.contiguous().view(B, hidden, h, w)
x = self.conv_more(x)
skip_channels = [512, 256, 128, 64]
for i, decoder_block in enumerate(self.blocks):
if features is not None:
skip = self.TransShape(
features[i], head_channels=skip_channels[i], up=i
)
else:
skip = None
x = decoder_block(x, skip=skip)
return x
[docs]
class MLAHead(nn.Module):
def __init__(self, mla_channels=256, mlahead_channels=128, norm_cfg=None):
super(MLAHead, self).__init__()
self.head2 = nn.Sequential(
nn.Conv2d(mla_channels, mlahead_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(mlahead_channels),
nn.ReLU(),
nn.Conv2d(
mlahead_channels, mlahead_channels, 3, padding=1, bias=False
),
nn.BatchNorm2d(mlahead_channels),
nn.ReLU(),
)
self.head3 = nn.Sequential(
nn.Conv2d(mla_channels, mlahead_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(mlahead_channels),
nn.ReLU(),
nn.Conv2d(
mlahead_channels, mlahead_channels, 3, padding=1, bias=False
),
nn.BatchNorm2d(mlahead_channels),
nn.ReLU(),
)
self.head4 = nn.Sequential(
nn.Conv2d(mla_channels, mlahead_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(mlahead_channels),
nn.ReLU(),
nn.Conv2d(
mlahead_channels, mlahead_channels, 3, padding=1, bias=False
),
nn.BatchNorm2d(mlahead_channels),
nn.ReLU(),
)
self.head5 = nn.Sequential(
nn.Conv2d(mla_channels, mlahead_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(mlahead_channels),
nn.ReLU(),
nn.Conv2d(
mlahead_channels, mlahead_channels, 3, padding=1, bias=False
),
nn.BatchNorm2d(mlahead_channels),
nn.ReLU(),
)
[docs]
def forward(self, mla_p2, mla_p3, mla_p4, mla_p5):
head2 = F.interpolate(
self.head2(mla_p2),
(4 * mla_p2.shape[-2], 4 * mla_p2.shape[-1]),
mode="bilinear",
align_corners=True,
)
head3 = F.interpolate(
self.head3(mla_p3),
(4 * mla_p3.shape[-2], 4 * mla_p3.shape[-1]),
mode="bilinear",
align_corners=True,
)
head4 = F.interpolate(
self.head4(mla_p4),
(4 * mla_p4.shape[-2], 4 * mla_p4.shape[-1]),
mode="bilinear",
align_corners=True,
)
head5 = F.interpolate(
self.head5(mla_p5),
(4 * mla_p5.shape[-2], 4 * mla_p5.shape[-1]),
mode="bilinear",
align_corners=True,
)
return torch.cat([head2, head3, head4, head5], dim=1)
[docs]
class VIT_MLAHead(nn.Module):
"""Vision Transformer with support for patch or hybrid CNN input stage"""
def __init__(
self,
img_size=768,
mla_channels=256,
mlahead_channels=128,
num_classes=6,
norm_layer=nn.BatchNorm2d,
norm_cfg=None,
**kwargs,
):
super(VIT_MLAHead, self).__init__(**kwargs)
self.img_size = img_size
self.norm_cfg = norm_cfg
self.mla_channels = mla_channels
self.BatchNorm = norm_layer
self.mlahead_channels = mlahead_channels
self.num_classes = num_classes
self.mlahead = MLAHead(
mla_channels=self.mla_channels,
mlahead_channels=self.mlahead_channels,
norm_cfg=self.norm_cfg,
)
self.cls = nn.Conv2d(
4 * self.mlahead_channels, self.num_classes, 3, padding=1
)
[docs]
def forward(self, x1, x2, x3, x4, h=14, w=14):
B, n_patch, hidden = x1.size()
if h == w:
h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))
x1 = x1.permute(0, 2, 1)
x1 = x1.contiguous().view(B, hidden, h, w)
x2 = x2.permute(0, 2, 1)
x2 = x2.contiguous().view(B, hidden, h, w)
x3 = x3.permute(0, 2, 1)
x3 = x3.contiguous().view(B, hidden, h, w)
x4 = x4.permute(0, 2, 1)
x4 = x4.contiguous().view(B, hidden, h, w)
x = self.mlahead(x1, x2, x3, x4)
x = self.cls(x)
x = F.interpolate(
x, size=(h * 16, w * 16), mode="bilinear", align_corners=True
)
return x
[docs]
def vit_base_patch16_downstream_regression(**kwargs):
model = VisionTransformer(
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs,
)
return model
[docs]
def vit_large_patch16_downstream_regression(**kwargs):
model = VisionTransformer(
patch_size=16,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs,
)
return model
[docs]
def vit_huge_patch14_downstream_regression(**kwargs):
model = VisionTransformer(
patch_size=14,
embed_dim=1280,
depth=32,
num_heads=16,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs,
)
return model
[docs]
def interpolate_pos_embed(
model, checkpoint_model, newsize1=None, newsize2=None
):
if "pos_embed" in checkpoint_model:
pos_embed_checkpoint = checkpoint_model["pos_embed"]
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.patch_embed.num_patches
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
orig_size = int(
(pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5
)
# height (== width) for the new position embedding
new_size = int(num_patches**0.5)
# class_token and dist_token are kept unchanged
if orig_size != new_size:
if newsize1 == None:
newsize1, newsize2 = new_size, new_size
print(
"Position interpolate from %dx%d to %dx%d"
% (orig_size, orig_size, newsize1, newsize2)
)
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(
-1, orig_size, orig_size, embedding_size
).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens,
size=(newsize1, newsize2),
mode="bicubic",
align_corners=False,
)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
checkpoint_model["pos_embed"] = new_pos_embed
[docs]
class SFM_BasePatch16_Downstream(SimpleSupervisedModel):
def __init__(
self,
img_size: Union[int, Tuple[int, ...]] = (512, 512),
num_classes: int = 6,
in_chans: int = 1,
loss_fn: Optional[torch.nn.Module] = None,
learning_rate: float = 1e-3,
**kwargs,
):
"""Create a SFM model with a ViT base backbone. The ViT-Base-16 backbone
has the following configuration:
- Patch size: 16
- Embedding dimension: 768
- Depth: 12
- Number of heads: 12
Parameters
----------
img_size : Union[int, Tuple[int, ...]]
Size of the input image. Note that, to use default pre-trained SFM
model, the size should be (512, 512).
num_classes : int
Number of classes for segmentation head. Default is 6.
in_chans : int
Number of input channels. Default is 1.
loss_fn : Optional[torch.nn.Module], optional
Loss function, by default None
learning_rate : float, optional
Learning rate value, by default 1e-3
"""
super().__init__(
backbone=vit_base_patch16_downstream_regression(
img_size=img_size,
num_classes=num_classes,
in_chans=in_chans,
),
fc=torch.nn.Identity(),
loss_fn=loss_fn or torch.nn.CrossEntropyLoss(),
learning_rate=learning_rate,
flatten=False,
**kwargs,
)
[docs]
def _single_step(
self, batch: torch.Tensor, batch_idx: int, step_name: str
) -> torch.Tensor:
x, y = batch
x = x.float()
if x.shape[1] > 1:
x = x[:, 0:1, :, :]
if y.ndim == 4:
y = y[:, 0, :, :].long()
return super()._single_step((x, y), batch_idx, step_name)
[docs]
def predict_step(
self, batch: torch.Tensor, batch_idx: int, dataloader_idx: int = 0
) -> torch.Tensor:
x, _ = batch
x = x.float()
if x.shape[1] > 1:
x = x[:, 0:1, :, :]
logits = self.backbone.model(x)
return logits