minerva.models.ssl.byol¶
Classes¶
Bootstrap Your Own Latent (BYOL) model for self-supervised representation learning. |
Module Contents¶
- class minerva.models.ssl.byol.BYOL(backbone=None, projection_head=None, prediction_head=None, learning_rate=0.001, schedule=90000, criterion=None, optimizer=torch.optim.Adam, optimizer_kwargs=None)[source]¶
Bases:
minerva.models.nets.base.SimpleSupervisedModelBootstrap Your Own Latent (BYOL) model for self-supervised representation learning.
This class implements the BYOL framework [1], built on top of
SimpleSupervisedModelto reuse its optimizer, logging, and training utilities. Unlike typical supervised models, BYOL does not require labeled data; instead, it learns representations by predicting one augmented view of an image from another, using both an online and a momentum encoder.- The model consists of:
An online encoder: backbone + projection head + prediction head.
A momentum encoder: backbone + projection head (no prediction head), updated using an exponential moving average of the online encoder parameters.
- Key features:
Self-supervised loss via
NegativeCosineSimilarityMomentum update schedule using cosine decay.
Default optimizer: Adam with
weight_decay=1e-6.Built-in hooks for momentum update and loss computation.
Parameters¶
- backbonenn.Module, optional
Feature extractor network. Defaults to
DeepLabV3Backbone.- projection_headnn.Module, optional
Projection head mapping encoder features to latent space. If None, a default 3-layer MLP is used.
- prediction_headnn.Module, optional
Prediction head mapping projected features to target space. If None, a default 2-layer MLP is used.
- learning_ratefloat, default=1e-3
Learning rate for optimizer.
- scheduleint, default=90000
Number of training steps over which to apply cosine momentum schedule.
- criterionnn.Module, optional
Loss function. Defaults to
NegativeCosineSimilarity.- optimizertype, optional
Optimizer class. Defaults to
torch.optim.Adamif not provided.- optimizer_kwargsdict, optional
Extra keyword arguments for the optimizer. By default, uses
{"weight_decay": 1e-6}.
Notes¶
Metrics are disabled by default since BYOL is self-supervised.
The
fclayer fromSimpleSupervisedModelis replaced withnn.Identity()because BYOL uses its own projection/prediction heads.The forward pass returns predictions from the online encoder; the momentum encoder is used internally for target computation only.
References¶
- [1] Grill, J.B., Strub, F., Altché, F., Tallec, C., Richemond, P.H., Buchatskaya, E.,
Doersch, C., Pires, B.A., Guo, Z.D., Azar, M.G., Piot, B., Kavukcuoglu, K., Munos, R., & Valko, M. (2020). Bootstrap Your Own Latent - A New Approach to Self-Supervised Learning. Advances in Neural Information Processing Systems, 33, 21271–21284.
Initializes the supervised model with training components and configs.
Parameters¶
- backbonetorch.nn.Module or LoadableModule
The backbone (feature extractor) model.
- fctorch.nn.Module or LoadableModule
The fully connected head. Use nn.Identity() if not required.
- loss_fntorch.nn.Module
Loss function to optimize during training.
- adapterCallable, optional
Function to transform backbone outputs before feeding into fc.
- learning_ratefloat, default=1e-3
Learning rate used for optimization.
- flattenbool, default=True
If True, flattens backbone outputs before fc.
- train_metricsdict, optional
TorchMetrics dictionary for training evaluation.
- val_metricsdict, optional
TorchMetrics dictionary for validation evaluation.
- test_metricsdict, optional
TorchMetrics dictionary for test evaluation.
- freeze_backbonebool, default=False
If True, backbone parameters are frozen during training.
- optimizer: type
Optimizer class to be instantiated. By default, it is set to torch.optim.Adam. Should be a subclass of torch.optim.Optimizer (e.g., torch.optim.SGD).
- optimizer_kwargsdict, optional
Additional kwargs passed to the optimizer constructor.
- lr_schedulertype, optional
Learning rate scheduler class to be instantiated. By default, it is set to None, which means no scheduler will be used. Should be a subclass of torch.optim.lr_scheduler.LRScheduler (e.g., torch.optim.lr_scheduler.StepLR).
- lr_scheduler_kwargsdict, optional
Additional kwargs passed to the scheduler constructor.
- _default_prediction_head()[source]¶
Creates the default prediction head used in BYOL.
- Return type:
torch.nn.Module
- _default_projection_head()[source]¶
Creates the default projection head used in BYOL.
- Return type:
torch.nn.Module
- _loss_func(outputs, targets=None)[source]¶
Calculate the loss between the output and the input data.
Parameters¶
- y_hattorch.Tensor
The output data from the forward pass.
- ytorch.Tensor
The input data/label.
Returns¶
- torch.Tensor
The loss value.
- Return type:
torch.Tensor
- backbone¶
- backbone_momentum¶
- cosine_schedule(step, max_steps, start_value, end_value, period=None)[source]¶
Uses cosine decay to gradually modify start_value to reach end_value.
Parameters¶
- stepint
Current step number.
- max_stepsint
Total number of steps.
- start_valuefloat
Starting value.
- end_valuefloat
Target value.
- periodOptional[int]
Steps over which cosine decay completes a full cycle. Defaults to max_steps.
Returns¶
- float
Cosine decay value.
- Parameters:
step (int)
max_steps (int)
start_value (float)
end_value (float)
period (Optional[int])
- Return type:
float
- criterion¶
- deactivate_requires_grad(model)[source]¶
Freezes the weights of the model.
Parameters¶
- modelnn.Module
Model to freeze.
- Parameters:
model (torch.nn.Module)
- forward(x)[source]¶
Forward pass for the BYOL model.
Parameters¶
- xTensor
Input image tensor.
Returns¶
- Tensor
Output tensor after passing through the backbone, projection, and prediction heads.
- Parameters:
x (torch.Tensor)
- Return type:
torch.Tensor
- forward_momentum(x)[source]¶
Forward pass using momentum encoder.
Parameters¶
- xTensor
Input image tensor.
Returns¶
- Tensor
Output tensor after passing through the momentum backbone and projection head.
- Parameters:
x (torch.Tensor)
- Return type:
torch.Tensor
- prediction_head¶
- projection_head¶
- projection_head_momentum¶
- schedule_length = 90000¶
- Parameters:
backbone (Optional[torch.nn.Module])
projection_head (Optional[torch.nn.Module])
prediction_head (Optional[torch.nn.Module])
learning_rate (float)
schedule (int)
criterion (Optional[torch.nn.Module])
optimizer (type)
optimizer_kwargs (Optional[Dict[str, Any]])