from itertools import product
from typing import Any, List, Optional, Sequence, Tuple, Union
import cv2
import numpy as np
import torch
from perlin_noise import PerlinNoise
[docs]
class Flip(_Transform):
"""Flip the input data along the specified axis."""
def __init__(self, axis: Union[int, List[int]] = 0):
"""Flip the input data along the specified axis.
Parameters
----------
axis : int | List[int], optional
One or more axis to flip the input data along, by default 0.
If a list of axis is provided, the input data is flipped along all the specified axis in the order they are provided.
"""
self.axis = axis
[docs]
def __call__(self, x: np.ndarray) -> np.ndarray:
"""Flip the input data along the specified axis.
if axis is an integer, the input data is flipped along the specified axis.
if axis is a list of integers, the input data is flipped along all the specified axis in the order they are provided.
The input must have the same, or less, number of dimensions as the length of the list of axis.
"""
if isinstance(self.axis, int):
return np.flip(x, axis=self.axis).copy()
assert (
len(self.axis) <= x.ndim
), "Axis list has more dimensions than input data. The length of axis needs to be less or equal to input dimensions."
for axis in self.axis:
x = np.flip(x, axis=axis)
return x.copy()
[docs]
def __str__(self) -> str:
return f"Flip(axis={self.axis})"
[docs]
class PerlinMasker(_Transform):
"""Zeroes entries of a tensor according to the sign of Perlin noise. Seed for the noise generator given by torch.randint"""
def __init__(self, octaves: int, scale: float = 1):
"""Zeroes entries of a tensor according to the sign of Perlin noise. Seed for the noise generator given by torch.randint
Parameters
----------
octaves: int
Level of detail for the Perlin noise generator
scale: float = 1
Optionally rescale the Perlin noise. Default is 1 (no rescaling)
"""
if octaves <= 0:
raise ValueError(
f"Number of octaves must be positive, but got {octaves=}"
)
if scale == 0:
raise ValueError(f"Scale can't be 0")
self.octaves = octaves
self.scale = scale
[docs]
def __call__(self, x: np.ndarray) -> np.ndarray:
"""Zeroes entries of a tensor according to the sign of Perlin noise.
Parameters
----------
x: np.ndarray
The tensor whose entries to zero.
"""
mask = np.empty_like(x, dtype=bool)
noise = PerlinNoise(self.octaves, torch.randint(0, 2**32, (1,)).item())
denom = self.scale * max(x.shape)
for pos in product(*[range(i) for i in mask.shape]):
mask[pos] = noise([i / denom for i in pos]) < 0
return x * mask
[docs]
class Squeeze(_Transform):
"""Remove single-dimensional entries from the shape of an array."""
def __init__(self, axis: int):
"""Remove single-dimensional entries from the shape of an array.
Parameters
----------
axis : int
The position of the axis to be removed.
"""
self.axis = axis
[docs]
def __call__(self, x: np.ndarray) -> np.ndarray:
"""Remove single-dimensional entries from the shape of an array."""
return np.squeeze(x, axis=self.axis)
[docs]
def __str__(self) -> str:
return f"Squeeze(axis={self.axis})"
[docs]
class Unsqueeze(_Transform):
"""Add a new axis to the input data at the specified position."""
def __init__(self, axis: int):
"""Add a new axis to the input data at the specified position.
Parameters
----------
axis : int
The position of the new axis to be added.
"""
self.axis = axis
[docs]
def __call__(self, x: np.ndarray) -> np.ndarray:
"""Add a new axis to the input data at the specified position."""
return np.expand_dims(x, axis=self.axis)
[docs]
def __str__(self) -> str:
return f"Unsqueeze(axis={self.axis})"
[docs]
class Transpose(_Transform):
"""Reorder the axes of numpy arrays."""
def __init__(self, axes: Sequence[int]):
"""Reorder the axes of numpy arrays.
Parameters
----------
axes : int
The order of the new axes
"""
self.axes = axes
[docs]
def __call__(self, x: np.ndarray) -> np.ndarray:
"""Reorder the axes of numpy arrays."""
return np.transpose(x, self.axes)
[docs]
def __str__(self) -> str:
return f"Transpose(axes={self.axes})"
[docs]
class CastTo(_Transform):
"""Cast the input data to the specified data type."""
def __init__(self, dtype: Union[type, str]):
"""Cast the input data to the specified data type.
Parameters
----------
dtype : type
The data type to which the input data will be cast.
"""
self.dtype = dtype
[docs]
def __call__(self, x: np.ndarray) -> np.ndarray:
"""Cast the input data to the specified data type."""
return x.astype(self.dtype)
[docs]
def __str__(self) -> str:
return f"CastTo(dtype={self.dtype})"
[docs]
class Padding(_Transform):
def __init__(self, target_h_size: int, target_w_size: int):
self.target_h_size = target_h_size
self.target_w_size = target_w_size
[docs]
def __call__(self, x: np.ndarray) -> np.ndarray:
h, w = x.shape[:2]
pad_h = max(0, self.target_h_size - h)
pad_w = max(0, self.target_w_size - w)
if len(x.shape) == 2:
padded = np.pad(x, ((0, pad_h), (0, pad_w)), mode="reflect")
padded = np.expand_dims(padded, axis=2)
else:
padded = np.pad(x, ((0, pad_h), (0, pad_w), (0, 0)), mode="reflect")
padded = np.transpose(padded, (2, 0, 1))
return padded
[docs]
class Gradient(_Transform):
def __init__(self, direction: int):
"""
direction:
0 -> Gradient along the x-axis (width)
1 -> Gradient along the y-axis (height)
"""
assert direction in [0, 1], "Direction must be 0 (x-axis) or 1 (y-axis)"
self.direction = direction
[docs]
def generate_gradient(self, shape: tuple[int, int]) -> np.ndarray:
"""
Inputs in format (H, W)
Outputs a gradient from 0 to 1 in either x or y direction based on the direction parameter
"""
xx, yy = np.meshgrid(
np.linspace(0, 1, shape[1]), np.linspace(0, 1, shape[0])
)
if self.direction == 0: # Gradient along the x-axis
return xx
elif self.direction == 1: # Gradient along the y-axis
return yy
[docs]
def __call__(self, x):
if x.ndim == 2:
shape = x.shape
else:
shape = x.shape[1:]
gradient = self.generate_gradient(
shape
) # Generate gradient in the specified direction
x_expanded = np.expand_dims(x, axis=0) if x.ndim == 2 else x
gradient_expanded = np.expand_dims(gradient, axis=0)
output = np.concatenate([x_expanded, gradient_expanded], axis=0)
assert output.shape == (
x_expanded.shape[0] + 1,
shape[0],
shape[1],
), f"Output shape {output.shape} does not match expected shape {(shape[0], shape[1], x_expanded.shape[0] + 1)}"
return output
[docs]
class ColorJitter(_Transform):
def __init__(
self,
brightness: float = 1.0,
contrast: float = 1.0,
saturation: float = 1.0,
hue: float = 0.0,
):
"""
Applies fixed adjustments to brightness, contrast, saturation, and hue to an input image.
Parameters
----------
brightness : float, optional
Fixed factor for brightness adjustment. A value of 1.0 means no change. Defaults to 1.0.
contrast : float, optional
Fixed factor for contrast adjustment. A value of 1.0 means no change. Defaults to 1.0.
saturation : float, optional
Fixed factor for saturation adjustment. A value of 1.0 means no change. Defaults to 1.0.
hue : float, optional
Fixed degree shift for hue adjustment, in the range [-180, 180]. Defaults to 0.0.
Returns
-------
np.ndarray
The transformed image with fixed brightness, contrast, saturation, and hue adjustments applied.
"""
self.brightness = brightness
self.contrast = contrast
self.saturation = saturation
self.hue = hue
[docs]
def __call__(self, image: np.ndarray) -> np.ndarray:
# Convert to HSV for hue/saturation adjustment
image = cv2.cvtColor(image, cv2.COLOR_RGB2HSV).astype(np.float32)
# Brightness adjustment
image[..., 2] = np.clip(image[..., 2] * self.brightness, 0, 255)
# Saturation adjustment
image[..., 1] = np.clip(image[..., 1] * self.saturation, 0, 255)
# Contrast adjustment
mean = image[..., 2].mean()
image[..., 2] = np.clip(
(image[..., 2] - mean) * self.contrast + mean, 0, 255
)
# Hue adjustment
image[..., 0] = (image[..., 0] + self.hue) % 180
return cv2.cvtColor(image.astype(np.uint8), cv2.COLOR_HSV2RGB)
[docs]
class Crop(_Transform):
def __init__(
self,
output_size: Tuple[int, int],
pad_mode: str = "reflect",
coords: Tuple[float, float] = (0, 0),
):
"""
Crops the input image to a specified output size, with optional padding if needed.
Parameters
----------
output_size : Tuple[int, int]
Desired output size as (height, width).
pad_mode : str, optional
Padding mode used if output size is larger than input size. Defaults to 'reflect'.
coords : Tuple[int, int], optional
Top-left coordinates for the crop box.
Values must go from 0 to 1 indicating the relative position on where the
new top-left corner can be set, taking in consideration the new size
Returns
-------
np.ndarray
Cropped image, padded as necessary.
"""
self.output_size = output_size
self.pad_mode = pad_mode
self.coords = coords
[docs]
def __call__(self, image: np.ndarray) -> np.ndarray:
X, Y = self.coords
h, w = image.shape[:2]
new_h, new_w = self.output_size
# Apply padding if output size is larger than input size
if new_h > h or new_w > w:
pad_h = max(new_h - h, 0)
pad_w = max(new_w - w, 0)
image = np.pad(
image,
(
(pad_h // 2, pad_h - pad_h // 2),
(pad_w // 2, pad_w - pad_w // 2),
(0, 0),
),
mode=self.pad_mode,
)
# Update dimensions after padding
h, w = image.shape[:2]
x = (h - new_h) * X
y = (w - new_w) * Y
return image[x : x + new_h, y : y + new_w]
[docs]
class GrayScale(_Transform):
def __init__(self, gray: float = 0.0):
"""
Converts an image to grayscale with a specified gray value.
Parameters
----------
gray : float, optional
Gray value to use when converting the image. Defaults to 0.0.
Returns
-------
np.ndarray
Grayscale image in RGB format with all channels set to `gray`.
"""
self.gray = gray
[docs]
def __call__(self, image: np.ndarray) -> np.ndarray:
return np.stack(
[self.gray] * 3, axis=-1
) # Convert grayscale to RGB format
[docs]
class Solarize(_Transform):
def __init__(self, threshold: int = 128):
"""
Solarizes the image by inverting pixel values above a specified threshold.
Parameters
----------
threshold : int, optional
Intensity threshold for inversion, default is 128.
Returns
-------
np.ndarray
Solarized image with inverted pixel values above threshold.
"""
self.threshold = threshold
[docs]
def __call__(self, image: np.ndarray) -> np.ndarray:
if len(image.shape) == 3: # Color image
channels = cv2.split(image)
solarized_channels = [
np.where(channel < self.threshold, channel, 255 - channel)
for channel in channels
]
solarized_image = cv2.merge(solarized_channels)
else: # Grayscale image
solarized_image = np.where(
image < self.threshold, image, 255 - image
)
return solarized_image
[docs]
class Rotation(_Transform):
def __init__(self, degrees: float):
"""
Rotates the image by a specified angle.
Parameters
----------
degrees : float
Angle in degrees to rotate the image.
Returns
-------
np.ndarray
Rotated image with reflection padding.
"""
self.degrees = degrees
[docs]
def __call__(self, image: np.ndarray) -> np.ndarray:
h, w = image.shape[:2]
center = (w // 2, h // 2)
rotation_matrix = cv2.getRotationMatrix2D(center, self.degrees, 1.0)
return cv2.warpAffine(
image, rotation_matrix, (w, h), borderMode=cv2.BORDER_REFLECT
)
[docs]
class PadCrop(_Transform):
def __init__(
self,
target_h_size: int,
target_w_size: int,
padding_mode: str = "reflect",
seed: Optional[int] = None,
constant_values: int = 0,
):
"""Transforms image and pads or crops it to the target size. If the
target size is larger than the input size, the image is padded, else,
the image is cropped. The same happens for both height and width.
The padding mode can be specified, as well as the seed for the random
number generator.
For padding, the padding is applied symmetrically on both sides of the
image, thus, image will be centered in the padded image. For cropping,
the crop is applied from a random position in the image.
Image is expected to be in C x H x W, or H x W format.
Parameters
----------
target_h_size : int
Desired height size.
target_w_size : int
Desired width size.
padding_mode : str, optional
The padding mode to use, by default "reflect"
seed : int, optional
The seed for the random number generator. It is used to generate
the random crop position. By default, None.
constant_values : int, optional
If padding mode is 'constant', the value to use for padding. By
default 0.
"""
self.target_h_size = target_h_size
self.target_w_size = target_w_size
self.padding_mode = padding_mode
self.seed = seed
self.rng = np.random.default_rng(
seed
) # Random number generator with the provided seed
self.constant_values = constant_values
[docs]
def __call__(self, x: np.ndarray) -> np.ndarray:
# Input is expected to be in C x H x W format or H x W format
# If input is in C x H x W format, convert to H x W x C format
if len(x.shape) == 3:
x = np.transpose(x, (1, 2, 0))
# Get the height and width of the input image (H and W)
h, w = x.shape[:2]
#### HEIGHT ####
# Handle height dimension independently: pad if target_h_size > h, else crop
if self.target_h_size > h:
pad_h = self.target_h_size - h
pad_top = pad_h // 2
pad_bottom = pad_h - pad_top
pad_args = {
"array": x,
"pad_width": (
((pad_top, pad_bottom), (0, 0), (0, 0))
if len(x.shape) == 3
else ((pad_top, pad_bottom), (0, 0))
),
"mode": self.padding_mode,
}
if self.padding_mode == "constant":
pad_args["constant_values"] = self.constant_values
x = np.pad(**pad_args)
elif self.target_h_size < h:
crop_h_start = self.rng.integers(0, h - self.target_h_size + 1)
x = x[crop_h_start : crop_h_start + self.target_h_size, ...]
#### WIDTH ####
# Handle width dimension independently: pad if target_w_size > w, else crop
if self.target_w_size > w:
pad_w = self.target_w_size - w
pad_left = pad_w // 2
pad_right = pad_w - pad_left
pad_args = {
"array": x,
"pad_width": (
((0, 0), (pad_left, pad_right), (0, 0))
if len(x.shape) == 3
else ((0, 0), (pad_left, pad_right))
),
"mode": self.padding_mode,
}
if self.padding_mode == "constant":
pad_args["constant_values"] = self.constant_values
x = np.pad(**pad_args)
elif self.target_w_size < w:
crop_w_start = self.rng.integers(0, w - self.target_w_size + 1)
x = x[:, crop_w_start : crop_w_start + self.target_w_size, ...]
# If input is 3D, convert back to C x H x W format
if len(x.shape) == 3:
x = np.transpose(x, (2, 0, 1))
return x
[docs]
def __str__(self) -> str:
return f"PadCrop(target_h_size={self.target_h_size}, target_w_size={self.target_w_size}, padding_mode={self.padding_mode}, constant_values={self.constant_values}, seed={self.seed})"
[docs]
class Identity(_Transform):
"""This class is a dummy transform that does nothing. It is useful when
you want to skip a transform in a pipeline.
"""
[docs]
def __call__(self, x: np.ndarray) -> np.ndarray:
return x
[docs]
def __str__(self) -> str:
return "Identity()"
[docs]
class Indexer(_Transform):
def __init__(self, index: int):
"""This transform extracts a single channel from a multi-channel image.
Parameters
----------
index : int
The index of the channel to extract.
"""
self.index = index
[docs]
def __call__(self, x: np.ndarray) -> np.ndarray:
return x[self.index]
[docs]
def __str__(self) -> str:
return f"Indexer(index={self.index})"
[docs]
class Repeat(_Transform):
def __init__(self, axis: int, n_repetitions: int):
"""This transform repeats the input data along the specified axis.
Parameters
----------
axis : int
The axis along which to repeat the input data.
n_repetitions : int
The number of repetitions.
"""
self.axis = axis
self.n_repetitions = n_repetitions
[docs]
def __call__(self, x: np.ndarray) -> np.ndarray:
return np.repeat(x, self.n_repetitions, axis=self.axis)
[docs]
def __str__(self) -> str:
return f"Repeat(axis={self.axis}, n_repetitions={self.n_repetitions})"
[docs]
class Normalize(_Transform):
def __init__(self, mean, std, to_rgb=False, normalize_labels=False):
"""
Normalize the input data using the provided means and standard deviations.
Parameters
----------
mean : List[float]
List of means for each channel.
std : List[float]
List of standard deviations for each channel.
to_rgb : bool, optional
Convert grayscale images to RGB format, by default False.
normalize_labels : bool, optional
Normalize label images, by default False.
"""
assert len(mean) == len(
std
), "Means and standard deviations must have the same length."
self.mean = mean
self.std = std
self.to_rgb = to_rgb
self.normalize_labels = normalize_labels
[docs]
def __call__(self, data):
is_label = True if data.dtype == np.uint8 else False
if (is_label and self.normalize_labels) or not is_label:
# Convert from gray scale (1 channel) to RGB (3 channels) if to_rgb is True
if self.to_rgb and data.shape[0] == 1:
data = np.repeat(data, 3, axis=0)
assert data.shape[0] == len(
self.mean
), f"Number of channels in data does not match the number of provided mean/std. {data.shape}"
# Normalize each channel
for i in range(len(self.mean)):
data[i, :, :] = (data[i, :, :] - self.mean[i]) / self.std[i]
return data
[docs]
def __str__(self) -> str:
return f"Normalize(mean={self.mean}, std={self.std}, to_rgb={self.to_rgb}, normalize_labels={self.normalize_labels})"