Source code for minerva.transforms.random_transform
from typing import List, Optional, Tuple, Union
import numpy as np
from minerva.transforms.transform import (
Crop,
Flip,
GrayScale,
Rotation,
Solarize,
_Transform,
)
[docs]
class _RandomSyncedTransform(_Transform):
"""Orchestrate the application of a type of random transform to a list of data, ensuring that the same random state is used for all of them."""
def __init__(self, num_samples: int = 1, seed: Optional[int] = None):
"""Orchestrate the application of a type of random transform to a list of data, ensuring that the same random state is used for all of them.
Parameters
----------
transform : _Transform
A transform that will be applied to the input data.
num_samples : int
The number of samples that will be transformed.
seed : Optional[int], optional
The seed that will be used to generate the random state, by default None.
"""
self.num_samples = num_samples
self.transformations_executed = 0
self.rng = np.random.default_rng(seed)
self.transform = EmptyTransform()
[docs]
def __call__(self, data):
if self.transformations_executed == 0:
self.transform = self.select_transform()
self.transformations_executed += 1
return self.transform(data)
else:
if self.transformations_executed == self.num_samples - 1:
self.transformations_executed = 0
else:
self.transformations_executed += 1
return self.transform(data)
[docs]
def select_transform(self):
raise NotImplementedError(
"This method should be implemented by the child class."
)
[docs]
class RandomFlip(_RandomSyncedTransform):
def __init__(
self,
num_samples: int = 1,
possible_axis: Union[int, List[int]] = 0,
seed: Optional[int] = None,
):
"""A transform that flips the input data along a random axis.
Parameters
----------
num_samples : int
The number of samples that will be transformed.
possible_axis : Union[int, List[int]], optional
Possible axis to be transformed, will be chosen at random, by default 0
seed : Optional[int], optional
A seed to ensure deterministic run, by default None
"""
super().__init__(num_samples, seed)
self.possible_axis = possible_axis
[docs]
def select_transform(self):
"""selects the transform to be applied to the data."""
if isinstance(self.possible_axis, int):
flip_axis = self.rng.choice([True, False])
if flip_axis:
return Flip(axis=self.possible_axis)
else:
flip_axis = [
bool(self.rng.choice([True, False]))
for _ in range(len(self.possible_axis))
]
if True in flip_axis:
chosen_axis = [
axis for axis, flip in zip(self.possible_axis, flip_axis) if flip
]
return Flip(axis=chosen_axis)
return EmptyTransform()
[docs]
class RandomCrop(_RandomSyncedTransform):
def __init__(
self,
crop_size: Tuple[int, int],
num_samples: int = 1,
seed: Optional[int] = None,
pad_mode: str = "reflect",
):
super().__init__(num_samples, seed)
self.crop_size = crop_size
self.pad_mode = pad_mode
[docs]
def select_transform(self):
X = self.rng.random()
Y = self.rng.random()
return Crop(output_size=self.crop_size, pad_mode=self.pad_mode, coords=(X, Y))
[docs]
class RandomGrayScale(_RandomSyncedTransform):
def __init__(
self,
num_samples: int = 1,
seed: Optional[int] = None,
prob: float = 0.1,
gray: float = 1.0,
):
super().__init__(num_samples, seed)
self.gray = gray
self.prob = prob
[docs]
def select_transform(self):
if self.rng.random() < self.prob:
return GrayScale(gray=self.gray)
else:
return EmptyTransform()
[docs]
class RandomSolarize(_RandomSyncedTransform):
def __init__(
self,
num_samples: int = 1,
seed: Optional[int] = None,
threshold: int = 128,
prob: float = 1.0,
):
super().__init__(num_samples, seed)
self.threshold = threshold
self.prob = prob
[docs]
def select_transform(self):
if self.rng.random() < self.prob:
return Solarize(self.threshold)
else:
return EmptyTransform()
[docs]
class RandomRotation(_RandomSyncedTransform):
def __init__(
self,
degrees: float,
prob: float,
num_samples: int = 1,
seed: Optional[int] = None,
):
super().__init__(num_samples, seed)
self.prob = prob
self.degrees = degrees
[docs]
def select_transform(self):
if self.rng.random() < self.prob:
degrees = self.rng.uniform(-self.degrees, self.degrees)
return Rotation(degrees=degrees)
else:
return EmptyTransform()