From 4d0f8c05f5acc277176985a3b93d283d3867f0fd Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Thu, 22 Sep 2022 13:14:39 +0200 Subject: [PATCH] Add `accelerate` support for ViLT (#18683) --- src/transformers/models/vilt/modeling_vilt.py | 3 ++- src/transformers/testing_utils.py | 4 ---- tests/models/vilt/test_modeling_vilt.py | 2 -- tests/test_modeling_common.py | 10 ++++++++++ 4 files changed, 12 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/vilt/modeling_vilt.py b/src/transformers/models/vilt/modeling_vilt.py index eefa8f641ff187..020aa9a6afc647 100755 --- a/src/transformers/models/vilt/modeling_vilt.py +++ b/src/transformers/models/vilt/modeling_vilt.py @@ -491,7 +491,7 @@ def forward(self, hidden_states, attention_mask=None, head_mask=None, output_att outputs = self_attention_outputs[1:] # add self attentions if we output attention weights # first residual connection - hidden_states = attention_output + hidden_states + hidden_states = attention_output + hidden_states.to(attention_output.device) # in ViLT, layernorm is also applied after self-attention layer_output = self.layernorm_after(hidden_states) @@ -573,6 +573,7 @@ class ViltPreTrainedModel(PreTrainedModel): config_class = ViltConfig base_model_prefix = "vilt" supports_gradient_checkpointing = True + _no_split_modules = ["ViltSelfAttention"] def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 2e99a76232c27c..b14ed5d589c593 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -772,7 +772,6 @@ class CaptureStd: ```""" def __init__(self, out=True, err=True, replay=True): - self.replay = replay if out: @@ -1122,7 +1121,6 @@ def get_auto_remove_tmp_dir(self, tmp_dir=None, before=None, after=None): tmp_dir(`string`): either the same value as passed via *tmp_dir* or the path to the auto-selected tmp dir """ if tmp_dir is not None: - # defining the most likely desired behavior for when a custom path is provided. # this most likely indicates the debug mode where we want an easily locatable dir that: # 1. gets cleared out before the test (if it already exists) @@ -1200,7 +1198,6 @@ def python_one_liner_max_rss(self, one_liner_str): return max_rss def tearDown(self): - # get_auto_remove_tmp_dir feature: remove registered temp dirs for path in self.teardown_tmp_dirs: shutil.rmtree(path, ignore_errors=True) @@ -1472,7 +1469,6 @@ def tee(line, sink, pipe, label=""): def execute_subprocess_async(cmd, env=None, stdin=None, timeout=180, quiet=False, echo=True) -> _RunOutput: - loop = asyncio.get_event_loop() result = loop.run_until_complete( _stream_subprocess(cmd, env=env, stdin=stdin, timeout=timeout, quiet=quiet, echo=echo) diff --git a/tests/models/vilt/test_modeling_vilt.py b/tests/models/vilt/test_modeling_vilt.py index 82aa0767470eba..280eff70d979af 100644 --- a/tests/models/vilt/test_modeling_vilt.py +++ b/tests/models/vilt/test_modeling_vilt.py @@ -215,7 +215,6 @@ def prepare_pixel_values(self): @require_torch class ViltModelTest(ModelTesterMixin, unittest.TestCase): - all_model_classes = ( ( ViltModel, @@ -512,7 +511,6 @@ def test_model_from_pretrained(self): @require_torch class ViltForImagesAndTextClassificationModelTest(ViltModelTest, unittest.TestCase): - all_model_classes = (ViltForImagesAndTextClassification,) if is_torch_available() else () def setUp(self): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 082f2a8a9057f9..42ecad03c6aee9 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2307,6 +2307,7 @@ def test_disk_offload(self): inputs_dict = self._prepare_for_class(inputs_dict, model_class) model = model_class(config).eval() model = model.to(torch_device) + torch.manual_seed(0) base_output = model(**inputs_dict) model_size = compute_module_sizes(model)[""] @@ -2324,6 +2325,7 @@ def test_disk_offload(self): ) self.check_device_map_is_respected(new_model, new_model.hf_device_map) + torch.manual_seed(0) new_output = new_model(**inputs_dict) self.assertTrue(torch.allclose(base_output[0], new_output[0])) @@ -2340,6 +2342,8 @@ def test_cpu_offload(self): inputs_dict = self._prepare_for_class(inputs_dict, model_class) model = model_class(config).eval() model = model.to(torch_device) + + torch.manual_seed(0) base_output = model(**inputs_dict) model_size = compute_module_sizes(model)[""] @@ -2355,6 +2359,8 @@ def test_cpu_offload(self): self.assertSetEqual(set(new_model.hf_device_map.values()), {0, "cpu"}) self.check_device_map_is_respected(new_model, new_model.hf_device_map) + + torch.manual_seed(0) new_output = new_model(**inputs_dict) self.assertTrue(torch.allclose(base_output[0], new_output[0])) @@ -2371,6 +2377,8 @@ def test_model_parallelism(self): inputs_dict = self._prepare_for_class(inputs_dict, model_class) model = model_class(config).eval() model = model.to(torch_device) + + torch.manual_seed(0) base_output = model(**inputs_dict) model_size = compute_module_sizes(model)[""] @@ -2386,6 +2394,8 @@ def test_model_parallelism(self): self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1}) self.check_device_map_is_respected(new_model, new_model.hf_device_map) + + torch.manual_seed(0) new_output = new_model(**inputs_dict) self.assertTrue(torch.allclose(base_output[0], new_output[0]))