minerva.models.ssl.lfr

Classes

LearnFromRandomnessModel

A PyTorch Lightning model for pretraining with the technique 'Learning From Random Projectors'.

RepeatedModuleList

A module list with the same module cls, instantiated size times.

Functions

dpp(kernel_matrix, max_length[, epsilon])

Our proposed fast implementation of the greedy algorithm

Module Contents

class minerva.models.ssl.lfr.LearnFromRandomnessModel(backbone, projectors, predictors, loss_fn=None, num_targets=None, adapter=None, learning_rate=0.001, weight_decay=0.0, flatten=False, predictor_training_epochs=None, max_backbone_training_steps=None, selection_batch_size=128)[source]

Bases: lightning.LightningModule

A PyTorch Lightning model for pretraining with the technique ‘Learning From Random Projectors’. When using ‘predictor_training_epochs’, please consider updating your number of training epochs as well. Otherwise, the LFR backbone will be trained for less epochs: - If the total training epochs in your Trainer is 100, and ‘predictor_training_epochs’ is 1, then the

backbone will be trained on the epochs 0, 2, 4, 6, … and 98, resulting in the backbone being effectively trained for 50 epochs instead of the specified 100.

  • If the total training epochs in your Trainer is 100, and ‘predictor_training_epochs’ is 2, then the backbone will be trained on the epochs 0, 3, 6, 9, … and 99, resulting in the backbone being effectively trained for 34 epochs instead of the specified 100.

In conclusion, consider updating your total number of training epochs to:

Total number of training epochs = (intended backbone training epochs) * (predictor_training_epochs + 1)

References

Yi Sui, Tongzi Wu, Jesse C. Cresswell, Ga Wu, George Stein, Xiao Shi Huang, Xiaochen Zhang, Maksims Volkovs. “Self-supervised Representation Learning From Random Data Projectors”, 2024

Initialize the LFR_Model, freezing the projectors. Remember to update your number of training epochs when using ‘predictor_training_epochs’:

Total number of training epochs = (intended backbone training epochs) * (predictor_training_epochs + 1)

Parameters

backbone: torch.nn.Module

The backbone neural network for feature extraction.

projectors: torch.nn.ModuleList

A list of projector networks.

predictors: torch.nn.ModuleList

A list of predictor networks.

num_targets: Optional[int]

The number of projectors and predictors to select from the lists provided, using the Fast Determinantal Point Process (DPP) algorithm. All projectors and predictors are used if the value received is None, a negative integer, or an integer greater than the length of the lists.

loss_fn: Optional[torch.nn.Module]

The loss function to optimize, by default None. If None, the BatchWiseBarlowTwinLoss is used.

adapter: Optional[Callable[[torch.Tensor], torch.Tensor]]

An optional adapter network to be used in the model, by default None.

learning_rate: float

The learning rate for the optimizer, by default 1e-3.

weight_decay: float

The weight decay for the optimizer, by default 0.0.

flatten: bool

Whether to flatten the input tensor or not, by default False.

predictor_training_epochs: Optional[int]

The number of epochs to train only the predictors (excluding the backbone), by default None. If None, zero, or negative, both the predictors and backbone are trained in every epoch. If a positive integer is provided, the backbone is trained for one epoch, then frozen, and the predictors are trained alone for the specified number of epochs. This cycle is repeated throughout the training phase.

max_backbone_training_steps: Optional[int]

The number of steps the backbone will be trained, by default None. If None, zero, or negative, no limit is applied. The steps where the backbone is frozen are ignored.

selection_batch_size: int

By default 128. When selecting projectors and predictors, this variable decides how many random samples from the dataset are used in the Fast Determinantal Point Process (DPP) algorithm.

_loss_from_targets(y_pred, y_proj)[source]

Computes the average loss between each pair of predictor and projector outputs. This function is isolated from _single_step to make it easier to test independently.

Parameters

y_predtorch.Tensor

The predictions tensors.

y_projtorch.Tensor

The projections tensors.

Parameters:
  • y_pred (torch.Tensor)

  • y_proj (torch.Tensor)

_select_targets(sample_data)[source]

Select projectors and predictors based on ‘num_targets’ using the Fast Determinantal Point Process (DPP) algorithm and some sample data. Code adapted from https://github.com/layer6ai-labs/lfr/blob/main/ssl_models/lfr.py

_single_step(batch, batch_idx, step_name)[source]

Perform a single training/validation/test step, computing and logging the loss.

Parameters

batchtorch.Tensor

The input batch of data.

batch_idxint

The index of the batch.

step_namestr

The name of the step (train, val, test).

Returns

torch.Tensor

The loss value for the batch.

Parameters:
  • batch (torch.Tensor)

  • batch_idx (int)

  • step_name (str)

Return type:

torch.Tensor

adapter = None
backbone
backbone_training_steps_counter = 1
configure_optimizers()[source]

Configure the optimizer for the model. This method sets up the optimizer for the model’s parameters, excluding the projectors.

flatten = False
forward(x)[source]

Forward pass through the network.

Parameters

xtorch.Tensor

The input data.

Returns

torch.Tensor

The predicted output and projected input.

Parameters:

x (torch.Tensor)

Return type:

torch.Tensor

freeze_backbone = False
learning_rate = 0.001
loss_fn
max_backbone_training_steps = None
num_targets = None
on_train_batch_end(outputs, batch, batch_idx)[source]

Updates the backbone training steps counter only if the backbone is not frozen.

on_train_batch_start(batch, batch_idx)[source]

If a training steps limit is set, it checks the training step counter at the start of every training batch. If the counter reached the limit, it returns -1, stopping the training.

on_train_epoch_start()[source]

Executed at the start of each training epoch. If the predictor training epochs is valid, this function evaluates the current epoch number and freeze or unfreeze the backbone based on it. If the predictor training epochs is None, zero, or negative, the backbone is always trained. In the first epoch, the backbone is trained. In the subsequent ‘predictor_training_epochs’ epochs, it is frozen.

predictor_training_epochs = None
predictors
projectors
selection_batch_size = 128
setup(stage)[source]

Setup function. If necessary, it picks projectors and predictors based on ‘num_targets’ using the first 128 elements of the training dataset, as used in https://github.com/layer6ai-labs/lfr/blob/main/scripts/har/run_har_diet.sh.

training_step(batch, batch_idx)[source]

Perform a training step using the ‘_single_step’ method.

Parameters

batchtorch.Tensor

The input batch of data.

batch_idxint

The index of the batch.

Returns

torch.Tensor

The loss value for the batch.

Parameters:
  • batch (torch.Tensor)

  • batch_idx (int)

validation_step(batch, batch_idx)[source]

Perform a validation step using the ‘_single_step’ method.

Parameters

batchtorch.Tensor

The input batch of data.

batch_idxint

The index of the batch.

Returns

torch.Tensor

The loss value for the batch.

Parameters:
  • batch (torch.Tensor)

  • batch_idx (int)

weight_decay = 0.0
Parameters:
  • backbone (torch.nn.Module)

  • projectors (torch.nn.ModuleList)

  • predictors (torch.nn.ModuleList)

  • loss_fn (Optional[torch.nn.Module])

  • num_targets (Optional[int])

  • adapter (Optional[Callable[[torch.Tensor], torch.Tensor]])

  • learning_rate (float)

  • weight_decay (float)

  • flatten (bool)

  • predictor_training_epochs (Optional[int])

  • max_backbone_training_steps (Optional[int])

  • selection_batch_size (int)

class minerva.models.ssl.lfr.RepeatedModuleList(size, cls, *args, **kwargs)[source]

Bases: torch.nn.ModuleList

A module list with the same module cls, instantiated size times.

Initializes the RepeatedModuleList with multiple instances of a given module class.

Parameters

size: int

The number of instances to create.

cls: type

The module class to instantiate. Must be a subclass of torch.nn.Module.

*args:

Positional arguments to pass to the module class constructor.

**kwargs:

Keyword arguments to pass to the module class constructor.

Raises

AssertionError:

If cls is not a subclass of torch.nn.Module.

Example

>>> class SimpleModule(torch.nn.Module):
>>>     def __init__(self, in_features, out_features):
>>>         super().__init__()
>>>         self.linear = torch.nn.Linear(in_features, out_features)
>>>
>>> repeated_modules = RepeatedModuleList(3, SimpleModule, 10, 5)
>>> print(repeated_modules)
RepeatedModuleList(
    (0): SimpleModule(
        (linear): Linear(in_features=10, out_features=5, bias=True)
    )
    (1): SimpleModule(
        (linear): Linear(in_features=10, out_features=5, bias=True)
    )
    (2): SimpleModule(
        (linear): Linear(in_features=10, out_features=5, bias=True)
    )
)
Parameters:
  • size (int)

  • cls (type)

minerva.models.ssl.lfr.dpp(kernel_matrix, max_length, epsilon=1e-10)[source]

Our proposed fast implementation of the greedy algorithm :param kernel_matrix: 2-d array :param max_length: positive int :param epsilon: small positive scalar :return: list