From eee780f765684a19f9cc27e647860e69b2d2ff6a Mon Sep 17 00:00:00 2001 From: Peng Lu Date: Wed, 13 Sep 2023 15:31:22 +0800 Subject: [PATCH] [Feature] Support VPD Depth Estimator (#3321) Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily get feedback. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers. ## Motivation Support depth estimation algorithm [VPD](https://github.com/wl-zhao/VPD) ## Modification 1. add VPD backbone 2. add VPD decoder head for depth estimation 3. add a new segmentor `DepthEstimator` based on `EncoderDecoder` for depth estimation 4. add an integrated metric that calculate common metrics in depth estimation 5. add SiLog loss for depth estimation 6. add config for VPD ## BC-breaking (Optional) Does the modification introduce changes that break the backward-compatibility of the downstream repos? If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR. ## Use cases (Optional) If this PR introduces a new feature, it is better to list some use cases here, and update the documentation. ## Checklist 1. Pre-commit or other linting tools are used to fix the potential lint issues. 7. The modification is covered by complete unit tests. If not, please add more unit test to ensure the correctness. 8. If the modification has potential influence on downstream projects, this PR should be tested with downstream projects, like MMDet or MMDet3D. 9. The documentation has been modified accordingly, like docstring or example tutorials. --- .circleci/test.yml | 2 + .pre-commit-config.yaml | 20 +- configs/_base_/datasets/nyu.py | 66 +++ configs/_base_/models/vpd_sd.py | 86 ++++ configs/_base_/schedules/schedule_25k.py | 28 ++ configs/vpd/README.md | 49 +++ configs/vpd/metafile.yaml | 34 ++ configs/vpd/vpd_sd_4xb8-25k_nyu-480x480.py | 37 ++ mmseg/datasets/transforms/__init__.py | 6 +- mmseg/datasets/transforms/loading.py | 3 + mmseg/datasets/transforms/transforms.py | 127 ++++++ mmseg/engine/__init__.py | 7 +- mmseg/engine/optimizers/__init__.py | 4 +- .../optimizers/force_default_constructor.py | 255 ++++++++++++ mmseg/engine/schedulers/__init__.py | 4 + .../engine/schedulers/poly_ratio_scheduler.py | 62 +++ mmseg/models/backbones/__init__.py | 3 +- mmseg/models/backbones/vpd.py | 383 +++++++++++++++++ mmseg/models/decode_heads/__init__.py | 3 +- mmseg/models/decode_heads/vpd_depth_head.py | 254 ++++++++++++ mmseg/models/losses/__init__.py | 3 +- mmseg/models/losses/silog_loss.py | 117 ++++++ mmseg/models/segmentors/__init__.py | 4 +- mmseg/models/segmentors/depth_estimator.py | 392 ++++++++++++++++++ mmseg/registry/registry.py | 4 +- mmseg/utils/misc.py | 20 +- requirements/optional.txt | 20 + tests/test_config.py | 12 +- tests/test_datasets/test_transform.py | 55 ++- tests/test_models/test_backbones/test_vpd.py | 51 +++ .../test_heads/test_vpd_depth_head.py | 50 +++ .../test_losses/test_silog_loss.py | 20 + .../test_segmentors/test_depth_estimator.py | 64 +++ 33 files changed, 2216 insertions(+), 29 deletions(-) create mode 100644 configs/_base_/datasets/nyu.py create mode 100644 configs/_base_/models/vpd_sd.py create mode 100644 configs/_base_/schedules/schedule_25k.py create mode 100644 configs/vpd/README.md create mode 100644 configs/vpd/metafile.yaml create mode 100644 configs/vpd/vpd_sd_4xb8-25k_nyu-480x480.py create mode 100644 mmseg/engine/optimizers/force_default_constructor.py create mode 100644 mmseg/engine/schedulers/__init__.py create mode 100644 mmseg/engine/schedulers/poly_ratio_scheduler.py create mode 100644 mmseg/models/backbones/vpd.py create mode 100644 mmseg/models/decode_heads/vpd_depth_head.py create mode 100644 mmseg/models/losses/silog_loss.py create mode 100644 mmseg/models/segmentors/depth_estimator.py create mode 100644 tests/test_models/test_backbones/test_vpd.py create mode 100644 tests/test_models/test_heads/test_vpd_depth_head.py create mode 100644 tests/test_models/test_losses/test_silog_loss.py create mode 100644 tests/test_models/test_segmentors/test_depth_estimator.py diff --git a/.circleci/test.yml b/.circleci/test.yml index ceef7884f7..57f89e18f5 100644 --- a/.circleci/test.yml +++ b/.circleci/test.yml @@ -110,6 +110,8 @@ jobs: docker exec mmseg mim install mmcv>=2.0.0 docker exec mmseg pip install mmpretrain>=1.0.0rc7 docker exec mmseg mim install mmdet>=3.0.0 + docker exec mmseg apt-get update + docker exec mmseg apt-get install -y git docker exec mmseg pip install -r requirements/tests.txt -r requirements/optional.txt docker exec mmseg python -m pip install albumentations>=0.3.2 --no-binary qudida,albumentations - run: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cf6ecdd8fc..aa5942748a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -42,15 +42,17 @@ repos: hooks: - id: docformatter args: ["--in-place", "--wrap-descriptions", "79"] - - repo: local - hooks: - - id: update-model-index - name: update-model-index - description: Collect model information and update model-index.yml - entry: .dev_scripts/update_model_index.py - additional_dependencies: [pyyaml] - language: python - require_serial: true + # temporarily remove update-model-index to avoid conflict raised + # by depth estimator models + # - repo: local + # hooks: + # - id: update-model-index + # name: update-model-index + # description: Collect model information and update model-index.yml + # entry: .dev_scripts/update_model_index.py + # additional_dependencies: [pyyaml] + # language: python + # require_serial: true - repo: https://github.com/asottile/pyupgrade rev: v3.0.0 hooks: diff --git a/configs/_base_/datasets/nyu.py b/configs/_base_/datasets/nyu.py new file mode 100644 index 0000000000..332e8b842d --- /dev/null +++ b/configs/_base_/datasets/nyu.py @@ -0,0 +1,66 @@ +# dataset settings +dataset_type = 'NYUDataset' +data_root = 'data/nyu' + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadDepthAnnotation', depth_rescale_factor=1e-3), + dict(type='RandomDepthMix', prob=0.25), + dict(type='RandomFlip', prob=0.5), + dict(type='RandomCrop', crop_size=(480, 480)), + dict( + type='Albu', + transforms=[ + dict(type='RandomBrightnessContrast'), + dict(type='RandomGamma'), + dict(type='HueSaturationValue'), + ]), + dict( + type='PackSegInputs', + meta_keys=('img_path', 'depth_map_path', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'category_id')), +] + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(dict(type='LoadDepthAnnotation', depth_rescale_factor=1e-3)), + dict( + type='PackSegInputs', + meta_keys=('img_path', 'depth_map_path', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'category_id')) +] + +train_dataloader = dict( + batch_size=8, + num_workers=8, + persistent_workers=True, + sampler=dict(type='InfiniteSampler', shuffle=True), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict( + img_path='images/train', depth_map_path='annotations/train'), + pipeline=train_pipeline)) + +val_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + test_mode=True, + data_prefix=dict( + img_path='images/test', depth_map_path='annotations/test'), + pipeline=test_pipeline)) +test_dataloader = val_dataloader + +val_evaluator = dict( + type='DepthMetric', + min_depth_eval=0.001, + max_depth_eval=10.0, + crop_type='nyu_crop') +test_evaluator = val_evaluator diff --git a/configs/_base_/models/vpd_sd.py b/configs/_base_/models/vpd_sd.py new file mode 100644 index 0000000000..880ccfe652 --- /dev/null +++ b/configs/_base_/models/vpd_sd.py @@ -0,0 +1,86 @@ +# model settings +data_preprocessor = dict( + type='SegDataPreProcessor', + mean=[127.5, 127.5, 127.5], + std=[127.5, 127.5, 127.5], + bgr_to_rgb=True, + pad_val=0, + seg_pad_val=255) + +# adapted from stable-diffusion/configs/stable-diffusion/v1-inference.yaml +stable_diffusion_cfg = dict( + base_learning_rate=0.0001, + target='ldm.models.diffusion.ddpm.LatentDiffusion', + checkpoint='https://download.openmmlab.com/mmsegmentation/v0.5/' + 'vpd/stable_diffusion_v1-5_pretrain_third_party.pth', + params=dict( + linear_start=0.00085, + linear_end=0.012, + num_timesteps_cond=1, + log_every_t=200, + timesteps=1000, + first_stage_key='jpg', + cond_stage_key='txt', + image_size=64, + channels=4, + cond_stage_trainable=False, + conditioning_key='crossattn', + monitor='val/loss_simple_ema', + scale_factor=0.18215, + use_ema=False, + scheduler_config=dict( + target='ldm.lr_scheduler.LambdaLinearScheduler', + params=dict( + warm_up_steps=[10000], + cycle_lengths=[10000000000000], + f_start=[1e-06], + f_max=[1.0], + f_min=[1.0])), + unet_config=dict( + target='ldm.modules.diffusionmodules.openaimodel.UNetModel', + params=dict( + image_size=32, + in_channels=4, + out_channels=4, + model_channels=320, + attention_resolutions=[4, 2, 1], + num_res_blocks=2, + channel_mult=[1, 2, 4, 4], + num_heads=8, + use_spatial_transformer=True, + transformer_depth=1, + context_dim=768, + use_checkpoint=True, + legacy=False)), + first_stage_config=dict( + target='ldm.models.autoencoder.AutoencoderKL', + params=dict( + embed_dim=4, + monitor='val/rec_loss', + ddconfig=dict( + double_z=True, + z_channels=4, + resolution=256, + in_channels=3, + out_ch=3, + ch=128, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_resolutions=[], + dropout=0.0), + lossconfig=dict(target='torch.nn.Identity'))), + cond_stage_config=dict( + target='ldm.modules.encoders.modules.AbstractEncoder'))) + +model = dict( + type='DepthEstimator', + data_preprocessor=data_preprocessor, + backbone=dict( + type='VPD', + diffusion_cfg=stable_diffusion_cfg, + ), +) + +# some of the parameters in stable-diffusion model will not be updated +# during training +find_unused_parameters = True diff --git a/configs/_base_/schedules/schedule_25k.py b/configs/_base_/schedules/schedule_25k.py new file mode 100644 index 0000000000..825e141ed1 --- /dev/null +++ b/configs/_base_/schedules/schedule_25k.py @@ -0,0 +1,28 @@ +# optimizer +optimizer = dict(type='AdamW', lr=0.001, weight_decay=0.1) +optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer, clip_grad=None) +# learning policy +param_scheduler = [ + dict( + type='LinearLR', start_factor=3e-2, begin=0, end=12000, + by_epoch=False), + dict( + type='PolyLRRatio', + eta_min_ratio=3e-2, + power=0.9, + begin=12000, + end=24000, + by_epoch=False), + dict(type='ConstantLR', by_epoch=False, factor=1, begin=24000, end=25000) +] +# training schedule for 25k +train_cfg = dict(type='IterBasedTrainLoop', max_iters=25000, val_interval=1000) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') +default_hooks = dict( + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=2000), + sampler_seed=dict(type='DistSamplerSeedHook'), + visualization=dict(type='SegVisualizationHook')) diff --git a/configs/vpd/README.md b/configs/vpd/README.md new file mode 100644 index 0000000000..cf80b92aad --- /dev/null +++ b/configs/vpd/README.md @@ -0,0 +1,49 @@ +# VPD + +> [Unleashing Text-to-Image Diffusion Models for Visual Perception](https://arxiv.org/abs/2303.02153) + +## Introduction + + + +Official Repo + +## Abstract + + + +Diffusion models (DMs) have become the new trend of generative models and have demonstrated a powerful ability of conditional synthesis. Among those, text-to-image diffusion models pre-trained on large-scale image-text pairs are highly controllable by customizable prompts. Unlike the unconditional generative models that focus on low-level attributes and details, text-to-image diffusion models contain more high-level knowledge thanks to the vision-language pre-training. In this paper, we propose VPD (Visual Perception with a pre-trained Diffusion model), a new framework that exploits the semantic information of a pre-trained text-to-image diffusion model in visual perception tasks. Instead of using the pre-trained denoising autoencoder in a diffusion-based pipeline, we simply use it as a backbone and aim to study how to take full advantage of the learned knowledge. Specifically, we prompt the denoising decoder with proper textual inputs and refine the text features with an adapter, leading to a better alignment to the pre-trained stage and making the visual contents interact with the text prompts. We also propose to utilize the cross-attention maps between the visual features and the text features to provide explicit guidance. Compared with other pre-training methods, we show that vision-language pre-trained diffusion models can be faster adapted to downstream visual perception tasks using the proposed VPD. Extensive experiments on semantic segmentation, referring image segmentation and depth estimation demonstrates the effectiveness of our method. Notably, VPD attains 0.254 RMSE on NYUv2 depth estimation and 73.3% oIoU on RefCOCO-val referring image segmentation, establishing new records on these two benchmarks. + + + +
+ +
+ +## Usage + +To run training or inference with VPD model, please install the required packages via + +```sh +pip install -r requirements/albu.txt +pip install -r requirements/optional.txt +``` + +## Results and models + +### NYU + +| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | Device | RMSE | d1 | d2 | d3 | REL | log_10 | config | download | +| ------ | --------------------- | --------- | ------- | -------- | -------------- | ------ | ----- | ----- | ----- | ----- | ----- | ------ | ----------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| VPD | Stable-Diffusion-v1-5 | 480x480 | 25000 | - | - | A100 | 0.253 | 0.964 | 0.995 | 0.999 | 0.069 | 0.030 | [config](https://github.com/open-mmlab/mmsegmentation/tree/main/configs/vpd/vpd_sd_4xb8-25k_nyu-480x480.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/vpd/vpd_sd_4xb8-25k_nyu-480x480_20230908-66144bc4.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/vpd/vpd_sd_4xb8-25k_nyu-480x480_20230908.json) | + +## Citation + +```bibtex +@article{zhao2023unleashing, + title={Unleashing Text-to-Image Diffusion Models for Visual Perception}, + author={Zhao, Wenliang and Rao, Yongming and Liu, Zuyan and Liu, Benlin and Zhou, Jie and Lu, Jiwen}, + journal={ICCV}, + year={2023} +} +``` diff --git a/configs/vpd/metafile.yaml b/configs/vpd/metafile.yaml new file mode 100644 index 0000000000..d87b51c4fe --- /dev/null +++ b/configs/vpd/metafile.yaml @@ -0,0 +1,34 @@ +Collections: +- Name: VPD + License: Apache License 2.0 + Metadata: + Training Data: + - NYU + Paper: + Title: Unleashing Text-to-Image Diffusion Models for Visual Perception + URL: https://arxiv.org/abs/2303.02153 + README: configs/vpd/README.md + Frameworks: + - PyTorch +Models: +- Name: vpd_sd_4xb8-25k_nyu-480x480 + In Collection: VPD + Results: + Task: Depth Estimation + Dataset: NYU + Metrics: + RMSE: 0.253 + Config: configs/vpd/vpd_sd_4xb8-25k_nyu-480x480.py + Metadata: + Training Data: NYU + Batch Size: 32 + Architecture: + - Stable-Diffusion + Training Resources: 8x A100 GPUS + Weights: https://download.openmmlab.com/mmsegmentation/v0.5/vpd/vpd_sd_4xb8-25k_nyu-480x480_20230908-66144bc4.pth + Training log: https://download.openmmlab.com/mmsegmentation/v0.5/vpd/vpd_sd_4xb8-25k_nyu-480x480_20230908.json + Paper: + Title: 'High-Resolution Image Synthesis with Latent Diffusion Models' + URL: https://arxiv.org/abs/2112.10752 + Code: https://github.com/open-mmlab/mmsegmentation/tree/main/mmseg/models/backbones/vpd.py#L333 + Framework: PyTorch diff --git a/configs/vpd/vpd_sd_4xb8-25k_nyu-480x480.py b/configs/vpd/vpd_sd_4xb8-25k_nyu-480x480.py new file mode 100644 index 0000000000..d1f3330513 --- /dev/null +++ b/configs/vpd/vpd_sd_4xb8-25k_nyu-480x480.py @@ -0,0 +1,37 @@ +_base_ = [ + '../_base_/models/vpd_sd.py', '../_base_/datasets/nyu.py', + '../_base_/default_runtime.py', '../_base_/schedules/schedule_25k.py' +] + +crop_size = (480, 480) + +model = dict( + type='DepthEstimator', + data_preprocessor=dict(size=crop_size), + backbone=dict( + class_embed_path='https://download.openmmlab.com/mmsegmentation/' + 'v0.5/vpd/nyu_class_embeddings.pth', + class_embed_select=True, + pad_shape=512, + unet_cfg=dict(use_attn=False), + ), + decode_head=dict( + type='VPDDepthHead', + in_channels=[320, 640, 1280, 1280], + max_depth=10, + fmap_border=(1, 1), + ), + test_cfg=dict(mode='slide_flip', crop_size=crop_size, stride=(160, 160))) + +default_hooks = dict(checkpoint=dict(save_best='rmse', rule='less')) + +# custom optimizer +optim_wrapper = dict( + type='ForceDefaultOptimWrapperConstructor', + paramwise_cfg=dict( + bias_decay_mult=0, + force_default_settings=True, + custom_keys={ + 'backbone.encoder_vq': dict(lr_mult=0), + 'backbone.unet': dict(lr_mult=0.01), + })) diff --git a/mmseg/datasets/transforms/__init__.py b/mmseg/datasets/transforms/__init__.py index 03c3f866b7..2bdb552663 100644 --- a/mmseg/datasets/transforms/__init__.py +++ b/mmseg/datasets/transforms/__init__.py @@ -10,7 +10,8 @@ BioMedicalGaussianBlur, BioMedicalGaussianNoise, BioMedicalRandomGamma, ConcatCDInput, GenerateEdge, PhotoMetricDistortion, RandomCrop, RandomCutOut, - RandomMosaic, RandomRotate, RandomRotFlip, Rerange, + RandomDepthMix, RandomFlip, RandomMosaic, + RandomRotate, RandomRotFlip, Rerange, ResizeShortestEdge, ResizeToMultiple, RGB2Gray, SegRescale) @@ -24,5 +25,6 @@ 'ResizeShortestEdge', 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur', 'BioMedical3DRandomFlip', 'BioMedicalRandomGamma', 'BioMedical3DPad', 'RandomRotFlip', 'Albu', 'LoadSingleRSImageFromFile', 'ConcatCDInput', - 'LoadMultipleRSImageFromFile', 'LoadDepthAnnotation' + 'LoadMultipleRSImageFromFile', 'LoadDepthAnnotation', 'RandomDepthMix', + 'RandomFlip' ] diff --git a/mmseg/datasets/transforms/loading.py b/mmseg/datasets/transforms/loading.py index c7d6af0ef6..438b5527f0 100644 --- a/mmseg/datasets/transforms/loading.py +++ b/mmseg/datasets/transforms/loading.py @@ -647,6 +647,8 @@ class LoadDepthAnnotation(BaseTransform): - gt_depth_map (np.ndarray): Depth map with shape (Y, X) by default, and data type is float32 if set to_float32 = True. + - depth_rescale_factor (float): The rescale factor of depth map, which + can be used to recover the original value of depth map. Args: decode_backend (str): The data decoding backend type. Options are @@ -691,6 +693,7 @@ def transform(self, results: Dict) -> Dict: gt_depth_map *= self.depth_rescale_factor results['gt_depth_map'] = gt_depth_map results['seg_fields'].append('gt_depth_map') + results['depth_rescale_factor'] = self.depth_rescale_factor return results def __repr__(self): diff --git a/mmseg/datasets/transforms/transforms.py b/mmseg/datasets/transforms/transforms.py index 1571c2279c..8eea35b172 100644 --- a/mmseg/datasets/transforms/transforms.py +++ b/mmseg/datasets/transforms/transforms.py @@ -8,6 +8,7 @@ import mmcv import mmengine import numpy as np +from mmcv.transforms import RandomFlip as MMCV_RandomFlip from mmcv.transforms.base import BaseTransform from mmcv.transforms.utils import cache_randomness from mmengine.utils import is_tuple_of @@ -950,6 +951,86 @@ def __repr__(self): return repr_str +@TRANSFORMS.register_module() +class RandomFlip(MMCV_RandomFlip): + """Flip the image & bbox & segmentation map. Added or Updated + keys: flip, flip_direction, img, gt_bboxes, gt_seg_map, and gt_depth_map. + There are 3 flip modes: + + - ``prob`` is float, ``direction`` is string: the image will be + ``direction``ly flipped with probability of ``prob`` . + E.g., ``prob=0.5``, ``direction='horizontal'``, + then image will be horizontally flipped with probability of 0.5. + + - ``prob`` is float, ``direction`` is list of string: the image will + be ``direction[i]``ly flipped with probability of + ``prob/len(direction)``. + E.g., ``prob=0.5``, ``direction=['horizontal', 'vertical']``, + then image will be horizontally flipped with probability of 0.25, + vertically with probability of 0.25. + + - ``prob`` is list of float, ``direction`` is list of string: + given ``len(prob) == len(direction)``, the image will + be ``direction[i]``ly flipped with probability of ``prob[i]``. + E.g., ``prob=[0.3, 0.5]``, ``direction=['horizontal', + 'vertical']``, then image will be horizontally flipped with + probability of 0.3, vertically with probability of 0.5. + + Required Keys: + + - img + - gt_bboxes (optional) + - gt_seg_map (optional) + - gt_depth_map (optional) + + Modified Keys: + + - img + - gt_bboxes (optional) + - gt_seg_map (optional) + - gt_depth_map (optional) + + Added Keys: + + - flip + - flip_direction + - swap_seg_labels (optional) + + Args: + prob (float | list[float], optional): The flipping probability. + Defaults to None. + direction(str | list[str]): The flipping direction. Options + If input is a list, the length must equal ``prob``. Each + element in ``prob`` indicates the flip probability of + corresponding direction. Defaults to 'horizontal'. + swap_seg_labels (list, optional): The label pair need to be swapped + for ground truth, like 'left arm' and 'right arm' need to be + swapped after horizontal flipping. For example, ``[(1, 5)]``, + where 1/5 is the label of the left/right arm. Defaults to None. + """ + + def _flip(self, results: dict) -> None: + """Flip images, bounding boxes and semantic segmentation map.""" + # flip image + results['img'] = mmcv.imflip( + results['img'], direction=results['flip_direction']) + + img_shape = results['img'].shape[:2] + + # flip bboxes + if results.get('gt_bboxes', None) is not None: + results['gt_bboxes'] = self._flip_bbox(results['gt_bboxes'], + img_shape, + results['flip_direction']) + + # flip seg map + for key in results.get('seg_fields', []): + if results.get(key, None) is not None: + results[key] = self._flip_seg_map( + results[key], direction=results['flip_direction']) + results['swap_seg_labels'] = self.swap_seg_labels + + @TRANSFORMS.register_module() class RandomMosaic(BaseTransform): """Mosaic augmentation. Given 4 images, mosaic transform combines them into @@ -2318,3 +2399,49 @@ def __repr__(self): repr_str = self.__class__.__name__ repr_str += f'(input_keys={self.input_keys}, ' return repr_str + + +@TRANSFORMS.register_module() +class RandomDepthMix(BaseTransform): + """This class implements the RandomDepthMix transform. + + Args: + prob (float): Probability of applying the transformation. + Defaults to 0.25. + mix_scale_ratio (float): Ratio to scale the mix width. + Defaults to 0.75. + """ + + def __init__( + self, + prob: float = 0.25, + mix_scale_ratio: float = 0.75, + ): + super().__init__() + + self.prob = prob + self.mix_scale_ratio = mix_scale_ratio + + def transform(self, results: dict) -> dict: + if random.random() > self.prob: + return results + + h, w = results['img_shape'][:2] + left = int(w * random.random()) + width_ratio = self.mix_scale_ratio * random.random() + width = int(max(1, (w - left) * width_ratio)) + + img = results['img'] + depth_rescale_factor = results.get('depth_rescale_factor', 1) + depth_map = results['gt_depth_map'] / depth_rescale_factor + + if img.ndim == 3: + for c in range(img.shape[-1]): + img[:, left:left + width, c] = depth_map[:, left:left + width] + elif img.ndim == 2: + img[:, left:left + width] = depth_map[:, left:left + width] + else: + raise ValueError(f'Invalid image shape ({img.shape})') + + results['img'] = img + return results diff --git a/mmseg/engine/__init__.py b/mmseg/engine/__init__.py index ada4057012..98139a0047 100644 --- a/mmseg/engine/__init__.py +++ b/mmseg/engine/__init__.py @@ -1,9 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. from .hooks import SegVisualizationHook -from .optimizers import (LayerDecayOptimizerConstructor, +from .optimizers import (ForceDefaultOptimWrapperConstructor, + LayerDecayOptimizerConstructor, LearningRateDecayOptimizerConstructor) +from .schedulers import PolyLRRatio __all__ = [ 'LearningRateDecayOptimizerConstructor', 'LayerDecayOptimizerConstructor', - 'SegVisualizationHook' + 'SegVisualizationHook', 'PolyLRRatio', + 'ForceDefaultOptimWrapperConstructor' ] diff --git a/mmseg/engine/optimizers/__init__.py b/mmseg/engine/optimizers/__init__.py index 4fbf4ecfcd..e4cf58741f 100644 --- a/mmseg/engine/optimizers/__init__.py +++ b/mmseg/engine/optimizers/__init__.py @@ -1,7 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .force_default_constructor import ForceDefaultOptimWrapperConstructor from .layer_decay_optimizer_constructor import ( LayerDecayOptimizerConstructor, LearningRateDecayOptimizerConstructor) __all__ = [ - 'LearningRateDecayOptimizerConstructor', 'LayerDecayOptimizerConstructor' + 'LearningRateDecayOptimizerConstructor', 'LayerDecayOptimizerConstructor', + 'ForceDefaultOptimWrapperConstructor' ] diff --git a/mmseg/engine/optimizers/force_default_constructor.py b/mmseg/engine/optimizers/force_default_constructor.py new file mode 100644 index 0000000000..12c642ad41 --- /dev/null +++ b/mmseg/engine/optimizers/force_default_constructor.py @@ -0,0 +1,255 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import logging +from typing import List, Optional, Union + +import torch +import torch.nn as nn +from mmengine.logging import print_log +from mmengine.optim import DefaultOptimWrapperConstructor +from mmengine.utils.dl_utils import mmcv_full_available +from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm, _InstanceNorm +from torch.nn import GroupNorm, LayerNorm + +from mmseg.registry import OPTIM_WRAPPER_CONSTRUCTORS + + +@OPTIM_WRAPPER_CONSTRUCTORS.register_module() +class ForceDefaultOptimWrapperConstructor(DefaultOptimWrapperConstructor): + """Default constructor with forced optimizer settings. + + This constructor extends the default constructor to add an option for + forcing default optimizer settings. This is useful for ensuring that + certain parameters or layers strictly adhere to pre-defined default + settings, regardless of any custom settings specified. + + By default, each parameter share the same optimizer settings, and we + provide an argument ``paramwise_cfg`` to specify parameter-wise settings. + It is a dict and may contain various fields like 'custom_keys', + 'bias_lr_mult', etc., as well as the additional field + `force_default_settings` which allows for enforcing default settings on + optimizer parameters. + + - ``custom_keys`` (dict): Specified parameters-wise settings by keys. If + one of the keys in ``custom_keys`` is a substring of the name of one + parameter, then the setting of the parameter will be specified by + ``custom_keys[key]`` and other setting like ``bias_lr_mult`` etc. will + be ignored. It should be noted that the aforementioned ``key`` is the + longest key that is a substring of the name of the parameter. If there + are multiple matched keys with the same length, then the key with lower + alphabet order will be chosen. + ``custom_keys[key]`` should be a dict and may contain fields ``lr_mult`` + and ``decay_mult``. See Example 2 below. + - ``bias_lr_mult`` (float): It will be multiplied to the learning + rate for all bias parameters (except for those in normalization + layers and offset layers of DCN). + - ``bias_decay_mult`` (float): It will be multiplied to the weight + decay for all bias parameters (except for those in + normalization layers, depthwise conv layers, offset layers of DCN). + - ``norm_decay_mult`` (float): It will be multiplied to the weight + decay for all weight and bias parameters of normalization + layers. + - ``flat_decay_mult`` (float): It will be multiplied to the weight + decay for all one-dimensional parameters + - ``dwconv_decay_mult`` (float): It will be multiplied to the weight + decay for all weight and bias parameters of depthwise conv + layers. + - ``dcn_offset_lr_mult`` (float): It will be multiplied to the learning + rate for parameters of offset layer in the deformable convs + of a model. + - ``bypass_duplicate`` (bool): If true, the duplicate parameters + would not be added into optimizer. Defaults to False. + - ``force_default_settings`` (bool): If true, this will override any + custom settings defined by ``custom_keys`` and enforce the use of + default settings for optimizer parameters like ``bias_lr_mult``. + This is particularly useful when you want to ensure that certain layers + or parameters adhere strictly to the pre-defined default settings. + + Note: + + 1. If the option ``dcn_offset_lr_mult`` is used, the constructor will + override the effect of ``bias_lr_mult`` in the bias of offset layer. + So be careful when using both ``bias_lr_mult`` and + ``dcn_offset_lr_mult``. If you wish to apply both of them to the offset + layer in deformable convs, set ``dcn_offset_lr_mult`` to the original + ``dcn_offset_lr_mult`` * ``bias_lr_mult``. + + 2. If the option ``dcn_offset_lr_mult`` is used, the constructor will + apply it to all the DCN layers in the model. So be careful when the + model contains multiple DCN layers in places other than backbone. + + 3. When the option ``force_default_settings`` is true, it will override + any custom settings provided in ``custom_keys``. This ensures that the + default settings for the optimizer parameters are used. + + Args: + optim_wrapper_cfg (dict): The config dict of the optimizer wrapper. + + Required fields of ``optim_wrapper_cfg`` are + + - ``type``: class name of the OptimizerWrapper + - ``optimizer``: The configuration of optimizer. + + Optional fields of ``optim_wrapper_cfg`` are + + - any arguments of the corresponding optimizer wrapper type, + e.g., accumulative_counts, clip_grad, etc. + + Required fields of ``optimizer`` are + + - `type`: class name of the optimizer. + + Optional fields of ``optimizer`` are + + - any arguments of the corresponding optimizer type, e.g., + lr, weight_decay, momentum, etc. + + paramwise_cfg (dict, optional): Parameter-wise options. + + Example 1: + >>> model = torch.nn.modules.Conv1d(1, 1, 1) + >>> optim_wrapper_cfg = dict( + >>> dict(type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01, + >>> momentum=0.9, weight_decay=0.0001)) + >>> paramwise_cfg = dict(norm_decay_mult=0.) + >>> optim_wrapper_builder = DefaultOptimWrapperConstructor( + >>> optim_wrapper_cfg, paramwise_cfg) + >>> optim_wrapper = optim_wrapper_builder(model) + + Example 2: + >>> # assume model have attribute model.backbone and model.cls_head + >>> optim_wrapper_cfg = dict(type='OptimWrapper', optimizer=dict( + >>> type='SGD', lr=0.01, weight_decay=0.95)) + >>> paramwise_cfg = dict(custom_keys={ + >>> 'backbone': dict(lr_mult=0.1, decay_mult=0.9)}) + >>> optim_wrapper_builder = DefaultOptimWrapperConstructor( + >>> optim_wrapper_cfg, paramwise_cfg) + >>> optim_wrapper = optim_wrapper_builder(model) + >>> # Then the `lr` and `weight_decay` for model.backbone is + >>> # (0.01 * 0.1, 0.95 * 0.9). `lr` and `weight_decay` for + >>> # model.cls_head is (0.01, 0.95). + """ + + def add_params(self, + params: List[dict], + module: nn.Module, + prefix: str = '', + is_dcn_module: Optional[Union[int, float]] = None) -> None: + """Add all parameters of module to the params list. + + The parameters of the given module will be added to the list of param + groups, with specific rules defined by paramwise_cfg. + + Args: + params (list[dict]): A list of param groups, it will be modified + in place. + module (nn.Module): The module to be added. + prefix (str): The prefix of the module + is_dcn_module (int|float|None): If the current module is a + submodule of DCN, `is_dcn_module` will be passed to + control conv_offset layer's learning rate. Defaults to None. + """ + # get param-wise options + custom_keys = self.paramwise_cfg.get('custom_keys', {}) + # first sort with alphabet order and then sort with reversed len of str + sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True) + + bias_lr_mult = self.paramwise_cfg.get('bias_lr_mult', None) + bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', None) + norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', None) + dwconv_decay_mult = self.paramwise_cfg.get('dwconv_decay_mult', None) + flat_decay_mult = self.paramwise_cfg.get('flat_decay_mult', None) + bypass_duplicate = self.paramwise_cfg.get('bypass_duplicate', False) + dcn_offset_lr_mult = self.paramwise_cfg.get('dcn_offset_lr_mult', None) + force_default_settings = self.paramwise_cfg.get( + 'force_default_settings', False) + + # special rules for norm layers and depth-wise conv layers + is_norm = isinstance(module, + (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm)) + is_dwconv = ( + isinstance(module, torch.nn.Conv2d) + and module.in_channels == module.groups) + + for name, param in module.named_parameters(recurse=False): + param_group = {'params': [param]} + if bypass_duplicate and self._is_in(param_group, params): + print_log( + f'{prefix} is duplicate. It is skipped since ' + f'bypass_duplicate={bypass_duplicate}', + logger='current', + level=logging.WARNING) + continue + if not param.requires_grad: + params.append(param_group) + continue + + # if the parameter match one of the custom keys, ignore other rules + is_custom = False + for key in sorted_keys: + if key in f'{prefix}.{name}': + is_custom = True + lr_mult = custom_keys[key].get('lr_mult', 1.) + param_group['lr'] = self.base_lr * lr_mult + if self.base_wd is not None: + decay_mult = custom_keys[key].get('decay_mult', 1.) + param_group['weight_decay'] = self.base_wd * decay_mult + # add custom settings to param_group + for k, v in custom_keys[key].items(): + param_group[k] = v + break + + if not is_custom or force_default_settings: + # bias_lr_mult affects all bias parameters + # except for norm.bias dcn.conv_offset.bias + if name == 'bias' and not ( + is_norm or is_dcn_module) and bias_lr_mult is not None: + param_group['lr'] = self.base_lr * bias_lr_mult + + if (prefix.find('conv_offset') != -1 and is_dcn_module + and dcn_offset_lr_mult is not None + and isinstance(module, torch.nn.Conv2d)): + # deal with both dcn_offset's bias & weight + param_group['lr'] = self.base_lr * dcn_offset_lr_mult + + # apply weight decay policies + if self.base_wd is not None: + # norm decay + if is_norm and norm_decay_mult is not None: + param_group[ + 'weight_decay'] = self.base_wd * norm_decay_mult + # bias lr and decay + elif (name == 'bias' and not is_dcn_module + and bias_decay_mult is not None): + param_group[ + 'weight_decay'] = self.base_wd * bias_decay_mult + # depth-wise conv + elif is_dwconv and dwconv_decay_mult is not None: + param_group[ + 'weight_decay'] = self.base_wd * dwconv_decay_mult + # flatten parameters except dcn offset + elif (param.ndim == 1 and not is_dcn_module + and flat_decay_mult is not None): + param_group[ + 'weight_decay'] = self.base_wd * flat_decay_mult + params.append(param_group) + for key, value in param_group.items(): + if key == 'params': + continue + full_name = f'{prefix}.{name}' if prefix else name + print_log( + f'paramwise_options -- {full_name}:{key}={value}', + logger='current') + + if mmcv_full_available(): + from mmcv.ops import DeformConv2d, ModulatedDeformConv2d + is_dcn_module = isinstance(module, + (DeformConv2d, ModulatedDeformConv2d)) + else: + is_dcn_module = False + for child_name, child_mod in module.named_children(): + child_prefix = f'{prefix}.{child_name}' if prefix else child_name + self.add_params( + params, + child_mod, + prefix=child_prefix, + is_dcn_module=is_dcn_module) diff --git a/mmseg/engine/schedulers/__init__.py b/mmseg/engine/schedulers/__init__.py new file mode 100644 index 0000000000..3cd3f62113 --- /dev/null +++ b/mmseg/engine/schedulers/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .poly_ratio_scheduler import PolyLRRatio + +__all__ = ['PolyLRRatio'] diff --git a/mmseg/engine/schedulers/poly_ratio_scheduler.py b/mmseg/engine/schedulers/poly_ratio_scheduler.py new file mode 100644 index 0000000000..057203acc9 --- /dev/null +++ b/mmseg/engine/schedulers/poly_ratio_scheduler.py @@ -0,0 +1,62 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +from mmengine.optim.scheduler import PolyLR + +from mmseg.registry import PARAM_SCHEDULERS + + +@PARAM_SCHEDULERS.register_module() +class PolyLRRatio(PolyLR): + """Implements polynomial learning rate decay with ratio. + + This scheduler adjusts the learning rate of each parameter group + following a polynomial decay equation. The decay can occur in + conjunction with external parameter adjustments made outside this + scheduler. + + Args: + optimizer (Optimizer or OptimWrapper): Wrapped optimizer. + eta_min (float): Minimum learning rate at the end of scheduling. + Defaults to 0. + eta_min_ratio (float, optional): The ratio of the minimum parameter + value to the base parameter value. Either `eta_min` or + `eta_min_ratio` should be specified. Defaults to None. + power (float): The power of the polynomial. Defaults to 1.0. + begin (int): Step at which to start updating the parameters. + Defaults to 0. + end (int): Step at which to stop updating the parameters. + Defaults to INF. + last_step (int): The index of last step. Used for resume without + state dict. Defaults to -1. + by_epoch (bool): Whether the scheduled parameters are updated by + epochs. Defaults to True. + verbose (bool): Whether to print the value for each update. + Defaults to False. + """ + + def __init__(self, eta_min_ratio: Optional[int] = None, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.eta_min_ratio = eta_min_ratio + + def _get_value(self): + """Compute value using chainable form of the scheduler.""" + + if self.last_step == 0: + return [ + group[self.param_name] for group in self.optimizer.param_groups + ] + + param_groups_value = [] + for base_value, param_group in zip(self.base_values, + self.optimizer.param_groups): + eta_min = self.eta_min if self.eta_min_ratio is None else \ + base_value * self.eta_min_ratio + step_ratio = (1 - 1 / + (self.total_iters - self.last_step + 1))**self.power + step_value = (param_group[self.param_name] - + eta_min) * step_ratio + eta_min + param_groups_value.append(step_value) + + return param_groups_value diff --git a/mmseg/models/backbones/__init__.py b/mmseg/models/backbones/__init__.py index d9228a500b..784d3dfdb7 100644 --- a/mmseg/models/backbones/__init__.py +++ b/mmseg/models/backbones/__init__.py @@ -23,6 +23,7 @@ from .twins import PCPVT, SVT from .unet import UNet from .vit import VisionTransformer +from .vpd import VPD __all__ = [ 'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN', @@ -30,5 +31,5 @@ 'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer', 'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone', 'ERFNet', 'PCPVT', 'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT', 'MAE', 'PIDNet', 'MSCAN', - 'DDRNet' + 'DDRNet', 'VPD' ] diff --git a/mmseg/models/backbones/vpd.py b/mmseg/models/backbones/vpd.py new file mode 100644 index 0000000000..8b57be39b2 --- /dev/null +++ b/mmseg/models/backbones/vpd.py @@ -0,0 +1,383 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# ------------------------------------------------------------------------------ +# Adapted from https://github.com/wl-zhao/VPD/blob/main/vpd/models.py +# Original licence: MIT License +# ------------------------------------------------------------------------------ + +import math +from typing import List, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from ldm.modules.diffusionmodules.util import timestep_embedding +from ldm.util import instantiate_from_config +from mmengine.model import BaseModule +from mmengine.runner import CheckpointLoader, load_checkpoint + +from mmseg.registry import MODELS +from mmseg.utils import ConfigType, OptConfigType + + +def register_attention_control(model, controller): + """Registers a control function to manage attention within a model. + + Args: + model: The model to which attention is to be registered. + controller: The control function responsible for managing attention. + """ + + def ca_forward(self, place_in_unet): + """Custom forward method for attention. + + Args: + self: Reference to the current object. + place_in_unet: The location in UNet (down/mid/up). + + Returns: + The modified forward method. + """ + + def forward(x, context=None, mask=None): + h = self.heads + is_cross = context is not None + context = context or x # if context is None, use x + + q, k, v = self.to_q(x), self.to_k(context), self.to_v(context) + q, k, v = ( + tensor.view(tensor.shape[0] * h, tensor.shape[1], + tensor.shape[2] // h) for tensor in [q, k, v]) + + sim = torch.matmul(q, k.transpose(-2, -1)) * self.scale + + if mask is not None: + mask = mask.flatten(1).unsqueeze(1).repeat(h, 1, 1) + max_neg_value = -torch.finfo(sim.dtype).max + sim.masked_fill_(~mask, max_neg_value) + + attn = sim.softmax(dim=-1) + attn_mean = attn.view(h, attn.shape[0] // h, + *attn.shape[1:]).mean(0) + controller(attn_mean, is_cross, place_in_unet) + + out = torch.matmul(attn, v) + out = out.view(out.shape[0] // h, out.shape[1], out.shape[2] * h) + return self.to_out(out) + + return forward + + def register_recr(net_, count, place_in_unet): + """Recursive function to register the custom forward method to all + CrossAttention layers. + + Args: + net_: The network layer currently being processed. + count: The current count of layers processed. + place_in_unet: The location in UNet (down/mid/up). + + Returns: + The updated count of layers processed. + """ + if net_.__class__.__name__ == 'CrossAttention': + net_.forward = ca_forward(net_, place_in_unet) + return count + 1 + if hasattr(net_, 'children'): + return sum( + register_recr(child, 0, place_in_unet) + for child in net_.children()) + return count + + cross_att_count = sum( + register_recr(net[1], 0, place) for net, place in [ + (child, 'down') if 'input_blocks' in name else ( + child, 'up') if 'output_blocks' in name else + (child, + 'mid') if 'middle_block' in name else (None, None) # Default case + for name, child in model.diffusion_model.named_children() + ] if net is not None) + + controller.num_att_layers = cross_att_count + + +class AttentionStore: + """A class for storing attention information in the UNet model. + + Attributes: + base_size (int): Base size for storing attention information. + max_size (int): Maximum size for storing attention information. + """ + + def __init__(self, base_size=64, max_size=None): + """Initialize AttentionStore with default or custom sizes.""" + self.reset() + self.base_size = base_size + self.max_size = max_size or (base_size // 2) + self.num_att_layers = -1 + + @staticmethod + def get_empty_store(): + """Returns an empty store for holding attention values.""" + return { + key: [] + for key in [ + 'down_cross', 'mid_cross', 'up_cross', 'down_self', 'mid_self', + 'up_self' + ] + } + + def reset(self): + """Resets the step and attention stores to their initial states.""" + self.cur_step = 0 + self.cur_att_layer = 0 + self.step_store = self.get_empty_store() + self.attention_store = {} + + def forward(self, attn, is_cross: bool, place_in_unet: str): + """Processes a single forward step, storing the attention. + + Args: + attn: The attention tensor. + is_cross (bool): Whether it's cross attention. + place_in_unet (str): The location in UNet (down/mid/up). + + Returns: + The unmodified attention tensor. + """ + key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" + if attn.shape[1] <= (self.max_size)**2: + self.step_store[key].append(attn) + return attn + + def between_steps(self): + """Processes and stores attention information between steps.""" + if not self.attention_store: + self.attention_store = self.step_store + else: + for key in self.attention_store: + self.attention_store[key] = [ + stored + step for stored, step in zip( + self.attention_store[key], self.step_store[key]) + ] + self.step_store = self.get_empty_store() + + def get_average_attention(self): + """Calculates and returns the average attention across all steps.""" + return { + key: [item for item in self.step_store[key]] + for key in self.step_store + } + + def __call__(self, attn, is_cross: bool, place_in_unet: str): + """Allows the class instance to be callable.""" + return self.forward(attn, is_cross, place_in_unet) + + @property + def num_uncond_att_layers(self): + """Returns the number of unconditional attention layers (default is + 0).""" + return 0 + + def step_callback(self, x_t): + """A placeholder for a step callback. + + Returns the input unchanged. + """ + return x_t + + +class UNetWrapper(nn.Module): + """A wrapper for UNet with optional attention mechanisms. + + Args: + unet (nn.Module): The UNet model to wrap + use_attn (bool): Whether to use attention. Defaults to True + base_size (int): Base size for the attention store. Defaults to 512 + max_attn_size (int, optional): Maximum size for the attention store. + Defaults to None + attn_selector (str): The types of attention to use. + Defaults to 'up_cross+down_cross' + """ + + def __init__(self, + unet, + use_attn=True, + base_size=512, + max_attn_size=None, + attn_selector='up_cross+down_cross'): + super().__init__() + self.unet = unet + self.attention_store = AttentionStore( + base_size=base_size // 8, max_size=max_attn_size) + self.attn_selector = attn_selector.split('+') + self.use_attn = use_attn + self.init_sizes(base_size) + if self.use_attn: + register_attention_control(unet, self.attention_store) + + def init_sizes(self, base_size): + """Initialize sizes based on the base size.""" + self.size16 = base_size // 32 + self.size32 = base_size // 16 + self.size64 = base_size // 8 + + def forward(self, x, timesteps=None, context=None, y=None, **kwargs): + """Forward pass through the model.""" + diffusion_model = self.unet.diffusion_model + if self.use_attn: + self.attention_store.reset() + hs, emb, out_list = self._unet_forward(x, timesteps, context, y, + diffusion_model) + if self.use_attn: + self._append_attn_to_output(out_list) + return out_list[::-1] + + def _unet_forward(self, x, timesteps, context, y, diffusion_model): + hs = [] + t_emb = timestep_embedding( + timesteps, diffusion_model.model_channels, repeat_only=False) + emb = diffusion_model.time_embed(t_emb) + h = x.type(diffusion_model.dtype) + for module in diffusion_model.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = diffusion_model.middle_block(h, emb, context) + out_list = [] + for i_out, module in enumerate(diffusion_model.output_blocks): + h = torch.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + if i_out in [1, 4, 7]: + out_list.append(h) + h = h.type(x.dtype) + out_list.append(h) + return hs, emb, out_list + + def _append_attn_to_output(self, out_list): + avg_attn = self.attention_store.get_average_attention() + attns = {self.size16: [], self.size32: [], self.size64: []} + for k in self.attn_selector: + for up_attn in avg_attn[k]: + size = int(math.sqrt(up_attn.shape[1])) + up_attn = up_attn.transpose(-1, -2).reshape( + *up_attn.shape[:2], size, -1) + attns[size].append(up_attn) + attn16 = torch.stack(attns[self.size16]).mean(0) + attn32 = torch.stack(attns[self.size32]).mean(0) + attn64 = torch.stack(attns[self.size64]).mean(0) if len( + attns[self.size64]) > 0 else None + out_list[1] = torch.cat([out_list[1], attn16], dim=1) + out_list[2] = torch.cat([out_list[2], attn32], dim=1) + if attn64 is not None: + out_list[3] = torch.cat([out_list[3], attn64], dim=1) + + +class TextAdapter(nn.Module): + """A PyTorch Module that serves as a text adapter. + + This module takes text embeddings and adjusts them based on a scaling + factor gamma. + """ + + def __init__(self, text_dim=768): + super().__init__() + self.fc = nn.Sequential( + nn.Linear(text_dim, text_dim), nn.GELU(), + nn.Linear(text_dim, text_dim)) + + def forward(self, texts, gamma): + texts_after = self.fc(texts) + texts = texts + gamma * texts_after + return texts + + +@MODELS.register_module() +class VPD(BaseModule): + """VPD (Visual Perception Diffusion) model. + + .. _`VPD`: https://arxiv.org/abs/2303.02153 + + Args: + diffusion_cfg (dict): Configuration for diffusion model. + class_embed_path (str): Path for class embeddings. + unet_cfg (dict, optional): Configuration for U-Net. + gamma (float, optional): Gamma for text adaptation. Defaults to 1e-4. + class_embed_select (bool, optional): If True, enables class embedding + selection. Defaults to False. + pad_shape (Optional[Union[int, List[int]]], optional): Padding shape. + Defaults to None. + pad_val (Union[int, List[int]], optional): Padding value. + Defaults to 0. + init_cfg (dict, optional): Configuration for network initialization. + """ + + def __init__(self, + diffusion_cfg: ConfigType, + class_embed_path: str, + unet_cfg: OptConfigType = dict(), + gamma: float = 1e-4, + class_embed_select=False, + pad_shape: Optional[Union[int, List[int]]] = None, + pad_val: Union[int, List[int]] = 0, + init_cfg: OptConfigType = None): + + super().__init__(init_cfg=init_cfg) + + if pad_shape is not None: + if not isinstance(pad_shape, (list, tuple)): + pad_shape = (pad_shape, pad_shape) + + self.pad_shape = pad_shape + self.pad_val = pad_val + + # diffusion model + diffusion_checkpoint = diffusion_cfg.pop('checkpoint', None) + sd_model = instantiate_from_config(diffusion_cfg) + if diffusion_checkpoint is not None: + load_checkpoint(sd_model, diffusion_checkpoint, strict=False) + + self.encoder_vq = sd_model.first_stage_model + self.unet = UNetWrapper(sd_model.model, **unet_cfg) + + # class embeddings & text adapter + class_embeddings = CheckpointLoader.load_checkpoint(class_embed_path) + text_dim = class_embeddings.size(-1) + self.text_adapter = TextAdapter(text_dim=text_dim) + self.class_embed_select = class_embed_select + if class_embed_select: + class_embeddings = torch.cat( + (class_embeddings, class_embeddings.mean(dim=0, + keepdims=True)), + dim=0) + self.register_buffer('class_embeddings', class_embeddings) + self.gamma = nn.Parameter(torch.ones(text_dim) * gamma) + + def forward(self, x): + """Extract features from images.""" + + # calculate cross-attn map + if self.class_embed_select: + if isinstance(x, (tuple, list)): + x, class_ids = x[:2] + class_ids = class_ids.tolist() + else: + class_ids = [-1] * x.size(0) + class_embeddings = self.class_embeddings[class_ids] + c_crossattn = self.text_adapter(class_embeddings, self.gamma) + c_crossattn = c_crossattn.unsqueeze(1) + else: + class_embeddings = self.class_embeddings + c_crossattn = self.text_adapter(class_embeddings, self.gamma) + c_crossattn = c_crossattn.unsqueeze(0).repeat(x.size(0), 1, 1) + + # pad to required input shape for pretrained diffusion model + if self.pad_shape is not None: + pad_width = max(0, self.pad_shape[1] - x.shape[-1]) + pad_height = max(0, self.pad_shape[0] - x.shape[-2]) + x = F.pad(x, (0, pad_width, 0, pad_height), value=self.pad_val) + + # forward the denoising model + with torch.no_grad(): + latents = self.encoder_vq.encode(x).mode().detach() + t = torch.ones((x.shape[0], ), device=x.device).long() + outs = self.unet(latents, t, context=c_crossattn) + + return outs diff --git a/mmseg/models/decode_heads/__init__.py b/mmseg/models/decode_heads/__init__.py index 36c37ec2dd..b63cdc3e2c 100644 --- a/mmseg/models/decode_heads/__init__.py +++ b/mmseg/models/decode_heads/__init__.py @@ -33,6 +33,7 @@ from .setr_up_head import SETRUPHead from .stdc_head import STDCHead from .uper_head import UPerHead +from .vpd_depth_head import VPDDepthHead __all__ = [ 'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead', @@ -42,5 +43,5 @@ 'SETRMLAHead', 'DPTHead', 'SETRMLAHead', 'SegmenterMaskTransformerHead', 'SegformerHead', 'ISAHead', 'STDCHead', 'IterativeDecodeHead', 'KernelUpdateHead', 'KernelUpdator', 'MaskFormerHead', 'Mask2FormerHead', - 'LightHamHead', 'PIDHead', 'DDRHead' + 'LightHamHead', 'PIDHead', 'DDRHead', 'VPDDepthHead' ] diff --git a/mmseg/models/decode_heads/vpd_depth_head.py b/mmseg/models/decode_heads/vpd_depth_head.py new file mode 100644 index 0000000000..0c54c2da1b --- /dev/null +++ b/mmseg/models/decode_heads/vpd_depth_head.py @@ -0,0 +1,254 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Sequence, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import build_conv_layer, build_norm_layer, build_upsample_layer +from mmengine.model import BaseModule +from torch import Tensor + +from mmseg.registry import MODELS +from mmseg.utils import SampleList +from ..builder import build_loss +from ..utils import resize +from .decode_head import BaseDecodeHead + + +class VPDDepthDecoder(BaseModule): + """VPD Depth Decoder class. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + num_deconv_layers (int): Number of deconvolution layers. + num_deconv_filters (List[int]): List of output channels for + deconvolution layers. + init_cfg (Optional[Union[Dict, List[Dict]]], optional): Configuration + for weight initialization. Defaults to Normal for Conv2d and + ConvTranspose2d layers. + """ + + def __init__(self, + in_channels: int, + out_channels: int, + num_deconv_layers: int, + num_deconv_filters: List[int], + init_cfg: Optional[Union[Dict, List[Dict]]] = dict( + type='Normal', + std=0.001, + layer=['Conv2d', 'ConvTranspose2d'])): + super().__init__(init_cfg=init_cfg) + self.in_channels = in_channels + + self.deconv_layers = self._make_deconv_layer( + num_deconv_layers, + num_deconv_filters, + ) + + conv_layers = [] + conv_layers.append( + build_conv_layer( + dict(type='Conv2d'), + in_channels=num_deconv_filters[-1], + out_channels=out_channels, + kernel_size=3, + stride=1, + padding=1)) + conv_layers.append(build_norm_layer(dict(type='BN'), out_channels)[1]) + conv_layers.append(nn.ReLU(inplace=True)) + self.conv_layers = nn.Sequential(*conv_layers) + + self.up_sample = nn.Upsample( + scale_factor=2, mode='bilinear', align_corners=False) + + def forward(self, x): + """Forward pass through the decoder network.""" + out = self.deconv_layers(x) + out = self.conv_layers(out) + + out = self.up_sample(out) + out = self.up_sample(out) + + return out + + def _make_deconv_layer(self, num_layers, num_deconv_filters): + """Make deconv layers.""" + + layers = [] + in_channels = self.in_channels + for i in range(num_layers): + + num_channels = num_deconv_filters[i] + layers.append( + build_upsample_layer( + dict(type='deconv'), + in_channels=in_channels, + out_channels=num_channels, + kernel_size=2, + stride=2, + padding=0, + output_padding=0, + bias=False)) + layers.append(nn.BatchNorm2d(num_channels)) + layers.append(nn.ReLU(inplace=True)) + in_channels = num_channels + + return nn.Sequential(*layers) + + +@MODELS.register_module() +class VPDDepthHead(BaseDecodeHead): + """Depth Prediction Head for VPD. + + .. _`VPD`: https://arxiv.org/abs/2303.02153 + + Args: + max_depth (float): Maximum depth value. Defaults to 10.0. + in_channels (Sequence[int]): Number of input channels for each + convolutional layer. + embed_dim (int): Dimension of embedding. Defaults to 192. + feature_dim (int): Dimension of aggregated feature. Defaults to 1536. + num_deconv_layers (int): Number of deconvolution layers in the + decoder. Defaults to 3. + num_deconv_filters (Sequence[int]): Number of filters for each deconv + layer. Defaults to (32, 32, 32). + fmap_border (Union[int, Sequence[int]]): Feature map border for + cropping. Defaults to 0. + align_corners (bool): Flag for align_corners in interpolation. + Defaults to False. + loss_decode (dict): Configurations for the loss function. Defaults to + dict(type='SiLogLoss'). + init_cfg (dict): Initialization configurations. Defaults to + dict(type='TruncNormal', std=0.02, layer=['Conv2d', 'Linear']). + """ + + num_classes = 1 + out_channels = 1 + input_transform = None + + def __init__( + self, + max_depth: float = 10.0, + in_channels: Sequence[int] = [320, 640, 1280, 1280], + embed_dim: int = 192, + feature_dim: int = 1536, + num_deconv_layers: int = 3, + num_deconv_filters: Sequence[int] = (32, 32, 32), + fmap_border: Union[int, Sequence[int]] = 0, + align_corners: bool = False, + loss_decode: dict = dict(type='SiLogLoss'), + init_cfg=dict( + type='TruncNormal', std=0.02, layer=['Conv2d', 'Linear']), + ): + + super(BaseDecodeHead, self).__init__(init_cfg=init_cfg) + + # initialize parameters + self.in_channels = in_channels + self.max_depth = max_depth + self.align_corners = align_corners + + # feature map border + if isinstance(fmap_border, int): + fmap_border = (fmap_border, fmap_border) + self.fmap_border = fmap_border + + # define network layers + self.conv1 = nn.Sequential( + nn.Conv2d(in_channels[0], in_channels[0], 3, stride=2, padding=1), + nn.GroupNorm(16, in_channels[0]), + nn.ReLU(), + nn.Conv2d(in_channels[0], in_channels[0], 3, stride=2, padding=1), + ) + self.conv2 = nn.Conv2d( + in_channels[1], in_channels[1], 3, stride=2, padding=1) + + self.conv_aggregation = nn.Sequential( + nn.Conv2d(sum(in_channels), feature_dim, 1), + nn.GroupNorm(16, feature_dim), + nn.ReLU(), + ) + + self.decoder = VPDDepthDecoder( + in_channels=embed_dim * 8, + out_channels=embed_dim, + num_deconv_layers=num_deconv_layers, + num_deconv_filters=num_deconv_filters) + + self.depth_pred_layer = nn.Sequential( + nn.Conv2d( + embed_dim, embed_dim, kernel_size=3, stride=1, padding=1), + nn.ReLU(inplace=False), + nn.Conv2d(embed_dim, 1, kernel_size=3, stride=1, padding=1)) + + # build loss + if isinstance(loss_decode, dict): + self.loss_decode = build_loss(loss_decode) + elif isinstance(loss_decode, (list, tuple)): + self.loss_decode = nn.ModuleList() + for loss in loss_decode: + self.loss_decode.append(build_loss(loss)) + else: + raise TypeError(f'loss_decode must be a dict or sequence of dict,\ + but got {type(loss_decode)}') + + def _stack_batch_gt(self, batch_data_samples: SampleList) -> Tensor: + gt_depth_maps = [ + data_sample.gt_depth_map.data for data_sample in batch_data_samples + ] + return torch.stack(gt_depth_maps, dim=0) + + def forward(self, x): + x = [ + x[0], x[1], + torch.cat([x[2], F.interpolate(x[3], scale_factor=2)], dim=1) + ] + x = torch.cat([self.conv1(x[0]), self.conv2(x[1]), x[2]], dim=1) + x = self.conv_aggregation(x) + + x = x[:, :, :x.size(2) - self.fmap_border[0], :x.size(3) - + self.fmap_border[1]].contiguous() + x = self.decoder(x) + out = self.depth_pred_layer(x) + + depth = torch.sigmoid(out) * self.max_depth + + return depth + + def loss_by_feat(self, pred_depth_map: Tensor, + batch_data_samples: SampleList) -> dict: + """Compute depth estimation loss. + + Args: + pred_depth_map (Tensor): The output from decode head forward + function. + batch_data_samples (List[:obj:`SegDataSample`]): The seg + data samples. It usually includes information such + as `metainfo` and `gt_dpeth_map`. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + + gt_depth_map = self._stack_batch_gt(batch_data_samples) + loss = dict() + pred_depth_map = resize( + input=pred_depth_map, + size=gt_depth_map.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + + if not isinstance(self.loss_decode, nn.ModuleList): + losses_decode = [self.loss_decode] + else: + losses_decode = self.loss_decode + for loss_decode in losses_decode: + if loss_decode.loss_name not in loss: + loss[loss_decode.loss_name] = loss_decode( + pred_depth_map, gt_depth_map) + else: + loss[loss_decode.loss_name] += loss_decode( + pred_depth_map, gt_depth_map) + + return loss diff --git a/mmseg/models/losses/__init__.py b/mmseg/models/losses/__init__.py index 9af5e40f23..0467cb3ad8 100644 --- a/mmseg/models/losses/__init__.py +++ b/mmseg/models/losses/__init__.py @@ -8,6 +8,7 @@ from .huasdorff_distance_loss import HuasdorffDisstanceLoss from .lovasz_loss import LovaszLoss from .ohem_cross_entropy_loss import OhemCrossEntropy +from .silog_loss import SiLogLoss from .tversky_loss import TverskyLoss from .utils import reduce_loss, weight_reduce_loss, weighted_loss @@ -16,5 +17,5 @@ 'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss', 'weight_reduce_loss', 'weighted_loss', 'LovaszLoss', 'DiceLoss', 'FocalLoss', 'TverskyLoss', 'OhemCrossEntropy', 'BoundaryLoss', - 'HuasdorffDisstanceLoss' + 'HuasdorffDisstanceLoss', 'SiLogLoss' ] diff --git a/mmseg/models/losses/silog_loss.py b/mmseg/models/losses/silog_loss.py new file mode 100644 index 0000000000..ac840535ac --- /dev/null +++ b/mmseg/models/losses/silog_loss.py @@ -0,0 +1,117 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Union + +import torch +import torch.nn as nn +from torch import Tensor + +from mmseg.registry import MODELS +from .utils import weight_reduce_loss + + +def silog_loss(pred: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + eps: float = 1e-4, + reduction: Union[str, None] = 'mean', + avg_factor: Optional[int] = None) -> Tensor: + """Computes the Scale-Invariant Logarithmic (SI-Log) loss between + prediction and target. + + Args: + pred (Tensor): Predicted output. + target (Tensor): Ground truth. + weight (Optional[Tensor]): Optional weight to apply on the loss. + eps (float): Epsilon value to avoid division and log(0). + reduction (Union[str, None]): Specifies the reduction to apply to the + output: 'mean', 'sum' or None. + avg_factor (Optional[int]): Optional average factor for the loss. + + Returns: + Tensor: The calculated SI-Log loss. + """ + pred, target = pred.flatten(1), target.flatten(1) + valid_mask = (target > eps).detach().float() + + diff_log = torch.log(target.clamp(min=eps)) - torch.log( + pred.clamp(min=eps)) + diff_log_sq_mean = (diff_log.pow(2) * valid_mask).sum( + dim=1) / valid_mask.sum(dim=1).clamp(min=eps) + diff_log_mean = (diff_log * valid_mask).sum(dim=1) / valid_mask.sum( + dim=1).clamp(min=eps) + + loss = torch.sqrt(diff_log_sq_mean - 0.5 * diff_log_mean.pow(2)) + + if weight is not None: + weight = weight.float() + + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + +@MODELS.register_module() +class SiLogLoss(nn.Module): + """Compute SiLog loss. + + Args: + reduction (str, optional): The method used + to reduce the loss. Options are "none", + "mean" and "sum". Defaults to 'mean'. + loss_weight (float, optional): Weight of loss. Defaults to 1.0. + eps (float): Avoid dividing by zero. Defaults to 1e-3. + loss_name (str, optional): Name of the loss item. If you want this + loss item to be included into the backward graph, `loss_` must + be the prefix of the name. Defaults to 'loss_silog'. + """ + + def __init__(self, + reduction='mean', + loss_weight=1.0, + eps=1e-6, + loss_name='loss_silog'): + super().__init__() + self.reduction = reduction + self.loss_weight = loss_weight + self.eps = eps + self._loss_name = loss_name + + def forward( + self, + pred, + target, + weight=None, + avg_factor=None, + reduction_override=None, + ): + + assert pred.shape == target.shape, 'the shapes of pred ' \ + f'({pred.shape}) and target ({target.shape}) are mismatch' + + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + + loss = self.loss_weight * silog_loss( + pred, + target, + weight, + eps=self.eps, + reduction=reduction, + avg_factor=avg_factor, + ) + + return loss + + @property + def loss_name(self): + """Loss Name. + + This function must be implemented and will return the name of this + loss function. This name will be used to combine different loss items + by simple sum operation. In addition, if you want this loss item to be + included into the backward graph, `loss_` must be the prefix of the + name. + Returns: + str: The name of this loss item. + """ + return self._loss_name diff --git a/mmseg/models/segmentors/__init__.py b/mmseg/models/segmentors/__init__.py index fec0d52c3a..ac63c73f74 100644 --- a/mmseg/models/segmentors/__init__.py +++ b/mmseg/models/segmentors/__init__.py @@ -1,9 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. from .base import BaseSegmentor from .cascade_encoder_decoder import CascadeEncoderDecoder +from .depth_estimator import DepthEstimator from .encoder_decoder import EncoderDecoder from .seg_tta import SegTTAModel __all__ = [ - 'BaseSegmentor', 'EncoderDecoder', 'CascadeEncoderDecoder', 'SegTTAModel' + 'BaseSegmentor', 'EncoderDecoder', 'CascadeEncoderDecoder', 'SegTTAModel', + 'DepthEstimator' ] diff --git a/mmseg/models/segmentors/depth_estimator.py b/mmseg/models/segmentors/depth_estimator.py new file mode 100644 index 0000000000..1020637e73 --- /dev/null +++ b/mmseg/models/segmentors/depth_estimator.py @@ -0,0 +1,392 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import logging +from typing import List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.logging import print_log +from mmengine.structures import PixelData +from torch import Tensor + +from mmseg.registry import MODELS +from mmseg.structures import SegDataSample +from mmseg.utils import (ConfigType, OptConfigType, OptMultiConfig, + OptSampleList, SampleList, add_prefix) +from ..utils import resize +from .encoder_decoder import EncoderDecoder + + +@MODELS.register_module() +class DepthEstimator(EncoderDecoder): + """Encoder Decoder depth estimator. + + EncoderDecoder typically consists of backbone, decode_head, auxiliary_head. + Note that auxiliary_head is only used for deep supervision during training, + which could be dumped during inference. + + 1. The ``loss`` method is used to calculate the loss of model, + which includes two steps: (1) Extracts features to obtain the feature maps + (2) Call the decode head loss function to forward decode head model and + calculate losses. + + .. code:: text + + loss(): extract_feat() -> _decode_head_forward_train() -> _auxiliary_head_forward_train (optional) + _decode_head_forward_train(): decode_head.loss() + _auxiliary_head_forward_train(): auxiliary_head.loss (optional) + + 2. The ``predict`` method is used to predict depth estimation results, + which includes two steps: (1) Run inference function to obtain the list of + depth (2) Call post-processing function to obtain list of + ``SegDataSample`` including ``pred_depth_map``. + + .. code:: text + + predict(): inference() -> postprocess_result() + inference(): whole_inference()/slide_inference() + whole_inference()/slide_inference(): encoder_decoder() + encoder_decoder(): extract_feat() -> decode_head.predict() + + 3. The ``_forward`` method is used to output the tensor by running the model, + which includes two steps: (1) Extracts features to obtain the feature maps + (2)Call the decode head forward function to forward decode head model. + + .. code:: text + + _forward(): extract_feat() -> _decode_head.forward() + + Args: + + backbone (ConfigType): The config for the backnone of depth estimator. + decode_head (ConfigType): The config for the decode head of depth estimator. + neck (OptConfigType): The config for the neck of depth estimator. + Defaults to None. + auxiliary_head (OptConfigType): The config for the auxiliary head of + depth estimator. Defaults to None. + train_cfg (OptConfigType): The config for training. Defaults to None. + test_cfg (OptConfigType): The config for testing. Defaults to None. + data_preprocessor (dict, optional): The pre-process config of + :class:`BaseDataPreprocessor`. + pretrained (str, optional): The path for pretrained model. + Defaults to None. + init_cfg (dict, optional): The weight initialized config for + :class:`BaseModule`. + """ # noqa: E501 + + def __init__(self, + backbone: ConfigType, + decode_head: ConfigType, + neck: OptConfigType = None, + auxiliary_head: OptConfigType = None, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + pretrained: Optional[str] = None, + init_cfg: OptMultiConfig = None): + super().__init__( + backbone=backbone, + decode_head=decode_head, + neck=neck, + auxiliary_head=auxiliary_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + pretrained=pretrained, + init_cfg=init_cfg) + + def extract_feat(self, + inputs: Tensor, + batch_img_metas: Optional[List[dict]] = None) -> Tensor: + """Extract features from images.""" + + if getattr(self.backbone, 'class_embed_select', False) and \ + isinstance(batch_img_metas, list) and \ + 'category_id' in batch_img_metas[0]: + cat_ids = [meta['category_id'] for meta in batch_img_metas] + cat_ids = torch.tensor(cat_ids).to(inputs.device) + inputs = (inputs, cat_ids) + + x = self.backbone(inputs) + if self.with_neck: + x = self.neck(x) + return x + + def encode_decode(self, inputs: Tensor, + batch_img_metas: List[dict]) -> Tensor: + """Encode images with backbone and decode into a depth map of the same + size as input.""" + x = self.extract_feat(inputs, batch_img_metas) + depth = self.decode_head.predict(x, batch_img_metas, self.test_cfg) + + return depth + + def _decode_head_forward_train(self, inputs: List[Tensor], + data_samples: SampleList) -> dict: + """Run forward function and calculate loss for decode head in + training.""" + losses = dict() + loss_decode = self.decode_head.loss(inputs, data_samples, + self.train_cfg) + + losses.update(add_prefix(loss_decode, 'decode')) + return losses + + def _auxiliary_head_forward_train(self, inputs: List[Tensor], + data_samples: SampleList) -> dict: + """Run forward function and calculate loss for auxiliary head in + training.""" + losses = dict() + if isinstance(self.auxiliary_head, nn.ModuleList): + for idx, aux_head in enumerate(self.auxiliary_head): + loss_aux = aux_head.loss(inputs, data_samples, self.train_cfg) + losses.update(add_prefix(loss_aux, f'aux_{idx}')) + else: + loss_aux = self.auxiliary_head.loss(inputs, data_samples, + self.train_cfg) + losses.update(add_prefix(loss_aux, 'aux')) + + return losses + + def loss(self, inputs: Tensor, data_samples: SampleList) -> dict: + """Calculate losses from a batch of inputs and data samples. + + Args: + inputs (Tensor): Input images. + data_samples (list[:obj:`SegDataSample`]): The seg data samples. + It usually includes information such as `metainfo` and + `gt_depth_map`. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + if data_samples is not None: + batch_img_metas = [ + data_sample.metainfo for data_sample in data_samples + ] + else: + batch_img_metas = [ + dict( + ori_shape=inputs.shape[2:], + img_shape=inputs.shape[2:], + pad_shape=inputs.shape[2:], + padding_size=[0, 0, 0, 0]) + ] * inputs.shape[0] + + x = self.extract_feat(inputs, batch_img_metas) + + losses = dict() + + loss_decode = self._decode_head_forward_train(x, data_samples) + losses.update(loss_decode) + + if self.with_auxiliary_head: + loss_aux = self._auxiliary_head_forward_train(x, data_samples) + losses.update(loss_aux) + + return losses + + def predict(self, + inputs: Tensor, + data_samples: OptSampleList = None) -> SampleList: + """Predict results from a batch of inputs and data samples with post- + processing. + + Args: + inputs (Tensor): Inputs with shape (N, C, H, W). + data_samples (List[:obj:`SegDataSample`], optional): The seg data + samples. It usually includes information such as `metainfo` + and `gt_depth_map`. + + Returns: + list[:obj:`SegDataSample`]: Depth estimation results of the + input images. Each SegDataSample usually contain: + + - ``pred_depth_max``(PixelData): Prediction of depth estimation. + """ + if data_samples is not None: + batch_img_metas = [ + data_sample.metainfo for data_sample in data_samples + ] + else: + batch_img_metas = [ + dict( + ori_shape=inputs.shape[2:], + img_shape=inputs.shape[2:], + pad_shape=inputs.shape[2:], + padding_size=[0, 0, 0, 0]) + ] * inputs.shape[0] + + depth = self.inference(inputs, batch_img_metas) + + return self.postprocess_result(depth, data_samples) + + def _forward(self, + inputs: Tensor, + data_samples: OptSampleList = None) -> Tensor: + """Network forward process. + + Args: + inputs (Tensor): Inputs with shape (N, C, H, W). + data_samples (List[:obj:`SegDataSample`]): The seg + data samples. It usually includes information such + as `metainfo` and `gt_depth_map`. + + Returns: + Tensor: Forward output of model without any post-processes. + """ + x = self.extract_feat(inputs) + return self.decode_head.forward(x) + + def slide_flip_inference(self, inputs: Tensor, + batch_img_metas: List[dict]) -> Tensor: + """Inference by sliding-window with overlap and flip. + + If h_crop > h_img or w_crop > w_img, the small patch will be used to + decode without padding. + + Args: + inputs (tensor): the tensor should have a shape NxCxHxW, + which contains all images in the batch. + batch_img_metas (List[dict]): List of image metainfo where each may + also contain: 'img_shape', 'scale_factor', 'flip', 'img_path', + 'ori_shape', and 'pad_shape'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:PackSegInputs`. + + Returns: + Tensor: The depth estimation results. + """ + + h_stride, w_stride = self.test_cfg.stride + h_crop, w_crop = self.test_cfg.crop_size + batch_size, _, h_img, w_img = inputs.size() + out_channels = self.out_channels + h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 + w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 + preds = inputs.new_zeros((batch_size, out_channels, h_img, w_img)) + count_mat = inputs.new_zeros((batch_size, 1, h_img, w_img)) + for h_idx in range(h_grids): + for w_idx in range(w_grids): + y1 = h_idx * h_stride + x1 = w_idx * w_stride + y2 = min(y1 + h_crop, h_img) + x2 = min(x1 + w_crop, w_img) + y1 = max(y2 - h_crop, 0) + x1 = max(x2 - w_crop, 0) + crop_img = inputs[:, :, y1:y2, x1:x2] + # change the image shape to patch shape + batch_img_metas[0]['img_shape'] = crop_img.shape[2:] + # the output of encode_decode is depth tensor map + # with shape [N, C, H, W] + crop_depth_map = self.encode_decode(crop_img, batch_img_metas) + + # average out the original and flipped prediction + crop_depth_map_flip = self.encode_decode( + crop_img.flip(dims=(3, )), batch_img_metas) + crop_depth_map_flip = crop_depth_map_flip.flip(dims=(3, )) + crop_depth_map = (crop_depth_map + crop_depth_map_flip) / 2.0 + + preds += F.pad(crop_depth_map, + (int(x1), int(preds.shape[3] - x2), int(y1), + int(preds.shape[2] - y2))) + + count_mat[:, :, y1:y2, x1:x2] += 1 + assert (count_mat == 0).sum() == 0 + depth = preds / count_mat + + return depth + + def inference(self, inputs: Tensor, batch_img_metas: List[dict]) -> Tensor: + """Inference with slide/whole style. + + Args: + inputs (Tensor): The input image of shape (N, 3, H, W). + batch_img_metas (List[dict]): List of image metainfo where each may + also contain: 'img_shape', 'scale_factor', 'flip', 'img_path', + 'ori_shape', 'pad_shape', and 'padding_size'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:PackSegInputs`. + + Returns: + Tensor: The depth estimation results. + """ + assert self.test_cfg.get('mode', 'whole') in ['slide', 'whole', + 'slide_flip'], \ + f'Only "slide", "slide_flip" or "whole" test mode are ' \ + f'supported, but got {self.test_cfg["mode"]}.' + ori_shape = batch_img_metas[0]['ori_shape'] + if not all(_['ori_shape'] == ori_shape for _ in batch_img_metas): + print_log( + 'Image shapes are different in the batch.', + logger='current', + level=logging.WARN) + if self.test_cfg.mode == 'slide': + depth_map = self.slide_inference(inputs, batch_img_metas) + if self.test_cfg.mode == 'slide_flip': + depth_map = self.slide_flip_inference(inputs, batch_img_metas) + else: + depth_map = self.whole_inference(inputs, batch_img_metas) + + return depth_map + + def postprocess_result(self, + depth: Tensor, + data_samples: OptSampleList = None) -> SampleList: + """ Convert results list to `SegDataSample`. + Args: + depth (Tensor): The depth estimation results. + data_samples (list[:obj:`SegDataSample`]): The seg data samples. + It usually includes information such as `metainfo` and + `gt_depth_map`. Default to None. + Returns: + list[:obj:`SegDataSample`]: Depth estomation results of the + input images. Each SegDataSample usually contain: + + - ``pred_depth_map``(PixelData): Prediction of depth estimation. + """ + batch_size, C, H, W = depth.shape + + if data_samples is None: + data_samples = [SegDataSample() for _ in range(batch_size)] + only_prediction = True + else: + only_prediction = False + + for i in range(batch_size): + if not only_prediction: + img_meta = data_samples[i].metainfo + # remove padding area + if 'img_padding_size' not in img_meta: + padding_size = img_meta.get('padding_size', [0] * 4) + else: + padding_size = img_meta['img_padding_size'] + padding_left, padding_right, padding_top, padding_bottom =\ + padding_size + # i_depth shape is 1, C, H, W after remove padding + i_depth = depth[i:i + 1, :, padding_top:H - padding_bottom, + padding_left:W - padding_right] + + flip = img_meta.get('flip', None) + if flip: + flip_direction = img_meta.get('flip_direction', None) + assert flip_direction in ['horizontal', 'vertical'] + if flip_direction == 'horizontal': + i_depth = i_depth.flip(dims=(3, )) + else: + i_depth = i_depth.flip(dims=(2, )) + + # resize as original shape + i_depth = resize( + i_depth, + size=img_meta['ori_shape'], + mode='bilinear', + align_corners=self.align_corners, + warning=False).squeeze(0) + else: + i_depth = depth[i] + + data_samples[i].set_data( + {'pred_depth_map': PixelData(**{'data': i_depth})}) + + return data_samples diff --git a/mmseg/registry/registry.py b/mmseg/registry/registry.py index 1e423980d1..37b6a77609 100644 --- a/mmseg/registry/registry.py +++ b/mmseg/registry/registry.py @@ -82,7 +82,9 @@ locations=['mmseg.engine.optimizers']) # mangage all kinds of parameter schedulers like `MultiStepLR` PARAM_SCHEDULERS = Registry( - 'parameter scheduler', parent=MMENGINE_PARAM_SCHEDULERS) + 'parameter scheduler', + parent=MMENGINE_PARAM_SCHEDULERS, + locations=['mmseg.engine.schedulers']) # manage all kinds of metrics METRICS = Registry( diff --git a/mmseg/utils/misc.py b/mmseg/utils/misc.py index 0a561732e9..dfc469e832 100644 --- a/mmseg/utils/misc.py +++ b/mmseg/utils/misc.py @@ -94,18 +94,28 @@ def stack_batch(inputs: List[torch.Tensor], # pad gt_sem_seg if data_samples is not None: data_sample = data_samples[i] - gt_sem_seg = data_sample.gt_sem_seg.data - del data_sample.gt_sem_seg.data - data_sample.gt_sem_seg.data = F.pad( - gt_sem_seg, padding_size, value=seg_pad_val) + pad_shape = None + if 'gt_sem_seg' in data_sample: + gt_sem_seg = data_sample.gt_sem_seg.data + del data_sample.gt_sem_seg.data + data_sample.gt_sem_seg.data = F.pad( + gt_sem_seg, padding_size, value=seg_pad_val) + pad_shape = data_sample.gt_sem_seg.shape if 'gt_edge_map' in data_sample: gt_edge_map = data_sample.gt_edge_map.data del data_sample.gt_edge_map.data data_sample.gt_edge_map.data = F.pad( gt_edge_map, padding_size, value=seg_pad_val) + pad_shape = data_sample.gt_edge_map.shape + if 'gt_depth_map' in data_sample: + gt_depth_map = data_sample.gt_depth_map.data + del data_sample.gt_depth_map.data + data_sample.gt_depth_map.data = F.pad( + gt_depth_map, padding_size, value=seg_pad_val) + pad_shape = data_sample.gt_depth_map.shape data_sample.set_metainfo({ 'img_shape': tensor.shape[-2:], - 'pad_shape': data_sample.gt_sem_seg.shape, + 'pad_shape': pad_shape, 'padding_size': padding_size }) padded_samples.append(data_sample) diff --git a/requirements/optional.txt b/requirements/optional.txt index 5eca649247..b0310f5296 100644 --- a/requirements/optional.txt +++ b/requirements/optional.txt @@ -1,2 +1,22 @@ cityscapesscripts +-e git+https://github.com/openai/CLIP.git@main#egg=clip + +# for vpd model +diffusers +einops==0.3.0 +imageio==2.9.0 +imageio-ffmpeg==0.4.2 +invisible-watermark +kornia==0.6 +-e git+https://github.com/CompVis/stable-diffusion@21f890f#egg=latent-diffusion nibabel +omegaconf==2.1.1 +pudb==2019.2 +pytorch-lightning==1.4.2 +streamlit>=0.73.1 +-e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers +test-tube>=0.7.5 +timm +torch-fidelity==0.3.0 +torchmetrics==0.6.0 +transformers==4.19.2 diff --git a/tests/test_config.py b/tests/test_config.py index 07b93b319f..cdd85ff57c 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -30,6 +30,7 @@ def _get_config_directory(): def test_config_build_segmentor(): """Test that all segmentation models defined in the configs can be initialized.""" + init_default_scope('mmseg') config_dpath = _get_config_directory() print(f'Found config_dpath = {config_dpath!r}') @@ -94,7 +95,8 @@ def test_config_data_pipeline(): # remove loading annotation in test pipeline load_anno_idx = -1 for i in range(len(config_mod.test_pipeline)): - if config_mod.test_pipeline[i].type == 'LoadAnnotations': + if config_mod.test_pipeline[i].type in ('LoadAnnotations', + 'LoadDepthAnnotation'): load_anno_idx = i del config_mod.test_pipeline[load_anno_idx] @@ -105,6 +107,7 @@ def test_config_data_pipeline(): if to_float32: img = img.astype(np.float32) seg = np.random.randint(0, 255, size=(1024, 2048, 1), dtype=np.uint8) + depth = np.random.rand(1024, 2048).astype(np.float32) results = dict( filename='test_img.png', @@ -112,7 +115,8 @@ def test_config_data_pipeline(): img=img, img_shape=img.shape, ori_shape=img.shape, - gt_seg_map=seg) + gt_seg_map=seg, + gt_depth_map=depth) results['seg_fields'] = ['gt_seg_map'] _check_concat_cd_input(config_mod, results) print(f'Test training data pipeline: \n{train_pipeline!r}') @@ -158,14 +162,14 @@ def _check_decode_head(decode_head_cfg, decode_head): elif input_transform == 'resize_concat': assert sum(in_channels) == decode_head.in_channels else: - assert isinstance(in_channels, int) assert in_channels == decode_head.in_channels - assert isinstance(decode_head.in_index, int) if decode_head_cfg['type'] == 'PointHead': assert decode_head_cfg.channels+decode_head_cfg.num_classes == \ decode_head.fc_seg.in_channels assert decode_head.fc_seg.out_channels == decode_head_cfg.num_classes + elif decode_head_cfg['type'] == 'VPDDepthHead': + assert decode_head.out_channels == 1 else: assert decode_head_cfg.channels == decode_head.conv_seg.in_channels assert decode_head.conv_seg.out_channels == decode_head_cfg.num_classes diff --git a/tests/test_datasets/test_transform.py b/tests/test_datasets/test_transform.py index 239b3842b2..e73e558ee8 100644 --- a/tests/test_datasets/test_transform.py +++ b/tests/test_datasets/test_transform.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy import os.path as osp +from unittest import TestCase import mmcv import numpy as np @@ -11,7 +12,8 @@ from mmseg.datasets.transforms import * # noqa from mmseg.datasets.transforms import (LoadBiomedicalData, LoadBiomedicalImageFromFile, - PhotoMetricDistortion, RandomCrop) + PhotoMetricDistortion, RandomCrop, + RandomDepthMix) from mmseg.registry import TRANSFORMS init_default_scope('mmseg') @@ -184,6 +186,14 @@ def test_flip(): assert np.equal(original_img, results['img']).all() assert np.equal(original_seg, results['gt_semantic_seg']).all() + results['gt_depth_map'] = seg + results['seg_fields'] = ['gt_depth_map'] + results = flip_module(results) + flip_module = TRANSFORMS.build(transform) + results = flip_module(results) + assert np.equal(original_img, results['img']).all() + assert np.equal(original_seg, results['gt_depth_map']).all() + def test_random_rotate_flip(): with pytest.raises(AssertionError): @@ -1218,3 +1228,46 @@ def test_albu_channel_order(): with pytest.raises(AssertionError): np.testing.assert_array_equal(results_albu['img'][..., 0], results_load['img'][..., 0]) + + +class TestRandomDepthMix(TestCase): + + def setUp(self): + self.transform = RandomDepthMix(prob=1.0) + + def test_transform_shape(self): + # Create a dummy result dict + results = { + 'img_shape': (10, 10), + 'img': np.random.rand(10, 10, 3), + 'gt_depth_map': np.random.rand(10, 10) + } + transformed = self.transform.transform(results) + + # Check if the shape remains the same + self.assertEqual(results['img'].shape, transformed['img'].shape) + + def test_transform_values(self): + # Create a dummy result dict + results = { + 'img_shape': (10, 10), + 'img': np.zeros((10, 10, 3)), + 'gt_depth_map': np.ones((10, 10)) + } + transformed = self.transform.transform(results) + + # Assuming the transformation modifies a portion of the image, + # it shouldn't remain all zeros + self.assertFalse(np.all(transformed['img'] == 0)) + + def test_invalid_image_dimension(self): + # Create a dummy result dict with invalid image dimension + results = { + 'img_shape': (10, 10), + 'img': np.random.rand(10, 10, 3, 3), + 'gt_depth_map': np.random.rand(10, 10) + } + + # Check if a ValueError is raised for invalid dimension + with self.assertRaises(ValueError): + self.transform.transform(results) diff --git a/tests/test_models/test_backbones/test_vpd.py b/tests/test_models/test_backbones/test_vpd.py new file mode 100644 index 0000000000..a268159155 --- /dev/null +++ b/tests/test_models/test_backbones/test_vpd.py @@ -0,0 +1,51 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from os.path import dirname, join +from unittest import TestCase + +import torch +from mmengine import Config + +import mmseg +from mmseg.models.backbones import VPD + + +class TestVPD(TestCase): + + def setUp(self) -> None: + + repo_dpath = dirname(dirname(mmseg.__file__)) + config_dpath = join(repo_dpath, 'configs/_base_/models/vpd_sd.py') + vpd_cfg = Config.fromfile(config_dpath).stable_diffusion_cfg + vpd_cfg.pop('checkpoint') + + self.vpd_model = VPD( + diffusion_cfg=vpd_cfg, + class_embed_path='https://download.openmmlab.com/mmsegmentation/' + 'v0.5/vpd/nyu_class_embeddings.pth', + class_embed_select=True, + pad_shape=64, + unet_cfg=dict(use_attn=False), + ) + + def test_forward(self): + # test forward without class_id + x = torch.randn(1, 3, 60, 60) + with torch.no_grad(): + out = self.vpd_model(x) + + self.assertEqual(len(out), 4) + self.assertListEqual(list(out[0].shape), [1, 320, 8, 8]) + self.assertListEqual(list(out[1].shape), [1, 640, 4, 4]) + self.assertListEqual(list(out[2].shape), [1, 1280, 2, 2]) + self.assertListEqual(list(out[3].shape), [1, 1280, 1, 1]) + + # test forward with class_id + x = torch.randn(1, 3, 60, 60) + with torch.no_grad(): + out = self.vpd_model((x, torch.tensor([2]))) + + self.assertEqual(len(out), 4) + self.assertListEqual(list(out[0].shape), [1, 320, 8, 8]) + self.assertListEqual(list(out[1].shape), [1, 640, 4, 4]) + self.assertListEqual(list(out[2].shape), [1, 1280, 2, 2]) + self.assertListEqual(list(out[3].shape), [1, 1280, 1, 1]) diff --git a/tests/test_models/test_heads/test_vpd_depth_head.py b/tests/test_models/test_heads/test_vpd_depth_head.py new file mode 100644 index 0000000000..e3a4f7558e --- /dev/null +++ b/tests/test_models/test_heads/test_vpd_depth_head.py @@ -0,0 +1,50 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch +from mmengine.structures import PixelData + +from mmseg.models.decode_heads import VPDDepthHead +from mmseg.structures import SegDataSample + + +class TestVPDDepthHead(TestCase): + + def setUp(self): + """Set up common resources.""" + self.in_channels = [320, 640, 1280, 1280] + self.max_depth = 10.0 + self.loss_decode = dict( + type='SiLogLoss' + ) # Replace with your actual loss type and parameters + self.vpd_depth_head = VPDDepthHead( + max_depth=self.max_depth, + in_channels=self.in_channels, + loss_decode=self.loss_decode) + + def test_forward(self): + """Test the forward method.""" + # Create a mock input tensor. Replace shape as per your needs. + x = [ + torch.randn(1, 320, 32, 32), + torch.randn(1, 640, 16, 16), + torch.randn(1, 1280, 8, 8), + torch.randn(1, 1280, 4, 4) + ] + + output = self.vpd_depth_head.forward(x) + print(output.shape) + + self.assertEqual(output.shape, (1, 1, 256, 256)) + + def test_loss_by_feat(self): + """Test the loss_by_feat method.""" + # Create mock data for `pred_depth_map` and `batch_data_samples`. + pred_depth_map = torch.randn(1, 1, 32, 32) + gt_depth_map = PixelData(data=torch.rand(1, 32, 32)) + batch_data_samples = [SegDataSample(gt_depth_map=gt_depth_map)] + + loss = self.vpd_depth_head.loss_by_feat(pred_depth_map, + batch_data_samples) + + self.assertIsNotNone(loss) diff --git a/tests/test_models/test_losses/test_silog_loss.py b/tests/test_models/test_losses/test_silog_loss.py new file mode 100644 index 0000000000..022434bcc1 --- /dev/null +++ b/tests/test_models/test_losses/test_silog_loss.py @@ -0,0 +1,20 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch + +from mmseg.models.losses import SiLogLoss + + +class TestSiLogLoss(TestCase): + + def test_SiLogLoss_forward(self): + pred = torch.tensor([[1.0, 2.0], [3.5, 4.0]], dtype=torch.float32) + target = torch.tensor([[0.0, 2.0], [3.0, 4.0]], dtype=torch.float32) + weight = torch.tensor([1.0, 0.5], dtype=torch.float32) + + loss_module = SiLogLoss() + loss = loss_module.forward(pred, target, weight) + + expected_loss = 0.02 + self.assertAlmostEqual(loss.item(), expected_loss, places=2) diff --git a/tests/test_models/test_segmentors/test_depth_estimator.py b/tests/test_models/test_segmentors/test_depth_estimator.py new file mode 100644 index 0000000000..e819c9e763 --- /dev/null +++ b/tests/test_models/test_segmentors/test_depth_estimator.py @@ -0,0 +1,64 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from copy import deepcopy +from os.path import dirname, join +from unittest import TestCase + +import torch +from mmengine import Config, ConfigDict +from mmengine.structures import PixelData + +import mmseg +from mmseg.models.segmentors import DepthEstimator +from mmseg.structures import SegDataSample + + +class TestDepthEstimator(TestCase): + + def setUp(self) -> None: + repo_dpath = dirname(dirname(mmseg.__file__)) + config_dpath = join(repo_dpath, 'configs/_base_/models/vpd_sd.py') + vpd_cfg = Config.fromfile(config_dpath).stable_diffusion_cfg + vpd_cfg.pop('checkpoint') + + backbone_cfg = dict( + type='VPD', + diffusion_cfg=vpd_cfg, + class_embed_path='https://download.openmmlab.com/mmsegmentation/' + 'v0.5/vpd/nyu_class_embeddings.pth', + class_embed_select=True, + pad_shape=64, + unet_cfg=dict(use_attn=False), + ) + + head_cfg = dict( + type='VPDDepthHead', + max_depth=10, + ) + + self.model = DepthEstimator( + backbone=backbone_cfg, decode_head=head_cfg) + + inputs = torch.randn(1, 3, 64, 80) + data_sample = SegDataSample() + data_sample.gt_depth_map = PixelData(data=torch.rand(1, 64, 80)) + data_sample.set_metainfo(dict(img_shape=(64, 80), ori_shape=(64, 80))) + self.data = dict(inputs=inputs, data_samples=[data_sample]) + + def test_slide_flip_inference(self): + + self.model.test_cfg = ConfigDict( + dict(mode='slide_flip', crop_size=(64, 64), stride=(16, 16))) + + with torch.no_grad(): + out = self.model.predict(**deepcopy(self.data)) + + self.assertEqual(len(out), 1) + self.assertIn('pred_depth_map', out[0].keys()) + self.assertListEqual(list(out[0].pred_depth_map.shape), [64, 80]) + + def test__forward(self): + data = deepcopy(self.data) + data['inputs'] = data['inputs'][:, :, :64, :64] + with torch.no_grad(): + out = self.model._forward(**data) + self.assertListEqual(list(out.shape), [1, 1, 64, 64])