minerva.models.ssl.simsiam

Classes

SimSiam

SimSiam implementation using PyTorch Lightning.

Module Contents

class minerva.models.ssl.simsiam.SimSiam(backbone, projection_head=None, prediction_head=None, loss_fn=None, learning_rate=0.0001, weight_decay=1e-06)[source]

Bases: lightning.LightningModule

SimSiam implementation using PyTorch Lightning.

This class implements the SimSiam self-supervised learning framework, which is designed to learn useful representations without using negative samples. It employs a backbone encoder, a projection head, and a prediction head to train the backbone.

Initialize the SimSiam module.

Parameters

backbonenn.Module

The feature extractor network (e.g., a ResNet encoder).

projection_headnn.Module, optional

The network that maps backbone outputs to the projection space. If None, a default 3-layer MLP designed to work with ResNet50 is used.

prediction_headnn.Module, optional

The network that maps projection vectors to the prediction space. If None, a default 2-layer MLP is used.

loss_fnCallable, optional

Loss function used for training. Default is cosine similarity loss.

learning_ratefloat, optional

Learning rate for the optimizer. Default is 0.0001.

weight_decayfloat, optional

Weight decay for the optimizer. Default is 1e-6.

_single_step(batch)[source]

Compute the loss for a single batch.

Parameters

batchTuple[Tuple[Tensor, Tensor], Any]

A tuple containing a pair of augmented views (x0, x1) and labels (unused).

Returns

torch.Tensor

The computed loss for the batch.

Parameters:

batch (Tuple[Tuple[torch.Tensor, torch.Tensor], Any])

Return type:

torch.Tensor

backbone
configure_optimizers()[source]

Configures the Adam optimizer with provided learning rate and weight decay.

Returns

torch.optim.Optimizer

The optimizer used for training.

forward(x)[source]

Forward pass through the backbone, projection, and prediction heads.

Parameters

xtorch.Tensor

Input tensor of shape (batch_size, channels, height, width).

Returns

Tuple[torch.Tensor, torch.Tensor]

The detached projection vector z and prediction vector p.

learning_rate = 0.0001
training_step(batch, batch_idx)[source]

Defines one training step.

Parameters

batchTuple[Tuple[Tensor, Tensor], Any]

Batch containing two augmented views and labels (unused).

batch_idxint

Index of the batch.

Returns

torch.Tensor

Training loss for the batch.

Parameters:
  • batch (Tuple[Tuple[torch.Tensor, torch.Tensor], Any])

  • batch_idx (int)

Return type:

torch.Tensor

validation_step(batch, batch_idx)[source]

Defines one validation step.

Parameters

batchTuple[Tuple[Tensor, Tensor], Any]

Batch containing two augmented views and labels (unused).

batch_idxint

Index of the batch.

Returns

torch.Tensor

Validation loss for the batch.

Parameters:
  • batch (Tuple[Tuple[torch.Tensor, torch.Tensor], Any])

  • batch_idx (int)

Return type:

torch.Tensor

weight_decay = 1e-06
Parameters:
  • backbone (torch.nn.Module)

  • projection_head (Optional[torch.nn.Module])

  • prediction_head (Optional[torch.nn.Module])

  • loss_fn (Optional[torch.nn.Module])

  • learning_rate (float)

  • weight_decay (float)