import warnings
from typing import Dict, List, Optional, Tuple, Union
import lightning.pytorch as L
import torch
from torch import nn
from torch.optim.adam import Adam
from torchmetrics import Metric
from minerva.engines.engine import _Engine
from minerva.models.nets.image.vit import _VisionTransformerBackbone
from minerva.utils.upsample import Upsample
[docs]
class _SETRUPHead(nn.Module):
    """Naive upsampling head and Progressive upsampling head of SETR
    (as in https://arxiv.org/pdf/2012.15840.pdf).
    """
    def __init__(
        self,
        channels: int,
        in_channels: int,
        num_classes: int,
        norm_layer: nn.Module,
        conv_norm: nn.Module,
        conv_act: nn.Module,
        num_convs: int,
        up_scale: int,
        kernel_size: int,
        align_corners: bool,
        dropout: float,
        interpolate_mode: str,
    ):
        """The SETR PUP Head.
        Parameters
        ----------
        channels : int
            Number of output channels.
        in_channels : int
            Number of input channels.
        num_classes : int
            Number of output classes.
        norm_layer : nn.Module
            Normalization layer.
        conv_norm : nn.Module
            Convolutional normalization layer.
        conv_act : nn.Module
            Convolutional activation layer.
        num_convs : int
            Number of convolutional layers.
        up_scale : int
            Upsampling scale factor.
        kernel_size : int
            Kernel size for convolutional layers.
        align_corners : bool
            Whether to align corners during upsampling.
        dropout : float
            Dropout rate.
        interpolate_mode : str
            Interpolation mode for upsampling.
        Raises
        ------
        AssertionError
            If kernel_size is not 1 or 3.
        """
        assert kernel_size in [1, 3], "kernel_size must be 1 or 3."
        super().__init__()
        self.num_classes = num_classes
        self.out_channels = channels
        self.cls_seg = nn.Conv2d(channels, self.num_classes, 1)
        self.norm = norm_layer
        conv_norm = conv_norm
        conv_act = conv_act
        self.dropout = nn.Dropout2d(dropout) if dropout > 0 != None else None
        self.up_convs = nn.ModuleList()
        for _ in range(num_convs):
            self.up_convs.append(
                nn.Sequential(
                    nn.Conv2d(
                        in_channels,
                        self.out_channels,
                        kernel_size,
                        padding=kernel_size // 2,
                        bias=False,
                    ),
                    conv_norm,
                    conv_act,
                    Upsample(
                        scale_factor=up_scale,
                        mode=interpolate_mode,
                        align_corners=align_corners,
                    ),
                )
            )
            in_channels = self.out_channels
[docs]
    def forward(self, x):
        n, c, h, w = x.shape
        x = x.reshape(n, c, h * w).transpose(1, 2).contiguous()
        x = self.norm(x)
        x = x.transpose(1, 2).reshape(n, c, h, w).contiguous()
        for up_conv in self.up_convs:
            x = up_conv(x)
        if self.dropout is not None:
            x = self.dropout(x)
        out = self.cls_seg(x)
        return out 
 
