Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Break BC] Create training directory, move checkpointing #1432

Merged
merged 5 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions docs/source/api_ref_training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,25 @@ torchtune.training

.. currentmodule:: torchtune.training

.. _checkpointing_label:

Checkpointing
-------------

torchtune offers checkpointers to allow seamless transitioning between checkpoint formats for training and interoperability with the rest of the ecosystem. For a comprehensive overview of
checkpointing, please see the :ref:`checkpointing deep-dive <understand_checkpointer>`.

.. autosummary::
:toctree: generated/
:nosignatures:

FullModelHFCheckpointer
FullModelMetaCheckpointer
FullModelTorchTuneCheckpointer
ModelType
update_state_dict_for_classifier


Reduced Precision
------------------

Expand Down
23 changes: 2 additions & 21 deletions docs/source/api_ref_utilities.rst
Original file line number Diff line number Diff line change
@@ -1,28 +1,9 @@
=================
===============
torchtune.utils
=================
===============

.. currentmodule:: torchtune.utils


.. _checkpointing_label:

Checkpointing
-------------

torchtune offers checkpointers to allow seamless transitioning between checkpoint formats for training and interoperability with the rest of the ecosystem. For a comprehensive overview of
checkpointing, please see the :ref:`checkpointing deep-dive <understand_checkpointer>`.

.. autosummary::
:toctree: generated/
:nosignatures:

FullModelHFCheckpointer
FullModelMetaCheckpointer
FullModelTorchTuneCheckpointer
ModelType
update_state_dict_for_classifier

.. _dist_label:

Distributed
Expand Down
22 changes: 11 additions & 11 deletions docs/source/deep_dives/checkpointer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ torchtune supports three different
each of which supports a different checkpoint format.


:class:`HFCheckpointer <torchtune.utils.FullModelHFCheckpointer>`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
:class:`HFCheckpointer <torchtune.training.FullModelHFCheckpointer>`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

This checkpointer reads and writes checkpoints in a format which is compatible with the transformers
framework from Hugging Face. As mentioned above, this is the most popular format within the Hugging Face
Expand Down Expand Up @@ -167,7 +167,7 @@ The following snippet explains how the HFCheckpointer is setup in torchtune conf
checkpointer:

# checkpointer to use
_component_: torchtune.utils.FullModelHFCheckpointer
_component_: torchtune.training.FullModelHFCheckpointer

# directory with the checkpoint files
# this should match the output_dir above
Expand Down Expand Up @@ -205,8 +205,8 @@ The following snippet explains how the HFCheckpointer is setup in torchtune conf

|

:class:`MetaCheckpointer <torchtune.utils.FullModelMetaCheckpointer>`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
:class:`MetaCheckpointer <torchtune.training.FullModelMetaCheckpointer>`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

This checkpointer reads and writes checkpoints in a format which is compatible with the original meta-llama
github repository.
Expand Down Expand Up @@ -237,7 +237,7 @@ The following snippet explains how the MetaCheckpointer is setup in torchtune co
checkpointer:

# checkpointer to use
_component_: torchtune.utils.FullModelMetaCheckpointer
_component_: torchtune.training.FullModelMetaCheckpointer

# directory with the checkpoint files
# this should match the output_dir above
Expand Down Expand Up @@ -265,8 +265,8 @@ The following snippet explains how the MetaCheckpointer is setup in torchtune co

|

:class:`TorchTuneCheckpointer <torchtune.utils.FullModelTorchTuneCheckpointer>`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
:class:`TorchTuneCheckpointer <torchtune.training.FullModelTorchTuneCheckpointer>`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

This checkpointer reads and writes checkpoints in a format that is compatible with torchtune's
model definition. This does not perform any state_dict conversions and is currently used either
Expand Down Expand Up @@ -335,7 +335,7 @@ to the config file
checkpointer:

# checkpointer to use
_component_: torchtune.utils.FullModelHFCheckpointer
_component_: torchtune.training.FullModelHFCheckpointer

checkpoint_dir: <checkpoint_dir>

