import lightning as L
import torch
import torch.nn.functional as F
from minerva.losses.batchwise_barlowtwins_loss import BatchWiseBarlowTwinLoss
from typing import Optional, Callable
import numpy as np
import math
# -------------------------------------------------------------------------------------------------
# This function was extracted from the LFR implementation in
# https://github.com/layer6ai-labs/lfr/blob/main/utils/dpp.py
[docs]
def dpp(kernel_matrix, max_length, epsilon=1e-10):
"""
Our proposed fast implementation of the greedy algorithm
:param kernel_matrix: 2-d array
:param max_length: positive int
:param epsilon: small positive scalar
:return: list
"""
item_size = kernel_matrix.shape[0]
cis = np.zeros((max_length, item_size))
di2s = np.copy(np.diag(kernel_matrix))
selected_items = list()
selected_item = np.argmax(di2s)
selected_items.append(selected_item)
while len(selected_items) < max_length:
k = len(selected_items) - 1
ci_optimal = cis[:k, selected_item]
di_optimal = math.sqrt(di2s[selected_item])
elements = kernel_matrix[selected_item, :]
eis = (elements - np.dot(ci_optimal, cis[:k, :])) / di_optimal
cis[k, :] = eis
di2s -= np.square(eis)
di2s[selected_item] = -np.inf
selected_item = np.argmax(di2s)
if di2s[selected_item] < epsilon:
break
selected_items.append(selected_item)
return selected_items
# -------------------------------------------------------------------------------------------------
[docs]
class RepeatedModuleList(torch.nn.ModuleList):
"""
A module list with the same module `cls`, instantiated `size` times.
"""
def __init__(self, size: int, cls: type, *args, **kwargs):
"""
Initializes the RepeatedModuleList with multiple instances of a given module class.
Parameters
----------
size: int
The number of instances to create.
cls: type
The module class to instantiate. Must be a subclass of `torch.nn.Module`.
*args:
Positional arguments to pass to the module class constructor.
**kwargs:
Keyword arguments to pass to the module class constructor.
Raises
------
AssertionError:
If `cls` is not a subclass of `torch.nn.Module`.
Example
-------
>>> class SimpleModule(torch.nn.Module):
>>> def __init__(self, in_features, out_features):
>>> super().__init__()
>>> self.linear = torch.nn.Linear(in_features, out_features)
>>>
>>> repeated_modules = RepeatedModuleList(3, SimpleModule, 10, 5)
>>> print(repeated_modules)
RepeatedModuleList(
(0): SimpleModule(
(linear): Linear(in_features=10, out_features=5, bias=True)
)
(1): SimpleModule(
(linear): Linear(in_features=10, out_features=5, bias=True)
)
(2): SimpleModule(
(linear): Linear(in_features=10, out_features=5, bias=True)
)
)
"""
assert issubclass(
cls, torch.nn.Module
), f"{cls} does not derive from torch.nn.Module"
super().__init__([cls(*args, **kwargs) for _ in range(size)])
[docs]
class LearnFromRandomnessModel(L.LightningModule):
"""
A PyTorch Lightning model for pretraining with the technique 'Learning From Random Projectors'.
When using 'predictor_training_epochs', please consider updating your number of training epochs as well.
Otherwise, the LFR backbone will be trained for less epochs:
- If the total training epochs in your Trainer is 100, and 'predictor_training_epochs' is 1, then the
backbone will be trained on the epochs 0, 2, 4, 6, ... and 98, resulting in the backbone being effectively
trained for 50 epochs instead of the specified 100.
- If the total training epochs in your Trainer is 100, and 'predictor_training_epochs' is 2, then the
backbone will be trained on the epochs 0, 3, 6, 9, ... and 99, resulting in the backbone being effectively
trained for 34 epochs instead of the specified 100.
In conclusion, consider updating your total number of training epochs to:
Total number of training epochs = (intended backbone training epochs) * (predictor_training_epochs + 1)
References
----------
Yi Sui, Tongzi Wu, Jesse C. Cresswell, Ga Wu, George Stein, Xiao Shi Huang, Xiaochen Zhang, Maksims Volkovs.
"Self-supervised Representation Learning From Random Data Projectors", 2024
"""
def __init__(
self,
backbone: torch.nn.Module,
projectors: torch.nn.ModuleList,
predictors: torch.nn.ModuleList,
loss_fn: Optional[torch.nn.Module] = None,
num_targets: Optional[int] = None,
adapter: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
learning_rate: float = 1e-3,
weight_decay: float = 0.0,
flatten: bool = False,
predictor_training_epochs: Optional[int] = None,
max_backbone_training_steps: Optional[int] = None,
selection_batch_size: int = 128,
):
"""
Initialize the LFR_Model, freezing the projectors. Remember to update your number of training epochs when
using 'predictor_training_epochs':
Total number of training epochs = (intended backbone training epochs) * (predictor_training_epochs + 1)
Parameters
----------
backbone: torch.nn.Module
The backbone neural network for feature extraction.
projectors: torch.nn.ModuleList
A list of projector networks.
predictors: torch.nn.ModuleList
A list of predictor networks.
num_targets: Optional[int]
The number of projectors and predictors to select from the lists provided, using the Fast
Determinantal Point Process (DPP) algorithm. All projectors and predictors are used if the
value received is None, a negative integer, or an integer greater than the length of the lists.
loss_fn: Optional[torch.nn.Module]
The loss function to optimize, by default None. If None, the BatchWiseBarlowTwinLoss is used.
adapter: Optional[Callable[[torch.Tensor], torch.Tensor]]
An optional adapter network to be used in the model, by default None.
learning_rate: float
The learning rate for the optimizer, by default 1e-3.
weight_decay: float
The weight decay for the optimizer, by default 0.0.
flatten: bool
Whether to flatten the input tensor or not, by default False.
predictor_training_epochs: Optional[int]
The number of epochs to train only the predictors (excluding the backbone), by default None.
If None, zero, or negative, both the predictors and backbone are trained in every epoch. If
a positive integer is provided, the backbone is trained for one epoch, then frozen, and the
predictors are trained alone for the specified number of epochs. This cycle is repeated
throughout the training phase.
max_backbone_training_steps: Optional[int]
The number of steps the backbone will be trained, by default None. If None, zero, or
negative, no limit is applied. The steps where the backbone is frozen are ignored.
selection_batch_size: int
By default 128. When selecting projectors and predictors, this variable decides how many
random samples from the dataset are used in the Fast Determinantal Point Process (DPP)
algorithm.
"""
super().__init__()
self.backbone = backbone
self.projectors = projectors
self.predictors = predictors
self.num_targets = num_targets
self.loss_fn = loss_fn if loss_fn is not None else BatchWiseBarlowTwinLoss()
self.adapter = adapter
self.learning_rate = learning_rate
self.weight_decay = weight_decay
self.flatten = flatten
self.selection_batch_size = selection_batch_size
self.predictor_training_epochs = predictor_training_epochs
if self.predictor_training_epochs is None or self.predictor_training_epochs < 0:
self.predictor_training_epochs = 0
self.max_backbone_training_steps = max_backbone_training_steps
if (
self.max_backbone_training_steps is None
or self.max_backbone_training_steps < 0
):
self.max_backbone_training_steps = 0
# Helper variables
self.backbone_training_steps_counter = 1 # Because steps start at 1
self.freeze_backbone = False
for param in self.projectors.parameters():
param.requires_grad = False
for proj in self.projectors:
proj.eval()
# Parameter restrictions
if self.num_targets is None and len(self.projectors) != len(self.predictors):
raise ValueError(
"When num_targets is None, the number of projectors and predictors must be equal."
)
if self.num_targets is not None and (
self.num_targets <= 0
or self.num_targets > min(len(self.projectors), len(self.predictors))
):
raise ValueError(
"num_targets must be None or a positive integer less than or equal to the number of projectors and predictors."
)
[docs]
def setup(self, stage):
"""
Setup function. If necessary, it picks projectors and predictors based on 'num_targets'
using the first 128 elements of the training dataset, as used in
https://github.com/layer6ai-labs/lfr/blob/main/scripts/har/run_har_diet.sh.
"""
if stage != "fit":
return
if self.num_targets is None:
return
# Get the training dataset
training_dataset = self.trainer.datamodule.train_dataloader().dataset
# Pick random values from the dataset
random_idxs = np.random.choice(
range(len(training_dataset)),
size=self.selection_batch_size,
replace=self.selection_batch_size > len(training_dataset),
)
sample_batch = []
for random_idx in random_idxs:
random_sample = training_dataset[random_idx]
if isinstance(random_sample, (tuple, list)):
random_sample = random_sample[0]
sample_batch.append(random_sample)
sample_batch = np.array(sample_batch)
sample_batch = torch.from_numpy(sample_batch)
self._select_targets(sample_batch)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass through the network.
Parameters
----------
x : torch.Tensor
The input data.
Returns
-------
torch.Tensor
The predicted output and projected input.
"""
z: torch.Tensor = self.backbone(x)
if self.adapter:
z = self.adapter(z)
if self.flatten:
z = z.view(z.size(0), -1)
x = x.view(x.size(0), -1)
y_pred = torch.stack([predictor(z) for predictor in self.predictors], 1)
y_proj = torch.stack([projector(x) for projector in self.projectors], 1)
return y_pred, y_proj
[docs]
def _loss_from_targets(self, y_pred: torch.Tensor, y_proj: torch.Tensor):
"""
Computes the average loss between each pair of predictor and projector outputs.
This function is isolated from `_single_step` to make it easier to test
independently.
Parameters
----------
y_pred : torch.Tensor
The predictions tensors.
y_proj : torch.Tensor
The projections tensors.
"""
loss = torch.tensor(0)
for i in range(y_pred.shape[1]):
loss = loss + self.loss_fn(y_pred[:, i, :], y_proj[:, i, :])
loss /= y_pred.shape[1]
return loss
[docs]
def _single_step(
self, batch: torch.Tensor, batch_idx: int, step_name: str
) -> torch.Tensor:
"""
Perform a single training/validation/test step, computing and logging the loss.
Parameters
----------
batch : torch.Tensor
The input batch of data.
batch_idx : int
The index of the batch.
step_name : str
The name of the step (train, val, test).
Returns
-------
torch.Tensor
The loss value for the batch.
"""
if isinstance(batch, (tuple, list)):
batch = batch[0]
x = batch
y_pred, y_proj = self.forward(x)
loss = self._loss_from_targets(y_pred=y_pred, y_proj=y_proj)
self.log(
f"{step_name}_loss",
loss,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
)
return loss
[docs]
def _select_targets(self, sample_data):
"""
Select projectors and predictors based on 'num_targets' using the Fast Determinantal Point
Process (DPP) algorithm and some sample data. Code adapted from
https://github.com/layer6ai-labs/lfr/blob/main/ssl_models/lfr.py
"""
with torch.no_grad():
sims = []
for projector in self.projectors:
# (bs, dim)
rep = projector(sample_data)
if rep.shape[0] > 1000:
rep = rep[
np.random.RandomState(seed=42).permutation(
np.arange(rep.shape[0])
)[:1000]
]
rep_normalized = F.normalize(rep, dim=1)
# (bs, bs) cosine similarity
sim = rep_normalized @ rep_normalized.T
sims.append(sim.view(-1))
# N, bs^2
sims = torch.stack(sims)
sims_normalized = F.normalize(sims, dim=1)
# N,N
sims_targets = sims_normalized @ sims_normalized.T
result = dpp(sims_targets.cpu().numpy(), self.num_targets)
# Select the projectors and predictors based on the result from DPP
self.projectors = torch.nn.ModuleList([self.projectors[idx] for idx in result])
if len(self.predictors) != self.num_targets:
self.predictors = torch.nn.ModuleList(
[self.predictors[idx] for idx in range(self.num_targets)]
)
[docs]
def training_step(self, batch: torch.Tensor, batch_idx: int):
"""
Perform a training step using the '_single_step' method.
Parameters
----------
batch : torch.Tensor
The input batch of data.
batch_idx : int
The index of the batch.
Returns
-------
torch.Tensor
The loss value for the batch.
"""
return self._single_step(batch, batch_idx, step_name="train")
[docs]
def validation_step(self, batch: torch.Tensor, batch_idx: int):
"""
Perform a validation step using the '_single_step' method.
Parameters
----------
batch : torch.Tensor
The input batch of data.
batch_idx : int
The index of the batch.
Returns
-------
torch.Tensor
The loss value for the batch.
"""
return self._single_step(batch, batch_idx, step_name="val")
[docs]
def on_train_epoch_start(self):
"""
Executed at the start of each training epoch. If the predictor training epochs is valid, this
function evaluates the current epoch number and freeze or unfreeze the backbone based on it.
If the predictor training epochs is None, zero, or negative, the backbone is always trained.
In the first epoch, the backbone is trained. In the subsequent 'predictor_training_epochs'
epochs, it is frozen.
"""
if self.predictor_training_epochs > 0:
self.freeze_backbone = (
self.current_epoch % (self.predictor_training_epochs + 1) != 0
)
self.backbone.train(mode=(not self.freeze_backbone))
for param in self.backbone.parameters():
param.requires_grad = not self.freeze_backbone
[docs]
def on_train_batch_start(self, batch, batch_idx):
"""
If a training steps limit is set, it checks the training step counter at the start of every
training batch. If the counter reached the limit, it returns -1, stopping the training.
"""
if (
self.max_backbone_training_steps > 0
and self.backbone_training_steps_counter > self.max_backbone_training_steps
):
return -1
[docs]
def on_train_batch_end(self, outputs, batch, batch_idx):
"""
Updates the backbone training steps counter only if the backbone is not frozen.
"""
if not self.freeze_backbone:
self.backbone_training_steps_counter += 1