[docs]
class _SETRMLAHead(nn.Module):
    """Multi level feature aggretation head of SETR (as in
    https://arxiv.org/pdf/2012.15840.pdf)
    Note: This has not been tested yet!
    """
    def __init__(
        self,
        channels: int,
        conv_norm: Optional[nn.Module],
        conv_act: Optional[nn.Module],
        in_channels: List[int],
        out_channels: int,
        num_classes: int,
        mla_channels: int = 128,
        up_scale: int = 4,
        kernel_size: int = 3,
        align_corners: bool = True,
        dropout: float = 0.1,
        threshold: Optional[float] = None,
    ):
        super().__init__()
        if out_channels is None:
            if num_classes == 2:
                warnings.warn(
                    "For binary segmentation, we suggest using"
                    "`out_channels = 1` to define the output"
                    "channels of segmentor, and use `threshold`"
                    "to convert `seg_logits` into a prediction"
                    "applying a threshold"
                )
            out_channels = num_classes
        if out_channels != num_classes and out_channels != 1:
            raise ValueError(
                "out_channels should be equal to num_classes,"
                "except binary segmentation set out_channels == 1 and"
                f"num_classes == 2, but got out_channels={out_channels}"
                f"and num_classes={num_classes}"
            )
        if out_channels == 1 and threshold is None:
            threshold = 0.3
            warnings.warn("threshold is not defined for binary, and defaults to 0.3")
        self.num_classes = num_classes
        self.out_channels = out_channels
        self.threshold = threshold
        conv_norm = (
            conv_norm if conv_norm is not None else nn.SyncBatchNorm(mla_channels)
        )
        conv_act = conv_act if conv_act is not None else nn.ReLU()
        self.dropout = nn.Dropout2d(dropout) if dropout > 0 != None else None
        self.cls_seg = nn.Conv2d(channels, out_channels, 1)
        num_inputs = len(in_channels)
        self.up_convs = nn.ModuleList()
        for i in range(num_inputs):
            self.up_convs.append(
                nn.Sequential(
                    nn.Conv2d(
                        in_channels[i],
                        mla_channels,
                        kernel_size,
                        padding=kernel_size // 2,
                        bias=False,
                    ),
                    conv_norm,
                    conv_act,
                    nn.Conv2d(
                        mla_channels,
                        mla_channels,
                        kernel_size,
                        padding=kernel_size // 2,
                        bias=False,
                    ),
                    conv_norm,
                    conv_act,
                    Upsample(
                        scale_factor=up_scale,
                        mode="bilinear",
                        align_corners=align_corners,
                    ),
                )
            )
[docs]
    def forward(self, x):
        outs = []
        for x, up_conv in zip(x, self.up_convs):
            outs.append(up_conv(x))
        out = torch.cat(outs, dim=1)
        if self.dropout is not None:
            out = self.dropout(out)
        out = self.cls_seg(out)
        return out 
 
