Source code for minerva.models.ssl.byol

import copy
import torch
import numpy as np
from torch import nn, Tensor
from collections import OrderedDict
from typing import Optional, Sequence, Dict, Any

from minerva.losses.negative_cossine_similatiry import NegativeCosineSimilarity
from minerva.models.nets.mlp import MLP
from minerva.models.nets.image.deeplabv3 import DeepLabV3Backbone
from minerva.models.nets.base import SimpleSupervisedModel


[docs] class BYOL(SimpleSupervisedModel): """ Bootstrap Your Own Latent (BYOL) model for self-supervised representation learning. This class implements the BYOL framework [1], built on top of :class:`SimpleSupervisedModel` to reuse its optimizer, logging, and training utilities. Unlike typical supervised models, BYOL does not require labeled data; instead, it learns representations by predicting one augmented view of an image from another, using both an online and a momentum encoder. The model consists of: - An **online encoder**: backbone + projection head + prediction head. - A **momentum encoder**: backbone + projection head (no prediction head), updated using an exponential moving average of the online encoder parameters. Key features: - Self-supervised loss via :class:`~minerva.losses.negative_cossine_similatiry.NegativeCosineSimilarity` - Momentum update schedule using cosine decay. - Default optimizer: Adam with ``weight_decay=1e-6``. - Built-in hooks for momentum update and loss computation. Parameters ---------- backbone : nn.Module, optional Feature extractor network. Defaults to :class:`~minerva.models.nets.image.deeplabv3.DeepLabV3Backbone`. projection_head : nn.Module, optional Projection head mapping encoder features to latent space. If None, a default 3-layer MLP is used. prediction_head : nn.Module, optional Prediction head mapping projected features to target space. If None, a default 2-layer MLP is used. learning_rate : float, default=1e-3 Learning rate for optimizer. schedule : int, default=90000 Number of training steps over which to apply cosine momentum schedule. criterion : nn.Module, optional Loss function. Defaults to :class:`~minerva.losses.negative_cossine_similatiry.NegativeCosineSimilarity`. optimizer : type, optional Optimizer class. Defaults to :class:`torch.optim.Adam` if not provided. optimizer_kwargs : dict, optional Extra keyword arguments for the optimizer. By default, uses ``{"weight_decay": 1e-6}``. Notes ----- - Metrics are disabled by default since BYOL is self-supervised. - The ``fc`` layer from :class:`SimpleSupervisedModel` is replaced with ``nn.Identity()`` because BYOL uses its own projection/prediction heads. - The forward pass returns predictions from the online encoder; the momentum encoder is used internally for target computation only. References ---------- [1] Grill, J.B., Strub, F., Altché, F., Tallec, C., Richemond, P.H., Buchatskaya, E., Doersch, C., Pires, B.A., Guo, Z.D., Azar, M.G., Piot, B., Kavukcuoglu, K., Munos, R., & Valko, M. (2020). Bootstrap Your Own Latent - A New Approach to Self-Supervised Learning. Advances in Neural Information Processing Systems, 33, 21271–21284. """ def __init__( self, backbone: Optional[nn.Module] = None, projection_head: Optional[nn.Module] = None, prediction_head: Optional[nn.Module] = None, learning_rate: float = 1e-3, schedule: int = 90000, criterion: Optional[nn.Module] = None, optimizer: type = torch.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, ): backbone_model = backbone or DeepLabV3Backbone() projection_head_model = projection_head or self._default_projection_head() prediction_head_model = prediction_head or self._default_prediction_head() loss_criterion = criterion or NegativeCosineSimilarity() optimizer = optimizer or torch.optim.Adam default_optimizer_kwargs = {"lr": learning_rate, "weight_decay": 1e-6} if optimizer_kwargs: default_optimizer_kwargs = optimizer_kwargs super().__init__( backbone=backbone_model, fc=nn.Identity(), loss_fn=loss_criterion, adapter=None, learning_rate=learning_rate, flatten=False, train_metrics=None, val_metrics=None, test_metrics=None, freeze_backbone=False, optimizer=optimizer, optimizer_kwargs=default_optimizer_kwargs, ) self.backbone = backbone_model self.projection_head = projection_head_model self.prediction_head = prediction_head_model self.backbone_momentum = copy.deepcopy(self.backbone) self.projection_head_momentum = copy.deepcopy(self.projection_head) self.deactivate_requires_grad(self.backbone_momentum) self.deactivate_requires_grad(self.projection_head_momentum) self.criterion = loss_criterion self.schedule_length = schedule
[docs] def _default_projection_head(self) -> nn.Module: """Creates the default projection head used in BYOL.""" return nn.Sequential( nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(start_dim=1), MLP( layer_sizes=[2048, 4096, 256], activation_cls=nn.ReLU, intermediate_ops=[nn.BatchNorm1d(4096), None], ), )
[docs] def _default_prediction_head(self) -> nn.Module: """Creates the default prediction head used in BYOL.""" return MLP( layer_sizes=[256, 4096, 256], activation_cls=nn.ReLU, intermediate_ops=[nn.BatchNorm1d(4096), None], )
[docs] def forward(self, x: Tensor) -> Tensor: """ Forward pass for the BYOL model. Parameters ---------- x : Tensor Input image tensor. Returns ------- Tensor Output tensor after passing through the backbone, projection, and prediction heads. """ y = self.backbone(x) if isinstance(y, OrderedDict): y = y["out"] z = self.projection_head(y) p = self.prediction_head(z) return p
[docs] def forward_momentum(self, x: Tensor) -> Tensor: """ Forward pass using momentum encoder. Parameters ---------- x : Tensor Input image tensor. Returns ------- Tensor Output tensor after passing through the momentum backbone and projection head. """ y = self.backbone_momentum(x) if isinstance(y, OrderedDict): y = y["out"] z = self.projection_head_momentum(y) return z.detach()
[docs] def _loss_func(self, outputs, targets=None) -> torch.Tensor: x0, x1 = outputs p0 = self.forward(x0) z0 = self.forward_momentum(x0) p1 = self.forward(x1) z1 = self.forward_momentum(x1) return 0.5 * (self.criterion(p0, z1) + self.criterion(p1, z0))
[docs] def training_step(self, batch: Sequence[Tensor], batch_idx: int) -> torch.Tensor: """Overrides SimpleSupervisedModel's step for BYOL.""" momentum = self.cosine_schedule( self.current_epoch, self.schedule_length, 0.996, 1 ) self.update_momentum(self.backbone, self.backbone_momentum, m=momentum) self.update_momentum( self.projection_head, self.projection_head_momentum, m=momentum ) loss = self._loss_func(batch) self.log( "train_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, ) return loss
[docs] @torch.no_grad() def update_momentum(self, model: nn.Module, model_ema: nn.Module, m: float): """ Updates model weights using momentum. Parameters ---------- model : nn.Module Original model. model_ema : nn.Module Momentum model. m : float Momentum factor. """ for model_ema_param, model_param in zip( model_ema.parameters(), model.parameters() ): model_ema_param.data = model_ema_param.data * m + model_param.data * ( 1.0 - m )
[docs] @torch.no_grad() def deactivate_requires_grad(self, model: nn.Module): """ Freezes the weights of the model. Parameters ---------- model : nn.Module Model to freeze. """ for param in model.parameters(): param.requires_grad = False
[docs] def cosine_schedule( self, step: int, max_steps: int, start_value: float, end_value: float, period: Optional[int] = None, ) -> float: """ Uses cosine decay to gradually modify `start_value` to reach `end_value`. Parameters ---------- step : int Current step number. max_steps : int Total number of steps. start_value : float Starting value. end_value : float Target value. period : Optional[int] Steps over which cosine decay completes a full cycle. Defaults to max_steps. Returns ------- float Cosine decay value. """ if step < 0: raise ValueError(f"Current step number {step} can't be negative") if max_steps < 1: raise ValueError(f"Total step number {max_steps} must be >= 1") if period is not None and period <= 0: raise ValueError(f"Period {period} must be >= 1") decay = ( end_value - (end_value - start_value) * (np.cos(np.pi * step / (max_steps - 1)) + 1) / 2 ) return decay