from typing import Callable, Optional
import lightning as L
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from minerva.schedulers.warmup_cosine_annealing import WarmupCosineAnnealingLR
[docs]
class DIET(L.LightningModule):
def __init__(
self,
backbone: nn.Module,
linear_head: Optional[torch.nn.Module] = None,
num_data: Optional[int] = None,
flatten: bool = True,
adapter: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
loss: Callable = None,
learning_rate: float = 3e-4,
weight_decay: float = 3e-4,
wca_scheduler_total_epochs: Optional[int] = None,
):
"""
DIET model.
Parameters
----------
backbone : torch.nn.Module
Backbone model.
linear_head: torch.nn.Module, optional
Linear head that computes logits from embeddings of the data input, by default None.
If None, the linear head is automatically defined before training. The lengths of
both training dataset and linear head output must match.
num_data : int, optional
Total number of samples in the training dataset, by default None. If None, the length
of the training dataset is computed before the training in the setup() function.
flatten : bool
If True, the output of the backbone is flattened before the linear layer,
by default True.
adapter : Optional[Callable[[torch.Tensor], torch.Tensor]], optional
If not None, an adapter is added after the backbone and before the flatten process,
by default None.
loss : Callable
Loss function, by default CrossEntropyLoss with label smoothing 0.8.
learning_rate : float, optional
Learning rate used in the optimizer, by default 3e-4.
weight_decay : float, optional
Weight decay used in the optimizer, by default 3e-4.
wca_scheduler_total_epochs : int, optional
Total number of epochs for the WarmupCosineAnnealing scheduler, by default None.
Must be None or an integer greater than 10. If None, no scheduler is used.
"""
super(DIET, self).__init__()
# Defining layers
self.backbone = backbone
self.linear_head = linear_head
self.num_data = num_data
# Defining adapter
self.adapter = adapter
self.flatten = flatten
# Defining loss
self.loss = loss or CrossEntropyLoss(label_smoothing=0.8)
# Defining other hyperparameters
self.learning_rate = learning_rate
self.weight_decay = weight_decay
self.wca_scheduler_total_epochs = wca_scheduler_total_epochs
if (
self.wca_scheduler_total_epochs is not None
and self.wca_scheduler_total_epochs <= 10
):
raise ValueError(
"Total number of epochs for the WarmupCosineAnnealing scheduler must be greater than 10."
)
[docs]
def setup(self, stage):
"""
Setup function. If the model lacks a linear head, this function computes the length
of the training dataset, the encoding size, and creates a linear head accordingly. Also
checks whether the linear head output matches the length of the training dataset,
raising an error in case of mismatch.
"""
if stage != "fit":
return
# Get the training dataset
training_dataset = self.trainer.datamodule.train_dataloader().dataset
# Update num_data if None
if self.num_data is None:
self.num_data = len(training_dataset)
# Define a linear head if None
if self.linear_head is None:
# Simulated input for encoding_size calculation
random_input = torch.rand(training_dataset[:5][0].shape)
# Compute the encoding size
with torch.no_grad():
# Obtain the embeddings from the random data
out = self.backbone(random_input)
if self.adapter:
out = self.adapter(out)
if self.flatten:
out = out.flatten(start_dim=1)
# Computes the encoding size
encoding_size = out.size(1)
# Defines the linear head
self.linear_head = nn.Linear(encoding_size, self.num_data)
else:
# Check if the linear head provided matches the length of the training dataset
assert (
self.num_data == self.linear_head.out_features
), f"Number of samples({self.num_data}) and output of linear head({self.linear_head.out_features}) do not match."
[docs]
def forward(self, x):
x = self.backbone(x)
if self.adapter:
x = self.adapter(x)
if self.flatten:
x = x.flatten(start_dim=1)
x = self.linear_head(x)
return x
[docs]
def training_step(self, batch, batch_idx):
"""
A simple training step.
"""
x, y = batch
y_hat = self(x)
loss = self.loss(y_hat, y)
self.log("train_loss", loss, on_epoch=True, on_step=False)
return loss