[docs]
class _SetR_PUP(nn.Module):
    def __init__(
        self,
        image_size: Union[int, Tuple[int, int]],
        patch_size: int,
        num_layers: int,
        num_heads: int,
        hidden_dim: int,
        mlp_dim: int,
        num_convs: int,
        num_classes: int,
        decoder_channels: int,
        up_scale: int,
        encoder_dropout: float,
        kernel_size: int,
        decoder_dropout: float,
        norm_layer: nn.Module,
        interpolate_mode: str,
        conv_norm: nn.Module,
        conv_act: nn.Module,
        align_corners: bool,
        aux_output: bool,
        aux_output_layers: Optional[List[int]],
        original_resolution: Optional[Tuple[int, int]],
    ):
        """Initializes the SETR PUP head.
        Parameters
        ----------
        image_size : int or Tuple[int, int]
            The size of the input image.
        patch_size : int
            The size of each patch in the input image.
        num_layers : int
            The number of layers in the transformer encoder.
        num_heads : int
            The number of attention heads in the transformer encoder.
        hidden_dim : int
            The hidden dimension of the transformer encoder.
        mlp_dim : int
            The dimension of the feed-forward network in the transformer encoder
        num_convs : int
            The number of convolutional layers in the decoder.
        num_classes : int
            The number of output classes.
        decoder_channels : int
            The number of channels in the decoder.
        up_scale : int
            The scale factor for upsampling in the decoder.
        encoder_dropout : float
            The dropout rate for the transformer encoder.
        kernel_size : int
            The kernel size for the convolutional layers in the decoder.
        decoder_dropout : float
            The dropout rate for the decoder.
        norm_layer : nn.Module
            The normalization layer to be used.
        interpolate_mode : str
            The mode for interpolation during upsampling.
        conv_norm : nn.Module
            The normalization layer to be used in the decoder convolutional
            layers.
        conv_act : nn.Module
            The activation function to be used in the decoder convolutional
            layers.
        align_corners : bool
            Whether to align corners during upsampling.
        aux_output: bool
            Whether to use auxiliary outputs. If True, aux_output_layers must
            be provided.
        aux_output_layers: List[int], optional
            The layers to use for auxiliary outputs. Must have exacly 3 values.
        original_resolution: Tuple[int, int], optional
            The original resolution of the input image in the pre-training
            weights. When None, positional embeddings will not be interpolated.
        """
        super().__init__()
        if aux_output:
            assert aux_output_layers is not None, "aux_output_layers must be provided."
            assert (
                len(aux_output_layers) == 3
            ), "aux_output_layers must have 3 values. Only 3 aux heads are supported."
        self.aux_output = aux_output
        self.aux_output_layers = aux_output_layers
        self.encoder = _VisionTransformerBackbone(
            image_size=image_size,
            patch_size=patch_size,
            num_layers=num_layers,
            num_heads=num_heads,
            hidden_dim=hidden_dim,
            mlp_dim=mlp_dim,
            num_classes=num_classes,
            dropout=encoder_dropout,
            aux_output=aux_output,
            aux_output_layers=aux_output_layers,
            original_resolution=original_resolution,
        )
        self.decoder = _SETRUPHead(
            channels=decoder_channels,
            in_channels=hidden_dim,
            num_classes=num_classes,
            num_convs=num_convs,
            up_scale=up_scale,
            kernel_size=kernel_size,
            align_corners=align_corners,
            dropout=decoder_dropout,
            conv_norm=conv_norm,
            conv_act=conv_act,
            interpolate_mode=interpolate_mode,
            norm_layer=norm_layer,
        )
        self.aux_head1 = _SETRUPHead(
            channels=decoder_channels,
            in_channels=hidden_dim,
            num_classes=num_classes,
            num_convs=num_convs,
            up_scale=up_scale,
            kernel_size=kernel_size,
            align_corners=align_corners,
            dropout=decoder_dropout,
            conv_norm=conv_norm,
            conv_act=conv_act,
            interpolate_mode=interpolate_mode,
            norm_layer=norm_layer,
        )
        self.aux_head2 = _SETRUPHead(
            channels=decoder_channels,
            in_channels=hidden_dim,
            num_classes=num_classes,
            num_convs=num_convs,
            up_scale=up_scale,
            kernel_size=kernel_size,
            align_corners=align_corners,
            dropout=decoder_dropout,
            conv_norm=conv_norm,
            conv_act=conv_act,
            interpolate_mode=interpolate_mode,
            norm_layer=norm_layer,
        )
        self.aux_head3 = _SETRUPHead(
            channels=decoder_channels,
            in_channels=hidden_dim,
            num_classes=num_classes,
            num_convs=num_convs,
            up_scale=up_scale,
            kernel_size=kernel_size,
            align_corners=align_corners,
            dropout=decoder_dropout,
            conv_norm=conv_norm,
            conv_act=conv_act,
            interpolate_mode=interpolate_mode,
            norm_layer=norm_layer,
        )
[docs]
    def forward(self, x: torch.Tensor):
        if self.aux_output:
            x, aux_results = self.encoder(x)
            x_aux1 = self.aux_head1(aux_results[0])
            x_aux2 = self.aux_head2(aux_results[1])
            x_aux3 = self.aux_head3(aux_results[2])
            x = self.decoder(x)
            return x, x_aux1, x_aux2, x_aux3
        x = self.encoder(x)
        x = self.decoder(x)
        return x 
[docs]
    def load_backbone(self, path: str, freeze: bool = False):
        self.encoder.load_backbone(path)
        if freeze:
            for param in self.encoder.parameters():
                param.requires_grad = False 
 
