Skip to content

Commit

Permalink
Add CopyAffine transform (#584)
Browse files Browse the repository at this point in the history
* Discussion #489: add CopyAffine transform

* review changes discussion #489

* update documentation for #489

* Minor edits in tests

* Edit transform and tests

* Fix docstring

Co-authored-by: STEFF Alban <[email protected]>
Co-authored-by: Fernando <[email protected]>
  • Loading branch information
3 people authored Jul 5, 2021
1 parent 3a869f6 commit 6c7a188
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 0 deletions.
7 changes: 7 additions & 0 deletions docs/source/transforms/preprocessing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ Spatial
:show-inheritance:


:class:`CopyAffine`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: CopyAffine
:show-inheritance:


:class:`Crop`
~~~~~~~~~~~~~

Expand Down
29 changes: 29 additions & 0 deletions tests/transforms/preprocessing/test_copy_affine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import torch

import torchio as tio
from ...utils import TorchioTestCase


class TestCopyAffine(TorchioTestCase):
"""Tests for `CopyAffine`."""

def test_missing_reference(self):
transform = tio.CopyAffine(target='missing')
with self.assertRaises(RuntimeError):
transform(self.sample_subject)

def test_wrong_target_type(self):
with self.assertRaises(ValueError):
tio.CopyAffine(target=[1])

def test_same_affine(self):
image = tio.ScalarImage(tensor=torch.rand(2, 2, 2, 2))
mask = tio.LabelMap(tensor=torch.rand(2, 2, 2, 2))
mask.affine *= 1.1
subject = tio.Subject(t1=image, mask=mask)
transform = tio.CopyAffine('t1')
transformed = transform(subject)
self.assertTensorEqual(
transformed['t1'].affine,
transformed['mask'].affine,
)
2 changes: 2 additions & 0 deletions tests/transforms/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def get_transform(self, channels, is_3d=True, labels=True):
tio.ToCanonical(),
tio.RandomAnisotropy(downsampling=(1.75, 2), axes=axes_downsample),
tio.EnsureShapeMultiple(2, method='crop'),
tio.CopyAffine(channels[0]),
tio.Resample((1, 1.1, 1.25)),
tio.RandomFlip(axes=flip_axes, flip_probability=1),
tio.RandomMotion(),
Expand Down Expand Up @@ -132,6 +133,7 @@ def test_transforms_subject_4d(self):
'RemapLabels',
'RemoveLabels',
'SequentialLabels',
'CopyAffine',
)
if transform.name not in exclude:
self.assertEqual(
Expand Down
2 changes: 2 additions & 0 deletions torchio/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from .preprocessing import Crop
from .preprocessing import Resample
from .preprocessing import CropOrPad
from .preprocessing import CopyAffine
from .preprocessing import ToCanonical
from .preprocessing import ZNormalization
from .preprocessing import RescaleIntensity
Expand Down Expand Up @@ -87,6 +88,7 @@
'RescaleIntensity',
'Mask',
'CropOrPad',
'CopyAffine',
'EnsureShapeMultiple',
'train_histogram',
'OneHot',
Expand Down
2 changes: 2 additions & 0 deletions torchio/transforms/preprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .spatial.resample import Resample
from .spatial.crop_or_pad import CropOrPad
from .spatial.to_canonical import ToCanonical
from .spatial.copy_affine import CopyAffine
from .spatial.ensure_shape_multiple import EnsureShapeMultiple

from .intensity.mask import Mask
Expand All @@ -24,6 +25,7 @@
'Resample',
'ToCanonical',
'CropOrPad',
'CopyAffine',
'EnsureShapeMultiple',
'Mask',
'RescaleIntensity',
Expand Down
79 changes: 79 additions & 0 deletions torchio/transforms/preprocessing/spatial/copy_affine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import copy
from ....data.subject import Subject
from ... import SpatialTransform


class CopyAffine(SpatialTransform):
"""Copy the spatial metadata from a reference image in the subject.
Small unexpected differences in spatial metadata across different images
of a subject can arise due to rounding errors while converting formats.
If the ``shape`` and ``orientation`` of the images are the same and their
``affine`` attributes are different but very similar, this transform can be
used to avoid errors during safety checks in other transforms and samplers.
Args:
target: Name of the image within the subject whose affine matrix will
be used.
Example:
>>> import torch
>>> import torchio as tio
>>> import numpy as np
>>> np.random.seed(0)
>>> affine = np.diag((*(np.random.rand(3) + 0.5), 1))
>>> t1 = tio.ScalarImage(tensor=torch.rand(1, 100, 100, 100), affine=affine)
>>> # Let's simulate a loss of precision
>>> # (caused for example by NIfTI storing spatial metadata in single precision)
>>> bad_affine = affine.astype(np.float16)
>>> t2 = tio.ScalarImage(tensor=torch.rand(1, 100, 100, 100), affine=bad_affine)
>>> subject = tio.Subject(t1=t1, t2=t2)
>>> resample = tio.Resample(0.5)
>>> resample(subject).shape # error as images are in different spaces
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Users/fernando/git/torchio/torchio/data/subject.py", line 101, in shape
self.check_consistent_attribute('shape')
File "/Users/fernando/git/torchio/torchio/data/subject.py", line 229, in check_consistent_attribute
raise RuntimeError(message)
RuntimeError: More than one shape found in subject images:
{'t1': (1, 210, 244, 221), 't2': (1, 210, 243, 221)}
>>> transform = tio.CopyAffine('t1')
>>> fixed = transform(subject)
>>> resample(fixed).shape
(1, 210, 244, 221)
.. warning:: This transform should be used with caution. Modifying the
spatial metadata of an image manually can lead to incorrect processing
of the position of anatomical structures. For example, a machine
learning algorithm might incorrectly predict that a lesion on the right
lung is on the left lung.
.. note:: For more information, see some related discussions on GitHub:
* https:/fepegar/torchio/issues/354
* https:/fepegar/torchio/discussions/489
* https:/fepegar/torchio/pull/584
* https:/fepegar/torchio/issues/430
* https:/fepegar/torchio/issues/382
* https:/fepegar/torchio/pull/592
""" # noqa: E501
def __init__(self, target: str, **kwargs):
super().__init__(**kwargs)
if not isinstance(target, str):
message = (
f'The target must be a string, but "{type(target)}" was found'
)
raise ValueError(message)
self.target = target

def apply_transform(self, subject: Subject) -> Subject:
if self.target not in subject:
message = f'Target image "{self.target}" not found in subject'
raise RuntimeError(message)
affine = subject[self.target].affine
for image in self.get_images(subject):
image.affine = copy.deepcopy(affine)
return subject

0 comments on commit 6c7a188

Please sign in to comment.