Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] add AdjustGamma transform #232

Merged
merged 12 commits into from
Dec 2, 2020
2 changes: 1 addition & 1 deletion mmseg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .version import __version__, version_info

MMCV_MIN = '1.1.4'
MMCV_MAX = '1.2.0'
MMCV_MAX = '1.3.0'


def digit_version(version_str):
Expand Down
36 changes: 36 additions & 0 deletions mmseg/datasets/pipelines/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,42 @@ def __repr__(self):
return repr_str


@PIPELINES.register_module()
class AdjustGamma(object):
"""Using gamma correction to process the image.

Args:
gamma (float or int): Gamma value used in gamma correction.
Default: 1.0.
"""

def __init__(self, gamma=1.0):
assert isinstance(gamma, float) or isinstance(gamma, int)
assert gamma > 0
self.gamma = gamma
inv_gamma = 1.0 / gamma
self.table = np.array([(i / 255.0)**inv_gamma * 255
for i in np.arange(256)]).astype('uint8')

def __call__(self, results):
"""Call function to process the image with gamma correction.

Args:
results (dict): Result dict from loading pipeline.

Returns:
dict: Processed results.
"""

results['img'] = mmcv.lut_transform(
np.array(results['img'], dtype=np.uint8), self.table)

return results

def __repr__(self):
return self.__class__.__name__ + f'(gamma={self.gamma})'


@PIPELINES.register_module()
class SegRescale(object):
"""Rescale semantic segmentation maps.
Expand Down
36 changes: 36 additions & 0 deletions tests/test_data/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,42 @@ def test_rgb2gray():
assert results['ori_shape'] == (h, w, c)


def test_adjust_gamma():
# test assertion if gamma <= 0
with pytest.raises(AssertionError):
transform = dict(type='AdjustGamma', gamma=0)
build_from_cfg(transform, PIPELINES)

# test assertion if gamma is list
with pytest.raises(AssertionError):
transform = dict(type='AdjustGamma', gamma=[1.2])
build_from_cfg(transform, PIPELINES)

# test with gamma = 1.2
transform = dict(type='AdjustGamma', gamma=1.2)
transform = build_from_cfg(transform, PIPELINES)
results = dict()
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
original_img = copy.deepcopy(img)
results['img'] = img
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0

results = transform(results)

inv_gamma = 1.0 / 1.2
table = np.array([((i / 255.0)**inv_gamma) * 255
for i in np.arange(0, 256)]).astype('uint8')
converted_img = mmcv.lut_transform(
np.array(original_img, dtype=np.uint8), table)
assert np.allclose(results['img'], converted_img)
assert str(transform) == f'AdjustGamma(gamma={1.2})'


def test_rerange():
# test assertion if min_value or max_value is illegal
with pytest.raises(AssertionError):
Expand Down