[docs]
class SETR_PUP(L.LightningModule):
    """SET-R model with PUP head for image segmentation.
    Methods
    -------
    forward(x: torch.Tensor) -> torch.Tensor
        Forward pass of the model.
    _compute_metrics(y_hat: torch.Tensor, y: torch.Tensor, step_name: str)
        Compute metrics for the given step.
    _loss_func(y_hat: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], y: torch.Tensor) -> torch.Tensor
        Calculate the loss between the output and the input data.
    _single_step(batch: torch.Tensor, batch_idx: int, step_name: str)
        Perform a single step of the training/validation loop.
    training_step(batch: torch.Tensor, batch_idx: int)
        Perform a single training step.
    validation_step(batch: torch.Tensor, batch_idx: int)
        Perform a single validation step.
    test_step(batch: torch.Tensor, batch_idx: int)
        Perform a single test step.
    predict_step(batch: torch.Tensor, batch_idx: int, dataloader_idx: Optional[int] = None)
        Perform a single prediction step.
    load_backbone(path: str, freeze: bool = False)
        Load a pre-trained backbone.
    configure_optimizers()
        Configure the optimizer for the model.
    create_from_dict(config: Dict) -> "SETR_PUP"
        Create an instance of SETR_PUP from a configuration dictionary.
    """
    def __init__(
        self,
        image_size: Union[int, Tuple[int, int]] = 512,
        patch_size: int = 16,
        num_layers: int = 24,
        num_heads: int = 16,
        hidden_dim: int = 1024,
        mlp_dim: int = 4096,
        encoder_dropout: float = 0.1,
        num_classes: int = 1000,
        norm_layer: Optional[nn.Module] = None,
        decoder_channels: int = 256,
        num_convs: int = 4,
        up_scale: int = 2,
        kernel_size: int = 3,
        align_corners: bool = False,
        decoder_dropout: float = 0.1,
        conv_norm: Optional[nn.Module] = None,
        conv_act: Optional[nn.Module] = None,
        interpolate_mode: str = "bilinear",
        loss_fn: Optional[nn.Module] = None,
        optimizer_type: Optional[type] = None,
        optimizer_params: Optional[Dict] = None,
        train_metrics: Optional[Dict[str, Metric]] = None,
        val_metrics: Optional[Dict[str, Metric]] = None,
        test_metrics: Optional[Dict[str, Metric]] = None,
        aux_output: bool = True,
        aux_output_layers: Optional[list[int]] = None,
        aux_weights: Optional[list[float]] = None,
        load_backbone_path: Optional[str] = None,
        freeze_backbone_on_load: bool = True,
        learning_rate: float = 1e-3,
        loss_weights: Optional[list[float]] = None,
        original_resolution: Optional[Tuple[int, int]] = None,
        head_lr_factor: float = 1.0,
        test_engine: Optional[_Engine] = None,
    ):
        """Initialize the SETR model with Progressive Upsampling Head.
        Parameters
        ----------
        image_size : Union[int, Tuple[int, int]], optional
            Size of the input image, by default 512.
        patch_size : int, optional
            Size of the patches to be extracted from the input image, by
            default 16.
        num_layers : int, optional
            Number of transformer layers, by default 24.
        num_heads : int, optional
            Number of attention heads, by default 16.
        hidden_dim : int, optional
            Dimension of the hidden layer, by default 1024.
        mlp_dim : int, optional
            Dimension of the MLP layer, by default 4096.
        encoder_dropout : float, optional
            Dropout rate for the encoder, by default 0.1.
        num_classes : int, optional
            Number of output classes, by default 1000.
        norm_layer : Optional[nn.Module], optional
            Normalization layer, by default None.
        decoder_channels : int, optional
            Number of channels in the decoder, by default 256.
        num_convs : int, optional
            Number of convolutional layers in the decoder, by default 4.
        up_scale : int, optional
            Upscaling factor for the decoder, by default 2.
        kernel_size : int, optional
            Kernel size for the convolutional layers, by default 3.
        align_corners : bool, optional
            Whether to align corners when interpolating, by default False.
        decoder_dropout : float, optional
            Dropout rate for the decoder, by default 0.1.
        conv_norm : Optional[nn.Module], optional
            Normalization layer for the convolutional layers, by default None.
        conv_act : Optional[nn.Module], optional
            Activation function for the convolutional layers, by default None.
        interpolate_mode : str, optional
            Interpolation mode, by default "bilinear".
        loss_fn : Optional[nn.Module], optional
            Loss function, when None defaults to nn.CrossEntropyLoss, by
            default None.
        optimizer_type : Optional[type], optional
            Type of optimizer, by default None.
        optimizer_params : Optional[Dict], optional
            Parameters for the optimizer, by default None.
        train_metrics : Optional[Dict[str, Metric]], optional
            Metrics for training, by default None.
        val_metrics : Optional[Dict[str, Metric]], optional
            Metrics for validation, by default None.
        test_metrics : Optional[Dict[str, Metric]], optional
            Metrics for testing, by default None.
        aux_output : bool, optional
            Whether to use auxiliary outputs, by default True.
        aux_output_layers : list[int], optional
            Layers for auxiliary outputs, when None it defaults to [9, 14, 19].
        aux_weights : list[float], optional
            Weights for auxiliary outputs, when None it defaults [0.3, 0.3, 0.3].
        load_backbone_path : Optional[str], optional
            Path to load the backbone model, by default None.
        freeze_backbone_on_load : bool, optional
            Whether to freeze the backbone model on load, by default True.
        learning_rate : float, optional
            Learning rate, by default 1e-3.
        loss_weights : Optional[list[float]], optional
            Weights for the loss function, by default None.
        original_resolution : Optional[Tuple[int, int]], optional
            The original resolution of the input image in the pre-training
            weights. When None, positional embeddings will not be interpolated.
            Defaults to None.
        head_lr_factor : float, optional
            Learning rate factor for the head. used if you need different
            learning rates for backbone and prediction head, by default 1.0.
        test_engine : Optional[_Engine], optional
            Engine used for test and validation steps. When None, behavior of
            all steps, training, testing and validation is the same, by default None.
        """
        super().__init__()
        if head_lr_factor != 1:
            self.automatic_optimization = False
            self.multiple_optimizers = True
        else:
            self.automatic_optimization = True
            self.multiple_optimizers = False
        self.loss_fn = (
            loss_fn
            if loss_fn is not None
            else nn.CrossEntropyLoss(
                weight=(
                    torch.tensor(loss_weights) if loss_weights is not None else None
                )
            )
        )
        norm_layer = norm_layer if norm_layer is not None else nn.LayerNorm(hidden_dim)
        conv_norm = (
            conv_norm if conv_norm is not None else nn.SyncBatchNorm(decoder_channels)
        )
        conv_act = conv_act if conv_act is not None else nn.ReLU()
        if aux_output:
            if aux_output_layers is None:
                aux_output_layers = [9, 14, 19]
                warnings.warn(
                    "aux_output_layers not provided. Using default values [9, 14, 19]."
                )
            if aux_weights is None:
                aux_weights = [0.3, 0.3, 0.3]
                warnings.warn(
                    "aux_weights not provided. Using default values [0.3, 0.3, 0.3]."
                )
            assert (
                len(aux_output_layers) == 3
            ), "aux_output_layers must have 3 values. Only 3 aux heads are supported."
        self.optimizer_type = optimizer_type
        if optimizer_type is not None:
            assert optimizer_params is not None, "optimizer_params must be provided."
            self.optimizer_params = optimizer_params
        self.num_classes = num_classes
        self.aux_weights = aux_weights
        self.head_lr_factor = head_lr_factor
        self.metrics = {
            "train": train_metrics,
            "val": val_metrics,
            "test": test_metrics,
        }
        self.model = _SetR_PUP(
            image_size=image_size,
            patch_size=patch_size,
            num_layers=num_layers,
            num_heads=num_heads,
            hidden_dim=hidden_dim,
            mlp_dim=mlp_dim,
            num_classes=num_classes,
            num_convs=num_convs,
            up_scale=up_scale,
            kernel_size=kernel_size,
            conv_norm=conv_norm,
            conv_act=conv_act,
            decoder_channels=decoder_channels,
            encoder_dropout=encoder_dropout,
            decoder_dropout=decoder_dropout,
            norm_layer=norm_layer,
            interpolate_mode=interpolate_mode,
            align_corners=align_corners,
            aux_output=aux_output,
            aux_output_layers=aux_output_layers,
            original_resolution=original_resolution,
        )
        if load_backbone_path is not None:
            self.model.load_backbone(load_backbone_path, freeze_backbone_on_load)
        self.learning_rate = learning_rate
        self.test_engine = test_engine
