from typing import Any, Iterable, List, Optional, Tuple, Union
from torch.utils.data import Dataset
from minerva.data.readers.reader import _Reader
from minerva.transforms.transform import _Transform
[docs]
class SimpleDataset(Dataset):
"""Dataset is responsible for loading data from multiple readers and
responsible for loading data from multiple readers and
apply specified transforms. It is a generic implementation that can be
used to create differents dataset, from supervised to unsupervised ones.
This class implements the common pipeline for reading and transforming data.
The __getitem__ pipeline is as follows:
For each reader R and transform list T:
1. Read the data from the reader R at the index idx.
2. Apply the transforms T to the data.
3. Append the transformed data to the list of data.
Return the tuple of transformed data.
"""
def __init__(
self,
readers: Union[_Reader, List[_Reader]],
transforms: Optional[Union[_Transform, List[_Transform]]] = None,
return_single: bool = False,
):
"""Load data from multiple sources and apply specified transforms.
Parameters
----------
readers : Union[_Reader, List[_Reader]]
The list of readers to load data from. It can be a single reader or
a list of readers.
transforms : Optional[Union[_Transform, List[_Transform]]], optional
The list of transforms to apply to each sample. This can be:
- None, in which case no transform is applied.
- A single transform, in which case the same transform is applied
to data from all readers.
- A list of transforms, in which case each transform is applied
to the corresponding reader. That is, the first transform is
applied to the first reader, the second transform is applied to
the second reader, and so on.
return_single : bool, optional
If True, the __getitem__ method will return a single sample when
a single reader is used. This is useful for unsupervised datasets,
where we usually have a single reader. If False, the __getitem__
method will return a tuple of samples, where each sample is from a
different reader, from same index. This is useful for supervised
datasets, where the data from different readers are related and
should be returned together. The default is False.
Examples
--------
1. Supervised Dataset:
```python
from minerva.data.readers import ImageReader, LabelReader
from minerva.transforms import ImageTransform, LabelTransform
from minerva.data.datasets import SimpleDataset
# Create the readers
image_reader = ImageReader("path/to/images")
label_reader = LabelReader("path/to/labels")
# Create the transforms
image_transform = ImageTransform()
label_transform = None # No transform for the labels
# Create the dataset
dataset = SimpleDataset(
readers=[image_reader, label_reader],
transforms=[image_transform, label_transform]
)
dataset[0] # Returns [image, label]
```
2. Unsupervised Dataset:
```python
from minerva.data.readers import ImageReader
from minerva.transforms import ImageTransform
from minerva.data.datasets import SimpleDataset
# Create the reader
image_reader = ImageReader("path/to/images")
# Create the transform
image_transform = ImageTransform()
# Create the dataset
dataset = SimpleDataset(
readers=[image_reader],
transforms=image_transform,
return_single=True
)
dataset[0] # Returns image
```
"""
self.readers = readers
self.transforms = transforms
self.return_single = return_single
# ---------------- Parsing readers ----------------
if not isinstance(self.readers, Iterable):
self.readers = [self.readers]
# ---------------- Parsing transforms ----------------
# If no transform is provided, use the identity transform.
# It will generate a list of None transforms with the same length
# as the number of readers.
if self.transforms is None:
self.transforms = [None] * len(self.readers)
# If a single transform is provided, use the same transform for all
# readers, that is, generate a list of the same transform with the same
# length as the number of readers.
if not isinstance(self.transforms, Iterable):
self.transforms = [self.transforms] * len(self.readers)
# ---------------- Validating objects ----------------
assert len(self.readers) == len(
self.transforms
), "The number of readers and transforms must be the same."
# If return_single is True, there must be only one reader.
assert (
not self.return_single or len(self.readers) == 1
), "If return_single is True, there must be only one reader."
[docs]
def __len__(self) -> int:
"""The length of the dataset is the length of the first reader.
Returns
-------
int
The number of samples in the dataset.
"""
return len(self.readers[0])
[docs]
def __getitem__(self, idx: int) -> Union[Any, Tuple[Any, ...]]:
"""Load data from multiple sources and apply specified transforms.
Parameters
----------
idx : int
The index of the sample to load.
Returns
-------
List[Any]
A list of transformed data from each reader.
"""
data = []
# For each reader and transform, read the data and apply the transform.
# Then, append the transformed data to the list of data.
for reader, transform in zip(self.readers, self.transforms):
sample = reader[idx]
# Apply the transform if it is not None
if transform is not None:
sample = transform(sample)
data.append(sample)
# Return the list of transformed data or a single sample if return_single
# is True and there is only one reader.
if self.return_single:
return data[0]
else:
return tuple(data)
[docs]
def __str__(self) -> str:
readers = self.readers if isinstance(self.readers, list) else [self.readers]
transforms = self.transforms if isinstance(self.transforms, list) else [self.transforms]
readers_info = "\n".join(
[
f" └── Reader {i}: {reader}\n │ └── Transform: {transform}"
for i, (reader, transform) in enumerate(zip(readers, transforms))
]
)
return (
f"{'=' * 50}\n"
f"{'📂 SimpleDataset Information':^50}\n"
f"{'=' * 50}\n"
f"📌 Dataset Type: {self.__class__.__name__}\n"
f"{readers_info}\n"
f" │\n"
f" └── Total Readers: {len(self.readers)}\n"
f"{'=' * 50}"
)
[docs]
def __repr__(self) -> str:
return self.__str__()