Source code for minerva.schedulers.warmup_cosine_annealing
from torch.optim.lr_scheduler import _LRScheduler
import math
from torch.optim.optimizer import Optimizer
[docs]
class WarmupCosineAnnealingLR(_LRScheduler):
"""
A custom learning rate scheduler that combines linear warmup with cosine
annealing. The learning rate increases linearly over the first 'warmup_epochs',
and then decreases until 'total_epochs' following a cosine curve.
"""
def __init__(
self,
optimizer: Optimizer,
warmup_epochs: int,
total_epochs: int,
min_lr: int = 0,
last_epoch: int = -1,
):
"""
Initializes the scheduler.
Parameters
----------
optimizer : torch.optim.optimizer.Optimizer
Wrapped optimizer.
warmup_epochs : int
Number of epochs for linear warmup.
total_epochs : int
Total number of training epochs.
min_lr : float
Minimum learning rate expected at the end of the cosine annealing.
last_epoch: int
Index of the last epoch. Used for resuming training.
"""
super(WarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch)
self.warmup_epochs = warmup_epochs
self.total_epochs = total_epochs
self.min_lr = min_lr
if warmup_epochs > total_epochs:
raise ValueError("total_epochs must be greater than warmup_epochs.")
[docs]
def get_lr(self):
if self.last_epoch < self.warmup_epochs:
# Linear warmup
warmup_factor = (self.last_epoch + 1) / self.warmup_epochs
return [base_lr * warmup_factor for base_lr in self.base_lrs]
else:
# Cosine annealing
progress = (self.last_epoch - self.warmup_epochs) / (
self.total_epochs - self.warmup_epochs
)
cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
return [
self.min_lr + (base_lr - self.min_lr) * cosine_decay
for base_lr in self.base_lrs
]