From ab3eb488628db5593c895e0986c82dd79a7e6ec2 Mon Sep 17 00:00:00 2001 From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com> Date: Mon, 24 Oct 2022 11:29:37 +0100 Subject: [PATCH 1/4] Add padding transformation --- src/transformers/image_transforms.py | 185 ++++++++----------------- tests/test_image_transforms.py | 193 +++++++++++++++++++-------- 2 files changed, 187 insertions(+), 191 deletions(-) diff --git a/src/transformers/image_transforms.py b/src/transformers/image_transforms.py index d8d1d60935d7e4..be97eb8d88008e 100644 --- a/src/transformers/image_transforms.py +++ b/src/transformers/image_transforms.py @@ -18,7 +18,8 @@ import numpy as np -from transformers.utils import TensorType +from transformers.image_utils import PILImageResampling +from transformers.utils import ExplicitEnum from transformers.utils.import_utils import is_flax_available, is_tf_available, is_torch_available, is_vision_available @@ -426,145 +427,65 @@ def center_crop( return new_image -def _center_to_corners_format_torch(bboxes_center: "torch.Tensor") -> "torch.Tensor": - center_x, center_y, width, height = bboxes_center.unbind(-1) - bbox_corners = torch.stack( - # top left x, top left y, bottom right x, bottom right y - [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)], - dim=-1, - ) - return bbox_corners - - -def _center_to_corners_format_numpy(bboxes_center: np.ndarray) -> np.ndarray: - center_x, center_y, width, height = bboxes_center.T - bboxes_corners = np.stack( - # top left x, top left y, bottom right x, bottom right y - [center_x - 0.5 * width, center_y - 0.5 * height, center_x + 0.5 * width, center_y + 0.5 * height], - axis=-1, - ) - return bboxes_corners - - -def _center_to_corners_format_tf(bboxes_center: "tf.Tensor") -> "tf.Tensor": - center_x, center_y, width, height = tf.unstack(bboxes_center, axis=-1) - bboxes_corners = tf.stack( - # top left x, top left y, bottom right x, bottom right y - [center_x - 0.5 * width, center_y - 0.5 * height, center_x + 0.5 * width, center_y + 0.5 * height], - axis=-1, - ) - return bboxes_corners - - -# 2 functions below inspired by https://github.com/facebookresearch/detr/blob/master/util/box_ops.py -def center_to_corners_format(bboxes_center: TensorType) -> TensorType: +class PaddingMode(ExplicitEnum): """ - Converts bounding boxes from center format to corners format. - - center format: contains the coordinate for the center of the box and its width, height dimensions - (center_x, center_y, width, height) - corners format: contains the coodinates for the top-left and bottom-right corners of the box - (top_left_x, top_left_y, bottom_right_x, bottom_right_y) + Enum class for the different padding modes to use when padding images. """ - # Function is used during model forward pass, so we use the input framework if possible, without - # converting to numpy - if is_torch_tensor(bboxes_center): - return _center_to_corners_format_torch(bboxes_center) - elif isinstance(bboxes_center, np.ndarray): - return _center_to_corners_format_numpy(bboxes_center) - elif is_tf_tensor(bboxes_center): - return _center_to_corners_format_tf(bboxes_center) - - raise ValueError(f"Unsupported input type {type(bboxes_center)}") - - -def _corners_to_center_format_torch(bboxes_corners: "torch.Tensor") -> "torch.Tensor": - top_left_x, top_left_y, bottom_right_x, bottom_right_y = bboxes_corners.unbind(-1) - b = [ - (top_left_x + bottom_right_x) / 2, # center x - (top_left_y + bottom_right_y) / 2, # center y - (bottom_right_x - top_left_x), # width - (bottom_right_y - top_left_y), # height - ] - return torch.stack(b, dim=-1) - - -def _corners_to_center_format_numpy(bboxes_corners: np.ndarray) -> np.ndarray: - top_left_x, top_left_y, bottom_right_x, bottom_right_y = bboxes_corners.T - bboxes_center = np.stack( - [ - (top_left_x + bottom_right_x) / 2, # center x - (top_left_y + bottom_right_y) / 2, # center y - (bottom_right_x - top_left_x), # width - (bottom_right_y - top_left_y), # height - ], - axis=-1, - ) - return bboxes_center - - -def _corners_to_center_format_tf(bboxes_corners: "tf.Tensor") -> "tf.Tensor": - top_left_x, top_left_y, bottom_right_x, bottom_right_y = tf.unstack(bboxes_corners, axis=-1) - bboxes_center = tf.stack( - [ - (top_left_x + bottom_right_x) / 2, # center x - (top_left_y + bottom_right_y) / 2, # center y - (bottom_right_x - top_left_x), # width - (bottom_right_y - top_left_y), # height - ], - axis=-1, - ) - return bboxes_center + CONSTANT = "constant" + REFLECT = "reflect" + REPLICATE = "replicate" + SYMMETRIC = "symmetric" -def corners_to_center_format(bboxes_corners: TensorType) -> TensorType: - """ - Converts bounding boxes from corners format to center format. - corners format: contains the coodinates for the top-left and bottom-right corners of the box - (top_left_x, top_left_y, bottom_right_x, bottom_right_y) - center format: contains the coordinate for the center of the box and its the width, height dimensions - (center_x, center_y, width, height) +def pad( + image: np.ndarray, + padding: Union[int, Tuple[int, int], Iterable[Tuple[int, int]]], + mode: PaddingMode = PaddingMode.CONSTANT, + constant_values: Union[float, Iterable[float]] = 0.0, + data_format: Optional[Union[str, ChannelDimension]] = None, +) -> np.ndarray: """ - # Inverse function accepts different input types so implemented here too - if is_torch_tensor(bboxes_corners): - return _corners_to_center_format_torch(bboxes_corners) - elif isinstance(bboxes_corners, np.ndarray): - return _corners_to_center_format_numpy(bboxes_corners) - elif is_tf_tensor(bboxes_corners): - return _corners_to_center_format_tf(bboxes_corners) + Pads the `image` with the specified `padding` and `mode`. - raise ValueError(f"Unsupported input type {type(bboxes_corners)}") + Args: + image (`np.ndarray`): + The image to pad. + padding (`int` or `Tuple[int, int]` or `Iterable[Tuple[int, int]]`): + Padding to apply to the edges of each axis. ((before_1, after_1), … (before_N, after_N)) unique pad widths + for each axis. ((before, after),) yields same before and after pad for each axis. (pad,) or int is a + shortcut for before = after = pad width for all axes. + mode (`PaddingMode`): + The padding mode to use. Can be one of: + - `"constant"`: pads with a constant value. + - `"reflect"`: pads with the reflection of the vector mirrored on the first and last values of the + vector along each axis. + - `"replicate"`: pads with the replication of the last value on the edge of the array along each axis. + - `"symmetric"`: pads with the reflection of the vector mirrored along the edge of the array. + constant_values (`float` or `Iterable[float]`, *optional*): + The value to use for the padding if `mode` is `"constant"`. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + If unset, will use the inferred format of the input image. + Returns: + `np.ndarray`: The padded image. -# 2 functions below copied from https://github.com/cocodataset/panopticapi/blob/master/panopticapi/utils.py -# Copyright (c) 2018, Alexander Kirillov -# All rights reserved. -def rgb_to_id(color): """ - Converts RGB color to unique ID. - """ - if isinstance(color, np.ndarray) and len(color.shape) == 3: - if color.dtype == np.uint8: - color = color.astype(np.int32) - return color[:, :, 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2] - return int(color[0] + 256 * color[1] + 256 * 256 * color[2]) - + mode = PaddingMode(mode) + + if mode == PaddingMode.CONSTANT: + image = np.pad(image, padding, mode="constant", constant_values=constant_values) + elif mode == PaddingMode.REFLECT: + image = np.pad(image, padding, mode="reflect") + elif mode == PaddingMode.REPLICATE: + image = np.pad(image, padding, mode="edge") + elif mode == PaddingMode.SYMMETRIC: + image = np.pad(image, padding, mode="symmetric") + else: + raise ValueError(f"Invalid padding mode: {mode}") -def id_to_rgb(id_map): - """ - Converts unique ID to RGB color. - """ - if isinstance(id_map, np.ndarray): - id_map_copy = id_map.copy() - rgb_shape = tuple(list(id_map.shape) + [3]) - rgb_map = np.zeros(rgb_shape, dtype=np.uint8) - for i in range(3): - rgb_map[..., i] = id_map_copy % 256 - id_map_copy //= 256 - return rgb_map - color = [] - for _ in range(3): - color.append(id_map % 256) - id_map //= 256 - return color + image = to_channel_dimension_format(image, data_format) if data_format is not None else image + return image diff --git a/tests/test_image_transforms.py b/tests/test_image_transforms.py index d0b7c9ade13745..a01c5dc72aaba0 100644 --- a/tests/test_image_transforms.py +++ b/tests/test_image_transforms.py @@ -41,6 +41,7 @@ get_resize_output_image_size, id_to_rgb, normalize, + pad, resize, rgb_to_id, to_channel_dimension_format, @@ -229,63 +230,137 @@ def test_center_crop(self): self.assertEqual(cropped_image.shape, (300, 260, 3)) self.assertTrue(np.allclose(cropped_image, expected_image)) - def test_center_to_corners_format(self): - bbox_center = np.array([[10, 20, 4, 8], [15, 16, 3, 4]]) - expected = np.array([[8, 16, 12, 24], [13.5, 14, 16.5, 18]]) - self.assertTrue(np.allclose(center_to_corners_format(bbox_center), expected)) - - # Check that the function and inverse function are inverse of each other - self.assertTrue(np.allclose(corners_to_center_format(center_to_corners_format(bbox_center)), bbox_center)) - - def test_corners_to_center_format(self): - bbox_corners = np.array([[8, 16, 12, 24], [13.5, 14, 16.5, 18]]) - expected = np.array([[10, 20, 4, 8], [15, 16, 3, 4]]) - self.assertTrue(np.allclose(corners_to_center_format(bbox_corners), expected)) - - # Check that the function and inverse function are inverse of each other - self.assertTrue(np.allclose(center_to_corners_format(corners_to_center_format(bbox_corners)), bbox_corners)) - - def test_rgb_to_id(self): - # test list input - rgb = [125, 4, 255] - self.assertEqual(rgb_to_id(rgb), 16712829) - - # test numpy array input - color = np.array( - [ - [ - [213, 54, 165], - [88, 207, 39], - [156, 108, 128], - ], - [ - [183, 194, 46], - [137, 58, 88], - [114, 131, 233], - ], - ] - ) - expected = np.array([[10827477, 2608984, 8416412], [3064503, 5782153, 15303538]]) - self.assertTrue(np.allclose(rgb_to_id(color), expected)) - - def test_id_to_rgb(self): - # test int input - self.assertEqual(id_to_rgb(16712829), [125, 4, 255]) - - # test array input - id_array = np.array([[10827477, 2608984, 8416412], [3064503, 5782153, 15303538]]) - color = np.array( - [ - [ - [213, 54, 165], - [88, 207, 39], - [156, 108, 128], - ], - [ - [183, 194, 46], - [137, 58, 88], - [114, 131, 233], - ], - ] + def test_pad(self): + # fmt: off + image = np.array([[ + [0, 1], + [2, 3], + ]]) + # fmt: on + + # Test that exception is raised if unknown padding mode is specified + with self.assertRaises(ValueError): + pad(image, 10, mode="unknown") + + # Test image is padded equally on all sides is padding is an int + # fmt: off + expected_image = np.array([ + [[0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0]], + + [[0, 0, 0, 0], + [0, 0, 1, 0], + [0, 2, 3, 0], + [0, 0, 0, 0]], + + [[0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0]], + ]) + # fmt: on + self.assertTrue(np.allclose(expected_image, pad(image, 1))) + + # Test the left and right of each axis is padded (pad_left, pad_right) + # fmt: off + expected_image = np.array([ + [[0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0]], + + [[0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0]], + + [[0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 1, 0], + [0, 0, 2, 3, 0], + [0, 0, 0, 0, 0]], + + [[0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0]]]) + # fmt: on + self.assertTrue(np.allclose(expected_image, pad(image, (2, 1)))) + + # Test only one axis is padded (pad_left, pad_right) + # fmt: off + expected_image = np.array([[ + [9, 9], + [9, 9], + [0, 1], + [2, 3], + [9, 9] + ]]) + # fmt: on + self.assertTrue(np.allclose(expected_image, pad(image, ((0, 0), (2, 1), (0, 0)), constant_values=9))) + + # Test padding with a constant value + # fmt: off + expected_image = np.array([[ + [7, 7, 8, 8, 7], + [7, 7, 0, 1, 7], + [7, 7, 2, 3, 7], + [7, 7, 8, 8, 7], + [7, 7, 8, 8, 7] + ]]) + # fmt: on + self.assertTrue( + np.allclose(expected_image, pad(image, ((0, 0), (1, 2), (2, 1)), constant_values=((9, 9), (8, 8), (7, 7)))) ) - self.assertTrue(np.allclose(id_to_rgb(id_array), color)) + + # fmt: off + image = np.array([[ + [0, 1, 2], + [3, 4, 5], + [6, 7, 8], + ]]) + # fmt: on + + # Test padding with PaddingMode.REFLECT + # fmt: off + expected_image = np.array([[ + [5, 4, 3, 4, 5, 4], + [2, 1, 0, 1, 2, 1], + [5, 4, 3, 4, 5, 4], + [8, 7, 6, 7, 8, 7], + [5, 4, 3, 4, 5, 4], + [2, 1, 0, 1, 2, 1], + ]]) + # fmt: on + self.assertTrue(np.allclose(expected_image, pad(image, ((0, 0), (1, 2), (2, 1)), mode="reflect"))) + + # Test padding with PaddingMode.REPLICATE + # fmt: off + expected_image = np.array([[ + [0, 0, 0, 1, 2, 2], + [0, 0, 0, 1, 2, 2], + [3, 3, 3, 4, 5, 5], + [6, 6, 6, 7, 8, 8], + [6, 6, 6, 7, 8, 8], + [6, 6, 6, 7, 8, 8], + ]]) + # fmt: on + self.assertTrue(np.allclose(expected_image, pad(image, ((0, 0), (1, 2), (2, 1)), mode="replicate"))) + + # Test padding with PaddingMode.SYMMETRIC + # fmt: off + expected_image = np.array([[ + [1, 0, 0, 1, 2, 2], + [1, 0, 0, 1, 2, 2], + [4, 3, 3, 4, 5, 5], + [7, 6, 6, 7, 8, 8], + [7, 6, 6, 7, 8, 8], + [4, 3, 3, 4, 5, 5], + ]]) + # fmt: on + self.assertTrue(np.allclose(expected_image, pad(image, ((0, 0), (1, 2), (2, 1)), mode="symmetric"))) From 391e220b9bb9b823913c93f6ddf4cd32bddc8bd8 Mon Sep 17 00:00:00 2001 From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com> Date: Wed, 16 Nov 2022 17:27:08 +0000 Subject: [PATCH 2/4] Add in upstream changes --- src/transformers/image_transforms.py | 170 +++++++++++++++++++++++++-- tests/test_image_transforms.py | 61 ++++++++++ 2 files changed, 221 insertions(+), 10 deletions(-) diff --git a/src/transformers/image_transforms.py b/src/transformers/image_transforms.py index be97eb8d88008e..d20bb31834bab0 100644 --- a/src/transformers/image_transforms.py +++ b/src/transformers/image_transforms.py @@ -14,13 +14,18 @@ # limitations under the License. import warnings -from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Tuple, Union import numpy as np -from transformers.image_utils import PILImageResampling from transformers.utils import ExplicitEnum -from transformers.utils.import_utils import is_flax_available, is_tf_available, is_torch_available, is_vision_available +from transformers.utils.import_utils import ( + TensorType, + is_flax_available, + is_tf_available, + is_torch_available, + is_vision_available, +) if is_vision_available(): @@ -39,13 +44,14 @@ ) -if TYPE_CHECKING: - if is_torch_available(): - import torch - if is_tf_available(): - import tensorflow as tf - if is_flax_available(): - import jax.numpy as jnp +if is_torch_available(): + import torch + +if is_tf_available(): + import tensorflow as tf + +if is_flax_available(): + import jax.numpy as jnp def to_channel_dimension_format(image: np.ndarray, channel_dim: Union[ChannelDimension, str]) -> np.ndarray: @@ -427,6 +433,150 @@ def center_crop( return new_image +def _center_to_corners_format_torch(bboxes_center: "torch.Tensor") -> "torch.Tensor": + center_x, center_y, width, height = bboxes_center.unbind(-1) + bbox_corners = torch.stack( + # top left x, top left y, bottom right x, bottom right y + [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)], + dim=-1, + ) + return bbox_corners + + +def _center_to_corners_format_numpy(bboxes_center: np.ndarray) -> np.ndarray: + center_x, center_y, width, height = bboxes_center.T + bboxes_corners = np.stack( + # top left x, top left y, bottom right x, bottom right y + [center_x - 0.5 * width, center_y - 0.5 * height, center_x + 0.5 * width, center_y + 0.5 * height], + axis=-1, + ) + return bboxes_corners + + +def _center_to_corners_format_tf(bboxes_center: "tf.Tensor") -> "tf.Tensor": + center_x, center_y, width, height = tf.unstack(bboxes_center, axis=-1) + bboxes_corners = tf.stack( + # top left x, top left y, bottom right x, bottom right y + [center_x - 0.5 * width, center_y - 0.5 * height, center_x + 0.5 * width, center_y + 0.5 * height], + axis=-1, + ) + return bboxes_corners + + +# 2 functions below inspired by https://github.com/facebookresearch/detr/blob/master/util/box_ops.py +def center_to_corners_format(bboxes_center: TensorType) -> TensorType: + """ + Converts bounding boxes from center format to corners format. + + center format: contains the coordinate for the center of the box and its width, height dimensions + (center_x, center_y, width, height) + corners format: contains the coodinates for the top-left and bottom-right corners of the box + (top_left_x, top_left_y, bottom_right_x, bottom_right_y) + """ + # Function is used during model forward pass, so we use the input framework if possible, without + # converting to numpy + if is_torch_tensor(bboxes_center): + return _center_to_corners_format_torch(bboxes_center) + elif isinstance(bboxes_center, np.ndarray): + return _center_to_corners_format_numpy(bboxes_center) + elif is_tf_tensor(bboxes_center): + return _center_to_corners_format_tf(bboxes_center) + + raise ValueError(f"Unsupported input type {type(bboxes_center)}") + + +def _corners_to_center_format_torch(bboxes_corners: "torch.Tensor") -> "torch.Tensor": + top_left_x, top_left_y, bottom_right_x, bottom_right_y = bboxes_corners.unbind(-1) + b = [ + (top_left_x + bottom_right_x) / 2, # center x + (top_left_y + bottom_right_y) / 2, # center y + (bottom_right_x - top_left_x), # width + (bottom_right_y - top_left_y), # height + ] + return torch.stack(b, dim=-1) + + +def _corners_to_center_format_numpy(bboxes_corners: np.ndarray) -> np.ndarray: + top_left_x, top_left_y, bottom_right_x, bottom_right_y = bboxes_corners.T + bboxes_center = np.stack( + [ + (top_left_x + bottom_right_x) / 2, # center x + (top_left_y + bottom_right_y) / 2, # center y + (bottom_right_x - top_left_x), # width + (bottom_right_y - top_left_y), # height + ], + axis=-1, + ) + return bboxes_center + + +def _corners_to_center_format_tf(bboxes_corners: "tf.Tensor") -> "tf.Tensor": + top_left_x, top_left_y, bottom_right_x, bottom_right_y = tf.unstack(bboxes_corners, axis=-1) + bboxes_center = tf.stack( + [ + (top_left_x + bottom_right_x) / 2, # center x + (top_left_y + bottom_right_y) / 2, # center y + (bottom_right_x - top_left_x), # width + (bottom_right_y - top_left_y), # height + ], + axis=-1, + ) + return bboxes_center + + +def corners_to_center_format(bboxes_corners: TensorType) -> TensorType: + """ + Converts bounding boxes from corners format to center format. + + corners format: contains the coodinates for the top-left and bottom-right corners of the box + (top_left_x, top_left_y, bottom_right_x, bottom_right_y) + center format: contains the coordinate for the center of the box and its the width, height dimensions + (center_x, center_y, width, height) + """ + # Inverse function accepts different input types so implemented here too + if is_torch_tensor(bboxes_corners): + return _corners_to_center_format_torch(bboxes_corners) + elif isinstance(bboxes_corners, np.ndarray): + return _corners_to_center_format_numpy(bboxes_corners) + elif is_tf_tensor(bboxes_corners): + return _corners_to_center_format_tf(bboxes_corners) + + raise ValueError(f"Unsupported input type {type(bboxes_corners)}") + + +# 2 functions below copied from https://github.com/cocodataset/panopticapi/blob/master/panopticapi/utils.py +# Copyright (c) 2018, Alexander Kirillov +# All rights reserved. +def rgb_to_id(color): + """ + Converts RGB color to unique ID. + """ + if isinstance(color, np.ndarray) and len(color.shape) == 3: + if color.dtype == np.uint8: + color = color.astype(np.int32) + return color[:, :, 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2] + return int(color[0] + 256 * color[1] + 256 * 256 * color[2]) + + +def id_to_rgb(id_map): + """ + Converts unique ID to RGB color. + """ + if isinstance(id_map, np.ndarray): + id_map_copy = id_map.copy() + rgb_shape = tuple(list(id_map.shape) + [3]) + rgb_map = np.zeros(rgb_shape, dtype=np.uint8) + for i in range(3): + rgb_map[..., i] = id_map_copy % 256 + id_map_copy //= 256 + return rgb_map + color = [] + for _ in range(3): + color.append(id_map % 256) + id_map //= 256 + return color + + class PaddingMode(ExplicitEnum): """ Enum class for the different padding modes to use when padding images. diff --git a/tests/test_image_transforms.py b/tests/test_image_transforms.py index a01c5dc72aaba0..6e3df16a9c0c1c 100644 --- a/tests/test_image_transforms.py +++ b/tests/test_image_transforms.py @@ -230,6 +230,67 @@ def test_center_crop(self): self.assertEqual(cropped_image.shape, (300, 260, 3)) self.assertTrue(np.allclose(cropped_image, expected_image)) + def test_center_to_corners_format(self): + bbox_center = np.array([[10, 20, 4, 8], [15, 16, 3, 4]]) + expected = np.array([[8, 16, 12, 24], [13.5, 14, 16.5, 18]]) + self.assertTrue(np.allclose(center_to_corners_format(bbox_center), expected)) + + # Check that the function and inverse function are inverse of each other + self.assertTrue(np.allclose(corners_to_center_format(center_to_corners_format(bbox_center)), bbox_center)) + + def test_corners_to_center_format(self): + bbox_corners = np.array([[8, 16, 12, 24], [13.5, 14, 16.5, 18]]) + expected = np.array([[10, 20, 4, 8], [15, 16, 3, 4]]) + self.assertTrue(np.allclose(corners_to_center_format(bbox_corners), expected)) + + # Check that the function and inverse function are inverse of each other + self.assertTrue(np.allclose(center_to_corners_format(corners_to_center_format(bbox_corners)), bbox_corners)) + + def test_rgb_to_id(self): + # test list input + rgb = [125, 4, 255] + self.assertEqual(rgb_to_id(rgb), 16712829) + + # test numpy array input + color = np.array( + [ + [ + [213, 54, 165], + [88, 207, 39], + [156, 108, 128], + ], + [ + [183, 194, 46], + [137, 58, 88], + [114, 131, 233], + ], + ] + ) + expected = np.array([[10827477, 2608984, 8416412], [3064503, 5782153, 15303538]]) + self.assertTrue(np.allclose(rgb_to_id(color), expected)) + + def test_id_to_rgb(self): + # test int input + self.assertEqual(id_to_rgb(16712829), [125, 4, 255]) + + # test array input + id_array = np.array([[10827477, 2608984, 8416412], [3064503, 5782153, 15303538]]) + color = np.array( + [ + [ + [213, 54, 165], + [88, 207, 39], + [156, 108, 128], + ], + [ + [183, 194, 46], + [137, 58, 88], + [114, 131, 233], + ], + ] + ) + self.assertTrue(np.allclose(id_to_rgb(id_array), color)) + def test_pad(self): # fmt: off image = np.array([[ From e6a71d9d2e0f2635a8ea0f28c211e4a0c1666af2 Mon Sep 17 00:00:00 2001 From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com> Date: Wed, 16 Nov 2022 19:34:17 +0000 Subject: [PATCH 3/4] Update tests & docs --- .../en/internal/image_processing_utils.mdx | 2 + src/transformers/image_transforms.py | 53 ++++++++++--- tests/test_image_transforms.py | 79 ++++++++----------- 3 files changed, 76 insertions(+), 58 deletions(-) diff --git a/docs/source/en/internal/image_processing_utils.mdx b/docs/source/en/internal/image_processing_utils.mdx index f1658e55525d74..831458bedab164 100644 --- a/docs/source/en/internal/image_processing_utils.mdx +++ b/docs/source/en/internal/image_processing_utils.mdx @@ -29,6 +29,8 @@ Most of those are only useful if you are studying the code of the image processo [[autodoc]] image_transforms.normalize +[[autodoc]] image_transforms.pad + [[autodoc]] image_transforms.rgb_to_id [[autodoc]] image_transforms.rescale diff --git a/src/transformers/image_transforms.py b/src/transformers/image_transforms.py index d20bb31834bab0..578af5986e0aab 100644 --- a/src/transformers/image_transforms.py +++ b/src/transformers/image_transforms.py @@ -18,14 +18,8 @@ import numpy as np -from transformers.utils import ExplicitEnum -from transformers.utils.import_utils import ( - TensorType, - is_flax_available, - is_tf_available, - is_torch_available, - is_vision_available, -) +from transformers.utils import ExplicitEnum, TensorType +from transformers.utils.import_utils import is_flax_available, is_tf_available, is_torch_available, is_vision_available if is_vision_available(): @@ -594,17 +588,19 @@ def pad( mode: PaddingMode = PaddingMode.CONSTANT, constant_values: Union[float, Iterable[float]] = 0.0, data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, ) -> np.ndarray: """ - Pads the `image` with the specified `padding` and `mode`. + Pads the `image` with the specified (height, width) `padding` and `mode`. Args: image (`np.ndarray`): The image to pad. padding (`int` or `Tuple[int, int]` or `Iterable[Tuple[int, int]]`): - Padding to apply to the edges of each axis. ((before_1, after_1), … (before_N, after_N)) unique pad widths - for each axis. ((before, after),) yields same before and after pad for each axis. (pad,) or int is a - shortcut for before = after = pad width for all axes. + Padding to apply to the edges of the height, width axes. Can be one of three formats: + - ((before_height, after_height), (before_width, after_width)) unique pad widths for each axis. + - ((before, after),) yields same before and after pad for height and width. + - (pad,) or int is a shortcut for before = after = pad width for all axes. mode (`PaddingMode`): The padding mode to use. Can be one of: - `"constant"`: pads with a constant value. @@ -618,15 +614,46 @@ def pad( The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + If unset, will use same as the input image. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. If unset, will use the inferred format of the input image. Returns: `np.ndarray`: The padded image. """ - mode = PaddingMode(mode) + if input_data_format is None: + input_data_format = infer_channel_dimension_format(image) + + def _expand_for_data_format(values): + """ + Convert values to be in the format expected by np.pad based on the data format. + """ + if isinstance(values, (int, float)): + values = ((values, values), (values, values)) + elif isinstance(values, tuple) and len(values) == 1: + values = ((values[0], values[0]), (values[0], values[0])) + elif isinstance(values, tuple) and len(values) == 2 and isinstance(values[0], int): + values = (values, values) + elif isinstance(values, tuple) and len(values) == 2 and isinstance(values[0], tuple): + values = values + else: + raise ValueError(f"Unsupported format: {values}") + + # add 0 for channel dimension + values = ((0, 0), *values) if input_data_format == ChannelDimension.FIRST else (*values, (0, 0)) + + # Add additional padding if there's a batch dimension + values = (0, *values) if image.ndim == 4 else values + return values + + padding = _expand_for_data_format(padding) if mode == PaddingMode.CONSTANT: + constant_values = _expand_for_data_format(constant_values) image = np.pad(image, padding, mode="constant", constant_values=constant_values) elif mode == PaddingMode.REFLECT: image = np.pad(image, padding, mode="reflect") diff --git a/tests/test_image_transforms.py b/tests/test_image_transforms.py index 6e3df16a9c0c1c..618181b004d595 100644 --- a/tests/test_image_transforms.py +++ b/tests/test_image_transforms.py @@ -303,53 +303,30 @@ def test_pad(self): with self.assertRaises(ValueError): pad(image, 10, mode="unknown") + # Test that exception is raised if invalid padding is specified + with self.assertRaises(ValueError): + # Cannot pad on channel dimension + pad(image, (5, 10, 10)) + # Test image is padded equally on all sides is padding is an int # fmt: off expected_image = np.array([ - [[0, 0, 0, 0], - [0, 0, 0, 0], - [0, 0, 0, 0], - [0, 0, 0, 0]], - [[0, 0, 0, 0], [0, 0, 1, 0], [0, 2, 3, 0], [0, 0, 0, 0]], - - [[0, 0, 0, 0], - [0, 0, 0, 0], - [0, 0, 0, 0], - [0, 0, 0, 0]], ]) # fmt: on self.assertTrue(np.allclose(expected_image, pad(image, 1))) # Test the left and right of each axis is padded (pad_left, pad_right) # fmt: off - expected_image = np.array([ - [[0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0]], - - [[0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0]], - + expected_image = np.array( [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 1, 0], [0, 0, 2, 3, 0], - [0, 0, 0, 0, 0]], - - [[0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0]]]) + [0, 0, 0, 0, 0]]) # fmt: on self.assertTrue(np.allclose(expected_image, pad(image, (2, 1)))) @@ -363,21 +340,18 @@ def test_pad(self): [9, 9] ]]) # fmt: on - self.assertTrue(np.allclose(expected_image, pad(image, ((0, 0), (2, 1), (0, 0)), constant_values=9))) + self.assertTrue(np.allclose(expected_image, pad(image, ((2, 1), (0, 0)), constant_values=9))) # Test padding with a constant value # fmt: off expected_image = np.array([[ - [7, 7, 8, 8, 7], - [7, 7, 0, 1, 7], - [7, 7, 2, 3, 7], - [7, 7, 8, 8, 7], - [7, 7, 8, 8, 7] + [8, 8, 0, 1, 9], + [8, 8, 2, 3, 9], + [8, 8, 7, 7, 9], + [8, 8, 7, 7, 9] ]]) # fmt: on - self.assertTrue( - np.allclose(expected_image, pad(image, ((0, 0), (1, 2), (2, 1)), constant_values=((9, 9), (8, 8), (7, 7)))) - ) + self.assertTrue(np.allclose(expected_image, pad(image, ((0, 2), (2, 1)), constant_values=((6, 7), (8, 9))))) # fmt: off image = np.array([[ @@ -390,7 +364,6 @@ def test_pad(self): # Test padding with PaddingMode.REFLECT # fmt: off expected_image = np.array([[ - [5, 4, 3, 4, 5, 4], [2, 1, 0, 1, 2, 1], [5, 4, 3, 4, 5, 4], [8, 7, 6, 7, 8, 7], @@ -398,12 +371,11 @@ def test_pad(self): [2, 1, 0, 1, 2, 1], ]]) # fmt: on - self.assertTrue(np.allclose(expected_image, pad(image, ((0, 0), (1, 2), (2, 1)), mode="reflect"))) + self.assertTrue(np.allclose(expected_image, pad(image, ((0, 2), (2, 1)), mode="reflect"))) # Test padding with PaddingMode.REPLICATE # fmt: off expected_image = np.array([[ - [0, 0, 0, 1, 2, 2], [0, 0, 0, 1, 2, 2], [3, 3, 3, 4, 5, 5], [6, 6, 6, 7, 8, 8], @@ -411,12 +383,11 @@ def test_pad(self): [6, 6, 6, 7, 8, 8], ]]) # fmt: on - self.assertTrue(np.allclose(expected_image, pad(image, ((0, 0), (1, 2), (2, 1)), mode="replicate"))) + self.assertTrue(np.allclose(expected_image, pad(image, ((0, 2), (2, 1)), mode="replicate"))) # Test padding with PaddingMode.SYMMETRIC # fmt: off expected_image = np.array([[ - [1, 0, 0, 1, 2, 2], [1, 0, 0, 1, 2, 2], [4, 3, 3, 4, 5, 5], [7, 6, 6, 7, 8, 8], @@ -424,4 +395,22 @@ def test_pad(self): [4, 3, 3, 4, 5, 5], ]]) # fmt: on - self.assertTrue(np.allclose(expected_image, pad(image, ((0, 0), (1, 2), (2, 1)), mode="symmetric"))) + self.assertTrue(np.allclose(expected_image, pad(image, ((0, 2), (2, 1)), mode="symmetric"))) + + # Test we can specify the output data format + # Test padding with PaddingMode.REFLECT + # fmt: off + image = np.array([[ + [0, 1], + [2, 3], + ]]) + expected_image = np.array([ + [[0], [1], [0], [1], [0]], + [[2], [3], [2], [3], [2]], + [[0], [1], [0], [1], [0]], + [[2], [3], [2], [3], [2]] + ]) + # fmt: on + self.assertTrue( + np.allclose(expected_image, pad(image, ((0, 2), (2, 1)), mode="reflect", data_format="channels_last")) + ) From d0103495e6f1da0b4aae2ed1f405d20285149d6b Mon Sep 17 00:00:00 2001 From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com> Date: Thu, 17 Nov 2022 11:20:23 +0000 Subject: [PATCH 4/4] Code formatting tuples in docstring --- src/transformers/image_transforms.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/image_transforms.py b/src/transformers/image_transforms.py index 578af5986e0aab..1909d04e2a6736 100644 --- a/src/transformers/image_transforms.py +++ b/src/transformers/image_transforms.py @@ -598,9 +598,9 @@ def pad( The image to pad. padding (`int` or `Tuple[int, int]` or `Iterable[Tuple[int, int]]`): Padding to apply to the edges of the height, width axes. Can be one of three formats: - - ((before_height, after_height), (before_width, after_width)) unique pad widths for each axis. - - ((before, after),) yields same before and after pad for height and width. - - (pad,) or int is a shortcut for before = after = pad width for all axes. + - `((before_height, after_height), (before_width, after_width))` unique pad widths for each axis. + - `((before, after),)` yields same before and after pad for height and width. + - `(pad,)` or int is a shortcut for before = after = pad width for all axes. mode (`PaddingMode`): The padding mode to use. Can be one of: - `"constant"`: pads with a constant value.