Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attempting to test automatically the _keys_to_ignore. #20042

Merged
merged 28 commits into from
Nov 9, 2022

Conversation

Narsil
Copy link
Contributor

@Narsil Narsil commented Nov 3, 2022

What does this PR do?

This adds a new part of the tied_weights test that aims at detecting automatically
when _keys_to_ignore is incorrectly set.

_keys_to_ignore aims to ignore weights that are supposed to be tied in the
final model, meaning it's OK if the parameter is missing from the on-disk weights.
The weights are really empty during the load, but they end up being tied afterwards
so we should ignore them during the load if they are missing.

The test also aims to detect _keys_to_ignore that might have been set but
could be misleading because the parameters are actually NOT tied anymore.

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Nov 3, 2022

The documentation is not available anymore as the PR was closed or merged.

@@ -1153,7 +1155,7 @@ def forward(
class BertLMHeadModel(BertPreTrainedModel):

_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", r"cls.predictions.decoder.weight"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me review this slowly 🙏 and verify a few things. But do you think r"predictions.decoder.bias" is a mistake and should be r"cls.predictions.decoder.bias" ?

I am going to check myself anyway.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A re match is done, so while the exact name is indeed r"cls.predictions.decoder.bias", this works. But would be great to fix just in case one day a weight named predictions.decoder.bias that should not be ignore appears ;-)

