Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: min length can't be larger than max length #16668

Merged
merged 2 commits into from
Apr 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/transformers/generation_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ def generate(
```"""
# set init values
max_length = max_length if max_length is not None else self.config.max_length
min_length = min_length if min_length is not None else self.config.min_length
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
Expand All @@ -269,6 +270,11 @@ def generate(

if decoder_start_token_id is None and self.config.is_encoder_decoder:
raise ValueError("`decoder_start_token_id` has to be defined for encoder-decoder generation.")
if min_length is not None and min_length > max_length:
raise ValueError(
f"Unfeasable length constraints: the minimum length ({min_length}) is larger than the maximum "
f"length ({max_length})"
)

if self.config.is_encoder_decoder:
# add encoder_outputs to model_kwargs
Expand Down Expand Up @@ -389,7 +395,6 @@ def _get_logits_processor(
no_repeat_ngram_size = (
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
)
min_length = min_length if min_length is not None else self.config.min_length
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
forced_bos_token_id = (
forced_bos_token_id if forced_bos_token_id is not None else self.config.forced_bos_token_id
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/generation_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1489,6 +1489,11 @@ def _generate(
if pad_token_id is None and eos_token_id is not None:
logger.warning(f"Setting `pad_token_id` to {eos_token_id} (first `eos_token_id`) to generate sequence")
pad_token_id = eos_token_id
if min_length is not None and min_length > max_length:
raise ValueError(
f"Unfeasable length constraints: the minimum length ({min_length}) is larger than the maximum "
f"length ({max_length})"
)

# 2. Define model inputs
input_ids = self._prepare_model_inputs(input_ids, bos_token_id)
Expand Down
7 changes: 6 additions & 1 deletion src/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,6 @@ def _get_logits_processor(
else self.config.encoder_no_repeat_ngram_size
)
bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
min_length = min_length if min_length is not None else self.config.min_length
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
diversity_penalty = diversity_penalty if diversity_penalty is not None else self.config.diversity_penalty
forced_bos_token_id = (
Expand Down Expand Up @@ -1185,7 +1184,13 @@ def generate(
)
# default to config if still None
max_length = max_length if max_length is not None else self.config.max_length
min_length = min_length if min_length is not None else self.config.min_length

if min_length is not None and min_length > max_length:
raise ValueError(
f"Unfeasable length constraints: the minimum length ({min_length}) is larger than the maximum "
f"length ({max_length})"
)
if input_ids_seq_length >= max_length:
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
logger.warning(
Expand Down
2 changes: 1 addition & 1 deletion tests/generation/test_generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def _get_logits_processor_and_kwargs(
diversity_penalty=None,
):
process_kwargs = {
"min_length": input_length + 1,
"min_length": input_length + 1 if max_length is None else max_length - 1,
"bad_words_ids": [[1, 0]],
"no_repeat_ngram_size": 2,
"repetition_penalty": 1.2,
Expand Down