-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for custom classes #71
Changes from 7 commits
06d5562
38c617c
538819e
f0ed77f
993e290
cda80d2
5b1af68
f8ccb85
d88f8f9
ea29d50
7478a42
d0a2763
04c5547
36a309a
b13aec1
5d0fda4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -58,6 +58,10 @@ class CustomDataset(Dataset): | |
ignore_index (int): The label index to be ignored. Default: 255 | ||
reduce_zero_label (bool): Whether to mark label zero as ignored. | ||
Default: False | ||
classes (str | Sequence[str], optional): Specify classes to load. | ||
If is None, ``cls.CLASSES`` will be used. Default: None. | ||
palette (str | Sequence[str], optional): Specify palette to load. | ||
If is None, ``cls.PALETTE`` will be used. Default: None. | ||
""" | ||
|
||
CLASSES = None | ||
|
@@ -74,7 +78,9 @@ def __init__(self, | |
data_root=None, | ||
test_mode=False, | ||
ignore_index=255, | ||
reduce_zero_label=False): | ||
reduce_zero_label=False, | ||
classes=None, | ||
palette=None): | ||
self.pipeline = Compose(pipeline) | ||
self.img_dir = img_dir | ||
self.img_suffix = img_suffix | ||
|
@@ -85,6 +91,8 @@ def __init__(self, | |
self.test_mode = test_mode | ||
self.ignore_index = ignore_index | ||
self.reduce_zero_label = reduce_zero_label | ||
self.CLASSES, self.PALETTE = self.get_classes_and_palette( | ||
classes, palette) | ||
|
||
# join paths if data_root is specified | ||
if self.data_root is not None: | ||
|
@@ -160,6 +168,8 @@ def get_ann_info(self, idx): | |
def pre_pipeline(self, results): | ||
"""Prepare results dict for pipeline.""" | ||
results['seg_fields'] = [] | ||
if self.custom_classes: | ||
results['old_id_to_new_id'] = self.old_id_to_new_id | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we use something like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand what you mean by There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry for the ambiguity. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Or There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Then There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We may also add comments about what's the purpose of adding this dict. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So shall we rename it to |
||
|
||
def __getitem__(self, idx): | ||
"""Get training/test data after pipeline. | ||
|
@@ -220,6 +230,10 @@ def get_gt_seg_maps(self): | |
for img_info in self.img_infos: | ||
gt_seg_map = mmcv.imread( | ||
img_info['ann']['seg_map'], flag='unchanged', backend='pillow') | ||
# modify if custom classes | ||
if hasattr(self, 'old_id_to_new_id'): | ||
for old_id, new_id in self.old_id_to_new_id.items(): | ||
gt_seg_map[gt_seg_map == old_id] = new_id | ||
if self.reduce_zero_label: | ||
# avoid using underflow conversion | ||
gt_seg_map[gt_seg_map == 0] = 255 | ||
|
@@ -230,6 +244,66 @@ def get_gt_seg_maps(self): | |
|
||
return gt_seg_maps | ||
|
||
def get_classes_and_palette(self, classes=None, palette=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
"""Get class names of current dataset. | ||
|
||
Args: | ||
classes (Sequence[str] | str | None): If classes is None, use | ||
default CLASSES defined by builtin dataset. If classes is a | ||
string, take it as a file name. The file contains the name of | ||
classes where each line contains one class name. If classes is | ||
a tuple or list, override the CLASSES defined by the dataset. | ||
palette (Sequence[str] | str | None): If palette is None, use | ||
default PALETTE defined by builtin dataset. If palette is a | ||
string, take it as a file name. The file contains the name of | ||
palette where each line contains one palette value. If palette | ||
is a tuple or list, override the PALETTE defined by the | ||
dataset. | ||
""" | ||
if classes is None: | ||
self.custom_classes = False | ||
return self.CLASSES, self.PALETTE | ||
|
||
self.custom_classes = True | ||
if isinstance(classes, str): | ||
# take it as a file path | ||
class_names = mmcv.list_from_file(classes) | ||
elif isinstance(classes, (tuple, list)): | ||
class_names = classes | ||
else: | ||
raise ValueError(f'Unsupported type {type(classes)} of classes.') | ||
|
||
if self.CLASSES: | ||
if not set(classes).issubset(self.CLASSES): | ||
raise ValueError('classes is not a subset of CLASSES.') | ||
self.old_id_to_new_id = {} | ||
for i, c in enumerate(self.CLASSES): | ||
if c not in class_names: | ||
self.old_id_to_new_id[i] = -1 | ||
else: | ||
self.old_id_to_new_id[i] = classes.index(c) | ||
|
||
return class_names, self.get_palette_for_custom_classes(palette) | ||
|
||
def get_palette_for_custom_classes(self, palette=None): | ||
|
||
if palette: | ||
if isinstance(palette, str): | ||
# take it as a file path | ||
palette = mmcv.list_from_file(palette) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
elif not isinstance(palette, (tuple, list)): | ||
raise ValueError( | ||
f'Unsupported type {type(palette)} of palette.') | ||
|
||
elif hasattr(self, 'old_id_to_new_id'): | ||
palette = [] | ||
for x in sorted(self.old_id_to_new_id.items(), key=lambda x: x[1]): | ||
if x[1] != -1: | ||
palette.append(self.PALETTE[x[0]]) | ||
palette = type(self.PALETTE)(palette) | ||
|
||
return palette | ||
|
||
def evaluate(self, results, metric='mIoU', logger=None, **kwargs): | ||
"""Evaluate the dataset. | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,8 +5,9 @@ | |
import pytest | ||
|
||
from mmseg.core.evaluation import get_classes, get_palette | ||
from mmseg.datasets import (ADE20KDataset, CityscapesDataset, ConcatDataset, | ||
CustomDataset, PascalVOCDataset, RepeatDataset) | ||
from mmseg.datasets import (DATASETS, ADE20KDataset, CityscapesDataset, | ||
ConcatDataset, CustomDataset, PascalVOCDataset, | ||
RepeatDataset) | ||
|
||
|
||
def test_classes(): | ||
|
@@ -171,3 +172,62 @@ def test_custom_dataset(): | |
assert 'mIoU' in eval_results | ||
assert 'mAcc' in eval_results | ||
assert 'aAcc' in eval_results | ||
|
||
|
||
@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock) | ||
@patch('mmseg.datasets.CustomDataset.__getitem__', | ||
MagicMock(side_effect=lambda idx: idx)) | ||
@pytest.mark.parametrize('dataset, classes', [ | ||
('ADE20KDataset', ('wall', 'building')), | ||
('CityscapesDataset', ('road', 'sidewalk')), | ||
('CustomDataset', ('bus', 'car')), | ||
('PascalVOCDataset', ('aeroplane', 'bicycle')), | ||
]) | ||
def test_custom_classes_override_default(dataset, classes): | ||
|
||
dataset_class = DATASETS.get(dataset) | ||
|
||
original_classes = dataset_class.CLASSES | ||
|
||
# Test setting classes as a tuple | ||
custom_dataset = dataset_class( | ||
pipeline=[], | ||
img_dir=MagicMock(), | ||
split=MagicMock(), | ||
classes=classes, | ||
test_mode=True) | ||
|
||
assert custom_dataset.CLASSES != original_classes | ||
assert custom_dataset.CLASSES == classes | ||
|
||
# Test setting classes as a list | ||
custom_dataset = dataset_class( | ||
pipeline=[], | ||
img_dir=MagicMock(), | ||
split=MagicMock(), | ||
classes=list(classes), | ||
test_mode=True) | ||
|
||
assert custom_dataset.CLASSES != original_classes | ||
assert custom_dataset.CLASSES == list(classes) | ||
|
||
# Test overriding not a subset | ||
custom_dataset = dataset_class( | ||
pipeline=[], | ||
img_dir=MagicMock(), | ||
split=MagicMock(), | ||
classes=[classes[0]], | ||
test_mode=True) | ||
|
||
assert custom_dataset.CLASSES != original_classes | ||
assert custom_dataset.CLASSES == [classes[0]] | ||
|
||
# Test default behavior | ||
custom_dataset = dataset_class( | ||
pipeline=[], | ||
img_dir=MagicMock(), | ||
split=MagicMock(), | ||
classes=None, | ||
test_mode=True) | ||
|
||
assert custom_dataset.CLASSES == original_classes | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We may add a test case for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should I add it in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We may add it in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
palette
arg is not needed.