from typing import Optional, Tuple, Union
import lightning as L
import torch
from torch.utils.data import DataLoader
[docs]
class SimpleDataset:
    def __init__(self, data: torch.Tensor, label: Optional[torch.Tensor] = None):
        self.data = data
        self.label = label
[docs]
    def __getitem__(self, idx):
        if self.label is not None:
            return self.data[idx], self.label[idx]
        else:
            return self.data[idx] 
[docs]
    def __len__(self):
        return len(self.data) 
 
[docs]
class RandomDataModule(L.LightningDataModule):
    def __init__(
        self,
        data_shape: Tuple[int, ...],
        label_shape: Union[int, Tuple[int, ...], None] = None,
        num_classes: Optional[int] = None,
        num_train_samples: int = 128,
        num_val_samples: int = 8,
        num_test_samples: int = 8,
        num_predict_samples: int = 8,
        batch_size: int = 8,
        data_dtype: torch.dtype = torch.float32,
        label_dtype: torch.dtype = torch.float32,
    ):
        super().__init__()
        self.data_shape = data_shape
        self.label_shape = label_shape
        self.num_classes = num_classes
        self.num_train_samples = num_train_samples
        self.num_val_samples = num_val_samples
        self.num_test_samples = num_test_samples
        self.num_predict_samples = num_predict_samples
        self.batch_size = batch_size
        self.data_dtype = data_dtype
        self.label_dtype = label_dtype
        self.train_data = None
        self.val_data = None
        self.test_data = None
        self.predict_data = None
        assert num_train_samples > 0, "num_train_samples must be greater than 0"
        if num_val_samples is not None:
            assert num_val_samples > 0, "num_val_samples must be greater than 0"
        else:
            delattr(self, "val_dataloader")
        if num_test_samples is not None:
            assert num_test_samples > 0, "num_test_samples must be greater than 0"
        else:
            delattr(self, "test_dataloader")
[docs]
    def _generate_data(self, num_samples, data_shape, label_shape, num_classes):
        data = torch.rand((num_samples, *data_shape), dtype=self.data_dtype)
        label = None
        if label_shape is not None and num_classes is not None:
            label = torch.randint(0, num_classes, (num_samples, *label_shape))
        elif label_shape is not None:
            label = torch.rand((num_samples, *label_shape))
        elif num_classes is not None:
            label = torch.randint(0, num_classes, (num_samples,))
        label = label.to(dtype=self.label_dtype)
        return data, label 
[docs]
    def setup(self, stage):
        if stage == "fit":
            data, label = self._generate_data(
                self.num_train_samples,
                self.data_shape,
                self.label_shape,
                self.num_classes,
            )
            self.train_data = SimpleDataset(data, label)
            if self.num_val_samples is not None:
                data, label = self._generate_data(
                    self.num_val_samples,
                    self.data_shape,
                    self.label_shape,
                    self.num_classes,
                )
                self.val_data = SimpleDataset(data, label)
        elif stage == "test":
            if self.num_test_samples is not None:
                data, label = self._generate_data(
                    self.num_test_samples,
                    self.data_shape,
                    self.label_shape,
                    self.num_classes,
                )
                self.test_data = SimpleDataset(data, label)
        elif stage == "predict":
            if self.num_predict_samples is not None:
                data, label = self._generate_data(
                    self.num_predict_samples,
                    self.data_shape,
                    self.label_shape,
                    self.num_classes,
                )
                self.predict_data = SimpleDataset(data, label)
        else:
            raise ValueError(f"Invalid stage: {stage}") 
[docs]
    def train_dataloader(self):
        return DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True) 
[docs]
    def val_dataloader(self):
        return DataLoader(self.val_data, batch_size=self.batch_size, shuffle=False) 
[docs]
    def test_dataloader(self):
        return DataLoader(self.test_data, batch_size=self.batch_size, shuffle=False) 
[docs]
    def predict_dataloader(self):
        return DataLoader(self.predict_data, batch_size=self.batch_size, shuffle=False) 
 
[docs]
def get_split_dataloader(data_module: L.LightningDataModule, stage: str) -> DataLoader:
    if stage == "train":
        data_module.setup("fit")
        return data_module.train_dataloader()
    elif stage == "validation":
        data_module.setup("fit")
        return data_module.val_dataloader()
    elif stage == "test":
        data_module.setup("test")
        return data_module.test_dataloader()
    elif stage == "predict":
        data_module.setup("predict")
        return data_module.predict_dataloader()
    else:
        raise ValueError(f"Invalid stage: {stage}") 
[docs]
def full_dataset_from_dataloader(dataloader: DataLoader):
    res = [dataloader.dataset[i] for i in range(len(dataloader.dataset))]
    # unpack the data and labels
    return list(zip(*res)) 
[docs]
def get_full_data_split(
    data_module: L.LightningDataModule,
    stage: str,
):
    dataloader = get_split_dataloader(data_module, stage)
    return full_dataset_from_dataloader(dataloader)