Skip to content

Commit

Permalink
Add support to pass label keys for dict input (#879)
Browse files Browse the repository at this point in the history
* Add support to pass label keys for dict input

* Add reference to @josegcpa's code

Co-authored-by: josegcpa <[email protected]>
  • Loading branch information
fepegar and josegcpa authored May 15, 2022
1 parent 250f67b commit 149a7be
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 9 deletions.
22 changes: 22 additions & 0 deletions tests/transforms/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,3 +322,25 @@ def get_mask(label):
mask = transform.get_mask_from_bounds(3 * (0, 1), tensor)
assert mask[0, 0, 0, 0] == 1
assert mask.sum() == 1

def test_label_keys(self):
# Adapted from the issue in which the feature was requested:
# https:/fepegar/torchio/issues/866#issue-1222255576
size = 1, 10, 10, 10
image = torch.rand(size)
num_classes = 2 # excluding background
label = torch.randint(num_classes + 1, size)

data_dict = {'image': image, 'label': label}

transform = tio.RandomAffine(
include=['image', 'label'],
label_keys=['label'],
)
transformed_label = transform(data_dict)['label']

# If the image is indeed transformed as a label map, nearest neighbor
# interpolation is used by default and therefore no intermediate values
# can exist in the output
num_unique_values = len(torch.unique(transformed_label))
assert num_unique_values <= num_classes + 1
19 changes: 14 additions & 5 deletions torchio/transforms/data_parser.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, List, Sequence, Union
from typing import Optional, Sequence, Union

import torch
import numpy as np
Expand All @@ -7,7 +7,7 @@

from ..typing import TypeData
from ..data.subject import Subject
from ..data.image import Image, ScalarImage
from ..data.image import Image, LabelMap, ScalarImage
from ..data.io import nib_to_sitk, sitk_to_nib


Expand All @@ -27,9 +27,11 @@ def __init__(
self,
data: TypeTransformInput,
keys: Optional[Sequence[str]] = None,
label_keys: Optional[Sequence[str]] = None,
):
self.data = data
self.keys = keys
self.label_keys = label_keys
self.default_image_name = 'default_image_name'
self.is_tensor = False
self.is_array = False
Expand Down Expand Up @@ -65,7 +67,11 @@ def get_subject(self):
' https://torchio.readthedocs.io/transforms/transforms.html#torchio.transforms.Transform' # noqa: E501
)
raise RuntimeError(message)
subject = self._get_subject_from_dict(self.data, self.keys)
subject = self._get_subject_from_dict(
self.data,
self.keys,
self.label_keys,
)
self.is_dict = True
else:
raise ValueError(f'Input type not recognized: {type(self.data)}')
Expand Down Expand Up @@ -130,12 +136,15 @@ def _get_subject_from_image(self, image: Image) -> Subject:
@staticmethod
def _get_subject_from_dict(
data: dict,
image_keys: List[str],
image_keys: Sequence[str],
label_keys: Optional[Sequence[str]] = None,
) -> Subject:
subject_dict = {}
label_keys = {} if label_keys is None else label_keys
for key, value in data.items():
if key in image_keys:
value = ScalarImage(tensor=value)
class_ = LabelMap if key in label_keys else ScalarImage
value = class_(tensor=value)
subject_dict[key] = value
return Subject(subject_dict)

Expand Down
18 changes: 14 additions & 4 deletions torchio/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import warnings
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Union, Tuple, Optional, Dict
from typing import Union, Tuple, Optional, Dict, Sequence

import torch
import numpy as np
Expand Down Expand Up @@ -69,12 +69,16 @@ class Transform(ABC):
transform such as :class:`~torchio.transforms.RandomBlur`,
the transform will be only applied to the MRI, as the label map is
excluded by default by spatial transforms.
keep: Dictionary with the names of the images that will be kept in the
subject and their new names.
keep: Dictionary with the names of the input images that will be kept
in the output and their new names. For example:
``{'t1': 't1_original'}``. This might be useful for autoencoders
or registration tasks.
parse_input: If ``True``, the input will be converted to an instance of
:class:`~torchio.Subject`. This is used internally by some special
transforms like
:class:`~torchio.transforms.augmentation.composition.Compose`.
label_keys: If the input is a dictionary, names of images that
correspond to label maps.
"""
def __init__(
self,
Expand All @@ -85,6 +89,7 @@ def __init__(
keys: TypeKeys = None,
keep: Optional[Dict[str, str]] = None,
parse_input: bool = True,
label_keys: Optional[Sequence[str]] = None,
):
self.probability = self.parse_probability(p)
self.copy = copy
Expand All @@ -99,6 +104,7 @@ def __init__(
include, exclude)
self.keep = keep
self.parse_input = parse_input
self.label_keys = label_keys
# args_names is the sequence of parameters from self that need to be
# passed to a non-random version of a random transform. They are also
# used to invert invertible transforms
Expand All @@ -125,7 +131,11 @@ def __call__(

# Some transforms such as Compose should not modify the input data
if self.parse_input:
data_parser = DataParser(data, keys=self.include)
data_parser = DataParser(
data,
keys=self.include,
label_keys=self.label_keys,
)
subject = data_parser.get_subject()
else:
subject = data
Expand Down

0 comments on commit 149a7be

Please sign in to comment.