diff --git a/src/torchio/transforms/preprocessing/intensity/rescale.py b/src/torchio/transforms/preprocessing/intensity/rescale.py index 34bcc4dd..2051f18f 100644 --- a/src/torchio/transforms/preprocessing/intensity/rescale.py +++ b/src/torchio/transforms/preprocessing/intensity/rescale.py @@ -1,6 +1,5 @@ import warnings from typing import Optional -from typing import Tuple import numpy as np import torch @@ -48,7 +47,7 @@ def __init__( out_min_max: TypeRangeFloat = (0, 1), percentiles: TypeRangeFloat = (0, 100), masking_method: TypeMaskingMethod = None, - in_min_max: Optional[Tuple[float, float]] = None, + in_min_max: Optional[TypeRangeFloat] = None, **kwargs ): super().__init__(masking_method=masking_method, **kwargs) @@ -60,7 +59,21 @@ def __init__( self.percentiles = self._parse_range( percentiles, 'percentiles', min_constraint=0, max_constraint=100, ) - self.args_names = ['out_min_max', 'percentiles', 'masking_method'] + + self.in_min: Optional[float] + self.in_max: Optional[float] + if self.in_min_max is not None: + self.in_min, self.in_max = self._parse_range( + self.in_min_max, 'in_min_max', + ) + else: + self.in_min = None + self.in_max = None + + self.args_names = [ + 'out_min_max', 'percentiles', 'masking_method', 'in_min_max', + ] + self.invert_transform = False def apply_normalization( self, @@ -91,10 +104,13 @@ def rescale( cutoff = np.percentile(values, self.percentiles) np.clip(array, *cutoff, out=array) # type: ignore[call-overload] if self.in_min_max is None: - in_min, in_max = array.min(), array.max() - else: - in_min, in_max = self.in_min_max - in_range = in_max - in_min + self.in_min_max = self._parse_range( + (array.min(), array.max()), 'in_min_max', + ) + self.in_min, self.in_max = self.in_min_max + assert self.in_min is not None + assert self.in_max is not None + in_range = self.in_max - self.in_min if in_range == 0: # should this be compared using a tolerance? message = ( f'Rescaling image "{image_name}" not possible' @@ -102,9 +118,15 @@ def rescale( ) warnings.warn(message, RuntimeWarning, stacklevel=2) return tensor - array -= in_min - array /= in_range out_range = self.out_max - self.out_min - array *= out_range - array += self.out_min + if self.invert_transform: + array -= self.out_min + array /= out_range + array *= in_range + array += self.in_min + else: + array -= self.in_min + array /= in_range + array *= out_range + array += self.out_min return torch.as_tensor(array) diff --git a/tests/transforms/preprocessing/test_rescale.py b/tests/transforms/preprocessing/test_rescale.py index 1b2f01e6..4053bebb 100644 --- a/tests/transforms/preprocessing/test_rescale.py +++ b/tests/transforms/preprocessing/test_rescale.py @@ -107,3 +107,14 @@ def test_empty_mask(self): rescale = tio.RescaleIntensity(masking_method='label') with pytest.warns(RuntimeWarning): rescale(subject) + + def test_invert_rescaling(self): + torch.manual_seed(0) + transform = tio.RescaleIntensity(out_min_max=(0, 1)) + data = torch.rand(1, 2, 3, 4).double() + subject = tio.Subject(t1=tio.ScalarImage(tensor=data)) + transformed = transform(subject) + assert transformed.t1.data.min() == 0 + assert transformed.t1.data.max() == 1 + inverted = transformed.apply_inverse_transform() + self.assert_tensor_almost_equal(inverted.t1.data, data)