minerva.models.ssl.byol ======================= .. py:module:: minerva.models.ssl.byol Classes ------- .. autoapisummary:: minerva.models.ssl.byol.BYOL Module Contents --------------- .. py:class:: BYOL(backbone = None, projection_head = None, prediction_head = None, learning_rate = 0.001, schedule = 90000, criterion = None, optimizer = torch.optim.Adam, optimizer_kwargs = None) Bases: :py:obj:`minerva.models.nets.base.SimpleSupervisedModel` Bootstrap Your Own Latent (BYOL) model for self-supervised representation learning. This class implements the BYOL framework [1], built on top of :class:`SimpleSupervisedModel` to reuse its optimizer, logging, and training utilities. Unlike typical supervised models, BYOL does not require labeled data; instead, it learns representations by predicting one augmented view of an image from another, using both an online and a momentum encoder. The model consists of: - An **online encoder**: backbone + projection head + prediction head. - A **momentum encoder**: backbone + projection head (no prediction head), updated using an exponential moving average of the online encoder parameters. Key features: - Self-supervised loss via :class:`~minerva.losses.negative_cossine_similatiry.NegativeCosineSimilarity` - Momentum update schedule using cosine decay. - Default optimizer: Adam with ``weight_decay=1e-6``. - Built-in hooks for momentum update and loss computation. Parameters ---------- backbone : nn.Module, optional Feature extractor network. Defaults to :class:`~minerva.models.nets.image.deeplabv3.DeepLabV3Backbone`. projection_head : nn.Module, optional Projection head mapping encoder features to latent space. If None, a default 3-layer MLP is used. prediction_head : nn.Module, optional Prediction head mapping projected features to target space. If None, a default 2-layer MLP is used. learning_rate : float, default=1e-3 Learning rate for optimizer. schedule : int, default=90000 Number of training steps over which to apply cosine momentum schedule. criterion : nn.Module, optional Loss function. Defaults to :class:`~minerva.losses.negative_cossine_similatiry.NegativeCosineSimilarity`. optimizer : type, optional Optimizer class. Defaults to :class:`torch.optim.Adam` if not provided. optimizer_kwargs : dict, optional Extra keyword arguments for the optimizer. By default, uses ``{"weight_decay": 1e-6}``. Notes ----- - Metrics are disabled by default since BYOL is self-supervised. - The ``fc`` layer from :class:`SimpleSupervisedModel` is replaced with ``nn.Identity()`` because BYOL uses its own projection/prediction heads. - The forward pass returns predictions from the online encoder; the momentum encoder is used internally for target computation only. References ---------- [1] Grill, J.B., Strub, F., Altché, F., Tallec, C., Richemond, P.H., Buchatskaya, E., Doersch, C., Pires, B.A., Guo, Z.D., Azar, M.G., Piot, B., Kavukcuoglu, K., Munos, R., & Valko, M. (2020). Bootstrap Your Own Latent - A New Approach to Self-Supervised Learning. Advances in Neural Information Processing Systems, 33, 21271–21284. Initializes the supervised model with training components and configs. Parameters ---------- backbone : torch.nn.Module or LoadableModule The backbone (feature extractor) model. fc : torch.nn.Module or LoadableModule The fully connected head. Use nn.Identity() if not required. loss_fn : torch.nn.Module Loss function to optimize during training. adapter : Callable, optional Function to transform backbone outputs before feeding into `fc`. learning_rate : float, default=1e-3 Learning rate used for optimization. flatten : bool, default=True If True, flattens backbone outputs before `fc`. train_metrics : dict, optional TorchMetrics dictionary for training evaluation. val_metrics : dict, optional TorchMetrics dictionary for validation evaluation. test_metrics : dict, optional TorchMetrics dictionary for test evaluation. freeze_backbone : bool, default=False If True, backbone parameters are frozen during training. optimizer: type Optimizer class to be instantiated. By default, it is set to `torch.optim.Adam`. Should be a subclass of `torch.optim.Optimizer` (e.g., `torch.optim.SGD`). optimizer_kwargs : dict, optional Additional kwargs passed to the optimizer constructor. lr_scheduler : type, optional Learning rate scheduler class to be instantiated. By default, it is set to None, which means no scheduler will be used. Should be a subclass of `torch.optim.lr_scheduler.LRScheduler` (e.g., `torch.optim.lr_scheduler.StepLR`). lr_scheduler_kwargs : dict, optional Additional kwargs passed to the scheduler constructor. .. py:method:: _default_prediction_head() Creates the default prediction head used in BYOL. .. py:method:: _default_projection_head() Creates the default projection head used in BYOL. .. py:method:: _loss_func(outputs, targets=None) Calculate the loss between the output and the input data. Parameters ---------- y_hat : torch.Tensor The output data from the forward pass. y : torch.Tensor The input data/label. Returns ------- torch.Tensor The loss value. .. py:attribute:: backbone .. py:attribute:: backbone_momentum .. py:method:: cosine_schedule(step, max_steps, start_value, end_value, period = None) Uses cosine decay to gradually modify `start_value` to reach `end_value`. Parameters ---------- step : int Current step number. max_steps : int Total number of steps. start_value : float Starting value. end_value : float Target value. period : Optional[int] Steps over which cosine decay completes a full cycle. Defaults to max_steps. Returns ------- float Cosine decay value. .. py:attribute:: criterion .. py:method:: deactivate_requires_grad(model) Freezes the weights of the model. Parameters ---------- model : nn.Module Model to freeze. .. py:method:: forward(x) Forward pass for the BYOL model. Parameters ---------- x : Tensor Input image tensor. Returns ------- Tensor Output tensor after passing through the backbone, projection, and prediction heads. .. py:method:: forward_momentum(x) Forward pass using momentum encoder. Parameters ---------- x : Tensor Input image tensor. Returns ------- Tensor Output tensor after passing through the momentum backbone and projection head. .. py:attribute:: prediction_head .. py:attribute:: projection_head .. py:attribute:: projection_head_momentum .. py:attribute:: schedule_length :value: 90000 .. py:method:: training_step(batch, batch_idx) Overrides SimpleSupervisedModel's step for BYOL. .. py:method:: update_momentum(model, model_ema, m) Updates model weights using momentum. Parameters ---------- model : nn.Module Original model. model_ema : nn.Module Momentum model. m : float Momentum factor.