Source code for minerva.models.nets.diet_linear
from typing import Callable
import torch
[docs]
class AdaptedHead(torch.nn.Module):
def __init__(self, model: torch.nn.Module, adapter: Callable):
super().__init__()
self.model = model
self.adapter = adapter
[docs]
def forward(self, x):
x = self.adapter(x)
return self.model.forward(x)
[docs]
class DIETLinear(torch.nn.Module):
def __init__(self, in_features: int, out_features: int):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.fc = torch.nn.Linear(in_features, out_features)
[docs]
def forward(self, x):
x = self.fc(x)
return x