Source code for minerva.transforms.random_transform
from typing import List, Optional, Tuple, Union
import numpy as np
from minerva.transforms.transform import (
Crop,
Flip,
GrayScale,
Rotation,
Solarize,
_Transform,
)
[docs]
class _RandomSyncedTransform(_Transform):
"""Orchestrate the application of a type of random transform to a list of data, ensuring that the same random state is used for all of them."""
def __init__(self, num_samples: int = 1, seed: Optional[int] = None):
"""Orchestrate the application of a type of random transform to a list of data, ensuring that the same random state is used for all of them.
Parameters
----------
transform : _Transform
A transform that will be applied to the input data.
num_samples : int
The number of samples that will be transformed.
seed : Optional[int], optional
The seed that will be used to generate the random state, by default None.
"""
assert num_samples > 0, "num_samples must be greater than 0"
self.num_samples = num_samples
self.transformations_executed = 0
self.rng = np.random.default_rng(seed)
self.transform = EmptyTransform()
[docs]
def __call__(self, data):
if self.transformations_executed == 0:
self.transform = self.select_transform()
self.transformations_executed = (
self.transformations_executed + 1
) % self.num_samples
return self.transform(data)
[docs]
def select_transform(self):
raise NotImplementedError(
"This method should be implemented by the child class."
)
[docs]
class RandomFlip(_RandomSyncedTransform):
def __init__(
self,
num_samples: int = 1,
possible_axis: Union[int, List[int]] = 0,
seed: Optional[int] = None,
):
"""A transform that flips the input data along a random axis.
Parameters
----------
num_samples : int
The number of samples that will be transformed.
possible_axis : Union[int, List[int]], optional
Possible axis to be transformed, will be chosen at random, by default 0
seed : Optional[int], optional
A seed to ensure deterministic run, by default None
"""
super().__init__(num_samples, seed)
self.possible_axis = possible_axis
[docs]
def select_transform(self):
"""selects the transform to be applied to the data."""
if isinstance(self.possible_axis, int):
flip_axis = self.rng.choice([True, False])
if flip_axis:
return Flip(axis=self.possible_axis)
else:
flip_axis = [
bool(self.rng.choice([True, False]))
for _ in range(len(self.possible_axis))
]
if True in flip_axis:
chosen_axis = [
axis for axis, flip in zip(self.possible_axis, flip_axis) if flip
]
return Flip(axis=chosen_axis)
return EmptyTransform()
[docs]
class RandomCrop(_RandomSyncedTransform):
def __init__(
self,
crop_size: Tuple[int, int],
num_samples: int = 1,
seed: Optional[int] = None,
pad_mode: str = "reflect",
):
super().__init__(num_samples, seed)
self.crop_size = crop_size
self.pad_mode = pad_mode
[docs]
def select_transform(self):
X = self.rng.random()
Y = self.rng.random()
return Crop(output_size=self.crop_size, pad_mode=self.pad_mode, coords=(X, Y))
[docs]
class RandomGrayScale(_RandomSyncedTransform):
def __init__(
self,
num_samples: int = 1,
seed: Optional[int] = None,
prob: float = 0.1,
method: str = "luminosity",
):
super().__init__(num_samples, seed)
self.method = method
self.prob = prob
[docs]
def select_transform(self):
if self.rng.random() < self.prob:
return GrayScale(method=self.method)
else:
return EmptyTransform()
[docs]
class RandomSolarize(_RandomSyncedTransform):
def __init__(
self,
num_samples: int = 1,
seed: Optional[int] = None,
threshold: int = 128,
prob: float = 1.0,
):
super().__init__(num_samples, seed)
self.threshold = threshold
self.prob = prob
[docs]
def select_transform(self):
if self.rng.random() < self.prob:
return Solarize(self.threshold)
else:
return EmptyTransform()
[docs]
class RandomRotation(_RandomSyncedTransform):
def __init__(
self,
degrees: float,
prob: float,
num_samples: int = 1,
seed: Optional[int] = None,
):
"""
Randomly applies a rotation to the image with a specified probability.
Parameters
----------
degrees : float
Maximum absolute value of the rotation angle in degrees. The angle is sampled
uniformly from [-degrees, +degrees].
prob : float
Probability that the rotation will be applied.
num_samples : int, optional
Number of samples to generate per call (for contrastive learning), default is 1.
seed : int, optional
Seed for the random number generator, useful for reproducibility.
"""
super().__init__(num_samples=num_samples, seed=seed)
self.degrees = degrees
self.prob = prob
[docs]
def select_transform(self):
if self.rng.random() < self.prob:
angle = self.rng.uniform(-self.degrees, self.degrees)
return Rotation(degrees=angle)
else:
return EmptyTransform()