-
Notifications
You must be signed in to change notification settings - Fork 404
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable Configurable FSDP Sharding #1024
Enable Configurable FSDP Sharding #1024
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1024
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 1d0f3c5 with merge base f6ddfcc (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a great idea - zero2 is used in a bunch of use cases and in general configurable sharding enables better speed / memory tradeoff. cc @ebsmothers to sign off
Also, would you mind adding testing details / details about speedups seen on your workloads? Thanks!
Also any advice on how to unit test this? |
Hi @tambulkar thanks for the PR! I agree with @rohan-varma, I think this is something we want to support. I have no major concerns with the changes themselves. Re testing: since this is really only exposed at the recipe level, I agree that it probably doesn't make sense to add a unit test under Re the GPU utilization you're seeing.. I agree it's a bit counterintuitive. I can take a look on my end as well. Couple quick questions: (1) are you just using the default |
Hi @ebsmothers yeah I used the default gemma 2b lora config with just the overrides mentioned above. Didn't see the WandB integration so I actually just generated the graph by logging the gpu utilization using Unfortuantely I don't really have a great multigpu set up - I just used tensordock to spin up a multigpu machine to test this out |
Hey @tambulkar sorry I am just getting back to this. I ran a quick test on my end via
which should be pretty similar to your setup and added So in my case it's FULL_SHARD = HYBRID_SHARD < NO_SHARD < SHARD_GRAD_OP = _HYBRID_SHARD_ZERO_2 This is closer to what I'd expect than the results you got, but still the whole So anyways this is a long way to say that I think this is working as expected. Can you update the PR summary with the commands used to run the three recipes (mainly as a sanity check that nothing will be obviously broken, we have CI but don't have coverage on the DPO recipe yet)? After that I think this is good to merge. |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1024 +/- ##
==========================================
+ Coverage 68.31% 68.52% +0.20%
==========================================
Files 255 258 +3
Lines 11796 11903 +107
==========================================
+ Hits 8059 8157 +98
- Misses 3737 3746 +9 ☔ View full report in Codecov by Sentry. |
@tambulkar Are you interested in finishing up this PR? |
Hey @joecummings sorry bout that I forgot about this - will finish it up today |
the lora fine-tune recipe works fine but lora dpo and full fine-tune seem to have some issues on my end
fails with
and
fails with
I downloaded llama3 using
not sure what the issue is exactly |
You are using a lora config for full_finetune_distributed. Running the code below with "lora_finetune_distributed" should work:
|
@felipemello1 that worked thanks - is there a llama3 DPO config I should use? |
My guess is that you need to change the dataset. @SalmanMohammadi @RdoubleA , can you confirm/share your thoughts on why this fails:
|
Thanks for this PR @tambulkar - this looks super cool. @tambulkar, could you share the dataset you're using? It might be a case of just mapping the columns correctly. |
I think it would be the default in the config, @SalmanMohammadi https:/pytorch/torchtune/blob/main/recipes/configs/llama3/8B_lora.yaml#L47
torchtune/torchtune/datasets/_alpaca.py Line 68 in 288ff44
|
Hmm, that's a instruct dataset - it should be using this, no? I see @tambulkar is using a config for 8B_lora, wheras they should copy this config and adapt it to Llama3 8B. |
Good catch - it seems to get further when I use
|
Very silly question, since I'm not familiar with distributed debugging. Is the loss in your config the same as in the DPO config? |
Good call @SalmanMohammadi but even when I use
in my config I still get the ncll failures - probably a version thing with the pod I am using |
@felipemello1 @SalmanMohammadi is there anything else to include here? |
I guess we just need to make sure that the DPO script runs without ncll failures, is that right? You had errors in your machine, so its not clear if its the recipe or the machine. I can run your recipe in my machine, to see if its fine thanks again for the PR! :) @tambulkar |
@felipemello1 NCLL errors went away on new machine I spun up - My config is the llama3/8B_lora with
and
|
Can you please add a quick comment in the docs at the top of each recipe (e.g. here) about this? Just a minimal 1-2 sentences explaining what the config parameter you're adding does, and the different options we can use. It'll help a lot with keeping track of which features we support as we start to document them more comprehensively. |
Thanks for the feedback @SalmanMohammadi updated the docstrings |
Context
What is the purpose of this PR? Is it to
Please link to any issues this PR addresses.
#1014
Changelog
What are the changes made in this PR?
Add FSDP sharding options to the config
Test plan
Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help.)
pre-commit install
)pytest tests
pytest tests -m integration_test
Example CLI Commands For sanity Checking