Skip to content

Commit

Permalink
Generate: model_kwargs can also be an input to `prepare_inputs_for_…
Browse files Browse the repository at this point in the history
…generation` (#20353)
  • Loading branch information
gante authored Nov 21, 2022
1 parent d21c97c commit 4cf3814
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 11 deletions.
6 changes: 3 additions & 3 deletions src/transformers/generation/flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,9 @@ def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
unused_model_args = []
model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
# `kwargs` if often used to handle optional forward pass inputs like `attention_mask`. If
# `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;)
if "kwargs" in model_args:
# `kwargs`/`model_kwargs` is often used to handle optional forward pass inputs like `attention_mask`. If
# `prepare_inputs_for_generation` doesn't accept them, then a stricter check can be made ;)
if "kwargs" in model_args or "model_kwargs" in model_args:
model_args |= set(inspect.signature(self.__call__).parameters)
for key, value in model_kwargs.items():
if value is not None and key not in model_args:
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/generation/tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1445,9 +1445,9 @@ def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):

unused_model_args = []
model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
# `kwargs` if often used to handle optional forward pass inputs like `attention_mask`. If
# `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;)
if "kwargs" in model_args:
# `kwargs`/`model_kwargs` is often used to handle optional forward pass inputs like `attention_mask`. If
# `prepare_inputs_for_generation` doesn't accept them, then a stricter check can be made ;)
if "kwargs" in model_args or "model_kwargs" in model_args:
model_args |= set(inspect.signature(self.call).parameters)
for key, value in model_kwargs.items():
if value is not None and key not in model_args:
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -981,9 +981,9 @@ def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):

unused_model_args = []
model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
# `kwargs` if often used to handle optional forward pass inputs like `attention_mask`. If
# `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;)
if "kwargs" in model_args:
# `kwargs`/`model_kwargs` is often used to handle optional forward pass inputs like `attention_mask`. If
# `prepare_inputs_for_generation` doesn't accept them, then a stricter check can be made ;)
if "kwargs" in model_args or "model_kwargs" in model_args:
model_args |= set(inspect.signature(self.forward).parameters)
for key, value in model_kwargs.items():
if value is not None and key not in model_args:
Expand Down
8 changes: 6 additions & 2 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3007,8 +3007,8 @@ def test_contrastive_search_batched(self):
self.assertTrue(max_score_diff < 1e-5)

def test_validate_generation_inputs(self):
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
model = AutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-roberta")
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-roberta")

encoder_input_str = "Hello world"
input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
Expand All @@ -3021,3 +3021,7 @@ def test_validate_generation_inputs(self):
with self.assertRaisesRegex(ValueError, "foo"):
fake_model_kwargs = {"foo": "bar"}
model.generate(input_ids, **fake_model_kwargs)

# However, valid model_kwargs are accepted
valid_model_kwargs = {"attention_mask": torch.zeros_like(input_ids)}
model.generate(input_ids, **valid_model_kwargs)

0 comments on commit 4cf3814

Please sign in to comment.