Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add padding image transformation #19838

Merged
merged 4 commits into from
Nov 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/internal/image_processing_utils.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ Most of those are only useful if you are studying the code of the image processo

[[autodoc]] image_transforms.normalize

[[autodoc]] image_transforms.pad

[[autodoc]] image_transforms.rgb_to_id

[[autodoc]] image_transforms.rescale
Expand Down
116 changes: 107 additions & 9 deletions src/transformers/image_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
# limitations under the License.

import warnings
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union
from typing import Iterable, List, Optional, Tuple, Union

import numpy as np

from transformers.utils import TensorType
from transformers.utils import ExplicitEnum, TensorType
from transformers.utils.import_utils import is_flax_available, is_tf_available, is_torch_available, is_vision_available


Expand All @@ -38,13 +38,14 @@
)


if TYPE_CHECKING:
if is_torch_available():
import torch
if is_tf_available():
import tensorflow as tf
if is_flax_available():
import jax.numpy as jnp
if is_torch_available():
import torch

if is_tf_available():
import tensorflow as tf

if is_flax_available():
import jax.numpy as jnp


def to_channel_dimension_format(image: np.ndarray, channel_dim: Union[ChannelDimension, str]) -> np.ndarray:
Expand Down Expand Up @@ -568,3 +569,100 @@ def id_to_rgb(id_map):
color.append(id_map % 256)
id_map //= 256
return color


class PaddingMode(ExplicitEnum):
"""
Enum class for the different padding modes to use when padding images.
"""

CONSTANT = "constant"
REFLECT = "reflect"
REPLICATE = "replicate"
SYMMETRIC = "symmetric"


def pad(
image: np.ndarray,
padding: Union[int, Tuple[int, int], Iterable[Tuple[int, int]]],
mode: PaddingMode = PaddingMode.CONSTANT,
constant_values: Union[float, Iterable[float]] = 0.0,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""
Pads the `image` with the specified (height, width) `padding` and `mode`.

Args:
image (`np.ndarray`):
The image to pad.
padding (`int` or `Tuple[int, int]` or `Iterable[Tuple[int, int]]`):
Padding to apply to the edges of the height, width axes. Can be one of three formats:
- `((before_height, after_height), (before_width, after_width))` unique pad widths for each axis.
- `((before, after),)` yields same before and after pad for height and width.
- `(pad,)` or int is a shortcut for before = after = pad width for all axes.
mode (`PaddingMode`):
The padding mode to use. Can be one of:
- `"constant"`: pads with a constant value.
- `"reflect"`: pads with the reflection of the vector mirrored on the first and last values of the
vector along each axis.
- `"replicate"`: pads with the replication of the last value on the edge of the array along each axis.
- `"symmetric"`: pads with the reflection of the vector mirrored along the edge of the array.
constant_values (`float` or `Iterable[float]`, *optional*):
The value to use for the padding if `mode` is `"constant"`.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
If unset, will use same as the input image.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
If unset, will use the inferred format of the input image.

Returns:
`np.ndarray`: The padded image.

"""
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)

def _expand_for_data_format(values):
"""
Convert values to be in the format expected by np.pad based on the data format.
"""
if isinstance(values, (int, float)):
values = ((values, values), (values, values))
elif isinstance(values, tuple) and len(values) == 1:
values = ((values[0], values[0]), (values[0], values[0]))
elif isinstance(values, tuple) and len(values) == 2 and isinstance(values[0], int):
values = (values, values)
elif isinstance(values, tuple) and len(values) == 2 and isinstance(values[0], tuple):
values = values
else:
raise ValueError(f"Unsupported format: {values}")

# add 0 for channel dimension
values = ((0, 0), *values) if input_data_format == ChannelDimension.FIRST else (*values, (0, 0))

# Add additional padding if there's a batch dimension
values = (0, *values) if image.ndim == 4 else values
return values

padding = _expand_for_data_format(padding)

if mode == PaddingMode.CONSTANT:
constant_values = _expand_for_data_format(constant_values)
image = np.pad(image, padding, mode="constant", constant_values=constant_values)
elif mode == PaddingMode.REFLECT:
image = np.pad(image, padding, mode="reflect")
elif mode == PaddingMode.REPLICATE:
image = np.pad(image, padding, mode="edge")
elif mode == PaddingMode.SYMMETRIC:
image = np.pad(image, padding, mode="symmetric")
else:
raise ValueError(f"Invalid padding mode: {mode}")

image = to_channel_dimension_format(image, data_format) if data_format is not None else image
return image
125 changes: 125 additions & 0 deletions tests/test_image_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
get_resize_output_image_size,
id_to_rgb,
normalize,
pad,
resize,
rgb_to_id,
to_channel_dimension_format,
Expand Down Expand Up @@ -289,3 +290,127 @@ def test_id_to_rgb(self):
]
)
self.assertTrue(np.allclose(id_to_rgb(id_array), color))

def test_pad(self):
# fmt: off
image = np.array([[
[0, 1],
[2, 3],
]])
# fmt: on

# Test that exception is raised if unknown padding mode is specified
with self.assertRaises(ValueError):
pad(image, 10, mode="unknown")

# Test that exception is raised if invalid padding is specified
with self.assertRaises(ValueError):
# Cannot pad on channel dimension
pad(image, (5, 10, 10))

# Test image is padded equally on all sides is padding is an int
# fmt: off
expected_image = np.array([
[[0, 0, 0, 0],
[0, 0, 1, 0],
[0, 2, 3, 0],
[0, 0, 0, 0]],
])
# fmt: on
self.assertTrue(np.allclose(expected_image, pad(image, 1)))

# Test the left and right of each axis is padded (pad_left, pad_right)
# fmt: off
expected_image = np.array(
[[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 1, 0],
[0, 0, 2, 3, 0],
[0, 0, 0, 0, 0]])
# fmt: on
self.assertTrue(np.allclose(expected_image, pad(image, (2, 1))))

# Test only one axis is padded (pad_left, pad_right)
# fmt: off
expected_image = np.array([[
[9, 9],
[9, 9],
[0, 1],
[2, 3],
[9, 9]
]])
# fmt: on
self.assertTrue(np.allclose(expected_image, pad(image, ((2, 1), (0, 0)), constant_values=9)))

# Test padding with a constant value
# fmt: off
expected_image = np.array([[
[8, 8, 0, 1, 9],
[8, 8, 2, 3, 9],
[8, 8, 7, 7, 9],
[8, 8, 7, 7, 9]
]])
# fmt: on
self.assertTrue(np.allclose(expected_image, pad(image, ((0, 2), (2, 1)), constant_values=((6, 7), (8, 9)))))

# fmt: off
image = np.array([[
[0, 1, 2],
[3, 4, 5],
[6, 7, 8],
]])
# fmt: on

# Test padding with PaddingMode.REFLECT
# fmt: off
expected_image = np.array([[
[2, 1, 0, 1, 2, 1],
[5, 4, 3, 4, 5, 4],
[8, 7, 6, 7, 8, 7],
[5, 4, 3, 4, 5, 4],
[2, 1, 0, 1, 2, 1],
]])
# fmt: on
self.assertTrue(np.allclose(expected_image, pad(image, ((0, 2), (2, 1)), mode="reflect")))

# Test padding with PaddingMode.REPLICATE
# fmt: off
expected_image = np.array([[
[0, 0, 0, 1, 2, 2],
[3, 3, 3, 4, 5, 5],
[6, 6, 6, 7, 8, 8],
[6, 6, 6, 7, 8, 8],
[6, 6, 6, 7, 8, 8],
]])
# fmt: on
self.assertTrue(np.allclose(expected_image, pad(image, ((0, 2), (2, 1)), mode="replicate")))

# Test padding with PaddingMode.SYMMETRIC
# fmt: off
expected_image = np.array([[
[1, 0, 0, 1, 2, 2],
[4, 3, 3, 4, 5, 5],
[7, 6, 6, 7, 8, 8],
[7, 6, 6, 7, 8, 8],
[4, 3, 3, 4, 5, 5],
]])
# fmt: on
self.assertTrue(np.allclose(expected_image, pad(image, ((0, 2), (2, 1)), mode="symmetric")))

# Test we can specify the output data format
# Test padding with PaddingMode.REFLECT
# fmt: off
image = np.array([[
[0, 1],
[2, 3],
]])
expected_image = np.array([
[[0], [1], [0], [1], [0]],
[[2], [3], [2], [3], [2]],
[[0], [1], [0], [1], [0]],
[[2], [3], [2], [3], [2]]
])
# fmt: on
self.assertTrue(
np.allclose(expected_image, pad(image, ((0, 2), (2, 1)), mode="reflect", data_format="channels_last"))
)