minerva.models.nets.sfm
Attributes
Classes
Masked Autoencoder with VisionTransformer backbone. |
Module Contents
- class minerva.models.nets.sfm.MaskedAutoencoderViT(img_size=224, patch_size=16, in_chans=1, embed_dim=1024, depth=24, num_heads=16, decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, mlp_ratio=4.0, norm_layer=nn.LayerNorm, norm_pix_loss=False)
Bases:
lightning.LightningModule
Masked Autoencoder with VisionTransformer backbone.
- Args:
img_size (int): Size of input image. patch_size (int): Size of image patch. in_chans (int): Number of input channels. embed_dim (int): Dimension of token embeddings. depth (int): Number of transformer blocks. num_heads (int): Number of attention heads. decoder_embed_dim (int): Dimension of decoder embeddings. decoder_depth (int): Number of decoder transformer blocks. decoder_num_heads (int): Number of decoder attention heads. mlp_ratio (float): Ratio of MLP hidden layer size to embedding size. norm_layer (torch.nn.LayerNorm): Normalization layer. norm_pix_loss (bool): Whether to normalize pixel loss.
- References:
- _init_weights(m)
- configure_optimizers()
Configure optimizer.
- Returns:
torch.optim.Optimizer: Optimizer.
- forward(imgs, mask_ratio=0.75)
Forward pass.
- Args:
imgs (torch.Tensor): Input images of shape (N, C, H, W). mask_ratio (float): Ratio of values to mask.
- Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Loss value, predicted output, binary mask.
- forward_decoder(x, ids_restore)
Forward pass through the decoder.
- Args:
x (torch.Tensor): Input tensor of shape (N, L, D). ids_restore (torch.Tensor): Indices to restore the original order of patches.
- Returns:
torch.Tensor: Decoded output tensor of shape (N, L, patch_size^2 * in_chans).
- forward_encoder(x, mask_ratio)
Forward pass through the encoder.
- Args:
x (torch.Tensor): Input tensor of shape (N, C, H, W). mask_ratio (float): Ratio of values to mask.
- Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Encoded representation, binary mask, shuffled indices.
- forward_loss(imgs, pred, mask)
Calculate the loss.
- Args:
imgs (torch.Tensor): Input images of shape (N, C, H, W). pred (torch.Tensor): Predicted output of shape (N, L, patch_size^2 * in_chans). mask (torch.Tensor): Binary mask of shape (N, L).
- Returns:
torch.Tensor: Computed loss value.
- initialize_weights()
- patchify(imgs)
Extract patches from input images.
- Args:
imgs (torch.Tensor): Input images of shape (N, C, H, W).
- Returns:
torch.Tensor: Patches of shape (N, num_patches, patch_size^2 * in_chans).
- random_masking(x, mask_ratio)
Perform per-sample random masking by per-sample shuffling.
- Args:
x (torch.Tensor): Input tensor of shape (N, L, D). mask_ratio (float): Ratio of values to mask.
- Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Masked input, binary mask, shuffled indices.
- training_step(batch, batch_idx)
Training step.
- Args:
batch (Tuple[torch.Tensor]): Input batch of images and corresponding labels. batch_idx (int): Index of the current batch.
- Returns:
Dict[str, torch.Tensor]: Dictionary containing the loss value for the current step.
- unpatchify(x)
Reconstruct images from patches.
- Args:
x (torch.Tensor): Patches of shape (N, L, patch_size^2 * in_chans).
- Returns:
torch.Tensor: Reconstructed images of shape (N, C, H, W).
- validation_step(batch, batch_idx)
Validation step.
- Args:
batch (Tuple[torch.Tensor]): Input batch of images and corresponding labels. batch_idx (int): Index of the current batch.
- Returns:
Dict[str, torch.Tensor]: Dictionary containing the loss value for the current step.
- minerva.models.nets.sfm.mae_vit_base_patch16
- minerva.models.nets.sfm.mae_vit_base_patch16D4d256
- minerva.models.nets.sfm.mae_vit_huge_patch14
- minerva.models.nets.sfm.mae_vit_large_patch16
- minerva.models.nets.sfm.mae_vit_large_patch16D4d256
- minerva.models.nets.sfm.mae_vit_small_patch16