Source code for minerva.data.datasets.binary_tree_subset

from torch.utils.data import Dataset, Subset
from math import ceil, floor


[docs] def build_indices(size: int, start: int, end: int): """ Recursively builds a list of `size` indices that are approximately evenly distributed across the interval [start, end) using a divide-and-conquer midpoint strategy. Parameters ---------- size : int The number of indices to generate. start : int The start of the interval (inclusive). end : int The end of the interval (exclusive). Returns ------- List[int] A list of indices of length `size`, approximately evenly spaced within the given interval. """ if (end <= start) or (size <= 0): return [] midpoint = (end + start) // 2 remainder = size - 1 left_apportion = ceil(remainder / 2) right_apportion = floor(remainder / 2) right_indices = build_indices(left_apportion, start, midpoint) left_indices = build_indices(right_apportion, midpoint + 1, end) return right_indices + [midpoint] + left_indices
[docs] class BinaryTreeSubset(Subset): def __init__(self, dataset: Dataset, size: int): """ A subset of a PyTorch Dataset whose elements are selected using a binary tree-style midpoint sampling strategy for approximate even distribution. This is useful for tasks such as hierarchical sampling or balanced data reduction, where a representative subset of a dataset is desired while preserving some notion of coverage across the index space. Parameters ---------- dataset : Dataset The base dataset from which to create the subset. size : int The number of samples to include in the subset. Must be positive and no greater than the length of the base dataset. Raises ------ ValueError If `size` is non-positive or exceeds the size of the dataset. """ if size <= 0: raise ValueError(f"`size` must be a positive integer, but got {size=}") len_base = len(dataset) # type: ignore if size > len_base: raise ValueError( f"Cannot create a subset of size {size} " f"because the base dataset has a size of {len_base}" ) super().__init__(dataset, build_indices(size, 0, len_base))
[docs] def __str__(self): return f"{self.dataset} Binary Tree Subset with {len(self.indices)} samples"