Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add accelerate support for ViLT #18683

Merged
merged 25 commits into from
Sep 22, 2022

Conversation

younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented Aug 18, 2022

Motivation

Add bnb support for ViLT model as it has been asked by a user in bitsandbytes-foundation/bitsandbytes#14
. This involved adding accelerate support for this model.

What does this PR do?

Adds _no_split_modules attribute at ViltModel class to support loading the model with device_map=auto. This also implied adding a .to operation inside ViltLayer.
I also redefined accelerate tests since for this model the hidden states are not deterministic. However, it is possible to check the correctness of the operation by checking some output attributes such as logits or pooler_output.

Questions

The test ViltModelIntegrationTest::test_inference_natural_language_visual_reasoning seem to never pass on my machine (aka even without _no_split_modules), is it related to something I am missing? Also it seems that those tests were failing too on the nightly run: https:/huggingface/transformers/runs/7882898294?check_suite_focus=true

cc @NielsRogge @ArthurZucker @ydshieh

- redefine  `accelerate` tests by picking the correct model output
- redefined tests tolerance due to stochasticity
- slow tests are passing except `iltModelIntegrationTest::test_inference_natural_language_visual_reasoning`
- But the test above seem to never pass anyway
@ydshieh
Copy link
Collaborator

ydshieh commented Aug 18, 2022

I also redefined accelerate tests since for this model the hidden states are not deterministic.

Could you elaborate this part in more depth, please?

@younesbelkada
Copy link
Contributor Author

Sure,
The accelerate tests defined in test_modeling_common.py such as test_cpu_offload or test_model_parallism compares the first element of the output dictionary returned by the model (aka base_output[0] and new_output[0] ). This is not correct in the case of ViLT since in most of the cases, the first element of the output list is a hidden_state which appears to be stochastic according to this line. This can be also confirmed also by running several inferences and see that the hidden states are not the same. This is because VilT samples image tokens from a multinomial distribution according to the linked line above.
However, you can still check logits correctness using logits and pooler_output instead of the hidden states, therefore I added those in this PR.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Aug 18, 2022

The documentation is not available anymore as the PR was closed or merged.

@ydshieh
Copy link
Collaborator

ydshieh commented Aug 18, 2022

Thank you, @younesbelkada

Let me bother you a bit more: why logits and pooler_output are deterministic while hidden_states is not. Strange to me, but I must miss some details in this model.

@younesbelkada
Copy link
Contributor Author

younesbelkada commented Aug 18, 2022

