diff --git a/.circleci/test.yml b/.circleci/test.yml index 2d5713cf1a..169bba2778 100644 --- a/.circleci/test.yml +++ b/.circleci/test.yml @@ -73,6 +73,10 @@ jobs: name: Install timm command: | pip install timm + - run: + name: Install transformers + command: | + pip install transformers - when: condition: equal: [ "0.10.0", << parameters.torchvision >> ] @@ -118,6 +122,10 @@ jobs: command: | docker exec mmaction pip install timm docker exec mmaction python -m pip install pytorchvideo + - run: + name: Install transformers + command: | + docker exec mmaction pip install transformers - run: name: Install mmaction dependencies command: | diff --git a/.github/workflows/merge_stage_test.yml b/.github/workflows/merge_stage_test.yml index de01615037..0a0222903a 100644 --- a/.github/workflows/merge_stage_test.yml +++ b/.github/workflows/merge_stage_test.yml @@ -69,6 +69,8 @@ jobs: if: ${{matrix.torchvision == '0.10.0'}} - name: Install timm run: pip install timm + - name: Install transformers + run: pip install transformers - name: Build and install run: rm -rf .eggs && pip install -e . - name: Run unittests and generate coverage report @@ -110,6 +112,8 @@ jobs: run: pip install lmdb - name: Install timm run: pip install timm + - name: Install transformers + run: pip install transformers - name: Install TurboJpeg lib run: sudo apt-get install -y libturbojpeg - name: Install PyTorch @@ -183,6 +187,8 @@ jobs: run: pip install librosa soundfile - name: Install lmdb run: pip install lmdb + - name: Install transformers + run: pip install transformers - name: Install mmaction dependencies run: | pip install git+https://github.com/open-mmlab/mmengine.git@main @@ -240,6 +246,8 @@ jobs: run: python -m pip install pytorchvideo - name: Install timm run: python -m pip install timm + - name: Install transformers + run: python -m pip install transformers - name: Build and install run: | pip install -e . -v diff --git a/configs/multimodal/vindlu/README.md b/configs/multimodal/vindlu/README.md new file mode 100644 index 0000000000..c49fed61fa --- /dev/null +++ b/configs/multimodal/vindlu/README.md @@ -0,0 +1,87 @@ +# VindLU + +[VindLU: A Recipe for Effective Video-and-Language Pretraining](https://arxiv.org/abs/2212.05051) + + + +## Abstract + + + +The last several years have witnessed remarkable progress in video-and-language (VidL) understanding. However, most modern VidL approaches use complex and specialized model architectures and sophisticated pretraining protocols, making the reproducibility, analysis and comparisons of these frameworks difficult. Hence, instead of proposing yet another new VidL model, this paper conducts a thorough empirical study demystifying the most important factors in the VidL model design. Among the factors that we investigate are (i) the spatiotemporal architecture design, (ii) the multimodal fusion schemes, (iii) the pretraining objectives, (iv) the choice of pretraining data, (v) pretraining and finetuning protocols, and (vi) dataset and model scaling. Our empirical study reveals that the most important design factors include: temporal modeling, video-to-text multimodal fusion, masked modeling objectives, and joint training on images and videos. Using these empirical insights, we then develop a step-by-step recipe, dubbed VindLU, for effective VidL pretraining. Our final model trained using our recipe achieves comparable or better than state-of-the-art results on several VidL tasks without relying on external CLIP pretraining. In particular, on the text-to-video retrieval task, our approach obtains 61.2% on DiDeMo, and 55.0% on ActivityNet, outperforming current SOTA by 7.8% and 6.1% respectively. Furthermore, our model also obtains state-of-the-art video question-answering results on ActivityNet-QA, MSRVTT-QA, MSRVTT-MC and TVQA. Our code and pretrained models are publicly available at: https://github.com/klauscc/VindLU. + + + +
+ +
+ +## Results and Models + +### Video Retrieval on MSRVTT-9k + +| frame sampling strategy | resolution | gpus | vision encoder | text encoder | pretraining | Recall@1 | config | ckpt | log | +| :---------------------: | :--------: | :--: | :------------: | :----------: | :--------------------: | :------: | :-----------------------------------: | :---------------------------------: | :---------------------------------: | +| uniform 12 | 224x224 | 8 | BEiT-Base | Bert-Base | C5M (WebVid-2M + CC3M) | 44.0 | [config](/configs/multimodal/vindlu/vindlu_beit-base_8x16_retrieval_msrvtt-9k.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/multimodal/vindlu/vindlu_beit-base_8x16_retrieval_msrvtt-9k/vindlu_beit-base_8x16_retrieval_msrvtt-9k_20230905-fc36231e.pth) | [log](https://download.openmmlab.com/mmaction/v1.0/multimodal/vindlu/vindlu_beit-base_8x16_retrieval_msrvtt-9k/vindlu_beit-base_8x16_retrieval_msrvtt-9k.log) | + +### Video Question-Answering on MSRVTT-QA + +| frame sampling strategy | resolution | gpus | vision encoder | text encoder | pretraining | top1 acc | config | ckpt | log | +| :---------------------: | :--------: | :--: | :------------: | :----------: | :--------------------: | :------: | :-----------------------------------: | :---------------------------------: | :---------------------------------: | +| uniform 12 | 224x224 | 8 | BEiT-Base | Bert-Base | C5M (WebVid-2M + CC3M) | 43.6 | [config](/configs/multimodal/vindlu/vindlu_beit-base_8x8_vqa_msrvtt-qa.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/multimodal/vindlu/vindlu_beit-base_8x8_vqa_msrvtt-qa/vindlu_beit-base_8x8_vqa_msrvtt-qa_20230906-6e693e64.pth) | [log](https://download.openmmlab.com/mmaction/v1.0/multimodal/vindlu/vindlu_beit-base_8x8_vqa_msrvtt-qa/vindlu_beit-base_8x8_vqa_msrvtt-qa.log) | + +### Multiple-Choice Question-Answering on MSRVTT-MC (Inference) + +| frame sampling strategy | resolution | gpus | vision encoder | text encoder | pretraining | top1 acc | config | ckpt | +| :---------------------: | :--------: | :--: | :------------: | :----------: | :--------------------: | :------: | :----------------------------------------------------: | :---------------------------------------------------: | +| uniform 12 | 224x224 | 8 | BEiT-Base | Bert-Base | C5M (WebVid-2M + CC3M) | 97.6 | [config](/configs/multimodal/vindlu/vindlu_beit-base_vqa-mc_msrvtt-mc.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/multimodal/vindlu/vindlu_beit-base_8x16_retrieval_msrvtt-9k/vindlu_beit-base_8x16_retrieval_msrvtt-9k_20230905-fc36231e.pth) | + +1. Currently, we only support the fine-tuning stage of VindLU models based on the pretrained checkpoint provided by the [original repo](https://github.com/klauscc/VindLU). + +For more details on data preparation, you can refer to [prepare msrvtt](/tools/data/msrvtt/README.md). + +## Train + +You can use the following command to train a model. + +```shell +python tools/train.py ${CONFIG_FILE} [optional arguments] +``` + +Example: train VindLU model on MSRVTT-9k dataset in a deterministic option with periodic validation. + +```shell +python tools/train.py configs/multimodal/vindlu/vindlu_beit-base_8x16_retrieval_msrvtt-9k.py \ + --seed 0 --deterministic +``` + +For more details, you can refer to the **Training** part in the [Training and Test Tutorial](/docs/en/user_guides/train_test.md). + +## Test + +You can use the following command to test a model. + +```shell +python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [optional arguments] +``` + +Example: test CLIP4Clip model on MSRVTT-9k dataset and dump the result to a pkl file. + +```shell +python tools/test.py cconfigs/multimodal/vindlu/vindlu_beit-base_8x16_retrieval_msrvtt-9k.py \ + checkpoints/SOME_CHECKPOINT.pth --dump result.pkl +``` + +For more details, you can refer to the **Test** part in the [Training and Test Tutorial](/docs/en/user_guides/train_test.md). + +## Citation + +```BibTeX +@inproceedings{cheng2023vindlu, + title={Vindlu: A recipe for effective video-and-language pretraining}, + author={Cheng, Feng and Wang, Xizi and Lei, Jie and Crandall, David and Bansal, Mohit and Bertasius, Gedas}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, + pages={10739--10750}, + year={2023} +} +``` diff --git a/configs/multimodal/vindlu/metafile.yml b/configs/multimodal/vindlu/metafile.yml new file mode 100644 index 0000000000..d7fdf7fe24 --- /dev/null +++ b/configs/multimodal/vindlu/metafile.yml @@ -0,0 +1,55 @@ +Collections: + - Name: VindLU + README: configs/multimodal/vindlu/README.md + Paper: + URL: https://arxiv.org/abs/2212.05051 + Title: 'VindLU: A Recipe for Effective Video-and-Language Pretraining' + +Models: + - Name: vindlu_beit-base_8x16_retrieval_msrvtt-9k + Config: configs/multimodal/vindlu/vindlu_beit-base_8x16_retrieval_msrvtt-9k.py + In Collection: VindLU + Metadata: + Architecture: BEiT-Base + Batch Size: 16 + Epochs: 5 + Training Data: MSRVTT-9k + Training Resources: 8 GPUs + Results: + Dataset: MSRVTT + Task: Video Retrieval + Metrics: + Recall@1: 44.0 + Recall@5: 70.6 + Recall@10: 80.0 + Training Log: https://download.openmmlab.com/mmaction/v1.0/multimodal/vindlu/vindlu_beit-base_8x16_retrieval_msrvtt-9k/vindlu_beit-base_8x16_retrieval_msrvtt-9k.log + Weights: https://download.openmmlab.com/mmaction/v1.0/multimodal/vindlu/vindlu_beit-base_8x16_retrieval_msrvtt-9k/vindlu_beit-base_8x16_retrieval_msrvtt-9k_20230905-fc36231e.pth + + - Name: vindlu_beit-base_8x8_vqa_msrvtt-qa + Config: configs/multimodal/vindlu/vindlu_beit-base_8x8_vqa_msrvtt-qa.py + In Collection: VindLU + Metadata: + Architecture: BEiT-Base + Batch Size: 8 + Epochs: 10 + Training Data: MSRVTT-qa + Training Resources: 8 GPUs + Results: + Dataset: MSRVTT + Task: Video Question-Answering + Metrics: + Top 1 Accuracy: 43.6 + Training Log: https://download.openmmlab.com/mmaction/v1.0/multimodal/vindlu/vindlu_beit-base_8x8_vqa_msrvtt-qa/vindlu_beit-base_8x8_vqa_msrvtt-qa.log + Weights: https://download.openmmlab.com/mmaction/v1.0/multimodal/vindlu/vindlu_beit-base_8x8_vqa_msrvtt-qa/vindlu_beit-base_8x8_vqa_msrvtt-qa_20230906-6e693e64.pth + + - Name: vindlu_beit-base_vqa-mc_msrvtt-mc + Config: configs/multimodal/vindlu/vindlu_beit-base_vqa-mc_msrvtt-mc.py + In Collection: VindLU + Metadata: + Architecture: BEiT-Base + Results: + Dataset: MSRVTT-MC + Task: Multiple-Choice Question-Answering + Metrics: + Top 1 Accuracy: 97.6 + Weights: https://download.openmmlab.com/mmaction/v1.0/multimodal/vindlu/vindlu_beit-base_8x16_retrieval_msrvtt-9k/vindlu_beit-base_8x16_retrieval_msrvtt-9k_20230905-fc36231e.pth diff --git a/configs/multimodal/vindlu/vindlu_beit-base_8x16_retrieval_msrvtt-9k.py b/configs/multimodal/vindlu/vindlu_beit-base_8x16_retrieval_msrvtt-9k.py new file mode 100644 index 0000000000..fd20acbc24 --- /dev/null +++ b/configs/multimodal/vindlu/vindlu_beit-base_8x16_retrieval_msrvtt-9k.py @@ -0,0 +1,200 @@ +_base_ = ['../../_base_/default_runtime.py'] + +video_root = 'data/msrvtt/videos_2fps_224' +anno_file_train = 'data/msrvtt/annotations/msrvtt_ret_train9k.json' +anno_file_test = 'data/msrvtt/annotations/msrvtt_ret_test1k.json' +pretrained_ckpt_url = 'https://download.openmmlab.com/mmaction/v1.0/multimodal/vindlu/vindlu_c5m_pretrain.pth' # noqa: E501 + +# model settings +model = dict( + type='VindLURetrieval', + gradient_checkpointing=True, + init_cfg=dict(type='Pretrained', checkpoint=pretrained_ckpt_url), + data_preprocessor=dict( + type='ActionDataPreprocessor', + mean=[128], + std=[128], + format_shape='NCTHW'), + tokenizer=dict( + type='VindLUTokenizer', + pretrained_model_name_or_path='bert-base-uncased'), + vision_encoder=dict( + type='BeitModel3D', + config='microsoft/beit-base-patch16-224-pt22k-ft22k', + tem_config=dict( + num_frames=12, + temporal_model_block='timesformer', + temporal_model_position='last', + temporal_model_config=dict(input_dim=768), + use_temporal_position_embedding=True), + encoder_width=768, + add_ln=True), + text_encoder=dict( + type='XBertModel', + pretrained_model_name_or_path='bert-base-uncased', + encoder_width=768, + fusion_layer=9, + add_pooling_layer=False), + proj_dim=256, + temperature=0.07, + max_txt_len=32, + topk=128) + +file_client_args = dict(io_backend='disk') +train_pipeline = [ + dict(type='DecordInit', **file_client_args), + dict( + type='SampleFrames', + clip_len=1, + frame_interval=1, + num_clips=12, + out_of_bound_opt='repeat_last', + ), + dict(type='DecordDecode'), + dict(type='RandomResizedCrop', area_range=(0.5, 1.0)), + dict( + type='Resize', + scale=(224, 224), + keep_ratio=False, + interpolation='bicubic'), + dict(type='Flip', flip_ratio=0.5), + dict(type='FormatShape', input_format='NCHW'), + dict( + type='PackActionInputs', + algorithm_keys=( + 'text', + 'gt_video_id', + 'gt_text_id', + )) +] + +val_pipeline = [ + dict(type='DecordInit', **file_client_args), + dict( + type='SampleFrames', + clip_len=1, + frame_interval=1, + num_clips=12, + test_mode=True, + out_of_bound_opt='repeat_last'), + dict(type='DecordDecode'), + dict( + type='Resize', + scale=(224, 224), + keep_ratio=False, + interpolation='bicubic'), + dict(type='FormatShape', input_format='NCHW'), + dict( + type='PackActionInputs', + algorithm_keys=( + 'text', + 'gt_video_id', + 'gt_text_id', + )) +] + +test_pipeline = [ + dict(type='DecordInit', **file_client_args), + dict( + type='SampleFrames', + clip_len=1, + frame_interval=1, + num_clips=12, + test_mode=True, + out_of_bound_opt='repeat_last'), + dict(type='DecordDecode'), + dict( + type='Resize', + scale=(224, 224), + keep_ratio=False, + interpolation='bicubic'), + dict(type='FormatShape', input_format='NCHW'), + dict( + type='PackActionInputs', + algorithm_keys=( + 'text', + 'gt_video_id', + 'gt_text_id', + )) +] + +dataset_type = 'MSRVTTRetrieval' + +train_dataloader = dict( + batch_size=32, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type=dataset_type, + ann_file=anno_file_train, + pipeline=train_pipeline, + data_prefix=dict(video=video_root), + )) + +val_dataloader = dict( + batch_size=8, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + ann_file=anno_file_test, + pipeline=test_pipeline, + data_prefix=dict(video=video_root), + )) + +test_dataloader = dict( + batch_size=8, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + ann_file=anno_file_test, + pipeline=test_pipeline, + data_prefix=dict(video=video_root), + )) + +train_cfg = dict( + type='EpochBasedTrainLoop', max_epochs=5, val_begin=1, val_interval=1) +val_cfg = dict(type='RetrievalValLoop') +test_cfg = dict(type='RetrievalTestLoop') + +val_evaluator = dict(type='RetrievalRecall', topk=(1, 5, 10)) +test_evaluator = dict(type='RetrievalRecall', topk=(1, 5, 10)) + +param_scheduler = [ + dict( + type='CosineAnnealingLR', + T_max=5, + eta_min_ratio=0.01, + by_epoch=True, + begin=0, + end=5, + convert_to_iter_based=True) +] + +optim_wrapper = dict( + type='AmpOptimWrapper', + optimizer=dict(type='AdamW', lr=1e-5, weight_decay=0.02), + paramwise_cfg=dict( + bypass_duplicate=True, norm_decay_mult=0.0, bias_decay_mult=0.0), + clip_grad=dict(max_norm=50, norm_type=2), +) + +model_wrapper_cfg = dict(type='MMDistributedDataParallel', static_graph=True) + +default_hooks = dict( + checkpoint=dict( + type='CheckpointHook', + interval=1, + save_best='t2i/retrieval/Recall@1', + rule='greater'), + logger=dict(type='LoggerHook', interval=20, ignore_last=False)) + +auto_scale_lr = dict(enable=True, base_batch_size=128) + +find_unused_parameters = True + +custom_hooks = [dict(type='EmptyCacheHook', after_epoch=True)] diff --git a/configs/multimodal/vindlu/vindlu_beit-base_8x8_vqa_msrvtt-qa.py b/configs/multimodal/vindlu/vindlu_beit-base_8x8_vqa_msrvtt-qa.py new file mode 100644 index 0000000000..461b045cdb --- /dev/null +++ b/configs/multimodal/vindlu/vindlu_beit-base_8x8_vqa_msrvtt-qa.py @@ -0,0 +1,190 @@ +_base_ = ['../../_base_/default_runtime.py'] + +video_root = 'data/msrvtt/videos_2fps_224' +anno_file_train = 'data/msrvtt/annotations/msrvtt_qa_train.json' +anno_file_val = 'data/msrvtt/annotations/msrvtt_qa_val.json' +anno_file_test = 'data/msrvtt/annotations/msrvtt_qa_test.json' +answer_list_file = 'data/msrvtt/annotations/msrvtt_qa_answer_list.json' +pretrained_ckpt_url = 'https://download.openmmlab.com/mmaction/v1.0/multimodal/vindlu/vindlu_c5m_pretrain.pth' # noqa: E501 + +# model settings +model = dict( + type='VindLUVQA', + init_cfg=dict(type='Pretrained', checkpoint=pretrained_ckpt_url), + data_preprocessor=dict( + type='ActionDataPreprocessor', + mean=[128], + std=[128], + format_shape='NCTHW'), + tokenizer=dict( + type='VindLUTokenizer', + pretrained_model_name_or_path='bert-base-uncased', + ), + vision_encoder=dict( + type='BeitModel3D', + config='microsoft/beit-base-patch16-224-pt22k-ft22k', + tem_config=dict( + num_frames=12, + temporal_model_block='timesformer', + temporal_model_position='last', + temporal_model_config=dict(input_dim=768), + use_temporal_position_embedding=True), + encoder_width=768, + add_ln=True), + text_encoder=dict( + type='XBertModel', + pretrained_model_name_or_path='bert-base-uncased', + encoder_width=768, + fusion_layer=9, + add_pooling_layer=False), + text_decoder=dict( + type='BertDecoder', + pretrained_model_name_or_path='bert-base-uncased', + encoder_width=768, + fusion_layer=0, + num_hidden_layers=3, + add_pooling_layer=True), + proj_dim=256, + temperature=0.07, + max_question_len=25, + max_answer_len=5, + num_ans_candidates=128, + gradient_checkpointing=True, + answer_list_path=answer_list_file) + +file_client_args = dict(io_backend='disk') + +train_pipeline = [ + dict(type='DecordInit', **file_client_args), + dict( + type='SampleFrames', + clip_len=1, + frame_interval=1, + num_clips=12, + out_of_bound_opt='repeat_last'), + dict(type='DecordDecode'), + dict(type='RandomResizedCrop', area_range=(0.5, 1.0)), + dict( + type='Resize', + scale=(224, 224), + keep_ratio=False, + interpolation='bicubic'), + dict(type='Flip', flip_ratio=0.5), + dict(type='FormatShape', input_format='NCHW'), + dict( + type='PackActionInputs', + algorithm_keys=( + 'question', + 'question_id', + 'gt_answer', + 'gt_answer_weight', + )) +] + +val_pipeline = [ + dict(type='DecordInit', **file_client_args), + dict( + type='SampleFrames', + clip_len=1, + frame_interval=1, + num_clips=12, + test_mode=True, + out_of_bound_opt='repeat_last'), + dict(type='DecordDecode'), + dict( + type='Resize', + scale=(224, 224), + keep_ratio=False, + interpolation='bicubic'), + dict(type='FormatShape', input_format='NCHW'), + dict( + type='PackActionInputs', + algorithm_keys=( + 'question', + 'gt_answer', + 'question_id', + )) +] + +test_pipeline = val_pipeline + +dataset_type = 'MSRVTTVQA' + +train_dataloader = dict( + batch_size=8, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type=dataset_type, + ann_file=anno_file_train, + pipeline=train_pipeline, + data_prefix=dict(video=video_root), + )) + +val_dataloader = dict( + batch_size=16, + num_workers=4, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + ann_file=anno_file_val, + pipeline=val_pipeline, + data_prefix=dict(video=video_root), + )) + +test_dataloader = dict( + batch_size=16, + num_workers=4, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + ann_file=anno_file_test, + pipeline=test_pipeline, + data_prefix=dict(video=video_root), + )) + +val_evaluator = dict(type='VQAAcc') +test_evaluator = dict(type='VQAAcc') + +train_cfg = dict( + type='EpochBasedTrainLoop', max_epochs=10, val_begin=1, val_interval=1) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +param_scheduler = [ + dict( + type='LinearLR', + start_factor=0.01, + by_epoch=True, + begin=0, + end=1, + convert_to_iter_based=True), + dict( + type='CosineAnnealingLR', + T_max=10, + eta_min_ratio=0.01, + by_epoch=True, + begin=1, + end=10, + convert_to_iter_based=True) +] + +optim_wrapper = dict( + type='AmpOptimWrapper', + optimizer=dict(type='AdamW', lr=1e-5, weight_decay=0.02), + paramwise_cfg=dict( + bypass_duplicate=True, norm_decay_mult=0.0, bias_decay_mult=0.0), + clip_grad=dict(max_norm=50, norm_type=2), +) + +model_wrapper_cfg = dict(type='MMDistributedDataParallel', static_graph=True) + +default_hooks = dict( + logger=dict(type='LoggerHook', interval=20, ignore_last=False)) + +auto_scale_lr = dict(enable=True, base_batch_size=32) + +find_unused_parameters = True diff --git a/configs/multimodal/vindlu/vindlu_beit-base_vqa-mc_msrvtt-mc.py b/configs/multimodal/vindlu/vindlu_beit-base_vqa-mc_msrvtt-mc.py new file mode 100644 index 0000000000..7ec0271928 --- /dev/null +++ b/configs/multimodal/vindlu/vindlu_beit-base_vqa-mc_msrvtt-mc.py @@ -0,0 +1,80 @@ +_base_ = ['../../_base_/default_runtime.py'] + +video_root = 'data/msrvtt/videos_2fps_224' +anno_file_test = 'data/msrvtt/annotations/msrvtt_mc_test.json' + +# model settings +model = dict( + type='VindLURetrievalMC', + data_preprocessor=dict( + type='ActionDataPreprocessor', + mean=[128], + std=[128], + format_shape='NCTHW'), + tokenizer=dict( + type='VindLUTokenizer', + pretrained_model_name_or_path='bert-base-uncased'), + vision_encoder=dict( + type='BeitModel3D', + config='microsoft/beit-base-patch16-224-pt22k-ft22k', + tem_config=dict( + num_frames=12, + temporal_model_block='timesformer', + temporal_model_position='last', + temporal_model_config=dict(input_dim=768), + use_temporal_position_embedding=True), + encoder_width=768, + add_ln=True), + text_encoder=dict( + type='XBertModel', + pretrained_model_name_or_path='bert-base-uncased', + encoder_width=768, + fusion_layer=9, + add_pooling_layer=False), + text_decoder=dict( + type='BertDecoder', + pretrained_model_name_or_path='bert-base-uncased', + encoder_width=768, + fusion_layer=0, + num_hidden_layers=3, + add_pooling_layer=True), + proj_dim=256, + temperature=0.07, + max_txt_len=32, + gradient_checkpointing=True) + +file_client_args = dict(io_backend='disk') + +test_pipeline = [ + dict(type='DecordInit', **file_client_args), + dict( + type='SampleFrames', + clip_len=1, + frame_interval=1, + num_clips=12, + test_mode=True), + dict(type='DecordDecode'), + dict(type='Resize', scale=(224, 224), keep_ratio=False), + dict(type='FormatShape', input_format='NCHW'), + dict(type='PackActionInputs', algorithm_keys=('caption_options', )) +] + +dataset_type = 'MSRVTTVQAMC' + +test_dataloader = dict( + batch_size=32, + num_workers=16, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + ann_file=anno_file_test, + pipeline=test_pipeline, + data_prefix=dict(video=video_root), + )) + +test_evaluator = dict(type='VQAMCACC') +test_cfg = dict(type='TestLoop') + +default_hooks = dict( + logger=dict(type='LoggerHook', interval=20, ignore_last=False), ) diff --git a/mmaction/datasets/__init__.py b/mmaction/datasets/__init__.py index ded946b727..cc838f8f31 100644 --- a/mmaction/datasets/__init__.py +++ b/mmaction/datasets/__init__.py @@ -3,6 +3,7 @@ from .audio_dataset import AudioDataset from .ava_dataset import AVADataset, AVAKineticsDataset from .base import BaseActionDataset +from .msrvtt_datasets import MSRVTTVQA, MSRVTTVQAMC, MSRVTTRetrieval from .pose_dataset import PoseDataset from .rawframe_dataset import RawframeDataset from .repeat_aug_dataset import RepeatAugDataset, repeat_pseudo_collate @@ -13,5 +14,6 @@ __all__ = [ 'AVADataset', 'AVAKineticsDataset', 'ActivityNetDataset', 'AudioDataset', 'BaseActionDataset', 'PoseDataset', 'RawframeDataset', 'RepeatAugDataset', - 'VideoDataset', 'repeat_pseudo_collate', 'VideoTextDataset' + 'VideoDataset', 'repeat_pseudo_collate', 'VideoTextDataset', + 'MSRVTTRetrieval', 'MSRVTTVQA', 'MSRVTTVQAMC' ] diff --git a/mmaction/datasets/msrvtt_datasets.py b/mmaction/datasets/msrvtt_datasets.py new file mode 100644 index 0000000000..058734c01d --- /dev/null +++ b/mmaction/datasets/msrvtt_datasets.py @@ -0,0 +1,116 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +import os.path as osp +import re +from collections import Counter +from typing import Dict, List + +from mmengine.fileio import exists + +from mmaction.registry import DATASETS +from .base import BaseActionDataset + + +@DATASETS.register_module() +class MSRVTTVQA(BaseActionDataset): + """MSR-VTT Video Question Answering dataset.""" + + def load_data_list(self) -> List[Dict]: + """Load annotation file to get video information.""" + exists(self.ann_file) + data_list = [] + + with open(self.ann_file) as f: + data_lines = json.load(f) + for data in data_lines: + answers = data['answer'] + if isinstance(answers, str): + answers = [answers] + count = Counter(answers) + answer_weight = [i / len(answers) for i in count.values()] + data_item = dict( + question_id=data['question_id'], + filename=osp.join(self.data_prefix['video'], + data['video']), + question=pre_text(data['question']), + gt_answer=list(count.keys()), + gt_answer_weight=answer_weight) + data_list.append(data_item) + + return data_list + + +@DATASETS.register_module() +class MSRVTTVQAMC(BaseActionDataset): + """MSR-VTT VQA multiple choices dataset.""" + + def load_data_list(self) -> List[Dict]: + """Load annotation file to get video information.""" + exists(self.ann_file) + data_list = [] + + with open(self.ann_file) as f: + data_lines = json.load(f) + for data in data_lines: + data_item = dict( + filename=osp.join(self.data_prefix['video'], + data['video']), + label=data['answer'], + caption_options=[pre_text(c) for c in data['caption']]) + data_list.append(data_item) + + return data_list + + +@DATASETS.register_module() +class MSRVTTRetrieval(BaseActionDataset): + """MSR-VTT Retrieval dataset.""" + + def load_data_list(self) -> List[Dict]: + """Load annotation file to get video information.""" + exists(self.ann_file) + data_list = [] + + with open(self.ann_file) as f: + data_lines = json.load(f) + video_idx = 0 + text_idx = 0 + for data in data_lines: + # don't consider multiple videos or multiple captions + video_path = osp.join(self.data_prefix['video'], data['video']) + data_item = dict( + filename=video_path, + text=[], + gt_video_id=[], + gt_text_id=[]) + if isinstance(data['caption'], str): + data['caption'] = [data['caption']] + + for text in data['caption']: + text = pre_text(text) + data_item['text'].append(text) + data_item['gt_video_id'].append(video_idx) + data_item['gt_text_id'].append(text_idx) + text_idx += 1 + + video_idx += 1 + data_list.append(data_item) + self.num_videos = video_idx + self.num_texts = text_idx + + return data_list + + +def pre_text(text, max_l=None): + text = re.sub(r"([,.'!?\"()*#:;~])", '', text.lower()) + text = text.replace('-', ' ').replace('/', + ' ').replace('', 'person') + + text = re.sub(r'\s{2,}', ' ', text) + text = text.rstrip('\n').strip(' ') + + if max_l: # truncate + words = text.split(' ') + if len(words) > max_l: + text = ' '.join(words[:max_l]) + return text diff --git a/mmaction/datasets/transforms/formatting.py b/mmaction/datasets/transforms/formatting.py index fb67e10c0e..a8e9b9ab82 100644 --- a/mmaction/datasets/transforms/formatting.py +++ b/mmaction/datasets/transforms/formatting.py @@ -20,6 +20,8 @@ class PackActionInputs(BaseTransform): meta_keys (Sequence[str]): The meta keys to saved in the `metainfo` of the `data_sample`. Defaults to ``('img_shape', 'img_key', 'video_id', 'timestamp')``. + algorithm_keys (Sequence[str]): The keys of custom elements to be used + in the algorithm. Defaults to an empty tuple. """ mapping_table = { @@ -28,13 +30,15 @@ class PackActionInputs(BaseTransform): } def __init__( - self, - collect_keys: Optional[Tuple[str]] = None, - meta_keys: Sequence[str] = ('img_shape', 'img_key', 'video_id', - 'timestamp') + self, + collect_keys: Optional[Tuple[str]] = None, + meta_keys: Sequence[str] = ('img_shape', 'img_key', 'video_id', + 'timestamp'), + algorithm_keys: Sequence[str] = (), ) -> None: self.collect_keys = collect_keys self.meta_keys = meta_keys + self.algorithm_keys = algorithm_keys def transform(self, results: Dict) -> Dict: """The transform function of :class:`PackActionInputs`. @@ -88,6 +92,12 @@ def transform(self, results: Dict) -> Dict: if 'label' in results: data_sample.set_gt_label(results['label']) + # Set custom algorithm keys + for key in self.algorithm_keys: + if key in results: + data_sample.set_field(results[key], key) + + # Set meta keys img_meta = {k: results[k] for k in self.meta_keys if k in results} data_sample.set_metainfo(img_meta) packed_results['data_samples'] = data_sample diff --git a/mmaction/datasets/transforms/processing.py b/mmaction/datasets/transforms/processing.py index 3d432bd723..6d54c6bf24 100644 --- a/mmaction/datasets/transforms/processing.py +++ b/mmaction/datasets/transforms/processing.py @@ -613,8 +613,9 @@ class Resize(BaseTransform): keep_ratio (bool): If set to True, Images will be resized without changing the aspect ratio. Otherwise, it will resize images to a given size. Default: True. - interpolation (str): Algorithm used for interpolation: - "nearest" | "bilinear". Default: "bilinear". + interpolation (str): Algorithm used for interpolation, + accepted values are "nearest", "bilinear", "bicubic", "area", + "lanczos". Default: "bilinear". lazy (bool): Determine whether to apply lazy operation. Default: False. """ diff --git a/mmaction/engine/runner/__init__.py b/mmaction/engine/runner/__init__.py index c7dc511ea8..9bc36f001b 100644 --- a/mmaction/engine/runner/__init__.py +++ b/mmaction/engine/runner/__init__.py @@ -1,4 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .multi_loop import MultiLoaderEpochBasedTrainLoop +from .retrieval_loop import RetrievalTestLoop, RetrievalValLoop -__all__ = ['MultiLoaderEpochBasedTrainLoop'] +__all__ = [ + 'MultiLoaderEpochBasedTrainLoop', 'RetrievalValLoop', 'RetrievalTestLoop' +] diff --git a/mmaction/engine/runner/retrieval_loop.py b/mmaction/engine/runner/retrieval_loop.py new file mode 100644 index 0000000000..dc884876da --- /dev/null +++ b/mmaction/engine/runner/retrieval_loop.py @@ -0,0 +1,168 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import torch +from mmengine.model import is_model_wrapper +from mmengine.runner import TestLoop, ValLoop, autocast + +from mmaction.registry import LOOPS + + +@LOOPS.register_module() +class RetrievalValLoop(ValLoop): + """Loop for multimodal retrieval val. + + Args: + runner (Runner): A reference of runner. + dataloader (Dataloader or dict): A dataloader object or a dict to + build a dataloader. + evaluator (Evaluator or dict or list): Used for computing metrics. + fp16 (bool): Whether to enable fp16 valing. Defaults to + False. + """ + + def run(self) -> dict: + """Launch val.""" + self.runner.call_hook('before_val') + self.runner.call_hook('before_val_epoch') + self.runner.model.eval() + + feats_local = [] + data_samples_local = [] + + for idx, data_batch in enumerate(self.dataloader): + with torch.no_grad(): + self.runner.call_hook( + 'before_val_iter', batch_idx=idx, data_batch=data_batch) + # predictions should be sequence of BaseDataElement + with autocast(enabled=self.fp16): + if is_model_wrapper(self.runner.model): + data_preprocessor = self.runner.model.module.data_preprocessor # noqa: E501 + else: + data_preprocessor = self.runner.model.data_preprocessor + + # get features for retrieval instead of data samples + data_batch = data_preprocessor(data_batch, False) + feats = self.runner.model._run_forward( + data_batch, mode='tensor') + feats_local.append(feats) + data_samples_local.extend(data_batch['data_samples']) + self.runner.call_hook( + 'after_val_iter', + batch_idx=idx, + data_batch=data_batch, + outputs=feats) + + # concatenate different features + feats_local = { + k: torch.cat([dic[k] for dic in feats_local]) + for k in feats_local[0] + } + + # get predictions + if is_model_wrapper(self.runner.model): + predict_all_fn = self.runner.model.module.predict_all + else: + predict_all_fn = self.runner.model.predict_all + + num_videos = self.dataloader.dataset.num_videos + num_texts = self.dataloader.dataset.num_texts + with torch.no_grad(): + with autocast(enabled=self.fp16): + i2t_data_samples, t2i_data_samples = predict_all_fn( + feats_local, + data_samples_local, + num_images=num_videos, + num_texts=num_texts, + ) + # process in evaluator and compute metrics + self.evaluator.process(i2t_data_samples, None) + i2t_metrics = self.evaluator.evaluate(num_videos) + i2t_metrics = {f'i2t/{k}': v for k, v in i2t_metrics.items()} + self.evaluator.process(t2i_data_samples, None) + t2i_metrics = self.evaluator.evaluate(num_texts) + t2i_metrics = {f't2i/{k}': v for k, v in t2i_metrics.items()} + metrics = {**i2t_metrics, **t2i_metrics} + self.runner.call_hook('after_val_epoch', metrics=metrics) + self.runner.call_hook('after_val') + return metrics + + +@LOOPS.register_module() +class RetrievalTestLoop(TestLoop): + """Loop for multimodal retrieval test. + + Args: + runner (Runner): A reference of runner. + dataloader (Dataloader or dict): A dataloader object or a dict to + build a dataloader. + evaluator (Evaluator or dict or list): Used for computing metrics. + fp16 (bool): Whether to enable fp16 testing. Defaults to + False. + """ + + def run(self) -> dict: + """Launch test.""" + self.runner.call_hook('before_test') + self.runner.call_hook('before_test_epoch') + self.runner.model.eval() + + feats_local = [] + data_samples_local = [] + + for idx, data_batch in enumerate(self.dataloader): + with torch.no_grad(): + self.runner.call_hook( + 'before_test_iter', batch_idx=idx, data_batch=data_batch) + # predictions should be sequence of BaseDataElement + with autocast(enabled=self.fp16): + if is_model_wrapper(self.runner.model): + data_preprocessor = self.runner.model.module.data_preprocessor # noqa: E501 + else: + data_preprocessor = self.runner.model.data_preprocessor + # get features for retrieval instead of data samples + data_batch = data_preprocessor(data_batch, False) + feats = self.runner.model._run_forward( + data_batch, mode='tensor') + feats_local.append(feats) + data_samples_local.extend(data_batch['data_samples']) + self.runner.call_hook( + 'after_test_iter', + batch_idx=idx, + data_batch=data_batch, + outputs=feats) + + # concatenate different features + feats_local = { + k: torch.cat([dic[k] for dic in feats_local]) + for k in feats_local[0] + } + + # get predictions + if is_model_wrapper(self.runner.model): + predict_all_fn = self.runner.model.module.predict_all + else: + predict_all_fn = self.runner.model.predict_all + + num_videos = self.dataloader.dataset.num_videos + num_texts = self.dataloader.dataset.num_texts + with torch.no_grad(): + with autocast(enabled=self.fp16): + i2t_data_samples, t2i_data_samples = predict_all_fn( + feats_local, + data_samples_local, + num_images=num_videos, + num_texts=num_texts, + ) + + # process in evaluator and compute metrics + self.evaluator.process(i2t_data_samples, None) + i2t_metrics = self.evaluator.evaluate(num_videos) + i2t_metrics = {f'i2t/{k}': v for k, v in i2t_metrics.items()} + self.evaluator.process(t2i_data_samples, None) + t2i_metrics = self.evaluator.evaluate(num_texts) + t2i_metrics = {f't2i/{k}': v for k, v in t2i_metrics.items()} + metrics = {**i2t_metrics, **t2i_metrics} + + self.runner.call_hook('after_test_epoch', metrics=metrics) + self.runner.call_hook('after_test') + return metrics diff --git a/mmaction/evaluation/metrics/__init__.py b/mmaction/evaluation/metrics/__init__.py index 8bf22c6672..341ec577ce 100644 --- a/mmaction/evaluation/metrics/__init__.py +++ b/mmaction/evaluation/metrics/__init__.py @@ -2,10 +2,12 @@ from .acc_metric import AccMetric, ConfusionMatrix from .anet_metric import ANetMetric from .ava_metric import AVAMetric +from .multimodal_metric import VQAMCACC, ReportVQA, RetrievalRecall, VQAAcc from .multisports_metric import MultiSportsMetric from .retrieval_metric import RetrievalMetric __all__ = [ 'AccMetric', 'AVAMetric', 'ANetMetric', 'ConfusionMatrix', - 'MultiSportsMetric', 'RetrievalMetric' + 'MultiSportsMetric', 'RetrievalMetric', 'VQAAcc', 'ReportVQA', 'VQAMCACC', + 'RetrievalRecall' ] diff --git a/mmaction/evaluation/metrics/multimodal_metric.py b/mmaction/evaluation/metrics/multimodal_metric.py new file mode 100644 index 0000000000..2c144ac10a --- /dev/null +++ b/mmaction/evaluation/metrics/multimodal_metric.py @@ -0,0 +1,565 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Copied from mmpretrain +# Partly adopted from https://github.com/GT-Vision-Lab/VQA +# Copyright (c) 2014, Aishwarya Agrawal +from typing import List, Optional, Sequence, Union + +import mmengine +import numpy as np +import torch +import torch.nn.functional as F +from mmengine.evaluator import BaseMetric +from mmengine.logging import MMLogger +from mmengine.utils import is_seq_of + +from mmaction.registry import METRICS +from mmaction.structures.action_data_sample import format_label +from .acc_metric import to_tensor + + +def _process_punctuation(inText): + import re + outText = inText + punct = [ + ';', r'/', '[', ']', '"', '{', '}', '(', ')', '=', '+', '\\', '_', '-', + '>', '<', '@', '`', ',', '?', '!' + ] + commaStrip = re.compile('(\d)(,)(\d)') # noqa: W605 + periodStrip = re.compile('(?!<=\d)(\.)(?!\d)') # noqa: W605 + for p in punct: + if (p + ' ' in inText or ' ' + p in inText) or (re.search( + commaStrip, inText) is not None): + outText = outText.replace(p, '') + else: + outText = outText.replace(p, ' ') + outText = periodStrip.sub('', outText, re.UNICODE) + return outText + + +def _process_digit_article(inText): + outText = [] + tempText = inText.lower().split() + articles = ['a', 'an', 'the'] + manualMap = { + 'none': '0', + 'zero': '0', + 'one': '1', + 'two': '2', + 'three': '3', + 'four': '4', + 'five': '5', + 'six': '6', + 'seven': '7', + 'eight': '8', + 'nine': '9', + 'ten': '10', + } + contractions = { + 'aint': "ain't", + 'arent': "aren't", + 'cant': "can't", + 'couldve': "could've", + 'couldnt': "couldn't", + "couldn'tve": "couldn't've", + "couldnt've": "couldn't've", + 'didnt': "didn't", + 'doesnt': "doesn't", + 'dont': "don't", + 'hadnt': "hadn't", + "hadnt've": "hadn't've", + "hadn'tve": "hadn't've", + 'hasnt': "hasn't", + 'havent': "haven't", + 'hed': "he'd", + "hed've": "he'd've", + "he'dve": "he'd've", + 'hes': "he's", + 'howd': "how'd", + 'howll': "how'll", + 'hows': "how's", + "Id've": "I'd've", + "I'dve": "I'd've", + 'Im': "I'm", + 'Ive': "I've", + 'isnt': "isn't", + 'itd': "it'd", + "itd've": "it'd've", + "it'dve": "it'd've", + 'itll': "it'll", + "let's": "let's", + 'maam': "ma'am", + 'mightnt': "mightn't", + "mightnt've": "mightn't've", + "mightn'tve": "mightn't've", + 'mightve': "might've", + 'mustnt': "mustn't", + 'mustve': "must've", + 'neednt': "needn't", + 'notve': "not've", + 'oclock': "o'clock", + 'oughtnt': "oughtn't", + "ow's'at": "'ow's'at", + "'ows'at": "'ow's'at", + "'ow'sat": "'ow's'at", + 'shant': "shan't", + "shed've": "she'd've", + "she'dve": "she'd've", + "she's": "she's", + 'shouldve': "should've", + 'shouldnt': "shouldn't", + "shouldnt've": "shouldn't've", + "shouldn'tve": "shouldn't've", + "somebody'd": 'somebodyd', + "somebodyd've": "somebody'd've", + "somebody'dve": "somebody'd've", + 'somebodyll': "somebody'll", + 'somebodys': "somebody's", + 'someoned': "someone'd", + "someoned've": "someone'd've", + "someone'dve": "someone'd've", + 'someonell': "someone'll", + 'someones': "someone's", + 'somethingd': "something'd", + "somethingd've": "something'd've", + "something'dve": "something'd've", + 'somethingll': "something'll", + 'thats': "that's", + 'thered': "there'd", + "thered've": "there'd've", + "there'dve": "there'd've", + 'therere': "there're", + 'theres': "there's", + 'theyd': "they'd", + "theyd've": "they'd've", + "they'dve": "they'd've", + 'theyll': "they'll", + 'theyre': "they're", + 'theyve': "they've", + 'twas': "'twas", + 'wasnt': "wasn't", + "wed've": "we'd've", + "we'dve": "we'd've", + 'weve': "we've", + 'werent': "weren't", + 'whatll': "what'll", + 'whatre': "what're", + 'whats': "what's", + 'whatve': "what've", + 'whens': "when's", + 'whered': "where'd", + 'wheres': "where's", + 'whereve': "where've", + 'whod': "who'd", + "whod've": "who'd've", + "who'dve": "who'd've", + 'wholl': "who'll", + 'whos': "who's", + 'whove': "who've", + 'whyll': "why'll", + 'whyre': "why're", + 'whys': "why's", + 'wont': "won't", + 'wouldve': "would've", + 'wouldnt': "wouldn't", + "wouldnt've": "wouldn't've", + "wouldn'tve": "wouldn't've", + 'yall': "y'all", + "yall'll": "y'all'll", + "y'allll": "y'all'll", + "yall'd've": "y'all'd've", + "y'alld've": "y'all'd've", + "y'all'dve": "y'all'd've", + 'youd': "you'd", + "youd've": "you'd've", + "you'dve": "you'd've", + 'youll': "you'll", + 'youre': "you're", + 'youve': "you've", + } + for word in tempText: + word = manualMap.setdefault(word, word) + if word not in articles: + outText.append(word) + for wordId, word in enumerate(outText): + if word in contractions: + outText[wordId] = contractions[word] + outText = ' '.join(outText) + return outText + + +@METRICS.register_module() +class VQAAcc(BaseMetric): + '''VQA Acc metric. + Args: + + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Should be modified according to the + `retrieval_type` for unambiguous results. Defaults to TR. + ''' + default_prefix = 'VQA' + + def __init__(self, + full_score_weight: float = 0.3, + collect_device: str = 'cpu', + prefix: Optional[str] = None): + super().__init__(collect_device=collect_device, prefix=prefix) + self.full_score_weight = full_score_weight + + def process(self, data_batch, data_samples): + """Process one batch of data samples. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch: A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + for sample in data_samples: + gt_answer = sample.get('gt_answer') + gt_answer_weight = sample.get('gt_answer_weight') + if isinstance(gt_answer, str): + gt_answer = [gt_answer] + if gt_answer_weight is None: + gt_answer_weight = [1. / (len(gt_answer))] * len(gt_answer) + + result = { + 'pred_answer': sample.get('pred_answer'), + 'gt_answer': gt_answer, + 'gt_answer_weight': gt_answer_weight, + } + + self.results.append(result) + + def compute_metrics(self, results: List): + """Compute the metrics from processed results. + + Args: + results (dict): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + acc = [] + for result in results: + pred_answer = self._process_answer(result['pred_answer']) + gt_answer = [ + self._process_answer(answer) for answer in result['gt_answer'] + ] + answer_weight = result['gt_answer_weight'] + + weight_sum = 0 + for i, gt in enumerate(gt_answer): + if gt == pred_answer: + weight_sum += answer_weight[i] + vqa_acc = min(1.0, weight_sum / self.full_score_weight) + acc.append(vqa_acc) + + accuracy = sum(acc) / len(acc) * 100 + + metrics = {'acc': accuracy} + return metrics + + def _process_answer(self, answer): + answer = answer.replace('\n', ' ') + answer = answer.replace('\t', ' ') + answer = answer.strip() + answer = _process_punctuation(answer) + answer = _process_digit_article(answer) + return answer + + +@METRICS.register_module() +class ReportVQA(BaseMetric): + """Dump VQA result to the standard json format for VQA evaluation. + + Args: + file_path (str): The file path to save the result file. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Should be modified according to the + `retrieval_type` for unambiguous results. Defaults to TR. + """ + default_prefix = 'VQA' + + def __init__(self, + file_path: str, + collect_device: str = 'cpu', + prefix: Optional[str] = None): + super().__init__(collect_device=collect_device, prefix=prefix) + if not file_path.endswith('.json'): + raise ValueError('The output file must be a json file.') + self.file_path = file_path + + def process(self, data_batch, data_samples) -> None: + """transfer tensors in predictions to CPU.""" + for sample in data_samples: + question_id = sample['question_id'] + pred_answer = sample['pred_answer'] + + result = { + 'question_id': int(question_id), + 'answer': pred_answer, + } + + self.results.append(result) + + def compute_metrics(self, results: List): + """Dump the result to json file.""" + mmengine.dump(results, self.file_path) + logger = MMLogger.get_current_instance() + logger.info(f'Results has been saved to {self.file_path}.') + return {} + + +@METRICS.register_module() +class VQAMCACC(BaseMetric): + '''VQA multiple choice Acc metric. + Args: + + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Should be modified according to the + `retrieval_type` for unambiguous results. Defaults to TR. + ''' + default_prefix = 'VQAMC' + + def __init__(self, + collect_device: str = 'cpu', + prefix: Optional[str] = None): + super().__init__(collect_device=collect_device, prefix=prefix) + + def process(self, data_batch, data_samples): + """Process one batch of data samples. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch: A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + for sample in data_samples: + # gt_labels in datasample is a LabelData + label = sample['gt_label'].item() + result = { + 'pred_label': sample.get('pred_label'), + 'gt_label': label, + } + + self.results.append(result) + + def compute_metrics(self, results: List): + """Compute the metrics from processed results. + + Args: + results (dict): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + preds = np.array([x['pred_label'] for x in results]) + labels = np.array([x['gt_label'] for x in results]) + + accuracy = np.sum(preds == labels) / len(preds) * 100 + + metrics = {'acc': accuracy} + return metrics + + +@METRICS.register_module() +class RetrievalRecall(BaseMetric): + r"""Recall evaluation metric for image retrieval. + + Args: + topk (int | Sequence[int]): If the ground truth label matches one of + the best **k** predictions, the sample will be regard as a positive + prediction. If the parameter is a tuple, all of top-k recall will + be calculated and outputted together. Defaults to 1. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + + """ + default_prefix: Optional[str] = 'retrieval' + + def __init__(self, + topk: Union[int, Sequence[int]], + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + topk = (topk, ) if isinstance(topk, int) else topk + + for k in topk: + if k <= 0: + raise ValueError('`topk` must be a ingter larger than 0 ' + 'or seq of ingter larger than 0.') + + self.topk = topk + super().__init__(collect_device=collect_device, prefix=prefix) + + def process(self, data_batch: Sequence[dict], + data_samples: Sequence[dict]): + """Process one batch of data and predictions. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch (Sequence[dict]): A batch of data from the dataloader. + predictions (Sequence[dict]): A batch of outputs from the model. + """ + for data_sample in data_samples: + pred_score = data_sample['pred_score'].cpu() + gt_label = format_label(data_sample['gt_label']) + + if 'gt_score' in data_sample: + target = data_sample.get('gt_score').clone() + else: + num_classes = pred_score.size()[-1] + target = F.one_hot(gt_label, num_classes) + + # Because the retrieval output logit vector will be much larger + # compared to the normal classification, to save resources, the + # evaluation results are computed each batch here and then reduce + # all results at the end. + result = RetrievalRecall.calculate( + pred_score.unsqueeze(0), target.unsqueeze(0), topk=self.topk) + self.results.append(result) + + def compute_metrics(self, results: List): + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + result_metrics = dict() + for i, k in enumerate(self.topk): + recall_at_k = sum([r[i].item() for r in results]) / len(results) + result_metrics[f'Recall@{k}'] = recall_at_k + + return result_metrics + + @staticmethod + def calculate(pred: Union[np.ndarray, torch.Tensor], + target: Union[np.ndarray, torch.Tensor], + topk: Union[int, Sequence[int]], + pred_indices: (bool) = False, + target_indices: (bool) = False) -> float: + """Calculate the average recall. + + Args: + pred (torch.Tensor | np.ndarray | Sequence): The prediction + results. A :obj:`torch.Tensor` or :obj:`np.ndarray` with + shape ``(N, M)`` or a sequence of index/onehot + format labels. + target (torch.Tensor | np.ndarray | Sequence): The prediction + results. A :obj:`torch.Tensor` or :obj:`np.ndarray` with + shape ``(N, M)`` or a sequence of index/onehot + format labels. + topk (int, Sequence[int]): Predictions with the k-th highest + scores are considered as positive. + pred_indices (bool): Whether the ``pred`` is a sequence of + category index labels. Defaults to False. + target_indices (bool): Whether the ``target`` is a sequence of + category index labels. Defaults to False. + + Returns: + List[float]: the average recalls. + """ + topk = (topk, ) if isinstance(topk, int) else topk + for k in topk: + if k <= 0: + raise ValueError('`topk` must be a ingter larger than 0 ' + 'or seq of ingter larger than 0.') + + max_keep = max(topk) + pred = _format_pred(pred, max_keep, pred_indices) + target = _format_target(target, target_indices) + + assert len(pred) == len(target), ( + f'Length of `pred`({len(pred)}) and `target` ({len(target)}) ' + f'must be the same.') + + num_samples = len(pred) + results = [] + for k in topk: + recalls = torch.zeros(num_samples) + for i, (sample_pred, + sample_target) in enumerate(zip(pred, target)): + sample_pred = np.array(to_tensor(sample_pred).cpu()) + sample_target = np.array(to_tensor(sample_target).cpu()) + recalls[i] = int(np.in1d(sample_pred[:k], sample_target).max()) + results.append(recalls.mean() * 100) + return results + + +def _format_pred(label, topk=None, is_indices=False): + """format various label to List[indices].""" + if is_indices: + assert isinstance(label, Sequence), \ + '`pred` must be Sequence of indices when' \ + f' `pred_indices` set to True, but get {type(label)}' + for i, sample_pred in enumerate(label): + assert is_seq_of(sample_pred, int) or isinstance( + sample_pred, (np.ndarray, torch.Tensor)), \ + '`pred` should be Sequence of indices when `pred_indices`' \ + f'set to True. but pred[{i}] is {sample_pred}' + if topk: + label[i] = sample_pred[:min(topk, len(sample_pred))] + return label + if isinstance(label, np.ndarray): + label = torch.from_numpy(label) + elif not isinstance(label, torch.Tensor): + raise TypeError(f'The pred must be type of torch.tensor, ' + f'np.ndarray or Sequence but get {type(label)}.') + topk = topk if topk else label.size()[-1] + _, indices = label.topk(topk) + return indices + + +def _format_target(label, is_indices=False): + """format various label to List[indices].""" + if is_indices: + assert isinstance(label, Sequence), \ + '`target` must be Sequence of indices when' \ + f' `target_indices` set to True, but get {type(label)}' + for i, sample_gt in enumerate(label): + assert is_seq_of(sample_gt, int) or isinstance( + sample_gt, (np.ndarray, torch.Tensor)), \ + '`target` should be Sequence of indices when ' \ + f'`target_indices` set to True. but target[{i}] is {sample_gt}' + return label + + if isinstance(label, np.ndarray): + label = torch.from_numpy(label) + elif isinstance(label, Sequence) and not mmengine.is_str(label): + label = torch.tensor(label) + elif not isinstance(label, torch.Tensor): + raise TypeError(f'The pred must be type of torch.tensor, ' + f'np.ndarray or Sequence but get {type(label)}.') + + indices = [sample_gt.nonzero().squeeze(-1) for sample_gt in label] + return indices diff --git a/mmaction/models/__init__.py b/mmaction/models/__init__.py index 6c53b29254..08f7d41f52 100644 --- a/mmaction/models/__init__.py +++ b/mmaction/models/__init__.py @@ -5,6 +5,7 @@ from .heads import * # noqa: F401,F403 from .localizers import * # noqa: F401,F403 from .losses import * # noqa: F401,F403 +from .multimodal import * # noqa: F401,F403 from .necks import * # noqa: F401,F403 from .recognizers import * # noqa: F401,F403 from .roi_heads import * # noqa: F401,F403 diff --git a/mmaction/models/multimodal/__init__.py b/mmaction/models/multimodal/__init__.py new file mode 100644 index 0000000000..9a5f2a99df --- /dev/null +++ b/mmaction/models/multimodal/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmaction.utils.dependency import WITH_MULTIMODAL + +if WITH_MULTIMODAL: + from .vindlu import * # noqa: F401,F403 + +else: + from mmaction.registry import MODELS + from mmaction.utils.dependency import register_multimodal_placeholder + + register_multimodal_placeholder( + ['VindLUVQA', 'VindLURetrievalMC', 'VindLURetrieval'], MODELS) diff --git a/mmaction/models/multimodal/vindlu/__init__.py b/mmaction/models/multimodal/vindlu/__init__.py new file mode 100644 index 0000000000..e17c193246 --- /dev/null +++ b/mmaction/models/multimodal/vindlu/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .beit3d import BeitModel3D +from .tokenizer import VindLUTokenizer +from .vindlu_ret import VindLURetrieval +from .vindlu_ret_mc import VindLURetrievalMC +from .vindlu_vqa import VindLUVQA +from .xbert import BertDecoder, BertModel + +__all__ = [ + 'VindLUVQA', 'VindLURetrievalMC', 'VindLURetrieval', 'VindLUTokenizer', + 'BeitModel3D', 'BertDecoder', 'BertModel' +] diff --git a/mmaction/models/multimodal/vindlu/beit3d.py b/mmaction/models/multimodal/vindlu/beit3d.py new file mode 100644 index 0000000000..8e0d6f2fc3 --- /dev/null +++ b/mmaction/models/multimodal/vindlu/beit3d.py @@ -0,0 +1,350 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import importlib +from typing import Dict, Optional, Tuple, Union + +import einops +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers.models.beit import BeitConfig, BeitModel +from transformers.models.beit.modeling_beit import BeitAttention, BeitDropPath +from transformers.models.beit.modeling_beit import \ + BeitEmbeddings as BeitEmbeddings2D +from transformers.models.beit.modeling_beit import BeitLayer as BeitLayer2D +from transformers.models.beit.modeling_beit import BeitRelativePositionBias +from transformers.models.beit.modeling_beit import \ + BeitRelativePositionBias as BeitRelativePositionBias2D + +from mmaction.registry import MODELS +from .temporal_model import (X_CLIP, STAdapter, TemporalAttention, + WindowTemporalAttention) + + +def interpolate_temporal_pos_embed(temp_embed_old, num_frames_new): + """ + temp_embed_old: (1, num_frames_old, 1, d) + Returns: + temp_embed_new: (1, num_frames_new, 1, d) + """ + temp_embed_old = temp_embed_old.squeeze(2).permute( + 0, 2, 1) # (1, d, num_frames_old) + temp_embed_new = F.interpolate( + temp_embed_old, num_frames_new, + mode='linear') # (1, d, num_frames_new) + temp_embed_new = temp_embed_new.permute(0, 2, 1).unsqueeze( + 2) # (1, num_frames_new, 1, d) + return temp_embed_new + + +class TemporalAttentionBeit(nn.Module): + """temporal attention using BeitAttention.""" + + def __init__(self, config: BeitConfig): + """TODO: to be defined.""" + super().__init__() + + self.layernorm_before = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.attention = BeitAttention(config, window_size=None) + self.scale = nn.Parameter( + config.temporal_model_init_value * torch.ones( + (config.hidden_size)), + requires_grad=True, + ) + self.drop_path = BeitDropPath(config.drop_path_rate) + + def forward(self, hidden_states: torch.Tensor): + """forward function. + + Args: + hidden_states (torch.Tensor): The input. Shape: [b,t,l,c] + + Returns: TODO + """ + b = hidden_states.shape[0] + output = einops.rearrange(hidden_states, 'b t l c -> (b l) t c') + output = self.layernorm_before(output) + output = self.attention(output) + output = einops.rearrange(output[0], '(b l) t c -> b t l c', b=b) + return hidden_states + self.drop_path(output[0]) * self.scale + + +class BeitPooler3D(nn.Module): + + def __init__(self, config: BeitConfig) -> None: + super().__init__() + self.num_prompts = config.add_k_prompts + self.layernorm = ( + nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + if config.use_mean_pooling else None) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Args: + hidden_states (torch.Tensor): Shape: [B,T,L,C] + """ + if self.layernorm is not None: + # Mean pool the final hidden states of the patch tokens + # patch_tokens = hidden_states[:, 1 + self.num_prompts :, :] + if self.num_prompts > 0: + patch_tokens = hidden_states[:, :, 1:-self.num_prompts, :] + else: + patch_tokens = hidden_states[:, :, 1:, :] + pooled_output = self.layernorm(patch_tokens.mean(2)) + else: + # Pool by simply taking the final hidden state of the [CLS] token + pooled_output = hidden_states[:, :, 0] + + return pooled_output + + +class BeitRelativePositionBias3D(BeitRelativePositionBias2D): + + def __init__(self, config: BeitConfig, window_size: tuple) -> None: + super().__init__(config, window_size) + + # add bias for prompts + self.k = config.add_k_prompts + if self.k > 0: + self.prompt_bias_table = nn.parameter.Parameter( + torch.zeros((2 + self.k) * self.k, config.num_attention_heads) + ) # k prompt-to-token, k token-to-prompt, k*k prompt-to-promt + else: + self.prompt_bias_table = None + + def forward(self) -> torch.Tensor: + # relative position bias 2d + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] + 1, + self.window_size[0] * self.window_size[1] + 1, + -1, + ) # Wh*Ww,Wh*Ww,nH + + # add bias for prompts + k = self.k + if k > 0: + l = self.window_size[0] * self.window_size[1] + 1 # noqa: E741 + bias = torch.zeros(l + k, l + k, + relative_position_bias.shape[-1]).to( + relative_position_bias.device) + bias[:l, :l] = relative_position_bias + bias[l:, :l] = self.prompt_bias_table[:k].view( + k, 1, -1) # prompt to token + bias[:l, + l:] = self.prompt_bias_table[k:2 * + k].view(1, k, + -1) # token to prompt + bias[l:, l:] = self.prompt_bias_table[2 * k, :].view( + k, k, -1) # prompt to prompt + else: + bias = relative_position_bias + + return bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + + +class BeitEmbeddings3D(BeitEmbeddings2D): + """Construct the CLS token, position and patch embeddings. + + Optionally, also the mask token. + """ + + def __init__(self, config: BeitConfig) -> None: + super().__init__(config) + + if config.use_temporal_position_embedding: + self.temporal_position_embeddings = nn.parameter.Parameter( + torch.zeros(1, config.num_frames, 1, config.hidden_size)) + else: + self.temporal_position_embeddings = None + + if config.add_k_prompts > 0: + self.prompt_tokens = nn.parameter.Parameter( + torch.zeros(1, config.add_k_prompts, config.hidden_size)) + else: + self.prompt_tokens = None + + def forward(self, + pixel_values: torch.Tensor, + bool_masked_pos: Optional[torch.BoolTensor] = None + ) -> torch.Tensor: + """ + Args: + pixel_values (torch.Tensor): The input image patches. + Shape: [B, T, C, H, W]. + + + """ + t = pixel_values.shape[1] + pixel_values = einops.rearrange(pixel_values, + 'b t c h w -> (b t) c h w') + + embeddings = self.patch_embeddings(pixel_values) + batch_size, seq_len, _ = embeddings.size() # [(b t) l c] + + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + if bool_masked_pos is not None: + mask_tokens = self.mask_token.expand(batch_size, seq_len, -1) + # replace the masked visual tokens by mask_tokens + w = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) + embeddings = embeddings * (1 - w) + mask_tokens * w + + if self.prompt_tokens is not None: + prompt_tokens = self.prompt_tokens.expand(batch_size, -1, -1) + embeddings = torch.cat((cls_tokens, embeddings, prompt_tokens), + dim=1) + else: + embeddings = torch.cat((cls_tokens, embeddings), + dim=1) # [B*T, L, C] + if self.position_embeddings is not None: + embeddings = embeddings + self.position_embeddings + + embeddings = einops.rearrange(embeddings, '(b t) l c -> b t l c', t=t) + if self.temporal_position_embeddings is not None: + if t <= self.temporal_position_embeddings.shape[1]: + embeddings = embeddings + \ + self.temporal_position_embeddings[:, :t] + else: + tpe = interpolate_temporal_pos_embed( + self.temporal_position_embeddings, t) + embeddings = embeddings + tpe + + embeddings = self.dropout(embeddings) + + return embeddings + + +class BeitLayer3D(BeitLayer2D): + + def __init__(self, + config: BeitConfig, + window_size: Optional[tuple] = None, + drop_path_rate: float = 0.0) -> None: + super().__init__(config, window_size, drop_path_rate) + + self.temporal_model_position = config.temporal_model_position + if config.temporal_model_block == 'st_adapter': + self.temp_model = STAdapter(**config.temporal_model_config) + elif config.temporal_model_block == 'timesformer': + self.temp_model = TemporalAttention(**config.temporal_model_config) + elif config.temporal_model_block == 'ta_beit': + self.temp_model = TemporalAttentionBeit(config) + elif config.temporal_model_block == 'window_attention': + self.temp_model = WindowTemporalAttention( + **config.temporal_model_config) + elif config.temporal_model_block == 'xclip': + self.temp_model = X_CLIP(**config.temporal_model_config) + elif config.temporal_model_block == 'none': + self.temp_model = None + else: + raise ValueError( + f'not accepted temporal model: {config.temporal_model_block}') + + self.temporal_model_block = config.temporal_model_block + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + relative_position_bias: Optional['BeitRelativePositionBias'] = None, + ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: + + b, t, l, c = hidden_states.shape + + if self.temporal_model_block == 'xclip': + assert (self.temporal_model_position == 'first' + and self.config.add_k_prompts + == 1), ('xclip must be put before the attention and' + 'add_k_prompts must be 1.') + + if self.temp_model is not None and \ + self.temporal_model_position == 'first': + hidden_states = self.temp_model(hidden_states) + + hidden_states = einops.rearrange(hidden_states, 'b t l c -> (b t) l c') + + self_attention_outputs = self.attention( + self.layernorm_before( + hidden_states + ), # in BEiT, layernorm is applied before self-attention + head_mask, + output_attentions=output_attentions, + relative_position_bias=relative_position_bias, + ) + attention_output = self_attention_outputs[0] + + # add self attentions if we output attention weights + outputs = self_attention_outputs[1:] + + # apply lambda_1 if present + if self.lambda_1 is not None: + attention_output = self.lambda_1 * attention_output + + # first residual connection + hidden_states = self.drop_path(attention_output) + hidden_states + + # in BEiT, layernorm is also applied after self-attention + layer_output = self.layernorm_after(hidden_states) + + layer_output = self.intermediate(layer_output) + layer_output = self.output(layer_output) + + if self.lambda_2 is not None: + layer_output = self.lambda_2 * layer_output + + # second residual connection + layer_output = self.drop_path(layer_output) + hidden_states + + layer_output = einops.rearrange( + layer_output, '(b t) l c -> b t l c', b=b) + + # apply temporal modeling block + if self.temp_model is not None and \ + self.temporal_model_position == 'last': + layer_output = self.temp_model(layer_output) + + outputs = (layer_output, ) + outputs + + return outputs + + +class BeitConfig3D(BeitConfig): + + def __init__(self, + num_frames=1, + temporal_model_block='none', + temporal_model_position='last', + temporal_model_init_value=0.0, + temporal_model_config={}, + use_temporal_position_embedding=False, + add_k_prompts=0, + **kwargs) -> None: + + super().__init__(**kwargs) + self.temporal_model_block = temporal_model_block + self.temporal_model_config = temporal_model_config + self.temporal_model_position = temporal_model_position + self.temporal_model_init_value = temporal_model_init_value + self.use_temporal_position_embedding = use_temporal_position_embedding + self.add_k_prompts = add_k_prompts + self.num_frames = num_frames + + +@MODELS.register_module() +class BeitModel3D(BeitModel): + + def __init__(self, + config: BeitConfig, + tem_config: Dict, + add_pooling_layer: bool = True) -> None: + # hack to replace original 2D modules with 3D modules + beit_package = importlib.import_module( + 'transformers.models.beit.modeling_beit') + beit_package.BeitEmbeddings = BeitEmbeddings3D + beit_package.BeitPooler = BeitPooler3D + beit_package.BeitLayer = BeitLayer3D + beit_package.BeitRelativePositionBias = BeitRelativePositionBias3D + + config = BeitConfig3D.from_pretrained(config, **tem_config) + super().__init__(config, add_pooling_layer) diff --git a/mmaction/models/multimodal/vindlu/modeling_bert.py b/mmaction/models/multimodal/vindlu/modeling_bert.py new file mode 100644 index 0000000000..5ffba79bdc --- /dev/null +++ b/mmaction/models/multimodal/vindlu/modeling_bert.py @@ -0,0 +1,1740 @@ +# flake8: noqa +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BERT model.""" + +import math +import os +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from mmengine.logging import MMLogger +from torch import Tensor, device, dtype, nn +from torch.nn import CrossEntropyLoss, MSELoss +from transformers.activations import ACT2FN +# from transformers.models.bert.configuration_bert import BertConfig +from transformers.configuration_utils import PretrainedConfig +from transformers.file_utils import (ModelOutput, add_start_docstrings, + add_start_docstrings_to_model_forward, + replace_return_docstrings) +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, MaskedLMOutput, + MultipleChoiceModelOutput, NextSentencePredictorOutput, + QuestionAnsweringModelOutput, SequenceClassifierOutput, + TokenClassifierOutput) +from transformers.modeling_utils import (PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer) + +transformers.logging.set_verbosity_error() + +_CONFIG_FOR_DOC = 'BertConfig' +_TOKENIZER_FOR_DOC = 'BertTokenizer' + +BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ + 'bert-base-uncased', + 'bert-large-uncased', + 'bert-base-cased', + 'bert-large-cased', + 'bert-base-multilingual-uncased', + 'bert-base-multilingual-cased', + 'bert-base-chinese', + 'bert-base-german-cased', + 'bert-large-uncased-whole-word-masking', + 'bert-large-cased-whole-word-masking', + 'bert-large-uncased-whole-word-masking-finetuned-squad', + 'bert-large-cased-whole-word-masking-finetuned-squad', + 'bert-base-cased-finetuned-mrpc', + 'bert-base-german-dbmdz-cased', + 'bert-base-german-dbmdz-uncased', + 'cl-tohoku/bert-base-japanese', + 'cl-tohoku/bert-base-japanese-whole-word-masking', + 'cl-tohoku/bert-base-japanese-char', + 'cl-tohoku/bert-base-japanese-char-whole-word-masking', + 'TurkuNLP/bert-base-finnish-cased-v1', + 'TurkuNLP/bert-base-finnish-uncased-v1', + 'wietsedv/bert-base-dutch-cased', + # See all BERT models at https://huggingface.co/models?filter=bert +] + + +class BertConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`BertModel`] or a [`TFBertModel`]. It is used to + instantiate a BERT model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the BERT + [bert-base-uncased](https://huggingface.co/bert-base-uncased) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 30522): + Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`BertModel`] or [`TFBertModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 2): + The vocabulary size of the `token_type_ids` passed when calling [`BertModel`] or [`TFBertModel`]. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + position_embedding_type (`str`, *optional*, defaults to `"absolute"`): + Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For + positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to + [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155). + For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models + with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658). + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + classifier_dropout (`float`, *optional*): + The dropout ratio for the classification head. + + Examples: + + ```python + >>> from transformers import BertModel, BertConfig + + >>> # Initializing a BERT bert-base-uncased style configuration + >>> configuration = BertConfig() + + >>> # Initializing a model from the bert-base-uncased style configuration + >>> model = BertModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + model_type = 'bert' + + def __init__( + self, + vocab_size=30522, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act='gelu', + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + pad_token_id=0, + position_embedding_type='absolute', + use_cache=True, + classifier_dropout=None, + cross_module='ca', + encoder_width=768, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.position_embedding_type = position_embedding_type + self.use_cache = use_cache + self.classifier_dropout = classifier_dropout + self.cross_module = cross_module + self.encoder_width = encoder_width + + +def load_tf_weights_in_bert(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + logger = MMLogger.get_current_instance() + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + 'Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see ' + 'https://www.tensorflow.org/install/ for installation instructions.' + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info('Converting TensorFlow checkpoint from {}'.format(tf_path)) + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info('Loading TF weight {} with shape {}'.format(name, shape)) + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.split('/') + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any(n in [ + 'adam_v', + 'adam_m', + 'AdamWeightDecayOptimizer', + 'AdamWeightDecayOptimizer_1', + 'global_step', + ] for n in name): + logger.info('Skipping {}'.format('/'.join(name))) + continue + pointer = model + for m_name in name: + if re.fullmatch(r'[A-Za-z]+_\d+', m_name): + scope_names = re.split(r'_(\d+)', m_name) + else: + scope_names = [m_name] + if scope_names[0] == 'kernel' or scope_names[0] == 'gamma': + pointer = getattr(pointer, 'weight') + elif scope_names[0] == 'output_bias' or scope_names[0] == 'beta': + pointer = getattr(pointer, 'bias') + elif scope_names[0] == 'output_weights': + pointer = getattr(pointer, 'weight') + elif scope_names[0] == 'squad': + pointer = getattr(pointer, 'classifier') + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info('Skipping {}'.format('/'.join(name))) + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if m_name[-11:] == '_embeddings': + pointer = getattr(pointer, 'weight') + elif m_name == 'kernel': + array = np.transpose(array) + try: + assert ( + pointer.shape == array.shape + ), f'Pointer shape {pointer.shape} and array shape {array.shape} mismatched' + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + + logger.info('Initialize PyTorch weight {}'.format(name)) + pointer.data = torch.from_numpy(array) + return model + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type + embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, + config.hidden_size, + padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, + config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + 'position_ids', + torch.arange(config.max_position_embeddings).expand((1, -1))) + self.position_embedding_type = getattr(config, + 'position_embedding_type', + 'absolute') + + self.config = config + + def forward( + self, + input_ids=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + past_key_values_length=0, + ): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length: + seq_length + + past_key_values_length] + + if token_type_ids is None: + token_type_ids = torch.zeros( + input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == 'absolute': + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + + def __init__(self, config, is_cross_attention): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and not hasattr( + config, 'embedding_size'): + raise ValueError( + 'The hidden size (%d) is not a multiple of the number of attention ' + 'heads (%d)' % + (config.hidden_size, config.num_attention_heads)) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / + config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + if is_cross_attention: + self.key = nn.Linear(config.encoder_width, self.all_head_size) + self.value = nn.Linear(config.encoder_width, self.all_head_size) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr(config, + 'position_embedding_type', + 'absolute') + if (self.position_embedding_type == 'relative_key' + or self.position_embedding_type == 'relative_key_query'): + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding( + 2 * config.max_position_embeddings - 1, + self.attention_head_size) + self.save_attention = False + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, + self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + key_layer = self.transpose_for_scores( + self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores( + self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, + key_layer.transpose(-1, -2)) + + if (self.position_embedding_type == 'relative_key' + or self.position_embedding_type == 'relative_key_query'): + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange( + seq_length, dtype=torch.long, + device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange( + seq_length, dtype=torch.long, + device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding( + distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to( + dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == 'relative_key': + relative_position_scores = torch.einsum( + 'bhld,lrd->bhlr', query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == 'relative_key_query': + relative_position_scores_query = torch.einsum( + 'bhld,lrd->bhlr', query_layer, positional_embedding) + relative_position_scores_key = torch.einsum( + 'bhrd,lrd->bhlr', key_layer, positional_embedding) + attention_scores = ( + attention_scores + relative_position_scores_query + + relative_position_scores_key) + + attention_scores = attention_scores / math.sqrt( + self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + if is_cross_attention and self.save_attention: + self.save_attention_map(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = torch.matmul(attention_probs_dropped, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + ( + self.all_head_size, ) + context_layer = context_layer.view(*new_context_layer_shape) + + # added `attention_scores` to return tuple + outputs = ((context_layer, attention_probs, + attention_scores) if output_attentions else + (context_layer, )) + + outputs = outputs + (past_key_value, ) + return outputs + + +class BertSelfOutput(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + + def __init__(self, config, is_cross_attention=False): + super().__init__() + + self.self = BertSelfAttention(config, is_cross_attention) + + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, + self.self.num_attention_heads, + self.self.attention_head_size, + self.pruned_heads, + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len( + heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + # add attentions if we output them + outputs = (attention_output, ) + self_outputs[1:] + return outputs # (context_layer, attention_probs, attention_scores, past_key_value,) + + +class BertIntermediate(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + + self.has_cross_attention = layer_num >= config.fusion_layer + if self.has_cross_attention: + self.crossattention = BertAttention( + config, is_cross_attention=True) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[: + 2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) # (context_layer, attention_probs, attention_scores, past_key_value,) + attention_output = self_attention_outputs[0] + + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + + if self.has_cross_attention: + assert ( + encoder_hidden_states is not None + ), 'encoder_hidden_states must be given for cross-attention layers' + + if type(encoder_hidden_states) == list: + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states[(self.layer_num - + self.config.fusion_layer) % + len(encoder_hidden_states)], + encoder_attention_mask[(self.layer_num - + self.config.fusion_layer) % + len(encoder_hidden_states)], + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] + + else: + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) # (context_layer, attention_probs, attention_scores, past_key_value,) + attention_output = cross_attention_outputs[0] + # add cross attentions if we output attention weights + outputs = outputs + cross_attention_outputs[1:-1] + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output, + ) + outputs = (layer_output, ) + outputs + + outputs = outputs + (present_key_value, ) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [BertLayer(config, i) for i in range(config.num_hidden_layers)]) + logger = MMLogger.get_current_instance() + logger.info(f'build bert with cross_module: {config.cross_module}') + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + mode='multi_modal', + normalize_attention=True, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + # all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_cross_attentions = () if output_attentions else None + + next_decoder_cache = () if use_cache else None + + if (mode == 'text' or mode == 'temporal' + ): # temporal is added and used for temporal att module. + start_layer = 0 + output_layer = self.config.fusion_layer + + elif mode == 'fusion': + start_layer = self.config.fusion_layer + output_layer = self.config.num_hidden_layers + + elif mode == 'multi_modal': + start_layer = 0 + output_layer = self.config.num_hidden_layers + + for i in range(start_layer, output_layer): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[ + i] if past_key_values is not None else None + + if getattr(self.config, 'gradient_checkpointing', + False) and self.training: + + if use_cache: + logger = MMLogger.get_current_instance() + logger.warn( + '`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting ' + '`use_cache=False`...') + use_cache = False + + def create_custom_forward(module): + + def custom_forward(*inputs): + return module(*inputs, past_key_value, + output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + use_reentrant=False, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) # (context_layer, attention_probs, attention_scores, past_key_value,) + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1], ) + if output_attentions: + # whether to output normalized attention, + # note for unnormalized attention, there is a mask added + offset = int(normalize_attention) + # all_self_attentions = all_self_attentions + (layer_outputs[1], ) + all_self_attentions = all_self_attentions + ( + layer_outputs[2 - offset], ) + if hasattr(layer_module, 'crossattention'): + # all_cross_attentions = all_cross_attentions + (layer_outputs[3], ) + all_cross_attentions = all_cross_attentions + ( + layer_outputs[4 - offset], ) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + if not return_dict: + return tuple(v for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] if v is not None) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class BertPooler(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear( + config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertOnlyNSPHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, pooled_output): + seq_relationship_score = self.seq_relationship(pooled_output) + return seq_relationship_score + + +class BertPreTrainingHeads(nn.Module): + + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, sequence_output, pooled_output): + prediction_scores = self.predictions(sequence_output) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + +class BertPreTrainedModel(PreTrainedModel): + """An abstract class to handle weights initialization and a simple + interface for downloading and loading pretrained models.""" + + config_class = BertConfig + load_tf_weights = load_tf_weights_in_bert + base_model_prefix = 'bert' + _keys_to_ignore_on_load_missing = [r'position_ids'] + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_( + mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +@dataclass +class BertForPreTrainingOutput(ModelOutput): + """Output type of :class:`~transformers.BertForPreTraining`. + + Args: + loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`): + Total loss as the sum of the masked language modeling loss and the next sequence prediction + (classification) loss. + prediction_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + seq_relationship_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`): + Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation + before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + prediction_logits: torch.FloatTensor = None + seq_relationship_logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +BERT_START_DOCSTRING = r""" + This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + This model is also a PyTorch `torch.nn.Module `__ + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + Parameters: + config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. +""" + +BERT_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`): + Indices of input sequence tokens in the vocabulary. + Indices can be obtained using :class:`~transformers.BertTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + `What are attention masks? <../glossary.html#attention-mask>`__ + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, + 1]``: + - 0 corresponds to a `sentence A` token, + - 1 corresponds to a `sentence B` token. + `What are token type IDs? <../glossary.html#token-type-ids>`_ + position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + `What are position IDs? <../glossary.html#position-ids>`_ + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + + +@add_start_docstrings( + 'The bare Bert Model transformer outputting raw hidden-states without any specific head on top.', + BERT_START_DOCSTRING, +) +class BertModel(BertPreTrainedModel): + """The model can behave as an encoder (with only self-attention) as well as + a decoder, in which case a layer of cross-attention is added between the + self-attention layers, following the architecture described in `Attention + is all you need `__ by Ashish Vaswani, + Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. + + Gomez, Lukasz Kaiser and Illia Polosukhin. argument and + :obj:`add_cross_attention` set to :obj:`True`; an + :obj:`encoder_hidden_states` is then expected as an input to the forward + pass. + """ + + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """Prunes heads of the model. + + heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def get_extended_attention_mask(self, attention_mask: Tensor, + input_shape: Tuple[int], device: device, + is_decoder: bool) -> Tensor: + """Makes broadcastable attention and causal masks so that future and + masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + seq_ids = torch.arange(seq_length, device=device) + causal_mask = ( + seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= + seq_ids[None, :, None]) + # in case past_key_values are used we need to add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[ + 1] - causal_mask.shape[1] + causal_mask = torch.cat( + [ + torch.ones( + (batch_size, seq_length, prefix_seq_len), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=-1, + ) + + extended_attention_mask = ( + causal_mask[:, None, :, :] * + attention_mask[:, None, None, :]) + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + 'Wrong shape for input_ids (shape {}) or attention_mask (shape {})' + .format(input_shape, attention_mask.shape)) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to( + dtype=self.dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + mode='multi_modal', + normalize_attention=True, + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = ( + output_attentions if output_attentions is not None else + self.config.output_attentions) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + 'You cannot specify both input_ids and inputs_embeds at the same time' + ) + elif input_ids is not None: + input_shape = input_ids.size() + batch_size, seq_length = input_shape + device = input_ids.device + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = inputs_embeds.device + elif encoder_embeds is not None: + input_shape = encoder_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = encoder_embeds.device + else: + raise ValueError( + 'You have to specify either input_ids or inputs_embeds or encoder_embeds' + ) + + # past_key_values_length + past_key_values_length = ( + past_key_values[0][0].shape[2] + if past_key_values is not None else 0) + + if attention_mask is None: + attention_mask = torch.ones( + ((batch_size, seq_length + past_key_values_length)), + device=device) + if token_type_ids is None: + token_type_ids = torch.zeros( + input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( + attention_mask, input_shape, device, is_decoder) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[ + 0].size() + else: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size( + ) + encoder_hidden_shape = (encoder_batch_size, + encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [ + self.invert_attention_mask(mask) + for mask in encoder_attention_mask + ] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones( + encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask) + else: + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, + self.config.num_hidden_layers) + + if encoder_embeds is None: + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + else: + embedding_output = encoder_embeds + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + mode=mode, + normalize_attention=normalize_attention, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler( + sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +@add_start_docstrings( + """ + Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next + sentence prediction (classification)` head. + """, + BERT_START_DOCSTRING, +) +class BertForPreTraining(BertPreTrainedModel): + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config) + self.cls = BertPreTrainingHeads(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + @add_start_docstrings_to_model_forward( + BERT_INPUTS_DOCSTRING.format('batch_size, sequence_length')) + @replace_return_docstrings( + output_type=BertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + next_sentence_label=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape ``(batch_size, sequence_length)``, `optional`): + Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., + config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored + (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` + next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`): + Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair + (see :obj:`input_ids` docstring) Indices should be in ``[0, 1]``: + - 0 indicates sequence B is a continuation of sequence A, + - 1 indicates sequence B is a random sequence. + kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`): + Used to hide legacy arguments that have been deprecated. + Returns: + Example:: + >>> from transformers import BertTokenizer, BertForPreTraining + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + >>> model = BertForPreTraining.from_pretrained('bert-base-uncased') + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + >>> prediction_logits = outputs.prediction_logits + >>> seq_relationship_logits = outputs.seq_relationship_logits + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output, pooled_output = outputs[:2] + prediction_scores, seq_relationship_score = self.cls( + sequence_output, pooled_output) + + total_loss = None + if labels is not None and next_sentence_label is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1)) + next_sentence_loss = loss_fct( + seq_relationship_score.view(-1, 2), + next_sentence_label.view(-1)) + total_loss = masked_lm_loss + next_sentence_loss + + if not return_dict: + output = (prediction_scores, seq_relationship_score) + outputs[2:] + return ((total_loss, ) + + output) if total_loss is not None else output + + return BertForPreTrainingOutput( + loss=total_loss, + prediction_logits=prediction_scores, + seq_relationship_logits=seq_relationship_score, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """Bert Model with a `language modeling` head on top for CLM fine-tuning. """, + BERT_START_DOCSTRING, +) +class BertLMHeadModel(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r'pooler'] + _keys_to_ignore_on_load_missing = [ + r'position_ids', r'predictions.decoder.bias' + ] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + @add_start_docstrings_to_model_forward( + BERT_INPUTS_DOCSTRING.format('batch_size, sequence_length')) + @replace_return_docstrings( + output_type=CausalLMOutputWithCrossAttentions, + config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=True, + reduction='mean', + mode='multi_modal', + normalize_attention=True, + soft_labels=None, + alpha=0, + return_logits=False, + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + Returns: + Example:: + >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + >>> config = BertConfig.from_pretrained("bert-base-cased") + >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + >>> prediction_logits = outputs.logits + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + mode=mode, + normalize_attention=normalize_attention, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores[:, :-1, :].contiguous() + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, : + -1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss(reduction=reduction) + lm_loss = loss_fct( + shifted_prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1)) + lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1) + + if soft_labels is not None: + loss_distill = -torch.sum( + F.log_softmax(shifted_prediction_scores, dim=1) * soft_labels, + dim=-1) + loss_distill = (loss_distill * (labels != -100)).sum(1) + lm_loss = (1 - alpha) * lm_loss + alpha * loss_distill + + if not return_dict: + output = (prediction_scores, ) + outputs[2:] + return ((lm_loss, ) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation(self, + input_ids, + past=None, + attention_mask=None, + **model_kwargs): + input_shape = input_ids.shape + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + 'input_ids': + input_ids, + 'attention_mask': + attention_mask, + 'past_key_values': + past, + 'encoder_hidden_states': + model_kwargs.get('encoder_hidden_states', None), + 'encoder_attention_mask': + model_kwargs.get('encoder_attention_mask', None), + 'is_decoder': + True, + } + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple( + past_state.index_select(0, beam_idx) + for past_state in layer_past), ) + return reordered_past + + +@dataclass +class MaskedLMOutputWithDistill(MaskedLMOutput): + loss_aux: Optional[torch.FloatTensor] = None + loss_distill: Optional[torch.FloatTensor] = None + + +@add_start_docstrings( + """Bert Model with a `language modeling` head on top. """, + BERT_START_DOCSTRING) +class BertForMaskedLM(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r'pooler'] + _keys_to_ignore_on_load_missing = [ + r'position_ids', r'predictions.decoder.bias' + ] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def tie_aux_decoder_weights(self, module, aux_modules): + """Tie decoder weights of all `aux_modules` to `module`, (not bias)""" + for m in aux_modules: + m.predictions.decoder.weight = module.predictions.decoder.weight + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + mode='multi_modal', + normalize_attention=True, + soft_labels=None, + alpha=0, + return_logits=False, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., + config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored + (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_embeds=encoder_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + mode=mode, + normalize_attention=normalize_attention, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores + + masked_lm_loss = None + masked_lm_loss_aux = 0.0 + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1)) + + if soft_labels is not None: + loss_distill = -torch.sum( + F.log_softmax(prediction_scores, dim=1) * soft_labels, dim=-1) + loss_distill = loss_distill[labels != -100].mean() + masked_lm_loss = (1 - + alpha) * masked_lm_loss + alpha * loss_distill + + if not return_dict: + output = (prediction_scores, ) + outputs[2:] + return ((masked_lm_loss, ) + + output) if masked_lm_loss is not None else output + + # changed from MaskedLMOutput to MaskedLMOutputWithDistill + return MaskedLMOutputWithDistill( + loss=masked_lm_loss, + loss_aux=masked_lm_loss_aux, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation(self, + input_ids, + attention_mask=None, + **model_kwargs): + input_shape = input_ids.shape + effective_batch_size = input_shape[0] + + # add a dummy token + assert (self.config.pad_token_id + is not None), 'The PAD token should be defined for generation' + attention_mask = torch.cat([ + attention_mask, + attention_mask.new_zeros((attention_mask.shape[0], 1)) + ], + dim=-1) + dummy_token = torch.full( + (effective_batch_size, 1), + self.config.pad_token_id, + dtype=torch.long, + device=input_ids.device, + ) + input_ids = torch.cat([input_ids, dummy_token], dim=1) + + return {'input_ids': input_ids, 'attention_mask': attention_mask} diff --git a/mmaction/models/multimodal/vindlu/temporal_model.py b/mmaction/models/multimodal/vindlu/temporal_model.py new file mode 100644 index 0000000000..7271aedc8a --- /dev/null +++ b/mmaction/models/multimodal/vindlu/temporal_model.py @@ -0,0 +1,213 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import einops +import torch +from einops import rearrange +from timm.models.layers import DropPath +from torch import nn +from torch.nn import LayerNorm, Linear, MultiheadAttention + + +class STAdapter(nn.Module): + """ST Adapter.""" + + def __init__( + self, + kernel_size=(3, 3, 3), + input_dim=768, + hidden_dim=384, + img_size=224, + patch_size=16, + drop_prob=0.1, + ): + super(STAdapter, self).__init__() + self.kernel_size = kernel_size + self.input_dim = input_dim + self.hidden_dim = hidden_dim + + self.h = self.w = img_size // patch_size + + self.linear1 = nn.Linear(input_dim, hidden_dim) + self.linear2 = nn.Linear(hidden_dim, input_dim) + self.act = nn.ReLU() + self.conv = nn.Conv3d( + hidden_dim, + hidden_dim, + kernel_size=kernel_size, + padding='same', + groups=hidden_dim) + self.droppath = DropPath(drop_prob=drop_prob) + + self.scale = nn.parameter.Parameter(torch.zeros([])) + + def forward(self, x: torch.Tensor): + """forward. + + Args: + x (torch.Tensor): input features. + Shape: [bs, nframes, l, c]. l = 1 + h*w + + Returns: features after adapter. The same shape as input. + """ + if x.shape[1] == 1: # for single frame, return itself. + return x + + shortcut = x + x = self.linear1(x) + cls = x[:, :, :1, :] + tokens = x[:, :, 1:, :] + tokens = einops.rearrange( + tokens, 'b t (h w) c -> b c t h w', h=self.h).contiguous() + tokens = self.conv(tokens) + tokens = einops.rearrange(tokens, 'b c t h w -> b t (h w) c') + x = torch.cat([cls, tokens], dim=2) # [b, t, 1+h*w, c] + x = self.act(x) + x = self.linear2(x) + + return shortcut + self.scale * self.droppath(x) + + +class TemporalAttention(nn.Module): + """perform temporal self-attention.""" + + def __init__(self, input_dim=768, droppath_rate=0.1): + """ + + Kwargs: + input_dim (int): The input feature dimension. + + + """ + super().__init__() + + self._input_dim = input_dim + self.temporal_attn = MultiheadAttention( + input_dim, num_heads=input_dim // 64) + self.norm = LayerNorm(input_dim, eps=1e-12) + self.linear = Linear(input_dim, input_dim) + self.droppath = DropPath(droppath_rate) + self.scale = nn.parameter.Parameter(torch.zeros([])) + + def forward(self, x: torch.Tensor): + """forward. + + Args: + x (torch.Tensor): input features. + Shape: [bs, nframes, l, c]. l = 1 + h*w + + Returns: features after adapter. The same shape as input. + """ + if x.shape[1] == 1: # for single frame, return itself. + return x + + shortcut = x + x = einops.rearrange(x, 'b t l c -> t (b l) c') + x = self.norm(x) + x = self.temporal_attn(x, x, x)[0] + x = einops.rearrange(x, 't (b l) c -> b t l c', b=shortcut.shape[0]) + return shortcut + self.scale * self.droppath(x) + + +class WindowTemporalAttention(nn.Module): + """perform windowed temporal self-attention.""" + + def __init__(self, input_dim=768, droppath_rate=0.1, window_size=(2, 2)): + """ + + Kwargs: + input_dim (int): The input feature dimension. + + + """ + super().__init__() + + self._input_dim = input_dim + self.temporal_attn = MultiheadAttention( + input_dim, num_heads=input_dim // 64) + self.norm = LayerNorm(input_dim, eps=1e-12) + self.droppath = DropPath(droppath_rate) + self.scale = nn.parameter.Parameter(torch.zeros([])) + self.wh, self.ww = window_size + + def forward(self, x: torch.Tensor): + """forward. + + Args: + x (torch.Tensor): input features. + Shape: [bs, nframes, l, c]. l = 1 + h*w + + Returns: features after adapter. The same shape as input. + """ + if x.shape[1] == 1: # for single frame, return itself. + return x + shortcut = x + + h = w = int(math.sqrt(x.shape[2] - 1)) + cls_token = x[:, :, :1, :] + x = einops.rearrange( + x[:, :, 1:, :], + 'b t (nh wh nw ww) c -> (t wh ww) (b nh nw) c', + nh=h // self.wh, + wh=self.wh, + nw=w // self.ww, + ww=self.ww, + ) + x = self.norm(x) + x = self.temporal_attn(x, x, x)[0] + x = einops.rearrange( + x, + '(t wh ww) (b nh nw) c -> b t (nh wh nw ww) c', + wh=self.wh, + ww=self.ww, + nh=h // self.wh, + nw=w // self.ww, + ) + # add back cls token. + x = torch.concat([cls_token, x], dim=2) + return shortcut + self.scale * self.droppath(x) + + +class X_CLIP(nn.Module): + """perform windowed temporal self-attention.""" + + def __init__(self, input_dim=768, droppath_rate=0.1, num_prompts=1): + """ + + Kwargs: + input_dim (int): The input feature dimension. + + + """ + super().__init__() + + d_model = input_dim + + self.message_fc = nn.Linear(d_model, d_model) + self.message_ln = LayerNorm(d_model, eps=1e-12) + self.message_attn = nn.MultiheadAttention(d_model, d_model // 64) + self.num_prompts = num_prompts + + self.droppath = DropPath(droppath_rate) + + def forward(self, x: torch.Tensor): + """forward. + + Args: + x (torch.Tensor): input features. + Shape: [bs, nframes, l, c]. l = 1 + h*w + + Returns: features after adapter. The same shape as input. + """ + if x.shape[1] == 1: # for single frame, return itself. + return x + msg_token = self.message_ln(self.message_fc(x[:, :, + 0, :])) # [b, t, c] + msg_token = rearrange(msg_token, 'b t c -> t b c') + msg_token = msg_token + self.droppath( + self.message_attn(msg_token, msg_token, msg_token)[0]) + msg_token = rearrange(msg_token, 't b c -> b t c') + # replace the last prompt token with msg_token. + x = torch.cat([x[:, :, :-1, :], + msg_token.unsqueeze(2)], dim=2) # [b, t, l+1, c] + return x diff --git a/mmaction/models/multimodal/vindlu/tokenizer.py b/mmaction/models/multimodal/vindlu/tokenizer.py new file mode 100644 index 0000000000..92be293dff --- /dev/null +++ b/mmaction/models/multimodal/vindlu/tokenizer.py @@ -0,0 +1,45 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional + +from transformers import BertTokenizer + +from mmaction.registry import TOKENIZER + + +class VindLUTokenizer(BertTokenizer): + """VindLUTokenizer inherit BertTokenizer. + + The main difference from BertTokenizer is removing the last separate token + for a single sequence. + """ + + def build_inputs_with_special_tokens( + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None) -> List[int]: + """Build model inputs from a sequence or a pair of sequence for + sequence classification tasks by concatenating and adding special + tokens. A BERT sequence has the following format: + + - single sequence: `[CLS] X` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with + the appropriate special tokens. + """ + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + token_ids_1 + sep + + +TOKENIZER.register_module( + 'VindLUTokenizer', module=VindLUTokenizer.from_pretrained) diff --git a/mmaction/models/multimodal/vindlu/utils.py b/mmaction/models/multimodal/vindlu/utils.py new file mode 100644 index 0000000000..8737dde9ea --- /dev/null +++ b/mmaction/models/multimodal/vindlu/utils.py @@ -0,0 +1,195 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import mmengine.dist as dist +import numpy as np +import torch +import torch.nn.functional as F +from mmengine.logging import MMLogger +from scipy import interpolate + + +def all_gather_concat(data: torch.Tensor) -> torch.Tensor: + """Gather tensors with different first-dimension size and concat to one + tenosr. + + Note: + Only the first dimension should be different. + + Args: + data (Tensor): Tensor to be gathered. + + Returns: + torch.Tensor: The concatenated tenosr. + """ + if dist.get_world_size() == 1: + return data + + data_size = torch.tensor(data.size(0), device=data.device) + sizes_list = dist.all_gather(data_size) + + total_length = sum(sizes_list) + max_length = max(sizes_list) + size_diff = max_length.item() - data_size.item() + if size_diff: + padding = torch.zeros( + size_diff, *data.size()[1:], device=data.device, dtype=data.dtype) + data = torch.cat((data, padding)) + + gather_list = dist.all_gather(data) + + # gather all data according to the default DDP sampler. For instance, + # 8 samples on 2 GPUs, GPU0: [0,2,4,6], GPU1: [1,3,5,7], will be gathered + # as [0,1,2,3,4,5,6,7] + all_data = [] + for gather_batch in zip(*gather_list): + all_data.extend(gather_batch) + + return torch.stack(all_data)[:total_length] + + +def interpolate_pos_embed_beit(state_dict, new_model): + """interpolate the positional embeddings. The spatial pe is relative and + temporal pe is absolute. additional temporal pe is padded with 0. + + Args: + state_dict (dict): The state_dict. + new_model (nn.Module): The created model. + + Returns: dict. The state_dict with updated positional embeddings. + """ + state_dict = interpolate_pos_relative_bias_beit( + state_dict_old=state_dict, + state_dict_new=new_model.state_dict(), + patch_shape_new=new_model.vision_encoder.embeddings.patch_embeddings. + patch_shape, + ) + # absolute temporal pos bias + temporal_pe_key = 'vision_encoder.embeddings.temporal_position_embeddings' + if temporal_pe_key in state_dict: + logger = MMLogger.get_current_instance() + logger.info( + f'interpolate temporal positional embeddings: {temporal_pe_key}') + state_dict[temporal_pe_key] = load_temp_embed_with_mismatch( + temp_embed_old=state_dict[temporal_pe_key], + temp_embed_new=new_model.state_dict()[temporal_pe_key], + ) + return state_dict + + +def load_temp_embed_with_mismatch(temp_embed_old, + temp_embed_new, + add_zero=True): + """Add/Remove extra temporal_embeddings as needed. + https://arxiv.org/abs/2104.00650 shows adding zero paddings works. + + temp_embed_old: (1, num_frames_old, 1, d) + temp_embed_new: (1, num_frames_new, 1, d) + add_zero: bool, if True, add zero, else, interpolate trained embeddings. + """ + # TODO zero pad + num_frms_new = temp_embed_new.shape[1] + num_frms_old = temp_embed_old.shape[1] + logger = MMLogger.get_current_instance() + logger.info( + f'Load temporal_embeddings, lengths: {num_frms_old}-->{num_frms_new}') + if num_frms_new > num_frms_old: + if add_zero: + temp_embed_new[:, :num_frms_old] \ + = temp_embed_old # untrained embeddings are zeros. + else: + temp_embed_new = interpolate_temporal_pos_embed( + temp_embed_old, num_frms_new) + elif num_frms_new < num_frms_old: + temp_embed_new = temp_embed_old[:, :num_frms_new] + else: # = + temp_embed_new = temp_embed_old + return temp_embed_new + + +def interpolate_temporal_pos_embed(temp_embed_old, num_frames_new): + """ + temp_embed_old: (1, num_frames_old, 1, d) + Returns: + temp_embed_new: (1, num_frames_new, 1, d) + """ + temp_embed_old = temp_embed_old.squeeze(2).permute( + 0, 2, 1) # (1, d, num_frames_old) + temp_embed_new = F.interpolate( + temp_embed_old, num_frames_new, + mode='linear') # (1, d, num_frames_new) + temp_embed_new = temp_embed_new.permute(0, 2, 1).unsqueeze( + 2) # (1, num_frames_new, 1, d) + return temp_embed_new + + +def interpolate_pos_relative_bias_beit(state_dict_old, state_dict_new, + patch_shape_new): + """ + Args: + state_dict_old: loaded state dict + state_dict_new: state dict for model with new image size + patch_shape_new: new model patch_shape + ref: https://github.com/microsoft/unilm/blob/master/beit/run_class_finetuning.py # noqa: E501 + """ + all_keys = list(state_dict_old.keys()) + for key in all_keys: + if 'relative_position_index' in key: + state_dict_old.pop(key) + + if 'relative_position_bias_table' in key: + rel_pos_bias = state_dict_old[key] + src_num_pos, num_attn_heads = rel_pos_bias.size() + dst_num_pos, _ = state_dict_new[key].size() + dst_patch_shape = patch_shape_new + if dst_patch_shape[0] != dst_patch_shape[1]: + raise NotImplementedError() + num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * ( + dst_patch_shape[1] * 2 - 1) + src_size = int((src_num_pos - num_extra_tokens)**0.5) + dst_size = int((dst_num_pos - num_extra_tokens)**0.5) + if src_size != dst_size: + extra_tokens = rel_pos_bias[-num_extra_tokens:, :] + rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :] + + def geometric_progression(a, r, n): + return a * (1.0 - r**n) / (1.0 - r) + + left, right = 1.01, 1.5 + while right - left > 1e-6: + q = (left + right) / 2.0 + gp = geometric_progression(1, q, src_size // 2) + if gp > dst_size // 2: + right = q + else: + left = q + + dis = [] + cur = 1 + for i in range(src_size // 2): + dis.append(cur) + cur += q**(i + 1) + + r_ids = [-_ for _ in reversed(dis)] + + x = r_ids + [0] + dis + y = r_ids + [0] + dis + + t = dst_size // 2.0 + dx = np.arange(-t, t + 0.1, 1.0) + dy = np.arange(-t, t + 0.1, 1.0) + + all_rel_pos_bias = [] + + for i in range(num_attn_heads): + z = rel_pos_bias[:, i].view(src_size, + src_size).float().numpy() + f = interpolate.interp2d(x, y, z, kind='cubic') + all_rel_pos_bias.append( + torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to( + rel_pos_bias.device)) + + rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) + + new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), + dim=0) + state_dict_old[key] = new_rel_pos_bias + return state_dict_old diff --git a/mmaction/models/multimodal/vindlu/vindlu.py b/mmaction/models/multimodal/vindlu/vindlu.py new file mode 100644 index 0000000000..1f6f9dcff2 --- /dev/null +++ b/mmaction/models/multimodal/vindlu/vindlu.py @@ -0,0 +1,227 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import abstractmethod +from typing import Optional + +import torch +from mmengine.logging import MMLogger +from mmengine.model import BaseModel +from mmengine.runner.checkpoint import _load_checkpoint +from torch import nn + +from mmaction.registry import MODELS, TOKENIZER +from mmaction.utils import ForwardResults, SampleList +from .utils import (interpolate_pos_embed_beit, + interpolate_pos_relative_bias_beit) + + +class VindLUBase(BaseModel): + """VindLU base Model. + + Args: + tokenizer: (dict): The config for tokenizer. + vision_encoder (dict): Backbone for extracting image features. + text_encoder (dict): Backbone for extracting text features. + temperature (float): Temperature parameter that controls the + concentration level of the distribution. Defaults to 0.07. + gradient_checkpointing (bool): Whether to do gradient_checkpointing. + Using checkpoint will save some memory while slowing down the + training speed. Defaults to False. + data_preprocessor (Optional[dict]): The config for preprocessing input + data. + init_cfg (Optional[dict]): the config to control the initialization. + Defaults to None. + """ + + def __init__( + self, + tokenizer: dict, + vision_encoder: dict, + text_encoder: dict, + proj_dim: int = 256, + temperature: float = 0.07, + gradient_checkpointing: bool = False, + pretrined_vl: bool = True, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None, + ): + if data_preprocessor is None: + data_preprocessor = dict(type='ActionDataPreprocessor') + super().__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + self.tokenizer = TOKENIZER.build(tokenizer) + self.vision_cfg = vision_encoder + self.text_encoder_cfg = text_encoder + self.gradient_checkpointing = gradient_checkpointing + self.text_encoder_cfg.gradient_checkpointing = gradient_checkpointing + + self.vision_width = vision_encoder.pop('encoder_width') + self.text_width = text_encoder.encoder_width + self.pretrined_vl = pretrined_vl + + if self.vision_cfg.pop('add_ln'): + self.vision_layernorm = nn.LayerNorm(self.vision_width, eps=1e-12) + else: + self.vision_layernorm = nn.Identity() + + self.vision_encoder = MODELS.build(self.vision_cfg) + + if gradient_checkpointing: + self.vision_encoder.gradient_checkpointing_enable() + + self.text_encoder = MODELS.build(self.text_encoder_cfg) + + self.vision_proj = nn.Linear(self.vision_width, proj_dim) + self.text_proj = nn.Linear(self.text_width, proj_dim) + + self.temp = nn.parameter.Parameter(torch.ones([]) * temperature) + self.itm_head = nn.Linear(self.text_width, 2) + + def extract_feat(self, inputs: torch.Tensor, **kwargs) -> ForwardResults: + """Extract features from raw inputs.""" + + @abstractmethod + def loss(self, inputs: torch.Tensor, data_samples: SampleList, + **kwargs) -> dict: + """Calculate losses from a batch of inputs and data samples.""" + + def forward(self, inputs, data_samples, mode: str = 'loss'): + """The unified entry for a forward process in both training and test. + + The method should accept three modes: + + - ``tensor``: Forward the whole network and return tensor or tuple of + tensor without any post-processing, same as a common nn.Module. + - ``predict``: Forward and return the predictions, which are fully + processed to a list of :obj:`ActionDataSample`. + - ``loss``: Forward and return a dict of losses according to the given + inputs and data samples. + + Note that this method doesn't handle neither back propagation nor + optimizer updating, which are done in the :meth:`train_step`. + + Args: + inputs (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[``ActionDataSample], optional): The + annotation data of every samples. Defaults to None. + mode (str): Return what kind of value. Defaults to ``tensor``. + + Returns: + The return type depends on ``mode``. + + - If ``mode="tensor"``, return a tensor or a tuple of tensor. + - If ``mode="predict"``, return a list of ``ActionDataSample``. + - If ``mode="loss"``, return a dict of tensor. + """ + + if mode == 'tensor': + return self.extract_feat(inputs, data_samples) + elif mode == 'loss': + return self.loss(inputs, data_samples) + elif mode == 'predict': + return self.predict(inputs, data_samples) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def encode_vision(self, image): + """encode image / videos as features. + + Args: + image (torch.Tensor): The input images. + + Returns: tuple. + - vision_embeds (torch.Tensor): The features of all patches. + Shape: [B,T,L,C]. + - pooled_vision_embeds (torch.Tensor): The pooled features. + Shape: [B,T,C]. + """ + output_dict = self.vision_encoder(image) + vision_embeds = self.vision_layernorm(output_dict.last_hidden_state) + pooled_vision_embeds = output_dict.pooler_output + + return vision_embeds, pooled_vision_embeds + + def encode_text(self, text): + """encode text. + Args: + text (dict): The output of huggingface's `PreTrainedTokenizer`. + contains keys: + - input_ids (torch.Tensor): Token ids to be fed to a model. + Shape: [B,L]. + - attention_mask (torch.Tensor): The mask indicate padded tokens. + Shape: [B,L]. 0 is padded token. + - other keys refer to "https://huggingface.co/docs/transformers/v4.21.2/en/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__". # noqa: E501 + Returns: tuple. + - text_embeds (torch.Tensor): The features of all tokens. Shape: [B,L,C]. + - pooled_text_embeds (torch.Tensor): The pooled features. Shape: [B,C]. + + """ + text_output = self.text_encoder( + text.input_ids, + attention_mask=text.attention_mask, + return_dict=True, + mode='text', + ) + text_embeds = text_output.last_hidden_state + pooled_text_embeds = text_embeds[:, 0] + return text_embeds, pooled_text_embeds + + @torch.no_grad() + def clip_contrastive_temperature(self, min_val=0.001, max_val=0.5): + """Seems only used during pre-training.""" + self.temp.clamp_(min_val, max_val) + + @property + def device(self): + return next(self.parameters()).device + + def preprocess_state_dict(self, state_dict): + """Preprocess pretrained checkpoint for text_encoder.""" + for key in list(state_dict.keys()): + if 'bert' in key: + encoder_key = key.replace('bert.', '') + state_dict[encoder_key] = state_dict[key] + del state_dict[key] + return state_dict + + def load_from_pretrainded_beit(self): + from transformers.models.beit.modeling_beit import BeitModel + beit2d = BeitModel.from_pretrained( + self.vision_cfg.pretrained_model_name_or_path) + ori_state_dict = beit2d.state_dict() + del beit2d + # interpolate relative pos bias + state_dict = interpolate_pos_relative_bias_beit( + state_dict_old=ori_state_dict, + state_dict_new=self.vision_encoder.state_dict(), + patch_shape_new=self.vision_encoder.embeddings.patch_embeddings. + patch_shape, + ) + + for k in list(state_dict.keys()): + if 'prompt_bias_table' in k: + del state_dict[k] + + msg = self.vision_encoder.load_state_dict(state_dict, strict=False) + logger = MMLogger.get_current_instance() + logger.info(msg) + + def init_weights(self): + if self.vision_cfg.get('pretrained2d', False): + self.load_from_pretrainded_beit() + + if self.pretrined_vl: + assert self.init_cfg.get('type') == 'Pretrained', ( + 'Please specify ' + 'init_cfg to use pretrained video-language checkpoint') + self.pretrained = self.init_cfg.get('checkpoint') + checkpoint = _load_checkpoint(self.pretrained, map_location='cpu') + state_dict = checkpoint['model'] + state_dict = interpolate_pos_embed_beit(state_dict, self) + state_dict = self.preprocess_state_dict(state_dict) + msg = self.load_state_dict(state_dict, strict=False) + logger = MMLogger.get_current_instance() + logger.info(msg) + else: + super().init_weights() diff --git a/mmaction/models/multimodal/vindlu/vindlu_ret.py b/mmaction/models/multimodal/vindlu/vindlu_ret.py new file mode 100644 index 0000000000..cc80982c39 --- /dev/null +++ b/mmaction/models/multimodal/vindlu/vindlu_ret.py @@ -0,0 +1,464 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional + +import mmengine.dist as dist +import torch +import torch.nn.functional as F +from einops import rearrange +from torch.distributed.nn import all_gather as all_gather_with_grad + +from mmaction.registry import MODELS +from mmaction.structures import ActionDataSample +from mmaction.utils import track_on_main_process +from .utils import all_gather_concat +from .vindlu import VindLUBase + + +@MODELS.register_module() +class VindLURetrieval(VindLUBase): + """VindLU retriever. + + max_txt_len (int): Max text length of input text, used for retrieval + from multiple choices. Defaults to 32. + topk (int): Select topk similarity as candidates for compute matching + scores. Defaults to 256. + negative_all_rank (bool): Whether to sample negative data from all + ranks for image text matching in training. Defaults to False. + fast_match (bool): If False, select topk similarity as candidates and + compute the matching score. If True, return the similarity as the + matching score directly. Defaults to False. + **kwargs: Other keyword arguments to initialize the VindLU base model. + """ + + def __init__(self, + max_txt_len: int = 32, + topk: int = 128, + negative_all_rank: bool = False, + fast_match: bool = False, + **kwargs): + super().__init__(**kwargs) + + self.max_txt_len = max_txt_len + self.topk = topk + self.negative_all_rank = negative_all_rank + self.fast_match = fast_match + + def loss( + self, + inputs: torch.Tensor, + data_samples: Optional[List[ActionDataSample]] = None, + ) -> Dict[str, torch.tensor]: + """Calculate losses from a batch of inputs and data samples. + + Args: + inputs (dict): A batch of inputs. The input tensor with of + at least one modality. For image, the value is a tensor + of shape (N, C, ...) in general. + For text, the value is a dict of tokenized text inputs. + data_samples (Optional[List[DataSample]]): + The annotation data of every samples. Defaults to None. + + Returns: + Dict[str, torch.tensor]: a dictionary of loss components of + """ + output = self.extract_feat(inputs, data_samples) + + text_embeds = output['text_embeds'] + text_attn_mask = output['text_attn_mask'] + image_embeds = output['image_embeds'] + image_feat = output['image_feat'] + text_feat = output['text_feat'] + + image_atts = torch.ones( + image_embeds.size()[:-1], dtype=torch.long).to(self.device) + + # ITC Loss + # B*world_size, D + image_feat_all = torch.cat(dist.all_gather(image_feat)) + # B*world_size, D + text_feat_all = torch.cat(dist.all_gather(text_feat)) + + # image to text similarity + # B, B*world_size + sim_i2t = torch.einsum('mld,nd->mln', image_feat, + text_feat_all).mean(1) / self.temp + # text-image similarity + # B, B*world_size + sim_t2i = torch.einsum('md,nld->mln', text_feat, + image_feat_all).mean(1) / self.temp + + rank = dist.get_rank() + bs = inputs.size(0) + itc_targets = torch.linspace( + rank * bs, rank * bs + bs - 1, bs, dtype=int).to(self.device) + + itc_loss = (F.cross_entropy(sim_i2t, itc_targets) + + F.cross_entropy(sim_t2i, itc_targets)) / 2 + + # prepare for itm + output_pos = self.text_encoder( + encoder_embeds=text_embeds, + attention_mask=text_attn_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + mode='fusion', + ) + + idx = torch.tensor([i.gt_video_id for i in data_samples]).view(-1, 1) + bs = idx.size(0) + if self.negative_all_rank: + idxs = torch.cat(dist.all_gather(idx)) + image_feat_world = torch.cat(dist.all_gather(image_feat)) + text_feat_world = torch.cat(dist.all_gather(text_feat)) + att_mask_world = torch.cat(dist.all_gather(text_attn_mask)) + text_embeds_world = torch.cat(all_gather_with_grad(text_embeds)) + image_embeds_world = torch.cat(all_gather_with_grad(image_embeds)) + else: + idxs = idx + image_feat_world = image_feat.detach() + text_feat_world = text_feat.detach() + image_embeds_world = image_embeds + text_embeds_world = text_embeds + att_mask_world = text_attn_mask + + with torch.no_grad(): + # compute sample similarity + sim_i2t = torch.einsum('mld,nd->mln', image_feat, + text_feat_world).mean(1) / self.temp + sim_t2i = torch.einsum('md,nld->mln', text_feat, + image_feat_world).mean(1) / self.temp + + mask = torch.eq(idx, idxs.t()).to(self.device) + weights_i2t = F.softmax(sim_i2t + 1e-4, dim=1) + weights_i2t.masked_fill_(mask, 0) + + weights_t2i = F.softmax(sim_t2i + 1e-4, dim=1) + weights_t2i.masked_fill_(mask, 0) + + # select a negative image for each text + neg_idx = torch.multinomial(weights_t2i, 1).squeeze() + image_embeds_neg = image_embeds_world[neg_idx] + + # select a negative text for each image + neg_idx = torch.multinomial(weights_i2t, 1).squeeze() + text_embeds_neg = text_embeds_world[neg_idx] + text_atts_neg = att_mask_world[neg_idx] + + text_embeds_all = torch.cat([text_embeds, text_embeds_neg], dim=0) + text_atts_all = torch.cat([text_attn_mask, text_atts_neg], dim=0) + + image_embeds_all = torch.cat([image_embeds_neg, image_embeds], dim=0) + image_atts_all = torch.cat([image_atts, image_atts], dim=0) + + output_neg = self.text_encoder( + encoder_embeds=text_embeds_all, + attention_mask=text_atts_all, + encoder_hidden_states=image_embeds_all, + encoder_attention_mask=image_atts_all, + return_dict=True, + mode='fusion', + ) + + vl_embeddings = torch.cat( + [ + output_pos.last_hidden_state[:, 0, :], + output_neg.last_hidden_state[:, 0, :], + ], + dim=0, + ) + + itm_targets = torch.ones((3 * bs, ), + dtype=torch.long, + device=inputs.device) + itm_targets[bs:] = 0 + itm_logit = self.itm_head(vl_embeddings) + itm_loss = F.cross_entropy(itm_logit, itm_targets) + + return dict(itc_loss=itc_loss, itm_loss=itm_loss) + + def preprocess_text(self, data_samples): + sample_item = data_samples[0] + + if sample_item is not None and 'text' in sample_item: + if isinstance(sample_item.get('text'), (list, tuple)): + texts = [] + for sample in data_samples: + texts.extend(sample.get('text')) + elif isinstance(sample_item.get('text'), str): + texts = [sample.get('text') for sample in data_samples] + else: + raise TypeError('text must be a string or a list of strings') + else: + return None + + # perform tokenize first if satisfied conditions + texts = self.tokenizer( + texts, + padding='max_length', + truncation=True, + max_length=self.max_txt_len, + return_tensors='pt', + ).to(self.device) + + return texts + + def extract_feat( + self, + images: torch.Tensor = None, + data_samples: List[ActionDataSample] = None, + return_texts=True, + ) -> Dict[str, torch.Tensor]: + """Extract features from the input dict. + + Args: + images (tensor, optional): The images to extract features. + Defaults to None. + data_samples (list, optional): The data samples containing texts + to extract features. Defaults to None. + return_texts (bool): Whether to return the tokenized text and the + corresponding attention masks. Defaults to True. + + Returns: + Tuple[torch.Tensor]: The output features. + If multimodal_backbone is not exist, tuple of torch.Tensor + will be returned. + """ + if data_samples is not None: + texts = self.preprocess_text(data_samples) + else: + texts = None + + assert images is not None or texts is not None, \ + 'At least single modality should be passed as inputs.' + + results = {} + if texts is not None and return_texts: + results.update({ + 'text_ids': texts.input_ids, + 'text_attn_mask': texts.attention_mask, + }) + + # extract image features + if images is not None: + image_embeds, pooled_image_embeds = self.encode_vision(images) + # concat temporal embeds + image_embeds = rearrange(image_embeds, + 'b t l c -> b (t l) c').contiguous() + results['image_embeds'] = image_embeds + results['image_feat'] = F.normalize( + self.vision_proj(pooled_image_embeds), dim=-1) + + # extract text features + if texts is not None: + texts_output = self.text_encoder( + texts.input_ids, + attention_mask=texts.attention_mask, + return_dict=True, + mode='text') + + text_embeds = texts_output.last_hidden_state + pooled_text_feat = text_embeds[:, 0] + results['text_embeds'] = text_embeds + results['text_feat'] = F.normalize( + self.text_proj(pooled_text_feat), dim=-1) + + return results + + def predict(self, images, data_samples, cal_i2t=True, cal_t2i=True): + feats = self.extract_feat(images, data_samples) + + return self.predict_all( + feats, data_samples, cal_i2t=cal_i2t, cal_t2i=cal_t2i) + + def predict_all(self, + feats, + data_samples, + num_images=None, + num_texts=None, + cal_i2t=True, + cal_t2i=True): + text_attn_mask = feats['text_attn_mask'] + image_embeds = feats.get('image_embeds', None) + image_feat = feats['image_feat'] + text_embeds = feats['text_embeds'] + text_feat = feats['text_feat'] + + num_images = num_images or image_feat.size(0) + num_texts = num_texts or text_feat.size(0) + + image_embeds_all = all_gather_concat(image_embeds)[:num_images] + image_feat_all = all_gather_concat(image_feat)[:num_images] + text_feat_all = all_gather_concat(text_feat)[:num_texts] + text_embeds_all = all_gather_concat(text_embeds)[:num_texts] + text_attn_mask_all = all_gather_concat(text_attn_mask)[:num_texts] + + results = [] + if cal_i2t: + result_i2t = self.compute_score_matrix_i2t( + image_feat, + image_embeds, + text_feat_all, + text_embeds_all, + text_attn_mask_all, + ) + results.append( + self._get_predictions(result_i2t, data_samples, mode='i2t')) + if cal_t2i: + result_t2i = self.compute_score_matrix_t2i( + image_feat_all, + image_embeds_all, + text_feat, + text_embeds, + text_attn_mask, + ) + results.append( + self._get_predictions(result_t2i, data_samples, mode='t2i')) + return tuple(results) + + def compute_score_matrix_i2t(self, img_feats, img_embeds, text_feats, + text_embeds, text_atts): + """Compare the score matrix for image-to-text retrieval. Every image + should compare to all the text features. + + Args: + img_feats (torch.Tensor): The input img feats tensor with shape + (M, C). M stands for numbers of samples on a single GPU. + img_embeds (torch.Tensor): The input img embeds tensor with shape + (M, C). M stands for numbers of samples on a single GPU. + text_feats (torch.Tensor): The input text feats tensor with shape + (N, C). N stands for numbers of all samples on all GPUs. + text_embeds (torch.Tensor): The input tensor with shape (N, C). + text_atts (torch.Tensor): The input tensor with shape (N, C). + + Returns: + torch.Tensor: Score matrix of image-to-text retrieval. + """ + # compute i2t sim matrix + sim_matrix_i2t = torch.einsum('mld,nd->mln', img_feats, + text_feats).mean(1) + if self.fast_match: + return sim_matrix_i2t + + score_matrix_i2t = torch.full((img_feats.size(0), text_feats.size(0)), + -100.0).to(self.device) + for i in track_on_main_process( + range(img_feats.size(0)), 'Compute I2T scores...'): + sims = sim_matrix_i2t[i] + topk_sim, topk_idx = sims.topk(k=self.topk, dim=0) + topk_bz = 32 + encoder_output = img_embeds[i].repeat(topk_bz, 1, 1) + encoder_att = torch.ones( + encoder_output.size()[:-1], dtype=torch.long).to(self.device) + for j in range(0, self.topk // topk_bz): + batch_topk = topk_idx[j * topk_bz:(j + 1) * topk_bz] + output = self.text_encoder( + encoder_embeds=text_embeds[batch_topk], + attention_mask=text_atts[batch_topk], + encoder_hidden_states=encoder_output, + encoder_attention_mask=encoder_att, + return_dict=True, + mode='fusion') + score = self.itm_head(output.last_hidden_state[:, 0, :])[:, 1] + score_matrix_i2t[i, batch_topk] = score + return score_matrix_i2t + + def compute_score_matrix_t2i(self, img_feats, img_embeds, text_feats, + text_embeds, text_atts): + """Compare the score matrix for text-to-image retrieval. Every text + should compare to all the image features. + + Args: + img_feats (torch.Tensor): The input img feats tensor with shape + (M, C). M stands for numbers of samples on a single GPU. + img_embeds (torch.Tensor): The input img embeds tensor with shape + (M, C). M stands for numbers of samples on a single GPU. + text_feats (torch.Tensor): The input text feats tensor with shape + (N, C). N stands for numbers of all samples on all GPUs. + text_embeds (torch.Tensor): The input tensor with shape (M, C). + text_atts (torch.Tensor): The input tensor with shape (M, C). + + Returns: + torch.Tensor: Score matrix of text-to-image retrieval. + """ + # compute t2i sim matrix + sim_matrix_t2i = torch.einsum('md,nld->mln', text_feats, + img_feats).mean(1) + + if self.fast_match: + return sim_matrix_t2i + + score_matrix_t2i = torch.full((text_feats.size(0), img_feats.size(0)), + -100.0).to(self.device) + for i in track_on_main_process( + range(text_feats.size(0)), 'Compute T2I scores...'): + sims = sim_matrix_t2i[i] + topk_sim, topk_idx = sims.topk(k=self.topk, dim=0) + topk_bz = 32 + for j in range(0, self.topk // topk_bz): + batch_topk = topk_idx[j * topk_bz:(j + 1) * topk_bz] + encoder_output = img_embeds[batch_topk] + encoder_att = torch.ones( + encoder_output.size()[:-1], + dtype=torch.long).to(self.device) + output = self.text_encoder( + encoder_embeds=text_embeds[i].repeat(topk_bz, 1, 1), + attention_mask=text_atts[i].repeat(topk_bz, 1), + encoder_hidden_states=encoder_output, + encoder_attention_mask=encoder_att, + return_dict=True, + mode='fusion') + score = self.itm_head(output.last_hidden_state[:, 0, :])[:, 1] + score_matrix_t2i[i, batch_topk] = score + return score_matrix_t2i + + def _get_predictions(self, + result: torch.Tensor, + data_samples: List[ActionDataSample], + mode: str = 'i2t'): + """Post-process the output of retriever. + + Args: + result (torch.Tensor): Score matrix of single retrieve, + either from image or text. + data_samples (List[ActionDataSample], optional): The annotation + data of every samples. + mode (str): Retrieve mode, either `i2t` for image to text, or `t2i` + text to image. Defaults to `i2t`. + + Returns: + List[ActionDataSample]: the raw data_samples with + the predicted results. + """ + + # create data sample if not exists + if data_samples is None: + data_samples = [ActionDataSample() for _ in range(result.size(0))] + elif mode == 't2i': + # Process data samples to align with the num of texts. + new_data_samples = [] + for sample in data_samples: + if isinstance(sample.text, (list, tuple)): + texts = sample.text + else: + texts = [sample.text] + for i, text in enumerate(texts): + new_sample = ActionDataSample(text=text) + if 'gt_video_id' in sample: + new_sample.gt_label = sample.gt_video_id[i] + new_data_samples.append(new_sample) + assert len(new_data_samples) == result.size(0) + data_samples = new_data_samples + elif mode == 'i2t': + for sample in data_samples: + if 'gt_text_id' in sample: + sample.gt_label = sample.gt_text_id + else: + raise ValueError(f'Type {mode} is not supported.') + + for data_sample, score in zip(data_samples, result): + idx = score.argmax(keepdim=True).detach() + + data_sample.set_pred_score(score) + data_sample.set_pred_label(idx) + return data_samples diff --git a/mmaction/models/multimodal/vindlu/vindlu_ret_mc.py b/mmaction/models/multimodal/vindlu/vindlu_ret_mc.py new file mode 100644 index 0000000000..d701438bb7 --- /dev/null +++ b/mmaction/models/multimodal/vindlu/vindlu_ret_mc.py @@ -0,0 +1,87 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn.functional as F +from einops import rearrange + +from mmaction.registry import MODELS +from .vindlu_ret import VindLURetrieval + + +@MODELS.register_module() +class VindLURetrievalMC(VindLURetrieval): + """VindLU VQA retrieval multiple choice. + + score_weight (float): Weight coefficient for itm_head score to compute the + choice score. similarity_weight (float): Weight coefficient for similarity + score to compute the choice score. + """ + + def __init__(self, score_weight=0.7, similarity_weight=0.3, **kwargs): + kwargs.pop('text_decoder') + super().__init__(**kwargs) + self.score_weight = score_weight + self.similarity_weight = similarity_weight + + def predict(self, inputs, data_samples, **kwargs): + """Predict captions from a batch of inputs. + + Args: + images (torch.Tensor): The input images tensor with shape + (N, C, ...) in general. + data_samples (List[DataSample], optional): The annotation + data of every samples. Defaults to None. + **kwargs: Other keyword arguments accepted by the ``predict`` + + Returns: + List[ActionDataSample]: Return list of data samples. + """ + num_options_per_q = len(data_samples[0].caption_options) + for sample in data_samples: + sample.text = sample.caption_options + + output = self.extract_feat(inputs, data_samples) + + text_embeds = output['text_embeds'] + text_attn_mask = output['text_attn_mask'] + image_embeds = output['image_embeds'] + image_feat = output['image_feat'] + text_feat = output['text_feat'] + + # compute similarity between vision feat and caption feat + text_feat = rearrange( + text_feat, '(b n) c -> b c n', n=num_options_per_q) + sim = torch.matmul(image_feat.mean(1, keepdim=True), + text_feat).squeeze(1) / self.temp + sim = F.softmax(sim, dim=1).flatten() + + # cross-modal encode + encoder_output = image_embeds.repeat_interleave( + num_options_per_q, dim=0) + image_atts = torch.ones( + encoder_output.size()[:-1], dtype=torch.long).to(inputs.device) + output = self.text_encoder( + encoder_embeds=text_embeds, + attention_mask=text_attn_mask, + encoder_hidden_states=encoder_output, + encoder_attention_mask=image_atts, + return_dict=True, + mode='fusion', + ) + itm_embeds = output.last_hidden_state[:, 0] # [CLS] + + itm_score = F.softmax(self.itm_head(itm_embeds), dim=1)[:, 1] # [bs*5] + score = itm_score * self.score_weight + sim * self.similarity_weight + + pred_answers = score.view(-1, num_options_per_q).max(1)[1].cpu() + + # assemble predictions + ensemble_scores = score.view(-1, num_options_per_q).cpu() # (bsz, 5) + + out_data_samples = [] + for data_sample, ensemble_score, pred_ans in \ + zip(data_samples, ensemble_scores, pred_answers): + data_sample.pred_label = pred_ans.item() + data_sample.score = ensemble_score.numpy() + out_data_samples.append(data_sample) + + return out_data_samples diff --git a/mmaction/models/multimodal/vindlu/vindlu_vqa.py b/mmaction/models/multimodal/vindlu/vindlu_vqa.py new file mode 100644 index 0000000000..87233b9b21 --- /dev/null +++ b/mmaction/models/multimodal/vindlu/vindlu_vqa.py @@ -0,0 +1,266 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import mmengine +import torch +import torch.nn.functional as F +from einops import rearrange + +from mmaction.registry import MODELS +from .vindlu import VindLUBase + + +@MODELS.register_module() +class VindLUVQA(VindLUBase): + """VindLU VQA. + + Args: + text_decoder (dict): Backbone for extracting + multi-modal features. We apply this part as VQA fusion module. + answer_list_path (str, optional): Path to `answer_list.json`. + max_question_len (int): Max text length of question text. + Defaults to 25. + max_answer_len (int): Max text length of answer text. Defaults to 5. + num_ans_candidates (int): Number of answer candidates, used to filter + out answers with low probability. Defaults to 128. + **kwargs: Other keyword arguments accepted by the VindLUBase. + """ + + def __init__(self, + text_decoder: dict, + answer_list_path: Optional[str] = None, + max_question_len: int = 25, + max_answer_len: int = 5, + num_ans_candidates: int = 128, + **kwargs): + super().__init__(**kwargs) + + self.max_question_len = max_question_len + self.max_answer_len = max_answer_len + self.num_ans_candidates = num_ans_candidates + self.answer_list_path = answer_list_path + self.text_decoder_cfg = text_decoder + + # for inference only + if answer_list_path: + self.answer_list = mmengine.load(answer_list_path) + + # delete extra/unnecessary modules inherited from VindLUBase + extra_attributes = ['vision_proj', 'text_proj', 'temp', 'itm_head'] + for attr in extra_attributes: + delattr(self, attr) + + self.text_decoder_cfg.gradient_checkpointing = \ + self.gradient_checkpointing + self.text_decoder = MODELS.build(self.text_decoder_cfg) + + def forward_encoder(self, inputs, data_samples): + # forward vision encoder + image_embeds, _ = self.encode_vision(inputs) + image_embeds = rearrange(image_embeds, 'b t l c -> b (t l) c') + image_atts = torch.ones( + image_embeds.size()[:-1], dtype=torch.long).to(inputs.device) + + # forward text encoder + questions = [sample.question for sample in data_samples] + questions = self.tokenizer( + questions, + padding='max_length', + truncation=True, + max_length=self.max_question_len, + return_tensors='pt').to(inputs.device) + + question_output = self.text_encoder( + questions.input_ids, + attention_mask=questions.attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True) + + return questions, question_output + + def loss(self, inputs, data_samples): + """Calculate losses from a batch of inputs and data samples. + + Args: + inputs (dict): A batch of inputs. The input tensor with of + at least one modality. For image, the value is a tensor + of shape (N, C, ...) in general. + For text, the value is a dict of tokenized text inputs. + data_samples (Optional[List[DataSample]]): + The annotation data of every samples. Defaults to None. + + Returns: + Dict[str, torch.tensor]: a dictionary of loss components of + """ + + questions, question_output = self.forward_encoder(inputs, data_samples) + + weights = torch.cat( + [torch.tensor(sample.gt_answer_weight) for sample in data_samples], + dim=0).to(inputs.device) + raw_answers = [] + for sample in data_samples: + raw_answers.extend(sample.gt_answer) + answer_count = torch.tensor([ + len(sample.gt_answer) for sample in data_samples + ]).to(inputs.device) + answers = [a + ' ' + '[SEP]' for a in raw_answers] + answers = self.tokenizer( + answers, + padding='max_length', + truncation=True, + max_length=self.max_answer_len, + return_tensors='pt').to(inputs.device) + + answer_targets = answers.input_ids.masked_fill( + answers.input_ids == self.tokenizer.pad_token_id, -100) + + question_states = [] + question_atts = [] + for b, n in enumerate(answer_count): + question_states += [question_output.last_hidden_state[b]] * n + question_atts += [questions.attention_mask[b]] * n + question_states = torch.stack(question_states, 0).to(inputs.device) + question_atts = torch.stack(question_atts, 0).to(inputs.device) + + answer_output = self.text_decoder( + answers.input_ids, + attention_mask=answers.attention_mask, + encoder_hidden_states=question_states, + encoder_attention_mask=question_atts, + labels=answer_targets, + return_dict=True, + reduction='none', + ) + loss = weights * answer_output.loss + loss = loss.sum() / inputs.size(0) + + return dict(loss=loss) + + def predict(self, inputs, data_samples, **kwargs): + + questions, question_output = self.forward_encoder(inputs, data_samples) + + raw_answers = self.answer_list + answers = [a + ' ' + '[SEP]' for a in raw_answers] + answers = self.tokenizer( + answers, + padding='max_length', + truncation=True, + max_length=self.max_answer_len, + return_tensors='pt', + ).to(inputs.device) + + topk_ids, topk_probs = self.rank_answer( + question_output.last_hidden_state, questions.attention_mask, + answers.input_ids, answers.attention_mask, self.num_ans_candidates) + + out_data_samples = [] + for data_sample, topk_id, topk_prob in zip(data_samples, topk_ids, + topk_probs): + _, pred = topk_prob.max(dim=0) + data_sample.pred_answer = raw_answers[topk_id[pred]] + out_data_samples.append(data_sample) + + return out_data_samples + + def rank_answer(self, question_states, question_atts, answer_ids, + answer_atts, k): + """ + question_states: (bsz, Lq, d) + answer_ids: answer input id after tokenization, (#answers, La) + """ + num_ques = question_states.size(0) + start_ids = answer_ids[0, 0].repeat(num_ques, 1) # bos token + + start_output = self.text_decoder( + start_ids, + encoder_hidden_states=question_states, + encoder_attention_mask=question_atts, + return_dict=True, + reduction='none', + ) + logits = start_output.logits[:, 0, :] # first token's logit + + # topk_probs: top-k probability + # topk_ids: [num_question, k] + answer_first_token = answer_ids[:, 1] + prob_first_token = F.softmax( + logits, dim=1).index_select( + dim=1, index=answer_first_token) + topk_probs, topk_ids = prob_first_token.topk(k, dim=1) + + # answer input: [num_question*k, answer_len] + input_ids = [] + input_atts = [] + for b, topk_id in enumerate(topk_ids): + input_ids.append(answer_ids.index_select(dim=0, index=topk_id)) + input_atts.append(answer_atts.index_select(dim=0, index=topk_id)) + input_ids = torch.cat(input_ids, dim=0) + input_atts = torch.cat(input_atts, dim=0) + + targets_ids = input_ids.masked_fill( + input_ids == self.tokenizer.pad_token_id, -100) + + question_states = question_states.repeat_interleave(k, dim=0) + question_atts = question_atts.repeat_interleave(k, dim=0) + + output = self.text_decoder( + input_ids, + attention_mask=input_atts, + encoder_hidden_states=question_states, + encoder_attention_mask=question_atts, + labels=targets_ids, + return_dict=True, + reduction='none', + ) + + answer_loss = output.loss + answer_loss = answer_loss.view(input_ids.size(0), -1) + + # topk_prob: first token probability + topk_probs = topk_probs.view(-1, 1) + log_probs = torch.cat([topk_probs.log(), -answer_loss], dim=1) + + # re-calculate log probabilities for the answer sequences + # using chain rule + log_probs_sum = log_probs.sum(1) + log_probs_sum = log_probs_sum.view(num_ques, k) + + topk_probs = F.softmax(log_probs_sum, dim=-1) + # get top-k after re-ranking + topk_probs, rerank_id = topk_probs.topk(k, dim=1) + topk_ids = torch.gather(topk_ids, 1, rerank_id) + + return topk_ids, topk_probs + + def preprocess_state_dict(self, state_dict): + """Preprocess pretrained checkpoint for text_encoder and + text_decoder.""" + for key in list(state_dict.keys()): + if 'bert' in key: + encoder_key = key.replace('bert.', '') + state_dict[encoder_key] = state_dict[key] + + # init text decoder as multimodal encoder + # (last 6 layers of model.text_encoder) + # only for generation tasks like VQA + if self.text_decoder_cfg and 'text_encoder' in key: + if 'layer' in key: + encoder_keys = key.split('.') + layer_num = int(encoder_keys[4]) + if layer_num < self.text_encoder_cfg.fusion_layer: + del state_dict[key] + continue + else: + decoder_layer_num = layer_num - 9 + encoder_keys[4] = str(decoder_layer_num) + encoder_key = '.'.join(encoder_keys) + else: + encoder_key = key + decoder_key = encoder_key.replace('text_encoder', + 'text_decoder') + state_dict[decoder_key] = state_dict[key] + del state_dict[key] + return state_dict diff --git a/mmaction/models/multimodal/vindlu/xbert.py b/mmaction/models/multimodal/vindlu/xbert.py new file mode 100644 index 0000000000..df020ce535 --- /dev/null +++ b/mmaction/models/multimodal/vindlu/xbert.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmaction.registry import MODELS +from .modeling_bert import (BertConfig, BertForMaskedLM, BertLMHeadModel, + BertModel) + + +@MODELS.register_module() +class XBertForMaskedLM(BertForMaskedLM): + + def __init__(self, pretrained_model_name_or_path, fusion_layer, + encoder_width, **kwargs): + config = BertConfig.from_pretrained(pretrained_model_name_or_path) + config.fusion_layer = fusion_layer + config.encoder_width = encoder_width + config.update(kwargs) + super().__init__(config) + + +@MODELS.register_module() +class XBertModel(BertModel): + + def __init__(self, pretrained_model_name_or_path, fusion_layer, + encoder_width, add_pooling_layer, **kwargs): + config = BertConfig.from_pretrained(pretrained_model_name_or_path) + config.fusion_layer = fusion_layer + config.encoder_width = encoder_width + config.update(kwargs) + super().__init__(config, add_pooling_layer) + + +@MODELS.register_module() +class BertDecoder(BertLMHeadModel): + + def __init__(self, pretrained_model_name_or_path, fusion_layer, + encoder_width, **kwargs): + config = BertConfig.from_pretrained(pretrained_model_name_or_path) + config.fusion_layer = fusion_layer + config.encoder_width = encoder_width + config.update(kwargs) + super().__init__(config) diff --git a/mmaction/registry.py b/mmaction/registry.py index 6d7d831db1..f214d514e5 100644 --- a/mmaction/registry.py +++ b/mmaction/registry.py @@ -54,7 +54,7 @@ DATA_SAMPLERS = Registry( 'data sampler', parent=MMENGINE_DATA_SAMPLERS, - locations=['mmaction.engine']) + locations=['mmaction.datasets']) TRANSFORMS = Registry( 'transform', parent=MMENGINE_TRANSFORMS, @@ -132,3 +132,9 @@ # manage function FUNCTION = Registry( 'function', parent=MMENGINE_FUNCTION, locations=['mmaction.mmengine']) + +# Tokenizer to encode sequence +TOKENIZER = Registry( + 'tokenizer', + locations=['mmaction.models'], +) diff --git a/mmaction/utils/__init__.py b/mmaction/utils/__init__.py index af91d382c4..54e78dd2b6 100644 --- a/mmaction/utils/__init__.py +++ b/mmaction/utils/__init__.py @@ -3,17 +3,12 @@ from .gradcam_utils import GradCAM from .misc import (VideoWriter, frame_extract, get_random_string, get_shm_dir, get_str_type, get_thread_id) +from .progress import track, track_on_main_process from .setup_env import register_all_modules from .typing_utils import * # noqa: F401,F403 __all__ = [ - 'collect_env', - 'get_random_string', - 'get_thread_id', - 'get_shm_dir', - 'frame_extract', - 'GradCAM', - 'register_all_modules', - 'VideoWriter', - 'get_str_type', + 'collect_env', 'get_random_string', 'get_thread_id', 'get_shm_dir', + 'frame_extract', 'GradCAM', 'register_all_modules', 'VideoWriter', + 'get_str_type', 'track', 'track_on_main_process' ] diff --git a/mmaction/utils/dependency.py b/mmaction/utils/dependency.py new file mode 100644 index 0000000000..dd8df115ec --- /dev/null +++ b/mmaction/utils/dependency.py @@ -0,0 +1,81 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import re +from functools import wraps +from inspect import isfunction + +from importlib_metadata import PackageNotFoundError, distribution +from mmengine.utils import digit_version + + +def satisfy_requirement(dep): + pat = '(' + '|'.join(['>=', '==', '>']) + ')' + parts = re.split(pat, dep, maxsplit=1) + parts = [p.strip() for p in parts] + package = parts[0] + if len(parts) > 1: + op, version = parts[1:] + op = { + '>=': '__ge__', + '==': '__eq__', + '>': '__gt__', + '<': '__lt__', + '<=': '__le__' + }[op] + else: + op, version = None, None + + try: + dist = distribution(package) + if op is None or getattr(digit_version(dist.version), op)( + digit_version(version)): + return True + except PackageNotFoundError: + pass + + return False + + +def require(dep, install=None): + """A wrapper of function for extra package requirements. + + Args: + dep (str): The dependency package name, like ``transformers`` + or ``transformers>=4.28.0``. + install (str, optional): The installation command hint. Defaults + to None, which means to use "pip install dep". + """ + + def wrapper(fn): + assert isfunction(fn) + + @wraps(fn) + def ask_install(*args, **kwargs): + name = fn.__qualname__.replace('.__init__', '') + ins = install or f'pip install "{dep}"' + raise ImportError( + f'{name} requires {dep}, please install it by `{ins}`.') + + if satisfy_requirement(dep): + fn._verify_require = getattr(fn, '_verify_require', lambda: None) + return fn + + ask_install._verify_require = ask_install + return ask_install + + return wrapper + + +WITH_MULTIMODAL = all( + satisfy_requirement(item) for item in ['transformers>=4.28.0']) + + +def register_multimodal_placeholder(names, registry): + for name in names: + + def ask_install(*args, **kwargs): + raise ImportError( + f'{name} requires extra multi-modal dependencies, please ' + 'install it by `pip install "mmaction2[multimodal]"` ' + 'or `pip install -e ".[multimodal]"`.') + + registry.register_module(name=name, module=ask_install) diff --git a/mmaction/utils/progress.py b/mmaction/utils/progress.py new file mode 100644 index 0000000000..b23f976a42 --- /dev/null +++ b/mmaction/utils/progress.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import mmengine.dist as dist +import rich.progress as progress +from rich.live import Live + +disable_progress_bar = False +global_progress = progress.Progress( + '{task.description}', + progress.BarColumn(), + progress.TaskProgressColumn(show_speed=True), + progress.TimeRemainingColumn(), +) +global_live = Live(global_progress, refresh_per_second=10) + + +def track(sequence, description: str = '', total: Optional[float] = None): + if disable_progress_bar: + yield from sequence + else: + global_live.start() + task_id = global_progress.add_task(description, total=total) + task = global_progress._tasks[task_id] + try: + yield from global_progress.track(sequence, task_id=task_id) + finally: + if task.total is None: + global_progress.update(task_id, total=task.completed) + if all(task.finished for task in global_progress.tasks): + global_live.stop() + for task_id in global_progress.task_ids: + global_progress.remove_task(task_id) + + +def track_on_main_process(sequence, description='', total=None): + if not dist.is_main_process() or disable_progress_bar: + yield from sequence + else: + yield from track(sequence, total=total, description=description) diff --git a/requirements/multimodal.txt b/requirements/multimodal.txt new file mode 100644 index 0000000000..c3503a0875 --- /dev/null +++ b/requirements/multimodal.txt @@ -0,0 +1 @@ +transformers>=4.28.0 diff --git a/setup.py b/setup.py index 4776e54145..94471e5220 100644 --- a/setup.py +++ b/setup.py @@ -191,5 +191,6 @@ def add_mim_extension(): 'tests': parse_requirements('requirements/tests.txt'), 'optional': parse_requirements('requirements/optional.txt'), 'mim': parse_requirements('requirements/mminstall.txt'), + 'multimodal': parse_requirements('requirements/multimodal.txt'), }, zip_safe=False) diff --git a/tests/evaluation/metrics/test_retrieval_metric.py b/tests/evaluation/metrics/test_retrieval_metric.py index cb1f1c72ba..fffc0dbacc 100644 --- a/tests/evaluation/metrics/test_retrieval_metric.py +++ b/tests/evaluation/metrics/test_retrieval_metric.py @@ -1,8 +1,13 @@ # Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import numpy as np import pytest import torch -from mmaction.evaluation.metrics import RetrievalMetric +from mmaction.evaluation.metrics import RetrievalMetric, RetrievalRecall +from mmaction.registry import METRICS +from mmaction.structures import ActionDataSample def generate_data(num_samples=5, feat_dim=10, random_label=False): @@ -47,3 +52,114 @@ def test_acc_metric(): assert eval_results['R1'] == eval_results['R5'] == eval_results[ 'R10'] == 100.0 assert eval_results['MdR'] == eval_results['MnR'] == 1.0 + + +class TestRetrievalRecall(TestCase): + + def test_evaluate(self): + """Test using the metric in the same way as Evalutor.""" + pred = [ + ActionDataSample().set_pred_score(i).set_gt_label(k).to_dict() + for i, k in zip([ + torch.tensor([0.7, 0.0, 0.3]), + torch.tensor([0.5, 0.2, 0.3]), + torch.tensor([0.4, 0.5, 0.1]), + torch.tensor([0.0, 0.0, 1.0]), + torch.tensor([0.0, 0.0, 1.0]), + torch.tensor([0.0, 0.0, 1.0]), + ], [[0], [0], [1], [2], [2], [0]]) + ] + + # Test with score (use score instead of label if score exists) + metric = METRICS.build(dict(type='RetrievalRecall', topk=1)) + metric.process(None, pred) + recall = metric.evaluate(6) + self.assertIsInstance(recall, dict) + self.assertAlmostEqual( + recall['retrieval/Recall@1'], 5 / 6 * 100, places=4) + + # Test with invalid topk + with self.assertRaisesRegex(RuntimeError, 'selected index k'): + metric = METRICS.build(dict(type='RetrievalRecall', topk=10)) + metric.process(None, pred) + metric.evaluate(6) + + with self.assertRaisesRegex(ValueError, '`topk` must be a'): + METRICS.build(dict(type='RetrievalRecall', topk=-1)) + + # Test initialization + metric = METRICS.build(dict(type='RetrievalRecall', topk=5)) + self.assertEqual(metric.topk, (5, )) + + # Test initialization + metric = METRICS.build(dict(type='RetrievalRecall', topk=(1, 2, 5))) + self.assertEqual(metric.topk, (1, 2, 5)) + + def test_calculate(self): + """Test using the metric from static method.""" + + # seq of indices format + y_true = [[0, 2, 5, 8, 9], [1, 4, 6]] + y_pred = [np.arange(10)] * 2 + + # test with average is 'macro' + recall_score = RetrievalRecall.calculate( + y_pred, y_true, topk=1, pred_indices=True, target_indices=True) + expect_recall = 50. + self.assertEqual(recall_score[0].item(), expect_recall) + + # test with tensor input + y_true = torch.Tensor([[1, 0, 1, 0, 0, 1, 0, 0, 1, 1], + [0, 1, 0, 0, 1, 0, 1, 0, 0, 0]]) + y_pred = np.array([np.linspace(0.95, 0.05, 10)] * 2) + recall_score = RetrievalRecall.calculate(y_pred, y_true, topk=1) + expect_recall = 50. + self.assertEqual(recall_score[0].item(), expect_recall) + + # test with topk is 5 + y_pred = np.array([np.linspace(0.95, 0.05, 10)] * 2) + recall_score = RetrievalRecall.calculate(y_pred, y_true, topk=2) + expect_recall = 100. + self.assertEqual(recall_score[0].item(), expect_recall) + + # test with topk is (1, 5) + y_pred = np.array([np.linspace(0.95, 0.05, 10)] * 2) + recall_score = RetrievalRecall.calculate(y_pred, y_true, topk=(1, 5)) + expect_recalls = [50., 100.] + self.assertEqual(len(recall_score), len(expect_recalls)) + for i in range(len(expect_recalls)): + self.assertEqual(recall_score[i].item(), expect_recalls[i]) + + # Test with invalid pred + y_pred = dict() + y_true = [[0, 2, 5, 8, 9], [1, 4, 6]] + with self.assertRaisesRegex(AssertionError, '`pred` must be Seq'): + RetrievalRecall.calculate(y_pred, y_true, True, True) + + # Test with invalid target + y_true = dict() + y_pred = [np.arange(10)] * 2 + with self.assertRaisesRegex(AssertionError, '`target` must be Seq'): + RetrievalRecall.calculate( + y_pred, y_true, topk=1, pred_indices=True, target_indices=True) + + # Test with different length `pred` with `target` + y_true = [[0, 2, 5, 8, 9], [1, 4, 6]] + y_pred = [np.arange(10)] * 3 + with self.assertRaisesRegex(AssertionError, 'Length of `pred`'): + RetrievalRecall.calculate( + y_pred, y_true, topk=1, pred_indices=True, target_indices=True) + + # Test with invalid pred + y_true = [[0, 2, 5, 8, 9], dict()] + y_pred = [np.arange(10)] * 2 + with self.assertRaisesRegex(AssertionError, '`target` should be'): + RetrievalRecall.calculate( + y_pred, y_true, topk=1, pred_indices=True, target_indices=True) + + # Test with invalid target + y_true = [[0, 2, 5, 8, 9], [1, 4, 6]] + y_pred = [np.arange(10), dict()] + with self.assertRaisesRegex(AssertionError, '`pred` should be'): + RetrievalRecall.calculate( + y_pred, y_true, topk=1, pred_indices=True, target_indices=True) diff --git a/tools/data/msrvtt/README.md b/tools/data/msrvtt/README.md new file mode 100644 index 0000000000..e9e72ad6b4 --- /dev/null +++ b/tools/data/msrvtt/README.md @@ -0,0 +1,68 @@ +# Preparing MSR-VTT Retrieval/ Video Question-Answering Dataset + +## Introduction + + + +```BibTeX +@inproceedings{xu2016msr, + title={Msr-vtt: A large video description dataset for bridging video and language}, + author={Xu, Jun and Mei, Tao and Yao, Ting and Rui, Yong}, + booktitle={CVPR}, + pages={5288--5296}, + year={2016} +} +``` + +Before preparing the dataset, please make sure that the directory is located at `$MMACTION2/tools/data/msrvtt/`. + +## Step 1. Download Annotation Files + +You can directly download the following annotation files related to MSR-VTT from the [Google Drive link](https://drive.google.com/drive/folders/12cr94wT8j7pR09AR2nmQg6o26Y1arI50) provided by [VindLU](https://github.com/klauscc) and place them in the `$MMACTION2/tools/data/msrvtt/annotations` directory: + +- [msrvtt_qa_train.json](https://drive.google.com/file/d/12dJq5_7v8FytrJwrPB_f22tET1MmGCNh/view?usp=drive_link) +- [msrvtt_qa_val.json](https://drive.google.com/file/d/138q-A-V8fCC2nBYJgqkQa3gBfXVNbNNd/view?usp=drive_link) +- [msrvtt_qa_test.json](https://drive.google.com/file/d/13IiEcUMHiNppWhGwVY1eAaip6iSJM35A/view?usp=drive_link) +- [msrvtt_qa_answer_list.json](https://drive.google.com/file/d/131euz_dssRkDTk3-ioAS5ZsvIxS_Tt4M/view?usp=drive_link) +- [msrvtt_mc_test.json](https://drive.google.com/file/d/13FrUQ2ZDsNDraP7lfnKvTArPIgdtHuLC/view?usp=drive_link) +- [msrvtt_ret_train9k.json](https://drive.google.com/file/d/13OVo0XRdVWTHlFFxbKg3daYCHsMbJxyd/view?usp=drive_link) +- [msrvtt_ret_train7k.json](https://drive.google.com/file/d/13ID97BX4ExO6mWPIUMp-GzXcPBkviSLx/view?usp=drive_link) +- [msrvtt_ret_test1k.json](https://drive.google.com/file/d/13FLrjI-aleKeU7LbJMDrYgktX7MbTbzu/view?usp=drive_link) +- [msrvtt_test1k.json](https://drive.google.com/file/d/12z6y-DNwIfICSzOhekbJwSbf7z2hlibE/view?usp=drive_link) + +## Step 2. Prepare Video Data + +You can refer to the [official website](https://www.microsoft.com/en-us/research/publication/msr-vtt-a-large-video-description-dataset-for-bridging-video-and-language/) of this dataset for basic information. Run the following commands to prepare the MSRVTT video files: + +```shell +# Download original videos +bash download_msrvtt.sh +# Preprocess videos to lower FPS and dimensions +bash compress_msrvtt.sh +``` + +After completing the above preparation steps, the directory structure will be as follows: + +``` +mmaction2 +├── mmaction +├── tools +├── configs +├── data +│ └── msrvtt +│ │ ├── annotations +│ │ │ ├── msrvtt_qa_train.json +│ │ │ ├── msrvtt_qa_val.json +│ │ │ ├── msrvtt_qa_test.json +│ │ │ ├── msrvtt_qa_answer_list.json +│ │ │ ├── msrvtt_mc_test.json +│ │ │ ├── msrvtt_ret_train9k.json +│ │ │ ├── msrvtt_ret_train7k.json +│ │ │ ├── msrvtt_ret_test1k.json +│ │ │ └── msrvtt_test1k.json +│ │ └── videos_2fps_224 +│ │ ├── video0.mp4 +│ │ ├── video1.mp4 +│ │ ├── ... +│ │ └── video9999.mp4 +``` diff --git a/tools/data/msrvtt/README_zh-CN.md b/tools/data/msrvtt/README_zh-CN.md new file mode 100644 index 0000000000..bbd3a009c4 --- /dev/null +++ b/tools/data/msrvtt/README_zh-CN.md @@ -0,0 +1,68 @@ +# 准备 MSR-VTT 检索/视频问答数据集 + +## 简介 + + + +```BibTeX +@inproceedings{xu2016msr, + title={Msr-vtt: A large video description dataset for bridging video and language}, + author={Xu, Jun and Mei, Tao and Yao, Ting and Rui, Yong}, + booktitle={CVPR}, + pages={5288--5296}, + year={2016} +} +``` + +在数据集准备前,请确保命令行当前路径为 `$MMACTION2/tools/data/msrvtt/`。 + +## 步骤 1. 下载标注文件 + +用户可从 [VindLU](https://github.com/klauscc/VindLU) 提供的 [Google Drive 链接](https://drive.google.com/drive/folders/12cr94wT8j7pR09AR2nmQg6o26Y1arI50)中直接下载以下与 MSR-VTT 相关的标注文件, 并放置到 `$MMACTION2/tools/data/msrvtt/annotations` 路径下: + +- [msrvtt_qa_train.json](https://drive.google.com/file/d/12dJq5_7v8FytrJwrPB_f22tET1MmGCNh/view?usp=drive_link) +- [msrvtt_qa_val.json](https://drive.google.com/file/d/138q-A-V8fCC2nBYJgqkQa3gBfXVNbNNd/view?usp=drive_link) +- [msrvtt_qa_test.json](https://drive.google.com/file/d/13IiEcUMHiNppWhGwVY1eAaip6iSJM35A/view?usp=drive_link) +- [msrvtt_qa_answer_list.json](https://drive.google.com/file/d/131euz_dssRkDTk3-ioAS5ZsvIxS_Tt4M/view?usp=drive_link) +- [msrvtt_mc_test.json](https://drive.google.com/file/d/13FrUQ2ZDsNDraP7lfnKvTArPIgdtHuLC/view?usp=drive_link) +- [msrvtt_ret_train9k.json](https://drive.google.com/file/d/13OVo0XRdVWTHlFFxbKg3daYCHsMbJxyd/view?usp=drive_link) +- [msrvtt_ret_train7k.json](https://drive.google.com/file/d/13ID97BX4ExO6mWPIUMp-GzXcPBkviSLx/view?usp=drive_link) +- [msrvtt_ret_test1k.json](https://drive.google.com/file/d/13FLrjI-aleKeU7LbJMDrYgktX7MbTbzu/view?usp=drive_link) +- [msrvtt_test1k.json](https://drive.google.com/file/d/12z6y-DNwIfICSzOhekbJwSbf7z2hlibE/view?usp=drive_link) + +## 步骤 2. 准备视频数据 + +用户可参考该数据集的[官网](https://www.microsoft.com/en-us/research/publication/msr-vtt-a-large-video-description-dataset-for-bridging-video-and-language/),以获取数据集相关的基本信息。运行下面的命令准备 MSRVTT 视频文件: + +```shell +# download original videos +bash download_msrvtt.sh +# preprocess videos to lower FPS and dimension +bash compress_msrvtt.sh +``` + +完成上述准备步骤后,文件目录如下: + +``` +mmaction2 +├── mmaction +├── tools +├── configs +├── data +│ └── msrvtt +│ │ ├── annotations +│ │ │ ├── msrvtt_qa_train.json +│ │ │ ├── msrvtt_qa_val.json +│ │ │ ├── msrvtt_qa_test.json +│ │ │ ├── msrvtt_qa_answer_list.json +│ │ │ ├── msrvtt_mc_test.json +│ │ │ ├── msrvtt_ret_train9k.json +│ │ │ ├── msrvtt_ret_train7k.json +│ │ │ ├── msrvtt_ret_test1k.json +│ │ │ └── msrvtt_test1k.json +│ │ └── videos_2fps_224 +│ │ ├── video0.mp4 +│ │ ├── video1.mp4 +│ │ ├── ... +│ │ └── video9999.mp4 +``` diff --git a/tools/data/msrvtt/compress.py b/tools/data/msrvtt/compress.py new file mode 100644 index 0000000000..48f022ddba --- /dev/null +++ b/tools/data/msrvtt/compress.py @@ -0,0 +1,192 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Used to compress videos (FPS and dimensions) in the Singularity project. + +copied from https://github.com/klauscc/VindLU +""" +import argparse +import os +import shutil +import subprocess +from multiprocessing import Pool +from os.path import exists, join +from pathlib import Path + +try: + from psutil import cpu_count +except ImportError: + from multiprocessing import cpu_count + +from functools import partial + +from PIL import Image +from tqdm import tqdm + + +def resize_image(input_path, output_path, size=224): + with Image.open(input_path) as img: + w, h = img.width, img.height + r = 1. * w / h + if w > h: + h = size + w = r * size + else: + h = size / r + w = size + + img_resized = img.resize((int(w), int(h))) + img_resized.save(output_path) + + +def _compress_images(input_output_pair, size=224): + """Scale and downsample an input image to a given fps and size (shorter + side size). + + This also removes the audio from the image. + """ + input_image_path, output_image_path = input_output_pair + try: + resize_image(input_image_path, output_image_path, size) + except Exception as e: + print(f'Caught Exception {e}') + + +def _compress_videos(input_output_pair, size=224, fps=3): + """Scale and downsample an input video to a given fps and size (shorter + side size). + + This also removes the audio from the video. + """ + input_file_path, output_file_path = input_output_pair + try: + command = [ + 'ffmpeg', + '-y', # (optional) overwrite output file if it exists + '-i', + input_file_path, + '-filter:v', # no audio + f"scale='if(gt(a,1),trunc(oh*a/2)*2,{size})':'if(gt(a,1),{size},trunc(ow*a/2)*2)'", # noqa: E501 + '-map', + '0:v', # no audio + '-r', + str(fps), # frames per second + # '-g', str(16), + output_file_path, + ] + subprocess.run( + command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + except Exception as e: + raise e + + +def _compress(input_output_pair, fps=3, size=224, file_type='image'): + if file_type == 'image': + _compress_images(input_output_pair, size) + elif file_type == 'video': + _compress_videos(input_output_pair, size, fps) + + +def prepare_input_output_pairs(input_root, + output_root, + input_file_list_path=None): + # filename list in `input_file_list_path` can be created very fast using `ls -U . >> ../video_filenames.txt` # noqa: E501 + if input_file_list_path: + with open(input_file_list_path, 'r') as f: + filenames = [s.strip() for s in f.readlines()] + else: + filenames = [ + video_path.name for video_path in Path(input_root).glob('*.mp4') + ] + print(f'There are {len(filenames)} video/images files loaded from list.') + input_file_path_list = [] + output_file_path_list = [] + for e in tqdm(filenames, desc='find un-processed videos/images'): + input_file_path = join(input_root, e) + output_file_path = join(output_root, e) + if not exists(output_file_path): + input_file_path_list.append(input_file_path) + output_file_path_list.append(output_file_path) + return input_file_path_list, output_file_path_list + + +def run_compress(): + parser = argparse.ArgumentParser( + description='Compress videos/images for speed-up') + parser.add_argument( + '--input_root', type=str, help='input root', required=True) + parser.add_argument( + '--input_file_list_path', + type=str, + default=None, + help='list of video filenames under args.input_root, it can be ' + 'created efficiently with `ls -U /path/to/video >> /path/to/video_filenames.txt`' # noqa: E501 + ) + parser.add_argument( + '--output_root', type=str, help='output root', required=True) + parser.add_argument( + '--size', + type=int, + default=224, + help='shorter side size, aspect ratio is kept') + parser.add_argument('--num_workers', type=int, default=24, help='#workers') + parser.add_argument( + '--fps', + type=int, + default=3, + help='fps for output video, ignored if file_type == image') + parser.add_argument( + '--file_type', + type=str, + choices=['image', 'video'], + help='input file type') + args = parser.parse_args() + + # set paths + input_root = args.input_root + output_root = args.output_root + assert input_root != output_root + if not exists(output_root): + os.makedirs(output_root, exist_ok=True) + + # prepare and find un-processed + input_file_path_list, output_file_path_list = prepare_input_output_pairs( + input_root, + output_root, + input_file_list_path=args.input_file_list_path, + ) + print(f'input_file_path_list[:3] {input_file_path_list[:3]}') + print(f'output_file_path_list[:3] {output_file_path_list[:3]}') + print('Total videos/images need to process: {}'.format( + len(input_file_path_list))) + + # start parallel jobs + num_cores = cpu_count() + num_workers = args.num_workers + print( + f'Begin with {num_cores}-core logical processor, {num_workers} workers' + ) + compress = partial( + _compress, fps=args.fps, size=args.size, file_type=args.file_type) + input_pairs = list(zip(input_file_path_list, output_file_path_list)) + with Pool(num_workers) as pool, tqdm( + total=len(input_file_path_list), + desc='re-encoding videos/images') as pbar: + for idx, _ in enumerate( + pool.imap_unordered(compress, input_pairs, chunksize=32)): + pbar.update(1) + + # copy-paste failed files + print('Compress finished, copy-paste failed files...') + copy_count = 0 + for input_file_path, output_file_path in zip(input_file_path_list, + output_file_path_list): + if exists(input_file_path): + if exists(output_file_path) is False or os.path.getsize( + output_file_path) < 1.: + copy_count += 1 + shutil.copyfile(input_file_path, output_file_path) + print('Copy and replace file: {}'.format(output_file_path)) + print(f'copy_count {copy_count}') + + +if __name__ == '__main__': + run_compress() diff --git a/tools/data/msrvtt/compress_msrvtt.sh b/tools/data/msrvtt/compress_msrvtt.sh new file mode 100644 index 0000000000..18822ce312 --- /dev/null +++ b/tools/data/msrvtt/compress_msrvtt.sh @@ -0,0 +1,10 @@ +#!/usr/bin/env bash + +FPS=2 +SIZE=224 +DATA_DIR="../../../data/msrvtt/videos" +OUT_DIR="../../../data/msrvtt/videos_2fps_224" + +python compress.py \ + --input_root=${DATA_DIR} --output_root=${OUT_DIR} \ + --fps=${FPS} --size=${SIZE} --file_type=video --num_workers 24 diff --git a/tools/data/msrvtt/download_msrvtt.sh b/tools/data/msrvtt/download_msrvtt.sh new file mode 100644 index 0000000000..6ae40d942d --- /dev/null +++ b/tools/data/msrvtt/download_msrvtt.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash + +DATA_DIR="../../../data/msrvtt" +mkdir -p ${DATA_DIR} + +if [ -f "MSRVTT.zip" ]; then + echo "MSRVTT.zip exists, skip downloading!" +else + echo "Downloading MSRVTT.zip." + wget https://www.robots.ox.ac.uk/~maxbain/frozen-in-time/data/MSRVTT.zip +fi + +echo "Processing videos started." +unzip -q MSRVTT.zip -d ${DATA_DIR} +mkdir -p "${DATA_DIR}/videos/" && find "${DATA_DIR}/MSRVTT/videos/all" -name "video*.mp4" -exec mv {} "${DATA_DIR}/videos/" \; +echo "Processing videos completed." + +rm -rf "${DATA_DIR}/MSRVTT" +rm -rf "${DATA_DIR}/msrvtt_data" +rm msrvtt_data.zip +rm MSRVTT.zip +echo "The preparation of the msrvtt dataset has been successfully completed." diff --git a/tools/data/video_retrieval/README_zh-CN.md b/tools/data/video_retrieval/README_zh-CN.md index 1d4374daea..1814ff36e2 100644 --- a/tools/data/video_retrieval/README_zh-CN.md +++ b/tools/data/video_retrieval/README_zh-CN.md @@ -34,7 +34,7 @@ bash prepare_msrvtt.sh ``` -完场上述准备步骤后,文件目录如下: +完成上述准备步骤后,文件目录如下: ``` mmaction2