minerva.models.ssl.simsiam¶
Classes¶
SimSiam implementation using PyTorch Lightning. |
Module Contents¶
- class minerva.models.ssl.simsiam.SimSiam(backbone, projection_head=None, prediction_head=None, loss_fn=None, learning_rate=0.0001, weight_decay=1e-06)[source]¶
Bases:
lightning.LightningModuleSimSiam implementation using PyTorch Lightning.
This class implements the SimSiam self-supervised learning framework, which is designed to learn useful representations without using negative samples. It employs a backbone encoder, a projection head, and a prediction head to train the backbone.
Initialize the SimSiam module.
Parameters¶
- backbonenn.Module
The feature extractor network (e.g., a ResNet encoder).
- projection_headnn.Module, optional
The network that maps backbone outputs to the projection space. If None, a default 3-layer MLP designed to work with ResNet50 is used.
- prediction_headnn.Module, optional
The network that maps projection vectors to the prediction space. If None, a default 2-layer MLP is used.
- loss_fnCallable, optional
Loss function used for training. Default is cosine similarity loss.
- learning_ratefloat, optional
Learning rate for the optimizer. Default is 0.0001.
- weight_decayfloat, optional
Weight decay for the optimizer. Default is 1e-6.
- _single_step(batch)[source]¶
Compute the loss for a single batch.
Parameters¶
- batchTuple[Tuple[Tensor, Tensor], Any]
A tuple containing a pair of augmented views (x0, x1) and labels (unused).
Returns¶
- torch.Tensor
The computed loss for the batch.
- Parameters:
batch (Tuple[Tuple[torch.Tensor, torch.Tensor], Any])
- Return type:
torch.Tensor
- backbone¶
- configure_optimizers()[source]¶
Configures the Adam optimizer with provided learning rate and weight decay.
Returns¶
- torch.optim.Optimizer
The optimizer used for training.
- forward(x)[source]¶
Forward pass through the backbone, projection, and prediction heads.
Parameters¶
- xtorch.Tensor
Input tensor of shape (batch_size, channels, height, width).
Returns¶
- Tuple[torch.Tensor, torch.Tensor]
The detached projection vector z and prediction vector p.
- learning_rate = 0.0001¶
- training_step(batch, batch_idx)[source]¶
Defines one training step.
Parameters¶
- batchTuple[Tuple[Tensor, Tensor], Any]
Batch containing two augmented views and labels (unused).
- batch_idxint
Index of the batch.
Returns¶
- torch.Tensor
Training loss for the batch.
- Parameters:
batch (Tuple[Tuple[torch.Tensor, torch.Tensor], Any])
batch_idx (int)
- Return type:
torch.Tensor
- validation_step(batch, batch_idx)[source]¶
Defines one validation step.
Parameters¶
- batchTuple[Tuple[Tensor, Tensor], Any]
Batch containing two augmented views and labels (unused).
- batch_idxint
Index of the batch.
Returns¶
- torch.Tensor
Validation loss for the batch.
- Parameters:
batch (Tuple[Tuple[torch.Tensor, torch.Tensor], Any])
batch_idx (int)
- Return type:
torch.Tensor
- weight_decay = 1e-06¶
- Parameters:
backbone (torch.nn.Module)
projection_head (Optional[torch.nn.Module])
prediction_head (Optional[torch.nn.Module])
loss_fn (Optional[torch.nn.Module])
learning_rate (float)
weight_decay (float)