import torch.nn as nn
from typing import Sequence, Optional, List
[docs]
class MLP(nn.Sequential):
"""
A flexible multilayer perceptron (MLP) implemented as a subclass of nn.Sequential.
This class allows you to quickly build an MLP with:
- Custom layer sizes
- Configurable activation functions
- Optional intermediate operations (e.g., BatchNorm, Dropout) after each linear layer
- An optional final operation (e.g., normalization, final activation)
Parameters
----------
layer_sizes : Sequence[int]
A list of integers specifying the sizes of each layer. Must contain at least two values:
the input and output dimensions.
activation_cls : type, optional
The activation function class (must inherit from nn.Module) to use between layers.
Defaults to nn.ReLU.
intermediate_ops : Optional[List[Optional[nn.Module]]], optional
A list of modules (e.g., nn.BatchNorm1d, nn.Dropout) to apply after each linear layer
and before the activation. Each item corresponds to one linear layer. Use `None` to skip
an operation for that layer. Must be the same length as the number of linear layers.
final_op : Optional[nn.Module], optional
A module to apply after the last layer (e.g., a final activation or normalization).
*args, **kwargs :
Additional arguments passed to the activation function constructor.
Example
-------
>>> from torch import nn
>>> mlp = MLP(
... [128, 256, 64, 10],
... activation_cls=nn.ReLU,
... intermediate_ops=[nn.BatchNorm1d(256), nn.BatchNorm1d(64), None],
... final_op=nn.Sigmoid()
... )
>>> print(mlp)
MLP(
(0): Linear(in_features=128, out_features=256, bias=True)
(1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Linear(in_features=256, out_features=64, bias=True)
(4): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU()
(6): Linear(in_features=64, out_features=10, bias=True)
(7): Sigmoid()
)
"""
def __init__(
self,
layer_sizes: Sequence[int],
activation_cls: type = nn.ReLU,
intermediate_ops: Optional[List[Optional[nn.Module]]] = None,
final_op: Optional[nn.Module] = None,
*args,
**kwargs,
):
assert (
len(layer_sizes) >= 2
), "Multilayer perceptron must have at least 2 layers"
assert all(
isinstance(ls, int) and ls > 0 for ls in layer_sizes
), "All layer sizes must be positive integers"
assert issubclass(
activation_cls, nn.Module
), "activation_cls must inherit from torch.nn.Module"
num_layers = len(layer_sizes) - 1
if intermediate_ops is not None:
if len(intermediate_ops) != num_layers:
raise ValueError(
f"Length of intermediate_ops ({len(intermediate_ops)}) must match number of layers ({num_layers})"
)
layers = []
for i in range(num_layers):
in_dim, out_dim = layer_sizes[i], layer_sizes[i + 1]
layers.append(nn.Linear(in_dim, out_dim))
if intermediate_ops is not None and intermediate_ops[i] is not None:
layers.append(intermediate_ops[i])
if activation_cls is not None:
layers.append(activation_cls(*args, **kwargs))
if final_op is not None:
layers.append(final_op)
super().__init__(*layers)