minerva.models.nets.sfm

Attributes

mae_vit_base_patch16

mae_vit_base_patch16D4d256

mae_vit_huge_patch14

mae_vit_large_patch16

mae_vit_large_patch16D4d256

mae_vit_small_patch16

Classes

MaskedAutoencoderViT

Masked Autoencoder with VisionTransformer backbone.

Module Contents

class minerva.models.nets.sfm.MaskedAutoencoderViT(img_size=224, patch_size=16, in_chans=1, embed_dim=1024, depth=24, num_heads=16, decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, mlp_ratio=4.0, norm_layer=nn.LayerNorm, norm_pix_loss=False)

Bases: lightning.LightningModule

Masked Autoencoder with VisionTransformer backbone.

Args:

img_size (int): Size of input image. patch_size (int): Size of image patch. in_chans (int): Number of input channels. embed_dim (int): Dimension of token embeddings. depth (int): Number of transformer blocks. num_heads (int): Number of attention heads. decoder_embed_dim (int): Dimension of decoder embeddings. decoder_depth (int): Number of decoder transformer blocks. decoder_num_heads (int): Number of decoder attention heads. mlp_ratio (float): Ratio of MLP hidden layer size to embedding size. norm_layer (torch.nn.LayerNorm): Normalization layer. norm_pix_loss (bool): Whether to normalize pixel loss.

References:
_init_weights(m)
configure_optimizers()

Configure optimizer.

Returns:

torch.optim.Optimizer: Optimizer.

forward(imgs, mask_ratio=0.75)

Forward pass.

Args:

imgs (torch.Tensor): Input images of shape (N, C, H, W). mask_ratio (float): Ratio of values to mask.

Returns:

Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Loss value, predicted output, binary mask.

forward_decoder(x, ids_restore)

Forward pass through the decoder.

Args:

x (torch.Tensor): Input tensor of shape (N, L, D). ids_restore (torch.Tensor): Indices to restore the original order of patches.

Returns:

torch.Tensor: Decoded output tensor of shape (N, L, patch_size^2 * in_chans).

forward_encoder(x, mask_ratio)

Forward pass through the encoder.

Args:

x (torch.Tensor): Input tensor of shape (N, C, H, W). mask_ratio (float): Ratio of values to mask.

Returns:

Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Encoded representation, binary mask, shuffled indices.

forward_loss(imgs, pred, mask)

Calculate the loss.

Args:

imgs (torch.Tensor): Input images of shape (N, C, H, W). pred (torch.Tensor): Predicted output of shape (N, L, patch_size^2 * in_chans). mask (torch.Tensor): Binary mask of shape (N, L).

Returns:

torch.Tensor: Computed loss value.

initialize_weights()
patchify(imgs)

Extract patches from input images.

Args:

imgs (torch.Tensor): Input images of shape (N, C, H, W).

Returns:

torch.Tensor: Patches of shape (N, num_patches, patch_size^2 * in_chans).

random_masking(x, mask_ratio)

Perform per-sample random masking by per-sample shuffling.

Args:

x (torch.Tensor): Input tensor of shape (N, L, D). mask_ratio (float): Ratio of values to mask.

Returns:

Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Masked input, binary mask, shuffled indices.

training_step(batch, batch_idx)

Training step.

Args:

batch (Tuple[torch.Tensor]): Input batch of images and corresponding labels. batch_idx (int): Index of the current batch.

Returns:

Dict[str, torch.Tensor]: Dictionary containing the loss value for the current step.

unpatchify(x)

Reconstruct images from patches.

Args:

x (torch.Tensor): Patches of shape (N, L, patch_size^2 * in_chans).

Returns:

torch.Tensor: Reconstructed images of shape (N, C, H, W).

validation_step(batch, batch_idx)

Validation step.

Args:

batch (Tuple[torch.Tensor]): Input batch of images and corresponding labels. batch_idx (int): Index of the current batch.

Returns:

Dict[str, torch.Tensor]: Dictionary containing the loss value for the current step.

minerva.models.nets.sfm.mae_vit_base_patch16
minerva.models.nets.sfm.mae_vit_base_patch16D4d256
minerva.models.nets.sfm.mae_vit_huge_patch14
minerva.models.nets.sfm.mae_vit_large_patch16
minerva.models.nets.sfm.mae_vit_large_patch16D4d256
minerva.models.nets.sfm.mae_vit_small_patch16