Source code for minerva.data.data_module_tools

from typing import Optional, Tuple, Union

import lightning as L
import torch
from torch.utils.data import DataLoader


[docs] class SimpleDataset: def __init__(self, data: torch.Tensor, label: Optional[torch.Tensor] = None): self.data = data self.label = label
[docs] def __getitem__(self, idx): if self.label is not None: return self.data[idx], self.label[idx] else: return self.data[idx]
[docs] def __len__(self): return len(self.data)
[docs] class RandomDataModule(L.LightningDataModule): def __init__( self, data_shape: Tuple[int, ...], label_shape: Union[int, Tuple[int, ...], None] = None, num_classes: Optional[int] = None, num_train_samples: int = 128, num_val_samples: int = 8, num_test_samples: int = 8, num_predict_samples: int = 8, batch_size: int = 8, data_dtype: torch.dtype = torch.float32, label_dtype: torch.dtype = torch.float32, ): super().__init__() self.data_shape = data_shape self.label_shape = label_shape self.num_classes = num_classes self.num_train_samples = num_train_samples self.num_val_samples = num_val_samples self.num_test_samples = num_test_samples self.num_predict_samples = num_predict_samples self.batch_size = batch_size self.data_dtype = data_dtype self.label_dtype = label_dtype self.train_data = None self.val_data = None self.test_data = None self.predict_data = None assert num_train_samples > 0, "num_train_samples must be greater than 0" if num_val_samples is not None: assert num_val_samples > 0, "num_val_samples must be greater than 0" else: delattr(self, "val_dataloader") if num_test_samples is not None: assert num_test_samples > 0, "num_test_samples must be greater than 0" else: delattr(self, "test_dataloader")
[docs] def _generate_data(self, num_samples, data_shape, label_shape, num_classes): data = torch.rand((num_samples, *data_shape), dtype=self.data_dtype) label = None if label_shape is not None and num_classes is not None: label = torch.randint(0, num_classes, (num_samples, *label_shape)) elif label_shape is not None: label = torch.rand((num_samples, *label_shape)) elif num_classes is not None: label = torch.randint(0, num_classes, (num_samples,)) label = label.to(dtype=self.label_dtype) return data, label
[docs] def setup(self, stage): if stage == "fit": data, label = self._generate_data( self.num_train_samples, self.data_shape, self.label_shape, self.num_classes, ) self.train_data = SimpleDataset(data, label) if self.num_val_samples is not None: data, label = self._generate_data( self.num_val_samples, self.data_shape, self.label_shape, self.num_classes, ) self.val_data = SimpleDataset(data, label) elif stage == "test": if self.num_test_samples is not None: data, label = self._generate_data( self.num_test_samples, self.data_shape, self.label_shape, self.num_classes, ) self.test_data = SimpleDataset(data, label) elif stage == "predict": if self.num_predict_samples is not None: data, label = self._generate_data( self.num_predict_samples, self.data_shape, self.label_shape, self.num_classes, ) self.predict_data = SimpleDataset(data, label) else: raise ValueError(f"Invalid stage: {stage}")
[docs] def train_dataloader(self): return DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True)
[docs] def val_dataloader(self): return DataLoader(self.val_data, batch_size=self.batch_size, shuffle=False)
[docs] def test_dataloader(self): return DataLoader(self.test_data, batch_size=self.batch_size, shuffle=False)
[docs] def predict_dataloader(self): return DataLoader(self.predict_data, batch_size=self.batch_size, shuffle=False)
[docs] def get_split_dataloader(data_module: L.LightningDataModule, stage: str) -> DataLoader: if stage == "train": data_module.setup("fit") return data_module.train_dataloader() elif stage == "validation": data_module.setup("fit") return data_module.val_dataloader() elif stage == "test": data_module.setup("test") return data_module.test_dataloader() elif stage == "predict": data_module.setup("predict") return data_module.predict_dataloader() else: raise ValueError(f"Invalid stage: {stage}")
[docs] def full_dataset_from_dataloader(dataloader: DataLoader): res = [dataloader.dataset[i] for i in range(len(dataloader.dataset))] # unpack the data and labels return list(zip(*res))
[docs] def get_full_data_split( data_module: L.LightningDataModule, stage: str, ): dataloader = get_split_dataloader(data_module, stage) return full_dataset_from_dataloader(dataloader)