minerva.models.ssl.barlowtwins

Classes

BarlowTwins

PyTorch Lightning module implementing the Barlow Twins self-supervised learning framework.

Module Contents

class minerva.models.ssl.barlowtwins.BarlowTwins(backbone, projection_head, loss_fn=None, learning_rate=0.0001, weight_decay=1e-06)[source]

Bases: lightning.LightningModule

PyTorch Lightning module implementing the Barlow Twins self-supervised learning framework.

It accepts a backbone and projection head for feature encoding and embedding, uses a contrastive loss (defaulting to BarlowTwinsLoss if none is provided), supports standard training and validation loops in PyTorch Lightning, and optimizes using the Adam optimizer.

Initialize the BarlowTwins module.

Parameters

backbonenn.Module

Neural network used to extract features from input images.

projection_headnn.Module

Network that maps backbone outputs to a latent space.

loss_fnnn.Module, optional

Custom loss function. Defaults to BarlowTwinsLoss.

learning_ratefloat, optional

Learning rate for the optimizer. Default is 0.0001.

weight_decayfloat, optional

Weight decay (L2 regularization). Default is 1e-6.

_single_step(batch)[source]

Compute the loss for a single batch.

Parameters

batchtuple of ((torch.Tensor, torch.Tensor), Any)

A tuple containing a pair of augmented views (x0, x1) and labels (unused).

Returns

torch.Tensor

The computed loss for the batch.

Parameters:

batch (Tuple[Tuple[torch.Tensor, torch.Tensor], Any])

Return type:

torch.Tensor

backbone
configure_optimizers()[source]

Configures the Adam optimizer with provided learning rate and weight decay.

Returns

torch.optim.Optimizer

Configured Adam optimizer.

forward(x)[source]

Forward pass through the model.

Parameters

xtorch.Tensor

Input tensor (e.g., batch of images).

Returns

torch.Tensor

Projected features in the embedding space.

Parameters:

x (torch.Tensor)

learning_rate = 0.0001
projection_head
training_step(batch, batch_idx)[source]

Defines one training step.

Parameters

batchtuple of ((torch.Tensor, torch.Tensor), Any)

Batch containing two augmented views and labels (unused).

batch_idxint

Index of the batch.

Returns

torch.Tensor

Training loss for the batch.

Parameters:
  • batch (Tuple[Tuple[torch.Tensor, torch.Tensor], Any])

  • batch_idx (int)

Return type:

torch.Tensor

validation_step(batch, batch_idx)[source]

Defines one validation step.

Parameters

batchtuple of ((torch.Tensor, torch.Tensor), Any)

Batch containing two augmented views and labels (unused).

batch_idxint

Index of the batch.

Returns

torch.Tensor

Validation loss for the batch.

Parameters:
  • batch (Tuple[Tuple[torch.Tensor, torch.Tensor], Any])

  • batch_idx (int)

Return type:

torch.Tensor

weight_decay = 1e-06
Parameters:
  • backbone (torch.nn.Module)

  • projection_head (torch.nn.Module)

  • loss_fn (Optional[torch.nn.Module])

  • learning_rate (float)

  • weight_decay (float)