Skip to content

Commit

Permalink
[FlaxWav2Vec2Model] Fix bug in attention mask (#16725)
Browse files Browse the repository at this point in the history
* [FlaxWav2Vec2Model] Fix bug in attention mask

* more fixes

* add (Flax)SpeechEncoderDecoderModel PT-FX cross-test
  • Loading branch information
sanchit-gandhi authored Apr 12, 2022
1 parent 6adefba commit a960406
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 17 deletions.
35 changes: 18 additions & 17 deletions src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,7 +920,7 @@ def __call__(
def _get_feat_extract_output_lengths(
self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
):
return self.module._get_feat_extract_output_lengths(input_lengths)
return self.module._get_feat_extract_output_lengths(input_lengths, add_adapter=add_adapter)


class FlaxWav2Vec2Module(nn.Module):
Expand Down Expand Up @@ -956,15 +956,10 @@ def __call__(

# make sure that no loss is computed on padded inputs
if attention_mask is not None:
# compute real output lengths according to convolution formula
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1).astype("i4"))

attention_mask = jnp.zeros(extract_features.shape[:2], dtype=self.dtype)

# these two operations makes sure that all values
# before the output lengths indices are attended to
attention_mask = attention_mask.at[jnp.arange(attention_mask.shape[0]), output_lengths - 1].set(1)
attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool")
# compute reduced attention_mask corresponding to feature vectors
attention_mask = self._get_feature_vector_attention_mask(
extract_features.shape[1], attention_mask, add_adapter=False
)

hidden_states, extract_features = self.feature_projection(extract_features, deterministic=deterministic)
if mask_time_indices is not None: # apply SpecAugment along time axis with given indices
Expand Down Expand Up @@ -1034,12 +1029,10 @@ def _get_feature_vector_attention_mask(
batch_size = attention_mask.shape[0]

attention_mask = jnp.zeros((batch_size, feature_vector_length), dtype=attention_mask.dtype)
# these two operations makes sure that all values before the output lengths idxs are attended to
idx = (jnp.arange(attention_mask.shape[0]), output_lengths - 1)
attention_mask = attention_mask.at[idx].set(1)
attention_mask = jnp.flip(jnp.flip(attention_mask, axis=-1).cumsum(axis=-1), axis=-1)

attention_mask = jnp.array(attention_mask, dtype=bool)
# these two operations makes sure that all values
# before the output lengths indices are attended to
attention_mask = attention_mask.at[jnp.arange(attention_mask.shape[0]), output_lengths - 1].set(1)
attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool")
return attention_mask


Expand Down Expand Up @@ -1286,11 +1279,15 @@ def __call__(
attentions=outputs.attentions,
)

def _get_feat_extract_output_lengths(self, input_lengths: Union[jnp.ndarray, int]):
def _get_feat_extract_output_lengths(
self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
):
"""
Computes the output length of the convolutional layers
"""

add_adapter = self.config.add_adapter if add_adapter is None else add_adapter

def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
Expand All @@ -1299,6 +1296,10 @@ def _conv_out_length(input_length, kernel_size, stride):
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)

if add_adapter:
for _ in range(self.config.num_adapter_layers):
input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)

return input_lengths


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,12 @@ def test_pt_flax_equivalence(self):
self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict)
self.check_equivalence_flax_to_pt(config, decoder_config, inputs_dict)

# check `add_adapter` works as expected
config.add_adapter = True
self.assertTrue(config.add_adapter)
self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict)
self.check_equivalence_flax_to_pt(config, decoder_config, inputs_dict)

@slow
def test_real_model_save_load_from_pretrained(self):
model_2 = self.get_pretrained_model()
Expand Down

0 comments on commit a960406

Please sign in to comment.