[docs]
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x) 
[docs]
    def _compute_metrics(self, y_hat: torch.Tensor, y: torch.Tensor, step_name: str):
        if self.metrics[step_name] is None:
            return {}
        return {
            f"{step_name}_{metric_name}": metric.to(self.device)(
                torch.argmax(y_hat, dim=1, keepdim=True), y
            )
            for metric_name, metric in self.metrics[step_name].items()
        } 
[docs]
    def _loss_func(
        self,
        y_hat: Union[
            torch.Tensor,
            Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
        ],
        y: torch.Tensor,
    ) -> torch.Tensor:
        """Calculate the loss between the output and the input data.
        Parameters
        ----------
        y_hat : torch.Tensor
            The output data from the forward pass.
        y : torch.Tensor
            The input data/label.
        Returns
        -------
        torch.Tensor
            The loss value.
        """
        if isinstance(y_hat, tuple):
            y_hat, y_aux1, y_aux2, y_aux3 = y_hat
            loss = self.loss_fn(y_hat, y.long())
            loss_aux1 = self.loss_fn(y_aux1, y.long())
            loss_aux2 = self.loss_fn(y_aux2, y.long())
            loss_aux3 = self.loss_fn(y_aux3, y.long())
            return (
                loss
                + (loss_aux1 * self.aux_weights[0])
                + (loss_aux2 * self.aux_weights[1])
                + (loss_aux3 * self.aux_weights[2])
            )
        loss = self.loss_fn(y_hat, y.long())
        return loss 
