Source code for minerva.transforms.split_transform
import numpy as np
import torch
from .transform import _Transform
from typing import Union, Tuple
[docs]
class SplitTransform(_Transform):
    def __init__(self, num_splits: int = 2, split_dimension: int = 0):
        """A transform that splits the input data along some dimension.
        When applied to a dataset, this transform will split the input data into
        the specified number of splits.
        Parameters
        ----------
        num_splits : int
            The number of splits to divide the input into.
        split_dimension : int
            The dimension along which to split the input data.
        """
        super().__init__()
        self.num_splits = num_splits
        self.split_dimension = split_dimension
        if num_splits <= 0:
            raise ValueError(
                f"Expected input 'num_splits' to be a positive integer greater than 0, but received {num_splits}."
            )
        if split_dimension < 0:
            raise ValueError(
                f"Expected input 'split_dimension' to be a positive integer greater than or equal to 0, but received {split_dimension}."
            )
[docs]
    def __call__(self, x: Union[np.ndarray, torch.Tensor]) -> Tuple:
        """Split the input data into the specified number of splits.
        Parameters
        ----------
        x : Union[np.ndarray, torch.Tensor]
            The input data to split.
        Returns
        -------
        Tuple
            The split data.
        """
        if not isinstance(x, (np.ndarray, torch.Tensor)):
            raise TypeError(
                f"Expected input 'x' to be either a numpy array or a Pytorch tensor, but received an object of type {type(x)}."
            )
        if self.split_dimension >= len(x.shape):
            raise ValueError(
                f"Invalid split dimension: expected the split dimension to be less than {len(x.shape)}, but received {self.split_dimension}."
            )
        if x.shape[self.split_dimension] % self.num_splits != 0:
            raise ValueError(
                f"Invalid split: expected {self.num_splits} to divide equally the dimension {x.shape[self.split_dimension]}."
            )
        if isinstance(x, np.ndarray):
            splits = np.split(x, self.num_splits, axis=self.split_dimension)
        elif isinstance(x, torch.Tensor):
            splits = torch.split(
                x,
                x.shape[self.split_dimension] // self.num_splits,
                dim=self.split_dimension,
            )
        return splits