Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for custom classes #71

Merged
merged 16 commits into from
Sep 16, 2020
76 changes: 75 additions & 1 deletion mmseg/datasets/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ class CustomDataset(Dataset):
ignore_index (int): The label index to be ignored. Default: 255
reduce_zero_label (bool): Whether to mark label zero as ignored.
Default: False
classes (str | Sequence[str], optional): Specify classes to load.
If is None, ``cls.CLASSES`` will be used. Default: None.
palette (str | Sequence[str], optional): Specify palette to load.
If is None, ``cls.PALETTE`` will be used. Default: None.
"""

CLASSES = None
Expand All @@ -74,7 +78,9 @@ def __init__(self,
data_root=None,
test_mode=False,
ignore_index=255,
reduce_zero_label=False):
reduce_zero_label=False,
classes=None,
palette=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

palette arg is not needed.

self.pipeline = Compose(pipeline)
self.img_dir = img_dir
self.img_suffix = img_suffix
Expand All @@ -85,6 +91,8 @@ def __init__(self,
self.test_mode = test_mode
self.ignore_index = ignore_index
self.reduce_zero_label = reduce_zero_label
self.CLASSES, self.PALETTE = self.get_classes_and_palette(
classes, palette)

# join paths if data_root is specified
if self.data_root is not None:
Expand Down Expand Up @@ -160,6 +168,8 @@ def get_ann_info(self, idx):
def pre_pipeline(self, results):
"""Prepare results dict for pipeline."""
results['seg_fields'] = []
if self.custom_classes:
results['old_id_to_new_id'] = self.old_id_to_new_id
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we use something like label_map and add some comments about it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand what you mean by label_map. Do you mean changing the key name?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the ambiguity.
Yes, 'old_id_to_new_id' is a bit too long. But 'label_map' may not be straight forward enough. Any idea?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe class_transform, class_id_transform or label_transform?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or label_map could be good with a comment explaining it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then label_map it is. We may add comments about key and value of this dict.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may also add comments about what's the purpose of adding this dict.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So shall we rename it to label_map?


def __getitem__(self, idx):
"""Get training/test data after pipeline.
Expand Down Expand Up @@ -220,6 +230,10 @@ def get_gt_seg_maps(self):
for img_info in self.img_infos:
gt_seg_map = mmcv.imread(
img_info['ann']['seg_map'], flag='unchanged', backend='pillow')
# modify if custom classes
if hasattr(self, 'old_id_to_new_id'):
for old_id, new_id in self.old_id_to_new_id.items():
gt_seg_map[gt_seg_map == old_id] = new_id
if self.reduce_zero_label:
# avoid using underflow conversion
gt_seg_map[gt_seg_map == 0] = 255
Expand All @@ -230,6 +244,66 @@ def get_gt_seg_maps(self):

return gt_seg_maps

