-
Notifications
You must be signed in to change notification settings - Fork 404
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Missing lm_head.weight key when using Gemma 7B distributed LoRA recipe with gemma-7b-it #1122
Comments
Hey @aubreyjstrier. Thanks for raising this. I think you're right about it being similar to the earlier issue. Re-reading my earlier fix and some of the HF code, I don't actually think we needed those changes. Huggingface models should generally be loaded and saved using At first glance: I think the reason it's failing here is because we build a map of I'll have a closer look later but I would be tempted to revert the changes in my PR with some additional testing. cc @pbontrager @ebsmothers |
@SalmanMohammadi can you share which PR you're referring to? I assume we can update the checkpoint mapping but I want to see the change. |
I think the issue that the PR was solving was due to the way the user was loading weights into a HF model (i.e. not using |
I'm running into this issue, as well. Will update this Issue as I figure things out. |
Reverting #1064 gets the following error trace
|
This also looks like a different error to the original poster - strange. FYI the small test I wrote for Gemma checkpointing passes if I use the original checkpointing logic - https://gist.github.com/SalmanMohammadi/e59ff24add75d37a2b81eeaccbff057c. |
Just getting caught up on some of this.. I think we should treat this as two separate pieces. (1) we should revert #1064. I don't think the usage of (2) the new error seems to be coming from our PEFT integration. I already chatted with @joecummings about this a bit, but because Gemma 7B does not satisfy |
Hi,
I'm having an issue with using the distributed LoRA recipe with instruction-tuned Gemma 7B via the torchtune CLI. The unexpected behavior is very similar to the bug raised in issue #1062 and solved in PR #1064, but instead of the lm_head.weight key being missing from the state_dict, it's missing from the weight_map field of the checkpointer.
After training is complete, torchtune attempts to split the state_dict beginning at line 506 and indexes into
self._weight_map
at line 508, at which point it errors:It raises:
[rank0]: KeyError: 'lm_head.weight'
I'm using the default config from the recipes folder, except that the checkpointer reads from and outputs to a directory where I've downloaded gemma-7b-it straight from the HF hub. This is on the latest build of torchtune.
Thanks in advance for your help!
The text was updated successfully, but these errors were encountered: