minerva.models.ssl.byol

Classes

BYOL

Bootstrap Your Own Latent (BYOL) model for self-supervised representation learning.

Module Contents

class minerva.models.ssl.byol.BYOL(backbone=None, projection_head=None, prediction_head=None, learning_rate=0.001, schedule=90000, criterion=None, optimizer=torch.optim.Adam, optimizer_kwargs=None)[source]

Bases: minerva.models.nets.base.SimpleSupervisedModel

Bootstrap Your Own Latent (BYOL) model for self-supervised representation learning.

This class implements the BYOL framework [1], built on top of SimpleSupervisedModel to reuse its optimizer, logging, and training utilities. Unlike typical supervised models, BYOL does not require labeled data; instead, it learns representations by predicting one augmented view of an image from another, using both an online and a momentum encoder.

The model consists of:
  • An online encoder: backbone + projection head + prediction head.

  • A momentum encoder: backbone + projection head (no prediction head), updated using an exponential moving average of the online encoder parameters.

Key features:
  • Self-supervised loss via NegativeCosineSimilarity

  • Momentum update schedule using cosine decay.

  • Default optimizer: Adam with weight_decay=1e-6.

  • Built-in hooks for momentum update and loss computation.

Parameters

backbonenn.Module, optional

Feature extractor network. Defaults to DeepLabV3Backbone.

projection_headnn.Module, optional

Projection head mapping encoder features to latent space. If None, a default 3-layer MLP is used.

prediction_headnn.Module, optional

Prediction head mapping projected features to target space. If None, a default 2-layer MLP is used.

learning_ratefloat, default=1e-3

Learning rate for optimizer.

scheduleint, default=90000

Number of training steps over which to apply cosine momentum schedule.

criterionnn.Module, optional

Loss function. Defaults to NegativeCosineSimilarity.

optimizertype, optional

Optimizer class. Defaults to torch.optim.Adam if not provided.

optimizer_kwargsdict, optional

Extra keyword arguments for the optimizer. By default, uses {"weight_decay": 1e-6}.

Notes

  • Metrics are disabled by default since BYOL is self-supervised.

  • The fc layer from SimpleSupervisedModel is replaced with nn.Identity() because BYOL uses its own projection/prediction heads.

  • The forward pass returns predictions from the online encoder; the momentum encoder is used internally for target computation only.

References

[1] Grill, J.B., Strub, F., Altché, F., Tallec, C., Richemond, P.H., Buchatskaya, E.,

Doersch, C., Pires, B.A., Guo, Z.D., Azar, M.G., Piot, B., Kavukcuoglu, K., Munos, R., & Valko, M. (2020). Bootstrap Your Own Latent - A New Approach to Self-Supervised Learning. Advances in Neural Information Processing Systems, 33, 21271–21284.

Initializes the supervised model with training components and configs.

Parameters

backbonetorch.nn.Module or LoadableModule

The backbone (feature extractor) model.

fctorch.nn.Module or LoadableModule

The fully connected head. Use nn.Identity() if not required.

loss_fntorch.nn.Module

Loss function to optimize during training.

adapterCallable, optional

Function to transform backbone outputs before feeding into fc.

learning_ratefloat, default=1e-3

Learning rate used for optimization.

flattenbool, default=True

If True, flattens backbone outputs before fc.

train_metricsdict, optional

TorchMetrics dictionary for training evaluation.

val_metricsdict, optional

TorchMetrics dictionary for validation evaluation.

test_metricsdict, optional

TorchMetrics dictionary for test evaluation.

freeze_backbonebool, default=False

If True, backbone parameters are frozen during training.

optimizer: type

Optimizer class to be instantiated. By default, it is set to torch.optim.Adam. Should be a subclass of torch.optim.Optimizer (e.g., torch.optim.SGD).

optimizer_kwargsdict, optional

Additional kwargs passed to the optimizer constructor.

lr_schedulertype, optional

Learning rate scheduler class to be instantiated. By default, it is set to None, which means no scheduler will be used. Should be a subclass of torch.optim.lr_scheduler.LRScheduler (e.g., torch.optim.lr_scheduler.StepLR).

lr_scheduler_kwargsdict, optional

Additional kwargs passed to the scheduler constructor.

_default_prediction_head()[source]

Creates the default prediction head used in BYOL.

Return type:

torch.nn.Module

_default_projection_head()[source]

Creates the default projection head used in BYOL.

Return type:

torch.nn.Module

_loss_func(outputs, targets=None)[source]

Calculate the loss between the output and the input data.

Parameters

y_hattorch.Tensor

The output data from the forward pass.

ytorch.Tensor

The input data/label.

Returns

torch.Tensor

The loss value.

Return type:

torch.Tensor

backbone
backbone_momentum
cosine_schedule(step, max_steps, start_value, end_value, period=None)[source]

Uses cosine decay to gradually modify start_value to reach end_value.

Parameters

stepint

Current step number.

max_stepsint

Total number of steps.

start_valuefloat

Starting value.

end_valuefloat

Target value.

periodOptional[int]

Steps over which cosine decay completes a full cycle. Defaults to max_steps.

Returns

float

Cosine decay value.

Parameters:
  • step (int)

  • max_steps (int)

  • start_value (float)

  • end_value (float)

  • period (Optional[int])

Return type:

float

criterion
deactivate_requires_grad(model)[source]

Freezes the weights of the model.

Parameters

modelnn.Module

Model to freeze.

Parameters:

model (torch.nn.Module)

forward(x)[source]

Forward pass for the BYOL model.

Parameters

xTensor

Input image tensor.

Returns

Tensor

Output tensor after passing through the backbone, projection, and prediction heads.

Parameters:

x (torch.Tensor)

Return type:

torch.Tensor

forward_momentum(x)[source]

Forward pass using momentum encoder.

Parameters

xTensor

Input image tensor.

Returns

Tensor

Output tensor after passing through the momentum backbone and projection head.

Parameters:

x (torch.Tensor)

Return type:

torch.Tensor

prediction_head
projection_head
projection_head_momentum
schedule_length = 90000
training_step(batch, batch_idx)[source]

Overrides SimpleSupervisedModel’s step for BYOL.

Parameters:
  • batch (Sequence[torch.Tensor])

  • batch_idx (int)

Return type:

torch.Tensor

update_momentum(model, model_ema, m)[source]

Updates model weights using momentum.

Parameters

modelnn.Module

Original model.

model_emann.Module

Momentum model.

mfloat

Momentum factor.

Parameters:
  • model (torch.nn.Module)

  • model_ema (torch.nn.Module)

  • m (float)

Parameters:
  • backbone (Optional[torch.nn.Module])

  • projection_head (Optional[torch.nn.Module])

  • prediction_head (Optional[torch.nn.Module])

  • learning_rate (float)

  • schedule (int)

  • criterion (Optional[torch.nn.Module])

  • optimizer (type)

  • optimizer_kwargs (Optional[Dict[str, Any]])