import lightning as L
from torch.nn import MSELoss
from torch.optim import Adam
import torch
from torch import nn
from typing import Optional, Callable
[docs]
class Autoencoder(L.LightningModule):
def __init__(
self,
encoder: nn.Module,
decoder: nn.Module,
loss: Optional[Callable]=None,
learning_rate: float=1e-3):
"""
Autoencoder model.
Parameters
----------
encoder : torch.nn.Module
Encoder model
decoder : torch.nn.Module
Decoder model
loss : Callable, optional
Reconstruction loss, by default None
learning_rate : float, optional
Learning rate, by default 1e-3
"""
super(Autoencoder, self).__init__()
# Saving parameters
self.learning_rate = learning_rate
# Defining layers
self.encoder = encoder
self.decoder = decoder
# Defining reconstruction loss
self.reconstruction_loss = loss if loss is not None else MSELoss()
[docs]
def forward(self, x):
z = self.encoder(x)
return self.decoder(z)
[docs]
def training_step(self, batch, batch_idx):
x, _ = batch
x_hat = self(x)
loss = self.reconstruction_loss(x, x_hat)
self.log('train_loss', loss)
return loss
[docs]
def validation_step(self, batch, batch_idx):
x, _ = batch
x_hat = self(x)
loss = self.reconstruction_loss(x, x_hat)
self.log('val_loss', loss)
return loss