No worries!
After deeply looking at the code it appears that the pooler and the other heads that are used by the model take as an input the first element of the variable hidden_states (with the latest having the shape batch_sizex nb_tokens x hidden_dim). With the hidden states containing an embedding for each image patch (stochastic) concatenated together with the text embeddings (non stochastic). Therefore since the heads uses the first text embedding (or CLS token) to perform its downstream task, there is no stochasticity involved there. The stochasticity is only involved on the image patch embedding side.
This can be confirmed as well from the model achitecture (taken from https://arxiv.org/pdf/2102.03334.pdf):
Screenshot 2022-08-18 at 17 27 05
Although one can always take the first hidden states to perform the accelerate tests, I do think the proposed fix is slightly better since it checks the output of modules that would not be checked if we consider only the first hidden states (e.g. Pooler).

@ArthurZucker
Copy link
Collaborator

Should still be deterministic from my intuition, let me have a look

@younesbelkada
Copy link
Contributor Author

After looking into it with @ArthurZucker fixing a manual seed on each accelerate test before each forward pass seems to fix the non passing tests.
We can either keep the changes as they are, or re-define the tests with a seed that is set before each forward pass

@ydshieh
Copy link
Collaborator

ydshieh commented Aug 18, 2022

I am open to the fix of using seed. It makes the change smaller.
But I am confused by the fact that the daily scheduled CI never report this issue (I should check on slack channels too).

https:/huggingface/transformers/runs/7891383348?check_suite_focus=true

The nightly CI you mentioned above uses nightly torch version (i.e. daily built torch) for which more test failures are expected.

@ydshieh
Copy link
Collaborator

ydshieh commented Aug 18, 2022

I think it's better for us to figure out why the relevant tests don't fail on scheduled CI for a long period, but here it fails. This seems contradictory.

@ydshieh
Copy link
Collaborator

ydshieh commented Aug 18, 2022

Could you mention on which hardware you tested, @younesbelkada ?

- replace everything by manually setting the seed
@younesbelkada
Copy link
Contributor Author

younesbelkada commented Aug 18, 2022

For the accelerate tests it is normal that they were passing since the _no_split_modules attribute was never defined in the model class, therefore these tests were never run.
Regarding the second test maybe it's a hardware issue yes! Let me print you the details:

Python version: 3.10.4 (main, Mar 31 2022, 08:41:55) [GCC 7.5.0]
transformers version: 4.22.0.dev0
Torch version: 1.12.1
Cuda available: True
Cuda version: 11.3
CuDNN version: 8302
Number of GPUs available: 2
NCCL version: (2, 10, 3)
DeepSpeed version: None
TensorFlow version: 2.9.1
TF GPUs available: True
Number of TF GPUs available: 2

@younesbelkada
Copy link
Contributor Author

The non passing test pass on my VM with atol=4e-2 by the way

@@ -530,7 +668,8 @@ def test_for_token_classification(self):

# We will verify our results on an image of cute cats
def prepare_img():
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ydshieh maybe it's because of this line? Is ./tests/fixtures/tests_samples/COCO/000000039769.png stored in the Docker image?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's in the repo, no? The CI clones the repo., so it should be there.

Thanks for the _no_split_modules not set part, now clear to me! So good for me to use seed for those tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great my bad then, reverted that change. However this seems to not fix the initial issue

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's probably hardware / library version related, I saw somewhere that different PIL version can yield to different results

Copy link
Contributor Author

@younesbelkada younesbelkada Aug 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check here:

if is_pillow_less_than_9:

My Pillow version is:
Pillow 9.2.0

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we just set the seed, then call super().test_cpu_offload()?

Copy link
Contributor Author

@younesbelkada younesbelkada Aug 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only thing I am worried is that you need to set a seed before each inference step, so more than 2 times. Once before the first forward pass and the second time before the other forward passes that is why I think that we cannot call super().test_cpu_offload() but have to define our own function

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe there is a way to fix a seed without having to re-set it at each inference step

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me check

Copy link
Contributor Author

@younesbelkada younesbelkada Aug 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I propose to add a context manager at 3cda114
with that there is no need to redefine the whole testing functions. We just need a decorator set_reproducible to be set on these functions to ensure reproducibility



# adapted from https://stackoverflow.com/questions/32163436/python-decorator-for-printing-every-line-executed-by-a-function
class set_reproducible(ContextDecorator):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is interesting, does that mean we can use this decorator every time we want to test PyTorch models that involve stochasticity? Cause ViTMAE is another example (as it creates a random boolean mask inside).

I'm currently using torch.manual_seed as seen here. Should we use this decorator instead?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this change could be well applied to ViTMAE - they have exactly the same stochasticity situation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does work for the mentioned test 🥳

@younesbelkada
Copy link
Contributor Author

younesbelkada commented Sep 12, 2022

Thanks a lot for your comments @ydshieh & @NielsRogge

@NielsRogge : I can confirm this decorator works fine for ViTMAE too, I replaced the function you pointed me by:

@set_reproducible(2)
def test_save_load(self):
    super().test_save_load()

And the test was passing so I guess we can safely replace stochastic tests with this kind of trick

I would love to have a quick review from @sgugger or @stas00 if possible 🙏 As I think that this decorator can be useful for future models

Thanks 💪

@ArthurZucker
Copy link
Collaborator

Just a small comment, in terms of performances I think the decorator can be a little bit improve to only run on the model's forward and not on every single forward pass (if I understand correctly here all the functions named forward are affected, instead of just the model's main forward pass)

Improve context manager's efficiency by filtering the forward calls base on the file origin
@ArthurZucker
Copy link
Collaborator

I can confirm that it now fixes the tests for vit_mae as suggested by @NielsRogge. Just added a quick fix to only set the seed on the model's forward. LGTM



# adapted from https://stackoverflow.com/questions/32163436/python-decorator-for-printing-every-line-executed-by-a-function
class set_reproducible(ContextDecorator):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does work for the mentioned test 🥳


def trace_calls(self, frame, event, arg):
# Set the seed when it is a call to a forward function from the model
if "/modeling_" in frame.f_code.co_filename:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a bit tricky to get, if someone knows where to find the documentation to link to the frame attribute , that would be great

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW the frames looked a bit like that :

