from typing import Optional
import yaml
from lightning import LightningDataModule
from torch.utils.data import DataLoader, Dataset
[docs]
class MinervaDataModule(LightningDataModule):
    def __init__(
        self,
        # Datasets
        train_dataset: Optional[Dataset] = None,
        val_dataset: Optional[Dataset] = None,
        test_dataset: Optional[Dataset] = None,
        predict_split: Optional[str] = "test",
        # DataLoader
        dataloader_cls: type = DataLoader,
        batch_size: int = 1,
        num_workers: int = 0,
        drop_last: bool = False,
        # Dataloader overrides (batch_size, num_workers, shuffle_train)
        additional_train_dataloader_kwargs: Optional[dict] = None,
        additional_val_dataloader_kwargs: Optional[dict] = None,
        additional_test_dataloader_kwargs: Optional[dict] = None,
        shuffle_train: bool = True,
        # Metadata
        name: str = "",
    ):
        """A fully-featured data module for PyTorch Lightning with support for
        acessing train, val, test, and predict datasets and dataloaders. This
        class is a generalization of the LightningDataModule class and is
        designed to be used with the Minerva framework.
        Parameters
        ----------
        train_dataset : Optional[Dataset], optional
            The training dataset, by default None
        val_dataset : Optional[Dataset], optional
            The validation dataset, by default None
        test_dataset : Optional[Dataset], optional
            The test dataset, by default None
        predict_split : Optional[str], optional
            Set the split to predict on (using the predict_dataloader method),
            by default "test"
        dataloader_cls : type, optional
            The dataloader class to use. The datasets will be wrapped in this
            class when creating the dataloaders, by default DataLoader
        batch_size : int, optional
            Default batch_size for all dataloaders, by default 1
        num_workers : int, optional
            Default num_workers for all dataloaders, by default 0
        drop_last : bool, optional
            Default drop_last for all dataloaders, by default False
        additional_train_dataloader_kwargs : Optional[dict], optional
            Override the default train dataloader kwargs, by default None
        additional_val_dataloader_kwargs : Optional[dict], optional
            Override the default val dataloader kwargs, by default None
        additional_test_dataloader_kwargs : Optional[dict], optional
            Override the default test dataloader kwargs, by default None
        shuffle_train : bool, optional
            If True, shuffle the training dataset. If False, do not shuffle the
            training dataset, by default True. By default, only the training
            dataloader is shuffled.
        name : str, optional
            Name of the data module, by default ""
        """
        super().__init__()
        self._name = name
        self._train_dataset = train_dataset
        self._val_dataset = val_dataset
        self._test_dataset = test_dataset
        self._predict_split = predict_split
        if predict_split == "train":
            self._predict_dataset = train_dataset
        elif predict_split == "val":
            self._predict_dataset = val_dataset
        elif predict_split == "test":
            self._predict_dataset = test_dataset
        elif predict_split is None:
            self._predict_dataset = None
        else:
            raise ValueError(
                f"predict_split must be one of 'train', 'val', 'test', or None. Got {predict_split}."
            )
        self._batch_size = batch_size
        self._num_workers = num_workers
        self._shuffle_train = shuffle_train
        self._dataloader_cls = dataloader_cls
        self._train_dataloader_kwargs = self.__update_dataloader_kwargs(
            additional_train_dataloader_kwargs,
            batch_size,
            num_workers,
            drop_last,
            shuffle=shuffle_train,
        )
        self._val_dataloader_kwargs = self.__update_dataloader_kwargs(
            additional_val_dataloader_kwargs,
            batch_size,
            num_workers,
            drop_last,
            shuffle=False,
        )
        self._test_dataloader_kwargs = self.__update_dataloader_kwargs(
            additional_test_dataloader_kwargs,
            batch_size,
            num_workers,
            drop_last,
            shuffle=False,
        )
        if predict_split == "train":
            self._predict_dataloader_kwargs = self._train_dataloader_kwargs
        elif predict_split == "val":
            self._predict_dataloader_kwargs = self._val_dataloader_kwargs
        elif predict_split == "test":
            self._predict_dataloader_kwargs = self._test_dataloader_kwargs
        else:
            self._predict_dataloader_kwargs = {}
        # Monkey patch the dataloaders if the datasets are not provided
        # It is applyed at instance level to avoid breaking the class signature
        if not self._train_dataset:
            self.train_dataloader = None  # type: ignore
        if not self._val_dataset:
            self.val_dataloader = None  # type: ignore
        if not self._test_dataset:
            self.test_dataloader = None  # type: ignore
        if not self._predict_dataset:
            self.predict_dataloader = None  # type: ignore
    @property
    def dataset_name(self):
        return self._name
    @staticmethod
    def __update_dataloader_kwargs(
        additional_kwargs, batch_size, num_workers, drop_last, shuffle
    ):
        kwargs = {
            "batch_size": batch_size,
            "num_workers": num_workers,
            "shuffle": shuffle,
            "drop_last": drop_last,
        }
        if additional_kwargs:
            kwargs.update(additional_kwargs)
        return kwargs
    @property
    def train_dataset(self):
        return self._train_dataset
    @property
    def val_dataset(self):
        return self._val_dataset
    @property
    def test_dataset(self):
        return self._test_dataset
    @property
    def predict_dataset(self):
        return self._predict_dataset
