Source code for minerva.models.ssl.topological_autoencoder

import lightning as L
from torch.nn import MSELoss
from torch.optim import Adam
from minerva.losses.topological_loss import TopologicalLoss
from typing import Callable, List, Optional
import torch.nn as nn


[docs] class TopologicalAutoencoder(L.LightningModule): def __init__( self, encoder: nn.Module, decoder: nn.Module, topological_loss: Optional[Callable] = None, reconstruction_loss: Optional[Callable] = None, lambda_param: float = 1e-3, learning_rate: float = 1e-3, ): """ Topological autoencoder model. Parameters ---------- encoder : torch.nn.Module Encoder model decoder : torch.nn.Module Decoder model topological_loss : torch.nn.Module, optional Topological loss, by default None reconstruction_loss : torch.nn.Module, optional Reconstruction loss, by default None lambda_param : float, optional Weight of the topological loss, by default 1e-3 learning_rate : float, optional Learning rate, by default 1e-3 """ super(TopologicalAutoencoder, self).__init__() # Saving parameters self.lambda_param = lambda_param self.learning_rate = learning_rate # Defining layers self.encoder = encoder self.decoder = decoder # Defining topological loss self.topological_loss = ( topological_loss if topological_loss is not None else TopologicalLoss() ) # Defining reconstruction loss self.reconstruction_loss = ( reconstruction_loss if reconstruction_loss is not None else MSELoss() )
[docs] def forward(self, x): z = self.encoder(x) return self.decoder(z)
[docs] def training_step(self, batch, batch_idx): x, _ = batch x_encoded = self.encoder(x) x_hat = self.decoder(x_encoded) loss = self.reconstruction_loss( x, x_encoded ) + self.lambda_param * self.topological_loss(x, x_hat) self.log("train_loss", loss) return loss
[docs] def validation_step(self, batch, batch_idx): x, _ = batch x_encoded = self.encoder(x) x_hat = self.decoder(x_encoded) loss = self.reconstruction_loss( x, x_encoded ) + self.lambda_param * self.topological_loss(x, x_hat) self.log("val_loss", loss) return loss
[docs] def configure_optimizers(self): return Adam(self.parameters(), lr=self.learning_rate)