Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SpeechEncoderDecoderModel] Fix bug in reshaping labels #16748

Merged

Conversation

sanchit-gandhi
Copy link
Contributor

@sanchit-gandhi sanchit-gandhi commented Apr 13, 2022

Currently, the target labels are reshaped using the view method before being passed into the loss function:

loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1))

The view method requires the Torch Tensor to be contiguous (cf https://pytorch.org/docs/stable/generated/torch.Tensor.view.html).

There are certain operations that are commonly performed on the labels that might cause them to not be contiguous, for example slicing. For speech seq2seq models, if the bos token is appended in the tokenisation step, we cut the bos token by slicing the labels as follows:

# if bos token is appended in previous tokenization step,
# cut bos token here as it's append later anyways
if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
labels = labels[:, 1:]

This slicing operation causes the labels to be not contiguous. If labels are not contiguous, calling labels.view(-1) will throw a RuntimeError. This is demonstrated by the following code snippet:

import torch

labels = torch.ones((2, 10), dtype=torch.int64)
print(f"Contiguous without slicing: {labels.is_contiguous()}")
labels.view(-1)

labels = torch.ones((2, 10), dtype=torch.int64)
labels = labels[:, 1:]
print(f"Contiguous with slicing: {labels.is_contiguous()}")
labels.view(-1)

Output:

Contiguous without slicing: True
Contiguous with slicing: False
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Input In [137], in <cell line: 10>()
      8 labels = labels[:, 1:]
      9 print(f"Contiguous with slicing: {labels.is_contiguous()}")
---> 10 labels.view(-1)

RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

And similarly for the speech encoder-decoder model:

import torch
from transformers import SpeechEncoderDecoderModel

model = SpeechEncoderDecoderModel.from_pretrained('hf-internal-testing/tiny-random-speech-encoder-decoder')

input_values = torch.ones((2, 1000), dtype=torch.float32)

labels = torch.ones((2, 10), dtype=torch.int64)
labels = labels[:, 1:]

outputs = model(input_values, labels=labels)

Output:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Input In [138], in <cell line: 11>()
      8 labels = torch.ones((2, 10), dtype=torch.int64)
      9 labels = labels[:, 1:]
---> 11 outputs = model(input_values, labels=labels)

File ~/venv/lib/python3.8/site-packages/torch/nn/modules/module.py:1110, in Module._call_impl(self, *input, **kwargs)
   1106 # If we don't have any hooks, we want to skip the rest of the logic in
   1107 # this function, and just call forward.
   1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110     return forward_call(*input, **kwargs)
   1111 # Do not call functions when jit is used
   1112 full_backward_hooks, non_full_backward_hooks = [], []

File ~/transformers/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py:560, in SpeechEncoderDecoderModel.forward(self, inputs, attention_mask, decoder_input_ids, decoder_attention_mask, encoder_outputs, past_key_values, decoder_inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, input_values, input_features, return_dict, **kwargs)
    558     logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
    559     loss_fct = CrossEntropyLoss()
--> 560     loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1))
    562 if not return_dict:
    563     if loss is not None:

RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

This PR follows the advice provided in the PyTorch docs by calling the .reshape(...) method instead of .view(...). Calling reshape returns view if the shapes are compatible, and copies (equivalent to calling contiguous()) otherwise.

import torch

labels = torch.ones((2, 10), dtype=torch.int64)
labels = labels[:, 1:]
print(f"Contiguous with slicing: {labels.is_contiguous()}")
labels.reshape(-1)  # no error despite labels being non-contiguous

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 13, 2022

The documentation is not available anymore as the PR was closed or merged.

@patrickvonplaten
Copy link
Contributor

Sounds good!

@patrickvonplaten
Copy link
Contributor

If I remember correctly reshape() == view() if the tensor does not need to call contiguous(), so good for me!

@sanchit-gandhi
Copy link
Contributor Author

sanchit-gandhi commented Apr 14, 2022

If I remember correctly reshape() == view() if the tensor does not need to call contiguous(), so good for me!

Yes, exactly that! Calling reshape() returns view() if the shapes are compatible, and copies (equivalent to calling contiguous()) otherwise.

@sanchit-gandhi sanchit-gandhi merged commit de8b06f into huggingface:main Apr 14, 2022
elusenji pushed a commit to elusenji/transformers that referenced this pull request Jun 12, 2022
@sanchit-gandhi sanchit-gandhi deleted the pt-speech-encoder-decoder branch June 25, 2023 09:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants