Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

When profile_memory, also export and save the snapshot.pickle for lora_finetune_single_device.py #1382

Merged
merged 2 commits into from
Aug 21, 2024

Conversation

janeyx99
Copy link
Contributor

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Please link to any issues this PR addresses.

Changelog

What are the changes made in this PR?

  • enable the snapshot.pickle to be exported when profile_memory is True.
  • the above required adding attributes to self for recipes/lora_finetune_single_device.py
  • needing to export the snapshot.pickle does mean that it takes a little longer to export.

Previously, you only had the html from the profiler, which is not interactive:
image

Now, you get that in addition to:
image

Test plan

Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.)

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Example of docstring:


Example in our docs: https://pytorch.org/torchtune/main/tutorials/qat_finetune.html#applying-qat-to-llama3-models

  • I did not change any public API;
  • I have added an example to docs or docstrings;

Copy link

pytorch-bot bot commented Aug 20, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1382

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit e797f3b with merge base 3c580fc (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 20, 2024
Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this! Can you fix the linter complaints? After that this should be good to go

Copy link
Contributor

@felipemello1 felipemello1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ebsmothers, wdyt? We were missing the memory snapshot. My only concern is hardcoding torch.cuda if we plan to support xpu

Edit: you already approved it :P

@@ -579,6 +584,9 @@ def train(self) -> None:
):
break

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Memory profiling
if curr_epoch == 0 and self.profiler_profile_memory and idx == (self.profiler_wait_steps + self.profiler_warmup_steps):
torch.cuda.memory._record_memory_history()

@@ -579,6 +584,9 @@ def train(self) -> None:
):
break

if curr_epoch == 0 and self.profiler_profile_memory and idx == self.profiler_wait_steps + self.profiler_warmup_steps:
torch.cuda.memory._record_memory_history()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is if in the future we support other backend, like intel. It makes me think that this should be a function, instead of being hardcoded as "torch.cuda"

cc: @ebsmothers

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is if in the future we support other backend, like intel. It makes me think that this should be a function, instead of being hardcoded as "torch.cuda"

cc: @ebsmothers

I think it's OK, idk that memory snapshot would be supported on XPU anyways. We can do a check on device + profile_memory combo just to be safe though

@@ -626,6 +634,9 @@ def train(self) -> None:
num_tokens = 0
t0 = time.perf_counter()

if curr_epoch == 0 and self.profiler_profile_memory and idx == self.profiler_wait_steps + self.profiler_warmup_steps + self.profiler_active_steps:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if curr_epoch == 0 and self.profiler_profile_memory and idx == self.profiler_wait_steps + self.profiler_warmup_steps + self.profiler_active_steps:
# Stop memory profiling
if curr_epoch == 0 and self.profiler_profile_memory and idx == self.profiler_wait_steps + self.profiler_warmup_steps + self.profiler_active_steps:

@felipemello1
Copy link
Contributor

should we just delete the HTML to save time, since it takes a while to save? Or is it useful?

@@ -328,6 +328,11 @@ def _setup_profiler(

log.info(f" Profiler config after instantiation: {profiler_cfg}")

self.profiler_wait_steps = profiler_cfg["wait_steps"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be throwing an error in our recipe tests, can you just quickly inspect profiler_cfg after running e.g.

pytest tests/recipes/test_lora_finetune_single_device.py -m integration_test -k 'test_loss'

I thought these fields would be defined based on this, but seems like something weird is happening here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, it's because profiling isn't enabled for tests, so it returns here

return DummyProfiler(), DictConfig({"enabled": False})

@janeyx99
Copy link
Contributor Author

should we just delete the HTML to save time, since it takes a while to save? Or is it useful?

Writing the html is very light so it doesn't make a tangible difference, I'm not sure if people are using it today so it'd be bad to just remove it if there were actual users.

@codecov-commenter
Copy link

Codecov Report

Attention: Patch coverage is 0% with 10 lines in your changes missing coverage. Please review.

Project coverage is 70.56%. Comparing base (3c580fc) to head (e797f3b).
Report is 5 commits behind head on main.

Files Patch % Lines
recipes/lora_finetune_single_device.py 0.00% 9 Missing ⚠️
torchtune/utils/_profiler.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1382      +/-   ##
==========================================
- Coverage   72.57%   70.56%   -2.01%     
==========================================
  Files         271      272       +1     
  Lines       12736    12895     +159     
==========================================
- Hits         9243     9100     -143     
- Misses       3493     3795     +302     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@facebook-github-bot
Copy link

@janeyx99 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@janeyx99 janeyx99 merged commit 9e65fa9 into pytorch:main Aug 21, 2024
20 checks passed
@janeyx99 janeyx99 deleted the snapshot-pickle branch August 21, 2024 16:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants