Source code for minerva.models.finetune_adapters

import torch
import torch.nn as nn
import math

[docs] class LoRA(nn.Module): """ LoRA (Low-Rank Adaptation) for Linear Layers. This module applies low-rank adaptation to an existing linear layer. LoRA enables fine-tuning of pre-trained models efficiently by introducing learnable low-rank matrices that adapt the weights of the original layer while keeping its parameters frozen. Parameters ---------- original_module : torch.nn.Module The original linear or transformer layer (e.g., `torch.nn.Linear`) to which LoRA is applied. It must have `in_features` and `out_features` attributes. bias : bool, optional Whether to include a bias term in the LoRA adaptation layers. Default is True. alpha : float, optional The scaling factor for the LoRA output. Default is 1. r : int, optional The rank of the low-rank matrices used for adaptation. Default is 4. Attributes ---------- original_module : torch.nn.Module The original module that LoRA adapts. matrix_A : torch.nn.Linear The low-rank matrix `A` with dimensions `(in_features, r)`. matrix_B : torch.nn.Linear The low-rank matrix `B` with dimensions `(r, out_features)`. scaling : float The scaling factor applied to the LoRA adaptation output. Methods ------- init_weights(): Initializes the weights of the low-rank matrices `A` and `B`. Matrix `A` is initialized using Kaiming uniform initialization, and matrix `B` is initialized with zeros. forward(x): Computes the forward pass through the adapted module. Examples -------- >>> import torch >>> import torch.nn as nn >>> from lora_module import LoRA >>> # Original linear layer >>> original_layer = nn.Linear(128, 64) >>> # Wrap the original layer with LoRA >>> lora_layer = LoRA(original_layer, alpha=2, r=8) >>> # Input tensor >>> x = torch.randn(16, 128) # batch size of 16 >>> # Forward pass >>> output = lora_layer(x) >>> print(output.shape) torch.Size([16, 64]) """ def __init__(self, original_module:torch.nn.Module, bias:bool=True, alpha:int=1, r:int=4): super(LoRA, self).__init__() self.original_module = original_module self.matrix_A = torch.nn.Linear(original_module.in_features, r, bias=bias) self.matrix_B = torch.nn.Linear(r, original_module.out_features, bias=bias) self.scaling = alpha / r self.init_weights()
[docs] def init_weights(self): """ Initialize weights for the low-rank matrices. Matrix `A` is initialized with Kaiming uniform initialization, which is suitable for layers with ReLU activations. Matrix `B` is initialized with zeros to ensure that the original module's behavior is not perturbed at the start. """ torch.nn.init.kaiming_uniform_(self.matrix_A.weight, a=math.sqrt(5)) torch.nn.init.zeros_(self.matrix_B.weight)
[docs] def forward(self, x:torch.Tensor): """ Forward pass of the LoRA module. Computes the output as the sum of the original module's output and the low-rank adaptation output, scaled by the specified `scaling` factor. Parameters ---------- x : torch.Tensor The input tensor with shape `(batch_size, in_features)`. Returns ------- torch.Tensor The output tensor with shape `(batch_size, out_features)`. Notes ----- The output is computed as: .. math:: y = \text{original_module}(x) + \text{scaling} \cdot B(A(x)), where `A` and `B` are the learnable low-rank matrices. """ return self.original_module(x) + self.scaling * self.matrix_B(self.matrix_A(x))