minerva.losses.ntxent_loss_poly

Classes

NTXentLoss_poly

Loss function used on the pretraining of the TFC model. It is based on the NTXentLoss, but it includes a polynomial loss term.

Module Contents

class minerva.losses.ntxent_loss_poly.NTXentLoss_poly(device, batch_size, temperature, use_cosine_similarity)[source]

Bases: torch.nn.modules.loss._Loss

Loss function used on the pretraining of the TFC model. It is based on the NTXentLoss, but it includes a polynomial loss term.

The constructor of the NTXentLoss_poly class.

Parameters

  • device: str

    The device to be used in the training of the model

  • batch_size: int

    The batch size of the model

  • temperature: float

    The temperature of the softmax function

  • use_cosine_similarity: bool

    If True, the cosine similarity is used. If False, the dot product is used

_cosine_simililarity(x, y)[source]

Function to calculate the cosine similarity between two tensors.

Parameters

  • x: torch.Tensor

    The first tensor

  • y: torch.Tensor

    The second tensor

Returns

  • torch.Tensor

    The cosine similarity between the two tensors

Return type:

torch.Tensor

static _dot_simililarity(x, y)[source]

Function to calculate the dot similarity between two tensors.

Parameters

  • x: torch.Tensor

    The first tensor

  • y: torch.Tensor

    The second tensor

Returns

  • torch.Tensor

    The dot similarity between the two tensors

Return type:

torch.Tensor

_get_correlated_mask()[source]

Get the mask of correlated samples.

Returns

  • torch.Tensor

    The mask of correlated samples

Return type:

torch.Tensor

_get_similarity_function(use_cosine_similarity)[source]

Define the similarity function to be used in the loss calculation.

Parameters

  • use_cosine_similarity: bool

    If True, the cosine similarity is used. If False, the dot product is used

Returns

  • function

    The similarity function to be used in the loss calculation

Parameters:

use_cosine_similarity (bool)

batch_size
criterion
device
forward(zis, zjs)[source]

The forward method of the NTXentLoss_poly class. It receives the samples and returns the loss of the model.

Parameters

  • zis: torch.Tensor

    The positive samples

  • zjs: torch.Tensor

    The negative samples

Returns

  • _Loss

    The loss of the model

Parameters:
  • zis (torch.Tensor)

  • zjs (torch.Tensor)

Return type:

torch.nn.modules.loss._Loss

mask_samples_from_same_repr
similarity_function
softmax
temperature
Parameters:
  • device (str)

  • batch_size (int)

  • temperature (float)

  • use_cosine_similarity (bool)