Source code for minerva.optimizers.lars

import torch
from torch.optim.optimizer import Optimizer
from typing import Set, Optional, Callable, Any


[docs] class LARS(Optimizer): """Layer-wise Adaptive Rate Scaling (LARS) optimizer. This optimizer implements the LARS algorithm, which adapts the learning rate for each layer based on the ratio of the weight norm to the gradient norm. This helps stabilize training and allows for larger learning rates. """ def __init__( self, params: Any, lr: float, momentum: float = 0.9, weight_decay: float = 1e-6, eta: float = 0.001, epsilon: float = 1e-8, exclude_from_layer_adaptation: Optional[Set[str]] = None, ): """Layer-wise Adaptive Rate Scaling (LARS) optimizer. This optimizer implements the LARS algorithm, which adapts the learning rate for each layer based on the ratio of the weight norm to the gradient norm. This helps stabilize training and allows for larger learning rates. Parameters ---------- params : Any Parameters to optimize. lr : float Base learning rate. momentum : float, optional, default: 0.9 Momentum factor. weight_decay : float, optional, default: 1e-6 Weight decay (L2 penalty) coefficient. eta : float, optional, default: 0.001 Trust coefficient for layer-wise rate scaling. epsilon : float, optional, default: 1e-8 Small constant for numerical stability. exclude_from_layer_adaptation : Set[str], optional Set of parameter names to exclude from layer-wise adaptation (e.g., batch normalization layers and biases). Attributes ---------- exclude_set : Set[str] Set of parameter names excluded from layer-wise adaptation. References ---------- .. [1] You, Yang, et al. "Large batch training of convolutional networks." arXiv preprint arXiv:1708.03888 (2017). """ if lr < 0.0: raise ValueError(f"Invalid learning rate: {lr}") defaults = dict( lr=lr, momentum=momentum, weight_decay=weight_decay, eta=eta, epsilon=epsilon, ) super(LARS, self).__init__(params, defaults) self.exclude_set = exclude_from_layer_adaptation or set()
[docs] @torch.no_grad() def step(self, closure: Optional[Callable[[], float]] = None): """ Performs a single optimization step. Parameters ---------- closure : callable, optional A closure that reevaluates the model and returns the loss. Returns ------- loss : torch.Tensor or None Loss from the closure if provided, otherwise None. """ loss = None if closure is not None: with torch.enable_grad(): loss = closure() for group in self.param_groups: for p in group["params"]: if p.grad is None: continue # Extract hyperparameters lr = group["lr"] momentum = group["momentum"] weight_decay = group["weight_decay"] eta = group["eta"] epsilon = group["epsilon"] grad = p.grad.data # Get state state = self.state[p] # Initialize momentum buffer if needed if "momentum_buffer" not in state: state["momentum_buffer"] = torch.zeros_like(p.data) # Get parameter name for exclusion check param_name = getattr(p, "param_name", "") # ===== LARS core calculation ===== if param_name in self.exclude_set: # For excluded parameters (BN, bias), use standard LR trust_ratio = 1.0 else: # Compute weight norm and raw gradient norm w_norm = torch.norm(p.data) g_norm = torch.norm(grad) # Compute trust ratio (local learning rate scaling) denom = g_norm + weight_decay * w_norm + epsilon trust_ratio = ( eta * w_norm / denom if w_norm > 0 and denom > 0 else 1.0 ) # Calculate effective learning rate effective_lr = lr * trust_ratio # Apply weight decay to gradient if weight_decay != 0: grad = grad.add(p.data, alpha=weight_decay) # Update momentum buffer state["momentum_buffer"].mul_(momentum).add_(grad, alpha=effective_lr) # Update weights p.data.sub_(state["momentum_buffer"]) return loss