minerva.data.data_module_tools
Classes
A DataModule standardizes the training, val, test splits, data preparation and transforms. The main advantage is |
|
Functions
|
|
|
|
|
Module Contents
- class minerva.data.data_module_tools.RandomDataModule(data_shape, label_shape=None, num_classes=None, num_train_samples=128, num_val_samples=8, num_test_samples=8, num_predict_samples=8, batch_size=8, data_dtype=torch.float32, label_dtype=torch.float32)[source]
Bases:
lightning.LightningDataModule
A DataModule standardizes the training, val, test splits, data preparation and transforms. The main advantage is consistent data splits, data preparation and transforms across models.
Example:
import lightning as L import torch.utils.data as data from lightning.pytorch.demos.boring_classes import RandomDataset class MyDataModule(L.LightningDataModule): def prepare_data(self): # download, IO, etc. Useful with shared filesystems # only called on 1 GPU/TPU in distributed ... def setup(self, stage): # make assignments here (val/train/test split) # called on every process in DDP dataset = RandomDataset(1, 100) self.train, self.val, self.test = data.random_split( dataset, [80, 10, 10], generator=torch.Generator().manual_seed(42) ) def train_dataloader(self): return data.DataLoader(self.train) def val_dataloader(self): return data.DataLoader(self.val) def test_dataloader(self): return data.DataLoader(self.test) def on_exception(self, exception): # clean up state after the trainer faced an exception ... def teardown(self): # clean up state after the trainer stops, delete files... # called on every process in DDP ...
- Attributes:
- prepare_data_per_node:
If True, each LOCAL_RANK=0 will call prepare data. Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data.
- allow_zero_length_dataloader_with_multiple_devices:
If True, dataloader with zero length within local rank is allowed. Default value is False.
- Parameters:
data_shape (Tuple[int, Ellipsis])
label_shape (Union[int, Tuple[int, Ellipsis], None])
num_classes (Optional[int])
num_train_samples (int)
num_val_samples (int)
num_test_samples (int)
num_predict_samples (int)
batch_size (int)
data_dtype (torch.dtype)
label_dtype (torch.dtype)
- batch_size = 8
- data_dtype = Ellipsis
- data_shape
- label_dtype = Ellipsis
- label_shape = None
- num_classes = None
- num_predict_samples = 8
- num_test_samples = 8
- num_train_samples = 128
- num_val_samples = 8
- predict_data = None
- predict_dataloader()[source]
An iterable or collection of iterables specifying prediction samples.
For more information about multiple dataloaders, see this section.
It’s recommended that all data downloads and preparation happen in
prepare_data()
.predict()
prepare_data()
- Note:
Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
- Return:
A
torch.utils.data.DataLoader
or a sequence of them specifying prediction samples.
- setup(stage)[source]
Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.
- Args:
stage: either
'fit'
,'validate'
,'test'
, or'predict'
Example:
class LitModel(...): def __init__(self): self.l1 = None def prepare_data(self): download_data() tokenize() # don't do this self.something = else def setup(self, stage): data = load_data(...) self.l1 = nn.Linear(28, data.num_classes)
- test_data = None
- test_dataloader()[source]
An iterable or collection of iterables specifying test samples.
For more information about multiple dataloaders, see this section.
For data processing use the following pattern:
download in
prepare_data()
process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
test()
prepare_data()
- Note:
Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
- Note:
If you don’t need a test dataset and a
test_step()
, you don’t need to implement this method.
- train_data = None
- train_dataloader()[source]
An iterable or collection of iterables specifying training samples.
For more information about multiple dataloaders, see this section.
The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
For data processing use the following pattern:
download in
prepare_data()
process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
fit()
prepare_data()
- Note:
Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
- val_data = None
- val_dataloader()[source]
An iterable or collection of iterables specifying validation samples.
For more information about multiple dataloaders, see this section.
The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
It’s recommended that all data downloads and preparation happen in
prepare_data()
.fit()
validate()
prepare_data()
- Note:
Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
- Note:
If you don’t need a validation dataset and a
validation_step()
, you don’t need to implement this method.
- class minerva.data.data_module_tools.SimpleDataset(data, label=None)[source]
- Parameters:
data (torch.Tensor)
label (Optional[torch.Tensor])
- data
- label = None
- minerva.data.data_module_tools.full_dataset_from_dataloader(dataloader)[source]
- Parameters:
dataloader (torch.utils.data.DataLoader)