minerva.utils.data

Classes

RandomDataModule

SimpleDataset

Functions

full_dataset_from_dataloader(dataloader)

get_full_data_split(data_module, stage)

get_split_dataloader(stage, data_module)

Module Contents

class minerva.utils.data.RandomDataModule(data_shape, label_shape=None, num_classes=None, num_train_samples=128, num_val_samples=8, num_test_samples=8, num_predict_samples=8, batch_size=8)

Bases: lightning.LightningDataModule

Parameters:
  • data_shape (Tuple[int, Ellipsis])

  • label_shape (int | Tuple[int, Ellipsis])

  • num_classes (int)

  • num_train_samples (int)

  • num_val_samples (int)

  • num_test_samples (int)

  • num_predict_samples (int)

  • batch_size (int)

_generate_data(num_samples, data_shape, label_shape, num_classes)
predict_dataloader()
setup(stage)
test_dataloader()
train_dataloader()
val_dataloader()
class minerva.utils.data.SimpleDataset(data, label=None)
Parameters:
  • data (torch.Tensor)

  • label (torch.Tensor)

__getitem__(idx)
__len__()
minerva.utils.data.full_dataset_from_dataloader(dataloader)
Parameters:

dataloader (torch.utils.data.DataLoader)

minerva.utils.data.get_full_data_split(data_module, stage)
Parameters:
  • data_module (lightning.LightningDataModule)

  • stage (str)

minerva.utils.data.get_split_dataloader(stage, data_module)
Parameters:
  • stage (str)

  • data_module (lightning.LightningDataModule)

Return type:

torch.utils.data.DataLoader