From 8259910d621b01d857b57a3ca2cc651acc1e3407 Mon Sep 17 00:00:00 2001 From: jsnfly Date: Mon, 18 Apr 2022 09:10:04 +0200 Subject: [PATCH 1/3] Add passing encoder_outputs as tuple to existing test --- .../test_modeling_encoder_decoder.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/encoder_decoder/test_modeling_encoder_decoder.py b/tests/encoder_decoder/test_modeling_encoder_decoder.py index 7e1d3b0c9774c4..46a1bf7b686a5a 100644 --- a/tests/encoder_decoder/test_modeling_encoder_decoder.py +++ b/tests/encoder_decoder/test_modeling_encoder_decoder.py @@ -142,6 +142,22 @@ def check_encoder_decoder_model( outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,)) ) + # Test passing encoder_outputs as tuple. + encoder_outputs = (encoder_hidden_states,) + outputs_encoder_decoder = enc_dec_model( + encoder_outputs=encoder_outputs, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + ) + + self.assertEqual( + outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)) + ) + self.assertEqual( + outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,)) + ) + def check_encoder_decoder_model_from_pretrained_using_model_paths( self, config, From 4fad2abf136213241f5aa6d2410f525ee2a73d71 Mon Sep 17 00:00:00 2001 From: jsnfly Date: Mon, 18 Apr 2022 09:18:40 +0200 Subject: [PATCH 2/3] Add check for tuple --- .../models/encoder_decoder/modeling_encoder_decoder.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py index 7bad5f98d3778e..972b80db7b4dbe 100644 --- a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py @@ -22,7 +22,7 @@ from torch.nn import CrossEntropyLoss from ...configuration_utils import PretrainedConfig -from ...modeling_outputs import Seq2SeqLMOutput +from ...modeling_outputs import BaseModelOutput, Seq2SeqLMOutput from ...modeling_utils import PreTrainedModel from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ..auto.configuration_auto import AutoConfig @@ -494,6 +494,8 @@ def forward( return_dict=return_dict, **kwargs_encoder, ) + elif isinstance(encoder_outputs, tuple): + encoder_outputs = BaseModelOutput(*encoder_outputs) encoder_hidden_states = encoder_outputs[0] From 45a9a18df6eae4a231677865c8463fdfd8cc28b8 Mon Sep 17 00:00:00 2001 From: jsnfly Date: Mon, 18 Apr 2022 09:28:11 +0200 Subject: [PATCH 3/3] Add check for tuple also for speech and vision --- .../speech_encoder_decoder/modeling_speech_encoder_decoder.py | 4 +++- .../vision_encoder_decoder/modeling_vision_encoder_decoder.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py index db5037eb531ab7..1453cf9370d098 100644 --- a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +++ b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py @@ -22,7 +22,7 @@ from torch.nn import CrossEntropyLoss from ...configuration_utils import PretrainedConfig -from ...modeling_outputs import Seq2SeqLMOutput +from ...modeling_outputs import BaseModelOutput, Seq2SeqLMOutput from ...modeling_utils import PreTrainedModel from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ..auto.configuration_auto import AutoConfig @@ -514,6 +514,8 @@ def forward( return_dict=return_dict, **kwargs_encoder, ) + elif isinstance(encoder_outputs, tuple): + encoder_outputs = BaseModelOutput(*encoder_outputs) encoder_hidden_states = encoder_outputs[0] diff --git a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py index 999ba2d2db8933..37072270a567d8 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py @@ -22,7 +22,7 @@ from torch.nn import CrossEntropyLoss from ...configuration_utils import PretrainedConfig -from ...modeling_outputs import Seq2SeqLMOutput +from ...modeling_outputs import BaseModelOutput, Seq2SeqLMOutput from ...modeling_utils import PreTrainedModel from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ..auto.configuration_auto import AutoConfig @@ -466,6 +466,8 @@ def forward( return_dict=return_dict, **kwargs_encoder, ) + elif isinstance(encoder_outputs, tuple): + encoder_outputs = BaseModelOutput(*encoder_outputs) encoder_hidden_states = encoder_outputs[0]