minerva.data.data_modules.base

Classes

MinervaDataModule

A DataModule standardizes the training, val, test splits, data preparation and transforms. The main advantage is

Module Contents

class minerva.data.data_modules.base.MinervaDataModule(train_dataset=None, val_dataset=None, test_dataset=None, predict_split='test', dataloader_cls=DataLoader, batch_size=1, num_workers=0, drop_last=False, additional_train_dataloader_kwargs=None, additional_val_dataloader_kwargs=None, additional_test_dataloader_kwargs=None, shuffle_train=True, name='')[source]

Bases: lightning.LightningDataModule

A DataModule standardizes the training, val, test splits, data preparation and transforms. The main advantage is consistent data splits, data preparation and transforms across models.

Example:

import lightning as L
import torch.utils.data as data
from lightning.pytorch.demos.boring_classes import RandomDataset

class MyDataModule(L.LightningDataModule):
    def prepare_data(self):
        # download, IO, etc. Useful with shared filesystems
        # only called on 1 GPU/TPU in distributed
        ...

    def setup(self, stage):
        # make assignments here (val/train/test split)
        # called on every process in DDP
        dataset = RandomDataset(1, 100)
        self.train, self.val, self.test = data.random_split(
            dataset, [80, 10, 10], generator=torch.Generator().manual_seed(42)
        )

    def train_dataloader(self):
        return data.DataLoader(self.train)

    def val_dataloader(self):
        return data.DataLoader(self.val)

    def test_dataloader(self):
        return data.DataLoader(self.test)

    def on_exception(self, exception):
        # clean up state after the trainer faced an exception
        ...

    def teardown(self):
        # clean up state after the trainer stops, delete files...
        # called on every process in DDP
        ...

A fully-featured data module for PyTorch Lightning with support for acessing train, val, test, and predict datasets and dataloaders. This class is a generalization of the LightningDataModule class and is designed to be used with the Minerva framework.

Parameters

train_datasetOptional[Dataset], optional

The training dataset, by default None

val_datasetOptional[Dataset], optional

The validation dataset, by default None

test_datasetOptional[Dataset], optional

The test dataset, by default None

predict_splitOptional[str], optional

Set the split to predict on (using the predict_dataloader method), by default “test”

dataloader_clstype, optional

The dataloader class to use. The datasets will be wrapped in this class when creating the dataloaders, by default DataLoader

batch_sizeint, optional

Default batch_size for all dataloaders, by default 1

num_workersint, optional

Default num_workers for all dataloaders, by default 0

drop_lastbool, optional

Default drop_last for all dataloaders, by default False

additional_train_dataloader_kwargsOptional[dict], optional

Override the default train dataloader kwargs, by default None

additional_val_dataloader_kwargsOptional[dict], optional

Override the default val dataloader kwargs, by default None

additional_test_dataloader_kwargsOptional[dict], optional

Override the default test dataloader kwargs, by default None

shuffle_trainbool, optional

If True, shuffle the training dataset. If False, do not shuffle the training dataset, by default True. By default, only the training dataloader is shuffled.

namestr, optional

Name of the data module, by default “”

__repr__()[source]
Return type:

str

__str__()[source]

Return a string representation of the datasets that are set up.

Returns:

A string representation of the datasets that are setup.

Return type:

str

static __update_dataloader_kwargs(additional_kwargs, batch_size, num_workers, drop_last, shuffle)
_batch_size = 1
_dataloader_cls
_name = ''
_num_workers = 0
_predict_split = 'test'
_shuffle_train = True
_test_dataloader_kwargs
_test_dataset = None
_train_dataloader_kwargs
_train_dataset = None
_val_dataloader_kwargs
_val_dataset = None
property dataset_name
predict_dataloader()[source]

An iterable or collection of iterables specifying prediction samples.

For more information about multiple dataloaders, see this section.

It’s recommended that all data downloads and preparation happen in prepare_data().

  • predict()

  • prepare_data()

  • setup()

Note:

Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Return:

A torch.utils.data.DataLoader or a sequence of them specifying prediction samples.

property predict_dataset
test_dataloader()[source]

An iterable or collection of iterables specifying test samples.

For more information about multiple dataloaders, see this section.

For data processing use the following pattern:

  • download in prepare_data()

  • process and split in setup()

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

  • test()

  • prepare_data()

  • setup()

Note:

Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Note:

If you don’t need a test dataset and a test_step(), you don’t need to implement this method.

property test_dataset
train_dataloader()[source]

An iterable or collection of iterables specifying training samples.

For more information about multiple dataloaders, see this section.

The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

For data processing use the following pattern:

  • download in prepare_data()

  • process and split in setup()

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

  • fit()

  • prepare_data()

  • setup()

Note:

Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

property train_dataset
val_dataloader()[source]

An iterable or collection of iterables specifying validation samples.

For more information about multiple dataloaders, see this section.

The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

It’s recommended that all data downloads and preparation happen in prepare_data().

  • fit()

  • validate()

  • prepare_data()

  • setup()

Note:

Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Note:

If you don’t need a validation dataset and a validation_step(), you don’t need to implement this method.

property val_dataset
Parameters:
  • train_dataset (Optional[torch.utils.data.Dataset])

  • val_dataset (Optional[torch.utils.data.Dataset])

  • test_dataset (Optional[torch.utils.data.Dataset])

  • predict_split (Optional[str])

  • dataloader_cls (type)

  • batch_size (int)

  • num_workers (int)

  • drop_last (bool)

  • additional_train_dataloader_kwargs (Optional[dict])

  • additional_val_dataloader_kwargs (Optional[dict])

  • additional_test_dataloader_kwargs (Optional[dict])

  • shuffle_train (bool)

  • name (str)