<frame at 0x564b14f28360, file '/home/arthur_huggingface_co/transformers/src/transformers/models/vilt/modeling_vilt.py', line 761, code forward>
<code object forward at 0x7fe2c0610f50, file "/home/arthur_huggingface_co/transformers/src/transformers/models/vilt/modeling_vilt.py", line 761>
/home/arthur_huggingface_co/transformers/src/transformers/models/vilt/modeling_vilt.py
<frame at 0x564b14d71fe0, file '/home/arthur_huggingface_co/transformers/src/transformers/models/vilt/modeling_vilt.py', line 200, code forward>
<code object forward at 0x7fe2c060d0e0, file "/home/arthur_huggingface_co/transformers/src/transformers/models/vilt/modeling_vilt.py", line 200>
/home/arthur_huggingface_co/transformers/src/transformers/models/vilt/modeling_vilt.py
<frame at 0x564b1432e910, file '/home/arthur_huggingface_co/transformers/src/transformers/models/vilt/modeling_vilt.py', line 265, code forward>
<code object forward at 0x7fe2c060d2f0, file "/home/arthur_huggingface_co/transformers/src/transformers/models/vilt/modeling_vilt.py", line 265>
/home/arthur_huggingface_co/transformers/src/transformers/models/vilt/modeling_vilt.py
<frame at 0x564b14ce68b0, file '/opt/conda/envs/jukebox/lib/python3.8/site-packages/torch/nn/modules/sparse.py', line 157, code forward>
<code object forward at 0x7fe40cd5dc90, file "/opt/conda/envs/jukebox/lib/python3.8/site-packages/torch/nn/modules/sparse.py", line 157>
/opt/conda/envs/jukebox/lib/python3.8/site-packages/torch/nn/modules/sparse.py
<frame at 0x564b14ce68b0, file '/opt/conda/envs/jukebox/lib/python3.8/site-packages/torch/nn/modules/sparse.py', line 157, code forward>
<code object forward at 0x7fe40cd5dc90, file "/opt/conda/envs/jukebox/lib/python3.8/site-packages/torch/nn/modules/sparse.py", line 157>
/opt/conda/envs/jukebox/lib/python3.8/site-packages/torch/nn/modules/sparse.py
<frame at 0x564b14ce68b0, file '/opt/conda/envs/jukebox/lib/python3.8/site-packages/torch/nn/modules/sparse.py', line 157, code forward>
<code object forward at 0x7fe40cd5dc90, file "/opt/conda/envs/jukebox/lib/python3.8/site-packages/torch/nn/modules/sparse.py", line 157>
/opt/conda/envs/jukebox/lib/python3.8/site-packages/torch/nn/modules/sparse.py

src/transformers/testing_utils.py Outdated Show resolved Hide resolved
small nit

Co-authored-by: Younes Belkada <[email protected]>
def trace_calls(self, frame, event, arg):
# Set the seed when it is a call to a forward function from the model
if "/modeling_" in frame.f_code.co_filename:
return self.set_seed
Copy link
Contributor

@stas00 stas00 Sep 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if I understand it correctly, you're using this hack to force the same seed before every python call, correct?

What is it that you're testing then, since this is not the normal behavior, i.e. normal code base sets the seed once to be reproducible.

In other words why setting the seed at the beginning of the program doesn't lead to a reproducible outcome? Isn't that an indicator of a bug in the proposed software and this hack tries to hide it?

and if you're really going to keep this can we call it something with a better mnemonic like with reset_seed_on_every_frame? Since set_reproducible to me means setting a seed once at the enter and leaving it alone.

and kudos on finding a very cool way to hack this functionality in, @younesbelkada! I'm just not sure it is needed.

Copy link
Contributor Author

@younesbelkada younesbelkada Sep 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @stas00 for your reply!

