Source code for minerva.transforms.split_transform
import numpy as np
import torch
from .transform import _Transform
from typing import Union, Tuple
[docs]
class SplitTransform(_Transform):
def __init__(self, num_splits: int = 2, split_dimension: int = 0):
"""A transform that splits the input data along some dimension.
When applied to a dataset, this transform will split the input data into
the specified number of splits.
Parameters
----------
num_splits : int
The number of splits to divide the input into.
split_dimension : int
The dimension along which to split the input data.
"""
super().__init__()
self.num_splits = num_splits
self.split_dimension = split_dimension
if num_splits <= 0:
raise ValueError(
f"Expected input 'num_splits' to be a positive integer greater than 0, but received {num_splits}."
)
if split_dimension < 0:
raise ValueError(
f"Expected input 'split_dimension' to be a positive integer greater than or equal to 0, but received {split_dimension}."
)
[docs]
def __call__(self, x: Union[np.ndarray, torch.Tensor]) -> Tuple:
"""Split the input data into the specified number of splits.
Parameters
----------
x : Union[np.ndarray, torch.Tensor]
The input data to split.
Returns
-------
Tuple
The split data.
"""
if not isinstance(x, (np.ndarray, torch.Tensor)):
raise TypeError(
f"Expected input 'x' to be either a numpy array or a Pytorch tensor, but received an object of type {type(x)}."
)
if self.split_dimension >= len(x.shape):
raise ValueError(
f"Invalid split dimension: expected the split dimension to be less than {len(x.shape)}, but received {self.split_dimension}."
)
if x.shape[self.split_dimension] % self.num_splits != 0:
raise ValueError(
f"Invalid split: expected {self.num_splits} to divide equally the dimension {x.shape[self.split_dimension]}."
)
if isinstance(x, np.ndarray):
splits = np.split(x, self.num_splits, axis=self.split_dimension)
elif isinstance(x, torch.Tensor):
splits = torch.split(
x,
x.shape[self.split_dimension] // self.num_splits,
dim=self.split_dimension,
)
return splits