Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Configurable FSDP Sharding #1024

Merged
merged 16 commits into from
Aug 10, 2024

Conversation

tambulkar
Copy link
Contributor

@tambulkar tambulkar commented May 27, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Please link to any issues this PR addresses.
#1014

Changelog

What are the changes made in this PR?
Add FSDP sharding options to the config

Test plan

Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help.)

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
    • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

Example CLI Commands For sanity Checking

tune run --nproc_per_node 2 lora_finetune_distributed --config llama3/8B_lora fsdp_sharding_strategy=test_invalid (This breaks)
tune run --nproc_per_node 2 lora_finetune_distributed --config llama3/8B_lora fsdp_sharding_strategy=NO_SHARD
tune run --nproc_per_node 2 lora_dpo_distributed --config llama3/8B_lora fsdp_sharding_strategy=NO_SHARD
tune run --nproc_per_node 2 full_finetune_distributed --config llama3/8B_full fsdp_sharding_strategy=FULL_SHARD

Copy link

pytorch-bot bot commented May 27, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1024

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 1d0f3c5 with merge base f6ddfcc (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 27, 2024
@tambulkar tambulkar changed the title add-recipe-updates Enable Configurable FSDP Sharing May 27, 2024
@tambulkar tambulkar changed the title Enable Configurable FSDP Sharing Enable Configurable FSDP Sharding May 27, 2024
@rohan-varma rohan-varma self-requested a review May 30, 2024 06:02
Copy link
Member

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a great idea - zero2 is used in a bunch of use cases and in general configurable sharding enables better speed / memory tradeoff. cc @ebsmothers to sign off

Also, would you mind adding testing details / details about speedups seen on your workloads? Thanks!

@tambulkar
Copy link
Contributor Author

tambulkar commented Jun 2, 2024

gpu_utilization_over_time
I have attached the gpu utilization from some small test runs. My set up was

  1. Gemma 2b
  2. 1% of alpaca training data
  3. 1 epoch
  4. batch size 4
  5. LORA distributed
  6. 4 x RTX 4090

It feels a little weird to me that fully shard didn't really use less memory than DDP for example. Maybe someone else could double check with a test run on their end. The speed of the runs makes sense to me though

NO_SHARD < _HYBRID_SHARD_ZERO2 < SHARD_GRAD_OP < HYBRID_SHARD < FULL_SHARD

@tambulkar tambulkar marked this pull request as ready for review June 2, 2024 22:05
@tambulkar
Copy link
Contributor Author

Also any advice on how to unit test this? tests/torchtune/config/* doesnt seem like the correct place to test it

@ebsmothers
Copy link
Contributor

Hi @tambulkar thanks for the PR! I agree with @rohan-varma, I think this is something we want to support. I have no major concerns with the changes themselves.

Re testing: since this is really only exposed at the recipe level, I agree that it probably doesn't make sense to add a unit test under config/. One option is to update our existing recipe tests: e.g. for the full finetune recipe you could update the parametrization here to pass in sharding strategy and just add one other sharding strategy for one of the models (no need to test them all). In that case I think the loss should be the same as with the default config (since the data parallel portion should be unchanged). Btw you can run these locally via e.g. pytest -m integration_test tests/recipes/test_full_finetune_distributed (assuming you have >1 GPU in your dev environment. If not lmk and we can figure something out)

Re the GPU utilization you're seeing.. I agree it's a bit counterintuitive. I can take a look on my end as well. Couple quick questions: (1) are you just using the default gemma/2B_lora config here (with the only overrides being the ones you described)? Also how are you generating the figure? Is it from WandB's native logging, or something else?

@tambulkar
Copy link
Contributor Author

tambulkar commented Jun 4, 2024

Hi @ebsmothers yeah I used the default gemma 2b lora config with just the overrides mentioned above. Didn't see the WandB integration so I actually just generated the graph by logging the gpu utilization using pynvml and graphing it myself.

Unfortuantely I don't really have a great multigpu set up - I just used tensordock to spin up a multigpu machine to test this out

@ebsmothers
Copy link
Contributor

Hey @tambulkar sorry I am just getting back to this. I ran a quick test on my end via

tune run --nnodes 1 --nproc_per_node 4 lora_finetune_distributed --config gemma/2B_lora \
metric_logger=torchtune.utils.metric_logging.WandBLogger metric_logger.project=lora-debug \
log_peak_memory_stats=True epochs=1 max_steps_per_epoch=100 

which should be pretty similar to your setup and added fsdp_shard_strategy=SHARD_GRAD_OP etc for other sharding strategies. Using torchtune's peak memory allocated logging I see the below logged from rank 0:

Screenshot 2024-06-10 at 2 46 24 PM

So in my case it's

FULL_SHARD = HYBRID_SHARD < NO_SHARD < SHARD_GRAD_OP = _HYBRID_SHARD_ZERO_2

This is closer to what I'd expect than the results you got, but still the whole NO_SHARD < SHARD_GRAD_OP seems a bit strange to me. Gemma is a bit of a weird case too since there are tied weights (unlike most of our other models). I just kicked off a run with Llama2-7B instead and it appears that SHARD_GRAD_OP < NO_SHARD does hold (over the few iterations I checked).

So anyways this is a long way to say that I think this is working as expected. Can you update the PR summary with the commands used to run the three recipes (mainly as a sanity check that nothing will be obviously broken, we have CI but don't have coverage on the DPO recipe yet)? After that I think this is good to merge.

@codecov-commenter
Copy link

codecov-commenter commented Jun 10, 2024

Codecov Report

Attention: Patch coverage is 30.00000% with 7 lines in your changes missing coverage. Please review.

Project coverage is 68.52%. Comparing base (f6ddfcc) to head (1d0f3c5).
Report is 2 commits behind head on main.

Files Patch % Lines
tests/recipes/test_full_finetune_distributed.py 33.33% 2 Missing ⚠️
tests/recipes/test_lora_finetune_distributed.py 50.00% 2 Missing ⚠️
recipes/full_finetune_distributed.py 0.00% 1 Missing ⚠️
recipes/lora_dpo_distributed.py 0.00% 1 Missing ⚠️
recipes/lora_finetune_distributed.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1024      +/-   ##
==========================================
+ Coverage   68.31%   68.52%   +0.20%     
==========================================
  Files         255      258       +3     
  Lines       11796    11903     +107     
==========================================
+ Hits         8059     8157      +98     
- Misses       3737     3746       +9     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@joecummings
Copy link
Contributor

@tambulkar Are you interested in finishing up this PR?

@tambulkar
Copy link
Contributor Author

Hey @joecummings sorry bout that I forgot about this - will finish it up today

@tambulkar
Copy link
Contributor Author

tambulkar commented Jul 31, 2024

the lora fine-tune recipe works fine but lora dpo and full fine-tune seem to have some issues on my end

tune run --nproc_per_node 2 lora_dpo_distributed --config llama3/8B_lora fsdp_sharding_strategy=NO_SHARD

fails with

    chosen_input_ids = [torch.tensor(ex["chosen_input_ids"]) for ex in batch]
KeyError: 'chosen_input_ids'

and

tune run --nproc_per_node 2 full_finetune_distributed --config llama3/8B_lora fsdp_sharding_strategy=NO_SHARD

fails with

RuntimeError: Error(s) in loading state_dict for TransformerDecoder:
	Missing key(s) in state_dict: "layers.0.attn.q_proj.lora_a.weight", "layers.0.attn.q_proj.lora_b.weight", "layers.0.attn.v_proj.lora_a.weight", "layers.0.attn.v_proj.lora_b.weight", "layers.1.attn.q_proj.lora_a.weight", "layers.1.attn.q_proj.lora_b.weight", "layers.1.attn.v_proj.lora_a.weight", "layers.1.attn.v_proj.lora_b.weight", "layers.2.attn.q_proj.lora_a.weight", "layers.2.attn.q_proj.lora_b.weight", "layers.2.attn.v_proj.lora_a.weight", "layers.2.attn.v_proj.lora_b.weight", "layers.3.attn.q_proj.lora_a.weight", "layers.3.attn.q_proj.lora_b.weight", "layers.3.attn.v_proj.lora_a.weight", "layers.3.attn.v_proj.lora_b.weight", "layers.4.attn.q_proj.lora_a.weight", "layers.4.attn.q_proj.lora_b.weight", "layers.4.attn.v_proj.lora_a.weight", "layers.4.attn.v_proj.lora_b.weight", "layers.5.attn.q_proj.lora_a.weight", "layers.5.attn.q_proj.lora_b.weight", "layers.5.attn.v_proj.lora_a.weight", "layers.5.attn.v_proj.lora_b.weight", "layers.6.attn.q_proj.lora_a.weight", "layers.6.attn.q_proj.lora_b.weight", "layers.6.attn.v_proj.lora_a.weight", "layers.6.attn.v_proj.lora_b.weight", "layers.7.attn.q_proj.lora_a.weight", "layers.7.attn.q_proj.lora_b.weight", "layers.7.attn.v_proj.lora_a.weight", "layers.7.attn.v_proj.lora_b.weight", "layers.8.attn.q_proj.lora_a.weight", "layers.8.attn.q_proj.lora_b.weight", "layers.8.attn.v_proj.lora_a.weight", "layers.8.attn.v_proj.lora_b.weight", "layers.9.attn.q_proj.lora_a.weight", "layers.9.attn.q_proj.lora_b.weight", "layers.9.attn.v_proj.lora_a.weight", "layers.9.attn.v_proj.lora_b.weight", "layers.10.attn.q_proj.lora_a.weight", "layers.10.attn.q_proj.lora_b.weight", "layers.10.attn.v_proj.lora_a.weight", "layers.10.attn.v_proj.lora_b.weight", "layers.11.attn.q_proj.lora_a.weight", "layers.11.attn.q_proj.lora_b.weight", "layers.11.attn.v_proj.lora_a.weight", "layers.11.attn.v_proj.lora_b.weight", "layers.12.attn.q_proj.lora_a.weight", "layers.12.attn.q_proj.lora_b.weight", "layers.12.attn.v_proj.lora_a.weight", "layers.12.attn.v_proj.lora_b.weight", "layers.13.attn.q_proj.lora_a.weight", "layers.13.attn.q_proj.lora_b.weight", "layers.13.attn.v_proj.lora_a.weight", "layers.13.attn.v_proj.lora_b.weight", "layers.14.attn.q_proj.lora_a.weight", "layers.14.attn.q_proj.lora_b.weight", "layers.14.attn.v_proj.lora_a.weight", "layers.14.attn.v_proj.lora_b.weight", "layers.15.attn.q_proj.lora_a.weight", "layers.15.attn.q_proj.lora_b.weight", "layers.15.attn.v_proj.lora_a.weight", "layers.15.attn.v_proj.lora_b.weight", "layers.16.attn.q_proj.lora_a.weight", "layers.16.attn.q_proj.lora_b.weight", "layers.16.attn.v_proj.lora_a.weight", "layers.16.attn.v_proj.lora_b.weight", "layers.17.attn.q_proj.lora_a.weight", "layers.17.attn.q_proj.lora_b.weight", "layers.17.attn.v_proj.lora_a.weight", "layers.17.attn.v_proj.lora_b.weight", "layers.18.attn.q_proj.lora_a.weight", "layers.18.attn.q_proj.lora_b.weight", "layers.18.attn.v_proj.lora_a.weight", "layers.18.attn.v_proj.lora_b.weight", "layers.19.attn.q_proj.lora_a.weight", "layers.19.attn.q_proj.lora_b.weight", "layers.19.attn.v_proj.lora_a.weight", "layers.19.attn.v_proj.lora_b.weight", "layers.20.attn.q_proj.lora_a.weight", "layers.20.attn.q_proj.lora_b.weight", "layers.20.attn.v_proj.lora_a.weight", "layers.20.attn.v_proj.lora_b.weight", "layers.21.attn.q_proj.lora_a.weight", "layers.21.attn.q_proj.lora_b.weight", "layers.21.attn.v_proj.lora_a.weight", "layers.21.attn.v_proj.lora_b.weight", "layers.22.attn.q_proj.lora_a.weight", "layers.22.attn.q_proj.lora_b.weight", "layers.22.attn.v_proj.lora_a.weight", "layers.22.attn.v_proj.lora_b.weight", "layers.23.attn.q_proj.lora_a.weight", "layers.23.attn.q_proj.lora_b.weight", "layers.23.attn.v_proj.lora_a.weight", "layers.23.attn.v_proj.lora_b.weight", "layers.24.attn.q_proj.lora_a.weight", "layers.24.attn.q_proj.lora_b.weight", "layers.24.attn.v_proj.lora_a.weight", "layers.24.attn.v_proj.lora_b.weight", "layers.25.attn.q_proj.lora_a.weight", "layers.25.attn.q_proj.lora_b.weight", "layers.25.attn.v_proj.lora_a.weight", "layers.25.attn.v_proj.lora_b.weight", "layers.26.attn.q_proj.lora_a.weight", "layers.26.attn.q_proj.lora_b.weight", "layers.26.attn.v_proj.lora_a.weight", "layers.26.attn.v_proj.lora_b.weight", "layers.27.attn.q_proj.lora_a.weight", "layers.27.attn.q_proj.lora_b.weight", "layers.27.attn.v_proj.lora_a.weight", "layers.27.attn.v_proj.lora_b.weight", "layers.28.attn.q_proj.lora_a.weight", "layers.28.attn.q_proj.lora_b.weight", "layers.28.attn.v_proj.lora_a.weight", "layers.28.attn.v_proj.lora_b.weight", "layers.29.attn.q_proj.lora_a.weight", "layers.29.attn.q_proj.lora_b.weight", "layers.29.attn.v_proj.lora_a.weight", "layers.29.attn.v_proj.lora_b.weight", "layers.30.attn.q_proj.lora_a.weight", "layers.30.attn.q_proj.lora_b.weight", "layers.30.attn.v_proj.lora_a.weight", "layers.30.attn.v_proj.lora_b.weight", "layers.31.attn.q_proj.lora_a.weight", "layers.31.attn.q_proj.lora_b.weight", "layers.31.attn.v_proj.lora_a.weight", "layers.31.attn.v_proj.lora_b.weight".

I downloaded llama3 using

tune download meta-llama/Meta-Llama-3-8B-Instruct --output-dir /tmp/Meta-Llama-3-8B-Instruct

not sure what the issue is exactly

@felipemello1
Copy link
Contributor

felipemello1 commented Jul 31, 2024

tune run --nproc_per_node 2 full_finetune_distributed --config llama3/8B_lora fsdp_sharding_strategy=NO_SHARD


fails with

You are using a lora config for full_finetune_distributed. Running the code below with "lora_finetune_distributed" should work:

tune run --nproc_per_node 2 lora_finetune_distributed --config llama3/8B_lora fsdp_sharding_strategy=NO_SHARD

@tambulkar
Copy link
Contributor Author

tambulkar commented Jul 31, 2024

@felipemello1 that worked thanks - is there a llama3 DPO config I should use?

@felipemello1
Copy link
Contributor

felipemello1 commented Aug 1, 2024

My guess is that you need to change the dataset. @SalmanMohammadi @RdoubleA , can you confirm/share your thoughts on why this fails:

tune run --nproc_per_node 2 lora_dpo_distributed --config llama3/8B_lora fsdp_sharding_strategy=NO_SHARD
fails with
chosen_input_ids = [torch.tensor(ex["chosen_input_ids"]) for ex in batch]
KeyError: 'chosen_input_ids'

@SalmanMohammadi
Copy link
Collaborator

SalmanMohammadi commented Aug 1, 2024

Thanks for this PR @tambulkar - this looks super cool.

@tambulkar, could you share the dataset you're using? It might be a case of just mapping the columns correctly.

@felipemello1
Copy link
Contributor

felipemello1 commented Aug 1, 2024

could you share the dataset you're using

I think it would be the default in the config, @SalmanMohammadi

https:/pytorch/torchtune/blob/main/recipes/configs/llama3/8B_lora.yaml#L47

dataset:
  _component_: torchtune.datasets.alpaca_cleaned_dataset

alpaca_cleaned_dataset = partial(alpaca_dataset, source="yahma/alpaca-cleaned")

@SalmanMohammadi
Copy link
Collaborator

Hmm, that's a instruct dataset - it should be using this, no?

I see @tambulkar is using a config for 8B_lora, wheras they should copy this config and adapt it to Llama3 8B.

@tambulkar
Copy link
Contributor Author

Good catch - it seems to get further when I use

# Dataset and Sampler
dataset:
  _component_: torchtune.datasets.stack_exchanged_paired_dataset
  max_seq_len: 1024
seed: null
shuffle: True
batch_size: 4
``` as the dataset with the 8B lora config
but I get some NCLL failures - could be my set up im using runpod

@SalmanMohammadi
Copy link
Collaborator

SalmanMohammadi commented Aug 2, 2024

Very silly question, since I'm not familiar with distributed debugging.

Is the loss in your config the same as in the DPO config?

@tambulkar
Copy link
Contributor Author

tambulkar commented Aug 2, 2024

Good call @SalmanMohammadi but even when I use

loss:
 _component_: torchtune.modules.loss.DPOLoss
 beta: 0.1
 label_smoothing: 0

in my config I still get the ncll failures - probably a version thing with the pod I am using

@tambulkar
Copy link
Contributor Author

@felipemello1 @SalmanMohammadi is there anything else to include here?

@felipemello1
Copy link
Contributor

felipemello1 commented Aug 9, 2024

I guess we just need to make sure that the DPO script runs without ncll failures, is that right? You had errors in your machine, so its not clear if its the recipe or the machine.

I can run your recipe in my machine, to see if its fine
If its not too much work for you, maybe you can run the dpo recipe from main to confirm that you also have these failures?

thanks again for the PR! :) @tambulkar

@tambulkar
Copy link
Contributor Author

tambulkar commented Aug 9, 2024

@felipemello1 NCLL errors went away on new machine I spun up -
tune run --nproc_per_node 2 lora_dpo_distributed --config ./my_custom_config.yaml fsdp_sharding_strategy=SHARD_GRAD_OP
starts running now I just get OOM with 2 x RTX 4090 - might still be worth running on your end as well. The OOM happens on main for me too which I feel is surprising given the numbers in the README.md
tune run --nproc_per_node 2 lora_dpo_distributed --config ./my_custom_config.yaml

My config is the llama3/8B_lora with

loss:
 _component_: torchtune.modules.loss.DPOLoss
 beta: 0.1
 label_smoothing: 0

and

dataset:
  _component_: torchtune.datasets.stack_exchanged_paired_dataset
  max_seq_len: 1024
seed: null
shuffle: True
batch_size: 4

@SalmanMohammadi
Copy link
Collaborator

Can you please add a quick comment in the docs at the top of each recipe (e.g. here) about this?

Just a minimal 1-2 sentences explaining what the config parameter you're adding does, and the different options we can use.

It'll help a lot with keeping track of which features we support as we start to document them more comprehensively.

@tambulkar
Copy link
Contributor Author

Thanks for the feedback @SalmanMohammadi updated the docstrings

@felipemello1 felipemello1 merged commit 2522c41 into pytorch:main Aug 10, 2024
20 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants