diff --git a/mmseg/evaluation/__init__.py b/mmseg/evaluation/__init__.py index c28bb75cb4..a82008f3ad 100644 --- a/mmseg/evaluation/__init__.py +++ b/mmseg/evaluation/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .metrics import CitysMetric, IoUMetric +from .metrics import CityscapesMetric, IoUMetric -__all__ = ['IoUMetric', 'CitysMetric'] +__all__ = ['IoUMetric', 'CityscapesMetric'] diff --git a/mmseg/evaluation/metrics/__init__.py b/mmseg/evaluation/metrics/__init__.py index aec08bb071..0aa39e480c 100644 --- a/mmseg/evaluation/metrics/__init__.py +++ b/mmseg/evaluation/metrics/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .citys_metric import CitysMetric +from .citys_metric import CityscapesMetric from .iou_metric import IoUMetric -__all__ = ['IoUMetric', 'CitysMetric'] +__all__ = ['IoUMetric', 'CityscapesMetric'] diff --git a/mmseg/evaluation/metrics/citys_metric.py b/mmseg/evaluation/metrics/citys_metric.py index 50e9ea68a0..a2c008b99d 100644 --- a/mmseg/evaluation/metrics/citys_metric.py +++ b/mmseg/evaluation/metrics/citys_metric.py @@ -1,30 +1,41 @@ # Copyright (c) OpenMMLab. All rights reserved. import os.path as osp -from typing import Dict, List, Optional, Sequence +import shutil +from collections import OrderedDict +from typing import Dict, Optional, Sequence + +try: + + import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as CSEval # noqa + import cityscapesscripts.helpers.labels as CSLabels +except ImportError: + CSLabels = None + CSEval = None import numpy as np +from mmengine.dist import is_main_process, master_only from mmengine.evaluator import BaseMetric from mmengine.logging import MMLogger, print_log -from mmengine.utils import mkdir_or_exist, scandir +from mmengine.utils import mkdir_or_exist from PIL import Image from mmseg.registry import METRICS @METRICS.register_module() -class CitysMetric(BaseMetric): +class CityscapesMetric(BaseMetric): """Cityscapes evaluation metric. Args: + output_dir (str): The directory for output prediction ignore_index (int): Index that will be ignored in evaluation. Default: 255. - citys_metrics (list[str] | str): Metrics to be evaluated, - Default: ['cityscapes']. - to_label_id (bool): whether convert output to label_id for - submission. Default: True. - suffix (str): The filename prefix of the png files. - If the prefix is "somepath/xxx", the png files will be - named "somepath/xxx.png". Default: '.format_cityscapes'. + format_only (bool): Only format result for results commit without + perform evaluation. It is useful when you want to format the result + to a specific format and submit it to the test server. + Defaults to False. + keep_results (bool): Whether to keep the results. When ``format_only`` + is True, ``keep_results`` must be True. Defaults to False. collect_device (str): Device name used for collecting results from different ranks during distributed training. Must be 'cpu' or 'gpu'. Defaults to 'cpu'. @@ -35,19 +46,34 @@ class CitysMetric(BaseMetric): """ def __init__(self, + output_dir: str, ignore_index: int = 255, - citys_metrics: List[str] = ['cityscapes'], - to_label_id: bool = True, - suffix: str = '.format_cityscapes', + format_only: bool = False, + keep_results: bool = False, collect_device: str = 'cpu', prefix: Optional[str] = None) -> None: super().__init__(collect_device=collect_device, prefix=prefix) - + if CSEval is None: + raise ImportError('Please run "pip install cityscapesscripts" to ' + 'install cityscapesscripts first.') + self.output_dir = output_dir self.ignore_index = ignore_index - self.metrics = citys_metrics - assert self.metrics[0] == 'cityscapes' - self.to_label_id = to_label_id - self.suffix = suffix + + self.format_only = format_only + if format_only: + assert keep_results, ( + 'When format_only is True, the results must be keep, please ' + f'set keep_results as True, but got {keep_results}') + self.keep_results = keep_results + self.prefix = prefix + if is_main_process(): + mkdir_or_exist(self.output_dir) + + @master_only + def __del__(self) -> None: + """Clean up.""" + if not self.keep_results: + shutil.rmtree(self.output_dir) def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: """Process one batch of data and data_samples. @@ -59,26 +85,23 @@ def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: data_batch (dict): A batch of data from the dataloader. data_samples (Sequence[dict]): A batch of outputs from the model. """ - mkdir_or_exist(self.suffix) + mkdir_or_exist(self.output_dir) for data_sample in data_samples: pred_label = data_sample['pred_sem_seg']['data'][0].cpu().numpy() - # results2img - if self.to_label_id: - pred_label = self._convert_to_label_id(pred_label) + # when evaluating with official cityscapesscripts, + # labelIds should be used + pred_label = self._convert_to_label_id(pred_label) basename = osp.splitext(osp.basename(data_sample['img_path']))[0] - png_filename = osp.join(self.suffix, f'{basename}.png') + png_filename = osp.abspath( + osp.join(self.output_dir, f'{basename}.png')) output = Image.fromarray(pred_label.astype(np.uint8)).convert('P') - import cityscapesscripts.helpers.labels as CSLabels - palette = np.zeros((len(CSLabels.id2label), 3), dtype=np.uint8) - for label_id, label in CSLabels.id2label.items(): - palette[label_id] = label.color - output.putpalette(palette) output.save(png_filename) - - ann_dir = osp.join(data_samples[0]['seg_map_path'].split('val')[0], - 'val') - self.results.append(ann_dir) + # when evaluating with official cityscapesscripts, + # **_gtFine_labelIds.png is used + gt_filename = data_sample['seg_map_path'].replace( + 'labelTrainIds.png', 'labelIds.png') + self.results.append((png_filename, gt_filename)) def compute_metrics(self, results: list) -> Dict[str, float]: """Compute the metrics from processed results. @@ -90,38 +113,28 @@ def compute_metrics(self, results: list) -> Dict[str, float]: dict[str: float]: Cityscapes evaluation results. """ logger: MMLogger = MMLogger.get_current_instance() - try: - import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as CSEval # noqa - except ImportError: - raise ImportError('Please run "pip install cityscapesscripts" to ' - 'install cityscapesscripts first.') - msg = 'Evaluating in Cityscapes style' + if self.format_only: + logger.info(f'results are saved to {osp.dirname(self.output_dir)}') + return OrderedDict() + msg = 'Evaluating in Cityscapes style' if logger is None: msg = '\n' + msg print_log(msg, logger=logger) - result_dir = self.suffix - eval_results = dict() - print_log(f'Evaluating results under {result_dir} ...', logger=logger) + print_log( + f'Evaluating results under {self.output_dir} ...', logger=logger) CSEval.args.evalInstLevelScore = True - CSEval.args.predictionPath = osp.abspath(result_dir) + CSEval.args.predictionPath = osp.abspath(self.output_dir) CSEval.args.evalPixelAccuracy = True CSEval.args.JSONOutput = False - seg_map_list = [] - pred_list = [] - ann_dir = results[0] - # when evaluating with official cityscapesscripts, - # **_gtFine_labelIds.png is used - for seg_map in scandir(ann_dir, 'gtFine_labelIds.png', recursive=True): - seg_map_list.append(osp.join(ann_dir, seg_map)) - pred_list.append(CSEval.getPrediction(CSEval.args, seg_map)) + pred_list, gt_list = zip(*results) metric = dict() eval_results.update( - CSEval.evaluateImgLists(pred_list, seg_map_list, CSEval.args)) + CSEval.evaluateImgLists(pred_list, gt_list, CSEval.args)) metric['averageScoreCategories'] = eval_results[ 'averageScoreCategories'] metric['averageScoreInstCategories'] = eval_results[ @@ -133,7 +146,6 @@ def _convert_to_label_id(result): """Convert trainId to id for cityscapes.""" if isinstance(result, str): result = np.load(result) - import cityscapesscripts.helpers.labels as CSLabels result_copy = result.copy() for trainId, label in CSLabels.trainId2label.items(): result_copy[result == trainId] = label.id diff --git a/tests/test_evaluation/test_metrics/test_citys_metric.py b/tests/test_evaluation/test_metrics/test_citys_metric.py index a6d6db5caa..0a20b41aee 100644 --- a/tests/test_evaluation/test_metrics/test_citys_metric.py +++ b/tests/test_evaluation/test_metrics/test_citys_metric.py @@ -1,15 +1,17 @@ # Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp from unittest import TestCase import numpy as np +import pytest import torch from mmengine.structures import BaseDataElement, PixelData -from mmseg.evaluation import CitysMetric +from mmseg.evaluation import CityscapesMetric from mmseg.structures import SegDataSample -class TestCitysMetric(TestCase): +class TestCityscapesMetric(TestCase): def _demo_mm_inputs(self, batch_size=1, @@ -42,9 +44,8 @@ def _demo_mm_inputs(self, gt_sem_seg_data = dict(data=gt_semantic_seg) data_sample.gt_sem_seg = PixelData(**gt_sem_seg_data) mm_inputs['data_sample'] = data_sample.to_dict() - mm_inputs['data_sample']['seg_map_path'] = \ - 'tests/data/pseudo_cityscapes_dataset/gtFine/val/\ - frankfurt/frankfurt_000000_000294_gtFine_labelTrainIds.png' + mm_inputs['data_sample'][ + 'seg_map_path'] = 'tests/data/pseudo_cityscapes_dataset/gtFine/val/frankfurt/frankfurt_000000_000294_gtFine_labelTrainIds.png' # noqa mm_inputs['seg_map_path'] = mm_inputs['data_sample'][ 'seg_map_path'] @@ -86,9 +87,8 @@ def _demo_mm_model_output(self, for pred in batch_datasampes: if isinstance(pred, BaseDataElement): test_data = pred.to_dict() - test_data['img_path'] = \ - 'tests/data/pseudo_cityscapes_dataset/leftImg8bit/val/\ - frankfurt/frankfurt_000000_000294_leftImg8bit.png' + test_data[ + 'img_path'] = 'tests/data/pseudo_cityscapes_dataset/leftImg8bit/val/frankfurt/frankfurt_000000_000294_leftImg8bit.png' # noqa _predictions.append(test_data) else: @@ -104,15 +104,23 @@ def test_evaluate(self): dict(**data, **result) for data, result in zip(data_batch, predictions) ] - iou_metric = CitysMetric(citys_metrics=['cityscapes']) - iou_metric.process(data_batch, data_samples) - res = iou_metric.evaluate(6) - self.assertIsInstance(res, dict) - # test to_label_id = True - iou_metric = CitysMetric( - citys_metrics=['cityscapes'], to_label_id=True) - iou_metric.process(data_batch, data_samples) - res = iou_metric.evaluate(6) + # test keep_results should be True when format_only is True + with pytest.raises(AssertionError): + CityscapesMetric( + output_dir='tmp', format_only=True, keep_results=False) + + # test evaluate with cityscape metric + metric = CityscapesMetric(output_dir='tmp') + metric.process(data_batch, data_samples) + res = metric.evaluate(2) self.assertIsInstance(res, dict) + + # test format_only + metric = CityscapesMetric( + output_dir='tmp', format_only=True, keep_results=True) + metric.process(data_batch, data_samples) + metric.evaluate(2) + assert osp.exists('tmp') + assert osp.isfile('tmp/frankfurt_000000_000294_leftImg8bit.png') import shutil - shutil.rmtree('.format_cityscapes') + shutil.rmtree('tmp')