Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable LoRA + FSDP2 #855

Merged
merged 58 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
e5826a1
enable LoRA + FSDP2
weifengpy Apr 24, 2024
64fc870
reset params for lora weights and rope
weifengpy Apr 24, 2024
0cd21c6
support lora weights checkpoint and checkpoint utils
weifengpy Apr 24, 2024
589191e
fix lora meta device bug
weifengpy Apr 24, 2024
c801f26
save optim state dict
weifengpy Apr 25, 2024
19a2d70
mark TODO
weifengpy Apr 25, 2024
441da10
optimizer foreach=True for DTensor
weifengpy Apr 25, 2024
750b9e5
clip grad norm
weifengpy Apr 25, 2024
3d632d5
switch to ptd state dict api
weifengpy Apr 26, 2024
cb3abb3
add profiler
weifengpy May 1, 2024
e68804a
use torchao copy_
weifengpy May 1, 2024
d6af9a2
enable saving checkpoint
weifengpy May 1, 2024
b616394
optimizer state dict: load on rank0 and broadcast
weifengpy May 1, 2024
a400497
import Optimizer
weifengpy May 1, 2024
e9de63c
resume training
weifengpy May 3, 2024
05d3895
prepare for full test
weifengpy May 3, 2024
7a5bb80
prepare for full test
weifengpy May 3, 2024
64bf49c
remove profiler
weifengpy May 3, 2024
cb1bba4
passed integration test
weifengpy May 4, 2024
ac516e9
remove uncesssary change
weifengpy May 4, 2024
bfde704
Merge branch 'main' into fsdp2
weifengpy May 4, 2024
102db31
bring back state dict validation
weifengpy May 4, 2024
0b66651
align indent on comment
weifengpy May 4, 2024
672aabb
remove unused import
weifengpy May 4, 2024
6af2723
switch to ptd state dict and keep self implemented in record
weifengpy May 8, 2024
42ad99c
clean unused code
weifengpy May 8, 2024
74f6175
remove cuda value error
weifengpy May 8, 2024
f1b8a5e
comment on to_empty
weifengpy May 8, 2024
36e6829
fix memory issues by switching model state dict api
weifengpy May 8, 2024
08cd1fd
clean for review
weifengpy May 8, 2024
559bc4d
Merge branch 'main' into fsdp2
weifengpy May 8, 2024
2333134
fix linter
weifengpy May 9, 2024
49a0364
fix checkpoint loading
weifengpy May 9, 2024
dc2ce02
expecttest CI depedency
weifengpy May 9, 2024
0a604aa
ci depdencecy
weifengpy May 9, 2024
fa83140
fix CI issue
weifengpy May 10, 2024
4b5a895
Merge branch 'pytorch:main' into fsdp2
weifengpy May 10, 2024
a2e34ec
support resuming training
weifengpy May 14, 2024
6142031
update docstring
weifengpy May 14, 2024
7607e14
remove depdency on broadcast_from_rank0
weifengpy May 14, 2024
1899beb
remove the need for model.to(device)
weifengpy May 15, 2024
c1cfabb
wrap lora and TransformerBlock
weifengpy May 17, 2024
d7382ae
require torch version 2.4.0
weifengpy May 17, 2024
d1ff53b
FSDP(CheckpointWrapper(model))
weifengpy May 22, 2024
1eb9e87
remove model.to()
weifengpy May 29, 2024
695e959
add docstrings and remove depdency on dcp
weifengpy May 31, 2024
e10f638
remove try...catch FSDPModule
weifengpy Jun 1, 2024
b1e3d30
Merge branch 'main' into fsdp2
weifengpy Jun 1, 2024
944a723
fsdp2 as dev recipe
weifengpy Jun 1, 2024
ac5f7aa
restore lora_finetune_distributed
weifengpy Jun 1, 2024
d769626
test cudnn ci error
weifengpy Jun 2, 2024
f90c3cc
test CI error
weifengpy Jun 3, 2024
42ef49a
address CI error for setting seed
weifengpy Jun 3, 2024
170de94
add back pytest
weifengpy Jun 3, 2024
f8a7018
add expecttest
weifengpy Jun 3, 2024
a3b2f3e
pytest 7.4.0
weifengpy Jun 3, 2024
1a692b3
add dev/recipe
weifengpy Jun 3, 2024
8fbbc4b
update yaml with lora_finetune_fsdp2
weifengpy Jun 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 58 additions & 81 deletions recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,24 @@

from torch import nn
from torch.distributed import destroy_process_group, init_process_group
from torch.distributed.fsdp import (
FullOptimStateDictConfig,
FullStateDictConfig,
FullyShardedDataParallel as FSDP,
StateDictType,
)
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed._tensor import DTensor
from torch.optim import Optimizer
from torch.optim.optimizer import _foreach_supported_types
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import config, modules, utils
from torchtune.modules.peft import LoRALinear
from torchtune.modules.peft.peft_utils import (
get_adapter_params,
get_merged_lora_ckpt,
set_trainable_params,
validate_state_dict_for_lora,
)
from torchtune.recipe_interfaces import FTRecipeInterface

# use foreach on CUDA
if DTensor not in _foreach_supported_types:
_foreach_supported_types.append(DTensor)

from tqdm import tqdm

log = utils.get_logger("DEBUG")
Expand Down Expand Up @@ -277,86 +278,62 @@ def _setup_model(
the correct device.
"""

if self._device.type != "cuda":
raise ValueError(
f'FSDP needs device="cuda" but found device={self._device.type}'
)

if self._is_rank_zero:
log.info("FSDP is enabled. Instantiating Model on CPU for Rank 0 ...")
log.info("FSDP is enabled. Model init and checkpoint loading on Rank 0 ...")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
log.info("FSDP is enabled. Model init and checkpoint loading on Rank 0 ...")
log.info("FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ...")

init_start = time.perf_counter()

with utils.set_default_dtype(self._dtype):
model = config.instantiate(cfg_model)
with utils.set_default_dtype(self._dtype), torch.device("meta"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry not able to comment above, but the docstring of this function should be updated since we're no longer initializing on CPU?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the docstring used to be Instantiating Model on CPU (left) and I removed the mention of CPU. I did not mention meta device because it measures meta init + checkpoing loading now. Happy to improve if you are referring to this docstring

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh. Just got you point. Updated docstring for _setup_model

model = config.instantiate(cfg_model)

log.info(
f"Model instantiation took {time.perf_counter() - init_start:.2f} secs"
)
# Note: this needs to be set before wrapping with FSDP
self.adapter_params = get_adapter_params(model)
set_trainable_params(model, self.adapter_params)

# The model contains LoRA params which won't have any matching keys in
# the state dict. As a result, we need to load with strict=False.
# Before loading the state dict, ensure the state dict keys for the base
# model and adapters (if available) match the keys in the full LoRA model
# This is a good sanity check to prevent silent errors
validate_state_dict_for_lora(
lora_attn_modules=cfg_model.lora_attn_modules,
apply_lora_to_mlp=cfg_model.apply_lora_to_mlp,
apply_lora_to_output=getattr(cfg_model, "apply_lora_to_output", False),
full_model_state_dict_keys=model.state_dict().keys(),
lora_state_dict_keys=(
lora_weights_state_dict.keys()
if lora_weights_state_dict is not None
else None
),
base_model_state_dict_keys=base_model_state_dict.keys(),
if enable_activation_checkpointing:
utils.set_activation_checkpointing(
model, auto_wrap_policy={modules.TransformerDecoderLayer}
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if isinstance(m, modules.TransformerDecoderLayer): is equivalent of auto_wrap_policy in FSDP1

# Load both the base model weights and (if available) the adapter weights. Both
# of this should happen only on Rank 0
model.load_state_dict(base_model_state_dict, strict=False)
if lora_weights_state_dict:
model.load_state_dict(lora_weights_state_dict, strict=False)
for m in model.modules():
if isinstance(m, modules.TransformerDecoderLayer):
fully_shard(m)
fully_shard(model)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the noob question, but can you help me understand what's going on here? Why do I need to full_shard the TransformerDecoderLayer and then call fully_shard on the model?

An unrelated question: if I have enough GPU memory, should I be thinking about using something similar to SHARD_GRAD_OP with FSDP2?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In FSDP1, we wrap each TransformerDecoderLayer and then root model as well. It's blackboxed in auto_wrap_policy=utils.lora_fsdp_wrap_policy(modules_to_wrap={modules.TransformerDecoderLayer})

In FSDP2, we un-blackboxed it to this for-loop. It you perfer, this can be factored into a util function in torchtune so user call util.fully_shard(model, modules_to_wrap)

Personally I have bias towards un-blackboxed approach since people can modify the for-loop to achieve different wrapping

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the equivalence SHARD_GRAD_OP in FSDP2 is reshard_after_forward=False . Do you want it as a config in .yaml?

fully_shard(model, reshard_after_forward=False)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the explanation! I love the un-blackboxed approach here - just needs more documentation and explanation :) After reading the FSDP2 RFC, this became a lot clearer.


else:
# For non-zero ranks, load the model on meta device
with utils.set_default_dtype(self._dtype), torch.device("meta"):
model = config.instantiate(cfg_model)
utils.load_from_full_state_dict(
model, base_model_state_dict, self._device, self._is_rank_zero
)
if lora_weights_state_dict:
utils.load_from_full_state_dict(
model, lora_weights_state_dict, self._device, self._is_rank_zero
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pros and cons of meta init. pros is 4.5x speed up during model init and thus shorter TTFB. cons is user need to call initialize_parameters on LoRALinear explicitly to move them from meta to gpu

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this because these params are not being loaded from checkpoint? Or do I misunderstand?

If this is indeed the reason, how do we handle this code block when the LoRA params are being loaded from checkpoint (eg: when resuming training)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right. when finetuning from a original HF checkpoint, lora_weights_state_dict = None

for resuming training, lora_weights_state_dict is not None and we avoided calling m.initialize_parameters() again

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got you, thanks so much for the explanation! I think something that would be super helpful would be document here in the form of comments the relationship between:

  • the modules on which we call fully_shard
  • init on meta device
  • calling initialize_parameters and reset_parameters

Also I think there was a technical reason with FSDP1 to call the function reset_parameters. Is that still true? Or can we standardize this with initialize_parameters in the modules code? Happy to chat about this offline!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point! will add comment to explain fully_shard, meta init, and reset/initialize_parameters

FSDP1 calls reset_parameters for the exact same reason FSDP2 call reset/initialize_parameters: RoPE are not covered in checkpoints, lora_a and lora_b are not covered in checkpoints for resume_training=False

It's just FSDP1 have a contract to call overrided nn.Module.reset_parameter through FSDP(model, param_init=), but FSDP2 does not impose overriding reset_parameter. now use can name it reset_parameter or initialize_parameters

if self._dtype == torch.bfloat16:
model = model.to(torch.bfloat16)
with utils.set_default_dtype(self._dtype), self._device:
for m in model.modules():
if isinstance(m, LoRALinear) and not lora_weights_state_dict:
# to_empty is needed since kaiming_uniform_ is inplace
m.to_empty(device=self._device)
awgu marked this conversation as resolved.
Show resolved Hide resolved
m.initialize_parameters()
if isinstance(m, modules.RotaryPositionalEmbeddings):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to clarify, we special handle RoPE because the buffer is not being loaded from a state dict, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's correct

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar comment here, let's document what's happening so that users can easily understand why we initialize these modules separately.

m.reset_parameters()

model = model.to(self._dtype)

# LoRA hyper-params needed for merging weights while saving checkpoints
self._lora_rank = cfg_model.lora_rank
self._lora_alpha = cfg_model.lora_alpha

# Note: this needs to be set before wrapping with FSDP
self.adapter_params = get_adapter_params(model)
set_trainable_params(model, self.adapter_params)

model = FSDP(
module=model,
auto_wrap_policy=utils.lora_fsdp_wrap_policy(
modules_to_wrap={modules.TransformerDecoderLayer}
),
sharding_strategy=torch.distributed.fsdp.ShardingStrategy.FULL_SHARD,
device_id=self._device,
# this recipe does not currently support mixed precision training
mixed_precision=None,
# Ensure we broadcast params and buffers from rank 0
sync_module_states=True,
# Initialize empty modules on all non-zero ranks
param_init_fn=(
lambda module: module.to_empty(
device=torch.device("cuda"), recurse=False
)
if not self._is_rank_zero
else None
),
)

# Ensure no params and buffers are on meta device
utils.validate_no_params_on_meta_device(model)

if enable_activation_checkpointing:
utils.set_activation_checkpointing(
model, auto_wrap_policy={modules.TransformerDecoderLayer}
)
if self._is_rank_zero:
log.info(
f"Model init and checkpoint loading took {time.perf_counter() - init_start:.2f} secs"
)
memory_stats = utils.get_memory_stats(device=self._device)
utils.log_memory_stats(memory_stats)

Expand All @@ -372,6 +349,7 @@ def _setup_optimizer(
if opt_state_dict:
# Note: technically we should check _contains_fsdp for
# just the state dict of the adapter cfg, but should be equivalent
# TODO: implement local -> DTensor
opt_state_dict = utils.transform_opt_state_dict(
opt_state_dict, self._model, optimizer
)
Expand Down Expand Up @@ -451,22 +429,21 @@ def save_checkpoint(
intermediate_checkpoint = epoch + 1 < self.total_epochs
# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
with FSDP.state_dict_type(
self._model,
StateDictType.FULL_STATE_DICT,
FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True),
):
cpu_state_dict = self._model.state_dict()
if intermediate_checkpoint:
opt_state_dict = FSDP.optim_state_dict(self._model, self._optimizer)
weifengpy marked this conversation as resolved.
Show resolved Hide resolved
else:
opt_state_dict = None

cpu_state_dict = utils.get_full_model_state_dict(
self._model, self._is_rank_zero
)

if intermediate_checkpoint:
opt_state_dict = utils.get_full_optimizer_state_dict(
self._optimizer, self._is_rank_zero
)
weifengpy marked this conversation as resolved.
Show resolved Hide resolved
else:
opt_state_dict = None

# Now that we have the model and opt state dict, create the actual checkpoint dict
# to be sent to the checkpointer and ultimately written to file
if self._is_rank_zero:

# Filter out the adapter keys and weights from the model state dict. These will
# be saved separately
adapter_key_filter = lambda x: x in self.adapter_params
Expand Down
3 changes: 3 additions & 0 deletions torchtune/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@
from ._device import get_device
from ._distributed import ( # noqa
contains_fsdp,
get_full_model_state_dict,
get_full_optimizer_state_dict,
get_world_size_and_rank,
init_distributed,
is_distributed,
load_from_full_state_dict,
lora_fsdp_wrap_policy,
prepare_model_for_fsdp_with_meta_device,
validate_no_params_on_meta_device,
Expand Down
72 changes: 70 additions & 2 deletions torchtune/utils/_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,25 @@
import logging
import os
from itertools import chain
from typing import Callable, Dict, Optional, Set, Tuple, Type, Union
from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, Union

import torch
import torch.distributed as dist
import torch.distributed._composable.fsdp
from torch import nn
from torch.distributed._tensor import distribute_tensor, DTensor
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
MixedPrecision,
ShardingStrategy,
)
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.optim import Optimizer
from torchtune.modules.peft.lora import (
_lora_a_init_params,
_lora_b_init_params,
LoRALinear,
)

from torchtune.utils._device import _validate_device_from_env, get_device
from torchtune.utils.logging import get_logger

Expand Down Expand Up @@ -297,3 +299,69 @@ def lora_wrap_fsdp(module: nn.Module, recurse: bool, **kwargs):
return isinstance(module, tuple(modules_to_wrap))

return lora_wrap_fsdp


weifengpy marked this conversation as resolved.
Show resolved Hide resolved
def load_from_full_state_dict(
model: torch.distributed._composable.fsdp.FSDP,
full_sd: Dict[str, Any],
device: torch.device,
is_rank_zero: bool,
):
meta_sharded_sd = model.state_dict()
sharded_sd = {}
for param_name, full_tensor in full_sd.items():
sharded_meta_param = meta_sharded_sd.get(param_name)
if is_rank_zero:
full_tensor = full_tensor.detach().to(device)
else:
full_tensor = torch.empty(
sharded_meta_param.size(),
device=device,
dtype=sharded_meta_param.dtype,
)
torch.distributed.broadcast(full_tensor, src=0)
sharded_tensor = distribute_tensor(
full_tensor, sharded_meta_param.device_mesh, sharded_meta_param.placements
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is device_mesh and placements information coming from?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's from fully_shard(model, mesh). we are using default mesh where every rank serves FSDP, since there is no 2D/3D parallasim involved

after fully_shard(model, mesh), model.parameters() is converted from plain tensor to DTensor with mesh

)
sharded_sd[param_name] = nn.Parameter(sharded_tensor)
model.load_state_dict(sharded_sd, strict=False, assign=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we catch missing and unexpected keys from load_state_dict with strict=False what format will the keys be in? Previously with FSDP1 the keys contained all the info about FSDP wrapping. E.g.model.layers.0._fsdp_flat_param.attn.q_proj.weight (probably not exactly right but something like that). Will that still be the case here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for FSDP2, it's clean FQNs without FSDP prefix. For example, layers.0.attn.q_proj.lora_a.weight

FSDP2 is clean because 1) fully_shard register hooks instead wrap nn.Module, 2) fully_shard set module.__class__ = type(f"FSDP{cls.__name__}", (FSDPModule, cls), dct) https://fburl.com/i20yr3s2

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great! I think this means we can actually do validation of LoRA state dict load more cleanly (note that we actually have two separate utilities for this for the single-device vs distributed case because of this FSDP prefix issue). Not a concern for this PR but this will allow us to clean up our code a bit



def get_full_model_state_dict(
model: torch.distributed._composable.fsdp.FSDP,
is_rank_zero: bool,
) -> Dict[str, Any]:
sharded_sd = model.state_dict()
cpu_state_dict = {}
for param_name, sharded_param in sharded_sd.items():
full_param = sharded_param.full_tensor()
if is_rank_zero:
cpu_state_dict[param_name] = full_param.cpu()
weifengpy marked this conversation as resolved.
Show resolved Hide resolved
else:
del full_param
return cpu_state_dict


weifengpy marked this conversation as resolved.
Show resolved Hide resolved
def get_full_optimizer_state_dict(
opt: Optimizer,
is_rank_zero: bool,
) -> Dict[str, Any]:
sharded_sd = opt.state_dict()
sharded_state = sharded_sd["state"]
full_state = {}
for group_id, sharded_group in sharded_state.items():
group_state = {}
for attr, sharded_tensor in sharded_group.items():
if isinstance(sharded_tensor, DTensor):
full_tensor = sharded_tensor.full_tensor()
else:
full_tensor = sharded_tensor
if is_rank_zero:
group_state[attr] = full_tensor.cpu()
else:
del full_tensor
full_state[group_id] = group_state
return {
"param_groups": sharded_sd["param_groups"],
"state": full_state,
}