-
Notifications
You must be signed in to change notification settings - Fork 26.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Attempting to test automatically the _keys_to_ignore
.
#20042
Conversation
The documentation is not available anymore as the PR was closed or merged. |
@@ -1153,7 +1155,7 @@ def forward( | |||
class BertLMHeadModel(BertPreTrainedModel): | |||
|
|||
_keys_to_ignore_on_load_unexpected = [r"pooler"] | |||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] | |||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", r"cls.predictions.decoder.weight"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me review this slowly 🙏 and verify a few things. But do you think r"predictions.decoder.bias"
is a mistake and should be r"cls.predictions.decoder.bias"
?
I am going to check myself anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A re match is done, so while the exact name is indeed r"cls.predictions.decoder.bias"
, this works. But would be great to fix just in case one day a weight named predictions.decoder.bias
that should not be ignore appears ;-)
@@ -1153,7 +1155,7 @@ def forward( | |||
class BertLMHeadModel(BertPreTrainedModel): | |||
|
|||
_keys_to_ignore_on_load_unexpected = [r"pooler"] | |||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] | |||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", r"cls.predictions.decoder.weight"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me review this slowly 🙏 and verify a few things. But do you think r"predictions.decoder.bias"
is a mistake and should be r"cls.predictions.decoder.bias"
?
I am going to check myself anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe. I'm not sure but the test does yell if we're hiding a valid key (we I don't try to yell when we have an unused _keys_to_ignore
. Want me to try and add it to the test ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting way of testing those attributes!
@@ -1153,7 +1155,7 @@ def forward( | |||
class BertLMHeadModel(BertPreTrainedModel): | |||
|
|||
_keys_to_ignore_on_load_unexpected = [r"pooler"] | |||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] | |||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", r"cls.predictions.decoder.weight"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A re match is done, so while the exact name is indeed r"cls.predictions.decoder.bias"
, this works. But would be great to fix just in case one day a weight named predictions.decoder.bias
that should not be ignore appears ;-)
@@ -1538,6 +1538,36 @@ def check_same_values(layer_1, layer_2): | |||
# # Check that the embedding layer and decoding layer are the same in size and in value | |||
# self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape) | |||
# self.assertTrue(check_same_values(model.transformer.wte, model.lm_head)) | |||
with tempfile.TemporaryDirectory() as d: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a comment here on what's check? Also maybe it should be in its own test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moving to its own test + comment on how it works.
src/transformers/modeling_utils.py
Outdated
@@ -2422,7 +2422,8 @@ def _fix_key(key): | |||
|
|||
if remove_prefix_from_model: | |||
expected_keys_not_prefixed = [s for s in expected_keys if not s.startswith(prefix)] | |||
expected_keys = [".".join(s.split(".")[1:]) if s.startswith(prefix) else s for s in expected_keys] | |||
_prefix = f"{prefix}." | |||
expected_keys = [".".join(s.split(".")[1:]) if s.startswith(_prefix) else s for s in expected_keys] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would fail for esm
/esmfold
since esm
is a prefix for esmfold
it would remove emsfold.
prefix instead of the intended esm.
This was silent before, but with the new test addition the following modification would fail to get key[-1]
when trying to identify missing keys.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like the right fix, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, there was a second place to update the fix to make the splinter
test pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -1485,11 +1485,24 @@ def test_correct_missing_keys(self): | |||
base_model_prefix = model.base_model_prefix | |||
|
|||
if hasattr(model, base_model_prefix): | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an important new modification.
This test tries to verify that we're correctly warning when we're saving the "core" of a model, and failing to save the head parameters, and that we yell when loading the head from this.
This test is great, and different from the newly added one.
But, since we fixed a lot of Decoder
head which are merely a copy of the embeddings matrix, then they are a head, but they DO NOT have new extra parameters. hence this test would fail.
There is self.test_missing_keys
flag to deactivate that particular test on the model testing. But it would deactivate ALL potential heads (as far as I understand), potentially opening gaps in our testing.
So I modified this test to try and detect those edge cases. I don't like putting logic like that in test as it makes reasoning hard but I failed to see another way.
This test verifies that the head does have extra parameters before actually doing the verification.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sure you account in the test for upcoming models which have weights with the same name in the model with head and the base model (Niels is adding one in #19981)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see which weights are sharing names.
Seems relatively regular to me: https:/huggingface/transformers/pull/19981/files#diff-a2ff7478a3e58383ae065b28d9e5e4e1e1bace51b32bc91b6eb086fabb0bd702R448-R470
" tying its weights", | ||
) | ||
|
||
# self.assertEqual( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I deactivated this part because token_type_ids
and position_ids
would yell pretty much everywhere.
I have no idea of the perfect solution for this.
I think position_ids
should be a non-parameter if it's not, or a parameter if it is learned or something.
Given how many models it affect we could add an escape hatch for just these peculiar names.
But I think flagging actually too inclusive _keys_to_ignore
can provide good value, especially since we're using regexp so it's easy to make a mistake, or modify the model later in such a way that it doesn't reflect properly anymore in those keys.
@ydshieh the
|
I am not able to reproduce the |
IT's this failure : https://app.circleci.com/jobs/github/huggingface/transformers/608396 |
in case they will trigger warnigns in the `transformers` side. Even if the model is perfectly fine, core maintainers fear an influx of opened issues. This is perfectly legit. On the `transformers` side fixes are on the way: huggingface/transformers#20042 We can wait for this PR to hit `main` before communicating super widely. In the meantime this script of convertion will now prevent converting models that would trigger such warnings (so the output of the script **will** depend on the `transformers` freshness.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a few comments - only for the test test_tied_model_weights_key_ignore
so far.
@@ -1539,6 +1552,44 @@ def check_same_values(layer_1, layer_2): | |||
# self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape) | |||
# self.assertTrue(check_same_values(model.transformer.wte, model.lm_head)) | |||
|
|||
def test_tied_model_weights_key_ignore(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would suggest to make the name more generic (i.e. not using tied_model
in the name)
The check in this test is not really specific to tied weights, although they are one of the main source of missing keys.
test_model_weights_key_ignore
+ a comment to explain its objective would be fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In my mind it is only about tied weights (because it's the only mechanism I know that can create the weight sharing).
Happy to update the name
# We are nuking ALL weights on file, so every parameter should | ||
# yell on load. We're going to detect if we yell too much, or too little. | ||
with open(os.path.join(d, "pytorch_model.bin"), "wb") as f: | ||
torch.save({}, f) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A comment here would be nice, as this is unusual.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I gave you the wrong link (on an old commit, the comment is actually there).
Is it readable ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, yes! Sorry, I actually didn't understand that We are nuking ALL weights on file, so every parameter should ...
which is quite clear already.
|
||
prefix = f"{model_reloaded.base_model_prefix}." | ||
params = dict(model_reloaded.named_parameters()) | ||
params.update(dict(model_reloaded.named_buffers())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You mentioned on DM
They are listed as two different things during from_pretrained but only appear once when doing model.named_parameters()
I haven't 100% understand it (despite got some rough feeling about it). Could you explain
- why not using
model_reloaded.state_dict()
-- (from your word, I guess we will get 2 things, and it will match the missing keys, which will pass the test when the situation should fail) - Give an example, say what happens to bert or gpt2 for this block against current
main
In any case, a comment (inside the code) should be added here regarding the usage of named_parameters
and named_parameters
🙏 .
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not using model_reloaded.state_dict() -- (from your word, I guess we will get 2 things, and it will match the missing keys, which will pass the test when the situation should fail)
Exactly ! Actually we could use state_dict()
and check iteratively the tensors which are the same (for instance using tensor.data_ptr()
). to detect the duplicates.
model = GPT2LMHeadModel.from_pretrained("gpt2")
"lm_head.weight" in model.state_dict().keys() # True
"lm_head.weight" in model.named_parameters() # False
In [6]: model.lm_head.weight.data_ptr()
Out[6]: 139901378371648
In [9]: model.transformer.wte.weight.data_ptr()
Out[9]: 139901378371648 # Same PTR, it's the same DATA ! we would need to check for stride too to be 100% accurate.
I'll add a comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As long as using named_parameters/named_buffer works, I wouldn't request to this fancier check :-)
) * Adding the convert scripts which will now prevent converting models in case they will trigger warnigns in the `transformers` side. Even if the model is perfectly fine, core maintainers fear an influx of opened issues. This is perfectly legit. On the `transformers` side fixes are on the way: huggingface/transformers#20042 We can wait for this PR to hit `main` before communicating super widely. In the meantime this script of convertion will now prevent converting models that would trigger such warnings (so the output of the script **will** depend on the `transformers` freshness. * Adding a nicer diff for the error when reloading.
tests/test_modeling_common.py
Outdated
self.assertEqual( | ||
extra_missing, | ||
set(), | ||
f"This model {model_class.__name__} might be missing some `keys_to_ignore`: {extra_missing} when" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Forgot to mention, we could remove when tying its weights
here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Amazing work! Some styling comments and parts that seem orthogonal to this PR.
src/transformers/modeling_utils.py
Outdated
@@ -2422,7 +2422,8 @@ def _fix_key(key): | |||
|
|||
if remove_prefix_from_model: | |||
expected_keys_not_prefixed = [s for s in expected_keys if not s.startswith(prefix)] | |||
expected_keys = [".".join(s.split(".")[1:]) if s.startswith(prefix) else s for s in expected_keys] | |||
_prefix = f"{prefix}." | |||
expected_keys = [".".join(s.split(".")[1:]) if s.startswith(_prefix) else s for s in expected_keys] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like the right fix, thanks!
src/transformers/modeling_utils.py
Outdated
module_keys = module_keys.union(set([".".join(key.split(".")[:-2]) for key in names if key[-1].isdigit()])) | ||
module_keys = module_keys.union( | ||
set([".".join(key.split(".")[:-2]) for key in names if key and key[-1].isdigit()]) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why the change here? Can we revert if it's just styling?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not just styling I'm adding if key and key[-1]
vs if key[-1]
it should prevent failing even if there is a bad prefix.
I still could revert though. I left it to prevent further failure, but we're already in a bad state if this fails so...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if key
-> if key is not None
or if len(key) > 0
(not sure what you are testing).
GitHub made it hard to read the diff, sorry.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was ""
so len(key) > 0
it is.
src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py
Outdated
Show resolved
Hide resolved
assert list(hidden_states.size()) == [ | ||
batch_size, | ||
ngram_sequence_length, | ||
hidden_size, | ||
], ( | ||
assert list(hidden_states.size()) == [batch_size, ngram_sequence_length, hidden_size], ( | ||
f"`hidden_states` should be of shape {batch_size, ngram_sequence_length, hidden_size}, but is of shape" | ||
f" {hidden_states.shape}" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems unrelated, and if we are going to change it, let's make it a real test and an exception raised :-p
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm pretty sure it's my black that does this. However it should be 22.3 Like `transformers expect. Not sure where the diff comes from.
And fixup
doesn't revert those :(
ngram_sequence_length, | ||
hidden_size, | ||
], ( | ||
assert list(hidden_states.size()) == [batch_size, ngram_sequence_length, hidden_size], ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as before.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Black is not happy about this the old code it forces me to use this new layout.
I tried to revert this.
@@ -1485,11 +1485,24 @@ def test_correct_missing_keys(self): | |||
base_model_prefix = model.base_model_prefix | |||
|
|||
if hasattr(model, base_model_prefix): | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sure you account in the test for upcoming models which have weights with the same name in the model with head and the base model (Niels is adding one in #19981)
5c25bf8
to
72493aa
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Impressive work! The changes in common look good to me.
The tests look quite robust as well.
@ydshieh This tests now fails in the CI tests/models/wav2vec2_conformer/test_modeling_wav2vec2_conformer.py::Wav2Vec2ConformerModelTest::test_save_load_fast_init_from_base However I can't seem to be able to reproduce locally ? Do you mind trying if it's my setup failing or the CI ? |
the missing keys check.
Co-authored-by: Sylvain Gugger <[email protected]>
ea9a45c
to
3354c8e
Compare
Merging. fingers crossed :) |
…20042) * Attempting to test automatically the `_keys_to_ignore`. * Style. * First fix pass. * Moving test on its own. * Another batch. * Second round removing BatchNorm * Fixing layoutlmv{2,3} + support older Python. * Disable miss missing warning. * Removing dodgy additions. * Big pass. * mbart. * More corrections. * Fixup. * Updating test_correct_missing_keys * Add escape hatch for when the head has no extra params so doesn't need the missing keys check. * Fixing test. * Greener. * Green ! (except for weird splinter bug). * Adding a test about `named_parameters` usage. * Shorten message. * Apply suggestions from code review Co-authored-by: Sylvain Gugger <[email protected]> * After rebase modifications. * More explicit condition checking. * Fixing slow tests issues. * Remove extra pdb. * Remove print. * Attempt to make failure consistent + fixing roc_bert. * Removing the seed (all tests passing with it). Co-authored-by: Sylvain Gugger <[email protected]>
Hey!
Here is the snippet to reproduce the error:
I did not followed entirely this PR but I will dig into that now and see what exactly caused the issue 💪 |
@@ -1244,12 +1246,8 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel): | |||
r"encoder.version", | |||
r"decoder.version", | |||
r"lm_head.weight", | |||
r"model.encoder.embed_positions.weights", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I realized these are removed, is this intended (I see that it's one of the only model where some keys are totally removed )? Maybe it has been done by mistake but I am not sure
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Narsil would know best
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems I forgot to re-add this.
This was removed because of the test which is now commented which tries to remove _key_to_ignore
which are masking real weights.
This SinusoidalPositionEmbeddings, seems to be a buffer and not a parameter, but something is wrong somewhere.
Since it's a nn.Module, maybe it still sees weights
as being a Parameter somehow, idk.
We can safely revert this part anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the explanation ! 🚀 The PR to revert is in #20381
…20042) * Attempting to test automatically the `_keys_to_ignore`. * Style. * First fix pass. * Moving test on its own. * Another batch. * Second round removing BatchNorm * Fixing layoutlmv{2,3} + support older Python. * Disable miss missing warning. * Removing dodgy additions. * Big pass. * mbart. * More corrections. * Fixup. * Updating test_correct_missing_keys * Add escape hatch for when the head has no extra params so doesn't need the missing keys check. * Fixing test. * Greener. * Green ! (except for weird splinter bug). * Adding a test about `named_parameters` usage. * Shorten message. * Apply suggestions from code review Co-authored-by: Sylvain Gugger <[email protected]> * After rebase modifications. * More explicit condition checking. * Fixing slow tests issues. * Remove extra pdb. * Remove print. * Attempt to make failure consistent + fixing roc_bert. * Removing the seed (all tests passing with it). Co-authored-by: Sylvain Gugger <[email protected]>
What does this PR do?
This adds a new part of the
tied_weights
test that aims at detecting automaticallywhen
_keys_to_ignore
is incorrectly set._keys_to_ignore
aims to ignore weights that are supposed to be tied in thefinal model, meaning it's OK if the parameter is missing from the on-disk weights.
The weights are really empty during the load, but they end up being tied afterwards
so we should ignore them during the load if they are missing.
The test also aims to detect
_keys_to_ignore
that might have been set butcould be misleading because the parameters are actually NOT tied anymore.
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.