def get_classes_and_palette(self, classes=None, palette=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

palette arg is not needed.

"""Get class names of current dataset.

Args:
classes (Sequence[str] | str | None): If classes is None, use
default CLASSES defined by builtin dataset. If classes is a
string, take it as a file name. The file contains the name of
classes where each line contains one class name. If classes is
a tuple or list, override the CLASSES defined by the dataset.
palette (Sequence[str] | str | None): If palette is None, use
default PALETTE defined by builtin dataset. If palette is a
string, take it as a file name. The file contains the name of
palette where each line contains one palette value. If palette
is a tuple or list, override the PALETTE defined by the
dataset.
"""
if classes is None:
self.custom_classes = False
return self.CLASSES, self.PALETTE

self.custom_classes = True
if isinstance(classes, str):
# take it as a file path
class_names = mmcv.list_from_file(classes)
elif isinstance(classes, (tuple, list)):
class_names = classes
else:
raise ValueError(f'Unsupported type {type(classes)} of classes.')

if self.CLASSES:
if not set(classes).issubset(self.CLASSES):
raise ValueError('classes is not a subset of CLASSES.')
self.old_id_to_new_id = {}
for i, c in enumerate(self.CLASSES):
if c not in class_names:
self.old_id_to_new_id[i] = -1
else:
self.old_id_to_new_id[i] = classes.index(c)

return class_names, self.get_palette_for_custom_classes(palette)

def get_palette_for_custom_classes(self, palette=None):

if palette:
if isinstance(palette, str):
# take it as a file path
palette = mmcv.list_from_file(palette)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The palette is a list of list. It is not trivial to load palette from path.
I suggest removing this function. We just need to select the right palette if there is.

elif not isinstance(palette, (tuple, list)):
raise ValueError(
f'Unsupported type {type(palette)} of palette.')

elif hasattr(self, 'old_id_to_new_id'):
palette = []
for x in sorted(self.old_id_to_new_id.items(), key=lambda x: x[1]):
if x[1] != -1:
palette.append(self.PALETTE[x[0]])
palette = type(self.PALETTE)(palette)

return palette

def evaluate(self, results, metric='mIoU', logger=None, **kwargs):
"""Evaluate the dataset.

Expand Down
4 changes: 4 additions & 0 deletions mmseg/datasets/pipelines/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@ def __call__(self, results):
gt_semantic_seg = mmcv.imfrombytes(
img_bytes, flag='unchanged',
backend=self.imdecode_backend).squeeze().astype(np.uint8)
# modify if custom classes
if results.get('old_id_to_new_id', None) is not None:
for old_id, new_id in results['old_id_to_new_id'].items():
gt_semantic_seg[gt_semantic_seg == old_id] = new_id
xvjiarui marked this conversation as resolved.
Show resolved Hide resolved
# reduce zero_label
if self.reduce_zero_label:
# avoid using underflow conversion
Expand Down
64 changes: 62 additions & 2 deletions tests/test_data/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import pytest

from mmseg.core.evaluation import get_classes, get_palette
from mmseg.datasets import (ADE20KDataset, CityscapesDataset, ConcatDataset,
CustomDataset, PascalVOCDataset, RepeatDataset)
from mmseg.datasets import (DATASETS, ADE20KDataset, CityscapesDataset,
ConcatDataset, CustomDataset, PascalVOCDataset,
RepeatDataset)


def test_classes():
Expand Down Expand Up @@ -171,3 +172,62 @@ def test_custom_dataset():
assert 'mIoU' in eval_results
assert 'mAcc' in eval_results
assert 'aAcc' in eval_results


@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)
@patch('mmseg.datasets.CustomDataset.__getitem__',
MagicMock(side_effect=lambda idx: idx))
@pytest.mark.parametrize('dataset, classes', [
('ADE20KDataset', ('wall', 'building')),
('CityscapesDataset', ('road', 'sidewalk')),
('CustomDataset', ('bus', 'car')),
('PascalVOCDataset', ('aeroplane', 'bicycle')),
])
def test_custom_classes_override_default(dataset, classes):

dataset_class = DATASETS.get(dataset)

original_classes = dataset_class.CLASSES

# Test setting classes as a tuple
custom_dataset = dataset_class(
pipeline=[],
img_dir=MagicMock(),
split=MagicMock(),
classes=classes,
test_mode=True)

assert custom_dataset.CLASSES != original_classes
assert custom_dataset.CLASSES == classes

# Test setting classes as a list
custom_dataset = dataset_class(
pipeline=[],
img_dir=MagicMock(),
split=MagicMock(),
classes=list(classes),
test_mode=True)

assert custom_dataset.CLASSES != original_classes
assert custom_dataset.CLASSES == list(classes)

# Test overriding not a subset
custom_dataset = dataset_class(
pipeline=[],
img_dir=MagicMock(),
split=MagicMock(),
classes=[classes[0]],
test_mode=True)

assert custom_dataset.CLASSES != original_classes
assert custom_dataset.CLASSES == [classes[0]]

# Test default behavior
custom_dataset = dataset_class(
pipeline=[],
img_dir=MagicMock(),
split=MagicMock(),
classes=None,
test_mode=True)

assert custom_dataset.CLASSES == original_classes
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may add a test case for LoadAnnotation pipeline when there is custom classes.
To achieve this, we may generate a random ground truth segmentation map.

Copy link
Contributor Author

@igonro igonro Sep 4, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I add it in test_loading.py or in this same file? Also I'm not very sure what type of test I should add.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may add it in test_loading.py.
We should check if the loaded annotation is of custom classes as desired.