from itertools import product
from typing import Any, List, Literal, Optional, Sequence, Tuple, Union
import cv2
import numpy as np
import torch
from perlin_noise import PerlinNoise
[docs]
class Flip(_Transform):
"""Flip the input data along the specified axis."""
def __init__(self, axis: Union[int, List[int]] = 1):
"""Flip the input data along the specified axis.
Parameters
----------
axis : int | List[int], optional
One or more axis to flip the input data along, by default 1 (horizontal).
If a list of axis is provided, the input data is flipped along all the specified axis in the order they are provided.
"""
self.axis = axis
[docs]
def __call__(self, x: np.ndarray) -> np.ndarray:
"""Flip the input data along the specified axis.
if axis is an integer, the input data is flipped along the specified axis.
if axis is a list of integers, the input data is flipped along all the specified axis in the order they are provided.
The input must have the same, or less, number of dimensions as the length of the list of axis.
"""
if isinstance(self.axis, list):
assert (
len(self.axis) <= x.ndim
), "Axis list has more dimensions than input data. The length of axis needs to be less or equal to input dimensions."
return np.flip(x, axis=self.axis).copy()
[docs]
def __str__(self) -> str:
return f"Flip(axis={self.axis})"
[docs]
class PerlinMasker(_Transform):
"""Zeroes entries of a tensor according to the sign of Perlin noise. Seed for the noise generator given by torch.randint"""
def __init__(self, octaves: int, scale: float = 1):
"""Zeroes entries of a tensor according to the sign of Perlin noise. Seed for the noise generator given by torch.randint
Parameters
----------
octaves: int
Level of detail for the Perlin noise generator
scale: float = 1
Optionally rescale the Perlin noise. Default is 1 (no rescaling)
"""
if octaves <= 0:
raise ValueError(f"Number of octaves must be positive, but got {octaves=}")
if scale == 0:
raise ValueError(f"Scale can't be 0")
self.octaves = octaves
self.scale = scale
[docs]
def __call__(self, x: np.ndarray) -> np.ndarray:
"""Zeroes entries of a tensor according to the sign of Perlin noise.
Parameters
----------
x: np.ndarray
The tensor whose entries to zero.
"""
mask = np.empty_like(x, dtype=bool)
noise = PerlinNoise(self.octaves, torch.randint(0, 2**32, (1,)).item())
denom = self.scale * max(x.shape)
for pos in product(*[range(i) for i in mask.shape]):
mask[pos] = noise([i / denom for i in pos]) < 0
return x * mask
[docs]
def __str__(self) -> str:
return f"PerlinMasker(octaves={self.octaves}, scale={self.scale})"
[docs]
class Squeeze(_Transform):
"""Remove single-dimensional entries from the shape of an array."""
def __init__(self, axis: int):
"""Remove single-dimensional entries from the shape of an array.
Parameters
----------
axis : int
The position of the axis to be removed.
"""
self.axis = axis
[docs]
def __call__(self, x: np.ndarray) -> np.ndarray:
"""Remove single-dimensional entries from the shape of an array."""
return np.squeeze(x, axis=self.axis)
[docs]
def __str__(self) -> str:
return f"Squeeze(axis={self.axis})"
[docs]
class Unsqueeze(_Transform):
"""Add a new axis to the input data at the specified position."""
def __init__(self, axis: int, label_only: bool = False):
"""Add a new axis to the input data at the specified position.
Parameters
----------
axis : int
The position of the new axis to be added.
label_only : bool, optional
If True, the transform will only apply to the label (second element) of the input tuple.
If False, it will apply to the entire input data, by default False.
"""
self.axis = axis
self.label_only = label_only
[docs]
def __call__(self, x: np.ndarray):
"""Add a new axis to the input data at the specified position."""
if self.label_only:
# Only apply to the second element (label) of the tuple
if np.issubdtype(x.dtype, np.unsignedinteger):
return np.expand_dims(x, axis=self.axis)
else:
return x
else:
return np.expand_dims(x, axis=self.axis)
[docs]
def __str__(self) -> str:
return f"Unsqueeze(axis={self.axis})"
[docs]
class Transpose(_Transform):
"""Reorder the axes of numpy arrays."""
def __init__(self, axes: Sequence[int]):
"""Reorder the axes of numpy arrays.
Parameters
----------
axes : int
The order of the new axes
"""
self.axes = axes
[docs]
def __call__(self, x: np.ndarray) -> np.ndarray:
"""Reorder the axes of numpy arrays."""
return np.transpose(x, self.axes)
[docs]
def __str__(self) -> str:
return f"Transpose(axes={self.axes})"
[docs]
class CastTo(_Transform):
"""Cast the input data to the specified data type."""
def __init__(self, dtype: Union[type, str]):
"""Cast the input data to the specified data type.
Parameters
----------
dtype : type
The data type to which the input data will be cast.
"""
self.dtype = dtype
[docs]
def __call__(self, x: np.ndarray) -> np.ndarray:
"""Cast the input data to the specified data type."""
return x.astype(self.dtype)
[docs]
def __str__(self) -> str:
return f"CastTo(dtype={self.dtype})"
[docs]
class Padding(_Transform):
def __init__(
self,
target_h_size: int,
target_w_size: int,
padding_mode: str = "reflect",
padding_value: int = 0,
mask_padding_value: int = 255,
):
"""
A transform that pads images and masks to reach target dimensions.
This transform automatically pads input arrays to the specified target dimensions
using different padding strategies. It intelligently determines whether the input
is an image or mask based on the data type and applies appropriate padding values.
For unsigned integer arrays (typically masks), it uses the mask padding value,
while for other arrays (typically images), it uses the regular padding value.
Parameters
----------
target_h_size : int
Target height in pixels after padding.
target_w_size : int
Target width in pixels after padding.
padding_mode : str, default="reflect"
Padding mode to use. Options include "reflect", "constant", "edge", "wrap".
When using "constant", the padding values specified below will be used.
padding_value : int, default=0
Padding value to use for image data when padding_mode is "constant".
Only used for non-unsigned integer arrays.
mask_padding_value : int, default=255
Padding value to use for mask data when padding_mode is "constant".
Only used for unsigned integer arrays. Commonly set to 255 to represent
ignore class in segmentation masks.
"""
self.target_h_size = target_h_size
self.target_w_size = target_w_size
self.padding_mode = padding_mode
self.padding_value = padding_value if padding_mode == "constant" else None
self.mask_padding_value = (
mask_padding_value if padding_mode == "constant" else None
)
[docs]
def __call__(self, x: np.ndarray) -> np.ndarray:
h, w = x.shape[:2]
pad_h = max(0, self.target_h_size - h)
pad_w = max(0, self.target_w_size - w)
if len(x.shape) == 2:
padded = np.pad(
x,
((0, pad_h), (0, pad_w)),
mode=self.padding_mode, # type: ignore
constant_values=(
self.mask_padding_value
if np.issubdtype(x.dtype, np.unsignedinteger)
else self.padding_value
),
) # type: ignore
padded = np.expand_dims(padded, axis=2)
else:
padded = np.pad(
x,
((0, pad_h), (0, pad_w), (0, 0)),
mode=self.padding_mode, # type: ignore
constant_values=(
self.mask_padding_value
if np.issubdtype(x.dtype, np.unsignedinteger)
else self.padding_value
),
) # type: ignore
return padded
[docs]
def __str__(self) -> str:
return f"Padding(target_h_size={self.target_h_size}, target_w_size={self.target_w_size})"
[docs]
class Gradient(_Transform):
directions = {0: "x (width)", 1: "y (height)"}
def __init__(self, direction: int):
"""
direction:
0 -> Gradient along the x-axis (width)
1 -> Gradient along the y-axis (height)
"""
assert direction in [0, 1], "Direction must be 0 (x-axis) or 1 (y-axis)"
self.direction = direction
[docs]
def generate_gradient(self, shape: tuple[int, int]) -> np.ndarray:
"""
Inputs in format (H, W)
Outputs a gradient from 0 to 1 in either x or y direction based on the direction parameter
"""
xx, yy = np.meshgrid(np.linspace(0, 1, shape[1]), np.linspace(0, 1, shape[0]))
if self.direction == 0: # Gradient along the x-axis
return xx
elif self.direction == 1: # Gradient along the y-axis
return yy
[docs]
def __call__(self, x):
if x.ndim == 2:
shape = x.shape
else:
shape = x.shape[1:]
gradient = self.generate_gradient(
shape
) # Generate gradient in the specified direction
x_expanded = np.expand_dims(x, axis=0) if x.ndim == 2 else x
gradient_expanded = np.expand_dims(gradient, axis=0)
output = np.concatenate([x_expanded, gradient_expanded], axis=0)
assert output.shape == (
x_expanded.shape[0] + 1,
shape[0],
shape[1],
), f"Output shape {output.shape} does not match expected shape {(shape[0], shape[1], x_expanded.shape[0] + 1)}"
return output
[docs]
def __str__(self) -> str:
return (
f"Gradient(direction={self.direction} - {self.directions[self.direction]})"
)
[docs]
class ColorJitter(_Transform):
def __init__(
self,
brightness: float = 1.0,
contrast: float = 1.0,
saturation: float = 1.0,
hue: float = 0.0,
):
"""
Applies fixed adjustments to brightness, contrast, saturation, and hue to an input image.
Parameters
----------
brightness : float, optional
Fixed factor for brightness adjustment. A value of 1.0 means no change. Defaults to 1.0.
contrast : float, optional
Fixed factor for contrast adjustment. A value of 1.0 means no change. Defaults to 1.0.
saturation : float, optional
Fixed factor for saturation adjustment. A value of 1.0 means no change. Defaults to 1.0.
hue : float, optional
Fixed degree shift for hue adjustment, in the range [-180, 180]. Defaults to 0.0.
Returns
-------
np.ndarray
The transformed image with fixed brightness, contrast, saturation, and hue adjustments applied.
"""
self.brightness = brightness
self.contrast = contrast
self.saturation = saturation
self.hue = hue
[docs]
def __call__(self, image: np.ndarray) -> np.ndarray:
# Convert to HSV for hue/saturation adjustment
image = cv2.cvtColor(image, cv2.COLOR_RGB2HSV).astype(np.float32)
# Brightness adjustment
image[..., 2] = np.clip(image[..., 2] * self.brightness, 0, 255)
# Saturation adjustment
image[..., 1] = np.clip(image[..., 1] * self.saturation, 0, 255)
# Contrast adjustment
mean = image[..., 2].mean()
image[..., 2] = np.clip((image[..., 2] - mean) * self.contrast + mean, 0, 255)
# Hue adjustment
image[..., 0] = (image[..., 0] + self.hue) % 180
return cv2.cvtColor(image.astype(np.uint8), cv2.COLOR_HSV2RGB)
[docs]
def __str__(self) -> str:
return f"ColorJitter(brightness={self.brightness}, contrast={self.contrast}, saturation={self.saturation}, hue={self.hue})"
[docs]
class Crop(_Transform):
"""Crop an image to a given output size, with optional padding and bounding-box support.
Expects inputs in CHW (C, H, W) or 2D (H, W) format.
"""
def __init__(
self,
output_size: Tuple[int, int],
pad_mode: str = "reflect",
coords: Tuple[float, float] = (0, 0),
bbox: Optional[Tuple[int, int, int, int]] = None,
):
"""
Parameters
----------
output_size : Tuple[int, int]
Desired output size as (height, width).
pad_mode : str, optional
Padding mode used if output size is larger than input size. Defaults to 'reflect'.
Valid modes include: 'constant', 'edge', 'linear_ramp', 'maximum', 'mean', 'median',
'minimum', 'reflect', 'symmetric', 'wrap', 'empty'.
coords : Tuple[float, float], optional
Top-left coordinates for the crop box as (row, col).
Values must go from 0 to 1 indicating the relative position on where the
new top-left corner can be set, taking in consideration the new size.
Defaults to (0, 0) which corresponds to the top-left corner of the image.
bbox : Optional[Tuple[int, int, int, int]], optional
If provided, crops the image to the bounding box defined by (y1, y2, x1, x2).
If this parameter is set, the `coords` parameter is ignored. Defaults to None.
"""
self.output_size = output_size
self.pad_mode = pad_mode
self.coords = coords
self.bbox = bbox
[docs]
def __call__(self, image: np.ndarray) -> np.ndarray:
"""Crop the image to the configured output size.
Parameters
----------
image : np.ndarray
Input array in CHW (C, H, W) or 2D (H, W) format.
Returns
-------
np.ndarray
Cropped array. Shape is (C, new_h, new_w) for 3D input or
(new_h, new_w) for 2D input. Padded symmetrically if the
output size exceeds the input size.
"""
# Always read spatial dims from the last two axes (consistent with CHW)
h, w = image.shape[-2:]
new_h, new_w = self.output_size
# Apply padding if output size is larger than input size
if new_h > h or new_w > w:
pad_h = max(new_h - h, 0)
pad_w = max(new_w - w, 0)
# 3D (C, H, W): pad only H and W, leave C untouched
# 2D (H, W): pad both dims directly
if len(image.shape) == 3:
pad_width = (
(0, 0),
(pad_h // 2, pad_h - pad_h // 2),
(pad_w // 2, pad_w - pad_w // 2),
)
else:
pad_width = (
(pad_h // 2, pad_h - pad_h // 2),
(pad_w // 2, pad_w - pad_w // 2),
)
image = np.pad(image, pad_width, mode=self.pad_mode)
# Update spatial dims from last two axes after padding (CHW fix)
h, w = image.shape[-2:]
if self.bbox is not None:
y1, y2, x1, x2 = self.bbox
return image[..., y1:y2, x1:x2]
row, col = self.coords
r = int((h - new_h) * row)
c = int((w - new_w) * col)
return image[..., r : r + new_h, c : c + new_w]
[docs]
def __str__(self) -> str:
return f"Crop(output_size={self.output_size}, pad_mode={self.pad_mode}, coords={self.coords})"
[docs]
class GrayScale(_Transform):
def __init__(self, method: Literal["average", "luminosity"] = "luminosity"):
"""
Converts an image to grayscale using the specified method.
Parameters
----------
method : {'average', 'luminosity'}, optional
The method to compute grayscale:
- 'average': (R + G + B) / 3
- 'luminosity': 0.299R + 0.587G + 0.114B
Defaults to 'luminosity'.
"""
if method not in {"average", "luminosity"}:
raise ValueError("method must be 'average' or 'luminosity'")
self.method = method
[docs]
def __call__(self, image: np.ndarray) -> np.ndarray:
"""
Applies grayscale conversion to the input RGB image.
Parameters
----------
image : np.ndarray
Input image in RGB format with shape (H, W, 3).
Returns
-------
np.ndarray
Grayscale image with shape (H, W, 3) where all channels are equal.
"""
assert (
image.ndim == 3 and image.shape[2] == 3
), "Input must have shape (H, W, 3)"
if self.method == "average":
gray = image.mean(axis=2)
else: # luminosity
weights = np.array([0.299, 0.587, 0.114])
gray = np.dot(image[..., :3], weights)
return np.stack([gray, gray, gray], axis=-1).astype(image.dtype)
[docs]
def __str__(self) -> str:
return f"GrayScale(method='{self.method}')"
[docs]
class Solarize(_Transform):
def __init__(self, threshold: int = 128):
"""
Solarizes the image by inverting pixel values above a specified threshold.
Parameters
----------
threshold : int, optional
Intensity threshold for inversion, default is 128.
Returns
-------
np.ndarray
Solarized image with inverted pixel values above threshold.
"""
self.threshold = threshold
[docs]
def __call__(self, image: np.ndarray) -> np.ndarray:
if len(image.shape) == 3: # Color image
channels = cv2.split(image)
solarized_channels = [
np.where(channel < self.threshold, channel, 255 - channel)
for channel in channels
]
solarized_image = cv2.merge(solarized_channels)
else: # Grayscale image
solarized_image = np.where(image < self.threshold, image, 255 - image)
return solarized_image
[docs]
def __str__(self):
return f"Solarize(threshold={self.threshold})"
[docs]
class Rotation(_Transform):
def __init__(self, degrees: float):
"""
Rotates an image by a specified angle using reflection padding.
Parameters
----------
degrees : float
Angle in degrees to rotate the image counterclockwise.
Notes
-----
- Accepts input with shape (H, W) or (H, W, C), where C can be any number of channels.
- For multi-channel images, the same transformation is applied to all channels.
- Uses OpenCV's warpAffine with reflection padding.
"""
self.degrees = degrees
[docs]
def __call__(self, image: np.ndarray) -> np.ndarray:
if image.ndim not in (2, 3):
raise ValueError(
f"Unsupported image shape: {image.shape}. Expected 2D or 3D with channels last."
)
h, w = image.shape[:2]
center = (w // 2, h // 2)
rotation_matrix = cv2.getRotationMatrix2D(center, self.degrees, 1.0)
if image.ndim == 2:
# Single-channel 2D image
return cv2.warpAffine(
image, rotation_matrix, (w, h), borderMode=cv2.BORDER_REFLECT
)
# Multi-channel image (H, W, C)
channels = [
cv2.warpAffine(
image[:, :, c], rotation_matrix, (w, h), borderMode=cv2.BORDER_REFLECT
)
for c in range(image.shape[2])
]
return np.stack(channels, axis=-1)
[docs]
def __str__(self) -> str:
return f"Rotation(degrees={self.degrees})"
[docs]
class PadCrop(_Transform):
def __init__(
self,
target_h_size: int,
target_w_size: int,
padding_mode: str = "reflect",
seed: Optional[int] = None,
constant_values: int = 0,
):
"""Transforms image and pads or crops it to the target size. If the
target size is larger than the input size, the image is padded, else,
the image is cropped. The same happens for both height and width.
The padding mode can be specified, as well as the seed for the random
number generator.
For padding, the padding is applied symmetrically on both sides of the
image, thus, image will be centered in the padded image. For cropping,
the crop is applied from a random position in the image.
Image is expected to be in C x H x W, or H x W format.
Parameters
----------
target_h_size : int
Desired height size.
target_w_size : int
Desired width size.
padding_mode : str, optional
The padding mode to use, by default "reflect"
seed : int, optional
The seed for the random number generator. It is used to generate
the random crop position. By default, None.
constant_values : int, optional
If padding mode is 'constant', the value to use for padding. By
default 0.
"""
self.target_h_size = target_h_size
self.target_w_size = target_w_size
self.padding_mode = padding_mode
self.seed = seed
self.rng = np.random.default_rng(
seed
) # Random number generator with the provided seed
self.constant_values = constant_values
[docs]
def __call__(self, x: np.ndarray) -> np.ndarray:
# Input is expected to be in C x H x W format or H x W format
# If input is in C x H x W format, convert to H x W x C format
if len(x.shape) == 3:
x = np.transpose(x, (1, 2, 0))
# Get the height and width of the input image (H and W)
h, w = x.shape[:2]
#### HEIGHT ####
# Handle height dimension independently: pad if target_h_size > h, else crop
if self.target_h_size > h:
pad_h = self.target_h_size - h
pad_top = pad_h // 2
pad_bottom = pad_h - pad_top
pad_args = {
"array": x,
"pad_width": (
((pad_top, pad_bottom), (0, 0), (0, 0))
if len(x.shape) == 3
else ((pad_top, pad_bottom), (0, 0))
),
"mode": self.padding_mode,
}
if self.padding_mode == "constant":
pad_args["constant_values"] = self.constant_values
x = np.pad(**pad_args)
elif self.target_h_size < h:
crop_h_start = self.rng.integers(0, h - self.target_h_size + 1)
x = x[crop_h_start : crop_h_start + self.target_h_size, ...]
#### WIDTH ####
# Handle width dimension independently: pad if target_w_size > w, else crop
if self.target_w_size > w:
pad_w = self.target_w_size - w
pad_left = pad_w // 2
pad_right = pad_w - pad_left
pad_args = {
"array": x,
"pad_width": (
((0, 0), (pad_left, pad_right), (0, 0))
if len(x.shape) == 3
else ((0, 0), (pad_left, pad_right))
),
"mode": self.padding_mode,
}
if self.padding_mode == "constant":
pad_args["constant_values"] = self.constant_values
x = np.pad(**pad_args)
elif self.target_w_size < w:
crop_w_start = self.rng.integers(0, w - self.target_w_size + 1)
x = x[:, crop_w_start : crop_w_start + self.target_w_size, ...]
# If input is 3D, convert back to C x H x W format
if len(x.shape) == 3:
x = np.transpose(x, (2, 0, 1))
return x
[docs]
def __str__(self) -> str:
return f"PadCrop(target_h_size={self.target_h_size}, target_w_size={self.target_w_size}, padding_mode={self.padding_mode}, constant_values={self.constant_values}, seed={self.seed})"
[docs]
class Identity(_Transform):
"""This class is a dummy transform that does nothing. It is useful when
you want to skip a transform in a pipeline.
"""
[docs]
def __call__(self, x: np.ndarray) -> np.ndarray:
return x
[docs]
def __str__(self) -> str:
return "Identity()"
[docs]
class Indexer(_Transform):
def __init__(self, index: int):
"""This transform extracts a single channel from a multi-channel image.
Parameters
----------
index : int
The index of the channel to extract.
"""
self.index = index
[docs]
def __call__(self, x: np.ndarray) -> np.ndarray:
return x[self.index]
[docs]
def __str__(self) -> str:
return f"Indexer(index={self.index})"
[docs]
class Repeat(_Transform):
def __init__(self, axis: int, n_repetitions: int):
"""This transform repeats the input data along the specified axis.
Parameters
----------
axis : int
The axis along which to repeat the input data.
n_repetitions : int
The number of repetitions.
"""
self.axis = axis
self.n_repetitions = n_repetitions
[docs]
def __call__(self, x: np.ndarray) -> np.ndarray:
return np.repeat(x, self.n_repetitions, axis=self.axis)
[docs]
def __str__(self) -> str:
return f"Repeat(axis={self.axis}, n_repetitions={self.n_repetitions})"
[docs]
class Normalize(_Transform):
def __init__(self, mean, std, to_rgb=False, normalize_labels=False):
"""
Normalize the input data using the provided means and standard deviations. I assumes the data shape is (C, H, W)
Parameters
----------
mean : List[float]
List of means for each channel.
std : List[float]
List of standard deviations for each channel.
to_rgb : bool, optional
Convert grayscale images to RGB format, by default False.
normalize_labels : bool, optional
Normalize label images, by default False.
"""
assert len(mean) == len(
std
), "Means and standard deviations must have the same length."
self.mean = mean
self.std = std
self.to_rgb = to_rgb
self.normalize_labels = normalize_labels
[docs]
def __call__(self, data):
is_label = np.issubdtype(data.dtype, np.unsignedinteger)
if (is_label and self.normalize_labels) or not is_label:
# Convert from gray scale (1 channel) to RGB (3 channels) if to_rgb is True
if self.to_rgb and data.shape[0] == 1:
data = np.repeat(data, 3, axis=0)
cv2.cvtColor(data, cv2.COLOR_BGR2RGB, data)
assert data.shape[0] == len(
self.mean
), f"Number of channels in data does not match the number of provided mean/std. {data.shape}"
# Normalize each channel
for i in range(len(self.mean)):
data[i, :, :] = (data[i, :, :] - self.mean[i]) / self.std[i]
return data
[docs]
def __str__(self) -> str:
return f"Normalize(mean={self.mean}, std={self.std}, to_rgb={self.to_rgb}, normalize_labels={self.normalize_labels})"
[docs]
class Reshape(_Transform):
def __init__(self, shape: Tuple[int, ...]):
"""Reshape the input data to the specified shape.
Parameters
----------
shape : Tuple[int, ...]
The target shape for the input data.
"""
self.shape = shape
[docs]
def __call__(self, x: np.ndarray) -> np.ndarray:
return x.reshape(self.shape)
[docs]
def __str__(self) -> str:
return f"Reshape(shape={self.shape})"
[docs]
class Resize(_Transform):
def __init__(
self,
target_h_size: int,
target_w_size: int,
keep_aspect_ratio: bool = False,
detect_mask: bool = False,
):
"""
A transformation class for resizing images with optional aspect ratio preservation.
This transform can resize images to specified target dimensions either by direct resizing
or by preserving the aspect ratio. When preserving aspect ratio, the transform scales
the image using the smaller of the two scaling factors to ensure the image fits within
the target dimensions without distortion.
Parameters
----------
target_h_size : int
Target height for the resized image in pixels.
target_w_size : int
Target width for the resized image in pixels.
keep_aspect_ratio : bool, default=False
If True, preserves the original aspect ratio by scaling with the smaller
scaling factor. If False, directly resizes to target dimensions which
may distort the image.
detect_mask : bool, default=False
If True, uses nearest neighbor interpolation for resizing masks (integer types).
"""
self.target_h_size = target_h_size
self.target_w_size = target_w_size
self.keep_aspect_ratio = keep_aspect_ratio
self.detect_mask = detect_mask
[docs]
def __call__(self, x: np.ndarray) -> np.ndarray:
original_height, original_width = x.shape[:2]
width, height = self.target_w_size, self.target_h_size
if self.keep_aspect_ratio:
# Calculate scaling factors for both dimensions
width_scale = self.target_w_size / original_width
height_scale = self.target_h_size / original_height
# Choose the smaller scale to keep aspect ratio, and round down
scale = min(width_scale, height_scale)
# Compute new dimensions, rounding down to match MMsegmentation's behavior
width = int(original_width * scale)
height = int(original_height * scale)
if np.issubdtype(x.dtype, np.unsignedinteger) and self.detect_mask:
return cv2.resize(
x,
(width, height),
interpolation=cv2.INTER_NEAREST,
)
return cv2.resize(
x,
(width, height),
interpolation=cv2.INTER_LINEAR,
)