minerva.models.ssl.barlowtwins ============================== .. py:module:: minerva.models.ssl.barlowtwins Classes ------- .. autoapisummary:: minerva.models.ssl.barlowtwins.BarlowTwins Module Contents --------------- .. py:class:: BarlowTwins(backbone, projection_head, loss_fn = None, learning_rate = 0.0001, weight_decay = 1e-06) Bases: :py:obj:`lightning.LightningModule` PyTorch Lightning module implementing the Barlow Twins self-supervised learning framework. It accepts a backbone and projection head for feature encoding and embedding, uses a contrastive loss (defaulting to BarlowTwinsLoss if none is provided), supports standard training and validation loops in PyTorch Lightning, and optimizes using the Adam optimizer. Initialize the BarlowTwins module. Parameters ---------- backbone : nn.Module Neural network used to extract features from input images. projection_head : nn.Module Network that maps backbone outputs to a latent space. loss_fn : nn.Module, optional Custom loss function. Defaults to BarlowTwinsLoss. learning_rate : float, optional Learning rate for the optimizer. Default is 0.0001. weight_decay : float, optional Weight decay (L2 regularization). Default is 1e-6. .. py:method:: _single_step(batch) Compute the loss for a single batch. Parameters ---------- batch : tuple of ((torch.Tensor, torch.Tensor), Any) A tuple containing a pair of augmented views (x0, x1) and labels (unused). Returns ------- torch.Tensor The computed loss for the batch. .. py:attribute:: backbone .. py:method:: configure_optimizers() Configures the Adam optimizer with provided learning rate and weight decay. Returns ------- torch.optim.Optimizer Configured Adam optimizer. .. py:method:: forward(x) Forward pass through the model. Parameters ---------- x : torch.Tensor Input tensor (e.g., batch of images). Returns ------- torch.Tensor Projected features in the embedding space. .. py:attribute:: learning_rate :value: 0.0001 .. py:attribute:: projection_head .. py:method:: training_step(batch, batch_idx) Defines one training step. Parameters ---------- batch : tuple of ((torch.Tensor, torch.Tensor), Any) Batch containing two augmented views and labels (unused). batch_idx : int Index of the batch. Returns ------- torch.Tensor Training loss for the batch. .. py:method:: validation_step(batch, batch_idx) Defines one validation step. Parameters ---------- batch : tuple of ((torch.Tensor, torch.Tensor), Any) Batch containing two augmented views and labels (unused). batch_idx : int Index of the batch. Returns ------- torch.Tensor Validation loss for the batch. .. py:attribute:: weight_decay :value: 1e-06