Source code for minerva.optimizers.lr_schedulers

from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler


[docs] class PolyLRScheduler(_LRScheduler): def __init__( self, optimizer: Optimizer, max_iter: int, power: float = 0.9, min_lr: float = 1e-4, last_epoch: int = -1, ): """ Polynomial decay LR scheduler. The learning rate decays from the initial ``base_lr`` to at least ``min_lr`` following a polynomial schedule: :math:`lr = \\max(\\text{min\\_lr}, \\text{base\\_lr} \\cdot (1 - \\tfrac{t}{T})^{power})` where :math:`t` is the current step (``last_epoch``) and :math:`T` is ``max_iter``. Parameters ---------- optimizer : torch.optim.Optimizer Wrapped optimizer. max_iter : int Total number of iterations (epochs or steps, depending on usage). Must be strictly greater than 0. When ``last_epoch`` reaches or exceeds ``max_iter``, the learning rate is clamped to ``min_lr``. power : float, default=0.9 Power factor for polynomial decay, controlling how fast the learning rate decays. min_lr : float, default=1e-4 Minimum learning rate allowed. Once the polynomial decay falls below this value, the learning rate is fixed at ``min_lr``. last_epoch : int, default=-1 The index of the last epoch. If set to ``-1`` (default), the scheduler initializes with the optimizer’s learning rates, and the first call to ``step()`` sets the learning rate to ``base_lr`` without applying decay. """ self.max_iter = max_iter self.power = power self.min_lr = min_lr assert max_iter > 0, "max_iter must be greater than 0" super().__init__(optimizer, last_epoch)
[docs] def get_lr(self): # Clamp last_epoch to be at most max_iter to avoid negative/zero/complex factors epoch = min(self.last_epoch, self.max_iter) factor = max(0, 1 - epoch / self.max_iter) ** self.power # Ensure LR never goes below min_lr return [max(self.min_lr, base_lr * factor) for base_lr in self.base_lrs]