Skip to content

Commit

Permalink
Adding classifier checkpointing utils - renaming peft/peft_utils.py (
Browse files Browse the repository at this point in the history
  • Loading branch information
SalmanMohammadi authored Aug 19, 2024
1 parent 8bb3a6f commit 3c580fc
Show file tree
Hide file tree
Showing 19 changed files with 255 additions and 44 deletions.
1 change: 1 addition & 0 deletions docs/source/api_ref_utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ checkpointing, please see the :ref:`checkpointing deep-dive <understand_checkpoi
FullModelMetaCheckpointer
FullModelTorchTuneCheckpointer
ModelType
update_state_dict_for_classifier

.. _dist_label:

Expand Down
20 changes: 11 additions & 9 deletions recipes/dev/lora_finetune_fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import config, modules, utils
from torchtune.datasets import ConcatDataset
from torchtune.modules.peft import LoRALinear
from torchtune.modules.peft.peft_utils import (
from torchtune.modules.peft import (
get_adapter_params,
get_lora_module_names,
get_merged_lora_ckpt,
LoRALinear,
set_trainable_params,
validate_missing_and_unexpected_for_lora,
)
Expand Down Expand Up @@ -434,13 +434,15 @@ def _setup_data(
dataset=ds,
batch_size=batch_size,
sampler=sampler,
collate_fn=partial(
utils.padded_collate,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else None,
collate_fn=(
partial(
utils.padded_collate,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else None
),
)

if self._is_rank_zero:
Expand Down
2 changes: 1 addition & 1 deletion recipes/lora_dpo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from torchtune.data import CROSS_ENTROPY_IGNORE_IDX
from torchtune.datasets import ConcatDataset
from torchtune.modules import rlhf
from torchtune.modules.peft.peft_utils import (
from torchtune.modules.peft import (
disable_adapter,
get_adapter_params,
get_merged_lora_ckpt,
Expand Down
2 changes: 1 addition & 1 deletion recipes/lora_dpo_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from torchtune.data import CROSS_ENTROPY_IGNORE_IDX
from torchtune.datasets import ConcatDataset
from torchtune.modules import rlhf
from torchtune.modules.peft.peft_utils import (
from torchtune.modules.peft import (
disable_adapter,
get_adapter_params,
get_merged_lora_ckpt,
Expand Down
2 changes: 1 addition & 1 deletion recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import config, modules, utils
from torchtune.datasets import ConcatDataset
from torchtune.modules.peft.peft_utils import (
from torchtune.modules.peft import (
get_adapter_params,
get_lora_module_names,
get_merged_lora_ckpt,
Expand Down
2 changes: 1 addition & 1 deletion recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import config, modules, utils
from torchtune.datasets import ConcatDataset
from torchtune.modules.peft.peft_utils import (
from torchtune.modules.peft import (
get_adapter_params,
get_lora_module_names,
get_merged_lora_ckpt,
Expand Down
32 changes: 14 additions & 18 deletions recipes/ppo_full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def setup(self, cfg: DictConfig) -> None:
self._value_model,
self._reward_model,
self._ref_policy_model,
) = self._setup_model(
) = self._setup_models(
cfg_model=cfg.policy_model,
cfg_reward_value_model=cfg.reward_and_value_model,
enable_activation_checkpointing=cfg.enable_activation_checkpointing,
Expand Down Expand Up @@ -394,7 +394,7 @@ def _setup_checkpointers(
reward_checkpointer,
)

def _setup_model(
def _setup_models(
self,
cfg_model: DictConfig,
cfg_reward_value_model: DictConfig,
Expand Down Expand Up @@ -426,24 +426,20 @@ def _setup_model(
policy_model.load_state_dict(policy_state_dict)
ref_policy_model.load_state_dict(ref_policy_state_dict)

reward_missing, reward_unexpected = reward_model.load_state_dict(
reward_model_state_dict, strict=False
)
value_missing, value_unexpected = value_model.load_state_dict(
value_model_state_dict, strict=False
# since we should be loading a classifier checkpoint into
# a classifier model, this function should just ensure
# output.weight appears in the state_dict and the model's parameters,
# and removes output.bias from the state dict if found
utils.update_state_dict_for_classifier(
reward_model_state_dict, reward_model.named_parameters()
)
reward_model.load_state_dict(reward_model_state_dict)

# some extra validation for HF classifier checkpoints with a `score.bias` present
assert (
reward_missing == value_missing == []
), f"Missing keys in reward ({reward_missing}) and value model ({value_missing}) state dicts."

if reward_unexpected or value_unexpected:
# the only unexpected keys should be when pre-trained HF models were saved with
# bias=True in final classification layers. This happens when training a reward model with TRL.
assert (
reward_unexpected == value_unexpected == ["output.bias"]
), f"Unexpected keys in reward ({reward_unexpected}) and value model ({value_unexpected}) state dicts."
# same as above
utils.update_state_dict_for_classifier(
value_model_state_dict, value_model.named_parameters()
)
value_model.load_state_dict(value_model_state_dict)

# Validate models were loaded in with the expected dtype.
utils.validate_expected_param_dtype(
Expand Down
3 changes: 1 addition & 2 deletions tests/torchtune/models/llama2/test_lora_llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
from torchtune.models.llama2 import llama2, lora_llama2
from torchtune.models.llama2._component_builders import lora_llama2_self_attention
from torchtune.modules.low_precision import FrozenNF4Linear
from torchtune.modules.peft import LoRALinear
from torchtune.modules.peft.peft_utils import get_merged_lora_ckpt
from torchtune.modules.peft import get_merged_lora_ckpt, LoRALinear
from torchtune.utils.seed import set_seed

RANK = 4
Expand Down
3 changes: 1 addition & 2 deletions tests/torchtune/models/phi3/test_lora_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
from torchtune import utils
from torchtune.models.phi3 import lora_phi3, phi3
from torchtune.models.phi3._component_builders import lora_phi3_self_attention
from torchtune.modules.peft import LoRALinear
from torchtune.modules.peft.peft_utils import get_merged_lora_ckpt
from torchtune.modules.peft import get_merged_lora_ckpt, LoRALinear
from torchtune.utils.seed import set_seed

RANK = 4
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@

from torch import nn
from torchtune.models.llama2 import llama2, lora_llama2
from torchtune.modules.peft import LoRALinear
from torchtune.modules.peft.peft_utils import (
from torchtune.modules.peft import (
AdapterModule,
disable_adapter,
get_adapter_params,
get_merged_lora_ckpt,
LoRALinear,
set_trainable_params,
validate_missing_and_unexpected_for_lora,
validate_state_dict_for_lora,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torch import randn

from torchtune.models import gemma, llama2, mistral
from torchtune.modules.peft.peft_utils import (
from torchtune.modules.peft import (
get_adapter_params,
get_lora_module_names,
validate_missing_and_unexpected_for_lora,
Expand Down
148 changes: 148 additions & 0 deletions tests/torchtune/utils/_checkpointing/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from copy import deepcopy

import pytest
import torch
from torchtune.models.llama2 import llama2, llama2_classifier
from torchtune.utils import update_state_dict_for_classifier

N_LAYERS = 3
IN_DIM = 5
OUT_DIM = 10
VOCAB_SIZE = 50
NUM_HEADS = 4
NUM_KV_HEADS = 2
EMBED_DIM = 64
MAX_SEQ_LEN = 64
NUM_CLASSES = 6


class TestUpdateStateDictForClassifer:
@pytest.fixture()
def llama2_state_dict(self):
model = llama2(
vocab_size=VOCAB_SIZE,
num_layers=N_LAYERS,
num_heads=NUM_KV_HEADS,
num_kv_heads=NUM_KV_HEADS,
embed_dim=EMBED_DIM,
max_seq_len=MAX_SEQ_LEN,
)
return model.state_dict()

@pytest.fixture()
def llama2_classifier_model(self):
return llama2_classifier(
num_classes=NUM_CLASSES,
vocab_size=VOCAB_SIZE,
num_layers=N_LAYERS,
num_heads=NUM_KV_HEADS,
num_kv_heads=NUM_KV_HEADS,
embed_dim=EMBED_DIM,
max_seq_len=MAX_SEQ_LEN,
)

def test_bias_in_classifier_checkpoint_is_removed(self, llama2_classifier_model):
# construct bogus state dict with output.bias included
state_dict_with_bias = llama2_classifier_model.state_dict().copy()
state_dict_with_bias["output.bias"] = torch.tensor([NUM_CLASSES])

# function should remove output.bias
update_state_dict_for_classifier(
state_dict_with_bias, llama2_classifier_model.named_parameters()
)

assert "output.bias" not in state_dict_with_bias

def test_loading_base_checkpoint_into_classifier(
self, llama2_state_dict, llama2_classifier_model
):
# grabbing the expected output.weight - the correct outcome here
# is for all weights aside from output.weight to be loaded in
# from the base model, so output.weight will remain in its rand init state
expected_output_weight = llama2_classifier_model.state_dict()[
"output.weight"
].clone()

# update the state dict to load with the classifier's output.weight
update_state_dict_for_classifier(
llama2_state_dict, llama2_classifier_model.named_parameters()
)

# load in all the base params
llama2_classifier_model.load_state_dict(llama2_state_dict)

# now we can assert that output.weight was unchanged
output_weight = llama2_classifier_model.state_dict()["output.weight"]
assert torch.equal(expected_output_weight, output_weight)

def test_assertion_error_when_missing_output_in_state_dict(
self, llama2_state_dict, llama2_classifier_model
):
llama2_state_dict.pop("output.weight")
with pytest.raises(
AssertionError, match="Expected output.weight in state_dict"
):
update_state_dict_for_classifier(
llama2_state_dict, llama2_classifier_model.named_parameters()
)

def test_assertion_error_when_missing_output_in_model_named_parameters(
self, llama2_state_dict, llama2_classifier_model
):
named_params = [
(k, v)
for (k, v) in llama2_classifier_model.named_parameters()
if k != "output.weight"
]
with pytest.raises(
AssertionError, match="Expected output.weight in model_named_parameters"
):
update_state_dict_for_classifier(llama2_state_dict, named_params)

def test_loading_classifier_weights(self, llama2_classifier_model):
state_dict_to_load = deepcopy(llama2_classifier_model.state_dict())
state_dict_to_load["output.weight"] = torch.ones_like(
state_dict_to_load["output.weight"]
)

update_state_dict_for_classifier(
state_dict_to_load, llama2_classifier_model.named_parameters()
)
llama2_classifier_model.load_state_dict(state_dict_to_load)

model_state_dict = llama2_classifier_model.state_dict()

assert set(model_state_dict.keys()) == set(state_dict_to_load.keys())
assert torch.equal(
model_state_dict["output.weight"],
torch.ones_like(model_state_dict["output.weight"]),
)

def test_loading_classifier_weights_force_override(self, llama2_classifier_model):
state_dict_to_load = deepcopy(llama2_classifier_model.state_dict())
state_dict_to_load["output.weight"] = torch.ones_like(
state_dict_to_load["output.weight"]
)

expected_output_weight = llama2_classifier_model.state_dict()[
"output.weight"
].clone()

update_state_dict_for_classifier(
state_dict_to_load, llama2_classifier_model.named_parameters(), True
)
llama2_classifier_model.load_state_dict(state_dict_to_load)

model_state_dict = llama2_classifier_model.state_dict()

assert set(model_state_dict.keys()) == set(state_dict_to_load.keys())
assert torch.equal(model_state_dict["output.weight"], expected_output_weight)


#
3 changes: 1 addition & 2 deletions tests/torchtune/utils/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@
from torchtune.models.llama2._component_builders import llama2, lora_llama2
from torchtune.models.llama3._component_builders import llama3
from torchtune.modules import TransformerSelfAttentionLayer
from torchtune.modules.peft import LoRALinear
from torchtune.modules.peft.peft_utils import get_adapter_params, set_trainable_params
from torchtune.modules.peft import get_adapter_params, LoRALinear, set_trainable_params


class TestDistributed:
Expand Down
8 changes: 6 additions & 2 deletions torchtune/modules/peft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,18 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from .lora import LoRALinear
from .peft_utils import ( # noqa
from ._utils import ( # noqa
AdapterModule,
disable_adapter,
get_adapter_params,
get_lora_module_names,
get_merged_lora_ckpt,
LORA_ATTN_MODULES,
set_trainable_params,
validate_missing_and_unexpected_for_lora,
validate_state_dict_for_lora,
)
from .lora import LoRALinear

__all__ = [
"LoRALinear",
Expand All @@ -23,4 +25,6 @@
"validate_missing_and_unexpected_for_lora",
"validate_state_dict_for_lora",
"disable_adapter",
"get_merged_lora_ckpt",
"get_lora_module_names",
]
File renamed without changes.
2 changes: 1 addition & 1 deletion torchtune/modules/peft/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from torchao.dtypes.nf4tensor import linear_nf4, to_nf4
from torchtune.modules.low_precision import _register_nf4_dispatch_ops # noqa: F401
from torchtune.modules.peft.peft_utils import AdapterModule
from torchtune.modules.peft import AdapterModule


class LoRALinear(nn.Module, AdapterModule):
Expand Down
2 changes: 2 additions & 0 deletions torchtune/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
FullModelMetaCheckpointer,
FullModelTorchTuneCheckpointer,
ModelType,
update_state_dict_for_classifier,
)

from ._device import get_device
Expand Down Expand Up @@ -72,6 +73,7 @@
from .seed import set_seed

__all__ = [
"update_state_dict_for_classifier",
"get_memory_stats",
"FSDPPolicyType",
"log_memory_stats",
Expand Down
Loading

0 comments on commit 3c580fc

Please sign in to comment.