diff --git a/mmseg/__init__.py b/mmseg/__init__.py index ffc848a934..f301a5dc34 100644 --- a/mmseg/__init__.py +++ b/mmseg/__init__.py @@ -3,7 +3,7 @@ from .version import __version__, version_info MMCV_MIN = '1.1.4' -MMCV_MAX = '1.2.0' +MMCV_MAX = '1.3.0' def digit_version(version_str): diff --git a/mmseg/datasets/pipelines/transforms.py b/mmseg/datasets/pipelines/transforms.py index 74be564c57..c138b21c20 100644 --- a/mmseg/datasets/pipelines/transforms.py +++ b/mmseg/datasets/pipelines/transforms.py @@ -650,6 +650,42 @@ def __repr__(self): return repr_str +@PIPELINES.register_module() +class AdjustGamma(object): + """Using gamma correction to process the image. + + Args: + gamma (float or int): Gamma value used in gamma correction. + Default: 1.0. + """ + + def __init__(self, gamma=1.0): + assert isinstance(gamma, float) or isinstance(gamma, int) + assert gamma > 0 + self.gamma = gamma + inv_gamma = 1.0 / gamma + self.table = np.array([(i / 255.0)**inv_gamma * 255 + for i in np.arange(256)]).astype('uint8') + + def __call__(self, results): + """Call function to process the image with gamma correction. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Processed results. + """ + + results['img'] = mmcv.lut_transform( + np.array(results['img'], dtype=np.uint8), self.table) + + return results + + def __repr__(self): + return self.__class__.__name__ + f'(gamma={self.gamma})' + + @PIPELINES.register_module() class SegRescale(object): """Rescale semantic segmentation maps. diff --git a/tests/test_data/test_transform.py b/tests/test_data/test_transform.py index d4f81ecc6f..19cf6d5337 100644 --- a/tests/test_data/test_transform.py +++ b/tests/test_data/test_transform.py @@ -330,6 +330,42 @@ def test_rgb2gray(): assert results['ori_shape'] == (h, w, c) +def test_adjust_gamma(): + # test assertion if gamma <= 0 + with pytest.raises(AssertionError): + transform = dict(type='AdjustGamma', gamma=0) + build_from_cfg(transform, PIPELINES) + + # test assertion if gamma is list + with pytest.raises(AssertionError): + transform = dict(type='AdjustGamma', gamma=[1.2]) + build_from_cfg(transform, PIPELINES) + + # test with gamma = 1.2 + transform = dict(type='AdjustGamma', gamma=1.2) + transform = build_from_cfg(transform, PIPELINES) + results = dict() + img = mmcv.imread( + osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color') + original_img = copy.deepcopy(img) + results['img'] = img + results['img_shape'] = img.shape + results['ori_shape'] = img.shape + # Set initial values for default meta_keys + results['pad_shape'] = img.shape + results['scale_factor'] = 1.0 + + results = transform(results) + + inv_gamma = 1.0 / 1.2 + table = np.array([((i / 255.0)**inv_gamma) * 255 + for i in np.arange(0, 256)]).astype('uint8') + converted_img = mmcv.lut_transform( + np.array(original_img, dtype=np.uint8), table) + assert np.allclose(results['img'], converted_img) + assert str(transform) == f'AdjustGamma(gamma={1.2})' + + def test_rerange(): # test assertion if min_value or max_value is illegal with pytest.raises(AssertionError):