minerva.data.data_module_tools ============================== .. py:module:: minerva.data.data_module_tools Classes ------- .. autoapisummary:: minerva.data.data_module_tools.RandomDataModule minerva.data.data_module_tools.SimpleDataset Functions --------- .. autoapisummary:: minerva.data.data_module_tools.full_dataset_from_dataloader minerva.data.data_module_tools.get_full_data_split minerva.data.data_module_tools.get_split_dataloader Module Contents --------------- .. py:class:: RandomDataModule(data_shape, label_shape = None, num_classes = None, num_train_samples = 128, num_val_samples = 8, num_test_samples = 8, num_predict_samples = 8, batch_size = 8, data_dtype = torch.float32, label_dtype = torch.float32) Bases: :py:obj:`lightning.LightningDataModule` A DataModule standardizes the training, val, test splits, data preparation and transforms. The main advantage is consistent data splits, data preparation and transforms across models. Example:: import lightning as L import torch.utils.data as data from lightning.pytorch.demos.boring_classes import RandomDataset class MyDataModule(L.LightningDataModule): def prepare_data(self): # download, IO, etc. Useful with shared filesystems # only called on 1 GPU/TPU in distributed ... def setup(self, stage): # make assignments here (val/train/test split) # called on every process in DDP dataset = RandomDataset(1, 100) self.train, self.val, self.test = data.random_split( dataset, [80, 10, 10], generator=torch.Generator().manual_seed(42) ) def train_dataloader(self): return data.DataLoader(self.train) def val_dataloader(self): return data.DataLoader(self.val) def test_dataloader(self): return data.DataLoader(self.test) def on_exception(self, exception): # clean up state after the trainer faced an exception ... def teardown(self): # clean up state after the trainer stops, delete files... # called on every process in DDP ... Attributes: prepare_data_per_node: If True, each LOCAL_RANK=0 will call prepare data. Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data. allow_zero_length_dataloader_with_multiple_devices: If True, dataloader with zero length within local rank is allowed. Default value is False. .. py:method:: _generate_data(num_samples, data_shape, label_shape, num_classes) .. py:attribute:: batch_size :value: 8 .. py:attribute:: data_dtype :value: Ellipsis .. py:attribute:: data_shape .. py:attribute:: label_dtype :value: Ellipsis .. py:attribute:: label_shape :value: None .. py:attribute:: num_classes :value: None .. py:attribute:: num_predict_samples :value: 8 .. py:attribute:: num_test_samples :value: 8 .. py:attribute:: num_train_samples :value: 128 .. py:attribute:: num_val_samples :value: 8 .. py:attribute:: predict_data :value: None .. py:method:: predict_dataloader() An iterable or collection of iterables specifying prediction samples. For more information about multiple dataloaders, see this :ref:`section `. It's recommended that all data downloads and preparation happen in :meth:`prepare_data`. - :meth:`~lightning.pytorch.trainer.trainer.Trainer.predict` - :meth:`prepare_data` - :meth:`setup` Note: Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself. Return: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying prediction samples. .. py:method:: setup(stage) Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP. Args: stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'`` Example:: class LitModel(...): def __init__(self): self.l1 = None def prepare_data(self): download_data() tokenize() # don't do this self.something = else def setup(self, stage): data = load_data(...) self.l1 = nn.Linear(28, data.num_classes) .. py:attribute:: test_data :value: None .. py:method:: test_dataloader() An iterable or collection of iterables specifying test samples. For more information about multiple dataloaders, see this :ref:`section `. For data processing use the following pattern: - download in :meth:`prepare_data` - process and split in :meth:`setup` However, the above are only necessary for distributed processing. .. warning:: do not assign state in prepare_data - :meth:`~lightning.pytorch.trainer.trainer.Trainer.test` - :meth:`prepare_data` - :meth:`setup` Note: Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself. Note: If you don't need a test dataset and a :meth:`test_step`, you don't need to implement this method. .. py:attribute:: train_data :value: None .. py:method:: train_dataloader() An iterable or collection of iterables specifying training samples. For more information about multiple dataloaders, see this :ref:`section `. The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer. For data processing use the following pattern: - download in :meth:`prepare_data` - process and split in :meth:`setup` However, the above are only necessary for distributed processing. .. warning:: do not assign state in prepare_data - :meth:`~lightning.pytorch.trainer.trainer.Trainer.fit` - :meth:`prepare_data` - :meth:`setup` Note: Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself. .. py:attribute:: val_data :value: None .. py:method:: val_dataloader() An iterable or collection of iterables specifying validation samples. For more information about multiple dataloaders, see this :ref:`section `. The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer. It's recommended that all data downloads and preparation happen in :meth:`prepare_data`. - :meth:`~lightning.pytorch.trainer.trainer.Trainer.fit` - :meth:`~lightning.pytorch.trainer.trainer.Trainer.validate` - :meth:`prepare_data` - :meth:`setup` Note: Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself. Note: If you don't need a validation dataset and a :meth:`validation_step`, you don't need to implement this method. .. py:class:: SimpleDataset(data, label = None) .. py:method:: __getitem__(idx) .. py:method:: __len__() .. py:attribute:: data .. py:attribute:: label :value: None .. py:function:: full_dataset_from_dataloader(dataloader) .. py:function:: get_full_data_split(data_module, stage) .. py:function:: get_split_dataloader(data_module, stage)