import math
from collections import OrderedDict
from functools import partial
from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, Union, List
try:
from typing import Literal
except ImportError:
from typing_extensions import Literal
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.jit import Final
from timm.layers import (
Mlp,
DropPath,
AttentionPoolLatent,
PatchDropout,
trunc_normal_,
lecun_normal_,
resample_patch_embed,
resample_abs_pos_embed,
use_fused_attn,
get_act_layer,
get_norm_layer,
LayerType,
)
from timm.models._manipulate import named_apply, checkpoint_seq, adapt_input_conv
from .patch_embed import PatchEmbed
[docs]
class Attention(nn.Module):
"""
Multi-head self-attention module.
This class implements the standard multi-head attention mechanism used in
Transformer architectures. It supports both standard and fused attention
implementations for improved performance when available.
"""
fused_attn: Final[bool]
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = False,
proj_bias: bool = True,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
norm_layer: Type[nn.Module] = nn.LayerNorm,
):
"""
Initialize the Attention module.
Parameters
----------
dim : int
Total dimension of the input and output features.
num_heads : int, default=8
Number of attention heads.
qkv_bias : bool, default=False
If True, add a bias term to the query, key, and value projections.
qk_norm : bool, default=False
If True, apply normalization to query and key tensors.
proj_bias : bool, default=True
If True, include bias in the output projection layer.
attn_drop : float, default=0.0
Dropout rate applied to the attention weights.
proj_drop : float, default=0.0
Dropout rate applied after the output projection.
norm_layer : Type[nn.Module], default=nn.LayerNorm
Normalization layer type applied to query and key vectors when `qk_norm=True`.
"""
super().__init__()
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.fused_attn = use_fused_attn()
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim, bias=proj_bias)
self.proj_drop = nn.Dropout(proj_drop)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the multi-head attention mechanism.
Parameters
----------
x : torch.Tensor
Input tensor of shape (B, N, C), where
B is the batch size, N is the sequence length, and C is the feature dimension.
Returns
-------
torch.Tensor
Output tensor of the same shape as input (B, N, C),
containing the attended feature representations.
"""
B, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B, N, 3, self.num_heads, self.head_dim)
.permute(2, 0, 3, 1, 4)
)
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
if self.fused_attn:
x = F.scaled_dot_product_attention(
q,
k,
v,
dropout_p=self.attn_drop.p if self.training else 0.0,
)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
[docs]
class LayerScale(nn.Module):
"""LayerScale module."""
def __init__(
self,
dim: int,
init_values: float = 1e-5,
inplace: bool = False,
):
"""
Initialize the LayerScale module.
Parameters
----------
dim : int
Number of feature dimensions (channels) to scale.
init_values : float, default=1e-5
Initial value for the learnable scaling parameter.
inplace : bool, default=False
If True, performs the scaling operation in-place to save memory.
"""
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass applying per-channel scaling to the input tensor.
Parameters
----------
x : torch.Tensor
Input tensor of shape (B, N, C) or (B, C, H, W), depending on context.
Returns
-------
torch.Tensor
Scaled tensor of the same shape as the input.
"""
return x.mul_(self.gamma) if self.inplace else x * self.gamma
[docs]
class Block(nn.Module):
"""Transformer block module."""
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = False,
qk_norm: bool = False,
proj_bias: bool = True,
proj_drop: float = 0.0,
attn_drop: float = 0.0,
init_values: Optional[float] = None,
drop_path: float = 0.0,
act_layer: Type[nn.Module] = nn.GELU,
norm_layer: Type[nn.Module] = nn.LayerNorm,
mlp_layer: Type[nn.Module] = Mlp,
):
"""
Initialize the Transformer block.
Parameters
----------
dim : int
Embedding dimension of the input and output features.
num_heads : int
Number of attention heads in the self-attention layer.
mlp_ratio : float, default=4.0
Expansion ratio for the hidden dimension in the MLP layer.
qkv_bias : bool, default=False
If True, add bias to the query, key, and value projections.
qk_norm : bool, default=False
If True, apply normalization to query and key tensors.
proj_bias : bool, default=True
If True, include bias in the projection layers.
proj_drop : float, default=0.0
Dropout rate applied to the output of the attention and MLP layers.
attn_drop : float, default=0.0
Dropout rate applied to the attention weights.
init_values : float, optional
If specified, enables LayerScale with this initial scaling value.
drop_path : float, default=0.0
Stochastic depth rate; set > 0 to apply DropPath regularization.
act_layer : Type[nn.Module], default=nn.GELU
Activation function used in the MLP layer.
norm_layer : Type[nn.Module], default=nn.LayerNorm
Normalization layer type applied before attention and MLP.
mlp_layer : Type[nn.Module], default=Mlp
Module type used for the feed-forward network.
"""
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
proj_bias=proj_bias,
attn_drop=attn_drop,
proj_drop=proj_drop,
norm_layer=norm_layer,
)
self.ls1 = (
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
)
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = mlp_layer(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
bias=proj_bias,
drop=proj_drop,
)
self.ls2 = (
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
)
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
return x
[docs]
class ResPostBlock(nn.Module):
"""Residual Post-Norm Transformer block."""
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = False,
qk_norm: bool = False,
proj_bias: bool = True,
proj_drop: float = 0.0,
attn_drop: float = 0.0,
init_values: Optional[float] = None,
drop_path: float = 0.0,
act_layer: Type[nn.Module] = nn.GELU,
norm_layer: Type[nn.Module] = nn.LayerNorm,
mlp_layer: Type[nn.Module] = Mlp,
):
super().__init__()
"""
Initialize the Residual Post-Norm Transformer block.
Parameters
----------
dim : int
Embedding dimension of the input and output features.
num_heads : int
Number of attention heads in the self-attention layer.
mlp_ratio : float, default=4.0
Expansion ratio for the hidden dimension in the MLP layer.
qkv_bias : bool, default=False
If True, add bias to the query, key, and value projections.
qk_norm : bool, default=False
If True, apply normalization to query and key tensors.
proj_bias : bool, default=True
If True, include bias in the projection layers.
proj_drop : float, default=0.0
Dropout rate applied to the output of the attention and MLP layers.
attn_drop : float, default=0.0
Dropout rate applied to the attention weights.
init_values : float, optional
If specified, initializes normalization layer weights with this constant.
drop_path : float, default=0.0
Stochastic depth rate; set > 0 to apply DropPath regularization.
act_layer : Type[nn.Module], default=nn.GELU
Activation function used in the MLP layer.
norm_layer : Type[nn.Module], default=nn.LayerNorm
Normalization layer type applied after attention and MLP.
mlp_layer : Type[nn.Module], default=Mlp
Module type used for the feed-forward network.
"""
self.init_values = init_values
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
proj_bias=proj_bias,
attn_drop=attn_drop,
proj_drop=proj_drop,
norm_layer=norm_layer,
)
self.norm1 = norm_layer(dim)
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.mlp = mlp_layer(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
bias=proj_bias,
drop=proj_drop,
)
self.norm2 = norm_layer(dim)
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.init_weights()
[docs]
def init_weights(self) -> None:
# NOTE this init overrides that base model init with specific changes for the block type
if self.init_values is not None:
nn.init.constant_(self.norm1.weight, self.init_values)
nn.init.constant_(self.norm2.weight, self.init_values)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the Residual Post-Norm Transformer block.
The input tensor passes through attention and MLP sublayers, each followed
by normalization and residual connections. DropPath is optionally applied
for regularization.
Parameters
----------
x : torch.Tensor
Input tensor of shape (B, N, C), where
B is batch size, N is sequence length, and C is embedding dimension.
Returns
-------
torch.Tensor
Output tensor of the same shape (B, N, C), representing the transformed features.
"""
x = x + self.drop_path1(self.norm1(self.attn(x)))
x = x + self.drop_path2(self.norm2(self.mlp(x)))
return x
[docs]
class ParallelScalingBlock(nn.Module):
"""
Parallel Scaling Vision Transformer block.
This module implements a parallel Transformer block that computes the
multi-head self-attention and MLP branches concurrently and then combines
their outputs. The design follows the architecture from
"Scaling Vision Transformers to 22 Billion Parameters"
(https://arxiv.org/abs/2302.05442).
The block includes LayerScale for stable deep scaling, optional DropPath for
stochastic depth regularization, and supports fused attention when available
for performance efficiency.
"""
fused_attn: Final[bool]
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = False,
qk_norm: bool = False,
proj_bias: bool = True,
proj_drop: float = 0.0,
attn_drop: float = 0.0,
init_values: Optional[float] = None,
drop_path: float = 0.0,
act_layer: Type[nn.Module] = nn.GELU,
norm_layer: Type[nn.Module] = nn.LayerNorm,
mlp_layer: Optional[Type[nn.Module]] = None,
):
"""
Initialize the ParallelScalingBlock.
Parameters
----------
dim : int
Embedding dimension of the input and output features.
num_heads : int
Number of attention heads in the multi-head self-attention layer.
mlp_ratio : float, default=4.0
Expansion ratio for the hidden dimension in the MLP branch.
qkv_bias : bool, default=False
If True, add bias to the query, key, and value projections.
qk_norm : bool, default=False
If True, apply normalization to the query and key tensors.
proj_bias : bool, default=True
If True, include bias in the output projection layers.
proj_drop : float, default=0.0
Dropout rate applied after the projection layers.
attn_drop : float, default=0.0
Dropout rate applied to the attention weights.
init_values : float, optional
If specified, enables LayerScale with this initialization value.
drop_path : float, default=0.0
Stochastic depth rate; set > 0 to apply DropPath regularization.
act_layer : Type[nn.Module], default=nn.GELU
Activation function used in the MLP branch.
norm_layer : Type[nn.Module], default=nn.LayerNorm
Normalization layer applied before the parallel branches.
mlp_layer : Type[nn.Module], optional
Optional custom MLP implementation; defaults to a standard linear MLP.
"""
super().__init__()
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.fused_attn = use_fused_attn()
mlp_hidden_dim = int(mlp_ratio * dim)
in_proj_out_dim = mlp_hidden_dim + 3 * dim
self.in_norm = norm_layer(dim)
self.in_proj = nn.Linear(dim, in_proj_out_dim, bias=qkv_bias)
self.in_split = [mlp_hidden_dim] + [dim] * 3
if qkv_bias:
self.register_buffer("qkv_bias", None)
self.register_parameter("mlp_bias", None)
else:
self.register_buffer("qkv_bias", torch.zeros(3 * dim), persistent=False)
self.mlp_bias = nn.Parameter(torch.zeros(mlp_hidden_dim))
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.attn_out_proj = nn.Linear(dim, dim, bias=proj_bias)
self.mlp_drop = nn.Dropout(proj_drop)
self.mlp_act = act_layer()
self.mlp_out_proj = nn.Linear(mlp_hidden_dim, dim, bias=proj_bias)
self.ls = (
LayerScale(dim, init_values=init_values)
if init_values is not None
else nn.Identity()
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the Parallel Scaling Transformer block.
Parameters
----------
x : torch.Tensor
Input tensor of shape (B, N, C), where
B is batch size, N is sequence length, and C is embedding dimension.
Returns
-------
torch.Tensor
Output tensor of shape (B, N, C), containing the updated feature representations.
"""
B, N, C = x.shape
# Combined MLP fc1 & qkv projections
y = self.in_norm(x)
if self.mlp_bias is not None:
# Concat constant zero-bias for qkv w/ trainable mlp_bias.
# Appears faster than adding to x_mlp separately
y = F.linear(
y, self.in_proj.weight, torch.cat((self.qkv_bias, self.mlp_bias))
)
else:
y = self.in_proj(y)
x_mlp, q, k, v = torch.split(y, self.in_split, dim=-1)
# Dot product attention w/ qk norm
q = self.q_norm(q.view(B, N, self.num_heads, self.head_dim)).transpose(1, 2)
k = self.k_norm(k.view(B, N, self.num_heads, self.head_dim)).transpose(1, 2)
v = v.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
if self.fused_attn:
x_attn = F.scaled_dot_product_attention(
q,
k,
v,
dropout_p=self.attn_drop.p if self.training else 0.0,
)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x_attn = attn @ v
x_attn = x_attn.transpose(1, 2).reshape(B, N, C)
x_attn = self.attn_out_proj(x_attn)
# MLP activation, dropout, fc2
x_mlp = self.mlp_act(x_mlp)
x_mlp = self.mlp_drop(x_mlp)
x_mlp = self.mlp_out_proj(x_mlp)
# Add residual w/ drop path & layer scale applied
y = self.drop_path(self.ls(x_attn + x_mlp))
x = x + y
return x
[docs]
class ParallelThingsBlock(nn.Module):
"""
Parallel Things Vision Transformer block.
This module implements a Transformer block that processes the input through
multiple parallel attention layers followed by multiple parallel MLP layers.
The outputs of each parallel branch are summed together, enabling a richer
representation and improved learning capacity.
The design follows the architecture from
"Three Things Everyone Should Know About Vision Transformers"
(https://arxiv.org/abs/2203.09795).
"""
def __init__(
self,
dim: int,
num_heads: int,
num_parallel: int = 2,
mlp_ratio: float = 4.0,
qkv_bias: bool = False,
qk_norm: bool = False,
proj_bias: bool = True,
init_values: Optional[float] = None,
proj_drop: float = 0.0,
attn_drop: float = 0.0,
drop_path: float = 0.0,
act_layer: Type[nn.Module] = nn.GELU,
norm_layer: Type[nn.Module] = nn.LayerNorm,
mlp_layer: Type[nn.Module] = Mlp,
) -> None:
"""
Initialize the ParallelThingsBlock.
Parameters
----------
dim : int
Embedding dimension of the input and output features.
num_heads : int
Number of attention heads in each attention branch.
num_parallel : int, default=2
Number of parallel attention and MLP branches.
mlp_ratio : float, default=4.0
Expansion ratio for the hidden dimension in the MLP layers.
qkv_bias : bool, default=False
If True, add bias to the query, key, and value projections.
qk_norm : bool, default=False
If True, apply normalization to query and key tensors.
proj_bias : bool, default=True
If True, include bias in the projection layers.
init_values : float, optional
If specified, enables LayerScale with this initialization value.
proj_drop : float, default=0.0
Dropout rate applied to the output of the projection layers.
attn_drop : float, default=0.0
Dropout rate applied to the attention weights.
drop_path : float, default=0.0
Stochastic depth rate; set > 0 to apply DropPath regularization.
act_layer : Type[nn.Module], default=nn.GELU
Activation function used in the MLP layers.
norm_layer : Type[nn.Module], default=nn.LayerNorm
Normalization layer type applied in each sub-block.
mlp_layer : Type[nn.Module], default=Mlp
Module type used for the feed-forward MLP networks.
"""
super().__init__()
self.num_parallel = num_parallel
self.attns = nn.ModuleList()
self.ffns = nn.ModuleList()
for _ in range(num_parallel):
self.attns.append(
nn.Sequential(
OrderedDict(
[
("norm", norm_layer(dim)),
(
"attn",
Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
proj_bias=proj_bias,
attn_drop=attn_drop,
proj_drop=proj_drop,
norm_layer=norm_layer,
),
),
(
"ls",
(
LayerScale(dim, init_values=init_values)
if init_values
else nn.Identity()
),
),
(
"drop_path",
(
DropPath(drop_path)
if drop_path > 0.0
else nn.Identity()
),
),
]
)
)
)
self.ffns.append(
nn.Sequential(
OrderedDict(
[
("norm", norm_layer(dim)),
(
"mlp",
mlp_layer(
dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
bias=proj_bias,
drop=proj_drop,
),
),
(
"ls",
(
LayerScale(dim, init_values=init_values)
if init_values
else nn.Identity()
),
),
(
"drop_path",
(
DropPath(drop_path)
if drop_path > 0.0
else nn.Identity()
),
),
]
)
)
)
[docs]
def _forward_jit(self, x: torch.Tensor) -> torch.Tensor:
x = x + torch.stack([attn(x) for attn in self.attns]).sum(dim=0)
x = x + torch.stack([ffn(x) for ffn in self.ffns]).sum(dim=0)
return x
[docs]
@torch.jit.ignore
def _forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + sum(attn(x) for attn in self.attns)
x = x + sum(ffn(x) for ffn in self.ffns)
return x
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the ParallelThingsBlock.
Parameters
----------
x : torch.Tensor
Input tensor of shape (B, N, C).
Returns
-------
torch.Tensor
Output tensor of the same shape (B, N, C), representing
the combined outputs from the parallel attention and MLP branches.
"""
if torch.jit.is_scripting() or torch.jit.is_tracing():
return self._forward_jit(x)
else:
return self._forward(x)
[docs]
def global_pool_nlc(
x: torch.Tensor,
pool_type: str = "token",
num_prefix_tokens: int = 1,
reduce_include_prefix: bool = False,
):
if not pool_type:
return x
if pool_type == "token":
x = x[:, 0] # class token
else:
x = x if reduce_include_prefix else x[:, num_prefix_tokens:]
if pool_type == "avg":
x = x.mean(dim=1)
elif pool_type == "avgmax":
x = 0.5 * (x.amax(dim=1) + x.mean(dim=1))
elif pool_type == "max":
x = x.amax(dim=1)
else:
assert not pool_type, f"Unknown pool type {pool_type}"
return x
[docs]
def init_weights_vit_timm(module: nn.Module, name: str = "") -> None:
"""ViT weight initialization, original timm impl (for reproducibility)"""
if isinstance(module, nn.Linear):
trunc_normal_(module.weight, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif hasattr(module, "init_weights"):
module.init_weights()
[docs]
def init_weights_vit_jax(
module: nn.Module, name: str = "", head_bias: float = 0.0
) -> None:
"""ViT weight initialization, matching JAX (Flax) impl"""
if isinstance(module, nn.Linear):
if name.startswith("head"):
nn.init.zeros_(module.weight)
nn.init.constant_(module.bias, head_bias)
else:
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
(
nn.init.normal_(module.bias, std=1e-6)
if "mlp" in name
else nn.init.zeros_(module.bias)
)
elif isinstance(module, nn.Conv2d):
lecun_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif hasattr(module, "init_weights"):
module.init_weights()
[docs]
def init_weights_vit_moco(module: nn.Module, name: str = "") -> None:
"""ViT weight initialization, matching moco-v3 impl minus fixed PatchEmbed"""
if isinstance(module, nn.Linear):
if "qkv" in name:
# treat the weights of Q, K, V separately
val = math.sqrt(
6.0 / float(module.weight.shape[0] // 3 + module.weight.shape[1])
)
nn.init.uniform_(module.weight, -val, val)
else:
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif hasattr(module, "init_weights"):
module.init_weights()
[docs]
def get_init_weights_vit(mode: str = "jax", head_bias: float = 0.0) -> Callable:
if "jax" in mode:
return partial(init_weights_vit_jax, head_bias=head_bias)
elif "moco" in mode:
return init_weights_vit_moco
else:
return init_weights_vit_timm
[docs]
def resize_pos_embed(
posemb: torch.Tensor,
posemb_new: torch.Tensor,
num_prefix_tokens: int = 1,
gs_new: Tuple[int, int] = (),
interpolation: str = "bicubic",
antialias: bool = False,
) -> torch.Tensor:
"""Rescale the grid of position embeddings when loading from state_dict.
*DEPRECATED* This function is being deprecated in favour of using resample_abs_pos_embed
"""
ntok_new = posemb_new.shape[1] - num_prefix_tokens
ntok_old = posemb.shape[1] - num_prefix_tokens
gs_old = [int(math.sqrt(ntok_old))] * 2
if not len(gs_new): # backwards compatibility
gs_new = [int(math.sqrt(ntok_new))] * 2
return resample_abs_pos_embed(
posemb,
gs_new,
gs_old,
num_prefix_tokens=num_prefix_tokens,
interpolation=interpolation,
antialias=antialias,
verbose=True,
)