Source code for minerva.models.nets.mlp

import torch.nn as nn
from typing import Sequence, Optional, List


[docs] class MLP(nn.Sequential): """ A flexible multilayer perceptron (MLP) implemented as a subclass of nn.Sequential. This class allows you to quickly build an MLP with: - Custom layer sizes - Configurable activation functions - Optional intermediate operations (e.g., BatchNorm, Dropout) after each linear layer - An optional final operation (e.g., normalization, final activation) Parameters ---------- layer_sizes : Sequence[int] A list of integers specifying the sizes of each layer. Must contain at least two values: the input and output dimensions. activation_cls : type, optional The activation function class (must inherit from nn.Module) to use between layers. Defaults to nn.ReLU. intermediate_ops : Optional[List[Optional[nn.Module]]], optional A list of modules (e.g., nn.BatchNorm1d, nn.Dropout) to apply after each linear layer and before the activation. Each item corresponds to one linear layer. Use `None` to skip an operation for that layer. Must be the same length as the number of linear layers. final_op : Optional[nn.Module], optional A module to apply after the last layer (e.g., a final activation or normalization). *args, **kwargs : Additional arguments passed to the activation function constructor. Example ------- >>> from torch import nn >>> mlp = MLP( ... [128, 256, 64, 10], ... activation_cls=nn.ReLU, ... intermediate_ops=[nn.BatchNorm1d(256), nn.BatchNorm1d(64), None], ... final_op=nn.Sigmoid() ... ) >>> print(mlp) MLP( (0): Linear(in_features=128, out_features=256, bias=True) (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU() (3): Linear(in_features=256, out_features=64, bias=True) (4): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (5): ReLU() (6): Linear(in_features=64, out_features=10, bias=True) (7): Sigmoid() ) """ def __init__( self, layer_sizes: Sequence[int], activation_cls: type = nn.ReLU, intermediate_ops: Optional[List[Optional[nn.Module]]] = None, final_op: Optional[nn.Module] = None, *args, **kwargs, ): assert ( len(layer_sizes) >= 2 ), "Multilayer perceptron must have at least 2 layers" assert all( isinstance(ls, int) and ls > 0 for ls in layer_sizes ), "All layer sizes must be positive integers" assert issubclass( activation_cls, nn.Module ), "activation_cls must inherit from torch.nn.Module" num_layers = len(layer_sizes) - 1 if intermediate_ops is not None: if len(intermediate_ops) != num_layers: raise ValueError( f"Length of intermediate_ops ({len(intermediate_ops)}) must match number of layers ({num_layers})" ) layers = [] for i in range(num_layers): in_dim, out_dim = layer_sizes[i], layer_sizes[i + 1] layers.append(nn.Linear(in_dim, out_dim)) if intermediate_ops is not None and intermediate_ops[i] is not None: layers.append(intermediate_ops[i]) if activation_cls is not None: layers.append(activation_cls(*args, **kwargs)) if final_op is not None: layers.append(final_op) super().__init__(*layers)