[docs]
    def _single_step(self, batch: torch.Tensor, batch_idx: int, step_name: str):
        """Perform a single step of the training/validation loop.
        Parameters
        ----------
        batch : torch.Tensor
            The input data.
        batch_idx : int
            The index of the batch.
        step_name : str
            The name of the step, either "train" or "val".
        Returns
        -------
        torch.Tensor
            The loss value.
        """
        x, y = batch
        if self.test_engine and (step_name == "test" or step_name == "val"):
            y_hat = self.test_engine(self.model, x)
        else:
            y_hat = self.model(x)
        metrics = self._compute_metrics(y_hat[0], y, step_name)
        loss = self._loss_func(y_hat, y.squeeze(1))
        for metric_name, metric_value in metrics.items():
            self.log(
                metric_name,
                metric_value,
                on_step=False,
                on_epoch=True,
                prog_bar=True,
                logger=True,
                sync_dist=True,
            )
        self.log(
            f"{step_name}_loss",
            loss,
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            logger=True,
            sync_dist=True,
        )
        return loss 
[docs]
    def training_step(self, batch: torch.Tensor, batch_idx: int):
        if self.multiple_optimizers:
            optimizers_list = self.optimizers()
            for opt in optimizers_list:
                opt.zero_grad()
            loss = self._single_step(batch, batch_idx, "train")
            self.manual_backward(loss)
            for opt in optimizers_list:
                opt.step()
        else:
            return self._single_step(batch, batch_idx, "train") 
[docs]
    def validation_step(self, batch: torch.Tensor, batch_idx: int):
        return self._single_step(batch, batch_idx, "val") 
[docs]
    def test_step(self, batch: torch.Tensor, batch_idx: int):
        return self._single_step(batch, batch_idx, "test") 
[docs]
    def predict_step(
        self,
        batch: torch.Tensor,
        batch_idx: int,
        dataloader_idx: Optional[int] = None,
    ):
        x, _ = batch
        return self.model(x)[0] 
[docs]
    def load_backbone(self, path: str, freeze: bool = False):
        self.model.load_backbone(path, freeze) 
[docs]
    @staticmethod
    def create_from_dict(config: Dict) -> "SETR_PUP":
        return SETR_PUP(**config)