minerva.transforms.split_transform

Classes

SplitTransform

This class is a base class for all transforms. Transforms is just a

Module Contents

class minerva.transforms.split_transform.SplitTransform(num_splits=2, split_dimension=0)[source]

Bases: minerva.transforms.transform._Transform

This class is a base class for all transforms. Transforms is just a fancy word for a function that takes an input and returns an output. The input and output can be anything. However, transforms operates over a single sample of data and does not require any additional information to perform the transformation. The __call__ method should be overridden in subclasses to define the transformation logic.

A transform that splits the input data along some dimension. When applied to a dataset, this transform will split the input data into the specified number of splits.

Parameters

num_splitsint

The number of splits to divide the input into.

split_dimensionint

The dimension along which to split the input data.

__call__(x)[source]

Split the input data into the specified number of splits.

Parameters

xUnion[np.ndarray, torch.Tensor]

The input data to split.

Returns

Tuple

The split data.

Parameters:

x (Union[numpy.ndarray, torch.Tensor])

Return type:

Tuple

num_splits = 2
split_dimension = 0
Parameters:
  • num_splits (int)

  • split_dimension (int)