Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Missing lm_head.weight key when using Gemma 7B distributed LoRA recipe with gemma-7b-it #1122

Closed
aubreyjstrier opened this issue Jun 26, 2024 · 7 comments · Fixed by #1169
Closed
Assignees
Labels
bug Something isn't working question Further information is requested

Comments

@aubreyjstrier
Copy link

Hi,

I'm having an issue with using the distributed LoRA recipe with instruction-tuned Gemma 7B via the torchtune CLI. The unexpected behavior is very similar to the bug raised in issue #1062 and solved in PR #1064, but instead of the lm_head.weight key being missing from the state_dict, it's missing from the weight_map field of the checkpointer.

After training is complete, torchtune attempts to split the state_dict beginning at line 506 and indexes into self._weight_map at line 508, at which point it errors:

split_state_dicts: Dict[str, Dict[str, torch.Tensor]] = {}
        for key, weight in state_dict[utils.MODEL_KEY].items():
            cpt_idx = self._weight_map[key]  ### fails here
            if cpt_idx not in split_state_dicts:
                split_state_dicts[cpt_idx] = {}
            split_state_dicts[cpt_idx].update({key: weight})

It raises: [rank0]: KeyError: 'lm_head.weight'

I'm using the default config from the recipes folder, except that the checkpointer reads from and outputs to a directory where I've downloaded gemma-7b-it straight from the HF hub. This is on the latest build of torchtune.

Thanks in advance for your help!

@SalmanMohammadi
Copy link
Collaborator

Hey @aubreyjstrier. Thanks for raising this.

I think you're right about it being similar to the earlier issue. Re-reading my earlier fix and some of the HF code, I don't actually think we needed those changes. Huggingface models should generally be loaded and saved using from_pretrained, and save_pretrained, which take care of tied embedding weights - the issue at hand was loading weights directly into the model. The Gemma 2B checkpoint doesn't have a keys for an output projection.

At first glance: I think the reason it's failing here is because we build a map of key : {checkpoint_file_id} for the model state dict, so we know to save tensor_0 in checkpoint 0001 and so on. When saving the checkpoint, we don't have an entry for lm_head since we didn't load one from the HF checkpoint in the first place.

I'll have a closer look later but I would be tempted to revert the changes in my PR with some additional testing. cc @pbontrager @ebsmothers

@pbontrager
Copy link
Contributor

@SalmanMohammadi can you share which PR you're referring to? I assume we can update the checkpoint mapping but I want to see the change.

@SalmanMohammadi
Copy link
Collaborator

@SalmanMohammadi can you share which PR you're referring to? I assume we can update the checkpoint mapping but I want to see the change.

#1064

I think the issue that the PR was solving was due to the way the user was loading weights into a HF model (i.e. not using from_pretrained, rather than anything to do with our checkpointing.

@felipemello1 felipemello1 added bug Something isn't working question Further information is requested labels Jul 2, 2024
@joecummings
Copy link
Contributor

I'm running into this issue, as well. Will update this Issue as I figure things out.

@joecummings
Copy link
Contributor

Reverting #1064 gets the following error trace

Traceback (most recent call last):
  File "/home/jrcummings/.conda/envs/joe-torchtune/bin/tune", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/home/jrcummings/projects/joe-torchtune/torchtune/_cli/tune.py", line 49, in main
    parser.run(args)
  File "/home/jrcummings/projects/joe-torchtune/torchtune/_cli/tune.py", line 43, in run
    args.func(args)
  File "/home/jrcummings/projects/joe-torchtune/torchtune/_cli/run.py", line 179, in _run_cmd
    self._run_single_device(args)
  File "/home/jrcummings/projects/joe-torchtune/torchtune/_cli/run.py", line 93, in _run_single_device
    runpy.run_path(str(args.recipe), run_name="__main__")
  File "<frozen runpy>", line 291, in run_path
  File "<frozen runpy>", line 98, in _run_module_code
  File "<frozen runpy>", line 88, in _run_code
  File "/home/jrcummings/projects/joe-torchtune/recipes/lora_finetune_single_device.py", line 648, in <module>
    sys.exit(recipe_main())
             ^^^^^^^^^^^^^
  File "/home/jrcummings/projects/joe-torchtune/torchtune/config/_parse.py", line 50, in wrapper
    sys.exit(recipe_main(conf))
             ^^^^^^^^^^^^^^^^^
  File "/home/jrcummings/projects/joe-torchtune/recipes/lora_finetune_single_device.py", line 643, in recipe_main
    recipe.train()
  File "/home/jrcummings/projects/joe-torchtune/recipes/lora_finetune_single_device.py", line 625, in train
    self.save_checkpoint(epoch=curr_epoch)
  File "/home/jrcummings/projects/joe-torchtune/recipes/lora_finetune_single_device.py", line 517, in save_checkpoint
    self._checkpointer.save_checkpoint(
  File "/home/jrcummings/projects/joe-torchtune/torchtune/utils/_checkpointing/_checkpointer.py", line 535, in save_checkpoint
    ] = convert_weights.tune_to_peft_adapter_weights(
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jrcummings/projects/joe-torchtune/torchtune/models/convert_weights.py", line 282, in tune_to_peft_adapter_weights
    value = _permute_lora_matrix(value, num_heads)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jrcummings/projects/joe-torchtune/torchtune/models/convert_weights.py", line 274, in _permute_lora_matrix
    t.view(n_heads, head_dim // 2, 2, rank)
RuntimeError: shape '[16, 96, 2, 8]' is invalid for input of size 32768

@SalmanMohammadi
Copy link
Collaborator

Reverting #1064 gets the following error trace

This also looks like a different error to the original poster - strange. FYI the small test I wrote for Gemma checkpointing passes if I use the original checkpointing logic - https://gist.github.com/SalmanMohammadi/e59ff24add75d37a2b81eeaccbff057c.

@ebsmothers
Copy link
Contributor

Just getting caught up on some of this.. I think we should treat this as two separate pieces.

(1) we should revert #1064. I don't think the usage of load_state_dict in #1062 is the correct way to load from a torchtune model with tied weights into an HF model with tied weights. Imo we should adhere to the format actually provided on the hub, which does not contain lm_head.weight. So I think we were doing the correct thing before. The reason HF's from_pretrained works with this checkpoint is because of their usage of _tied_weights_keys (as pointed out by @SalmanMohammadi in the original PR). But I don't think we should try to replicate this on our end; better to adhere to the contract that we return the same state dict format at the end of training as what we got in (i.e. one that doesn't contain lm_head.weight). If we want to load a torchtune checkpoint directly into the CausalLM version of Gemma (which has patched in an extra weight on the backend), it's natural to ask the user to write the same corresponding glue code.

(2) the new error seems to be coming from our PEFT integration. I already chatted with @joecummings about this a bit, but because Gemma 7B does not satisfy num_heads * head_dim = embed_dim, we may need to change this function for saving our LoRA weights into PEFT format. I imagine we will want to just explicitly pass head_dim instead of inferring it like we currently do. Note that we only need to permute the LoRA B matrix (exercise to the reader to figure out why) so we don't need to worry about permuting the A matrix, which now will not have any knowledge of num_heads or head_dim

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working question Further information is requested
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants