Skip to content

Commit

Permalink
[Feature] Add BioMedical3DPad (#2383)
Browse files Browse the repository at this point in the history
## Motivation
Add the 3d pad transform for biomedical images, which follows the design
of the nnUNet.
  • Loading branch information
suyanzhou626 authored Jan 3, 2023
1 parent 3ca690b commit 6af2b8e
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 10 deletions.
9 changes: 5 additions & 4 deletions mmseg/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
from .potsdam import PotsdamDataset
from .stare import STAREDataset
# yapf: disable
from .transforms import (CLAHE, AdjustGamma, BioMedical3DRandomCrop,
BioMedicalGaussianBlur, BioMedicalGaussianNoise,
BioMedicalRandomGamma, GenerateEdge, LoadAnnotations,
from .transforms import (CLAHE, AdjustGamma, BioMedical3DPad,
BioMedical3DRandomCrop, BioMedicalGaussianBlur,
BioMedicalGaussianNoise, BioMedicalRandomGamma,
GenerateEdge, LoadAnnotations,
LoadBiomedicalAnnotation, LoadBiomedicalData,
LoadBiomedicalImageFromFile, LoadImageFromNDArray,
PackSegInputs, PhotoMetricDistortion, RandomCrop,
Expand All @@ -45,5 +46,5 @@
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
'DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge',
'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur',
'BioMedicalRandomGamma'
'BioMedicalRandomGamma', 'BioMedical3DPad'
]
12 changes: 6 additions & 6 deletions mmseg/datasets/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
LoadBiomedicalData, LoadBiomedicalImageFromFile,
LoadImageFromNDArray)
# yapf: disable
from .transforms import (CLAHE, AdjustGamma, BioMedical3DRandomCrop,
BioMedicalGaussianBlur, BioMedicalGaussianNoise,
BioMedicalRandomGamma, GenerateEdge,
PhotoMetricDistortion, RandomCrop, RandomCutOut,
RandomMosaic, RandomRotate, Rerange,
from .transforms import (CLAHE, AdjustGamma, BioMedical3DPad,
BioMedical3DRandomCrop, BioMedicalGaussianBlur,
BioMedicalGaussianNoise, BioMedicalRandomGamma,
GenerateEdge, PhotoMetricDistortion, RandomCrop,
RandomCutOut, RandomMosaic, RandomRotate, Rerange,
ResizeShortestEdge, ResizeToMultiple, RGB2Gray,
SegRescale)

Expand All @@ -20,5 +20,5 @@
'ResizeToMultiple', 'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
'ResizeShortestEdge', 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur',
'BioMedicalRandomGamma'
'BioMedicalRandomGamma', 'BioMedical3DPad'
]
132 changes: 132 additions & 0 deletions mmseg/datasets/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1805,3 +1805,135 @@ def __repr__(self):
repr_str += f'per_channel={self.per_channel},'
repr_str += f'retain_stats={self.retain_stats}'
return repr_str


@TRANSFORMS.register_module()
class BioMedical3DPad(BaseTransform):
"""Pad the biomedical 3d image & biomedical 3d semantic segmentation maps.
Required Keys:
- img (np.ndarry): Biomedical image with shape (N, Z, Y, X) by default,
N is the number of modalities.
- gt_seg_map (np.ndarray, optional): Biomedical seg map with shape
(Z, Y, X) by default.
Modified Keys:
- img (np.ndarry): Biomedical image with shape (N, Z, Y, X) by default,
N is the number of modalities.
- gt_seg_map (np.ndarray, optional): Biomedical seg map with shape
(Z, Y, X) by default.
Added Keys:
- pad_shape (Tuple[int, int, int]): The padded shape.
Args:
pad_shape (Tuple[int, int, int]): Fixed padding size.
Expected padding shape (Z, Y, X).
pad_val (float): Padding value for biomedical image.
The padding mode is set to "constant". The value
to be filled in padding area. Default: 0.
seg_pad_val (int): Padding value for biomedical 3d semantic
segmentation maps. The padding mode is set to "constant".
The value to be filled in padding area. Default: 0.
"""

def __init__(self,
pad_shape: Tuple[int, int, int],
pad_val: float = 0.,
seg_pad_val: int = 0) -> None:

# check pad_shape
assert pad_shape is not None
if not isinstance(pad_shape, tuple):
assert len(pad_shape) == 3

self.pad_shape = pad_shape
self.pad_val = pad_val
self.seg_pad_val = seg_pad_val

def _pad_img(self, results: dict) -> None:
"""Pad images according to ``self.pad_shape``
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: The dict contains the padded image and shape
information.
"""
padded_img = self._to_pad(
results['img'], pad_shape=self.pad_shape, pad_val=self.pad_val)

results['img'] = padded_img
results['pad_shape'] = padded_img.shape[1:]

def _pad_seg(self, results: dict) -> None:
"""Pad semantic segmentation map according to ``self.pad_shape`` if
``gt_seg_map`` is not None in results dict.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Update the padded gt seg map in dict.
"""
if results.get('gt_seg_map', None) is not None:
pad_gt_seg = self._to_pad(
results['gt_seg_map'][None, ...],
pad_shape=results['pad_shape'],
pad_val=self.seg_pad_val)
results['gt_seg_map'] = pad_gt_seg[1:]

@staticmethod
def _to_pad(img: np.ndarray,
pad_shape: Tuple[int, int, int],
pad_val: Union[int, float] = 0) -> np.ndarray:
"""Pad the given 3d image to a certain shape with specified padding
value.
Args:
img (ndarray): Biomedical image with shape (N, Z, Y, X)
to be padded. N is the number of modalities.
pad_shape (Tuple[int,int,int]): Expected padding shape (Z, Y, X).
pad_val (float, int): Values to be filled in padding areas
and the padding_mode is set to 'constant'. Default: 0.
Returns:
ndarray: The padded image.
"""
# compute pad width
d = max(pad_shape[0] - img.shape[1], 0)
pad_d = (d // 2, d - d // 2)
h = max(pad_shape[1] - img.shape[2], 0)
pad_h = (h // 2, h - h // 2)
w = max(pad_shape[2] - img.shape[2], 0)
pad_w = (w // 2, w - w // 2)

pad_list = [(0, 0), pad_d, pad_h, pad_w]

img = np.pad(img, pad_list, mode='constant', constant_values=pad_val)
return img

def transform(self, results: dict) -> dict:
"""Call function to pad images, semantic segmentation maps.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Updated result dict.
"""
self._pad_img(results)
self._pad_seg(results)

return results

def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'pad_shape={self.pad_shape}, '
repr_str += f'pad_val={self.pad_val}), '
repr_str += f'seg_pad_val={self.seg_pad_val})'
return repr_str
43 changes: 43 additions & 0 deletions tests/test_datasets/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,3 +951,46 @@ def test_BioMedicalRandomGamma():
results = transform2(results)
transformed_img = results['img']
assert origin_img.shape == transformed_img.shape


def test_BioMedical3DPad():
# test assertion.
with pytest.raises(AssertionError):
transform = dict(type='BioMedical3DPad', pad_shape=None)
TRANSFORMS.build(transform)

with pytest.raises(AssertionError):
transform = dict(type='BioMedical3DPad', pad_shape=[256, 256])
TRANSFORMS.build(transform)

data_info1 = dict(img=np.random.random((8, 6, 4, 4)))

transform = dict(type='BioMedical3DPad', pad_shape=(6, 6, 6))
transform = TRANSFORMS.build(transform)
results = transform(copy.deepcopy(data_info1))
assert results['img'].shape[1:] == (6, 6, 6)
assert results['pad_shape'] == (6, 6, 6)

transform = dict(type='BioMedical3DPad', pad_shape=(4, 6, 6))
transform = TRANSFORMS.build(transform)
results = transform(copy.deepcopy(data_info1))
assert results['img'].shape[1:] == (6, 6, 6)
assert results['pad_shape'] == (6, 6, 6)

data_info2 = dict(
img=np.random.random((8, 6, 4, 4)),
gt_seg_map=np.random.randint(0, 2, (6, 4, 4)))

transform = dict(type='BioMedical3DPad', pad_shape=(6, 6, 6))
transform = TRANSFORMS.build(transform)
results = transform(copy.deepcopy(data_info2))
assert results['img'].shape[1:] == (6, 6, 6)
assert results['gt_seg_map'].shape[1:] == (6, 6, 6)
assert results['pad_shape'] == (6, 6, 6)

transform = dict(type='BioMedical3DPad', pad_shape=(4, 6, 6))
transform = TRANSFORMS.build(transform)
results = transform(copy.deepcopy(data_info2))
assert results['img'].shape[1:] == (6, 6, 6)
assert results['gt_seg_map'].shape[1:] == (6, 6, 6)
assert results['pad_shape'] == (6, 6, 6)

0 comments on commit 6af2b8e

Please sign in to comment.