From 4cf38148dc98b3df1df6eb2f06e4f02448026b19 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 21 Nov 2022 16:20:27 +0000 Subject: [PATCH] Generate: `model_kwargs` can also be an input to `prepare_inputs_for_generation` (#20353) --- src/transformers/generation/flax_utils.py | 6 +++--- src/transformers/generation/tf_utils.py | 6 +++--- src/transformers/generation/utils.py | 6 +++--- tests/generation/test_utils.py | 8 ++++++-- 4 files changed, 15 insertions(+), 11 deletions(-) diff --git a/src/transformers/generation/flax_utils.py b/src/transformers/generation/flax_utils.py index d9d2eae8795cfe..5d936ce5b1dccd 100644 --- a/src/transformers/generation/flax_utils.py +++ b/src/transformers/generation/flax_utils.py @@ -194,9 +194,9 @@ def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): """Validates model kwargs for generation. Generate argument typos will also be caught here.""" unused_model_args = [] model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters) - # `kwargs` if often used to handle optional forward pass inputs like `attention_mask`. If - # `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;) - if "kwargs" in model_args: + # `kwargs`/`model_kwargs` is often used to handle optional forward pass inputs like `attention_mask`. If + # `prepare_inputs_for_generation` doesn't accept them, then a stricter check can be made ;) + if "kwargs" in model_args or "model_kwargs" in model_args: model_args |= set(inspect.signature(self.__call__).parameters) for key, value in model_kwargs.items(): if value is not None and key not in model_args: diff --git a/src/transformers/generation/tf_utils.py b/src/transformers/generation/tf_utils.py index 3758431ff91c35..e437e55f48a36c 100644 --- a/src/transformers/generation/tf_utils.py +++ b/src/transformers/generation/tf_utils.py @@ -1445,9 +1445,9 @@ def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): unused_model_args = [] model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters) - # `kwargs` if often used to handle optional forward pass inputs like `attention_mask`. If - # `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;) - if "kwargs" in model_args: + # `kwargs`/`model_kwargs` is often used to handle optional forward pass inputs like `attention_mask`. If + # `prepare_inputs_for_generation` doesn't accept them, then a stricter check can be made ;) + if "kwargs" in model_args or "model_kwargs" in model_args: model_args |= set(inspect.signature(self.call).parameters) for key, value in model_kwargs.items(): if value is not None and key not in model_args: diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 997e2a5769a81f..3d945b2be37a74 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -981,9 +981,9 @@ def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): unused_model_args = [] model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters) - # `kwargs` if often used to handle optional forward pass inputs like `attention_mask`. If - # `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;) - if "kwargs" in model_args: + # `kwargs`/`model_kwargs` is often used to handle optional forward pass inputs like `attention_mask`. If + # `prepare_inputs_for_generation` doesn't accept them, then a stricter check can be made ;) + if "kwargs" in model_args or "model_kwargs" in model_args: model_args |= set(inspect.signature(self.forward).parameters) for key, value in model_kwargs.items(): if value is not None and key not in model_args: diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 5d9c9fbad2fc27..a03f0d12b9d147 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -3007,8 +3007,8 @@ def test_contrastive_search_batched(self): self.assertTrue(max_score_diff < 1e-5) def test_validate_generation_inputs(self): - tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") - model = AutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-roberta") + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-roberta") encoder_input_str = "Hello world" input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids @@ -3021,3 +3021,7 @@ def test_validate_generation_inputs(self): with self.assertRaisesRegex(ValueError, "foo"): fake_model_kwargs = {"foo": "bar"} model.generate(input_ids, **fake_model_kwargs) + + # However, valid model_kwargs are accepted + valid_model_kwargs = {"attention_mask": torch.zeros_like(input_ids)} + model.generate(input_ids, **valid_model_kwargs)