Source code for minerva.transforms.split_transform
import numpy as np
import torch
from .transform import _Transform
from typing import Union, Tuple
[docs]
class SplitTransform(_Transform):
def __init__(
self,
num_splits: int = 2,
split_dimension: int = 0
):
"""A transform that splits the input data along some dimension.
When applied to a dataset, this transform will split the input data into
the specified number of splits.
Parameters
----------
num_splits : int
The number of splits to divide the input into.
split_dimension : int
The dimension along which to split the input data.
"""
super().__init__()
self.num_splits = num_splits
self.split_dimension = split_dimension
if num_splits <= 0:
raise ValueError(f"Expected input 'num_splits' to be a positive integer greater than 0, but received {num_splits}.")
if split_dimension < 0:
raise ValueError(f"Expected input 'split_dimension' to be a positive integer greater than or equal to 0, but received {split_dimension}.")
[docs]
def __call__(self, x: Union[np.ndarray, torch.Tensor]) -> Tuple:
"""Split the input data into the specified number of splits.
Parameters
----------
x : Union[np.ndarray, torch.Tensor]
The input data to split.
Returns
-------
Tuple
The split data.
"""
if not isinstance(x, (np.ndarray, torch.Tensor)):
raise TypeError(f"Expected input 'x' to be either a numpy array or a Pytorch tensor, but received an object of type {type(x)}.")
if self.split_dimension >= len(x.shape):
raise ValueError(f"Invalid split dimension: expected the split dimension to be less than {len(x.shape)}, but received {self.split_dimension}.")
if x.shape[self.split_dimension] % self.num_splits != 0:
raise ValueError(f"Invalid split: expected {self.num_splits} to divide equally the dimension {x.shape[self.split_dimension]}.")
if isinstance(x, np.ndarray):
splits = np.split(x, self.num_splits, axis=self.split_dimension)
elif isinstance(x, torch.Tensor):
splits = torch.split(x, x.shape[self.split_dimension] // self.num_splits, dim=self.split_dimension)
return splits