Source code for minerva.transforms.transform

from itertools import product
from typing import Any, List, Optional, Sequence, Tuple, Union

import cv2
import numpy as np
import torch
from perlin_noise import PerlinNoise


[docs] class _Transform: """This class is a base class for all transforms. Transforms is just a fancy word for a function that takes an input and returns an output. The input and output can be anything. However, transforms operates over a single sample of data and does not require any additional information to perform the transformation. The __call__ method should be overridden in subclasses to define the transformation logic. """
[docs] def __call__(self, *args, **kwargs) -> Any: """Implement the transformation logic in this method. Usually, the transformation is applied on a single sample of data. """ raise NotImplementedError
[docs] class TransformPipeline(_Transform): """Apply a sequence of transforms to a single sample of data and return the transformed data. """ def __init__(self, transforms: Sequence[_Transform]): """Apply a sequence of transforms to a single sample of data and return the transformed data. Parameters ---------- transforms : List[_Transform] A list of transforms to be applied to the input data. """ self.transforms = transforms
[docs] def __call__(self, x: Any) -> Any: """Apply a sequence of transforms to a single sample of data and return the transformed data. """ for transform in self.transforms: x = transform(x) return x
[docs] def __add__(self, other: _Transform) -> "TransformPipeline": """Add a transform to the pipeline.""" if isinstance(other, TransformPipeline): return TransformPipeline( list(self.transforms) + list(other.transforms) ) return TransformPipeline(list(self.transforms) + [other])
[docs] def __radd__(self, other: _Transform) -> "TransformPipeline": """Add a transform to the pipeline.""" return self.__add__(other)
[docs] def __str__(self) -> str: return f"TransformPipeline(transforms={self.transforms})"
[docs] class Flip(_Transform): """Flip the input data along the specified axis.""" def __init__(self, axis: Union[int, List[int]] = 0): """Flip the input data along the specified axis. Parameters ---------- axis : int | List[int], optional One or more axis to flip the input data along, by default 0. If a list of axis is provided, the input data is flipped along all the specified axis in the order they are provided. """ self.axis = axis
[docs] def __call__(self, x: np.ndarray) -> np.ndarray: """Flip the input data along the specified axis. if axis is an integer, the input data is flipped along the specified axis. if axis is a list of integers, the input data is flipped along all the specified axis in the order they are provided. The input must have the same, or less, number of dimensions as the length of the list of axis. """ if isinstance(self.axis, int): return np.flip(x, axis=self.axis).copy() assert ( len(self.axis) <= x.ndim ), "Axis list has more dimensions than input data. The length of axis needs to be less or equal to input dimensions." for axis in self.axis: x = np.flip(x, axis=axis) return x.copy()
[docs] def __str__(self) -> str: return f"Flip(axis={self.axis})"
[docs] class PerlinMasker(_Transform): """Zeroes entries of a tensor according to the sign of Perlin noise. Seed for the noise generator given by torch.randint""" def __init__(self, octaves: int, scale: float = 1): """Zeroes entries of a tensor according to the sign of Perlin noise. Seed for the noise generator given by torch.randint Parameters ---------- octaves: int Level of detail for the Perlin noise generator scale: float = 1 Optionally rescale the Perlin noise. Default is 1 (no rescaling) """ if octaves <= 0: raise ValueError( f"Number of octaves must be positive, but got {octaves=}" ) if scale == 0: raise ValueError(f"Scale can't be 0") self.octaves = octaves self.scale = scale
[docs] def __call__(self, x: np.ndarray) -> np.ndarray: """Zeroes entries of a tensor according to the sign of Perlin noise. Parameters ---------- x: np.ndarray The tensor whose entries to zero. """ mask = np.empty_like(x, dtype=bool) noise = PerlinNoise(self.octaves, torch.randint(0, 2**32, (1,)).item()) denom = self.scale * max(x.shape) for pos in product(*[range(i) for i in mask.shape]): mask[pos] = noise([i / denom for i in pos]) < 0 return x * mask
[docs] class Squeeze(_Transform): """Remove single-dimensional entries from the shape of an array.""" def __init__(self, axis: int): """Remove single-dimensional entries from the shape of an array. Parameters ---------- axis : int The position of the axis to be removed. """ self.axis = axis
[docs] def __call__(self, x: np.ndarray) -> np.ndarray: """Remove single-dimensional entries from the shape of an array.""" return np.squeeze(x, axis=self.axis)
[docs] def __str__(self) -> str: return f"Squeeze(axis={self.axis})"
[docs] class Unsqueeze(_Transform): """Add a new axis to the input data at the specified position.""" def __init__(self, axis: int): """Add a new axis to the input data at the specified position. Parameters ---------- axis : int The position of the new axis to be added. """ self.axis = axis
[docs] def __call__(self, x: np.ndarray) -> np.ndarray: """Add a new axis to the input data at the specified position.""" return np.expand_dims(x, axis=self.axis)
[docs] def __str__(self) -> str: return f"Unsqueeze(axis={self.axis})"
[docs] class Transpose(_Transform): """Reorder the axes of numpy arrays.""" def __init__(self, axes: Sequence[int]): """Reorder the axes of numpy arrays. Parameters ---------- axes : int The order of the new axes """ self.axes = axes
[docs] def __call__(self, x: np.ndarray) -> np.ndarray: """Reorder the axes of numpy arrays.""" return np.transpose(x, self.axes)
[docs] def __str__(self) -> str: return f"Transpose(axes={self.axes})"
[docs] class CastTo(_Transform): """Cast the input data to the specified data type.""" def __init__(self, dtype: Union[type, str]): """Cast the input data to the specified data type. Parameters ---------- dtype : type The data type to which the input data will be cast. """ self.dtype = dtype
[docs] def __call__(self, x: np.ndarray) -> np.ndarray: """Cast the input data to the specified data type.""" return x.astype(self.dtype)
[docs] def __str__(self) -> str: return f"CastTo(dtype={self.dtype})"
[docs] class Padding(_Transform): def __init__(self, target_h_size: int, target_w_size: int): self.target_h_size = target_h_size self.target_w_size = target_w_size
[docs] def __call__(self, x: np.ndarray) -> np.ndarray: h, w = x.shape[:2] pad_h = max(0, self.target_h_size - h) pad_w = max(0, self.target_w_size - w) if len(x.shape) == 2: padded = np.pad(x, ((0, pad_h), (0, pad_w)), mode="reflect") padded = np.expand_dims(padded, axis=2) else: padded = np.pad(x, ((0, pad_h), (0, pad_w), (0, 0)), mode="reflect") padded = np.transpose(padded, (2, 0, 1)) return padded
[docs] class Gradient(_Transform): def __init__(self, direction: int): """ direction: 0 -> Gradient along the x-axis (width) 1 -> Gradient along the y-axis (height) """ assert direction in [0, 1], "Direction must be 0 (x-axis) or 1 (y-axis)" self.direction = direction
[docs] def generate_gradient(self, shape: tuple[int, int]) -> np.ndarray: """ Inputs in format (H, W) Outputs a gradient from 0 to 1 in either x or y direction based on the direction parameter """ xx, yy = np.meshgrid( np.linspace(0, 1, shape[1]), np.linspace(0, 1, shape[0]) ) if self.direction == 0: # Gradient along the x-axis return xx elif self.direction == 1: # Gradient along the y-axis return yy
[docs] def __call__(self, x): if x.ndim == 2: shape = x.shape else: shape = x.shape[1:] gradient = self.generate_gradient( shape ) # Generate gradient in the specified direction x_expanded = np.expand_dims(x, axis=0) if x.ndim == 2 else x gradient_expanded = np.expand_dims(gradient, axis=0) output = np.concatenate([x_expanded, gradient_expanded], axis=0) assert output.shape == ( x_expanded.shape[0] + 1, shape[0], shape[1], ), f"Output shape {output.shape} does not match expected shape {(shape[0], shape[1], x_expanded.shape[0] + 1)}" return output
[docs] class ColorJitter(_Transform): def __init__( self, brightness: float = 1.0, contrast: float = 1.0, saturation: float = 1.0, hue: float = 0.0, ): """ Applies fixed adjustments to brightness, contrast, saturation, and hue to an input image. Parameters ---------- brightness : float, optional Fixed factor for brightness adjustment. A value of 1.0 means no change. Defaults to 1.0. contrast : float, optional Fixed factor for contrast adjustment. A value of 1.0 means no change. Defaults to 1.0. saturation : float, optional Fixed factor for saturation adjustment. A value of 1.0 means no change. Defaults to 1.0. hue : float, optional Fixed degree shift for hue adjustment, in the range [-180, 180]. Defaults to 0.0. Returns ------- np.ndarray The transformed image with fixed brightness, contrast, saturation, and hue adjustments applied. """ self.brightness = brightness self.contrast = contrast self.saturation = saturation self.hue = hue
[docs] def __call__(self, image: np.ndarray) -> np.ndarray: # Convert to HSV for hue/saturation adjustment image = cv2.cvtColor(image, cv2.COLOR_RGB2HSV).astype(np.float32) # Brightness adjustment image[..., 2] = np.clip(image[..., 2] * self.brightness, 0, 255) # Saturation adjustment image[..., 1] = np.clip(image[..., 1] * self.saturation, 0, 255) # Contrast adjustment mean = image[..., 2].mean() image[..., 2] = np.clip( (image[..., 2] - mean) * self.contrast + mean, 0, 255 ) # Hue adjustment image[..., 0] = (image[..., 0] + self.hue) % 180 return cv2.cvtColor(image.astype(np.uint8), cv2.COLOR_HSV2RGB)
[docs] class Crop(_Transform): def __init__( self, output_size: Tuple[int, int], pad_mode: str = "reflect", coords: Tuple[float, float] = (0, 0), ): """ Crops the input image to a specified output size, with optional padding if needed. Parameters ---------- output_size : Tuple[int, int] Desired output size as (height, width). pad_mode : str, optional Padding mode used if output size is larger than input size. Defaults to 'reflect'. coords : Tuple[int, int], optional Top-left coordinates for the crop box. Values must go from 0 to 1 indicating the relative position on where the new top-left corner can be set, taking in consideration the new size Returns ------- np.ndarray Cropped image, padded as necessary. """ self.output_size = output_size self.pad_mode = pad_mode self.coords = coords
[docs] def __call__(self, image: np.ndarray) -> np.ndarray: X, Y = self.coords h, w = image.shape[:2] new_h, new_w = self.output_size # Apply padding if output size is larger than input size if new_h > h or new_w > w: pad_h = max(new_h - h, 0) pad_w = max(new_w - w, 0) image = np.pad( image, ( (pad_h // 2, pad_h - pad_h // 2), (pad_w // 2, pad_w - pad_w // 2), (0, 0), ), mode=self.pad_mode, ) # Update dimensions after padding h, w = image.shape[:2] x = (h - new_h) * X y = (w - new_w) * Y return image[x : x + new_h, y : y + new_w]
[docs] class GrayScale(_Transform): def __init__(self, gray: float = 0.0): """ Converts an image to grayscale with a specified gray value. Parameters ---------- gray : float, optional Gray value to use when converting the image. Defaults to 0.0. Returns ------- np.ndarray Grayscale image in RGB format with all channels set to `gray`. """ self.gray = gray
[docs] def __call__(self, image: np.ndarray) -> np.ndarray: return np.stack( [self.gray] * 3, axis=-1 ) # Convert grayscale to RGB format
[docs] class Solarize(_Transform): def __init__(self, threshold: int = 128): """ Solarizes the image by inverting pixel values above a specified threshold. Parameters ---------- threshold : int, optional Intensity threshold for inversion, default is 128. Returns ------- np.ndarray Solarized image with inverted pixel values above threshold. """ self.threshold = threshold
[docs] def __call__(self, image: np.ndarray) -> np.ndarray: if len(image.shape) == 3: # Color image channels = cv2.split(image) solarized_channels = [ np.where(channel < self.threshold, channel, 255 - channel) for channel in channels ] solarized_image = cv2.merge(solarized_channels) else: # Grayscale image solarized_image = np.where( image < self.threshold, image, 255 - image ) return solarized_image
[docs] class Rotation(_Transform): def __init__(self, degrees: float): """ Rotates the image by a specified angle. Parameters ---------- degrees : float Angle in degrees to rotate the image. Returns ------- np.ndarray Rotated image with reflection padding. """ self.degrees = degrees
[docs] def __call__(self, image: np.ndarray) -> np.ndarray: h, w = image.shape[:2] center = (w // 2, h // 2) rotation_matrix = cv2.getRotationMatrix2D(center, self.degrees, 1.0) return cv2.warpAffine( image, rotation_matrix, (w, h), borderMode=cv2.BORDER_REFLECT )
[docs] class PadCrop(_Transform): def __init__( self, target_h_size: int, target_w_size: int, padding_mode: str = "reflect", seed: Optional[int] = None, constant_values: int = 0, ): """Transforms image and pads or crops it to the target size. If the target size is larger than the input size, the image is padded, else, the image is cropped. The same happens for both height and width. The padding mode can be specified, as well as the seed for the random number generator. For padding, the padding is applied symmetrically on both sides of the image, thus, image will be centered in the padded image. For cropping, the crop is applied from a random position in the image. Image is expected to be in C x H x W, or H x W format. Parameters ---------- target_h_size : int Desired height size. target_w_size : int Desired width size. padding_mode : str, optional The padding mode to use, by default "reflect" seed : int, optional The seed for the random number generator. It is used to generate the random crop position. By default, None. constant_values : int, optional If padding mode is 'constant', the value to use for padding. By default 0. """ self.target_h_size = target_h_size self.target_w_size = target_w_size self.padding_mode = padding_mode self.seed = seed self.rng = np.random.default_rng( seed ) # Random number generator with the provided seed self.constant_values = constant_values
[docs] def __call__(self, x: np.ndarray) -> np.ndarray: # Input is expected to be in C x H x W format or H x W format # If input is in C x H x W format, convert to H x W x C format if len(x.shape) == 3: x = np.transpose(x, (1, 2, 0)) # Get the height and width of the input image (H and W) h, w = x.shape[:2] #### HEIGHT #### # Handle height dimension independently: pad if target_h_size > h, else crop if self.target_h_size > h: pad_h = self.target_h_size - h pad_top = pad_h // 2 pad_bottom = pad_h - pad_top pad_args = { "array": x, "pad_width": ( ((pad_top, pad_bottom), (0, 0), (0, 0)) if len(x.shape) == 3 else ((pad_top, pad_bottom), (0, 0)) ), "mode": self.padding_mode, } if self.padding_mode == "constant": pad_args["constant_values"] = self.constant_values x = np.pad(**pad_args) elif self.target_h_size < h: crop_h_start = self.rng.integers(0, h - self.target_h_size + 1) x = x[crop_h_start : crop_h_start + self.target_h_size, ...] #### WIDTH #### # Handle width dimension independently: pad if target_w_size > w, else crop if self.target_w_size > w: pad_w = self.target_w_size - w pad_left = pad_w // 2 pad_right = pad_w - pad_left pad_args = { "array": x, "pad_width": ( ((0, 0), (pad_left, pad_right), (0, 0)) if len(x.shape) == 3 else ((0, 0), (pad_left, pad_right)) ), "mode": self.padding_mode, } if self.padding_mode == "constant": pad_args["constant_values"] = self.constant_values x = np.pad(**pad_args) elif self.target_w_size < w: crop_w_start = self.rng.integers(0, w - self.target_w_size + 1) x = x[:, crop_w_start : crop_w_start + self.target_w_size, ...] # If input is 3D, convert back to C x H x W format if len(x.shape) == 3: x = np.transpose(x, (2, 0, 1)) return x
[docs] def __str__(self) -> str: return f"PadCrop(target_h_size={self.target_h_size}, target_w_size={self.target_w_size}, padding_mode={self.padding_mode}, constant_values={self.constant_values}, seed={self.seed})"
[docs] class Identity(_Transform): """This class is a dummy transform that does nothing. It is useful when you want to skip a transform in a pipeline. """
[docs] def __call__(self, x: np.ndarray) -> np.ndarray: return x
[docs] def __str__(self) -> str: return "Identity()"
[docs] class Indexer(_Transform): def __init__(self, index: int): """This transform extracts a single channel from a multi-channel image. Parameters ---------- index : int The index of the channel to extract. """ self.index = index
[docs] def __call__(self, x: np.ndarray) -> np.ndarray: return x[self.index]
[docs] def __str__(self) -> str: return f"Indexer(index={self.index})"
[docs] class Repeat(_Transform): def __init__(self, axis: int, n_repetitions: int): """This transform repeats the input data along the specified axis. Parameters ---------- axis : int The axis along which to repeat the input data. n_repetitions : int The number of repetitions. """ self.axis = axis self.n_repetitions = n_repetitions
[docs] def __call__(self, x: np.ndarray) -> np.ndarray: return np.repeat(x, self.n_repetitions, axis=self.axis)
[docs] def __str__(self) -> str: return f"Repeat(axis={self.axis}, n_repetitions={self.n_repetitions})"
[docs] class Normalize(_Transform): def __init__(self, mean, std, to_rgb=False, normalize_labels=False): """ Normalize the input data using the provided means and standard deviations. Parameters ---------- mean : List[float] List of means for each channel. std : List[float] List of standard deviations for each channel. to_rgb : bool, optional Convert grayscale images to RGB format, by default False. normalize_labels : bool, optional Normalize label images, by default False. """ assert len(mean) == len( std ), "Means and standard deviations must have the same length." self.mean = mean self.std = std self.to_rgb = to_rgb self.normalize_labels = normalize_labels
[docs] def __call__(self, data): is_label = True if data.dtype == np.uint8 else False if (is_label and self.normalize_labels) or not is_label: # Convert from gray scale (1 channel) to RGB (3 channels) if to_rgb is True if self.to_rgb and data.shape[0] == 1: data = np.repeat(data, 3, axis=0) assert data.shape[0] == len( self.mean ), f"Number of channels in data does not match the number of provided mean/std. {data.shape}" # Normalize each channel for i in range(len(self.mean)): data[i, :, :] = (data[i, :, :] - self.mean[i]) / self.std[i] return data
[docs] def __str__(self) -> str: return f"Normalize(mean={self.mean}, std={self.std}, to_rgb={self.to_rgb}, normalize_labels={self.normalize_labels})"
[docs] class ContrastiveTransform(_Transform): def __init__(self, transform: _Transform): self.transform = transform
[docs] def __call__(self, x: np.ndarray) -> Tuple: return self.transform(x), self.transform(x)
[docs] def __str__(self) -> str: return f"ContrastiveTransform(transform={self.transform})"