@@ -1153,7 +1155,7 @@ def forward(
class BertLMHeadModel(BertPreTrainedModel):

_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", r"cls.predictions.decoder.weight"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me review this slowly 🙏 and verify a few things. But do you think r"predictions.decoder.bias" is a mistake and should be r"cls.predictions.decoder.bias" ?

I am going to check myself anyway.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe. I'm not sure but the test does yell if we're hiding a valid key (we I don't try to yell when we have an unused _keys_to_ignore. Want me to try and add it to the test ?

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting way of testing those attributes!

@@ -1153,7 +1155,7 @@ def forward(
class BertLMHeadModel(BertPreTrainedModel):

_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", r"cls.predictions.decoder.weight"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A re match is done, so while the exact name is indeed r"cls.predictions.decoder.bias", this works. But would be great to fix just in case one day a weight named predictions.decoder.bias that should not be ignore appears ;-)

@@ -1538,6 +1538,36 @@ def check_same_values(layer_1, layer_2):
# # Check that the embedding layer and decoding layer are the same in size and in value
# self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape)
# self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))
with tempfile.TemporaryDirectory() as d:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment here on what's check? Also maybe it should be in its own test?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving to its own test + comment on how it works.

@@ -2422,7 +2422,8 @@ def _fix_key(key):

if remove_prefix_from_model:
expected_keys_not_prefixed = [s for s in expected_keys if not s.startswith(prefix)]
expected_keys = [".".join(s.split(".")[1:]) if s.startswith(prefix) else s for s in expected_keys]
_prefix = f"{prefix}."
expected_keys = [".".join(s.split(".")[1:]) if s.startswith(_prefix) else s for s in expected_keys]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would fail for esm/esmfold since esm is a prefix for esmfold it would remove emsfold. prefix instead of the intended esm.

This was silent before, but with the new test addition the following modification would fail to get key[-1] when trying to identify missing keys.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like the right fix, thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, there was a second place to update the fix to make the splinter test pass.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -1485,11 +1485,24 @@ def test_correct_missing_keys(self):
base_model_prefix = model.base_model_prefix

if hasattr(model, base_model_prefix):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an important new modification.

This test tries to verify that we're correctly warning when we're saving the "core" of a model, and failing to save the head parameters, and that we yell when loading the head from this.

This test is great, and different from the newly added one.
But, since we fixed a lot of Decoder head which are merely a copy of the embeddings matrix, then they are a head, but they DO NOT have new extra parameters. hence this test would fail.

There is self.test_missing_keys flag to deactivate that particular test on the model testing. But it would deactivate ALL potential heads (as far as I understand), potentially opening gaps in our testing.

So I modified this test to try and detect those edge cases. I don't like putting logic like that in test as it makes reasoning hard but I failed to see another way.

This test verifies that the head does have extra parameters before actually doing the verification.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sure you account in the test for upcoming models which have weights with the same name in the model with head and the base model (Niels is adding one in #19981)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

" tying its weights",
)

# self.assertEqual(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I deactivated this part because token_type_ids and position_ids would yell pretty much everywhere.

I have no idea of the perfect solution for this.
I think position_ids should be a non-parameter if it's not, or a parameter if it is learned or something.
Given how many models it affect we could add an escape hatch for just these peculiar names.

But I think flagging actually too inclusive _keys_to_ignore can provide good value, especially since we're using regexp so it's easy to make a mistake, or modify the model later in such a way that it doesn't reflect properly anymore in those keys.

@Narsil
Copy link
Contributor Author

Narsil commented Nov 4, 2022

@ydshieh the splinter test failing is normal ?

FAILED tests/models/splinter/test_modeling_splinter.py::SplinterModelTest::test_save_load_fast_init_from_base - AssertionError: 3069.73388671875 not less than or equal to 0.001 : splinter_qass.query_start_transform.dense.weight not identical

@ydshieh
Copy link
Collaborator

ydshieh commented Nov 4, 2022

@Narsil

I am not able to reproduce the splinter test failure you mentioned above with current main on a GCP GPU VM. Could you provide more information about your environment and how you launched the test?

@Narsil
Copy link
Contributor Author

Narsil commented Nov 4, 2022

Narsil added a commit to huggingface/safetensors that referenced this pull request Nov 4, 2022
in case they will trigger warnigns in the `transformers` side.
Even if the model is perfectly fine, core maintainers fear an influx
of opened issues.

This is perfectly legit.
On the `transformers` side fixes are on the way: huggingface/transformers#20042

We can wait for this PR to hit `main` before communicating super widely.

In the meantime this script of convertion will now prevent converting
models that would trigger such warnings (so the output of the script
**will** depend on the `transformers` freshness.
Copy link
Collaborator

@ydshieh ydshieh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a few comments - only for the test test_tied_model_weights_key_ignore so far.

@@ -1539,6 +1552,44 @@ def check_same_values(layer_1, layer_2):
# self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape)
# self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))

def test_tied_model_weights_key_ignore(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest to make the name more generic (i.e. not using tied_model in the name)

The check in this test is not really specific to tied weights, although they are one of the main source of missing keys.

test_model_weights_key_ignore + a comment to explain its objective would be fine.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my mind it is only about tied weights (because it's the only mechanism I know that can create the weight sharing).

Happy to update the name

# We are nuking ALL weights on file, so every parameter should
# yell on load. We're going to detect if we yell too much, or too little.
with open(os.path.join(d, "pytorch_model.bin"), "wb") as f:
torch.save({}, f)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A comment here would be nice, as this is unusual.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I gave you the wrong link (on an old commit, the comment is actually there).

Is it readable ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, yes! Sorry, I actually didn't understand that We are nuking ALL weights on file, so every parameter should ... which is quite clear already.


prefix = f"{model_reloaded.base_model_prefix}."
params = dict(model_reloaded.named_parameters())
params.update(dict(model_reloaded.named_buffers()))
Copy link
Collaborator

@ydshieh ydshieh Nov 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mentioned on DM

They are listed as two different things during from_pretrained but only appear once when doing model.named_parameters()

I haven't 100% understand it (despite got some rough feeling about it). Could you explain

  • why not using model_reloaded.state_dict() -- (from your word, I guess we will get 2 things, and it will match the missing keys, which will pass the test when the situation should fail)
  • Give an example, say what happens to bert or gpt2 for this block against current main

In any case, a comment (inside the code) should be added here regarding the usage of named_parameters and named_parameters 🙏 .

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not using model_reloaded.state_dict() -- (from your word, I guess we will get 2 things, and it will match the missing keys, which will pass the test when the situation should fail)

Exactly ! Actually we could use state_dict() and check iteratively the tensors which are the same (for instance using tensor.data_ptr()). to detect the duplicates.

model = GPT2LMHeadModel.from_pretrained("gpt2")
"lm_head.weight" in model.state_dict().keys()  # True
"lm_head.weight" in model.named_parameters() # False
In [6]: model.lm_head.weight.data_ptr()
Out[6]: 139901378371648

In [9]: model.transformer.wte.weight.data_ptr()
Out[9]: 139901378371648  # Same PTR, it's the same DATA ! we would need to check for stride too to be 100% accurate.

I'll add a comment.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As long as using named_parameters/named_buffer works, I wouldn't request to this fancier check :-)

Narsil added a commit to huggingface/safetensors that referenced this pull request Nov 4, 2022
)

* Adding the convert scripts which will now prevent converting models

in case they will trigger warnigns in the `transformers` side.
Even if the model is perfectly fine, core maintainers fear an influx
of opened issues.

This is perfectly legit.
On the `transformers` side fixes are on the way: huggingface/transformers#20042

We can wait for this PR to hit `main` before communicating super widely.

In the meantime this script of convertion will now prevent converting
models that would trigger such warnings (so the output of the script
**will** depend on the `transformers` freshness.

* Adding a nicer diff for the error when reloading.
self.assertEqual(
extra_missing,
set(),
f"This model {model_class.__name__} might be missing some `keys_to_ignore`: {extra_missing} when"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forgot to mention, we could remove when tying its weights here.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing work! Some styling comments and parts that seem orthogonal to this PR.

@@ -2422,7 +2422,8 @@ def _fix_key(key):

if remove_prefix_from_model:
expected_keys_not_prefixed = [s for s in expected_keys if not s.startswith(prefix)]
expected_keys = [".".join(s.split(".")[1:]) if s.startswith(prefix) else s for s in expected_keys]
_prefix = f"{prefix}."
expected_keys = [".".join(s.split(".")[1:]) if s.startswith(_prefix) else s for s in expected_keys]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like the right fix, thanks!

Comment on lines 2644 to 2647
module_keys = module_keys.union(set([".".join(key.split(".")[:-2]) for key in names if key[-1].isdigit()]))
module_keys = module_keys.union(
set([".".join(key.split(".")[:-2]) for key in names if key and key[-1].isdigit()])
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the change here? Can we revert if it's just styling?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not just styling I'm adding if key and key[-1] vs if key[-1] it should prevent failing even if there is a bad prefix.

I still could revert though. I left it to prevent further failure, but we're already in a bad state if this fails so...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if key -> if key is not None or if len(key) > 0 (not sure what you are testing).

GitHub made it hard to read the diff, sorry.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was "" so len(key) > 0 it is.

src/transformers/models/bart/modeling_bart.py Outdated Show resolved Hide resolved
src/transformers/models/bart/modeling_bart.py Outdated Show resolved Hide resolved
src/transformers/models/longt5/modeling_longt5.py Outdated Show resolved Hide resolved
src/transformers/models/mvp/modeling_mvp.py Outdated Show resolved Hide resolved
Comment on lines -862 to 865
assert list(hidden_states.size()) == [
batch_size,
ngram_sequence_length,
hidden_size,
], (
assert list(hidden_states.size()) == [batch_size, ngram_sequence_length, hidden_size], (
f"`hidden_states` should be of shape {batch_size, ngram_sequence_length, hidden_size}, but is of shape"
f" {hidden_states.shape}"
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems unrelated, and if we are going to change it, let's make it a real test and an exception raised :-p

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm pretty sure it's my black that does this. However it should be 22.3 Like `transformers expect. Not sure where the diff comes from.

And fixup doesn't revert those :(

ngram_sequence_length,
hidden_size,
], (
assert list(hidden_states.size()) == [batch_size, ngram_sequence_length, hidden_size], (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as before.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Black is not happy about this the old code it forces me to use this new layout.

I tried to revert this.

@@ -1485,11 +1485,24 @@ def test_correct_missing_keys(self):
base_model_prefix = model.base_model_prefix

if hasattr(model, base_model_prefix):

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sure you account in the test for upcoming models which have weights with the same name in the model with head and the base model (Niels is adding one in #19981)

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Impressive work! The changes in common look good to me.
The tests look quite robust as well.

@Narsil
Copy link
Contributor Author

Narsil commented Nov 8, 2022

@ydshieh This tests now fails in the CI tests/models/wav2vec2_conformer/test_modeling_wav2vec2_conformer.py::Wav2Vec2ConformerModelTest::test_save_load_fast_init_from_base

However I can't seem to be able to reproduce locally ? Do you mind trying if it's my setup failing or the CI ?

@Narsil
Copy link
Contributor Author

Narsil commented Nov 9, 2022

Merging.

fingers crossed :)

@Narsil Narsil merged commit bac2d29 into huggingface:main Nov 9, 2022
@Narsil Narsil deleted the tied_weights_warning_check branch November 9, 2022 15:03
amyeroberts pushed a commit to amyeroberts/transformers that referenced this pull request Nov 14, 2022
…20042)

* Attempting to test automatically the `_keys_to_ignore`.

* Style.

* First fix pass.

* Moving test on its own.

* Another batch.

* Second round removing BatchNorm

* Fixing layoutlmv{2,3} + support older Python.

* Disable miss missing warning.

* Removing dodgy additions.

* Big pass.

* mbart.

* More corrections.

* Fixup.

* Updating test_correct_missing_keys

* Add escape hatch for when the head has no extra params so doesn't need

the missing keys check.

* Fixing test.

* Greener.

* Green ! (except for weird splinter bug).

* Adding a test about `named_parameters` usage.

* Shorten message.

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <[email protected]>

* After rebase modifications.

* More explicit condition checking.

* Fixing slow tests issues.

* Remove extra pdb.

* Remove print.

* Attempt to make failure consistent + fixing roc_bert.

* Removing the seed  (all tests passing with it).

Co-authored-by: Sylvain Gugger <[email protected]>
@younesbelkada
Copy link
Contributor

younesbelkada commented Nov 21, 2022

Hey!
I am not sure if it is because of this PR but loading NLLB (that is affected by this PR) now gives:

│ /home/younes_huggingface_co/debug_issues/code/transformers/src/transformers/modeling_utils.py:24 │
│ 59 in _load_pretrained_model                                                                     │
│                                                                                                  │
│   2456 │   │   │   for key in missing_keys:                                                      │
│   2457 │   │   │   │   if key.startswith(prefix):                                                │
│   2458 │   │   │   │   │   key = ".".join(key.split(".")[1:])                                    │
│ ❱ 2459 │   │   │   │   param = model_state_dict[key]                                             │
│   2460 │   │   │   │   if param.device == torch.device("meta"):                                  │
│   2461 │   │   │   │   │   if not load_in_8bit:                                                  │
│   2462 │   │   │   │   │   │   set_module_tensor_to_device(model, key, "cpu", torch.empty(*para  │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
KeyError: 'encoder.embed_positions.weights'

Here is the snippet to reproduce the error:

from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

src_lang = "eng_Latn"
tgt_lang = "spa_Latn"

tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M", src_lang=src_lang)
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M",
                                              device_map= "auto")

I did not followed entirely this PR but I will dig into that now and see what exactly caused the issue 💪

cc @Narsil @sgugger

@@ -1244,12 +1246,8 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
r"encoder.version",
r"decoder.version",
r"lm_head.weight",
r"model.encoder.embed_positions.weights",
Copy link
Contributor

@younesbelkada younesbelkada Nov 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realized these are removed, is this intended (I see that it's one of the only model where some keys are totally removed )? Maybe it has been done by mistake but I am not sure

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Narsil would know best

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems I forgot to re-add this.

This was removed because of the test which is now commented which tries to remove _key_to_ignore which are masking real weights.

This SinusoidalPositionEmbeddings, seems to be a buffer and not a parameter, but something is wrong somewhere.
Since it's a nn.Module, maybe it still sees weights as being a Parameter somehow, idk.

We can safely revert this part anyway.

Copy link
Contributor

@younesbelkada younesbelkada Nov 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the explanation ! 🚀 The PR to revert is in #20381

mpierrau pushed a commit to mpierrau/transformers that referenced this pull request Dec 15, 2022
…20042)

* Attempting to test automatically the `_keys_to_ignore`.

* Style.

* First fix pass.

* Moving test on its own.

* Another batch.

* Second round removing BatchNorm

* Fixing layoutlmv{2,3} + support older Python.

* Disable miss missing warning.

* Removing dodgy additions.

* Big pass.

* mbart.

* More corrections.

* Fixup.

* Updating test_correct_missing_keys

* Add escape hatch for when the head has no extra params so doesn't need

the missing keys check.

* Fixing test.

* Greener.

* Green ! (except for weird splinter bug).

* Adding a test about `named_parameters` usage.

* Shorten message.

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <[email protected]>

* After rebase modifications.

* More explicit condition checking.

* Fixing slow tests issues.

* Remove extra pdb.

* Remove print.

* Attempt to make failure consistent + fixing roc_bert.

* Removing the seed  (all tests passing with it).

Co-authored-by: Sylvain Gugger <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants