-
Notifications
You must be signed in to change notification settings - Fork 404
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Break BC] Create training directory, move checkpointing #1432
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1432
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 9dfe15f with merge base 929a45a (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1432 +/- ##
===========================================
+ Coverage 27.00% 72.24% +45.23%
===========================================
Files 268 269 +1
Lines 12923 12930 +7
===========================================
+ Hits 3490 9341 +5851
+ Misses 9433 3589 -5844 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few small comments, but modulo resolving merge conflicts and (still) green CI this looks good to me
torchtune/training/__init__.py
Outdated
__all__ = [ | ||
"FullModelHFCheckpointer", | ||
"FullModelMetaCheckpointer", | ||
"FullModelTorchTuneCheckpointer", | ||
"ModelType", | ||
"Checkpointer", | ||
"update_state_dict_for_classifier", | ||
"ADAPTER_CONFIG", | ||
"ADAPTER_KEY", | ||
"EPOCHS_KEY", | ||
"MAX_STEPS_KEY", | ||
"MODEL_KEY", | ||
"OPT_KEY", | ||
"RNG_KEY", | ||
"SEED_KEY", | ||
"STEPS_KEY", | ||
"TOTAL_EPOCHS_KEY", | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is different than what we do now, right? Rn we don't include checkpointer APIs in the parent's __all__
. Not saying it's wrong to do it this way but just wanna understand the rationale for the change
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mainly to keep it consistent, but not opposed to removing these
@@ -5,12 +5,12 @@ | |||
# LICENSE file in the root directory of this source tree. | |||
from typing import Union | |||
|
|||
from torchtune.utils._checkpointing._checkpointer import ( | |||
from torchtune.training.checkpointing._checkpointer import ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where'd my underscore go
@@ -381,7 +381,7 @@ looks something like this: | |||
checkpointer: | |||
|
|||
# checkpointer to use | |||
_component_: torchtune.utils.FullModelHFCheckpointer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
L430 too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dam good eye
@@ -13,13 +13,13 @@ | |||
|
|||
import torch | |||
from safetensors.torch import save_file | |||
from torchtune import utils | |||
from torchtune import training |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not a huge deal (given the scope of other stuff you're doing here) but it's a little weird to me that we use training.MODEL_KEY
etc. when all these things are literally defined in the local directory. Feels like needless indirection to me
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
eh, later problem
Context
The big kahuna of refactors.
Motivation is discussed extensively in #1421. Here, we only move the checkpointing directory into training. This should break most cycles since it depends on models. Unfortunately, this touches all configs... will run a few to make sure none break.
All references to
torchtune.utils.FullModelXXXCheckpointer
now becomestorchtune.training.FullModelXXXCheckpointer
Test plan
Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.)
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Example of docstring:
torchtune/torchtune/modules/vision_transformer.py
Line 285 in 6a7951f
Example in our docs: https://pytorch.org/torchtune/main/tutorials/qat_finetune.html#applying-qat-to-llama3-models