Source code for minerva.samplers.domain_sampler

import random
from typing import List, Optional
from torch.utils.data.sampler import Sampler
from torch.utils.data import Dataset
from dataclasses import dataclass


[docs] class RandomDomainSampler(Sampler): def __init__( self, dataset: Dataset, domain_labels: List[int], batch_size: int = 1, n_domains_per_sample: int = 2, shuffle: bool = True, consistent_iterating: bool = False, ): """Sample data from multiple domains in a balanced way. If domains have different number of samples, the number of samples will be the minimum number of samples for each domain. Parameters ---------- dataset : Dataset The dataset to sample from. domain_labels : List[int] The domain labels for each sample in the dataset. batch_size : int, optional The number of samples for each domain in a batch, by default 1. The effective batch size will be batch_size * n_domains_per_sample. n_domains_per_sample : int, optional The number of domains to sample from in each batch, by default 2. Note that, the domain labels must have at least n_domains_per_sample distinct domains. shuffle : bool, optional Shuffle the samples in each domain before sampling, by default True consistent_iterating : bool, optional As each domain may have different number of samples, in different iterations, the same samples may not be used. If True, the same samples will be used in every iteration, by default False. """ self.dataset = dataset self.domain_labels = domain_labels self.batch_size = batch_size self.shuffle = shuffle self.consistent_iterating = consistent_iterating self.domains = set(domain_labels) self.min_batches = min( len([l for l in domain_labels if l == d]) // batch_size for d in self.domains ) self.n_domains_per_sample = n_domains_per_sample assert self.min_batches > 0, "Not enough samples for a batch" self.cached = None self.seed = random.random() self.rng = random.Random(self.seed)
[docs] def __len__(self): return ( self.min_batches * len(self.domains) ) // self.n_domains_per_sample
[docs] def _select_samples(self): indices = {} for d in self.domains: idxs = [i for i, l in enumerate(self.domain_labels) if l == d] if self.shuffle: random.shuffle(idxs) idxs = idxs[: self.min_batches * self.batch_size] indices[d] = idxs return indices
[docs] def __iter__(self): if self.consistent_iterating: if self.cached is None: self.cached = self._select_samples() indices = self.cached.copy() else: indices = self._select_samples() batches = [] if self.consistent_iterating: rng = random.Random(self.seed) else: rng = self.rng while True: batch = [] for i in range(self.n_domains_per_sample): if len(indices) == 0: break selected_domain = rng.choice(list(indices.keys())) idxs = indices[selected_domain] selected_indices = idxs[: self.batch_size] batch += selected_indices idxs = idxs[self.batch_size :] if len(idxs) < self.batch_size: del indices[selected_domain] else: indices[selected_domain] = idxs if len(batch) != self.batch_size * self.n_domains_per_sample: break batches.append(batch) yield from batches