Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2/n - Make Gemma use regular TransformerDecoder #1553

Merged

Conversation

felipemello1
Copy link
Contributor

@felipemello1 felipemello1 commented Sep 12, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

This PR was built on top of #1547 (comment), so qwen changes should disappear from here after the other PR is merged

Changelog

  • updates gemma to use TransformerDecoder
  • Uses TiedLinear for the output projection
  • Uses Normalized Embedding to put together embedding + gemma norm
  • Also quick fix on lora alpha to be 2x lora rank
  • Reduce warmup steps to 10. I dont see why we need 100. LoRA training should be stable, we dont have it in finetuning. I did this after comparing a short run of full vs lora, and realizing that lora loss didnt decrease at all for the first 50 steps.

Test plan

resume from checkpoint working well

tune run --nnodes 1 --nproc_per_node 8 full_finetune_distributed --config gemma/2B_full batch_size=8 max_steps_per_epoch=20 metric_logger=torchtune.training.metric_logging.WandBLogger gradient_accumulation_steps=1 epochs=2 compile=True
image

Copy link

pytorch-bot bot commented Sep 12, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1553

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 175782f with merge base 7c51100 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 12, 2024
@felipemello1 felipemello1 changed the title Make Gemma use regular TransformerDecoder 2/n - Make Gemma use regular TransformerDecoder Sep 12, 2024
Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🫡

from torchtune.modules.loss import CEWithChunkedOutputLoss
from torchtune.utils import get_logger, torch_version_ge

log = get_logger("INFO")


def compile_model(
model: Union[TransformerDecoder, TiedEmbeddingTransformerDecoder],
model: TransformerDecoder,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why wasn't this handled in 1/n?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I forgot and left a comment that i was doing it in 2/n



class GemmaNormEmbeddings(nn.Embedding):
def __init__(self, in_dim: int, out_dim: int):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstrings, esp to explain why this is a separate class

@codecov-commenter
Copy link

Codecov Report

Attention: Patch coverage is 70.83333% with 7 lines in your changes missing coverage. Please review.

Project coverage is 73.18%. Comparing base (7c51100) to head (4e1e8a6).

Files with missing lines Patch % Lines
torchtune/models/gemma/_component_builders.py 66.66% 3 Missing ⚠️
torchtune/models/gemma/gemma_norm_embedding.py 77.77% 2 Missing ⚠️
torchtune/models/gemma/transformer.py 0.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1553      +/-   ##
==========================================
- Coverage   73.32%   73.18%   -0.14%     
==========================================
  Files         288      289       +1     
  Lines       14133    14164      +31     
==========================================
+ Hits        10363    10366       +3     
- Misses       3770     3798      +28     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@felipemello1 felipemello1 merged commit 7dad2d6 into pytorch:main Sep 12, 2024
17 checks passed
@felipemello1 felipemello1 deleted the gemma_deprecate_tied_transformer branch September 12, 2024 20:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants