Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PixMIM #721

Merged
merged 17 commits into from
Mar 23, 2023
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ Supported algorithms:
- [x] [BEiT v2 (arXiv'2022)](https:/open-mmlab/mmselfsup/tree/dev-1.x/configs/selfsup/beitv2)
- [x] [EVA (arXiv'2022)](https:/open-mmlab/mmselfsup/tree/dev-1.x/configs/selfsup/eva)
- [x] [MixMIM (ArXiv'2022)](https:/open-mmlab/mmselfsup/tree/dev-1.x/configs/selfsup/mixmim)
- [x] [PixMIM (ArXiv'2023)](https:/open-mmlab/mmselfsup/tree/dev-1.x/configs/selfsup/pixmim)

More algorithms are in our plan.

Expand Down
1 change: 1 addition & 0 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ Useful Tools
- [x] [BEiT v2 (arXiv'2022)](https:/open-mmlab/mmselfsup/tree/dev-1.x/configs/selfsup/beitv2)
- [x] [EVA (arXiv'2022)](https:/open-mmlab/mmselfsup/tree/dev-1.x/configs/selfsup/eva)
- [x] [MixMIM (ArXiv'2022)](https:/open-mmlab/mmselfsup/tree/dev-1.x/configs/selfsup/mixmim)
- [x] [PixMIM (ArXiv'2023)](https:/open-mmlab/mmselfsup/tree/dev-1.x/configs/selfsup/pixmim)

更多的算法实现已经在我们的计划中。

Expand Down
152 changes: 152 additions & 0 deletions configs/selfsup/pixmim/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# PixMIM

> [PixMIM: Rethinking Pixel Reconstruction in Masked Image Modeling
> ](https://arxiv.org/abs/2303.02416)

## TL;DR

PixMIM can seamlessly replace MAE as a stronger baseline, with
negligible computational overhead.

<!-- [ALGORITHM] -->

## Abstract

Masked Image Modeling (MIM) has achieved promising progress with the advent of Masked Autoencoders
(MAE) and BEiT. However, subsequent works have complicated the framework with new auxiliary tasks or extra pretrained models,
inevitably increasing computational overhead. This paper undertakes a fundamental analysis of
MIM from the perspective of pixel reconstruction, which
examines the input image patches and reconstruction target, and highlights two critical but previously overlooked
bottlenecks. Based on this analysis, we propose a remarkably simple and effective method, PixMIM, that entails two
strategies: 1) filtering the high-frequency components from
the reconstruction target to de-emphasize the network’s focus on texture-rich details and 2) adopting a conservative
data transform strategy to alleviate the problem of missing foreground in MIM training. PixMIM can be easily
integrated into most existing pixel-based MIM approaches
(i.e., using raw images as reconstruction target) with negligible additional computation. Without bells and whistles,
our method consistently improves three MIM approaches,
MAE, ConvMAE, and LSMAE, across various downstream
tasks. We believe this effective plug-and-play method will
serve as a strong baseline for self-supervised learning and
provide insights for future improvements of the MIM framework.

<div align=center>
<img src="https://user-images.githubusercontent.com/30762564/226782993-28b2b20f-9143-4514-8c61-1aa81146d159.png"/>
</div>

## Models and Benchmarks

Here, we report the results of the model on ImageNet, the details are below:

<table class="docutils">
<thead>
<tr>
<th rowspan="2">Algorithm</th>
<th rowspan="2">Backbone</th>
<th rowspan="2">Epoch</th>
<th rowspan="2">Batch Size</th>
<th colspan="2" align="center">Results (Top-1 %)</th>
<th colspan="3" align="center">Links</th>
</tr>
<tr>
<th>Linear Probing</th>
<th>Fine-tuning</th>
<th>Pretrain</th>
<th>Linear Probing</th>
<th>Fine-tuning</th>
</tr>
</thead>
<tr>
<td>PixMIM</td>
<td>ViT-base</td>
<td>300</td>
<td>4096</td>
<td>63.3</td>
<td>83.1</td>
<td><a href='https:/open-mmlab/mmselfsup/blob/1.x/configs/selfsup/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-300e_in1k.py'> config </a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-300e_in1k/pixmim_vit-base-p16_8xb512-amp-coslr-300e_in1k_20230322-3304a88c.pth'> model </a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-300e_in1k/pixmim_vit-base-p16_8xb512-amp-coslr-300e_in1k_20230322-3304a88c.json'> log </a></td>
<td><a href='https:/open-mmlab/mmselfsup/blob/1.x/configs/selfsup/pixmim/classification/vit-base-p16_linear-8xb2048-coslr-torchvision-transform-90e_in1k.py'> config </a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-300e_in1k/vit-base-p16_linear-8xb2048-torchvision-transform-coslr-90e_in1k/vit-base-p16_linear-8xb2048-torchvision-transform-coslr-90e_in1k_20230322-72322af8.pth'> model </a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-300e_in1k/vit-base-p16_linear-8xb2048-torchvision-transform-coslr-90e_in1k/vit-base-p16_linear-8xb2048-torchvision-transform-coslr-90e_in1k_20230322-72322af8.json'> log </a></td>
<td><a href='https:/open-mmlab/mmselfsup/blob/1.x/configs/selfsup/pixmim/classification/vit-base-p16_ft-8xb128-coslr-100e_in1k.py'> config </a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-300e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k_20230322-7eba2bc2.pth'> model </a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-300e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k_20230322-7eba2bc2.json'> log </a></td>
</tr>
<tr>
<td>PixMIM</td>
<td>ViT-base</td>
<td>800</td>
<td>4096</td>
<td>67.5</td>
<td>83.5</td>
<td><a href='https:/open-mmlab/mmselfsup/blob/1.x/configs/selfsup/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-800e_in1k.py'> config </a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-800e_in1k/pixmim_vit-base-p16_8xb512-amp-coslr-800e_in1k_20230322-e8137924.pth'> model </a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-800e_in1k/pixmim_vit-base-p16_8xb512-amp-coslr-800e_in1k_20230322-e8137924.json'> log </a></td>
<td><a href='https:/open-mmlab/mmselfsup/blob/1.x/configs/selfsup/pixmim/classification/vit-base-p16_linear-8xb2048-coslr-torchvision-transform-90e_in1k.py'> config </a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-800e_in1k/vit-base-p16_linear-8xb2048-torchvision-transform-coslr-90e_in1k/vit-base-p16_linear-8xb2048-torchvision-transform-coslr-90e_in1k_20230322-12c15568.pth'> model </a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-800e_in1k/vit-base-p16_linear-8xb2048-torchvision-transform-coslr-90e_in1k/vit-base-p16_linear-8xb2048-torchvision-transform-coslr-90e_in1k_20230322-12c15568.json'> log </a></td>
<td><a href='https:/open-mmlab/mmselfsup/blob/1.x/configs/selfsup/pixmim/classification/vit-base-p16_ft-8xb128-coslr-100e_in1k.py'> config </a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-800e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k_20230322-616b1a7f.pth'> model </a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-800e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k_20230322-616b1a7f.json'> log </a></td>
</tr>
</tbody>
</table>

## Pre-train and Evaluation

### Pre-train

If you use a cluster managed by Slurm

```sh
# all of our experiments can be run on a single machine, with 8 A100 GPUs
bash tools/slurm_train.sh $partition $job_name configs/selfsup/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-300e_in1k.py --amp
```

If you use a single machine without any cluster management software

```sh
bash tools/dist_train.sh configs/selfsup/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-300e_in1k.py 8 --amp
```

### Linear Probing

If you use a cluster managed by Slurm

```sh
# all of our experiments can be run on a single machine, with 8 A100 GPUs
bash tools/benchmarks/classification/mim_slurm_train.sh $partition configs/selfsup/pixmim/classification/vit-base-p16_linear-8xb2048-coslr-torchvision-transform-90e_in1k.py --amp
YuanLiuuuuuu marked this conversation as resolved.
Show resolved Hide resolved
```

If you use a single machine without any cluster management software

```sh
bash tools/benchmarks/classification/mim_dist_train.sh configs/selfsup/pixmim/classification/vit-base-p16_linear-8xb2048-coslr-torchvision-transform-90e_in1k.py 8 --amp
YuanLiuuuuuu marked this conversation as resolved.
Show resolved Hide resolved
```

### Fine-tuning

If you use a cluster managed by Slurm

```sh
# all of our experiments can be run on a single machine, with 8 A100 GPUs
bash tools/benchmarks/classification/mim_slurm_train.sh $partition configs/selfsup/pixmim/classification/vit-base-p16_ft-8xb128-coslr-100e_in1k.py $pretrained_model --amp
```

If you use a single machine without any cluster management software

```sh
GPUS=8 bash tools/benchmarks/classification/mim_dist_train.sh configs/selfsup/pixmim/classification/vit-base-p16_ft-8xb128-coslr-100e_in1k.py $pretrained_model --amp
```

## Detection and Segmentation

If you want to evaluate your model on detection or segmentation task, we provide a [script](https:/open-mmlab/mmselfsup/blob/dev-1.x/tools/model_converters/mmcls2timm.py) to convert the model keys from MMClassification style to timm style.

```sh
cd $MMSELFSUP
python tools/model_converters/mmcls2timm.py $src_ckpt $dst_ckpt
```

Then, using this converted ckpt, you can evaluate your model on detection task, following [Detectron2](https:/facebookresearch/detectron2/tree/main/projects/ViTDet),
and on semantic segmentation task, following this [project](https:/implus/mae_segmentation). Besides, using the unconverted ckpt, you can use
[MMSegmentation](https:/open-mmlab/mmsegmentation/tree/master/configs/mae) to evaluate your model.

## Citation

```bibtex
@article{PixMIM,
author = {Yuan Liu, Songyang Zhang, Jiacheng Chen, Kai Chen, Dahua Lin},
journal = {arXiv:2303.02416},
title = {PixMIM: Rethinking Pixel Reconstruction in Masked Image Modeling},
year = {2023},
}
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# mmcls:: means we use the default settings from MMClassification
_base_ = [
'mmcls::_base_/datasets/imagenet_bs64_swin_224.py',
'mmcls::_base_/schedules/imagenet_bs1024_adamw_swin.py',
'mmcls::_base_/default_runtime.py'
]

# MAE fine-tuning setting

# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='VisionTransformer',
arch='base',
img_size=224,
patch_size=16,
drop_path_rate=0.1,
avg_token=True,
output_cls_token=False,
final_norm=False,
init_cfg=dict(type='Pretrained', checkpoint='')),
neck=None,
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=768,
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
init_cfg=[dict(type='TruncNormal', layer='Linear', std=2e-5)]),
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0)
]))

train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=224,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(
type='RandAugment',
policies='timm_increasing',
num_policies=2,
total_level=10,
magnitude_level=9,
magnitude_std=0.5,
hparams=dict(pad_val=[104, 116, 124], interpolation='bicubic')),
dict(
type='RandomErasing',
erase_prob=0.25,
mode='rand',
min_area_ratio=0.02,
max_area_ratio=0.3333333333333333,
fill_color=[103.53, 116.28, 123.675],
fill_std=[57.375, 57.12, 58.395]),
dict(type='PackClsInputs')
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeEdge',
scale=256,
edge='short',
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=224),
dict(type='PackClsInputs')
]

train_dataloader = dict(batch_size=128, dataset=dict(pipeline=train_pipeline))
val_dataloader = dict(batch_size=128, dataset=dict(pipeline=test_pipeline))
test_dataloader = val_dataloader

# optimizer wrapper
optim_wrapper = dict(
optimizer=dict(
type='AdamW',
lr=2e-3,
weight_decay=0.05,
eps=1e-8,
betas=(0.9, 0.999),
model_type='vit', # layer-wise lr decay type
layer_decay_rate=0.65), # layer-wise lr decay factor
constructor='mmselfsup.LearningRateDecayOptimWrapperConstructor',
paramwise_cfg=dict(
custom_keys={
'.ln': dict(decay_mult=0.0),
'.bias': dict(decay_mult=0.0),
'.cls_token': dict(decay_mult=0.0),
'.pos_embed': dict(decay_mult=0.0)
}))

# learning rate scheduler
param_scheduler = [
dict(
type='LinearLR',
start_factor=1e-4,
by_epoch=True,
begin=0,
end=5,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=95,
by_epoch=True,
begin=5,
end=100,
eta_min=1e-6,
convert_to_iter_based=True)
]

# runtime settings
default_hooks = dict(
# save checkpoint per epoch.
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3))

train_cfg = dict(by_epoch=True, max_epochs=100)

randomness = dict(seed=0, diff_rank_seed=True)
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
_base_ = '../../../benchmarks/classification/imagenet/vit-base-p16_linear-8xb2048-coslr-90e_in1k.py' # noqa: E501

train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='mmcls.ToPIL', to_rgb=True),
dict(type='mmselfsup.MAERandomResizedCrop', size=224, interpolation=3),
dict(type='mmcls.torchvision/RandomHorizontalFlip', p=0.5),
dict(type='mmcls.ToNumpy', to_rgb=True),
dict(type='PackClsInputs'),
]
train_dataloader = dict(
batch_size=2048, dataset=dict(pipeline=train_pipeline), drop_last=True)
78 changes: 78 additions & 0 deletions configs/selfsup/pixmim/metafile.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
Collections:
- Name: PixMIM
Metadata:
Training Data: ImageNet-1k
Training Techniques:
- AdamW
Training Resources: 8x A100-80G GPUs
Architecture:
- ViT
Paper:
URL: https://arxiv.org/abs/2303.02416
Title: "PixMIM: Rethinking Pixel Reconstruction in Masked Image Modeling"
README: configs/selfsup/pixmim/README.md

Models:
- Name: pixmim_vit-base-p16_8xb512-amp-coslr-300e_in1k
In Collection: PixMIM
Metadata:
Epochs: 300
Batch Size: 4096
Results: null
Config: configs/selfsup/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-300e_in1k.py
Weights: https://download.openmmlab.com/mmselfsup/1.x/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-300e_in1k/pixmim_vit-base-p16_8xb512-amp-coslr-300e_in1k_20230322-3304a88c.pth
Downstream:
- Type: Image Classification
Metadata:
Epochs: 100
Batch Size: 1024
Results:
- Task: Fine-tuning
Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 83.1
Config: configs/selfsup/pixmim/classification/vit-base-p16_ft-8xb128-coslr-100e_in1k.py
Weights: https://download.openmmlab.com/mmselfsup/1.x/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-300e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k_20230322-7eba2bc2.pth
- Type: Image Classification
Metadata:
Epochs: 100
Batch Size: 16384
Results:
- Task: Linear Evaluation
Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 63.3
Config: configs/selfsup/pixmim/classification/vit-base-p16_linear-8xb2048-coslr-torchvision-transform-90e_in1k.py
Weights: https://download.openmmlab.com/mmselfsup/1.x/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-300e_in1k/vit-base-p16_linear-8xb2048-torchvision-transform-coslr-90e_in1k/vit-base-p16_linear-8xb2048-torchvision-transform-coslr-90e_in1k_20230322-72322af8.pth

- Name: pixmim_vit-base-p16_8xb512-amp-coslr-800e_in1k
In Collection: PixMIM
Metadata:
Epochs: 800
Batch Size: 4096
Results: null
Config: configs/selfsup/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-800e_in1k.py
Weights: https://download.openmmlab.com/mmselfsup/1.x/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-800e_in1k/pixmim_vit-base-p16_8xb512-amp-coslr-800e_in1k_20230322-e8137924.pth
Downstream:
- Type: Image Classification
Metadata:
Epochs: 100
Batch Size: 1024
Results:
- Task: Fine-tuning
Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 83.5
Config: configs/selfsup/pixmim/classification/vit-base-p16_ft-8xb128-coslr-100e_in1k.py
Weights: https://download.openmmlab.com/mmselfsup/1.x/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-800e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k_20230322-616b1a7f.pth
- Type: Image Classification
Metadata:
Epochs: 100
Batch Size: 16384
Results:
- Task: Linear Evaluation
Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 67.5
Config: configs/selfsup/pixmim/classification/vit-base-p16_linear-8xb2048-coslr-torchvision-transform-90e_in1k.py
Weights: https://download.openmmlab.com/mmselfsup/1.x/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-800e_in1k/vit-base-p16_linear-8xb2048-torchvision-transform-coslr-90e_in1k/vit-base-p16_linear-8xb2048-torchvision-transform-coslr-90e_in1k_20230322-12c15568.pth
Loading