minerva.models.ssl.barlowtwins¶
Classes¶
PyTorch Lightning module implementing the Barlow Twins self-supervised learning framework. |
Module Contents¶
- class minerva.models.ssl.barlowtwins.BarlowTwins(backbone, projection_head, loss_fn=None, learning_rate=0.0001, weight_decay=1e-06)[source]¶
Bases:
lightning.LightningModulePyTorch Lightning module implementing the Barlow Twins self-supervised learning framework.
It accepts a backbone and projection head for feature encoding and embedding, uses a contrastive loss (defaulting to BarlowTwinsLoss if none is provided), supports standard training and validation loops in PyTorch Lightning, and optimizes using the Adam optimizer.
Initialize the BarlowTwins module.
Parameters¶
- backbonenn.Module
Neural network used to extract features from input images.
- projection_headnn.Module
Network that maps backbone outputs to a latent space.
- loss_fnnn.Module, optional
Custom loss function. Defaults to BarlowTwinsLoss.
- learning_ratefloat, optional
Learning rate for the optimizer. Default is 0.0001.
- weight_decayfloat, optional
Weight decay (L2 regularization). Default is 1e-6.
- _single_step(batch)[source]¶
Compute the loss for a single batch.
Parameters¶
- batchtuple of ((torch.Tensor, torch.Tensor), Any)
A tuple containing a pair of augmented views (x0, x1) and labels (unused).
Returns¶
- torch.Tensor
The computed loss for the batch.
- Parameters:
batch (Tuple[Tuple[torch.Tensor, torch.Tensor], Any])
- Return type:
torch.Tensor
- backbone¶
- configure_optimizers()[source]¶
Configures the Adam optimizer with provided learning rate and weight decay.
Returns¶
- torch.optim.Optimizer
Configured Adam optimizer.
- forward(x)[source]¶
Forward pass through the model.
Parameters¶
- xtorch.Tensor
Input tensor (e.g., batch of images).
Returns¶
- torch.Tensor
Projected features in the embedding space.
- Parameters:
x (torch.Tensor)
- learning_rate = 0.0001¶
- projection_head¶
- training_step(batch, batch_idx)[source]¶
Defines one training step.
Parameters¶
- batchtuple of ((torch.Tensor, torch.Tensor), Any)
Batch containing two augmented views and labels (unused).
- batch_idxint
Index of the batch.
Returns¶
- torch.Tensor
Training loss for the batch.
- Parameters:
batch (Tuple[Tuple[torch.Tensor, torch.Tensor], Any])
batch_idx (int)
- Return type:
torch.Tensor
- validation_step(batch, batch_idx)[source]¶
Defines one validation step.
Parameters¶
- batchtuple of ((torch.Tensor, torch.Tensor), Any)
Batch containing two augmented views and labels (unused).
- batch_idxint
Index of the batch.
Returns¶
- torch.Tensor
Validation loss for the batch.
- Parameters:
batch (Tuple[Tuple[torch.Tensor, torch.Tensor], Any])
batch_idx (int)
- Return type:
torch.Tensor
- weight_decay = 1e-06¶
- Parameters:
backbone (torch.nn.Module)
projection_head (torch.nn.Module)
loss_fn (Optional[torch.nn.Module])
learning_rate (float)
weight_decay (float)