Skip to content

Commit

Permalink
Fix pipeline test
Browse files Browse the repository at this point in the history
  • Loading branch information
Niels Rogge authored and Niels Rogge committed Nov 18, 2022
1 parent 0d69e1d commit c2eff5f
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@

class ASTFeatureExtractor(SequenceFeatureExtractor):
r"""
Constructs a Audio Spectrogram Transformer feature extractor.
Constructs a Audio Spectrogram Transformer (AST) feature extractor.
This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
most of the main methods. Users should refer to this superclass for more information regarding those methods.
This class extracts mel-filter bank features from raw speech using TorchAudio and applies utterance-level cepstral
mean and variance normalization to the extracted features.
This class extracts mel-filter bank features from raw speech using TorchAudio, pads/truncates them to a fixed
length and normalizes them using a mean and standard deviation.
Args:
feature_size (`int`, *optional*, defaults to 1):
Expand All @@ -47,6 +47,8 @@ class ASTFeatureExtractor(SequenceFeatureExtractor):
The sampling rate at which the audio files should be digitalized expressed in Hertz per second (Hz).
num_mel_bins (`int`, *optional*, defaults to 128):
Number of Mel-frequency bins.
max_length (`int`, *optional*, defaults to 1024):
Maximum length to which to pad/truncate the extracted features.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether or not to normalize the log-Mel features using `mean` and `std`.
mean (`float`, *optional*, defaults to -4.2677393):
Expand All @@ -65,6 +67,7 @@ def __init__(
feature_size=1,
sampling_rate=16000,
num_mel_bins=128,
max_length=1024,
padding_value=0.0,
do_normalize=True,
mean=-4.2677393,
Expand All @@ -74,6 +77,7 @@ def __init__(
):
super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
self.num_mel_bins = num_mel_bins
self.max_length = max_length
self.do_normalize = do_normalize
self.mean = mean
self.std = std
Expand Down Expand Up @@ -121,7 +125,6 @@ def normalize(self, input_values: np.ndarray) -> np.ndarray:
def __call__(
self,
raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
max_length: int = 1024,
sampling_rate: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs
Expand All @@ -133,8 +136,6 @@ def __call__(
raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`):
The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float
values, a list of numpy arrays or a list of list of float values.
max_length (`int`, *optional*, defaults to 1024):
Maximum length of the returned list and optionally padding length (see above).
sampling_rate (`int`, *optional*):
The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass
`sampling_rate` at the forward call to prevent silent errors.
Expand Down Expand Up @@ -175,8 +176,8 @@ def __call__(
if not is_batched:
raw_speech = [raw_speech]

# extract fbank features (padded/truncated to max_length)
features = [self._extract_fbank_features(waveform, max_length=max_length) for waveform in raw_speech]
# extract fbank features and pad/truncate to max_length
features = [self._extract_fbank_features(waveform, max_length=self.max_length) for waveform in raw_speech]

# convert into BatchFeature
padded_inputs = BatchFeature({"input_values": features})
Expand Down
4 changes: 4 additions & 0 deletions tests/pipelines/test_pipelines_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,10 @@ def get_tiny_feature_extractor_from_checkpoint(checkpoint, tiny_config, feature_
if hasattr(tiny_config, "image_size") and feature_extractor:
feature_extractor = feature_extractor.__class__(size=tiny_config.image_size, crop_size=tiny_config.image_size)

# Audio Spectogram Transformer specific.
if feature_extractor.__class__.__name__ == "ASTFeatureExtractor":
feature_extractor = feature_extractor.__class__(max_length=24, num_mel_bins=16)

# Speech2TextModel specific.
if hasattr(tiny_config, "input_feat_per_channel") and feature_extractor:
feature_extractor = feature_extractor.__class__(
Expand Down

0 comments on commit c2eff5f

Please sign in to comment.