-
Notifications
You must be signed in to change notification settings - Fork 404
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
enable QLoRA + FSDP2 #909
enable QLoRA + FSDP2 #909
Changes from 91 commits
e5826a1
64fc870
0cd21c6
589191e
c801f26
19a2d70
441da10
750b9e5
3d632d5
cb3abb3
dfcdde3
e68804a
b6fad93
d6af9a2
7bbe522
b616394
a400497
e9de63c
05d3895
7a5bb80
64bf49c
cb1bba4
ac516e9
bfde704
102db31
0b66651
672aabb
6af2723
42ad99c
74f6175
f1b8a5e
36e6829
08cd1fd
559bc4d
2333134
49a0364
dc2ce02
0a604aa
fa83140
6203a1f
4b5a895
1080e2c
1a70498
cb862e9
21f5458
33773bd
483028b
cf42618
b519d50
8600ced
b2fd531
bb8a8bc
db71c5c
16bf2de
df6e535
5f621e1
7d92b1c
588871e
1a5bf1a
ae7de20
23cea56
be06efa
79ef995
5f55c16
b88fa2d
b47ee93
d86b454
2b109f4
9bd07a6
f5cb12e
a2066f9
29d1761
3a01d7f
1d6b4a2
c74c9a9
00f96ff
62192df
ecd5e7e
99c549b
7d11a89
0080795
00360f7
d8664a3
f58f9b2
b9bfd41
7a3d9a1
2835d2a
559b81d
85f978b
dbae23c
f4a8dfa
10e304d
e117a21
4bb5e0f
174d916
5fdcefb
b878018
a8f1a9a
ae49684
cbb3da8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
# Config for single device QLoRA with lora_finetune_single_device.py | ||
# using a Llama2 7B model | ||
# | ||
# This config assumes that you've run the following command before launching | ||
# this run: | ||
# tune download meta-llama/Llama-2-7b-hf --output-dir /tmp/Llama-2-7b-hf --hf-token <HF_TOKEN> | ||
# | ||
# To launch on a single device, run the following command from root: | ||
# tune run lora_finetune_single_device --config llama2/7B_qlora_single_device | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will replace |
||
# | ||
# You can add specific overrides through the command line. For example | ||
# to override the checkpointer directory while launching training | ||
# you can run: | ||
# tune run lora_finetune_single_device --config 7B_qlora_single_device checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR> | ||
# | ||
# This config works only for training on single device. | ||
|
||
# Model Arguments | ||
model: | ||
_component_: torchtune.models.llama2.qlora_llama2_7b | ||
lora_attn_modules: ['q_proj', 'v_proj', 'k_proj', 'output_proj'] | ||
apply_lora_to_mlp: True | ||
apply_lora_to_output: False | ||
lora_rank: 8 | ||
lora_alpha: 16 | ||
|
||
tokenizer: | ||
_component_: torchtune.models.llama2.llama2_tokenizer | ||
path: /tmp/Llama-2-7b-hf/tokenizer.model | ||
|
||
checkpointer: | ||
_component_: torchtune.utils.FullModelHFCheckpointer | ||
checkpoint_dir: /tmp/Llama-2-7b-hf | ||
checkpoint_files: [ | ||
pytorch_model-00001-of-00002.bin, | ||
pytorch_model-00002-of-00002.bin | ||
] | ||
adapter_checkpoint: null | ||
recipe_checkpoint: null | ||
output_dir: /tmp/Llama-2-7b-hf | ||
model_type: LLAMA2 | ||
resume_from_checkpoint: False | ||
|
||
# Dataset and Sampler | ||
dataset: | ||
_component_: torchtune.datasets.alpaca_cleaned_dataset | ||
train_on_input: True | ||
seed: null | ||
shuffle: True | ||
batch_size: 2 | ||
|
||
# Optimizer and Scheduler | ||
optimizer: | ||
_component_: torch.optim.AdamW | ||
weight_decay: 0.01 | ||
lr: 3e-4 | ||
lr_scheduler: | ||
_component_: torchtune.modules.get_cosine_schedule_with_warmup | ||
num_warmup_steps: 100 | ||
|
||
loss: | ||
_component_: torch.nn.CrossEntropyLoss | ||
|
||
fsdp: | ||
cpu_offload: False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. comparing with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Strictly speaking you probably do not need to put There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know Less has been working on another things is, other Does it make sense to you? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah yeah this is a good point, makes sense to me. Thanks for clarifying |
||
|
||
# Training | ||
epochs: 1 | ||
max_steps_per_epoch: null | ||
gradient_accumulation_steps: 16 | ||
compile: False | ||
|
||
# Logging | ||
output_dir: /tmp/qlora_finetune_output/ | ||
metric_logger: | ||
_component_: torchtune.utils.metric_logging.DiskLogger | ||
log_dir: ${output_dir} | ||
log_every_n_steps: 1 | ||
log_peak_memory_stats: False | ||
|
||
# Environment | ||
device: cuda | ||
dtype: bf16 | ||
enable_activation_checkpointing: True | ||
|
||
# Show case the usage of pytorch profiler | ||
# Set enabled to False as it's only needed for debugging training | ||
profiler: | ||
_component_: torchtune.utils.profiler | ||
enabled: False | ||
output_dir: ${output_dir}/torchtune_perf_tracing.json |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,7 +9,7 @@ | |
import time | ||
|
||
from functools import partial | ||
from typing import Any, Dict, Optional, Tuple | ||
from typing import Any, Dict, Optional, Tuple, Union | ||
from warnings import warn | ||
|
||
import torch | ||
|
@@ -32,7 +32,6 @@ | |
get_lora_module_names, | ||
get_merged_lora_ckpt, | ||
set_trainable_params, | ||
validate_state_dict_for_lora, | ||
) | ||
from torchtune.recipe_interfaces import FTRecipeInterface | ||
|
||
|
@@ -214,6 +213,7 @@ def setup(self, cfg: DictConfig) -> None: | |
if self._resume_from_checkpoint | ||
else None | ||
), | ||
cfg_fsdp=cfg.fsdp if hasattr(cfg, "fsdp") else None, | ||
) | ||
self._tokenizer = config.instantiate(cfg.tokenizer) | ||
|
||
|
@@ -265,6 +265,7 @@ def _setup_model( | |
enable_activation_checkpointing: bool, | ||
base_model_state_dict: Dict[str, Any], | ||
lora_weights_state_dict: Optional[Dict[str, Any]] = None, | ||
cfg_fsdp: Optional[Union[DictConfig, None]] = None, | ||
) -> nn.Module: | ||
""" | ||
Model initialization has some important considerations: | ||
|
@@ -290,25 +291,6 @@ def _setup_model( | |
with utils.set_default_dtype(self._dtype), torch.device("meta"): | ||
model = config.instantiate(cfg_model) | ||
|
||
if self._is_rank_zero: | ||
# The model contains LoRA params which won't have any matching keys in | ||
# the state dict. As a result, we need to load with strict=False. | ||
# Before loading the state dict, ensure the state dict keys for the base | ||
# model and adapters (if available) match the keys in the full LoRA model | ||
# This is a good sanity check to prevent silent errors | ||
validate_state_dict_for_lora( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh yeah this is what I was alluding to in the previous PR. You can try validate_missing_and_unexpected_for_lora as in the single-device recipe now that we have the correct FQN. If it doesn't work out of the box feel free to leave this as-is, we can come back to refactor after. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
lora_attn_modules=cfg_model.lora_attn_modules, | ||
apply_lora_to_mlp=cfg_model.apply_lora_to_mlp, | ||
apply_lora_to_output=getattr(cfg_model, "apply_lora_to_output", False), | ||
full_model_state_dict_keys=model.state_dict().keys(), | ||
lora_state_dict_keys=( | ||
lora_weights_state_dict.keys() | ||
if lora_weights_state_dict is not None | ||
else None | ||
), | ||
base_model_state_dict_keys=base_model_state_dict.keys(), | ||
) | ||
|
||
self.adapter_params = get_adapter_params(model) | ||
set_trainable_params(model, self.adapter_params) | ||
|
||
|
@@ -317,34 +299,39 @@ def _setup_model( | |
model, auto_wrap_policy={modules.TransformerDecoderLayer} | ||
) | ||
|
||
fsdp_kwargs = {} | ||
if cfg_fsdp and cfg_fsdp.cpu_offload: | ||
from torch.distributed._composable.fsdp import CPUOffloadPolicy | ||
|
||
fsdp_kwargs["offload_policy"] = CPUOffloadPolicy() | ||
# iterating from lowerer modules to higher | ||
# eg grouping lora adapters before transformer block | ||
for m in reversed(list(model.modules())): | ||
if isinstance(m, nn.Linear) and m.weight.requires_grad: | ||
fully_shard(m) | ||
fully_shard(m, **fsdp_kwargs) | ||
# TransformerDecoderLayer is wrapped by CheckpointWrapper | ||
# when enable_activation_checkpointing | ||
if enable_activation_checkpointing: | ||
if isinstance(m, CheckpointWrapper): | ||
fully_shard(m) | ||
fully_shard(m, **fsdp_kwargs) | ||
else: | ||
if isinstance(m, modules.TransformerDecoderLayer): | ||
fully_shard(m) | ||
|
||
fully_shard(model) | ||
fully_shard(m, **fsdp_kwargs) | ||
fully_shard(model, **fsdp_kwargs) | ||
|
||
if lora_weights_state_dict: | ||
utils.load_from_full_model_state_dict( | ||
model, lora_weights_state_dict, self._device, self._is_rank_zero | ||
) | ||
|
||
with utils.set_default_dtype(self._dtype), self._device: | ||
lora_device = "cpu" if cfg_fsdp and cfg_fsdp.cpu_offload else self._device | ||
for m in model.modules(): | ||
if isinstance(m, LoRALinear) and not lora_weights_state_dict: | ||
# lora may not be covered in state dict | ||
# if finetune for the 1st time | ||
m.lora_a.to_empty(device=self._device) | ||
m.lora_b.to_empty(device=self._device) | ||
m.lora_a.to_empty(device=lora_device) | ||
m.lora_b.to_empty(device=lora_device) | ||
m.initialize_parameters() | ||
# RoPE is not covered in state dict | ||
if isinstance(m, modules.RotaryPositionalEmbeddings): | ||
|
@@ -590,6 +577,8 @@ def train(self) -> None: | |
logits = logits.transpose(1, 2) | ||
# Compute loss | ||
loss = self._loss_fn(logits, labels) | ||
# free logits otherwise it peaks backward memory | ||
del logits | ||
weifengpy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
loss = loss / self._gradient_accumulation_steps | ||
running_loss += loss | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,26 +17,3 @@ def clone(func, *args, **kwargs): | |
in precision. | ||
""" | ||
return to_nf4(args[0][0].get_original_weight()) | ||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. starting from TorchAO==0.2.0, we implemented |
||
@nf4_tensor_impl([torch.ops.aten.copy_.default]) | ||
def inplace_copy(func, *args, **kwargs): | ||
""" | ||
Performs an inplace copy of an incoming tensor into the tensor | ||
being copied into. The inplace tensor is given by args[0][1] and the | ||
tensor being copied into is given by args[0][0]. The copy is performed | ||
by copying over all attributes. This method would have to be updated | ||
if additional attributes are added to NF4Tensor. | ||
""" | ||
dest_tensor = args[0][0] # tensor we are inplace copying into | ||
ref_tensor = to_nf4( | ||
args[0][1].to(dest_tensor.device) | ||
) # TODO check if nf4 tensor takes in device arg | ||
dest_tensor.block_size = ref_tensor.block_size | ||
dest_tensor.n_blocks = ref_tensor.n_blocks | ||
dest_tensor.scaler_block_size = ref_tensor.scaler_block_size | ||
dest_tensor.quantized_scalers = ref_tensor.quantized_scalers | ||
dest_tensor.quantization_factor = ref_tensor.quantization_factor | ||
dest_tensor.scaler_mean = ref_tensor.scaler_mean | ||
dest_tensor.quantized_data = ref_tensor.quantized_data | ||
dest_tensor.nf4 = ref_tensor.nf4 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,7 +8,7 @@ | |
import logging | ||
import os | ||
from itertools import chain | ||
from typing import Any, Callable, Dict, Set, Tuple, Type | ||
from typing import Any, Callable, cast, Dict, Set, Tuple, Type | ||
|
||
import torch | ||
import torch.distributed as dist | ||
|
@@ -19,6 +19,7 @@ | |
from torch.distributed.fsdp import ShardingStrategy | ||
from torch.distributed.fsdp.wrap import ModuleWrapPolicy | ||
from torch.optim import Optimizer | ||
from torchao.dtypes.nf4tensor import NF4Tensor, to_nf4 | ||
from torchtune import modules | ||
from torchtune.modules.peft.lora import ( | ||
_lora_a_init_params, | ||
|
@@ -280,12 +281,39 @@ def load_from_full_model_state_dict( | |
sharded_sd = {} | ||
for param_name, full_tensor in full_sd.items(): | ||
sharded_meta_param = meta_sharded_sd.get(param_name) | ||
# `.to(dtype)` ensures same dtype when `assign=True` | ||
sharded_tensor = distribute_tensor( | ||
full_tensor.to(sharded_meta_param.dtype), | ||
sharded_meta_param.device_mesh, | ||
sharded_meta_param.placements, | ||
) | ||
full_tensor = full_tensor.to(sharded_meta_param.dtype).to(device) | ||
if isinstance(sharded_meta_param._local_tensor, NF4Tensor): | ||
full_tensor = to_nf4(full_tensor) | ||
# replicating logic from `_fsdp_param.py`` `_init_sharded_param` | ||
# otherwise `distribute_tensor(DTensor(local=NF4))` | ||
# requires dispatching `c10d.scatter_`` | ||
# long-term solution is `swap_tensor` | ||
mesh = sharded_meta_param.device_mesh | ||
if mesh.ndim > 1: | ||
raise NotImplementedError(f"only support 1D FSDP but got {mesh.ndim=}") | ||
shard_mesh_dim = 0 | ||
shard_world_size = mesh.size(shard_mesh_dim) | ||
shard_rank = cast( | ||
torch.distributed.ProcessGroup, mesh.get_group(shard_mesh_dim) | ||
).rank() | ||
chunk = list(torch.chunk(full_tensor, shard_world_size, dim=0))[shard_rank] | ||
sharded_param = full_tensor.new_zeros(chunk.size()) | ||
sharded_param[: chunk.size(0)].copy_(chunk) | ||
sharded_tensor = DTensor( | ||
sharded_param, | ||
sharded_meta_param.device_mesh, | ||
sharded_meta_param.placements, | ||
shape=sharded_meta_param.size(), | ||
dtype=sharded_meta_param.dtype, | ||
requires_grad=sharded_meta_param.requires_grad, | ||
stride=sharded_meta_param.stride(), | ||
) | ||
Comment on lines
+285
to
+310
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If it's not too much work, is it possible to extend the unit test you added in #855 to cover this case as well? As is I find it a bit hard to follow and want to make sure we have a reliable sanity check in case anything breaks in the future There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you're right. will try to come up with a unit test to cover NF4 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ebsmothers I added |
||
else: | ||
sharded_tensor = distribute_tensor( | ||
full_tensor, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No longer need to convert dtype as before? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's moved to line 284 |
||
sharded_meta_param.device_mesh, | ||
sharded_meta_param.placements, | ||
) | ||
sharded_sd[param_name] = nn.Parameter(sharded_tensor) | ||
# choose `assign=True` since we cannot call `copy_` on meta tensor | ||
model.load_state_dict(sharded_sd, strict=False, assign=True) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we rename this config as
7B_qlora_fsdp2.yaml
to match the LoRA one?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point! updating now