Expand Down Expand Up @@ -381,7 +381,7 @@ looks something like this:
checkpointer:

# checkpointer to use
_component_: torchtune.utils.FullModelHFCheckpointer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

L430 too

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dam good eye

_component_: torchtune.training.FullModelHFCheckpointer

# directory with the checkpoint files
# this should match the output_dir above
Expand Down Expand Up @@ -427,7 +427,7 @@ For this section we'll use the Llama2 13B model in HF format.
.. code-block:: python

import torch
from torchtune.utils import FullModelHFCheckpointer, ModelType
from torchtune.training import FullModelHFCheckpointer, ModelType
from torchtune.models.llama2 import llama2_13b

# Set the right directory and files
Expand Down
8 changes: 4 additions & 4 deletions docs/source/deep_dives/wandb_logging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,10 @@ A suggested approach would be something like this:
description="Model checkpoint",
# you can add whatever metadata you want as a dict
metadata={
utils.SEED_KEY: self.seed,
utils.EPOCHS_KEY: self.epochs_run,
utils.TOTAL_EPOCHS_KEY: self.total_epochs,
utils.MAX_STEPS_KEY: self.max_steps_per_epoch,
training.SEED_KEY: self.seed,
training.EPOCHS_KEY: self.epochs_run,
training.TOTAL_EPOCHS_KEY: self.total_epochs,
training.MAX_STEPS_KEY: self.max_steps_per_epoch,
}
)
wandb_at.add_file(checkpoint_file)
Expand Down
4 changes: 2 additions & 2 deletions docs/source/tutorials/e2e_flow.rst
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ First, we modify ``custom_eval_config.yaml`` to include the fine-tuned checkpoin
.. code-block:: yaml

checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
_component_: torchtune.training.FullModelHFCheckpointer

# directory with the checkpoint files
# this should match the output_dir specified during
Expand Down Expand Up @@ -262,7 +262,7 @@ Let's modify ``custom_generation_config.yaml`` to include the following changes.
.. code-block:: yaml

checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
_component_: torchtune.training.FullModelHFCheckpointer

# directory with the checkpoint files
# this should match the output_dir specified during
Expand Down
4 changes: 2 additions & 2 deletions docs/source/tutorials/llama3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ Next, we modify ``custom_eval_config.yaml`` to include the fine-tuned checkpoint
_component_: torchtune.models.llama3.llama3_8b

checkpointer:
_component_: torchtune.utils.FullModelMetaCheckpointer
_component_: torchtune.training.FullModelMetaCheckpointer

# directory with the checkpoint files
# this should match the output_dir specified during
Expand Down Expand Up @@ -203,7 +203,7 @@ Now we modify ``custom_generation_config.yaml`` to point to our checkpoint and t
_component_: torchtune.models.llama3.llama3_8b

checkpointer:
_component_: torchtune.utils.FullModelMetaCheckpointer
_component_: torchtune.training.FullModelMetaCheckpointer

# directory with the checkpoint files
# this should match the output_dir specified during
Expand Down
4 changes: 2 additions & 2 deletions docs/source/tutorials/qat_finetune.rst
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ copy and make the following modifications to the quantization config:
_component_: torchtune.models.llama3.llama3_8b

checkpointer:
_component_: torchtune.utils.FullModelMetaCheckpointer
_component_: torchtune.training.FullModelMetaCheckpointer
checkpoint_dir: <your QAT checkpoint dir>
checkpoint_files: [meta_model_0.pt]
recipe_checkpoint: null
Expand Down Expand Up @@ -269,7 +269,7 @@ integrated in torchtune. First, copy the evaluation config and make the followin
_component_: torchtune.models.llama3.llama3_8b

checkpointer:
_component_: torchtune.utils.FullModelTorchTuneCheckpointer
_component_: torchtune.training.FullModelTorchTuneCheckpointer
checkpoint_dir: <your quantized model checkpoint dir>
checkpoint_files: [meta_model_0-8da4w.pt]
recipe_checkpoint: null
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/code_llama2/7B_full_low_memory.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ tokenizer:

