Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BART can only generate a maximum of 20 tokens #16622

Closed
ayaka14732 opened this issue Apr 6, 2022 · 3 comments · Fixed by #16668
Closed

BART can only generate a maximum of 20 tokens #16622

ayaka14732 opened this issue Apr 6, 2022 · 3 comments · Fixed by #16668
Assignees

Comments

@ayaka14732
Copy link
Contributor

ayaka14732 commented Apr 6, 2022

Environment info

  • transformers version: 4.18.0.dev0
  • Platform: Linux-5.11.0-1018-gcp-x86_64-with-glibc2.31
  • Python version: 3.10.4
  • Huggingface_hub version: 0.4.0
  • PyTorch version (GPU?): 1.11.0+cu102 (False)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): 0.4.1 (tpu)
  • Jax version: 0.3.4
  • JaxLib version: 0.3.2
  • Using GPU in script?: No
  • Using distributed or parallel set-up in script?: No

Who can help

@patil-suraj @patrickvonplaten

Information

Model I am using: BART

To reproduce

Steps to reproduce the behavior:

from transformers import BartTokenizer, BartForConditionalGeneration

tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')

sentences = ['At the launch of the latest report by the Intergovernmental Panel on Climate Change, on the mitigation of climate change, the UN Secretary-General called for an urgent shift of investments and subsidies from fossil fuels to renewable energy, warning that investing in new fossil fuels infrastructure is moral and economic madness.']

inputs = tokenizer(sentences, return_tensors='pt')
print('Input shape:', inputs.input_ids.shape)

generate_ids = model.generate(inputs.input_ids, num_beams=5, min_length=50)
print('Generated shape:', generate_ids.shape)

print(tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0])

Output:

Input shape: torch.Size([1, 60])
Generated shape: torch.Size([1, 20])
At the launch of the latest report by the Intergovernmental Panel on Climate Change, on

Expected behavior

The output should not be truncated.

Actual behavior

The output is truncated.

Note that the output is truncated even if min_length=50 is specified.

@gante
Copy link
Member

gante commented Apr 6, 2022

Hi @ayaka14732 👋 That happens because the stopping conditions take precedence over anything else. The default for max_length is 20, so that's why you see 20 generated tokens. In your example, if you rewrite the generate line into generate_ids = model.generate(inputs.input_ids, num_beams=5, min_length=50, max_length=100), you'll get the results you expect.

@patrickvonplaten @patil-suraj should we raise an exception in this case? (min_length > max_length)

@gante gante self-assigned this Apr 6, 2022
ayaka14732 added a commit to ayaka14732/bart-base-jax that referenced this issue Apr 6, 2022
@patrickvonplaten
Copy link
Contributor

@gante, yes this would work for me! Let's maybe do this in generate() before we just into the sub-generation methods

@gante
Copy link
Member

gante commented Apr 11, 2022

@ayaka14732 if you pull from master (or install transformers==4.19.0.dev0), you shall see an informative Exception if you try to run your original script.

Thank you for reporting this issue :D

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants