-
Notifications
You must be signed in to change notification settings - Fork 240
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Discussion #489: add CopyAffine transform * review changes discussion #489 * update documentation for #489 * Minor edits in tests * Edit transform and tests * Fix docstring Co-authored-by: STEFF Alban <[email protected]> Co-authored-by: Fernando <[email protected]>
- Loading branch information
1 parent
3a869f6
commit 6c7a188
Showing
6 changed files
with
121 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import torch | ||
|
||
import torchio as tio | ||
from ...utils import TorchioTestCase | ||
|
||
|
||
class TestCopyAffine(TorchioTestCase): | ||
"""Tests for `CopyAffine`.""" | ||
|
||
def test_missing_reference(self): | ||
transform = tio.CopyAffine(target='missing') | ||
with self.assertRaises(RuntimeError): | ||
transform(self.sample_subject) | ||
|
||
def test_wrong_target_type(self): | ||
with self.assertRaises(ValueError): | ||
tio.CopyAffine(target=[1]) | ||
|
||
def test_same_affine(self): | ||
image = tio.ScalarImage(tensor=torch.rand(2, 2, 2, 2)) | ||
mask = tio.LabelMap(tensor=torch.rand(2, 2, 2, 2)) | ||
mask.affine *= 1.1 | ||
subject = tio.Subject(t1=image, mask=mask) | ||
transform = tio.CopyAffine('t1') | ||
transformed = transform(subject) | ||
self.assertTensorEqual( | ||
transformed['t1'].affine, | ||
transformed['mask'].affine, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import copy | ||
from ....data.subject import Subject | ||
from ... import SpatialTransform | ||
|
||
|
||
class CopyAffine(SpatialTransform): | ||
"""Copy the spatial metadata from a reference image in the subject. | ||
Small unexpected differences in spatial metadata across different images | ||
of a subject can arise due to rounding errors while converting formats. | ||
If the ``shape`` and ``orientation`` of the images are the same and their | ||
``affine`` attributes are different but very similar, this transform can be | ||
used to avoid errors during safety checks in other transforms and samplers. | ||
Args: | ||
target: Name of the image within the subject whose affine matrix will | ||
be used. | ||
Example: | ||
>>> import torch | ||
>>> import torchio as tio | ||
>>> import numpy as np | ||
>>> np.random.seed(0) | ||
>>> affine = np.diag((*(np.random.rand(3) + 0.5), 1)) | ||
>>> t1 = tio.ScalarImage(tensor=torch.rand(1, 100, 100, 100), affine=affine) | ||
>>> # Let's simulate a loss of precision | ||
>>> # (caused for example by NIfTI storing spatial metadata in single precision) | ||
>>> bad_affine = affine.astype(np.float16) | ||
>>> t2 = tio.ScalarImage(tensor=torch.rand(1, 100, 100, 100), affine=bad_affine) | ||
>>> subject = tio.Subject(t1=t1, t2=t2) | ||
>>> resample = tio.Resample(0.5) | ||
>>> resample(subject).shape # error as images are in different spaces | ||
Traceback (most recent call last): | ||
File "<stdin>", line 1, in <module> | ||
File "/Users/fernando/git/torchio/torchio/data/subject.py", line 101, in shape | ||
self.check_consistent_attribute('shape') | ||
File "/Users/fernando/git/torchio/torchio/data/subject.py", line 229, in check_consistent_attribute | ||
raise RuntimeError(message) | ||
RuntimeError: More than one shape found in subject images: | ||
{'t1': (1, 210, 244, 221), 't2': (1, 210, 243, 221)} | ||
>>> transform = tio.CopyAffine('t1') | ||
>>> fixed = transform(subject) | ||
>>> resample(fixed).shape | ||
(1, 210, 244, 221) | ||
.. warning:: This transform should be used with caution. Modifying the | ||
spatial metadata of an image manually can lead to incorrect processing | ||
of the position of anatomical structures. For example, a machine | ||
learning algorithm might incorrectly predict that a lesion on the right | ||
lung is on the left lung. | ||
.. note:: For more information, see some related discussions on GitHub: | ||
* https:/fepegar/torchio/issues/354 | ||
* https:/fepegar/torchio/discussions/489 | ||
* https:/fepegar/torchio/pull/584 | ||
* https:/fepegar/torchio/issues/430 | ||
* https:/fepegar/torchio/issues/382 | ||
* https:/fepegar/torchio/pull/592 | ||
""" # noqa: E501 | ||
def __init__(self, target: str, **kwargs): | ||
super().__init__(**kwargs) | ||
if not isinstance(target, str): | ||
message = ( | ||
f'The target must be a string, but "{type(target)}" was found' | ||
) | ||
raise ValueError(message) | ||
self.target = target | ||
|
||
def apply_transform(self, subject: Subject) -> Subject: | ||
if self.target not in subject: | ||
message = f'Target image "{self.target}" not found in subject' | ||
raise RuntimeError(message) | ||
affine = subject[self.target].affine | ||
for image in self.get_images(subject): | ||
image.affine = copy.deepcopy(affine) | ||
return subject |