# Checkpointer
checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/CodeLlama-7b-hf
checkpoint_files: [
pytorch_model-00001-of-00003.bin,
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/code_llama2/7B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ tokenizer:

# Checkpointer
checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/CodeLlama-7b-hf
checkpoint_files: [
pytorch_model-00001-of-00003.bin,
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/code_llama2/7B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ tokenizer:

# Checkpointer
checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/CodeLlama-7b-hf
checkpoint_files: [
pytorch_model-00001-of-00003.bin,
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/dev/8B_full_experimental.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ model:
_component_: torchtune.models.llama3.llama3_8b

checkpointer:
_component_: torchtune.utils.FullModelMetaCheckpointer
_component_: torchtune.training.FullModelMetaCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3-8B/original/
checkpoint_files: [
consolidated.00.pth
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/dev/llama2/13B_lora_fsdp2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ model:
lora_alpha: 16

checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Llama-2-13b-hf/
checkpoint_files: [
pytorch_model-00001-of-00003.bin,
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/dev/llama2/70B_lora_fsdp2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ tokenizer:
max_seq_len: null

checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Llama-2-70b-hf
checkpoint_files: [
pytorch_model-00001-of-00015.bin,
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/dev/llama2/70B_qlora_fsdp2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ tokenizer:
max_seq_len: null

checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Llama-2-70b-hf
checkpoint_files: [
pytorch_model-00001-of-00015.bin,
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/dev/llama2/7B_lora_fsdp2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ tokenizer:
max_seq_len: null

checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Llama-2-7b-hf
checkpoint_files: [
pytorch_model-00001-of-00002.bin,
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/dev/llama2/7B_qlora_fsdp2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ tokenizer:
max_seq_len: null

checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Llama-2-7b-hf
checkpoint_files: [
pytorch_model-00001-of-00002.bin,
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/eleuther_evaluation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ model:
_component_: torchtune.models.llama2.llama2_7b

checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Llama-2-7b-hf
checkpoint_files: [
pytorch_model-00001-of-00002.bin,
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/gemma/2B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ model:
_component_: torchtune.models.gemma.gemma_2b

checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/gemma-2b/
checkpoint_files: [
model-00001-of-00002.safetensors,
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/gemma/2B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ model:
lora_alpha: 16

checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/gemma-2b/
checkpoint_files: [
model-00001-of-00002.safetensors,
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/gemma/2B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ model:
lora_alpha: 16

checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/gemma-2b/
checkpoint_files: [
model-00001-of-00002.safetensors,
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/gemma/2B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ model:
lora_alpha: 16

checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/gemma-2b/
checkpoint_files: [
model-00001-of-00002.safetensors,
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/gemma/7B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ model:
_component_: torchtune.models.gemma.gemma_7b

checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/gemma-7b/
checkpoint_files: [
model-00001-of-00004.safetensors,
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/gemma/7B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ model:
lora_alpha: 16

checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/gemma-7b/
checkpoint_files: [
model-00001-of-00004.safetensors,
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/gemma/7B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ model:
lora_alpha: 16

checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/gemma-7b/
checkpoint_files: [
model-00001-of-00004.safetensors,
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/gemma/7B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ model:
lora_alpha: 16

checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/gemma-7b/
checkpoint_files: [
model-00001-of-00004.safetensors,
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/generation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ model:
_component_: torchtune.models.llama2.llama2_7b

checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Llama-2-7b-hf/
checkpoint_files: [
pytorch_model-00001-of-00002.bin,
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama2/13B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ model:
_component_: torchtune.models.llama2.llama2_13b

checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Llama-2-13b-hf/
checkpoint_files: [
pytorch_model-00001-of-00003.bin,
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama2/13B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ model:
lora_alpha: 16

checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Llama-2-13b-hf/
checkpoint_files: [
pytorch_model-00001-of-00003.bin,
Expand Down
Loading
Loading