minerva.models.ssl.simsiam ========================== .. py:module:: minerva.models.ssl.simsiam Classes ------- .. autoapisummary:: minerva.models.ssl.simsiam.SimSiam Module Contents --------------- .. py:class:: SimSiam(backbone, projection_head = None, prediction_head = None, loss_fn = None, learning_rate = 0.0001, weight_decay = 1e-06) Bases: :py:obj:`lightning.LightningModule` SimSiam implementation using PyTorch Lightning. This class implements the SimSiam self-supervised learning framework, which is designed to learn useful representations without using negative samples. It employs a backbone encoder, a projection head, and a prediction head to train the backbone. Initialize the SimSiam module. Parameters ---------- backbone : nn.Module The feature extractor network (e.g., a ResNet encoder). projection_head : nn.Module, optional The network that maps backbone outputs to the projection space. If None, a default 3-layer MLP designed to work with ResNet50 is used. prediction_head : nn.Module, optional The network that maps projection vectors to the prediction space. If None, a default 2-layer MLP is used. loss_fn : Callable, optional Loss function used for training. Default is cosine similarity loss. learning_rate : float, optional Learning rate for the optimizer. Default is 0.0001. weight_decay : float, optional Weight decay for the optimizer. Default is 1e-6. .. py:method:: _single_step(batch) Compute the loss for a single batch. Parameters ---------- batch : Tuple[Tuple[Tensor, Tensor], Any] A tuple containing a pair of augmented views (x0, x1) and labels (unused). Returns ------- torch.Tensor The computed loss for the batch. .. py:attribute:: backbone .. py:method:: configure_optimizers() Configures the Adam optimizer with provided learning rate and weight decay. Returns ------- torch.optim.Optimizer The optimizer used for training. .. py:method:: forward(x) Forward pass through the backbone, projection, and prediction heads. Parameters ---------- x : torch.Tensor Input tensor of shape (batch_size, channels, height, width). Returns ------- Tuple[torch.Tensor, torch.Tensor] The detached projection vector `z` and prediction vector `p`. .. py:attribute:: learning_rate :value: 0.0001 .. py:method:: training_step(batch, batch_idx) Defines one training step. Parameters ---------- batch : Tuple[Tuple[Tensor, Tensor], Any] Batch containing two augmented views and labels (unused). batch_idx : int Index of the batch. Returns ------- torch.Tensor Training loss for the batch. .. py:method:: validation_step(batch, batch_idx) Defines one validation step. Parameters ---------- batch : Tuple[Tuple[Tensor, Tensor], Any] Batch containing two augmented views and labels (unused). batch_idx : int Index of the batch. Returns ------- torch.Tensor Validation loss for the batch. .. py:attribute:: weight_decay :value: 1e-06