Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Optimizer-in-the-backward #1737

Closed
wants to merge 0 commits into from
Closed

Conversation

mori360
Copy link

@mori360 mori360 commented Oct 1, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Enable Optimizer-in-the-backward for full_finetune_distributed

Changelog

  • Update full_finetune_distributed for enabling Optimizer-in-the-backward
  • Update test_full_finetune_distributed with _optimizer_in_bwd config
  • updated test_distributed to test running with/without optimized_in_the_backward, and performance after saving-loading state_dict.

Test plan

  • Test running with optimizer_in_the_backward: tune run --nproc_per_node 2 full_finetune_distributed --config llama2/7B_full fsdp_cpu_offload=False max_steps_per_epoch=2 optimizer_in_bwd=True
  • Test running optimizer_in_the_backward with resume_from_checkpoint: tune run --nproc_per_node 2 full_finetune_distributed --config llama2/7B_full fsdp_cpu_offload=False max_steps_per_epoch=2 epochs=10 optimizer_in_bwd=True resume_from_checkpoint=True checkpointer.recipe_checkpoint=/tmp/Llama-2-7b-hf/recipe_state.pt checkpointer.checkpoint_files=[hf_model_0001_1.pt,hf_model_0002_1.pt]
  • Verify that running with Optimizer-in-the-backward could have the same loss, model_state_dict and optimizer_state_dict, model after saving and loading could also have the same: pytest tests/torchtune/training/test_distributed.py -k test_optimizer_in_backward

Memory cost analysis:
With each layer gradient cost 193MB memory, the origin(left) case has the peak memory at the 31th layer with accumulation of 193MB memory times 30.
The right case with optimizer-in-the-backward frees these memory during backward, gets lower peak memory.
memory compare

Training time and loss analysis:
training time and loss

Copy link

pytorch-bot bot commented Oct 1, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1737

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit f639b6d with merge base f639b6d (image):

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 1, 2024
@codecov-commenter
Copy link

codecov-commenter commented Oct 2, 2024

Codecov Report

Attention: Patch coverage is 1.97368% with 149 lines in your changes missing coverage. Please review.

Project coverage is 25.44%. Comparing base (7cf656b) to head (207b1b1).
Report is 21 commits behind head on main.

Files with missing lines Patch % Lines
tests/torchtune/training/test_distributed.py 2.88% 101 Missing ⚠️
recipes/full_finetune_distributed.py 0.00% 48 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1737       +/-   ##
===========================================
- Coverage   69.33%   25.44%   -43.89%     
===========================================
  Files         305      305               
  Lines       15892    16089      +197     
===========================================
- Hits        11018     4094     -6924     
- Misses       4874    11995     +7121     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@mori360 mori360 marked this pull request as ready for review October 7, 2024 22:24
@mori360 mori360 marked this pull request as draft October 7, 2024 23:02
@mori360 mori360 marked this pull request as ready for review October 8, 2024 21:56
@weifengpy
Copy link
Contributor

could we draw loss curves in weights & bias to showcase numerics are the same with/without optimizer-in-the-backward?

@mori360
Copy link
Author

mori360 commented Oct 10, 2024

could we draw loss curves in weights & bias to showcase numerics are the same with/without optimizer-in-the-backward?

The loss curves have been added in the comments section of the third column on the right-hand side table.

@mori360 mori360 marked this pull request as ready for review October 10, 2024 00:59
recipes/full_finetune_distributed.py Outdated Show resolved Hide resolved
recipes/full_finetune_distributed.py Outdated Show resolved Hide resolved
optimizer,
opt_state_dict,
self._device,
if not optimizer_in_bwd:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a nit, but let's switch the order of this if/else so that it's if optimizer_in_bwd. Then it's closer to what's in the single-device recipe and easier to compare between the two

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have the initial zero out at single device at:

# zero out the gradients before starting training
if not self._optimizer_in_bwd:
self._optimizer.zero_grad()

recipes/full_finetune_distributed.py Outdated Show resolved Hide resolved
Comment on lines 708 to 710
raise NotImplementedError(
"Gradient clipping is not supported after optimizer-in-the-backward."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this actually due to optimizer in backward, or something else? I don't think we have such a check in the single-device recipe

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's due to optimizer-in-backward.
optimizer-in-backward calls .step() and .zero_grad() during loss.backward(), thus grads are None at the time after to call torch.nn.utils.clip_grad_norm_, could not have gradient clipping successfully.
Single_device has the same issue that if optimizer_in_bwd=True and clip_grad_norm=True

@@ -722,6 +775,27 @@ def train(self) -> None:

self._profiler.stop()

def get_lr_scheduler(self) -> float:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are gonna have this as a utility I would take it out of this recipe since it's equally applicable to the single-device case. Also I wouldn't call it get_lr_scheduler, since we're really getting the current lr, not the scheduler itself.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Single_device get lr with the assumption that all optimizers have the same LR.

Shall we apply the logic here to single device, rename to get_lr and move to /torchtune/utils?
or just taking the same assumption as single device?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically I think this was also previously the assumption in this recipe, right? Since we just log the LR from the first param group. So we should be able to maintain the same logic for both

version.parse(torch.__version__).base_version < "2.5.0",
reason="torch >= 2.5 required",
)
def test_optimizer_in_backward(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering why we create a whole new test in here rather than adding a distributed test case in the existing TestOptimInBackward test case (or if we want to inherit from FSDPTest, a new distributed version of that class in the same file). Because testing optimizer-in-backward as part of a class that is otherwise meant to test our fully_shard + state dict save and load logic feels a bit strange to me.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test here wants to proof that models running with optimzier-in-backward could have the same performance as running without.
The state_dict saving and loading want to test optimzier-in-backward's wrapper which is a bit different to the traditional optim.
There are not covered in TestOptimInBackward.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want to do such an end-to-end test, maybe we can add a recipe test case instead? See e.g. test_full_finetune_distributed.py here: with this test case you can run the full end-to-end recipe on a small test model and set optimizer_in_bwd=True directly from the config

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants