minerva.data.data_modules
Submodules
Classes
A DataModule standardizes the training, val, test splits, data preparation and transforms. The main advantage is |
Package Contents
- class minerva.data.data_modules.MinervaDataModule(train_dataset=None, val_dataset=None, test_dataset=None, predict_split='test', dataloader_cls=DataLoader, batch_size=1, num_workers=0, drop_last=False, additional_train_dataloader_kwargs=None, additional_val_dataloader_kwargs=None, additional_test_dataloader_kwargs=None, shuffle_train=True, name='')[source]
Bases:
lightning.LightningDataModule
A DataModule standardizes the training, val, test splits, data preparation and transforms. The main advantage is consistent data splits, data preparation and transforms across models.
Example:
import lightning as L import torch.utils.data as data from lightning.pytorch.demos.boring_classes import RandomDataset class MyDataModule(L.LightningDataModule): def prepare_data(self): # download, IO, etc. Useful with shared filesystems # only called on 1 GPU/TPU in distributed ... def setup(self, stage): # make assignments here (val/train/test split) # called on every process in DDP dataset = RandomDataset(1, 100) self.train, self.val, self.test = data.random_split( dataset, [80, 10, 10], generator=torch.Generator().manual_seed(42) ) def train_dataloader(self): return data.DataLoader(self.train) def val_dataloader(self): return data.DataLoader(self.val) def test_dataloader(self): return data.DataLoader(self.test) def on_exception(self, exception): # clean up state after the trainer faced an exception ... def teardown(self): # clean up state after the trainer stops, delete files... # called on every process in DDP ...
A fully-featured data module for PyTorch Lightning with support for acessing train, val, test, and predict datasets and dataloaders. This class is a generalization of the LightningDataModule class and is designed to be used with the Minerva framework.
Parameters
- train_datasetOptional[Dataset], optional
The training dataset, by default None
- val_datasetOptional[Dataset], optional
The validation dataset, by default None
- test_datasetOptional[Dataset], optional
The test dataset, by default None
- predict_splitOptional[str], optional
Set the split to predict on (using the predict_dataloader method), by default “test”
- dataloader_clstype, optional
The dataloader class to use. The datasets will be wrapped in this class when creating the dataloaders, by default DataLoader
- batch_sizeint, optional
Default batch_size for all dataloaders, by default 1
- num_workersint, optional
Default num_workers for all dataloaders, by default 0
- drop_lastbool, optional
Default drop_last for all dataloaders, by default False
- additional_train_dataloader_kwargsOptional[dict], optional
Override the default train dataloader kwargs, by default None
- additional_val_dataloader_kwargsOptional[dict], optional
Override the default val dataloader kwargs, by default None
- additional_test_dataloader_kwargsOptional[dict], optional
Override the default test dataloader kwargs, by default None
- shuffle_trainbool, optional
If True, shuffle the training dataset. If False, do not shuffle the training dataset, by default True. By default, only the training dataloader is shuffled.
- namestr, optional
Name of the data module, by default “”
- __str__()[source]
Return a string representation of the datasets that are set up.
- Returns:
A string representation of the datasets that are setup.
- Return type:
str
- static __update_dataloader_kwargs(additional_kwargs, batch_size, num_workers, drop_last, shuffle)
- _batch_size = 1
- _dataloader_cls
- _name = ''
- _num_workers = 0
- _predict_split = 'test'
- _shuffle_train = True
- _test_dataloader_kwargs
- _test_dataset = None
- _train_dataloader_kwargs
- _train_dataset = None
- _val_dataloader_kwargs
- _val_dataset = None
- property dataset_name
- predict_dataloader()[source]
An iterable or collection of iterables specifying prediction samples.
For more information about multiple dataloaders, see this section.
It’s recommended that all data downloads and preparation happen in
prepare_data()
.predict()
prepare_data()
setup()
- Note:
Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
- Return:
A
torch.utils.data.DataLoader
or a sequence of them specifying prediction samples.
- property predict_dataset
- test_dataloader()[source]
An iterable or collection of iterables specifying test samples.
For more information about multiple dataloaders, see this section.
For data processing use the following pattern:
download in
prepare_data()
process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
test()
prepare_data()
setup()
- Note:
Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
- Note:
If you don’t need a test dataset and a
test_step()
, you don’t need to implement this method.
- property test_dataset
- train_dataloader()[source]
An iterable or collection of iterables specifying training samples.
For more information about multiple dataloaders, see this section.
The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
For data processing use the following pattern:
download in
prepare_data()
process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
fit()
prepare_data()
setup()
- Note:
Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
- property train_dataset
- val_dataloader()[source]
An iterable or collection of iterables specifying validation samples.
For more information about multiple dataloaders, see this section.
The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
It’s recommended that all data downloads and preparation happen in
prepare_data()
.fit()
validate()
prepare_data()
setup()
- Note:
Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
- Note:
If you don’t need a validation dataset and a
validation_step()
, you don’t need to implement this method.
- property val_dataset
- Parameters:
train_dataset (Optional[torch.utils.data.Dataset])
val_dataset (Optional[torch.utils.data.Dataset])
test_dataset (Optional[torch.utils.data.Dataset])
predict_split (Optional[str])
dataloader_cls (type)
batch_size (int)
num_workers (int)
drop_last (bool)
additional_train_dataloader_kwargs (Optional[dict])
additional_val_dataloader_kwargs (Optional[dict])
additional_test_dataloader_kwargs (Optional[dict])
shuffle_train (bool)
name (str)