From dc1773579f3d9f0bdb7781b05f0fb2307b45968d Mon Sep 17 00:00:00 2001 From: xiexinch Date: Tue, 10 Jan 2023 16:25:29 +0800 Subject: [PATCH] fix random crop image_shape --- mmseg/datasets/transforms/transforms.py | 4 ++-- tests/test_datasets/test_transform.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mmseg/datasets/transforms/transforms.py b/mmseg/datasets/transforms/transforms.py index 21b8e34e33..36292cf87b 100644 --- a/mmseg/datasets/transforms/transforms.py +++ b/mmseg/datasets/transforms/transforms.py @@ -314,9 +314,9 @@ def transform(self, results: dict) -> dict: # crop semantic seg for key in results.get('seg_fields', []): results[key] = self.crop(results[key], crop_bbox) - img_shape = img.shape + results['img'] = img - results['img_shape'] = img_shape + results['img_shape'] = img.shape[:2] return results def __repr__(self): diff --git a/tests/test_datasets/test_transform.py b/tests/test_datasets/test_transform.py index f052e6e65e..a0c949af38 100644 --- a/tests/test_datasets/test_transform.py +++ b/tests/test_datasets/test_transform.py @@ -321,7 +321,7 @@ def test_random_crop(): results = pipeline(results) assert results['img'].shape[:2] == (h - 20, w - 20) - assert results['img_shape'][:2] == (h - 20, w - 20) + assert results['img_shape'] == (h - 20, w - 20) assert results['gt_semantic_seg'].shape[:2] == (h - 20, w - 20)