Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable QLoRA + FSDP2 #909

Merged
merged 100 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from 51 commits
Commits
Show all changes
100 commits
Select commit Hold shift + click to select a range
e5826a1
enable LoRA + FSDP2
weifengpy Apr 24, 2024
64fc870
reset params for lora weights and rope
weifengpy Apr 24, 2024
0cd21c6
support lora weights checkpoint and checkpoint utils
weifengpy Apr 24, 2024
589191e
fix lora meta device bug
weifengpy Apr 24, 2024
c801f26
save optim state dict
weifengpy Apr 25, 2024
19a2d70
mark TODO
weifengpy Apr 25, 2024
441da10
optimizer foreach=True for DTensor
weifengpy Apr 25, 2024
750b9e5
clip grad norm
weifengpy Apr 25, 2024
3d632d5
switch to ptd state dict api
weifengpy Apr 26, 2024
cb3abb3
add profiler
weifengpy May 1, 2024
dfcdde3
qlora 7b config
weifengpy May 1, 2024
e68804a
use torchao copy_
weifengpy May 1, 2024
b6fad93
Merge pull request #1 from weifengpy/fsdp2
weifengpy May 1, 2024
d6af9a2
enable saving checkpoint
weifengpy May 1, 2024
7bbe522
Merge pull request #2 from weifengpy/fsdp2
weifengpy May 1, 2024
b616394
optimizer state dict: load on rank0 and broadcast
weifengpy May 1, 2024
a400497
import Optimizer
weifengpy May 1, 2024
e9de63c
resume training
weifengpy May 3, 2024
05d3895
prepare for full test
weifengpy May 3, 2024
7a5bb80
prepare for full test
weifengpy May 3, 2024
64bf49c
remove profiler
weifengpy May 3, 2024
cb1bba4
passed integration test
weifengpy May 4, 2024
ac516e9
remove uncesssary change
weifengpy May 4, 2024
bfde704
Merge branch 'main' into fsdp2
weifengpy May 4, 2024
102db31
bring back state dict validation
weifengpy May 4, 2024
0b66651
align indent on comment
weifengpy May 4, 2024
672aabb
remove unused import
weifengpy May 4, 2024
6af2723
switch to ptd state dict and keep self implemented in record
weifengpy May 8, 2024
42ad99c
clean unused code
weifengpy May 8, 2024
74f6175
remove cuda value error
weifengpy May 8, 2024
f1b8a5e
comment on to_empty
weifengpy May 8, 2024
36e6829
fix memory issues by switching model state dict api
weifengpy May 8, 2024
08cd1fd
clean for review
weifengpy May 8, 2024
559bc4d
Merge branch 'main' into fsdp2
weifengpy May 8, 2024
2333134
fix linter
weifengpy May 9, 2024
49a0364
fix checkpoint loading
weifengpy May 9, 2024
dc2ce02
expecttest CI depedency
weifengpy May 9, 2024
0a604aa
ci depdencecy
weifengpy May 9, 2024
fa83140
fix CI issue
weifengpy May 10, 2024
6203a1f
Merge branch 'main' into qlora
weifengpy May 10, 2024
4b5a895
Merge branch 'pytorch:main' into fsdp2
weifengpy May 10, 2024
1080e2c
Merge branch 'fsdp2' into qlora
weifengpy May 10, 2024
1a70498
rebase qlora
weifengpy May 10, 2024
cb862e9
rebase qlora
weifengpy May 10, 2024
21f5458
sync lora changes
weifengpy May 14, 2024
33773bd
push qlora for perf measurement
weifengpy May 14, 2024
483028b
fix meta init + cpu offloading
weifengpy May 15, 2024
cf42618
init RotaryPositionalEmbeddings in both fresh training and resume
weifengpy May 15, 2024
b519d50
import cpu offloading when needed
weifengpy May 17, 2024
8600ced
FSDP(CheckpointWrapper(Model))
weifengpy May 22, 2024
b2fd531
bring back cpu offloading
weifengpy May 22, 2024
bb8a8bc
remove model.to
weifengpy May 29, 2024
db71c5c
apply nf4 when loading model state dict
weifengpy May 30, 2024
16bf2de
move lora to cpu when cpu offloading
weifengpy May 30, 2024
df6e535
Update documentation tab to point to main instead of stable (#960)
kartikayk May 11, 2024
5f621e1
Update tokens_per_sec to tokens_per_sec_per_gpu (#956)
kartikayk May 11, 2024
7d92b1c
Delete init_weights_with_constant test util (#974)
ebsmothers May 13, 2024
588871e
Sample packing for map datasets with correct RoPE encoding and no cro…
RdoubleA May 15, 2024
1a5bf1a
Utilize compile on model rather than via torch API (#953)
joecummings May 15, 2024
ae7de20
Add better formatting for Eleuther eval results (#986)
joecummings May 16, 2024
23cea56
updating help docs for hf-token arg in download.py (#991)
SalmanMohammadi May 16, 2024
be06efa
Fix position embeddings for Phi3 when packing + nits (#992)
RdoubleA May 17, 2024
79ef995
Llama3-8b memory efficient full finetune (#990)
rohan-varma May 17, 2024
5f55c16
Fix Gemma 2B model forward call (#998)
joecummings May 17, 2024
b88fa2d
fix: lora dropout applied to all models (#995)
Optimox May 17, 2024
b47ee93
fix: different rope base between phi3 and lora_phi3 (#997)
Optimox May 17, 2024
d86b454
Add support for free generation tasks in evals (#975)
joecummings May 19, 2024
2b109f4
Filter out special tokens and placeholder tokens for Phi-3 (#983)
joecummings May 20, 2024
9bd07a6
TorchTune --> torchtune (#1007)
joecummings May 20, 2024
f5cb12e
Support for unstructured text corpus datasets for CPT (#868)
RdoubleA May 21, 2024
a2066f9
Save adapter config and remapped adapter weights for loading into PEF…
ebsmothers May 21, 2024
29d1761
Datasets tutorial improvements (#994)
RdoubleA May 21, 2024
3a01d7f
Fix TypeError: tuple indices must be integers or slices, not str issu…
tambulkar May 22, 2024
1d6b4a2
Add recipe test for llama3 (#929)
SLR722 May 23, 2024
c74c9a9
Fix the Gemma generation (#1016)
solitude-alive May 24, 2024
00f96ff
Update chat tutorial so that it works as is (#1004)
christobill May 28, 2024
62192df
[fix] llama3 70B_lora update commented instructions (#1030)
pbontrager May 30, 2024
ecd5e7e
Move nf4 op registration from utils to modules (#1035)
ebsmothers May 31, 2024
99c549b
feat: add gemma7b support (#971)
Optimox May 31, 2024
7d11a89
Llama3-70b: Full Finetune w/CPU offload + fused optimizer (#993)
rohan-varma Jun 1, 2024
0080795
enable LoRA + FSDP2 (#855)
weifengpy Jun 3, 2024
00360f7
Merge branch 'weifengpy-qlora' into qlora
weifengpy Jun 4, 2024
d8664a3
Merge branch 'main' into qlora
weifengpy Jun 4, 2024
f58f9b2
rebase
weifengpy Jun 4, 2024
b9bfd41
revert lora_finetune_distributed.py
weifengpy Jun 4, 2024
7a3d9a1
rebase and register recipe
weifengpy Jun 4, 2024
2835d2a
del logits to save memory
weifengpy Jun 4, 2024
559b81d
fix linter
weifengpy Jun 4, 2024
85f978b
gate NF4.copy_ on TorchAO==0.2.0
weifengpy Jun 4, 2024
dbae23c
improve torchao gating comment
weifengpy Jun 4, 2024
f4a8dfa
upgrade torchao to 0.2
weifengpy Jun 4, 2024
10e304d
gate torchao 0.2
weifengpy Jun 4, 2024
e117a21
replace with lora_finetune_fsdp2
weifengpy Jun 4, 2024
4bb5e0f
add llama2-70B
weifengpy Jun 4, 2024
174d916
replace with qlora and lora_finetune_fsdp2 in yaml
weifengpy Jun 4, 2024
5fdcefb
rename yaml to _fsdp2.yaml
weifengpy Jun 5, 2024
b878018
add unit test for nf4 state dict
weifengpy Jun 5, 2024
a8f1a9a
python 3.8 style dict union
weifengpy Jun 5, 2024
ae49684
validate lora sd missing
weifengpy Jun 5, 2024
cbb3da8
skip test if <2 gpu
weifengpy Jun 5, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,13 @@ tune = "torchtune._cli.tune:main"
dev = [
"bitsandbytes>=0.43.0",
"pre-commit",
"pytest",
"pytest==7.4.0",
"pytest-cov",
"pytest-mock",
"pytest-integration",
"tensorboard",
"wandb",
"expecttest==0.1.6",
]

[tool.setuptools.dynamic]
Expand Down
3 changes: 3 additions & 0 deletions recipes/configs/llama2/7B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ lr_scheduler:
loss:
_component_: torch.nn.CrossEntropyLoss

fsdp:
cpu_offload: False

# Training
epochs: 1
max_steps_per_epoch: null
Expand Down
171 changes: 80 additions & 91 deletions recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,32 @@
import time

from functools import partial
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional, Tuple, Union
from warnings import warn

import torch
from omegaconf import DictConfig, ListConfig

from torch import nn
from torch.distributed import destroy_process_group, init_process_group
from torch.distributed.fsdp import (
FullOptimStateDictConfig,
FullStateDictConfig,
FullyShardedDataParallel as FSDP,
StateDictType,
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointWrapper,
)
from torch.distributed.checkpoint.state_dict import (
get_optimizer_state_dict,
StateDictOptions,
)

from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import config, modules, utils
from torchtune.datasets import ConcatDataset
from torchtune.modules.peft import LoRALinear
from torchtune.modules.peft.peft_utils import (
get_adapter_params,
get_merged_lora_ckpt,
set_trainable_params,
validate_state_dict_for_lora,
)
from torchtune.recipe_interfaces import FTRecipeInterface

Expand Down Expand Up @@ -213,6 +216,7 @@ def setup(self, cfg: DictConfig) -> None:
if self._resume_from_checkpoint
else None
),
cfg_fsdp=cfg.fsdp if hasattr(cfg, "fsdp") else None,
)
self._tokenizer = config.instantiate(cfg.tokenizer)

Expand Down Expand Up @@ -264,59 +268,69 @@ def _setup_model(
enable_activation_checkpointing: bool,
base_model_state_dict: Dict[str, Any],
lora_weights_state_dict: Optional[Dict[str, Any]] = None,
cfg_fsdp: Optional[Union[DictConfig, None]] = None,
) -> nn.Module:
"""
Model initialization has some important considerations:
a. To minimize GPU peak memory, we load the model on CPU with the right
dtype. To ensure that we don't instantiate ``world_size`` number of models,
we initialize on meta_device for all ranks other than rank 0.
b. Rank 0 is also responsible for calling ``load_state_dict`` and loading the
model weights from checkpoint.
c. While wrapping the model with FSDP, we set ``sync_module_states``
to TRUE and broadcast module params and buffers from rank 0.
d. The ``device_id`` param ensures that the FSDP initialization happens on
the correct device.
a. To minimize GPU peak memory, we initialize the model on meta device with
the right dtype
b. All ranks calls ``load_state_dict`` without peaking CPU RAMs since
full state dicts are loaded with ``torch.load(mmap=True)``
c. We register (pre-)forward hooks with ``fully_shard`` instead of wrapping `nn.Module`
"""

if self._is_rank_zero:
log.info("FSDP is enabled. Instantiating Model on CPU for Rank 0 ...")
log.info(
"FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ..."
)
init_start = time.perf_counter()

with utils.set_default_dtype(self._dtype):
model = config.instantiate(cfg_model)
with utils.set_default_dtype(self._dtype), torch.device("meta"):
model = config.instantiate(cfg_model)

log.info(
f"Model instantiation took {time.perf_counter() - init_start:.2f} secs"
)
self.adapter_params = get_adapter_params(model)
set_trainable_params(model, self.adapter_params)

# The model contains LoRA params which won't have any matching keys in
# the state dict. As a result, we need to load with strict=False.
# Before loading the state dict, ensure the state dict keys for the base
# model and adapters (if available) match the keys in the full LoRA model
# This is a good sanity check to prevent silent errors
validate_state_dict_for_lora(
lora_attn_modules=cfg_model.lora_attn_modules,
apply_lora_to_mlp=cfg_model.apply_lora_to_mlp,
apply_lora_to_output=getattr(cfg_model, "apply_lora_to_output", False),
full_model_state_dict_keys=model.state_dict().keys(),
lora_state_dict_keys=(
lora_weights_state_dict.keys()
if lora_weights_state_dict is not None
else None
),
base_model_state_dict_keys=base_model_state_dict.keys(),
if enable_activation_checkpointing:
utils.set_activation_checkpointing(
model, auto_wrap_policy={modules.TransformerDecoderLayer}
)

# Load both the base model weights and (if available) the adapter weights. Both
# of this should happen only on Rank 0
model.load_state_dict(base_model_state_dict, strict=False)
if lora_weights_state_dict:
model.load_state_dict(lora_weights_state_dict, strict=False)
fsdp_kwargs = {}
if cfg_fsdp and cfg_fsdp.cpu_offload:
from torch.distributed._composable.fsdp import CPUOffloadPolicy

else:
# For non-zero ranks, load the model on meta device
with utils.set_default_dtype(self._dtype), torch.device("meta"):
model = config.instantiate(cfg_model)
fsdp_kwargs["offload_policy"] = CPUOffloadPolicy()

for m in reversed(list(model.modules())):
if isinstance(m, nn.Linear) and m.weight.requires_grad:
fully_shard(m, **fsdp_kwargs)
# TransformerDecoderLayer is wrapped by CheckpointWrapper
# when enable_activation_checkpointing
if enable_activation_checkpointing:
if isinstance(m, CheckpointWrapper):
fully_shard(m, **fsdp_kwargs)
else:
if isinstance(m, modules.TransformerDecoderLayer):
fully_shard(m, **fsdp_kwargs)
fully_shard(model, **fsdp_kwargs)

if lora_weights_state_dict:
utils.load_from_full_model_state_dict(
model, lora_weights_state_dict, self._device, self._is_rank_zero
)

with utils.set_default_dtype(self._dtype), self._device:
for m in model.modules():
if isinstance(m, LoRALinear) and not lora_weights_state_dict:
m.lora_a.to_empty(device=self._device)
m.lora_b.to_empty(device=self._device)
m.initialize_parameters()
if isinstance(m, modules.RotaryPositionalEmbeddings):
m.reset_parameters()
utils.load_from_full_model_state_dict(
model, base_model_state_dict, self._device, self._is_rank_zero
)

if self._dtype == torch.bfloat16:
model = model.to(torch.bfloat16)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will dequant nf4 to original weight. for QLoRA, we may not want it

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I would say that "we certainly do not want it".

Expand All @@ -325,39 +339,13 @@ def _setup_model(
self._lora_rank = cfg_model.lora_rank
self._lora_alpha = cfg_model.lora_alpha

# Note: this needs to be set before wrapping with FSDP
self.adapter_params = get_adapter_params(model)
set_trainable_params(model, self.adapter_params)

model = FSDP(
module=model,
auto_wrap_policy=utils.lora_fsdp_wrap_policy(
modules_to_wrap={modules.TransformerDecoderLayer}
),
sharding_strategy=torch.distributed.fsdp.ShardingStrategy.FULL_SHARD,
device_id=self._device,
# this recipe does not currently support mixed precision training
mixed_precision=None,
# Ensure we broadcast params and buffers from rank 0
sync_module_states=True,
# Initialize empty modules on all non-zero ranks
param_init_fn=(
lambda module: module.to_empty(
device=torch.device("cuda"), recurse=False
)
if not self._is_rank_zero
else None
),
)

# Ensure no params and buffers are on meta device
utils.validate_no_params_on_meta_device(model)

if enable_activation_checkpointing:
utils.set_activation_checkpointing(
model, auto_wrap_policy={modules.TransformerDecoderLayer}
)
if self._is_rank_zero:
log.info(
f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs"
)
memory_stats = utils.get_memory_stats(device=self._device)
utils.log_memory_stats(memory_stats)

Expand All @@ -371,12 +359,11 @@ def _setup_optimizer(
) -> Optimizer:
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
if opt_state_dict:
# Note: technically we should check _contains_fsdp for
# just the state dict of the adapter cfg, but should be equivalent
opt_state_dict = utils.transform_opt_state_dict(
opt_state_dict, self._model, optimizer
utils.load_from_full_optimizer_state_dict(
optimizer,
opt_state_dict,
self._device,
)
optimizer.load_state_dict(opt_state_dict)

if self._is_rank_zero:
log.info("Optimizer and loss are initialized.")
Expand Down Expand Up @@ -461,17 +448,19 @@ def save_checkpoint(
intermediate_checkpoint = epoch + 1 < self.total_epochs
# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
with FSDP.state_dict_type(
cpu_state_dict = utils.get_full_model_state_dict(
self._model,
StateDictType.FULL_STATE_DICT,
FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True),
):
cpu_state_dict = self._model.state_dict()
if intermediate_checkpoint:
opt_state_dict = FSDP.optim_state_dict(self._model, self._optimizer)
else:
opt_state_dict = None
self._is_rank_zero,
)

if intermediate_checkpoint:
opt_state_dict = get_optimizer_state_dict(
self._model,
self._optimizer,
options=StateDictOptions(full_state_dict=True, cpu_offload=True),
)
else:
opt_state_dict = None

# Now that we have the model and opt state dict, create the actual checkpoint dict
# to be sent to the checkpointer and ultimately written to file
Expand Down
Loading
Loading