[docs]
    def train_dataloader(self):
        return self._dataloader_cls(self.train_dataset, **self._train_dataloader_kwargs) 
[docs]
    def val_dataloader(self):
        return self._dataloader_cls(self.val_dataset, **self._val_dataloader_kwargs) 
[docs]
    def test_dataloader(self):
        return self._dataloader_cls(self.test_dataset, **self._test_dataloader_kwargs) 
[docs]
    def predict_dataloader(self):
        return self._dataloader_cls(
            self.predict_dataset, **self._predict_dataloader_kwargs
        ) 
[docs]
    def __str__(self) -> str:
        def indent_text(text, spaces=6, add_line_breaks=True):
            """Indent each line of a string by a given number of spaces."""
            if not text:
                return "No data."
            return "\n".join(
                (
                    " " * int(spaces // 2) + "│"
                    if add_line_breaks
                    else " " * int(spaces // 2) + " "
                )
                + (" " * spaces + line if line.strip() else line)
                for line in text.split("\n")
            )
        def pretty_yaml(d, indent=6):
            """Pretty-print dictionary in YAML format with '├──' for each key."""
            if not d:
                return "No data."
            yaml_str = yaml.dump(
                d, default_flow_style=False, sort_keys=False
            ).strip()  # Remove trailing newlines
            return "\n".join(
                f"{' ' * indent}├── {line}"
                for line in yaml_str.split("\n")
                if line.strip()
            )
        return (
            f"{'=' * 50}\n"
            f"{' ' * ((50 - len(self._name)) // 2)}🆔 {self._name}\n"
            f"{'=' * 50}\n"
            f"└── Predict Split: {self._predict_split}\n"
            f"📂 Datasets:\n"
            f"   ├── Train Dataset:\n{indent_text(str(self.train_dataset))}\n"
            f"   ├── Val Dataset:\n{indent_text(str(self.val_dataset))}\n"
            f"   └── Test Dataset:\n{indent_text(str(self.test_dataset), add_line_breaks=False)}\n"
            f"\n🛠 **Dataloader Configurations:**\n"
            f"   ├── Dataloader class: {self._dataloader_cls}\n"
            f"   ├── Train Dataloader Kwargs:\n{pretty_yaml(self._train_dataloader_kwargs, indent=9)}\n"
            f"   ├── Val Dataloader Kwargs:\n{pretty_yaml(self._val_dataloader_kwargs, indent=9)}\n"
            f"   └── Test Dataloader Kwargs:\n{pretty_yaml(self._test_dataloader_kwargs, indent=9)}\n"
            f"{'=' * 50}"
        ) 
[docs]
    def __repr__(self) -> str:
        return self.__str__()