Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Standardize Axes in Random Transforms. Add Random Axis to RandomMotion #1185

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 21 additions & 21 deletions src/torchio/data/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,16 @@
PROTECTED_KEYS = DATA, AFFINE, TYPE, PATH, STEM
TypeBound = Tuple[float, float]
TypeBounds = Tuple[TypeBound, TypeBound, TypeBound]
FLIP_AXIS = {
'L': 'R',
'R': 'L',
'A': 'P',
'P': 'A',
'I': 'S',
'S': 'I',
'T': 'B',
'B': 'T',
}

deprecation_message = (
'Setting the image data with the property setter is deprecated. Use the'
Expand Down Expand Up @@ -378,7 +388,7 @@ def axis_name_to_index(self, axis: str) -> int:
versions and first letters are also valid, as only the first
letter will be used.

.. note:: If you are working with animals, you should probably use
.. note:: If you are working with animals, you should use
``'Superior'``, ``'Inferior'``, ``'Anterior'`` and ``'Posterior'``
for ``'Dorsal'``, ``'Ventral'``, ``'Rostral'`` and ``'Caudal'``,
respectively.
Expand All @@ -392,6 +402,15 @@ def axis_name_to_index(self, axis: str) -> int:
if not isinstance(axis, str):
raise ValueError('Axis must be a string')
axis = axis[0].upper()
if axis not in 'LRPAISTB':
message = (
'Incorrect axis naming. Please use one of: "Left", "Right", '
'"Anterior", "Posterior", "Inferior", "Superior". '
'Lower-case versions and first letters are also valid '
'(i.e., "L", "r", etc). For 2D images, use "Top" and "Bottom" '
'to refer to the vertical (2nd) axis.'
)
raise ValueError(message)

# Generally, TorchIO tensors are (C, W, H, D)
if axis in 'TB': # Top, Bottom
Expand All @@ -400,31 +419,12 @@ def axis_name_to_index(self, axis: str) -> int:
try:
index = self.orientation.index(axis)
except ValueError:
index = self.orientation.index(self.flip_axis(axis))
index = self.orientation.index(FLIP_AXIS[axis])
# Return negative indices so that it does not matter whether we
# refer to spatial dimensions or not
index = -3 + index
return index

@staticmethod
def flip_axis(axis: str) -> str:
"""Return the opposite axis label. For example, ``'L'`` -> ``'R'``.

Args:
axis: Axis label, such as ``'L'`` or ``'left'``.
"""
labels = 'LRPAISTBDV'
first = labels[::2]
last = labels[1::2]
flip_dict = {a: b for a, b in zip(first + last, last + first)}
axis = axis[0].upper()
flipped_axis = flip_dict.get(axis)
if flipped_axis is None:
values = ', '.join(labels)
message = f'Axis not understood. Please use one of: {values}'
raise ValueError(message)
return flipped_axis

def get_spacing_string(self) -> str:
strings = [f'{n:.2f}' for n in self.spacing]
string = f'({", ".join(strings)})'
Expand Down
47 changes: 31 additions & 16 deletions src/torchio/data/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,9 @@
from ..typing import TypeData
from ..typing import TypeDataAffine
from ..typing import TypeDirection
from ..typing import TypeDoubletInt
from ..typing import TypePath
from ..typing import TypeQuartetInt
from ..typing import TypeTripletFloat
from ..typing import TypeTripletInt


# Matrices used to switch between LPS and RAS
Expand Down Expand Up @@ -87,26 +85,43 @@ def _read_dicom(directory: TypePath):


def read_shape(path: TypePath) -> TypeQuartetInt:
reader = sitk.ImageFileReader()
reader.SetFileName(str(path))
reader.ReadImageInformation()
num_channels = reader.GetNumberOfComponents()
num_dimensions = reader.GetDimension()
try:
reader = sitk.ImageFileReader()
reader.SetFileName(str(path))
reader.ReadImageInformation()
num_channels = reader.GetNumberOfComponents()
num_dimensions = reader.GetDimension()
shape = reader.GetSize()
except RuntimeError as e: # try with NiBabel
message = f'Error loading image with SimpleITK:\n{e}\n\nTrying NiBabel...'
warnings.warn(message, stacklevel=2)
try:
obj: SpatialImage = nib.load(str(path)) # type: ignore[assignment]
except nib.loadsave.ImageFileError as e:
message = (
f'File "{path}" not understood.'
' Check supported formats by at'
' https://simpleitk.readthedocs.io/en/master/IO.html#images'
' and https://nipy.org/nibabel/api.html#file-formats'
)
raise RuntimeError(message) from e
num_dimensions = obj.ndim
shape = obj.shape
num_channels = 1 if num_dimensions < 4 else shape[-1]
assert 2 <= num_dimensions <= 4
if num_dimensions == 2:
spatial_shape_2d: TypeDoubletInt = reader.GetSize()
assert len(spatial_shape_2d) == 2
si, sj = spatial_shape_2d
assert len(shape) == 2
si, sj = shape
sk = 1
elif num_dimensions == 4:
# We assume bad NIfTI file (channels encoded as spatial dimension)
spatial_shape_4d: TypeQuartetInt = reader.GetSize()
assert len(spatial_shape_4d) == 4
si, sj, sk, num_channels = spatial_shape_4d
assert len(shape) == 4
si, sj, sk, num_channels = shape
elif num_dimensions == 3:
spatial_shape_3d: TypeTripletInt = reader.GetSize()
assert len(spatial_shape_3d) == 3
si, sj, sk = spatial_shape_3d
assert len(shape) == 3
si, sj, sk = shape
else:
raise ValueError(f'Unsupported number of dimensions: {num_dimensions}')
shape = num_channels, si, sj, sk
return shape

Expand Down
34 changes: 16 additions & 18 deletions src/torchio/transforms/augmentation/intensity/random_ghosting.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from collections import defaultdict
from typing import Dict
from typing import Iterable
from typing import Tuple
from typing import Union

Expand Down Expand Up @@ -60,16 +59,7 @@ def __init__(
**kwargs,
):
super().__init__(**kwargs)
if not isinstance(axes, tuple):
try:
axes = tuple(axes) # type: ignore[arg-type]
except TypeError:
axes = (axes,) # type: ignore[assignment]
assert isinstance(axes, Iterable)
for axis in axes:
if not isinstance(axis, str) and axis not in (0, 1, 2):
raise ValueError(f'Axes must be in (0, 1, 2), not "{axes}"')
self.axes = axes
self.axes = self.parse_axes(axes)
self.num_ghosts_range = self._parse_range(
num_ghosts,
'num_ghosts',
Expand All @@ -84,16 +74,13 @@ def __init__(
self.restore = _parse_restore(restore)

def apply_transform(self, subject: Subject) -> Subject:
axes = self.ensure_axes_indices(subject, self.axes)
arguments: Dict[str, dict] = defaultdict(dict)
if any(isinstance(n, str) for n in self.axes):
subject.check_consistent_orientation()
for name, image in self.get_images_dict(subject).items():
is_2d = image.is_2d()
axes = [a for a in self.axes if a != 2] if is_2d else self.axes
for name, _ in self.get_images_dict(subject).items():
min_ghosts, max_ghosts = self.num_ghosts_range
params = self.get_params(
axes,
(int(min_ghosts), int(max_ghosts)),
axes, # type: ignore[arg-type]
self.intensity_range,
)
num_ghosts_param, axis_param, intensity_param = params
Expand All @@ -108,8 +95,8 @@ def apply_transform(self, subject: Subject) -> Subject:

def get_params(
self,
num_ghosts_range: Tuple[int, int],
axes: Tuple[int, ...],
num_ghosts_range: Tuple[int, int],
intensity_range: Tuple[float, float],
) -> Tuple:
ng_min, ng_max = num_ghosts_range
Expand All @@ -118,6 +105,17 @@ def get_params(
intensity = self.sample_uniform(*intensity_range)
return num_ghosts, axis, intensity

@staticmethod
def parse_restore(restore):
try:
restore = float(restore)
except ValueError as e:
raise TypeError(f'Restore must be a float, not "{restore}"') from e
if not 0 <= restore <= 1:
message = f'Restore must be a number between 0 and 1, not {restore}'
raise ValueError(message)
return restore


class Ghosting(IntensityTransform, FourierTransform):
r"""Add MRI ghosting artifact.
Expand Down
Loading
Loading