Skip to content

Commit

Permalink
Add audio support to DataPack (#585)
Browse files Browse the repository at this point in the history
* Add audio support to DataPack

* Fix PR comments

* Update soundfile dependency and docs

* Update README for extra requirements

* Add soundfile to docs/requirements.txt

* Update README for extra requirements

* Add wikipedia extra req to README

Co-authored-by: Suqi Sun <[email protected]>
  • Loading branch information
mylibrar and Suqi Sun authored Jan 12, 2022
1 parent 818564c commit 0a16879
Show file tree
Hide file tree
Showing 13 changed files with 265 additions and 1 deletion.
4 changes: 3 additions & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ jobs:
${{ runner.os }}-
- name: Install dependencies
run: |
sudo apt-get install -y libsndfile1-dev
python -m pip install --progress-bar off --upgrade pip
pip install --progress-bar off Django django-guardian
pip install --progress-bar off pylint==2.10.2 flake8==3.9.2 mypy==0.910 pytest==5.1.3 black==20.8b1
Expand Down Expand Up @@ -77,7 +78,7 @@ jobs:
rm -rf texar-pytorch
- name: Install Forte
run: |
pip install --use-feature=in-tree-build --progress-bar off .[ner,test,example,wikipedia,augment,stave]
pip install --use-feature=in-tree-build --progress-bar off .[ner,test,example,wikipedia,augment,stave,audio_ext]
- name: Build ontology
run: |
./scripts/build_ontology_specs.sh
Expand Down Expand Up @@ -121,6 +122,7 @@ jobs:
${{ runner.os }}-
- name: Install dependencies
run: |
sudo apt-get install -y libsndfile1-dev
python -m pip install --progress-bar off --upgrade pip
pip install --progress-bar off -r requirements.txt
pip install --progress-bar off -r docs/requirements.txt
Expand Down
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,16 @@ cd forte-wrappers
pip install src/spacy
```

Some components or modules in forte may require some [extra requirements](https:/asyml/forte/blob/master/setup.py#L45):

* `pip install forte[ner]`: Install packages required for [ner_trainer](https:/asyml/forte/blob/master/forte/trainer/ner_trainer.py)
* `pip install forte[test]`: Install packages required for running [unit tests](https:/asyml/forte/tree/master/tests).
* `pip install forte[example]`: Install packages required for running [forte examples](https:/asyml/forte/tree/master/examples).
* `pip install forte[wikipedia]`: Install packages required for reading [wikipedia datasets](https:/asyml/forte/tree/master/forte/datasets/wikipedia).
* `pip install forte[augment]`: Install packages required for [data augmentation module](https:/asyml/forte/tree/master/forte/processors/data_augment).
* `pip install forte[stave]`: Install packages required for [StaveProcessor](https:/asyml/forte/blob/master/forte/processors/stave/stave_processor.py).
* `pip install forte[audio_ext]`: Install packages required for [AudioReader](https:/asyml/forte/blob/master/forte/data/readers/audio_reader.py).

## Getting Started

* [Examples](./examples)
Expand Down
5 changes: 5 additions & 0 deletions data_samples/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
This folder contains a list of data samples that are used by forte to facilitate test cases.

# List of Data Samples
## audio_reader_test
This directory consists of audio files that are used in a unit test for verifying the AudioReader in `forte/tests/forte/data/readers/audio_reader_test.py`. Currently it contains two `.flac` files excerpted from a HuggingFace dataset called [patrickvonplaten/librispeech_asr_dummy](https://huggingface.co/datasets/patrickvonplaten/librispeech_asr_dummy) for automatic speech recognition.
Binary file added data_samples/audio_reader_test/test_audio_0.flac
Binary file not shown.
Binary file added data_samples/audio_reader_test/test_audio_1.flac
Binary file not shown.
26 changes: 26 additions & 0 deletions docs/audio_processing.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Audio Processing

## Audio DataPack
`DataPack` includes a payload for audio data and a metadata for sample rate. You can set them by calling the `set_audio` method:
```python
from forte.data.data_pack import DataPack

pack: DataPack = DataPack()
pack.set_audio(audio, sample_rate)
```
The input parameter `audio` should be a numpy array of raw waveform and `sample_rate` should be an integer the specifies the sample rate. Now you can access these data using `DataPack.audio` and `DataPack.sample_rate`.

## Audio Reader
`AudioReader` supports reading in the audio data from files under a specific directory. You can set it as the reader of your forte pipeline whenever you need to process audio files:
```python
from forte.pipeline import Pipeline
from forte.data.readers.audio_reader import AudioReader

Pipeline().set_reader(
reader=AudioReader(),
config={"file_ext": ".wav"}
).run(
"path-to-audio-directory"
)
```
The example above builds a simple pipeline that can walk through the specified directory and load all the files with extension of `.wav`. `AudioReader` will create a `DataPack` for each file with the corresponding audio payload and the sample rate.
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Welcome to Forte's documentation!

examples.md
ontology_generation.md
audio_processing.md

API
====
Expand Down
3 changes: 3 additions & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,6 @@ nltk==3.4.5
# FastAPI
fastapi==0.65.2
uvicorn==0.14.0

# soundfile
soundfile>=0.10.3
27 changes: 27 additions & 0 deletions forte/data/data_pack.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,13 @@ class Meta(BaseMeta):
language: The language used by this data pack, default is English.
span_unit: The unit used for interpreting the Span object of this
data pack. Default is character.
sample_rate: An integer specifying the sample rate of audio payload.
Default is None.
info: Store additional string based information that the user add.
Attributes:
pack_name: storing the provided `pack_name`.
language: storing the provided `language`.
sample_rate: storing the provided `sample_rate`.
info: storing the provided `info`.
record: Initialized as a dictionary. This is not a required field.
The key of the record should be the entry type and values should
Expand All @@ -87,11 +90,13 @@ def __init__(
pack_name: Optional[str] = None,
language: str = "eng",
span_unit: str = "character",
sample_rate: Optional[int] = None,
info: Optional[Dict[str, str]] = None,
):
super().__init__(pack_name)
self.language = language
self.span_unit = span_unit
self.sample_rate: Optional[int] = sample_rate
self.record: Dict[str, Set[str]] = {}
self.info: Dict[str, str]
if info is None:
Expand Down Expand Up @@ -155,6 +160,7 @@ class DataPack(BasePack[Entry, Link, Group]):
def __init__(self, pack_name: Optional[str] = None):
super().__init__(pack_name)
self._text = ""
self._audio: Optional[np.ndarray] = None

self.annotations: SortedList[Annotation] = SortedList()
self.links: SortedList[Link] = SortedList()
Expand Down Expand Up @@ -242,6 +248,16 @@ def text(self) -> str:
r"""Return the text of the data pack"""
return self._text

@property
def audio(self) -> Optional[np.ndarray]:
r"""Return the audio of the data pack"""
return self._audio

@property
def sample_rate(self) -> Optional[int]:
r"""Return the sample rate of the audio data"""
return getattr(self._meta, "sample_rate")

@property
def all_annotations(self) -> Iterator[Annotation]:
"""
Expand Down Expand Up @@ -365,6 +381,17 @@ def set_text(
self.__orig_text_len,
) = data_utils_io.modify_text_and_track_ops(text, span_ops)

def set_audio(self, audio: np.ndarray, sample_rate: int):
r"""Set the audio payload and sample rate of the :class:`DataPack`
object.
Args:
audio: A numpy array storing the audio waveform.
sample_rate: An integer specifying the sample rate.
"""
self._audio = audio
self.set_meta(sample_rate=sample_rate)

def get_original_text(self):
r"""Get original unmodified text from the :class:`DataPack` object.
Expand Down
1 change: 1 addition & 0 deletions forte/data/readers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,4 @@
from forte.data.readers.ag_news_reader import *
from forte.data.readers.largemovie_reader import *
from forte.data.readers.misc_readers import *
from forte.data.readers.audio_reader import *
84 changes: 84 additions & 0 deletions forte/data/readers/audio_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright 2022 The Forte Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The reader that reads audio files into Datapacks.
"""
import os
from typing import Any, Iterator

from forte.data.data_pack import DataPack
from forte.data.data_utils_io import dataset_path_iterator
from forte.data.base_reader import PackReader

__all__ = [
"AudioReader",
]


class AudioReader(PackReader):
r""":class:`AudioReader` is designed to read in audio files."""

try:
import soundfile # pylint: disable=import-outside-toplevel
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
"AudioReader requires 'soundfile' package to be installed."
" You can run 'pip install soundfile' or 'pip install forte"
"[audio_ext]'. Note that additional steps might apply to Linux"
" users (refer to "
"https://pysoundfile.readthedocs.io/en/latest/#installation)."
) from e

def _collect(self, audio_directory) -> Iterator[Any]: # type: ignore
r"""Should be called with param ``audio_directory`` which is a path to a
folder containing audio files.
Args:
audio_directory: audio directory containing the files.
Returns: Iterator over paths to audio files
"""
return dataset_path_iterator(audio_directory, self.configs.file_ext)

def _cache_key_function(self, audio_file: str) -> str:
return os.path.basename(audio_file)

def _parse_pack(self, file_path: str) -> Iterator[DataPack]:
pack: DataPack = DataPack()

# Read in audio data and store in DataPack
audio, sample_rate = self.soundfile.read(
file=file_path, **(self.configs.read_kwargs or {})
)
pack.set_audio(audio=audio, sample_rate=sample_rate)
pack.pack_name = file_path

yield pack

@classmethod
def default_configs(cls):
r"""This defines a basic configuration structure for audio reader.
Here:
- file_ext (str): The file extension to find the target audio files
under a specific directory path. Default value is ".flac".
- read_kwargs (dict): A dictionary containing all the keyword
arguments for `soundfile.read` method. For details, refer to
https://pysoundfile.readthedocs.io/en/latest/#soundfile.read.
Default value is None.
Returns: The default configuration of audio reader.
"""
return {"file_ext": ".flac", "read_kwargs": None}
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
# transformers 4.10.0 will break the translation model we used here
"augment": ["transformers>=3.1, <=4.9.2", "nltk"],
"stave": ["stave>=0.0.1.dev12"],
"audio_ext": ["soundfile>=0.10.3"],
},
entry_points={
"console_scripts": [
Expand Down
104 changes: 104 additions & 0 deletions tests/forte/data/readers/audio_reader_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright 2022 The Forte Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Unit tests for AudioReader.
"""
import os
import unittest
from typing import Dict
from torch import argmax
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC

from forte.common.configuration import Config
from forte.common.resources import Resources
from forte.common.exception import ProcessFlowException
from forte.data.data_pack import DataPack
from forte.data.readers import AudioReader
from forte.pipeline import Pipeline
from forte.processors.base.pack_processor import PackProcessor


class TestASRProcessor(PackProcessor):
"""
An audio processor for automatic speech recognition.
"""
def initialize(self, resources: Resources, configs: Config):
super().initialize(resources, configs)

# Initialize tokenizer and model
pretrained_model: str = "facebook/wav2vec2-base-960h"
self._tokenizer = Wav2Vec2Processor.from_pretrained(pretrained_model)
self._model = Wav2Vec2ForCTC.from_pretrained(pretrained_model)

def _process(self, input_pack: DataPack):
required_sample_rate: int = 16000
if input_pack.sample_rate != required_sample_rate:
raise ProcessFlowException(
f"A sample rate of {required_sample_rate} Hz is requied by the"
" pretrained model."
)

# tokenize
input_values = self._tokenizer(
input_pack.audio, return_tensors="pt", padding="longest"
).input_values # Batch size 1

# take argmax and decode
transcription = self._tokenizer.batch_decode(
argmax(self._model(input_values).logits, dim=-1)
)

input_pack.set_text(text=transcription[0])


class AudioReaderPipelineTest(unittest.TestCase):
"""
Test AudioReader by running audio processing pipelines
"""

def setUp(self):
self._test_audio_path: str = os.path.abspath(
os.path.join(
os.path.dirname(os.path.abspath(__file__)),
os.pardir,
os.pardir,
os.pardir,
os.pardir,
"data_samples/audio_reader_test"
)
)

# Define and config the Pipeline
self._pipeline = Pipeline[DataPack]()
self._pipeline.set_reader(AudioReader())
self._pipeline.add(TestASRProcessor())
self._pipeline.initialize()

def test_asr_pipeline(self):
target_transcription: Dict[str, str] = {
self._test_audio_path + "/test_audio_0.flac":
"A MAN SAID TO THE UNIVERSE SIR I EXIST",
self._test_audio_path + "/test_audio_1.flac": (
"NOR IS MISTER QUILTER'S MANNER LESS INTERESTING "
"THAN HIS MATTER"
)
}

# Verify the ASR result of each datapack
for pack in self._pipeline.process_dataset(self._test_audio_path):
self.assertEqual(pack.text, target_transcription[pack.pack_name])


if __name__ == "__main__":
unittest.main()

0 comments on commit 0a16879

Please sign in to comment.