Skip to content

Commit

Permalink
[Update] Update audio-based model (#2570)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dai-Wenxun authored Jul 20, 2023
1 parent dd034bb commit c548788
Show file tree
Hide file tree
Showing 13 changed files with 145 additions and 577 deletions.
11 changes: 0 additions & 11 deletions configs/_base_/models/tsn_r18_audio.py

This file was deleted.

17 changes: 5 additions & 12 deletions configs/recognition_audio/resnet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

<!-- [ABSTRACT] -->

We present Audiovisual SlowFast Networks, an architecture for integrated audiovisual perception. AVSlowFast has Slow and Fast visual pathways that are deeply inte- grated with a Faster Audio pathway to model vision and sound in a unified representation. We fuse audio and vi- sual features at multiple layers, enabling audio to con- tribute to the formation of hierarchical audiovisual con- cepts. To overcome training difficulties that arise from dif- ferent learning dynamics for audio and visual modalities, we introduce DropPathway, which randomly drops the Au- dio pathway during training as an effective regularization technique. Inspired by prior studies in neuroscience, we perform hierarchical audiovisual synchronization to learn joint audiovisual features. We report state-of-the-art results on six video action classification and detection datasets, perform detailed ablation studies, and show the gener- alization of AVSlowFast to learn self-supervised audiovi- sual features. Code will be made available at: https: /facebookresearch/SlowFast.
We present Audiovisual SlowFast Networks, an architecture for integrated audiovisual perception. AVSlowFast has Slow and Fast visual pathways that are deeply integrated with a Faster Audio pathway to model vision and sound in a unified representation. We fuse audio and visual features at multiple layers, enabling audio to contribute to the formation of hierarchical audiovisual concepts. To overcome training difficulties that arise from different learning dynamics for audio and visual modalities, we introduce DropPathway, which randomly drops the Au- dio pathway during training as an effective regularization technique. Inspired by prior studies in neuroscience, we perform hierarchical audiovisual synchronization to learn joint audiovisual features. We report state-of-the-art results on six video action classification and detection datasets, perform detailed ablation studies, and show the generalization of AVSlowFast to learn self-supervised audiovisual features.

<!-- [IMAGE] -->

Expand All @@ -20,16 +20,9 @@ We present Audiovisual SlowFast Networks, an architecture for integrated audiovi

### Kinetics-400

| frame sampling strategy | n_fft | gpus | backbone | pretrain | top1 acc | top5 acc | testing protocol | gpu_mem(M) | config | ckpt | log |
| :---------------------: | :---: | :--: | :------: | :------: | :------: | :------: | :--------------: | :--------: | :------------------------------------: | :----------------------------------: | :----------------------------------: |
| 64x1x1 | 1024 | 8 | Resnet18 | None | 19.7 | 35.75 | 10 clips | 1897 | [config](/configs/recognition_audio/resnet/tsn_r18_8xb320-64x1x1-100e_kinetics400-audio-feature.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/recognition_audio/resnet/tsn_r18_8xb320-64x1x1-100e_kinetics400-audio-feature/tsn_r18_8xb320-64x1x1-100e_kinetics400-audio-feature_20201012-bf34df6c.pth) | [log](https://download.openmmlab.com/mmaction/v1.0/recognition_audio/resnet/tsn_r18_8xb320-64x1x1-100e_kinetics400-audio-feature/tsn_r18_8xb320-64x1x1-100e_kinetics400-audio-feature.log) |

1. The **gpus** indicates the number of gpus we used to get the checkpoint. It is noteworthy that the configs we provide are used for 8 gpus as default.
According to the [Linear Scaling Rule](https://arxiv.org/abs/1706.02677), you may set the learning rate proportional to the batch size if you use different GPUs or videos per GPU,
e.g., lr=0.01 for 4 GPUs x 2 video/gpu and lr=0.08 for 16 GPUs x 4 video/gpu.
2. The validation set of Kinetics400 we used consists of 19796 videos. These videos are available at [Kinetics400-Validation](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155136485_link_cuhk_edu_hk/EbXw2WX94J1Hunyt3MWNDJUBz-nHvQYhO9pvKqm6g39PMA?e=a9QldB). The corresponding [data list](https://download.openmmlab.com/mmaction/dataset/k400_val/kinetics_val_list.txt) (each line is of the format 'video_id, num_frames, label_index') and the [label map](https://download.openmmlab.com/mmaction/dataset/k400_val/kinetics_class2ind.txt) are also available.

For more details on data preparation, you can refer to `Prepare audio` in [Data Preparation Tutorial](/docs/en/user_guides/prepare_dataset.md).
| frame sampling strategy | n_fft | gpus | backbone | pretrain | top1 acc | top5 acc | testing protocol | FLOPs | params | config | ckpt | log |
| :---------------------: | :---: | :--: | :------: | :------: | :------: | :------: | :--------------: | :---: | :----: | :------------------------------------: | :----------------------------------: | :---------------------------------: |
| 64x1x1 | 1024 | 8 | Resnet18 | None | 13.7 | 27.3 | 1 clips | 0.37G | 11.4M | [config](/configs/recognition_audio/resnet/tsn_r18_8xb320-64x1x1-100e_kinetics400-audio-feature.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/recognition_audio/resnet/tsn_r18_8xb320-64x1x1-100e_kinetics400-audio-feature/tsn_r18_8xb320-64x1x1-100e_kinetics400-audio-feature_20230702-e4642fb0.pth) | [log](https://download.openmmlab.com/mmaction/v1.0/recognition_audio/resnet/tsn_r18_8xb320-64x1x1-100e_kinetics400-audio-feature/tsn_r18_8xb320-64x1x1-100e_kinetics400-audio-feature.log) |

## Train

Expand All @@ -43,7 +36,7 @@ Example: train ResNet model on Kinetics-400 audio dataset in a deterministic opt

```shell
python tools/train.py configs/recognition_audio/resnet/tsn_r18_8xb320-64x1x1-100e_kinetics400-audio-feature.py \
--cfg-options randomness.seed=0 randomness.deterministic=True
--seed 0 --deterministic
```

For more details, you can refer to the **Training** part in the [Training and Test Tutorial](/docs/en/user_guides/train_test.md).
Expand Down
12 changes: 8 additions & 4 deletions configs/recognition_audio/resnet/metafile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,20 @@ Models:
In Collection: Audio
Metadata:
Architecture: ResNet18
Batch Size: 320
Epochs: 100
FLOPs: 0.37G
Parameters: 11.4M
Pretrained: None
n_fft: 1024
Training Data: Kinetics-400
Training Resources: 8 GPUs
n_fft: 1024
Modality: Audio
Results:
- Dataset: Kinetics-400
Task: Action Recognition
Metrics:
Top 1 Accuracy: 19.7
Top 5 Accuracy: 35.75
Top 1 Accuracy: 13.7
Top 5 Accuracy: 27.3
Training Log: https://download.openmmlab.com/mmaction/v1.0/recognition_audio/resnet/tsn_r18_8xb320-64x1x1-100e_kinetics400-audio-feature/tsn_r18_8xb320-64x1x1-100e_kinetics400-audio-feature.log
Weights: https://download.openmmlab.com/mmaction/v1.0/recognition_audio/resnet/tsn_r18_8xb320-64x1x1-100e_kinetics400-audio-feature/tsn_r18_8xb320-64x1x1-100e_kinetics400-audio-feature_20201012-bf34df6c.pth
Weights: https://download.openmmlab.com/mmaction/v1.0/recognition_audio/resnet/tsn_r18_8xb320-64x1x1-100e_kinetics400-audio-feature/tsn_r18_8xb320-64x1x1-100e_kinetics400-audio-feature_20230702-e4642fb0.pth
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@
_base_ = [
'../../_base_/models/tsn_r18_audio.py', '../../_base_/default_runtime.py'
]
_base_ = '../../_base_/default_runtime.py'

# model settings
model = dict(
type='RecognizerAudio',
backbone=dict(type='ResNet', depth=18, in_channels=1, norm_eval=False),
cls_head=dict(
type='TSNAudioHead',
num_classes=400,
in_channels=512,
dropout_ratio=0.5,
init_std=0.01,
average_clips='prob'))

# dataset settings
dataset_type = 'AudioDataset'
data_root = 'data/kinetics400/audio_features_train'
data_root_val = 'data/kinetics400/audio_features_val'
ann_file_train = 'data/kinetics400/kinetics400_val_list_audio_features.txt'
ann_file_val = 'data/kinetics400/kinetics400_val_list_audio_features.txt'
ann_file_test = 'data/kinetics400/kinetics400_val_list_audio_features.txt'
data_root = 'data/kinetics400'
ann_file_train = 'kinetics400_train_list_audio_features.txt'
ann_file_val = 'kinetics400_val_list_audio_features.txt'
ann_file_test = 'kinetics400_val_list_audio_features.txt'

train_pipeline = [
dict(type='LoadAudioFeature'),
dict(type='SampleFrames', clip_len=64, frame_interval=1, num_clips=1),
Expand All @@ -28,53 +38,42 @@
dict(type='FormatAudioShape', input_format='NCTF'),
dict(type='PackActionInputs')
]
test_pipeline = [
dict(type='LoadAudioFeature'),
dict(
type='SampleFrames',
clip_len=64,
frame_interval=1,
num_clips=10,
test_mode=True),
dict(type='AudioFeatureSelector'),
dict(type='FormatAudioShape', input_format='NCTF'),
dict(type='PackActionInputs')
]
test_pipeline = val_pipeline

train_dataloader = dict(
batch_size=320,
num_workers=2,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type,
ann_file=ann_file_train,
data_prefix=dict(audio=data_root_val),
suffix='.npy',
pipeline=train_pipeline))
pipeline=train_pipeline,
data_root=data_root,
data_prefix=dict(audio='audio_features_train')))
val_dataloader = dict(
batch_size=320,
num_workers=2,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
ann_file=ann_file_val,
pipeline=val_pipeline,
data_prefix=dict(audio=data_root_val),
suffix='.npy',
data_root=data_root,
data_prefix=dict(audio='audio_features_val'),
test_mode=True))
test_dataloader = dict(
batch_size=1,
num_workers=2,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
ann_file=ann_file_test,
pipeline=test_pipeline,
data_prefix=dict(audio=data_root_val),
suffix='.npy',
data_root=data_root,
data_prefix=dict(audio='audio_features_val'),
test_mode=True))

val_evaluator = dict(type='AccMetric')
Expand All @@ -90,8 +89,7 @@
]

optim_wrapper = dict(
optimizer=dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001),
optimizer=dict(type='SGD', lr=0.2, momentum=0.9, weight_decay=0.0001),
clip_grad=dict(max_norm=40, norm_type=2))

default_hooks = dict(
checkpoint=dict(max_keep_ckpts=3, interval=5), logger=dict(interval=20))
default_hooks = dict(checkpoint=dict(max_keep_ckpts=3, interval=5))

This file was deleted.

47 changes: 15 additions & 32 deletions mmaction/datasets/audio_dataset.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,21 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import Callable, List, Optional, Union
from typing import Callable, Dict, List, Optional, Union

import torch
from mmengine.utils import check_file_exist

from mmaction.registry import DATASETS
from mmaction.utils import ConfigType
from .base import BaseActionDataset


@DATASETS.register_module()
class AudioDataset(BaseActionDataset):
"""Audio dataset for action recognition. Annotation file can be that of the
rawframe dataset, or:
"""Audio dataset for action recognition.
.. code-block:: txt
some/directory-1.wav 163 1
some/directory-2.wav 122 1
some/directory-3.wav 258 2
some/directory-4.wav 234 2
some/directory-5.wav 295 3
some/directory-6.wav 121 3
The ann_file is a text file with multiple lines, and each line indicates
a sample audio or extracted audio feature with the filepath, total frames
of the raw video and label, which are split with a whitespace.
Example of a annotation file:
.. code-block:: txt
some/directory-1.npy 163 1
Expand All @@ -33,26 +27,22 @@ class AudioDataset(BaseActionDataset):
Args:
ann_file (str): Path to the annotation file.
pipeline (List[Union[dict, ConfigDict, Callable]]): A sequence of
data transforms.
data_prefix (dict or ConfigDict, optional): Path to a directory where
pipeline (list[dict | callable]): A sequence of data transforms.
data_prefix (dict): Path to a directory where
audios are held. Defaults to ``dict(audio='')``.
multi_class (bool): Determines whether it is a multi-class
recognition dataset. Defaults to False.
num_classes (int, optional): Number of classes in the dataset.
Defaults to None.
suffix (str): The suffix of the audio file. Defaults to ``.wav``.
"""

def __init__(self,
ann_file: str,
pipeline: List[Union[ConfigType, Callable]],
data_prefix: ConfigType = dict(audio=''),
pipeline: List[Union[Dict, Callable]],
data_prefix: Dict = dict(audio=''),
multi_class: bool = False,
num_classes: Optional[int] = None,
suffix: str = '.wav',
**kwargs) -> None:
self.suffix = suffix
super().__init__(
ann_file,
pipeline,
Expand All @@ -62,8 +52,8 @@ def __init__(self,
modality='Audio',
**kwargs)

def load_data_list(self) -> List[dict]:
"""Load annotation file to get video information."""
def load_data_list(self) -> List[Dict]:
"""Load annotation file to get audio information."""
check_file_exist(self.ann_file)
data_list = []
with open(self.ann_file, 'r') as fin:
Expand All @@ -73,25 +63,18 @@ def load_data_list(self) -> List[dict]:
idx = 0
filename = line_split[idx]
if self.data_prefix['audio'] is not None:
if not filename.endswith(self.suffix):
filename = osp.join(self.data_prefix['audio'],
filename + self.suffix)
else:
filename = osp.join(self.data_prefix['audio'],
filename)
filename = osp.join(self.data_prefix['audio'], filename)
video_info['audio_path'] = filename
idx += 1
# idx for total_frames
video_info['total_frames'] = int(line_split[idx])
idx += 1
# idx for label[s]
# idx for label
label = [int(x) for x in line_split[idx:]]
assert label, f'missing label in line: {line}'
if self.multi_class:
assert self.num_classes is not None
onehot = torch.zeros(self.num_classes)
onehot[label] = 1.0
video_info['label'] = onehot
video_info['label'] = label
else:
assert len(label) == 1
video_info['label'] = label[0]
Expand Down
Loading

0 comments on commit c548788

Please sign in to comment.