Regarding your first point yes but, we set the seed before each forward call from a Pytorch module instead of any python call (from @ArthurZucker 's modification).

Regarding your second point, I was also very confused at the beginning but setting the pytorch seed before each forward seems to be needed to ensure reproducibility for stochastic operations. E.g. this code snippet shows that:

>>> import torch
>>> torch.manual_seed(0)
<torch._C.Generator object at 0x10576b6f0>
>>> torch.randn(1, 3)
tensor([[ 1.5410, -0.2934, -2.1788]])
>>> torch.randn(1, 3)
tensor([[ 0.5684, -1.0845, -1.3986]])
>>> torch.manual_seed(0)
<torch._C.Generator object at 0x10576b6f0>
>>> torch.randn(1, 3)
tensor([[ 1.5410, -0.2934, -2.1788]])

You need to set the seed before each stochastic call to ensure reproducibility. From my understanding that is why we need to call the seed setting on stochastic tests (e.g. here where the seed is called before each forward pass). I am still not sure why calling the seed once is not sufficient but I think that it might be because the seed has been 'consumed' when a stochastic operation is performed (This has been also observed on Jukebox integration I think).

For the last point yes! I think that we should change the decorator's name to make it more understandable

Copy link
Contributor

@stas00 stas00 Sep 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I missed that nuance about forward - thank you for explaining.

Then reset_seed_before_every_forward would be more appropriate.

But there is already a mechanism for doing this - this is called register_forward_hook
https://pytorch.org/docs/stable/generated/torch.jit.ScriptModule.html?highlight=hook#torch.jit.ScriptModule.register_forward_hook

so you don't need any cool hacks to accomplish that.

e.g. see:

self.register_forward_hook()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very interesting, both option work, and maybe register_forward_pre_hook(hook) is more appropriate.
IMO having a decorator will allow anyone to use it without modifying their test function, but we might have a lot of decorators.

We could define the following

# testing_utils.py
def pre_hook_set_seed(seed):
    def pre_hook(module, input):
        set_seed(seed)
    return pre_hook

And call

# test_modeling_vilt.py
handle = self.register_forward_pre_hook(pre_hook_set_seed)
...
handle.remove()

I guess its pretty similar but requires a bit more code. 😄
@stas00 What do you think would work best long term?

Copy link
Contributor

@stas00 stas00 Sep 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the pytorch call is designated for this specific functionality so it's a much cleaner solution, IMHO.

You can just as well implement it as a context manager so enter/exit for the above sample.

Also I was thinking that if ViLT needs this functionality often, perhaps the hook should be added directly into its modeling file?

But testing_utils is perfect too if it's generic and other models might need it too.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems a bit overengineered for three tests where we could just insert two set_seed calls to ensure the same results (even registering a pre-foward hook seems over-engineered TBH). Or am I missing something?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right 😅 But since VitMae also requires that hack, we though that long term, made more sense to have an easy to use decorator! You have to insert the set_seed before every forward (so in the for loop that is testing) for each of the test so the original idea to have a decorator followed Yih-dar's comment.
Anyway will do as you think works better!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While I was excited by the idea at first, reading Stas' comments and thinking more about it, I think we should stick to something simple :-)

Copy link
Contributor

@stas00 stas00 Sep 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does it have to be a test decorator? what would be wrong with calling it from the test?

Also I was thinking that if this might be used by users - e.g. to debug their code, won't it be a good idea to make this functionality as part of the model itself?

So basically one would call:

model = from_pretrained(...)
model.reset_seed_before_every_forward(seed=42) # this will inject the prehook

and so now it's accessible to tests and users too. So it's no longer a hack to solve testing issues, but an official debug feature.

If multiple models require that it could be a mixin class and they can inherit from it, so it won't make any noise in the modeling file.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also fine with this.

@younesbelkada
Copy link
Contributor Author

younesbelkada commented Sep 17, 2022

Thank you all for your comments!
Reverted the changes related to the context manager and kept the tests to be as simple as possible!
Can confirm all the slow tests pass now (expect for this test but this is expected as stated in the comment)
I propose to address the stochasticity issue potentially in a follow-up PR, and keep this PR only for its main goal: add ViLT support for accelerate

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the overloaded tests are just adding two lines for setting the seed. That modification can definitely goo in the base test (sorry if I was unclear in my comments before) :-)

- add the attribute `reset_seed_before_every_forward` on `accelerate` tests
- this supports stochastic tests
@younesbelkada
Copy link
Contributor Author

younesbelkada commented Sep 20, 2022

@sgugger thanks for your comment! And sorry for my late reply
I should have addressed the proposed suggestion in the commit 43b087a and can confirm the slow tests for this model are passing! I hesitated between this solution and having an attribute on ModelTesterMixin but I thought that this solution is simpler to understand for future users (the ModelTesterMixin has already a lot of attributes)

I can also take care of opening a follow-up PR to update the tests of ViTMAE to make sure these changes are consistent for stochastic models tests inside transformers

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't mean adding a new argument to the test. We can have the seed set in the tests for all models. :-)

tests/test_modeling_common.py Outdated Show resolved Hide resolved
tests/test_modeling_common.py Outdated Show resolved Hide resolved
tests/test_modeling_common.py Outdated Show resolved Hide resolved
younesbelkada and others added 2 commits September 21, 2022 22:15
- set seed before every forward on `accelerate` tests
- `make fixup`
@younesbelkada
Copy link
Contributor Author

Perfect thanks! @sgugger
Should have addressed the changes now! Let me know if you still need some modifications

Comment on lines 519 to 532
@require_accelerate
@require_torch_gpu
def test_cpu_offload(self):
super().test_cpu_offload()

@require_accelerate
@require_torch_gpu
def test_disk_offload(self):
super().test_disk_offload()

@require_accelerate
@require_torch_multi_gpu
def test_model_parallelism(self):
super().test_model_parallelism()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None of this is needed anymore ;-p

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops yes ahah forgot to remove them

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be fixed in 19515d2

- tests will be correctly called from the base class
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot!

@younesbelkada younesbelkada merged commit 4d0f8c0 into huggingface:main Sep 22, 2022
oneraghavan pushed a commit to oneraghavan/transformers that referenced this pull request Sep 26, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants