import lightning as L
from torch.nn import MSELoss
from torch.optim import Adam
from minerva.losses.topological_loss import TopologicalLoss
from typing import Callable, List, Optional
import torch.nn as nn
[docs]
class TopologicalAutoencoder(L.LightningModule):
    def __init__(
        self,
        encoder: nn.Module,
        decoder: nn.Module,
        topological_loss: Optional[Callable] = None,
        reconstruction_loss: Optional[Callable] = None,
        lambda_param: float = 1e-3,
        learning_rate: float = 1e-3,
    ):
        """
        Topological autoencoder model.
        Parameters
        ----------
        encoder : torch.nn.Module
            Encoder model
        decoder : torch.nn.Module
            Decoder model
        topological_loss : torch.nn.Module, optional
            Topological loss, by default None
        reconstruction_loss : torch.nn.Module, optional
            Reconstruction loss, by default None
        lambda_param : float, optional
            Weight of the topological loss, by default 1e-3
        learning_rate : float, optional
            Learning rate, by default 1e-3
        """
        super(TopologicalAutoencoder, self).__init__()
        # Saving parameters
        self.lambda_param = lambda_param
        self.learning_rate = learning_rate
        # Defining layers
        self.encoder = encoder
        self.decoder = decoder
        # Defining topological loss
        self.topological_loss = (
            topological_loss if topological_loss is not None else TopologicalLoss()
        )
        # Defining reconstruction loss
        self.reconstruction_loss = (
            reconstruction_loss if reconstruction_loss is not None else MSELoss()
        )
[docs]
    def forward(self, x):
        z = self.encoder(x)
        return self.decoder(z) 
[docs]
    def training_step(self, batch, batch_idx):
        x, _ = batch
        x_encoded = self.encoder(x)
        x_hat = self.decoder(x_encoded)
        loss = self.reconstruction_loss(
            x, x_encoded
        ) + self.lambda_param * self.topological_loss(x, x_hat)
        self.log("train_loss", loss)
        return loss 
[docs]
    def validation_step(self, batch, batch_idx):
        x, _ = batch
        x_encoded = self.encoder(x)
        x_hat = self.decoder(x_encoded)
        loss = self.reconstruction_loss(
            x, x_encoded
        ) + self.lambda_param * self.topological_loss(x, x_hat)
        self.log("val_loss", loss)
        return loss