minerva.models.nets.siamese_network_wrapper

Classes

SiameseNetworkWrapper

A simple wrapper for a Siamese Network. The code was inspired by the tutorial in the Pytorch

Module Contents

class minerva.models.nets.siamese_network_wrapper.SiameseNetworkWrapper(backbone)[source]

Bases: torch.nn.Module

A simple wrapper for a Siamese Network. The code was inspired by the tutorial in the Pytorch website https://github.com/pytorch/examples/blob/main/siamese_network/main.py. It passes the inputs (namely x1 and x2) through the same backbone, and concatenates the representations obtained.

Initializes the wrapper.

Parameters

backbonenn.Module

The backbone of the Siamese Network.

backbone
forward(x)[source]

Passes the inputs through the backbone and concatenates the representations. x must be a list or a tuple containing two inputs, namely x1 and x2.

Parameters

xUnion[list, tuple]

A list or a tuple containing the two inputs.

Returns

torch.Tensor

The concatenated representations.

Parameters:

x (Union[list, tuple])

Return type:

torch.Tensor

forward_once(x)[source]

Passes the input through the backbone and flattens the output.

Parameters

xtorch.Tensor

The input tensor.

Returns

torch.Tensor

The output data from the forward pass through the backbone.

Parameters:

backbone (torch.nn.Module)