-
Notifications
You must be signed in to change notification settings - Fork 404
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generalize configs and add Llama2 13B + Mistral 7B #571
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/571
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 2a06b8f with merge base 49b523c (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
✅ Deploy Preview for torchtune-preview ready!
To edit notification comments on pull requests, go to your Netlify site configuration. |
@@ -66,7 +66,7 @@ def main(parser): | |||
epilog=textwrap.dedent( | |||
"""\ | |||
examples: | |||
$ tune cp lora_finetune_distributed.yaml ./my_custom_llama2_lora.yaml | |||
$ tune cp llama2/7B_lora.yaml ./my_custom_llama2_lora.yaml |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@joecummings I think appending the model name (eg: llama2) is fine and in fact will be critical as we add more models. I tested and this should work with our pkg structure. Let me know what you think.
recipes/README.md
Outdated
@@ -6,7 +6,7 @@ | |||
|
|||
Recipes are the primary entry points for TorchTune users. These can be thought of as end-to-end pipelines for training and optionally evaluating LLMs. Each recipe consists of three components: | |||
|
|||
- **Configurable parameters**, specified through yaml configs [example](https:/pytorch/torchtune/blob/main/recipes/configs/full_finetune_distributed.yaml) and command-line overrides | |||
- **Configurable parameters**, specified through yaml configs [example](https:/pytorch/torchtune/blob/main/recipes/configs/7B_full.yaml) and command-line overrides |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this path is wrong?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yeh sorry, let me fix this
# tune --nnodes 1 --nproc_per_node 1 full_finetune_distributed \ | ||
# --config full_finetune_distributed \ | ||
# checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR> ... | ||
# This config assumes that you've run the following command before launching |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why add this huge block in the config?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently there's no documentation on the configs at all. Once we have live docs available, we can add these to the docs. But for now, I'd like to give some understanding to users about when and how to use each config.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there's still some information in the README on configs and we can make that more clear. I think cluttering up the configs can be overwhelming.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why would it be overwhelming? Isn't it just documentation? I don't think we would be able to add config-level info to the README?
@@ -161,7 +161,7 @@ will list out all the locations where an error was found. | |||
|
|||
.. code-block:: bash | |||
|
|||
tune validate --config recipes/configs/full_finetune_single_device.yaml batch_size=4 | |||
tune validate --config recipes/configs/llama2/7B_full.yaml batch_size=4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing to watch out for: there can be lots of issues when filenames start with a number. I've definitely seen it as a problem with Python imports, maybe it will be ok with YAML files? But something to keep in mind
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is interesting - what sort of issues? But yeh I don't expect us to be importing the configs anymore?
# This config should be used with 2+ GPUs. Single device full fine-tuning | ||
# requires several memory optimizations which are exposed through | ||
# 7B_full_single_device.yaml. Please update the model and checkpoints to 13B | ||
# in that config. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess implicit in our choice of naming here is that >1 device is kind of now the "default", right? While I understand that we are doing more memory optimizations in the single device recipes now, we've obviously seen that FSDP comes with its own nuances too. So I do wonder if it's now hard for someone to just come in and say "give me a simple single-device recipe to get started on"
This is also a bit weird for QLoRA imo where we currently only support single device
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeh this is a good question. I did this primarily for two reasons:
- With the distributed CI testing sorted out, I removed the constraint on the distributed recipes. We can now run those on single device but without the memory optimizations.
- As we go to larger models, the single device setting will be less frequent. So when I thought about the default, distributed seem like the more natural one.
Does this make sense?
torchtune/_cli/ls.py
Outdated
full_finetune_distributed.py llama2/7B_full, llama2/13B_full | ||
lora_finetune_distributed.py llama2/7B_lora, llama2/13B_lora |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Happy to finally see more than one config per recipe. Nit: maybe we can split the configs over separate lines? E.g.
RECIPE CONFIG
full_finetune_distributed.py llama2/7B_full
llama2/13B_full
lora_finetune_distributed llama2/7B_lora
llama2/13B_lora
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, my vision was to be over multiple lines and I think that's actually how it works right now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup, sorry I just need to update this example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome to see this! Wonder if we have a parity check for 13b model, comparing to HF forward outputs and/or Meta model outputs?
@rohan-varma yup! I don't think I'll compare with the Meta implementation since I don't have the code for sharded checkpointing in a decent state (will do that comparison in a follow up PR). But I'll add a comparison with HF implementation. |
Added numerical parity checks and e2e eval comparisons for Llama2 13B to the context section. Thanks @joecummings for the help on this! |
@@ -60,7 +60,7 @@ class LoRAFinetuneRecipeDistributed(FTRecipeInterface): | |||
The following configs can be used to run this recipe: | |||
>>> tune ls | |||
RECIPE CONFIG | |||
lora_finetune_distributed lora_finetune_distributed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you actually ran this command? I think it goes on multiple lines?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome change! LGTM
): | ||
raise RuntimeError("Full bf16 training is not supported on this hardware.") | ||
|
||
world_size, rank = utils.get_world_size_and_rank() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that we have our distributed tests working as expected, I'm removing this constraint.
model_type: LLAMA2 | ||
resume_from_checkpoint: False | ||
|
||
# Fine-tuning arguments | ||
batch_size: 2 | ||
batch_size: 32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These config changes will make for high consumption memory for mistral, right?
Context
All relevant TorchTune components (recipes, checkpointers and CLI) can easily be generalized to models beyond Llama2 7B. In this PR, I show that adding Llama2 13B and Mistral7B is as simple as adding some new builder functions. The Llama2 13B currently assumes HF-format checkpoints (see configs) since I need to add support for being able to deal with Meta's sharded checkpoints. I'll do this in a follow-up PR.
Accompanying this addition re some cosmetic changes which makes the repo more user friendly for multiple models. This includes better organizing our configs and shortening their names so its not cumbersome to type them out, and better organizing our models.
I manually tested out all of the commands in the different READMEs and docstrings (except QLoRA which might be currently impacted by the torchao-nightly change).
Training speed for 13B is quite competitive. Quick comparisons showed us to be 2.5x faster than some competitors without any change to recipe code. The next section shows the correctness checks.
Note: Mistral 7B requires some data preprocessing changes for Alpaca finetuning. I'll follow up with those changes in a separate PR. In this PR, I add support for the model and show numerical parity with HF.
Changelog
llama2_13b
andlora_llama2_13b
builder functions to support the Llama2 13B modelalpaca_generate.yaml
) fromconfigs/
toconfigs/llama2/
._convert_weights
out ofmodels/llama2
and make this a public file which I expect users to use._component_builders.py
and_model_builders.py
undermodels/mistral
to support themistral_7b
model.Correctness Checks for Llama 13B
Numeric Parity of Llama2 13B with HF's implementation
Eval using Eleuther's Harness on
truthfulqa_mc2
Baseline: 36.9% vs Finetuned: 47.1%
Loss Curve (loss is comparable to what some forums shared)
Correctness Checks for Mistral 7B
Test plan