diff --git a/configs/_base_/models/tsn_r18_audio.py b/configs/_base_/models/tsn_r18_audio.py deleted file mode 100644 index be21b44c0b..0000000000 --- a/configs/_base_/models/tsn_r18_audio.py +++ /dev/null @@ -1,11 +0,0 @@ -# model settings -model = dict( - type='RecognizerAudio', - backbone=dict(type='ResNet', depth=18, in_channels=1, norm_eval=False), - cls_head=dict( - type='TSNAudioHead', - num_classes=400, - in_channels=512, - dropout_ratio=0.5, - init_std=0.01, - average_clips='prob')) diff --git a/configs/recognition_audio/resnet/README.md b/configs/recognition_audio/resnet/README.md index f6386e313f..3a58b201c7 100644 --- a/configs/recognition_audio/resnet/README.md +++ b/configs/recognition_audio/resnet/README.md @@ -8,7 +8,7 @@ -We present Audiovisual SlowFast Networks, an architecture for integrated audiovisual perception. AVSlowFast has Slow and Fast visual pathways that are deeply inte- grated with a Faster Audio pathway to model vision and sound in a unified representation. We fuse audio and vi- sual features at multiple layers, enabling audio to con- tribute to the formation of hierarchical audiovisual con- cepts. To overcome training difficulties that arise from dif- ferent learning dynamics for audio and visual modalities, we introduce DropPathway, which randomly drops the Au- dio pathway during training as an effective regularization technique. Inspired by prior studies in neuroscience, we perform hierarchical audiovisual synchronization to learn joint audiovisual features. We report state-of-the-art results on six video action classification and detection datasets, perform detailed ablation studies, and show the gener- alization of AVSlowFast to learn self-supervised audiovi- sual features. Code will be made available at: https: //github.com/facebookresearch/SlowFast. +We present Audiovisual SlowFast Networks, an architecture for integrated audiovisual perception. AVSlowFast has Slow and Fast visual pathways that are deeply integrated with a Faster Audio pathway to model vision and sound in a unified representation. We fuse audio and visual features at multiple layers, enabling audio to contribute to the formation of hierarchical audiovisual concepts. To overcome training difficulties that arise from different learning dynamics for audio and visual modalities, we introduce DropPathway, which randomly drops the Au- dio pathway during training as an effective regularization technique. Inspired by prior studies in neuroscience, we perform hierarchical audiovisual synchronization to learn joint audiovisual features. We report state-of-the-art results on six video action classification and detection datasets, perform detailed ablation studies, and show the generalization of AVSlowFast to learn self-supervised audiovisual features. @@ -20,16 +20,9 @@ We present Audiovisual SlowFast Networks, an architecture for integrated audiovi ### Kinetics-400 -| frame sampling strategy | n_fft | gpus | backbone | pretrain | top1 acc | top5 acc | testing protocol | gpu_mem(M) | config | ckpt | log | -| :---------------------: | :---: | :--: | :------: | :------: | :------: | :------: | :--------------: | :--------: | :------------------------------------: | :----------------------------------: | :----------------------------------: | -| 64x1x1 | 1024 | 8 | Resnet18 | None | 19.7 | 35.75 | 10 clips | 1897 | [config](/configs/recognition_audio/resnet/tsn_r18_8xb320-64x1x1-100e_kinetics400-audio-feature.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/recognition_audio/resnet/tsn_r18_8xb320-64x1x1-100e_kinetics400-audio-feature/tsn_r18_8xb320-64x1x1-100e_kinetics400-audio-feature_20201012-bf34df6c.pth) | [log](https://download.openmmlab.com/mmaction/v1.0/recognition_audio/resnet/tsn_r18_8xb320-64x1x1-100e_kinetics400-audio-feature/tsn_r18_8xb320-64x1x1-100e_kinetics400-audio-feature.log) | - -1. The **gpus** indicates the number of gpus we used to get the checkpoint. It is noteworthy that the configs we provide are used for 8 gpus as default. - According to the [Linear Scaling Rule](https://arxiv.org/abs/1706.02677), you may set the learning rate proportional to the batch size if you use different GPUs or videos per GPU, - e.g., lr=0.01 for 4 GPUs x 2 video/gpu and lr=0.08 for 16 GPUs x 4 video/gpu. -2. The validation set of Kinetics400 we used consists of 19796 videos. These videos are available at [Kinetics400-Validation](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155136485_link_cuhk_edu_hk/EbXw2WX94J1Hunyt3MWNDJUBz-nHvQYhO9pvKqm6g39PMA?e=a9QldB). The corresponding [data list](https://download.openmmlab.com/mmaction/dataset/k400_val/kinetics_val_list.txt) (each line is of the format 'video_id, num_frames, label_index') and the [label map](https://download.openmmlab.com/mmaction/dataset/k400_val/kinetics_class2ind.txt) are also available. - -For more details on data preparation, you can refer to `Prepare audio` in [Data Preparation Tutorial](/docs/en/user_guides/prepare_dataset.md). +| frame sampling strategy | n_fft | gpus | backbone | pretrain | top1 acc | top5 acc | testing protocol | FLOPs | params | config | ckpt | log | +| :---------------------: | :---: | :--: | :------: | :------: | :------: | :------: | :--------------: | :---: | :----: | :------------------------------------: | :----------------------------------: | :---------------------------------: | +| 64x1x1 | 1024 | 8 | Resnet18 | None | 13.7 | 27.3 | 1 clips | 0.37G | 11.4M | [config](/configs/recognition_audio/resnet/tsn_r18_8xb320-64x1x1-100e_kinetics400-audio-feature.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/recognition_audio/resnet/tsn_r18_8xb320-64x1x1-100e_kinetics400-audio-feature/tsn_r18_8xb320-64x1x1-100e_kinetics400-audio-feature_20230702-e4642fb0.pth) | [log](https://download.openmmlab.com/mmaction/v1.0/recognition_audio/resnet/tsn_r18_8xb320-64x1x1-100e_kinetics400-audio-feature/tsn_r18_8xb320-64x1x1-100e_kinetics400-audio-feature.log) | ## Train @@ -43,7 +36,7 @@ Example: train ResNet model on Kinetics-400 audio dataset in a deterministic opt ```shell python tools/train.py configs/recognition_audio/resnet/tsn_r18_8xb320-64x1x1-100e_kinetics400-audio-feature.py \ - --cfg-options randomness.seed=0 randomness.deterministic=True + --seed 0 --deterministic ``` For more details, you can refer to the **Training** part in the [Training and Test Tutorial](/docs/en/user_guides/train_test.md). diff --git a/configs/recognition_audio/resnet/metafile.yml b/configs/recognition_audio/resnet/metafile.yml index f82d234e9a..26f495cd9e 100644 --- a/configs/recognition_audio/resnet/metafile.yml +++ b/configs/recognition_audio/resnet/metafile.yml @@ -11,16 +11,20 @@ Models: In Collection: Audio Metadata: Architecture: ResNet18 + Batch Size: 320 + Epochs: 100 + FLOPs: 0.37G + Parameters: 11.4M Pretrained: None + n_fft: 1024 Training Data: Kinetics-400 Training Resources: 8 GPUs - n_fft: 1024 Modality: Audio Results: - Dataset: Kinetics-400 Task: Action Recognition Metrics: - Top 1 Accuracy: 19.7 - Top 5 Accuracy: 35.75 + Top 1 Accuracy: 13.7 + Top 5 Accuracy: 27.3 Training Log: https://download.openmmlab.com/mmaction/v1.0/recognition_audio/resnet/tsn_r18_8xb320-64x1x1-100e_kinetics400-audio-feature/tsn_r18_8xb320-64x1x1-100e_kinetics400-audio-feature.log - Weights: https://download.openmmlab.com/mmaction/v1.0/recognition_audio/resnet/tsn_r18_8xb320-64x1x1-100e_kinetics400-audio-feature/tsn_r18_8xb320-64x1x1-100e_kinetics400-audio-feature_20201012-bf34df6c.pth + Weights: https://download.openmmlab.com/mmaction/v1.0/recognition_audio/resnet/tsn_r18_8xb320-64x1x1-100e_kinetics400-audio-feature/tsn_r18_8xb320-64x1x1-100e_kinetics400-audio-feature_20230702-e4642fb0.pth diff --git a/configs/recognition_audio/resnet/tsn_r18_8xb320-64x1x1-100e_kinetics400-audio-feature.py b/configs/recognition_audio/resnet/tsn_r18_8xb320-64x1x1-100e_kinetics400-audio-feature.py index 8a37ab5bad..9b00c34796 100644 --- a/configs/recognition_audio/resnet/tsn_r18_8xb320-64x1x1-100e_kinetics400-audio-feature.py +++ b/configs/recognition_audio/resnet/tsn_r18_8xb320-64x1x1-100e_kinetics400-audio-feature.py @@ -1,14 +1,24 @@ -_base_ = [ - '../../_base_/models/tsn_r18_audio.py', '../../_base_/default_runtime.py' -] +_base_ = '../../_base_/default_runtime.py' + +# model settings +model = dict( + type='RecognizerAudio', + backbone=dict(type='ResNet', depth=18, in_channels=1, norm_eval=False), + cls_head=dict( + type='TSNAudioHead', + num_classes=400, + in_channels=512, + dropout_ratio=0.5, + init_std=0.01, + average_clips='prob')) # dataset settings dataset_type = 'AudioDataset' -data_root = 'data/kinetics400/audio_features_train' -data_root_val = 'data/kinetics400/audio_features_val' -ann_file_train = 'data/kinetics400/kinetics400_val_list_audio_features.txt' -ann_file_val = 'data/kinetics400/kinetics400_val_list_audio_features.txt' -ann_file_test = 'data/kinetics400/kinetics400_val_list_audio_features.txt' +data_root = 'data/kinetics400' +ann_file_train = 'kinetics400_train_list_audio_features.txt' +ann_file_val = 'kinetics400_val_list_audio_features.txt' +ann_file_test = 'kinetics400_val_list_audio_features.txt' + train_pipeline = [ dict(type='LoadAudioFeature'), dict(type='SampleFrames', clip_len=64, frame_interval=1, num_clips=1), @@ -28,53 +38,42 @@ dict(type='FormatAudioShape', input_format='NCTF'), dict(type='PackActionInputs') ] -test_pipeline = [ - dict(type='LoadAudioFeature'), - dict( - type='SampleFrames', - clip_len=64, - frame_interval=1, - num_clips=10, - test_mode=True), - dict(type='AudioFeatureSelector'), - dict(type='FormatAudioShape', input_format='NCTF'), - dict(type='PackActionInputs') -] +test_pipeline = val_pipeline train_dataloader = dict( batch_size=320, - num_workers=2, + num_workers=4, persistent_workers=True, sampler=dict(type='DefaultSampler', shuffle=True), dataset=dict( type=dataset_type, ann_file=ann_file_train, - data_prefix=dict(audio=data_root_val), - suffix='.npy', - pipeline=train_pipeline)) + pipeline=train_pipeline, + data_root=data_root, + data_prefix=dict(audio='audio_features_train'))) val_dataloader = dict( batch_size=320, - num_workers=2, + num_workers=4, persistent_workers=True, sampler=dict(type='DefaultSampler', shuffle=False), dataset=dict( type=dataset_type, ann_file=ann_file_val, pipeline=val_pipeline, - data_prefix=dict(audio=data_root_val), - suffix='.npy', + data_root=data_root, + data_prefix=dict(audio='audio_features_val'), test_mode=True)) test_dataloader = dict( batch_size=1, - num_workers=2, + num_workers=4, persistent_workers=True, sampler=dict(type='DefaultSampler', shuffle=False), dataset=dict( type=dataset_type, ann_file=ann_file_test, pipeline=test_pipeline, - data_prefix=dict(audio=data_root_val), - suffix='.npy', + data_root=data_root, + data_prefix=dict(audio='audio_features_val'), test_mode=True)) val_evaluator = dict(type='AccMetric') @@ -90,8 +89,7 @@ ] optim_wrapper = dict( - optimizer=dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001), + optimizer=dict(type='SGD', lr=0.2, momentum=0.9, weight_decay=0.0001), clip_grad=dict(max_norm=40, norm_type=2)) -default_hooks = dict( - checkpoint=dict(max_keep_ckpts=3, interval=5), logger=dict(interval=20)) +default_hooks = dict(checkpoint=dict(max_keep_ckpts=3, interval=5)) diff --git a/configs/recognition_audio/resnet/tsn_r18_8xb320-64x1x1-100e_kinetics400-audio.py b/configs/recognition_audio/resnet/tsn_r18_8xb320-64x1x1-100e_kinetics400-audio.py deleted file mode 100644 index ccae1b251f..0000000000 --- a/configs/recognition_audio/resnet/tsn_r18_8xb320-64x1x1-100e_kinetics400-audio.py +++ /dev/null @@ -1,100 +0,0 @@ -_base_ = [ - '../../_base_/models/tsn_r18_audio.py', '../../_base_/default_runtime.py' -] - -# dataset settings -dataset_type = 'AudioDataset' -data_root = 'data/kinetics400/audios_train' -data_root_val = 'data/kinetics400/audios_val' -ann_file_train = 'data/kinetics400/kinetics400_train_list_audios.txt' -ann_file_val = 'data/kinetics400/kinetics400_val_list_audios.txt' -ann_file_test = 'data/kinetics400/kinetics400_val_list_audios.txt' -train_pipeline = [ - dict(type='AudioDecodeInit'), - dict(type='SampleFrames', clip_len=64, frame_interval=1, num_clips=1), - dict(type='AudioDecode'), - dict(type='AudioAmplify', ratio=1.5), - dict(type='MelSpectrogram'), - dict(type='FormatAudioShape', input_format='NCTF'), - dict(type='PackActionInputs') -] -val_pipeline = [ - dict(type='AudioDecodeInit'), - dict( - type='SampleFrames', - clip_len=64, - frame_interval=1, - num_clips=1, - test_mode=True), - dict(type='AudioDecode'), - dict(type='AudioAmplify', ratio=1.5), - dict(type='MelSpectrogram'), - dict(type='FormatAudioShape', input_format='NCTF'), - dict(type='PackActionInputs') -] -test_pipeline = [ - dict(type='AudioDecodeInit'), - dict( - type='SampleFrames', - clip_len=64, - frame_interval=1, - num_clips=10, - test_mode=True), - dict(type='AudioDecode'), - dict(type='AudioAmplify', ratio=1.5), - dict(type='MelSpectrogram'), - dict(type='FormatAudioShape', input_format='NCTF'), - dict(type='PackActionInputs') -] - -train_dataloader = dict( - batch_size=320, - num_workers=2, - persistent_workers=True, - sampler=dict(type='DefaultSampler', shuffle=True), - dataset=dict( - type=dataset_type, - ann_file=ann_file_train, - data_prefix=dict(audio=data_root), - pipeline=train_pipeline)) -val_dataloader = dict( - batch_size=320, - num_workers=2, - persistent_workers=True, - sampler=dict(type='DefaultSampler', shuffle=False), - dataset=dict( - type=dataset_type, - ann_file=ann_file_val, - pipeline=val_pipeline, - data_prefix=dict(audio=data_root_val), - test_mode=True)) -test_dataloader = dict( - batch_size=1, - num_workers=2, - persistent_workers=True, - sampler=dict(type='DefaultSampler', shuffle=False), - dataset=dict( - type=dataset_type, - ann_file=ann_file_test, - pipeline=test_pipeline, - data_prefix=dict(audio=data_root_val), - test_mode=True)) - -val_evaluator = dict(type='AccMetric') -test_evaluator = val_evaluator - -train_cfg = dict( - type='EpochBasedTrainLoop', max_epochs=100, val_begin=1, val_interval=5) -val_cfg = dict(type='ValLoop') -test_cfg = dict(type='TestLoop') - -param_scheduler = [ - dict(type='CosineAnnealingLR', eta_min=0, T_max=100, by_epoch=True) -] - -optim_wrapper = dict( - optimizer=dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001), - clip_grad=dict(max_norm=40, norm_type=2)) - -default_hooks = dict( - checkpoint=dict(max_keep_ckpts=3, interval=5), logger=dict(interval=20)) diff --git a/mmaction/datasets/audio_dataset.py b/mmaction/datasets/audio_dataset.py index 42c98fb091..07aae25143 100644 --- a/mmaction/datasets/audio_dataset.py +++ b/mmaction/datasets/audio_dataset.py @@ -1,27 +1,21 @@ # Copyright (c) OpenMMLab. All rights reserved. import os.path as osp -from typing import Callable, List, Optional, Union +from typing import Callable, Dict, List, Optional, Union -import torch from mmengine.utils import check_file_exist from mmaction.registry import DATASETS -from mmaction.utils import ConfigType from .base import BaseActionDataset @DATASETS.register_module() class AudioDataset(BaseActionDataset): - """Audio dataset for action recognition. Annotation file can be that of the - rawframe dataset, or: + """Audio dataset for action recognition. - .. code-block:: txt - some/directory-1.wav 163 1 - some/directory-2.wav 122 1 - some/directory-3.wav 258 2 - some/directory-4.wav 234 2 - some/directory-5.wav 295 3 - some/directory-6.wav 121 3 + The ann_file is a text file with multiple lines, and each line indicates + a sample audio or extracted audio feature with the filepath, total frames + of the raw video and label, which are split with a whitespace. + Example of a annotation file: .. code-block:: txt some/directory-1.npy 163 1 @@ -33,26 +27,22 @@ class AudioDataset(BaseActionDataset): Args: ann_file (str): Path to the annotation file. - pipeline (List[Union[dict, ConfigDict, Callable]]): A sequence of - data transforms. - data_prefix (dict or ConfigDict, optional): Path to a directory where + pipeline (list[dict | callable]): A sequence of data transforms. + data_prefix (dict): Path to a directory where audios are held. Defaults to ``dict(audio='')``. multi_class (bool): Determines whether it is a multi-class recognition dataset. Defaults to False. num_classes (int, optional): Number of classes in the dataset. Defaults to None. - suffix (str): The suffix of the audio file. Defaults to ``.wav``. """ def __init__(self, ann_file: str, - pipeline: List[Union[ConfigType, Callable]], - data_prefix: ConfigType = dict(audio=''), + pipeline: List[Union[Dict, Callable]], + data_prefix: Dict = dict(audio=''), multi_class: bool = False, num_classes: Optional[int] = None, - suffix: str = '.wav', **kwargs) -> None: - self.suffix = suffix super().__init__( ann_file, pipeline, @@ -62,8 +52,8 @@ def __init__(self, modality='Audio', **kwargs) - def load_data_list(self) -> List[dict]: - """Load annotation file to get video information.""" + def load_data_list(self) -> List[Dict]: + """Load annotation file to get audio information.""" check_file_exist(self.ann_file) data_list = [] with open(self.ann_file, 'r') as fin: @@ -73,25 +63,18 @@ def load_data_list(self) -> List[dict]: idx = 0 filename = line_split[idx] if self.data_prefix['audio'] is not None: - if not filename.endswith(self.suffix): - filename = osp.join(self.data_prefix['audio'], - filename + self.suffix) - else: - filename = osp.join(self.data_prefix['audio'], - filename) + filename = osp.join(self.data_prefix['audio'], filename) video_info['audio_path'] = filename idx += 1 # idx for total_frames video_info['total_frames'] = int(line_split[idx]) idx += 1 - # idx for label[s] + # idx for label label = [int(x) for x in line_split[idx:]] assert label, f'missing label in line: {line}' if self.multi_class: assert self.num_classes is not None - onehot = torch.zeros(self.num_classes) - onehot[label] = 1.0 - video_info['label'] = onehot + video_info['label'] = label else: assert len(label) == 1 video_info['label'] = label[0] diff --git a/mmaction/datasets/transforms/__init__.py b/mmaction/datasets/transforms/__init__.py index d8b8cc4eb3..f2670cd929 100644 --- a/mmaction/datasets/transforms/__init__.py +++ b/mmaction/datasets/transforms/__init__.py @@ -1,9 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. from .formatting import (FormatAudioShape, FormatGCNInput, FormatShape, PackActionInputs, PackLocalizationInputs, Transpose) -from .loading import (ArrayDecode, AudioDecode, AudioDecodeInit, - AudioFeatureSelector, BuildPseudoClip, DecordDecode, - DecordInit, DenseSampleFrames, +from .loading import (ArrayDecode, AudioFeatureSelector, BuildPseudoClip, + DecordDecode, DecordInit, DenseSampleFrames, GenerateLocalizationLabels, ImageDecode, LoadAudioFeature, LoadHVULabel, LoadLocalizationFeature, LoadProposals, LoadRGBFromFile, OpenCVDecode, OpenCVInit, @@ -15,29 +14,28 @@ MMDecode, MMUniformSampleFrames, PadTo, PoseCompact, PoseDecode, PreNormalize2D, PreNormalize3D, ToMotion, UniformSampleFrames) -from .processing import (AudioAmplify, CenterCrop, ColorJitter, Flip, Fuse, - MelSpectrogram, MultiScaleCrop, RandomCrop, - RandomRescale, RandomResizedCrop, Resize, TenCrop, - ThreeCrop) +from .processing import (CenterCrop, ColorJitter, Flip, Fuse, MultiScaleCrop, + RandomCrop, RandomRescale, RandomResizedCrop, Resize, + TenCrop, ThreeCrop) from .text_transforms import CLIPTokenize from .wrappers import ImgAug, PytorchVideoWrapper, TorchVisionWrapper __all__ = [ - 'ArrayDecode', 'AudioAmplify', 'AudioDecode', 'AudioDecodeInit', - 'AudioFeatureSelector', 'BuildPseudoClip', 'CenterCrop', 'ColorJitter', - 'DecordDecode', 'DecordInit', 'DecordInit', 'DenseSampleFrames', 'Flip', - 'FormatAudioShape', 'FormatGCNInput', 'FormatShape', 'Fuse', 'GenSkeFeat', - 'GenerateLocalizationLabels', 'GeneratePoseTarget', 'ImageDecode', - 'ImgAug', 'JointToBone', 'LoadAudioFeature', 'LoadHVULabel', - 'LoadKineticsPose', 'LoadLocalizationFeature', 'LoadProposals', - 'LoadRGBFromFile', 'MelSpectrogram', 'MergeSkeFeat', 'MultiScaleCrop', - 'OpenCVDecode', 'OpenCVInit', 'OpenCVInit', 'PIMSDecode', 'PIMSInit', - 'PackActionInputs', 'PackLocalizationInputs', 'PadTo', 'PoseCompact', - 'PoseDecode', 'PreNormalize2D', 'PreNormalize3D', 'PyAVDecode', - 'PyAVDecodeMotionVector', 'PyAVInit', 'PyAVInit', 'PytorchVideoWrapper', - 'RandomCrop', 'RandomRescale', 'RandomResizedCrop', 'RawFrameDecode', - 'Resize', 'SampleAVAFrames', 'SampleFrames', 'TenCrop', 'ThreeCrop', - 'ToMotion', 'TorchVisionWrapper', 'Transpose', 'UniformSample', - 'UniformSampleFrames', 'UntrimmedSampleFrames', 'MMUniformSampleFrames', - 'MMDecode', 'MMCompact', 'CLIPTokenize' + 'ArrayDecode', 'AudioFeatureSelector', 'BuildPseudoClip', 'CenterCrop', + 'ColorJitter', 'DecordDecode', 'DecordInit', 'DecordInit', + 'DenseSampleFrames', 'Flip', 'FormatAudioShape', 'FormatGCNInput', + 'FormatShape', 'Fuse', 'GenSkeFeat', 'GenerateLocalizationLabels', + 'GeneratePoseTarget', 'ImageDecode', 'ImgAug', 'JointToBone', + 'LoadAudioFeature', 'LoadHVULabel', 'LoadKineticsPose', + 'LoadLocalizationFeature', 'LoadProposals', 'LoadRGBFromFile', + 'MergeSkeFeat', 'MultiScaleCrop', 'OpenCVDecode', 'OpenCVInit', + 'OpenCVInit', 'PIMSDecode', 'PIMSInit', 'PackActionInputs', + 'PackLocalizationInputs', 'PadTo', 'PoseCompact', 'PoseDecode', + 'PreNormalize2D', 'PreNormalize3D', 'PyAVDecode', 'PyAVDecodeMotionVector', + 'PyAVInit', 'PyAVInit', 'PytorchVideoWrapper', 'RandomCrop', + 'RandomRescale', 'RandomResizedCrop', 'RawFrameDecode', 'Resize', + 'SampleAVAFrames', 'SampleFrames', 'TenCrop', 'ThreeCrop', 'ToMotion', + 'TorchVisionWrapper', 'Transpose', 'UniformSample', 'UniformSampleFrames', + 'UntrimmedSampleFrames', 'MMUniformSampleFrames', 'MMDecode', 'MMCompact', + 'CLIPTokenize' ] diff --git a/mmaction/datasets/transforms/formatting.py b/mmaction/datasets/transforms/formatting.py index 6ca61a4ccc..9b9cb375a9 100644 --- a/mmaction/datasets/transforms/formatting.py +++ b/mmaction/datasets/transforms/formatting.py @@ -361,8 +361,17 @@ def __repr__(self) -> str: class FormatAudioShape(BaseTransform): """Format final audio shape to the given input_format. - Required keys are ``audios``, ``num_clips`` and ``clip_len``, added or - modified keys are ``audios`` and ``input_shape``. + Required Keys: + + - audios + + Modified Keys: + + - audios + + Added Keys: + + - input_shape Args: input_format (str): Define the final imgs format. @@ -374,7 +383,7 @@ def __init__(self, input_format: str) -> None: raise ValueError( f'The input format {self.input_format} is invalid.') - def transform(self, results: dict) -> dict: + def transform(self, results: Dict) -> Dict: """Performs the FormatShape formatting. Args: @@ -389,7 +398,7 @@ def transform(self, results: dict) -> dict: results['input_shape'] = audios.shape return results - def __repr__(self): + def __repr__(self) -> str: repr_str = self.__class__.__name__ repr_str += f"(input_format='{self.input_format}')" return repr_str diff --git a/mmaction/datasets/transforms/loading.py b/mmaction/datasets/transforms/loading.py index e876143cd3..22070371a1 100644 --- a/mmaction/datasets/transforms/loading.py +++ b/mmaction/datasets/transforms/loading.py @@ -1621,105 +1621,39 @@ def transform(self, results): @TRANSFORMS.register_module() -class AudioDecodeInit(BaseTransform): - """Using librosa to initialize the audio reader. - - Required keys are ``audio_path``, added or modified keys are ``length``, - ``sample_rate``, ``audios``. - - Args: - io_backend (str): io backend where frames are store. - Defaults to ``disk``. - sample_rate (int): Audio sampling times per second. Defaults to 16000. - pad_method (str): Padding method. Defaults to ``zero``. - """ - - def __init__(self, - io_backend: str = 'disk', - sample_rate: int = 16000, - pad_method: str = 'zero', - **kwargs) -> None: - self.io_backend = io_backend - self.sample_rate = sample_rate - if pad_method in ['random', 'zero']: - self.pad_method = pad_method - else: - raise NotImplementedError - self.kwargs = kwargs - self.file_client = None - - @staticmethod - def _zero_pad(shape: int) -> np.ndarray: - """Zero padding method.""" - return np.zeros(shape, dtype=np.float32) - - @staticmethod - def _random_pad(shape: int) -> np.ndarray: - """Random padding method.""" - # librosa load raw audio file into a distribution of -1~+1 - return np.random.rand(shape).astype(np.float32) * 2 - 1 - - def transform(self, results: dict) -> dict: - """Perform the librosa initialization. - - Args: - results (dict): The resulting dict to be modified and passed - to the next transform in pipeline. - """ - try: - import librosa - except ImportError: - raise ImportError('Please install librosa first.') +class LoadAudioFeature(BaseTransform): + """Load offline extracted audio features. - if self.file_client is None: - self.file_client = FileClient(self.io_backend, **self.kwargs) - if osp.exists(results['audio_path']): - file_obj = io.BytesIO(self.file_client.get(results['audio_path'])) - y, sr = librosa.load(file_obj, sr=self.sample_rate) - else: - # Generate a random dummy 10s input - pad_func = getattr(self, f'_{self.pad_method}_pad') - y = pad_func(int(round(10.0 * self.sample_rate))) - sr = self.sample_rate + Required Keys: - results['length'] = y.shape[0] - results['sample_rate'] = sr - results['audios'] = y - return results - - def __repr__(self): - repr_str = (f'{self.__class__.__name__}(' - f'io_backend={self.io_backend}, ' - f'sample_rate={self.sample_rate}, ' - f'pad_method={self.pad_method})') - return repr_str + - audio_path + Added Keys: -@TRANSFORMS.register_module() -class LoadAudioFeature(BaseTransform): - """Load offline extracted audio features. + - length + - audios - Required keys are "audio_path", added or modified keys are "length", - audios". + Args: + pad_method (str): Padding method. Defaults to ``'zero'``. """ - def __init__(self, pad_method='zero'): + def __init__(self, pad_method: str = 'zero') -> None: if pad_method not in ['zero', 'random']: raise NotImplementedError self.pad_method = pad_method @staticmethod - def _zero_pad(shape): + def _zero_pad(shape: int) -> np.ndarray: """Zero padding method.""" return np.zeros(shape, dtype=np.float32) @staticmethod - def _random_pad(shape): + def _random_pad(shape: int) -> np.ndarray: """Random padding method.""" # spectrogram is normalized into a distribution of 0~1 return np.random.rand(shape).astype(np.float32) - def transform(self, results): + def transform(self, results: Dict) -> Dict: """Perform the numpy loading. Args: @@ -1738,68 +1672,12 @@ def transform(self, results): results['audios'] = feature_map return results - def __repr__(self): + def __repr__(self) -> str: repr_str = (f'{self.__class__.__name__}(' f'pad_method={self.pad_method})') return repr_str -@TRANSFORMS.register_module() -class AudioDecode(BaseTransform): - """Sample the audio w.r.t. the frames selected. - - Args: - fixed_length (int): As the audio clip selected by frames sampled may - not be exactly the same, ``fixed_length`` will truncate or pad them - into the same size. Defaults to 32000. - - Required keys are ``frame_inds``, ``num_clips``, ``total_frames``, - ``length``, added or modified keys are ``audios``, ``audios_shape``. - """ - - def __init__(self, fixed_length: int = 32000) -> None: - self.fixed_length = fixed_length - - def transform(self, results: dict) -> dict: - """Perform the ``AudioDecode`` to pick audio clips.""" - audio = results['audios'] - frame_inds = results['frame_inds'] - num_clips = results['num_clips'] - resampled_clips = list() - frame_inds = frame_inds.reshape(num_clips, -1) - for clip_idx in range(num_clips): - clip_frame_inds = frame_inds[clip_idx] - start_idx = max( - 0, - int( - round((clip_frame_inds[0] + 1) / results['total_frames'] * - results['length']))) - end_idx = min( - results['length'], - int( - round((clip_frame_inds[-1] + 1) / results['total_frames'] * - results['length']))) - cropped_audio = audio[start_idx:end_idx] - if cropped_audio.shape[0] >= self.fixed_length: - truncated_audio = cropped_audio[:self.fixed_length] - else: - truncated_audio = np.pad( - cropped_audio, - ((0, self.fixed_length - cropped_audio.shape[0])), - mode='constant') - - resampled_clips.append(truncated_audio) - - results['audios'] = np.array(resampled_clips) - results['audios_shape'] = results['audios'].shape - return results - - def __repr__(self): - repr_str = self.__class__.__name__ - repr_str += f"(fixed_length='{self.fixed_length}')" - return repr_str - - @TRANSFORMS.register_module() class BuildPseudoClip(BaseTransform): """Build pseudo clips with one single image by repeating it n times. @@ -1840,19 +1718,32 @@ def __repr__(self): class AudioFeatureSelector(BaseTransform): """Sample the audio feature w.r.t. the frames selected. - Required keys are "audios", "frame_inds", "num_clips", "length", - "total_frames", added or modified keys are "audios", "audios_shape". + Required Keys: + + - audios + - frame_inds + - num_clips + - length + - total_frames + + Modified Keys: + + - audios + + Added Keys: + + - audios_shape Args: fixed_length (int): As the features selected by frames sampled may not be exactly the same, `fixed_length` will truncate or pad them - into the same size. Default: 128. + into the same size. Defaults to 128. """ - def __init__(self, fixed_length=128): + def __init__(self, fixed_length: int = 128) -> None: self.fixed_length = fixed_length - def transform(self, results): + def transform(self, results: Dict) -> Dict: """Perform the ``AudioFeatureSelector`` to pick audio feature clips. Args: @@ -1891,7 +1782,7 @@ def transform(self, results): results['audios_shape'] = results['audios'].shape return results - def __repr__(self): + def __repr__(self) -> str: repr_str = (f'{self.__class__.__name__}(' f'fix_length={self.fixed_length})') return repr_str diff --git a/mmaction/datasets/transforms/processing.py b/mmaction/datasets/transforms/processing.py index 13637dcf38..3d432bd723 100644 --- a/mmaction/datasets/transforms/processing.py +++ b/mmaction/datasets/transforms/processing.py @@ -1276,117 +1276,6 @@ def __repr__(self): return repr_str -@TRANSFORMS.register_module() -class AudioAmplify(BaseTransform): - """Amplify the waveform. - - Required keys are ``audios``, added or modified keys are ``audios``, - ``amplify_ratio``. - - Args: - ratio (float): The ratio used to amplify the audio waveform. - """ - - def __init__(self, ratio: float) -> None: - if isinstance(ratio, float): - self.ratio = ratio - else: - raise TypeError('Amplification ratio should be float.') - - def transform(self, results: dict) -> dict: - """Perform the audio amplification. - - Args: - results (dict): The resulting dict to be modified and passed - to the next transform in pipeline. - """ - - assert 'audios' in results - results['audios'] *= self.ratio - results['amplify_ratio'] = self.ratio - - return results - - def __repr__(self): - repr_str = f'{self.__class__.__name__}(ratio={self.ratio})' - return repr_str - - -@TRANSFORMS.register_module() -class MelSpectrogram(BaseTransform): - """MelSpectrogram. Transfer an audio wave into a melspectogram figure. - - Required keys are ``audios``, ``sample_rate``, ``num_clips``, added or - modified keys are ``audios``. - - Args: - window_size (int): The window size in millisecond. Defaults to 32. - step_size (int): The step size in millisecond. Defaults to 16. - n_mels (int): Number of mels. Defaults to 80. - fixed_length (int): The sample length of melspectrogram maybe not - exactly as wished due to different fps, fix the length for batch - collation by truncating or padding. Defaults to 128. - """ - - def __init__(self, - window_size: int = 32, - step_size: int = 16, - n_mels: int = 80, - fixed_length: int = 128) -> None: - if all( - isinstance(x, int) - for x in [window_size, step_size, n_mels, fixed_length]): - self.window_size = window_size - self.step_size = step_size - self.n_mels = n_mels - self.fixed_length = fixed_length - else: - raise TypeError('All arguments should be int.') - - def transform(self, results: dict) -> dict: - """Perform MelSpectrogram transformation. - - Args: - results (dict): The resulting dict to be modified and passed - to the next transform in pipeline. - """ - try: - import librosa - except ImportError: - raise ImportError('Install librosa first.') - signals = results['audios'] - sample_rate = results['sample_rate'] - n_fft = int(round(sample_rate * self.window_size / 1000)) - hop_length = int(round(sample_rate * self.step_size / 1000)) - melspectrograms = list() - for clip_idx in range(results['num_clips']): - clip_signal = signals[clip_idx] - mel = librosa.feature.melspectrogram( - y=clip_signal, - sr=sample_rate, - n_fft=n_fft, - hop_length=hop_length, - n_mels=self.n_mels) - if mel.shape[0] >= self.fixed_length: - mel = mel[:self.fixed_length, :] - else: - mel = np.pad( - mel, ((0, self.fixed_length - mel.shape[0]), (0, 0)), - mode='edge') - melspectrograms.append(mel) - - results['audios'] = np.array(melspectrograms) - return results - - def __repr__(self): - repr_str = (f'{self.__class__.__name__}' - f'(window_size={self.window_size}), ' - f'step_size={self.step_size}, ' - f'n_mels={self.n_mels}, ' - f'fixed_length={self.fixed_length})') - return repr_str - - @TRANSFORMS.register_module() class RandomErasing(BaseTransform): """Randomly selects a rectangle region in an image and erase pixels. diff --git a/tests/datasets/transforms/test_loading.py b/tests/datasets/transforms/test_loading.py index 035a2213cc..ee2cc64717 100644 --- a/tests/datasets/transforms/test_loading.py +++ b/tests/datasets/transforms/test_loading.py @@ -10,8 +10,7 @@ from mmengine.testing import assert_dict_has_keys from numpy.testing import assert_array_almost_equal -from mmaction.datasets.transforms import (AudioDecode, AudioDecodeInit, - DecordDecode, DecordInit, +from mmaction.datasets.transforms import (DecordDecode, DecordInit, GenerateLocalizationLabels, LoadAudioFeature, LoadHVULabel, LoadLocalizationFeature, @@ -533,42 +532,6 @@ def test_rawframe_decode(self): f'{frame_selector.__class__.__name__}(io_backend=disk, ' f'decoding_backend=turbojpeg)') - def test_audio_decode_init(self): - try: - import soundfile as sf # noqa: F401 - except (OSError, ImportError): - return - target_keys = ['audios', 'length', 'sample_rate'] - inputs = copy.deepcopy(self.audio_results) - audio_decode_init = AudioDecodeInit() - results = audio_decode_init(inputs) - assert assert_dict_has_keys(results, target_keys) - - # test when no audio file exists - inputs = copy.deepcopy(self.audio_results) - inputs['audio_path'] = 'foo/foo/bar.wav' - audio_decode_init = AudioDecodeInit() - results = audio_decode_init(inputs) - assert assert_dict_has_keys(results, target_keys) - assert results['audios'].shape == (10.0 * - audio_decode_init.sample_rate, ) - assert repr(audio_decode_init) == ( - f'{audio_decode_init.__class__.__name__}(' - f'io_backend=disk, ' - f'sample_rate=16000, ' - f'pad_method=zero)') - - def test_audio_decode(self): - target_keys = ['frame_inds', 'audios'] - inputs = copy.deepcopy(self.audio_results) - inputs['frame_inds'] = np.arange(0, self.audio_total_frames, - 2)[:, np.newaxis] - inputs['num_clips'] = 1 - inputs['length'] = 1280 - audio_selector = AudioDecode() - results = audio_selector(inputs) - assert assert_dict_has_keys(results, target_keys) - def test_pyav_decode_motion_vector(self): pyav_init = PyAVInit() pyav = PyAVDecodeMotionVector() diff --git a/tests/datasets/transforms/test_processing.py b/tests/datasets/transforms/test_processing.py index 028f5d7129..cc7c18add2 100644 --- a/tests/datasets/transforms/test_processing.py +++ b/tests/datasets/transforms/test_processing.py @@ -7,11 +7,10 @@ from mmengine.testing import assert_dict_has_keys from numpy.testing import assert_array_almost_equal -from mmaction.datasets.transforms import (AudioAmplify, CenterCrop, - ColorJitter, Flip, Fuse, - MelSpectrogram, MultiScaleCrop, - RandomCrop, RandomResizedCrop, - Resize, TenCrop, ThreeCrop) +from mmaction.datasets.transforms import (CenterCrop, ColorJitter, Flip, Fuse, + MultiScaleCrop, RandomCrop, + RandomResizedCrop, Resize, TenCrop, + ThreeCrop) def check_crop(origin_imgs, result_imgs, result_bbox, num_crops=1): @@ -70,59 +69,6 @@ def check_flip(origin_imgs, result_imgs, flip_type): return True -class TestAudio: - - @staticmethod - def test_audio_amplify(): - target_keys = ['audios', 'amplify_ratio'] - with pytest.raises(TypeError): - # ratio should be float - AudioAmplify(1) - - audio = (np.random.rand(8, )) - results = dict(audios=audio) - amplifier = AudioAmplify(1.5) - results = amplifier(results) - assert assert_dict_has_keys(results, target_keys) - assert repr(amplifier) == (f'{amplifier.__class__.__name__}' - f'(ratio={amplifier.ratio})') - - @staticmethod - def test_melspectrogram(): - target_keys = ['audios'] - with pytest.raises(TypeError): - # ratio should be float - MelSpectrogram(window_size=12.5) - audio = (np.random.rand(1, 160000)) - - # test padding - results = dict(audios=audio, sample_rate=16000) - results['num_clips'] = 1 - results['sample_rate'] = 16000 - mel = MelSpectrogram() - try: - import soundfile as sf # noqa: F401 - except (OSError, ImportError): - return - - results = mel(results) - assert assert_dict_has_keys(results, target_keys) - - # test truncating - audio = (np.random.rand(1, 160000)) - results = dict(audios=audio, sample_rate=16000) - results['num_clips'] = 1 - results['sample_rate'] = 16000 - mel = MelSpectrogram(fixed_length=1) - results = mel(results) - assert assert_dict_has_keys(results, target_keys) - assert repr(mel) == (f'{mel.__class__.__name__}' - f'(window_size={mel.window_size}), ' - f'step_size={mel.step_size}, ' - f'n_mels={mel.n_mels}, ' - f'fixed_length={mel.fixed_length})') - - class TestColor: @staticmethod diff --git a/tools/data/build_audio_features.py b/tools/data/build_audio_features.py index 28356a0e64..cd3070bace 100644 --- a/tools/data/build_audio_features.py +++ b/tools/data/build_audio_features.py @@ -38,11 +38,16 @@ class AudioTools: `_. Args: - frame_rate (int): The frame rate per second of the video. Default: 30. - sample_rate (int): The sample rate for audio sampling. Default: 16000. - num_mels (int): Number of channels of the melspectrogram. Default: 80. - fft_size (int): fft_size / sample_rate is window size. Default: 1280. - hop_size (int): hop_size / sample_rate is step size. Default: 320. + frame_rate (int): The frame rate per second of the video. + Defaults to 30. + sample_rate (int): The sample rate for audio sampling. + Defaults to 16000. + num_mels (int): Number of channels of the melspectrogram. + Defaults to 80. + fft_size (int): fft_size / sample_rate is window size. + Defaults to 1280. + hop_size (int): hop_size / sample_rate is step size. + Defaults to 320. """ def __init__(self, @@ -290,15 +295,15 @@ def extract_audio_feature(wav_path, audio_tools, mel_out_dir): parser.add_argument('audio_home_path', type=str) parser.add_argument('spectrogram_save_path', type=str) parser.add_argument('--level', type=int, default=1) - parser.add_argument('--ext', default='.m4a') + parser.add_argument('--ext', default='m4a') parser.add_argument('--num-workers', type=int, default=4) parser.add_argument('--part', type=str, default='1/1') args = parser.parse_args() mmengine.mkdir_or_exist(args.spectrogram_save_path) - files = glob.glob( - osp.join(args.audio_home_path, '*/' * args.level, '*' + args.ext)) + files = glob.glob(args.audio_home_path + '/*' * args.level + '.' + + args.ext) print(f'found {len(files)} files.') files = sorted(files) if args.part is not None: