minerva.losses.ntxent_loss_poly
Classes
Loss function used on the pretraining of the TFC model. It is based on the NTXentLoss, but it includes a polynomial loss term. |
Module Contents
- class minerva.losses.ntxent_loss_poly.NTXentLoss_poly(device, batch_size, temperature, use_cosine_similarity)[source]
Bases:
torch.nn.modules.loss._Loss
Loss function used on the pretraining of the TFC model. It is based on the NTXentLoss, but it includes a polynomial loss term.
The constructor of the NTXentLoss_poly class.
Parameters
- device: str
The device to be used in the training of the model
- batch_size: int
The batch size of the model
- temperature: float
The temperature of the softmax function
- use_cosine_similarity: bool
If True, the cosine similarity is used. If False, the dot product is used
- _cosine_simililarity(x, y)[source]
Function to calculate the cosine similarity between two tensors.
Parameters
- x: torch.Tensor
The first tensor
- y: torch.Tensor
The second tensor
Returns
- torch.Tensor
The cosine similarity between the two tensors
- Return type:
torch.Tensor
- static _dot_simililarity(x, y)[source]
Function to calculate the dot similarity between two tensors.
Parameters
- x: torch.Tensor
The first tensor
- y: torch.Tensor
The second tensor
Returns
- torch.Tensor
The dot similarity between the two tensors
- Return type:
torch.Tensor
Get the mask of correlated samples.
Returns
- torch.Tensor
The mask of correlated samples
- Return type:
torch.Tensor
- _get_similarity_function(use_cosine_similarity)[source]
Define the similarity function to be used in the loss calculation.
Parameters
- use_cosine_similarity: bool
If True, the cosine similarity is used. If False, the dot product is used
Returns
- function
The similarity function to be used in the loss calculation
- Parameters:
use_cosine_similarity (bool)
- batch_size
- criterion
- device
- forward(zis, zjs)[source]
The forward method of the NTXentLoss_poly class. It receives the samples and returns the loss of the model.
Parameters
- zis: torch.Tensor
The positive samples
- zjs: torch.Tensor
The negative samples
Returns
- _Loss
The loss of the model
- Parameters:
zis (torch.Tensor)
zjs (torch.Tensor)
- Return type:
torch.nn.modules.loss._Loss
- mask_samples_from_same_repr
- similarity_function
- softmax
- temperature
- Parameters:
device (str)
batch_size (int)
temperature (float)
use_cosine_similarity (bool)