minerva.samplers.domain_sampler

Classes

RandomDomainSampler

Base class for all Samplers.

Module Contents

class minerva.samplers.domain_sampler.RandomDomainSampler(dataset, domain_labels, batch_size=1, n_domains_per_sample=2, shuffle=True, consistent_iterating=False)[source]

Bases: torch.utils.data.sampler.Sampler

Base class for all Samplers.

Every Sampler subclass has to provide an __iter__() method, providing a way to iterate over indices or lists of indices (batches) of dataset elements, and may provide a __len__() method that returns the length of the returned iterators.

Args:
data_source (Dataset): This argument is not used and will be removed in 2.2.0.

You may still have custom implementation that utilizes it.

Example:
>>> # xdoctest: +SKIP
>>> class AccedingSequenceLengthSampler(Sampler[int]):
>>>     def __init__(self, data: List[str]) -> None:
>>>         self.data = data
>>>
>>>     def __len__(self) -> int:
>>>         return len(self.data)
>>>
>>>     def __iter__(self) -> Iterator[int]:
>>>         sizes = torch.tensor([len(x) for x in self.data])
>>>         yield from torch.argsort(sizes).tolist()
>>>
>>> class AccedingSequenceLengthBatchSampler(Sampler[List[int]]):
>>>     def __init__(self, data: List[str], batch_size: int) -> None:
>>>         self.data = data
>>>         self.batch_size = batch_size
>>>
>>>     def __len__(self) -> int:
>>>         return (len(self.data) + self.batch_size - 1) // self.batch_size
>>>
>>>     def __iter__(self) -> Iterator[List[int]]:
>>>         sizes = torch.tensor([len(x) for x in self.data])
>>>         for batch in torch.chunk(torch.argsort(sizes), len(self)):
>>>             yield batch.tolist()

Note

The __len__() method isn’t strictly required by DataLoader, but is expected in any calculation involving the length of a DataLoader.

Sample data from multiple domains in a balanced way. If domains have different number of samples, the number of samples will be the minimum number of samples for each domain.

Parameters

datasetDataset

The dataset to sample from.

domain_labelsList[int]

The domain labels for each sample in the dataset.

batch_sizeint, optional

The number of samples for each domain in a batch, by default 1. The effective batch size will be batch_size * n_domains_per_sample.

n_domains_per_sampleint, optional

The number of domains to sample from in each batch, by default 2. Note that, the domain labels must have at least n_domains_per_sample distinct domains.

shufflebool, optional

Shuffle the samples in each domain before sampling, by default True

consistent_iteratingbool, optional

As each domain may have different number of samples, in different iterations, the same samples may not be used. If True, the same samples will be used in every iteration, by default False.

__iter__()[source]
__len__()[source]
_select_samples()[source]
batch_size = 1
cached = None
consistent_iterating = False
dataset
domain_labels
domains
min_batches
n_domains_per_sample = 2
rng
seed
shuffle = True
Parameters:
  • dataset (torch.utils.data.Dataset)

  • domain_labels (List[int])

  • batch_size (int)

  • n_domains_per_sample (int)

  • shuffle (bool)

  • consistent_iterating (bool)