Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge Gemma recipe with full finetune #668

Merged
merged 4 commits into from
Apr 15, 2024
Merged

Merge Gemma recipe with full finetune #668

merged 4 commits into from
Apr 15, 2024

Conversation

RdoubleA
Copy link
Contributor

@RdoubleA RdoubleA commented Apr 9, 2024

Context

The primary reason Gemma had its own recipe was due to weight tying, where the output projection = token embedding weights. This replicates the behavior of ReversibleEmbedding in Keras where you can use the embedding weight to project back from output dim to input dim. This also had implications in FSDP wrapping and initializing on meta device, you can see #630 and #616 for more discussion on that.

We can actually achieve the same "weight tying" by getting rid of the output projection altogether and using the embedding weight directly for the output (shout-out @pbontrager):

output = F.linear(h, self.tok_embeddings.weight).float()

This is more akin to how its done in GemmaCausalLM in Keras, where there's no output projection and the token embedding weight is used directly.

Changelog

  • Remove output projection from GemmaTransformerDecoder
  • Remove gemma_full_finetune_distributed.py recipe
  • Remove load_shared_weights_utils and save_shared_weights_utils
  • Remove special GEMMA cases in HF checkpointer
  • Remove various unused imports throughout torchtune/models/

Test plan

This run had nearly equivalent loss values to the gemma recipe on main:
tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config gemma/2B_full max_steps_per_epoch=5

1|5|Loss: 0.9243088364601135:   0%|                                                                                                           | 5/6501 
...
2|5|Loss: 1.2039588689804077:   0%|                                                                                                           | 5/6501 
...
3|5|Loss: 1.597070574760437:   0%|                                                                                                            | 5/6501 

tune run --nnodes 1 --nproc_per_node 4 gemma_full_finetune_distributed --config gemma/2B_full max_steps_per_epoch=5

1|5|Loss: 0.9225602149963379:   0%|                                                                                                           | 5/6501 
...
2|5|Loss: 1.204840898513794:   0%|                                                                                                            | 5/6501 
...
3|5|Loss: 1.5983972549438477:   0%|                                                                                                           | 5/6501 
image

Comparison with HF implementation:

image

Copy link

pytorch-bot bot commented Apr 9, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/668

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 706848b with merge base ff594c2 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 9, 2024
Copy link
Contributor

@kartikayk kartikayk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks really clean - thanks for making the change! I'll wait for you to debug the increasing loss and also add a fwd/bwd comparison with the reference implementation before accepting.

Also update the README and cite the original author?

@RdoubleA RdoubleA mentioned this pull request Apr 11, 2024
@RdoubleA RdoubleA merged commit 3f93b25 into main Apr 15, 2024
27 checks passed
@RdoubleA RdoubleA deleted the rafiayub/gemma_fix branch April 15, 2024 15:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants