minerva.models.ssl.lfr¶
Classes¶
A PyTorch Lightning model for pretraining with the technique 'Learning From Random Projectors'. |
|
A module list with the same module cls, instantiated size times. |
Functions¶
|
Our proposed fast implementation of the greedy algorithm |
Module Contents¶
- class minerva.models.ssl.lfr.LearnFromRandomnessModel(backbone, projectors, predictors, loss_fn=None, num_targets=None, adapter=None, learning_rate=0.001, weight_decay=0.0, flatten=False, predictor_training_epochs=None, max_backbone_training_steps=None, selection_batch_size=128)[source]¶
Bases:
lightning.LightningModuleA PyTorch Lightning model for pretraining with the technique ‘Learning From Random Projectors’. When using ‘predictor_training_epochs’, please consider updating your number of training epochs as well. Otherwise, the LFR backbone will be trained for less epochs: - If the total training epochs in your Trainer is 100, and ‘predictor_training_epochs’ is 1, then the
backbone will be trained on the epochs 0, 2, 4, 6, … and 98, resulting in the backbone being effectively trained for 50 epochs instead of the specified 100.
If the total training epochs in your Trainer is 100, and ‘predictor_training_epochs’ is 2, then the backbone will be trained on the epochs 0, 3, 6, 9, … and 99, resulting in the backbone being effectively trained for 34 epochs instead of the specified 100.
- In conclusion, consider updating your total number of training epochs to:
Total number of training epochs = (intended backbone training epochs) * (predictor_training_epochs + 1)
References¶
Yi Sui, Tongzi Wu, Jesse C. Cresswell, Ga Wu, George Stein, Xiao Shi Huang, Xiaochen Zhang, Maksims Volkovs. “Self-supervised Representation Learning From Random Data Projectors”, 2024
Initialize the LFR_Model, freezing the projectors. Remember to update your number of training epochs when using ‘predictor_training_epochs’:
Total number of training epochs = (intended backbone training epochs) * (predictor_training_epochs + 1)
Parameters¶
- backbone: torch.nn.Module
The backbone neural network for feature extraction.
- projectors: torch.nn.ModuleList
A list of projector networks.
- predictors: torch.nn.ModuleList
A list of predictor networks.
- num_targets: Optional[int]
The number of projectors and predictors to select from the lists provided, using the Fast Determinantal Point Process (DPP) algorithm. All projectors and predictors are used if the value received is None, a negative integer, or an integer greater than the length of the lists.
- loss_fn: Optional[torch.nn.Module]
The loss function to optimize, by default None. If None, the BatchWiseBarlowTwinLoss is used.
- adapter: Optional[Callable[[torch.Tensor], torch.Tensor]]
An optional adapter network to be used in the model, by default None.
- learning_rate: float
The learning rate for the optimizer, by default 1e-3.
- weight_decay: float
The weight decay for the optimizer, by default 0.0.
- flatten: bool
Whether to flatten the input tensor or not, by default False.
- predictor_training_epochs: Optional[int]
The number of epochs to train only the predictors (excluding the backbone), by default None. If None, zero, or negative, both the predictors and backbone are trained in every epoch. If a positive integer is provided, the backbone is trained for one epoch, then frozen, and the predictors are trained alone for the specified number of epochs. This cycle is repeated throughout the training phase.
- max_backbone_training_steps: Optional[int]
The number of steps the backbone will be trained, by default None. If None, zero, or negative, no limit is applied. The steps where the backbone is frozen are ignored.
- selection_batch_size: int
By default 128. When selecting projectors and predictors, this variable decides how many random samples from the dataset are used in the Fast Determinantal Point Process (DPP) algorithm.
- _loss_from_targets(y_pred, y_proj)[source]¶
Computes the average loss between each pair of predictor and projector outputs. This function is isolated from _single_step to make it easier to test independently.
Parameters¶
- y_predtorch.Tensor
The predictions tensors.
- y_projtorch.Tensor
The projections tensors.
- Parameters:
y_pred (torch.Tensor)
y_proj (torch.Tensor)
- _select_targets(sample_data)[source]¶
Select projectors and predictors based on ‘num_targets’ using the Fast Determinantal Point Process (DPP) algorithm and some sample data. Code adapted from https://github.com/layer6ai-labs/lfr/blob/main/ssl_models/lfr.py
- _single_step(batch, batch_idx, step_name)[source]¶
Perform a single training/validation/test step, computing and logging the loss.
Parameters¶
- batchtorch.Tensor
The input batch of data.
- batch_idxint
The index of the batch.
- step_namestr
The name of the step (train, val, test).
Returns¶
- torch.Tensor
The loss value for the batch.
- Parameters:
batch (torch.Tensor)
batch_idx (int)
step_name (str)
- Return type:
torch.Tensor
- adapter = None¶
- backbone¶
- backbone_training_steps_counter = 1¶
- configure_optimizers()[source]¶
Configure the optimizer for the model. This method sets up the optimizer for the model’s parameters, excluding the projectors.
- flatten = False¶
- forward(x)[source]¶
Forward pass through the network.
Parameters¶
- xtorch.Tensor
The input data.
Returns¶
- torch.Tensor
The predicted output and projected input.
- Parameters:
x (torch.Tensor)
- Return type:
torch.Tensor
- freeze_backbone = False¶
- learning_rate = 0.001¶
- loss_fn¶
- max_backbone_training_steps = None¶
- num_targets = None¶
- on_train_batch_end(outputs, batch, batch_idx)[source]¶
Updates the backbone training steps counter only if the backbone is not frozen.
- on_train_batch_start(batch, batch_idx)[source]¶
If a training steps limit is set, it checks the training step counter at the start of every training batch. If the counter reached the limit, it returns -1, stopping the training.
- on_train_epoch_start()[source]¶
Executed at the start of each training epoch. If the predictor training epochs is valid, this function evaluates the current epoch number and freeze or unfreeze the backbone based on it. If the predictor training epochs is None, zero, or negative, the backbone is always trained. In the first epoch, the backbone is trained. In the subsequent ‘predictor_training_epochs’ epochs, it is frozen.
- predictor_training_epochs = None¶
- predictors¶
- projectors¶
- selection_batch_size = 128¶
- setup(stage)[source]¶
Setup function. If necessary, it picks projectors and predictors based on ‘num_targets’ using the first 128 elements of the training dataset, as used in https://github.com/layer6ai-labs/lfr/blob/main/scripts/har/run_har_diet.sh.
- training_step(batch, batch_idx)[source]¶
Perform a training step using the ‘_single_step’ method.
Parameters¶
- batchtorch.Tensor
The input batch of data.
- batch_idxint
The index of the batch.
Returns¶
- torch.Tensor
The loss value for the batch.
- Parameters:
batch (torch.Tensor)
batch_idx (int)
- validation_step(batch, batch_idx)[source]¶
Perform a validation step using the ‘_single_step’ method.
Parameters¶
- batchtorch.Tensor
The input batch of data.
- batch_idxint
The index of the batch.
Returns¶
- torch.Tensor
The loss value for the batch.
- Parameters:
batch (torch.Tensor)
batch_idx (int)
- weight_decay = 0.0¶
- Parameters:
backbone (torch.nn.Module)
projectors (torch.nn.ModuleList)
predictors (torch.nn.ModuleList)
loss_fn (Optional[torch.nn.Module])
num_targets (Optional[int])
adapter (Optional[Callable[[torch.Tensor], torch.Tensor]])
learning_rate (float)
weight_decay (float)
flatten (bool)
predictor_training_epochs (Optional[int])
max_backbone_training_steps (Optional[int])
selection_batch_size (int)
- class minerva.models.ssl.lfr.RepeatedModuleList(size, cls, *args, **kwargs)[source]¶
Bases:
torch.nn.ModuleListA module list with the same module cls, instantiated size times.
Initializes the RepeatedModuleList with multiple instances of a given module class.
Parameters¶
Raises¶
- AssertionError:
If cls is not a subclass of torch.nn.Module.
Example¶
>>> class SimpleModule(torch.nn.Module): >>> def __init__(self, in_features, out_features): >>> super().__init__() >>> self.linear = torch.nn.Linear(in_features, out_features) >>> >>> repeated_modules = RepeatedModuleList(3, SimpleModule, 10, 5) >>> print(repeated_modules) RepeatedModuleList( (0): SimpleModule( (linear): Linear(in_features=10, out_features=5, bias=True) ) (1): SimpleModule( (linear): Linear(in_features=10, out_features=5, bias=True) ) (2): SimpleModule( (linear): Linear(in_features=10, out_features=5, bias=True) ) )
- Parameters:
size (int)
cls (type)