minerva.data.data_module_tools

Classes

RandomDataModule

A DataModule standardizes the training, val, test splits, data preparation and transforms. The main advantage is

SimpleDataset

Functions

full_dataset_from_dataloader(dataloader)

get_full_data_split(data_module, stage)

get_split_dataloader(data_module, stage)

Module Contents

class minerva.data.data_module_tools.RandomDataModule(data_shape, label_shape=None, num_classes=None, num_train_samples=128, num_val_samples=8, num_test_samples=8, num_predict_samples=8, batch_size=8, data_dtype=torch.float32, label_dtype=torch.float32)[source]

Bases: lightning.LightningDataModule

A DataModule standardizes the training, val, test splits, data preparation and transforms. The main advantage is consistent data splits, data preparation and transforms across models.

Example:

import lightning as L
import torch.utils.data as data
from lightning.pytorch.demos.boring_classes import RandomDataset

class MyDataModule(L.LightningDataModule):
    def prepare_data(self):
        # download, IO, etc. Useful with shared filesystems
        # only called on 1 GPU/TPU in distributed
        ...

    def setup(self, stage):
        # make assignments here (val/train/test split)
        # called on every process in DDP
        dataset = RandomDataset(1, 100)
        self.train, self.val, self.test = data.random_split(
            dataset, [80, 10, 10], generator=torch.Generator().manual_seed(42)
        )

    def train_dataloader(self):
        return data.DataLoader(self.train)

    def val_dataloader(self):
        return data.DataLoader(self.val)

    def test_dataloader(self):
        return data.DataLoader(self.test)

    def on_exception(self, exception):
        # clean up state after the trainer faced an exception
        ...

    def teardown(self):
        # clean up state after the trainer stops, delete files...
        # called on every process in DDP
        ...
Attributes:
prepare_data_per_node:

If True, each LOCAL_RANK=0 will call prepare data. Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data.

allow_zero_length_dataloader_with_multiple_devices:

If True, dataloader with zero length within local rank is allowed. Default value is False.

Parameters:
  • data_shape (Tuple[int, Ellipsis])

  • label_shape (Union[int, Tuple[int, Ellipsis], None])

  • num_classes (Optional[int])

  • num_train_samples (int)

  • num_val_samples (int)

  • num_test_samples (int)

  • num_predict_samples (int)

  • batch_size (int)

  • data_dtype (torch.dtype)

  • label_dtype (torch.dtype)

_generate_data(num_samples, data_shape, label_shape, num_classes)[source]
batch_size = 8
data_dtype = Ellipsis
data_shape
label_dtype = Ellipsis
label_shape = None
num_classes = None
num_predict_samples = 8
num_test_samples = 8
num_train_samples = 128
num_val_samples = 8
predict_data = None
predict_dataloader()[source]

An iterable or collection of iterables specifying prediction samples.

For more information about multiple dataloaders, see this section.

It’s recommended that all data downloads and preparation happen in prepare_data().

Note:

Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Return:

A torch.utils.data.DataLoader or a sequence of them specifying prediction samples.

setup(stage)[source]

Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.

Args:

stage: either 'fit', 'validate', 'test', or 'predict'

Example:

class LitModel(...):
    def __init__(self):
        self.l1 = None

    def prepare_data(self):
        download_data()
        tokenize()

        # don't do this
        self.something = else

    def setup(self, stage):
        data = load_data(...)
        self.l1 = nn.Linear(28, data.num_classes)
test_data = None
test_dataloader()[source]

An iterable or collection of iterables specifying test samples.

For more information about multiple dataloaders, see this section.

For data processing use the following pattern:

  • download in prepare_data()

  • process and split in setup()

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

Note:

Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Note:

If you don’t need a test dataset and a test_step(), you don’t need to implement this method.

train_data = None
train_dataloader()[source]

An iterable or collection of iterables specifying training samples.

For more information about multiple dataloaders, see this section.

The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

For data processing use the following pattern:

  • download in prepare_data()

  • process and split in setup()

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

Note:

Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

val_data = None
val_dataloader()[source]

An iterable or collection of iterables specifying validation samples.

For more information about multiple dataloaders, see this section.

The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

It’s recommended that all data downloads and preparation happen in prepare_data().

  • fit()

  • validate()

  • prepare_data()

  • setup()

Note:

Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Note:

If you don’t need a validation dataset and a validation_step(), you don’t need to implement this method.

class minerva.data.data_module_tools.SimpleDataset(data, label=None)[source]
Parameters:
  • data (torch.Tensor)

  • label (Optional[torch.Tensor])

__getitem__(idx)[source]
__len__()[source]
data
label = None
minerva.data.data_module_tools.full_dataset_from_dataloader(dataloader)[source]
Parameters:

dataloader (torch.utils.data.DataLoader)

minerva.data.data_module_tools.get_full_data_split(data_module, stage)[source]
Parameters:
  • data_module (lightning.LightningDataModule)

  • stage (str)

minerva.data.data_module_tools.get_split_dataloader(data_module, stage)[source]
Parameters:
  • data_module (lightning.LightningDataModule)

  • stage (str)

Return type:

torch.utils.data.DataLoader