[FlaxWav2Vec2Model] Fix bug in attention mask #16725
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Currently, the FlaxWav2Vec2 reduced attention mask is computed by calling the function
_get_feat_extract_output_lengths
, without explicit specification of whether an (optional) adapter module is used:transformers/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py
Lines 959 to 960 in 924484e
By default, if
add_adapter
isNone
, the booleanadd_adapter
will be set based on theconfig
:transformers/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py
Lines 1001 to 1008 in 924484e
For this default setting, if the model contains an adapter module, then
add_adapter
will be set toTrue
. This results in the convolutional formula including the downsampling performed by the convolutional layers in the feature extractor and the adapter module.However, since the reduced attention mask is required for the encoder module, it should be computed based on the convolutional layers of the feature extractor only, and not those of the subsequent adapter module. This is highlighted by the PyTorch Wav2Vec2 modelling code:
transformers/src/transformers/models/wav2vec2/modeling_wav2vec2.py
Lines 1350 to 1354 in 924484e
The following code snippet demonstrates the effect of this bug by means of a PyTorch-Flax cross-test:
Output prior to fix:
Output following fix: