Skip to content

Commit

Permalink
enable QLoRA + FSDP2 (#909)
Browse files Browse the repository at this point in the history
Co-authored-by: Kartikay Khandelwal <[email protected]>
Co-authored-by: ebsmothers <[email protected]>
Co-authored-by: Rafi Ayub <[email protected]>
Co-authored-by: Joe Cummings <[email protected]>
Co-authored-by: Salman Mohammadi <[email protected]>
Co-authored-by: Rohan Varma <[email protected]>
Co-authored-by: Optimox <[email protected]>
Co-authored-by: Tanish Ambulkar <[email protected]>
Co-authored-by: Botao Chen <[email protected]>
Co-authored-by: solitude-alive <[email protected]>
Co-authored-by: christobill <[email protected]>
Co-authored-by: Philip Bontrager <[email protected]>
Co-authored-by: Evan Smothers <[email protected]>
  • Loading branch information
14 people authored Jun 5, 2024
1 parent 41a3f92 commit f9cb9e6
Show file tree
Hide file tree
Showing 9 changed files with 326 additions and 66 deletions.
90 changes: 90 additions & 0 deletions recipes/configs/dev/llama2/70B_qlora_fsdp2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Config for multi-device QLoRA in lora_finetune_fsdp2.py
# using a Llama2 70B model
#
# This config assumes that you've run the following command before launching
# this run:
# tune download meta-llama/Llama-2-70b-hf --output-dir /tmp/Llama-2-70b-hf --hf-token <HF_TOKEN>
#
# This config needs 8 GPUs to run
# # tune run --nproc_per_node 8 lora_finetune_fsdp2 --config llama2/70B_qlora
#

# Model Arguments
model:
_component_: torchtune.models.llama2.qlora_llama2_70b
lora_attn_modules: ['q_proj', 'v_proj', 'k_proj', 'output_proj']
apply_lora_to_mlp: True
apply_lora_to_output: False
lora_rank: 16
lora_alpha: 32

tokenizer:
_component_: torchtune.models.llama2.llama2_tokenizer
path: /tmp/Llama-2-70b-hf/tokenizer.model

checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
checkpoint_dir: /tmp/Llama-2-70b-hf
checkpoint_files: [
pytorch_model-00001-of-00015.bin,
pytorch_model-00002-of-00015.bin,
pytorch_model-00003-of-00015.bin,
pytorch_model-00004-of-00015.bin,
pytorch_model-00005-of-00015.bin,
pytorch_model-00006-of-00015.bin,
pytorch_model-00007-of-00015.bin,
pytorch_model-00008-of-00015.bin,
pytorch_model-00009-of-00015.bin,
pytorch_model-00010-of-00015.bin,
pytorch_model-00011-of-00015.bin,
pytorch_model-00012-of-00015.bin,
pytorch_model-00013-of-00015.bin,
pytorch_model-00014-of-00015.bin,
pytorch_model-00015-of-00015.bin,
]
recipe_checkpoint: null
output_dir: /tmp/Llama-2-70b-hf
model_type: LLAMA2
resume_from_checkpoint: False

# Dataset and Sampler
dataset:
_component_: torchtune.datasets.alpaca_dataset
train_on_input: True
seed: null
shuffle: True
batch_size: 2

# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
_component_: torchtune.modules.get_cosine_schedule_with_warmup
num_warmup_steps: 100

loss:
_component_: torch.nn.CrossEntropyLoss

fsdp:
cpu_offload: False

# Training
epochs: 1
max_steps_per_epoch: null
gradient_accumulation_steps: 1
compile: False

# Logging
output_dir: /tmp/qlora_finetune_output
metric_logger:
_component_: torchtune.utils.metric_logging.DiskLogger
log_dir: ${output_dir}
log_every_n_steps: 1
log_peak_memory_stats: False

# Environment
device: cuda
dtype: bf16
enable_activation_checkpointing: True
84 changes: 84 additions & 0 deletions recipes/configs/dev/llama2/7B_qlora_fsdp2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Config for single device QLoRA with lora_finetune_fsdp2.py
# using a Llama2 7B model
#
# This config assumes that you've run the following command before launching
# this run:
# tune download meta-llama/Llama-2-7b-hf --output-dir /tmp/Llama-2-7b-hf --hf-token <HF_TOKEN>
#
# To launch on a single device, run the following command from root:
# tune run lora_finetune_fsdp2 --config llama2/7B_qlora
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run lora_finetune_fsdp2 --config llama2/7B_qlora checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
# This config works only for training on single device.

# Model Arguments
model:
_component_: torchtune.models.llama2.qlora_llama2_7b
lora_attn_modules: ['q_proj', 'v_proj', 'k_proj', 'output_proj']
apply_lora_to_mlp: True
apply_lora_to_output: False
lora_rank: 8
lora_alpha: 16

tokenizer:
_component_: torchtune.models.llama2.llama2_tokenizer
path: /tmp/Llama-2-7b-hf/tokenizer.model

checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
checkpoint_dir: /tmp/Llama-2-7b-hf
checkpoint_files: [
pytorch_model-00001-of-00002.bin,
pytorch_model-00002-of-00002.bin
]
adapter_checkpoint: null
recipe_checkpoint: null
output_dir: /tmp/Llama-2-7b-hf
model_type: LLAMA2
resume_from_checkpoint: False

# Dataset and Sampler
dataset:
_component_: torchtune.datasets.alpaca_cleaned_dataset
train_on_input: True
seed: null
shuffle: True
batch_size: 2

# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
_component_: torchtune.modules.get_cosine_schedule_with_warmup
num_warmup_steps: 100

loss:
_component_: torch.nn.CrossEntropyLoss

fsdp:
cpu_offload: False

# Training
epochs: 1
max_steps_per_epoch: null
gradient_accumulation_steps: 16
compile: False

# Logging
output_dir: /tmp/qlora_finetune_output/
metric_logger:
_component_: torchtune.utils.metric_logging.DiskLogger
log_dir: ${output_dir}
log_every_n_steps: 1
log_peak_memory_stats: False

# Environment
device: cuda
dtype: bf16
enable_activation_checkpointing: True
66 changes: 31 additions & 35 deletions recipes/dev/lora_finetune_fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import time

from functools import partial
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional, Tuple, Union
from warnings import warn

import torch
Expand All @@ -32,7 +32,7 @@
get_lora_module_names,
get_merged_lora_ckpt,
set_trainable_params,
validate_state_dict_for_lora,
validate_missing_and_unexpected_for_lora,
)
from torchtune.recipe_interfaces import FTRecipeInterface

Expand Down Expand Up @@ -214,6 +214,7 @@ def setup(self, cfg: DictConfig) -> None:
if self._resume_from_checkpoint
else None
),
cfg_fsdp=cfg.fsdp if hasattr(cfg, "fsdp") else None,
)
self._tokenizer = config.instantiate(cfg.tokenizer)

Expand Down Expand Up @@ -265,6 +266,7 @@ def _setup_model(
enable_activation_checkpointing: bool,
base_model_state_dict: Dict[str, Any],
lora_weights_state_dict: Optional[Dict[str, Any]] = None,
cfg_fsdp: Optional[Union[DictConfig, None]] = None,
) -> nn.Module:
"""
Model initialization has some important considerations:
Expand All @@ -290,25 +292,6 @@ def _setup_model(
with utils.set_default_dtype(self._dtype), torch.device("meta"):
model = config.instantiate(cfg_model)

if self._is_rank_zero:
# The model contains LoRA params which won't have any matching keys in
# the state dict. As a result, we need to load with strict=False.
# Before loading the state dict, ensure the state dict keys for the base
# model and adapters (if available) match the keys in the full LoRA model
# This is a good sanity check to prevent silent errors
validate_state_dict_for_lora(
lora_attn_modules=cfg_model.lora_attn_modules,
apply_lora_to_mlp=cfg_model.apply_lora_to_mlp,
apply_lora_to_output=getattr(cfg_model, "apply_lora_to_output", False),
full_model_state_dict_keys=model.state_dict().keys(),
lora_state_dict_keys=(
lora_weights_state_dict.keys()
if lora_weights_state_dict is not None
else None
),
base_model_state_dict_keys=base_model_state_dict.keys(),
)

self.adapter_params = get_adapter_params(model)
set_trainable_params(model, self.adapter_params)

Expand All @@ -317,47 +300,58 @@ def _setup_model(
model, auto_wrap_policy={modules.TransformerDecoderLayer}
)

fsdp_kwargs = {}
if cfg_fsdp and cfg_fsdp.cpu_offload:
from torch.distributed._composable.fsdp import CPUOffloadPolicy

fsdp_kwargs["offload_policy"] = CPUOffloadPolicy()
# iterating from lowerer modules to higher
# eg grouping lora adapters before transformer block
for m in reversed(list(model.modules())):
if isinstance(m, nn.Linear) and m.weight.requires_grad:
fully_shard(m)
fully_shard(m, **fsdp_kwargs)
# TransformerDecoderLayer is wrapped by CheckpointWrapper
# when enable_activation_checkpointing
if enable_activation_checkpointing:
if isinstance(m, CheckpointWrapper):
fully_shard(m)
fully_shard(m, **fsdp_kwargs)
else:
if isinstance(m, modules.TransformerDecoderLayer):
fully_shard(m)

fully_shard(model)
fully_shard(m, **fsdp_kwargs)
fully_shard(model, **fsdp_kwargs)

if lora_weights_state_dict:
utils.load_from_full_model_state_dict(
lora_missing, lora_unexpected = utils.load_from_full_model_state_dict(
model, lora_weights_state_dict, self._device, self._is_rank_zero
)
else:
lora_missing, lora_unexpected = None, None

with utils.set_default_dtype(self._dtype), self._device:
lora_device = "cpu" if cfg_fsdp and cfg_fsdp.cpu_offload else self._device
for m in model.modules():
if isinstance(m, LoRALinear) and not lora_weights_state_dict:
# lora may not be covered in state dict
# if finetune for the 1st time
m.lora_a.to_empty(device=self._device)
m.lora_b.to_empty(device=self._device)
m.lora_a.to_empty(device=lora_device)
m.lora_b.to_empty(device=lora_device)
m.initialize_parameters()
# RoPE is not covered in state dict
if isinstance(m, modules.RotaryPositionalEmbeddings):
m.reset_parameters()

utils.load_from_full_model_state_dict(
base_missing, base_unexpected = utils.load_from_full_model_state_dict(
model, base_model_state_dict, self._device, self._is_rank_zero
)

# LoRA hyper-params needed for merging weights while saving checkpoints
self._lora_rank = cfg_model.lora_rank
self._lora_alpha = cfg_model.lora_alpha

validate_missing_and_unexpected_for_lora(
lora_attn_modules=self._lora_attn_modules,
apply_lora_to_mlp=self._apply_lora_to_mlp,
apply_lora_to_output=self._apply_lora_to_output,
base_missing=base_missing,
base_unexpected=base_unexpected,
lora_missing=lora_missing,
lora_unexpected=lora_unexpected,
)
# Ensure no params and buffers are on meta device
utils.validate_no_params_on_meta_device(model)

Expand Down Expand Up @@ -590,6 +584,8 @@ def train(self) -> None:
logits = logits.transpose(1, 2)
# Compute loss
loss = self._loss_fn(logits, labels)
# free logits otherwise it peaks backward memory
del logits

loss = loss / self._gradient_accumulation_steps
running_loss += loss
Expand Down
45 changes: 43 additions & 2 deletions tests/torchtune/utils/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
from torch.distributed import launcher

from torch.distributed._composable.fsdp import fully_shard
from torch.distributed._tensor import DTensor
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

from torch.testing._internal.common_fsdp import FSDPTest, MLP
from torchao.dtypes.nf4tensor import NF4Tensor
from torchtune import modules, utils
from torchtune.models.llama2._component_builders import llama2, lora_llama2
from torchtune.models.llama3._component_builders import llama3
Expand Down Expand Up @@ -237,7 +239,7 @@ def test_lora_meta_device_init_fsdp(self):
embed_dim=EMBED_DIM,
max_seq_len=MAX_SEQ_LEN,
lora_rank=4,
lora_alpha=1.0,
lora_alpha=8,
)
utils.prepare_model_for_fsdp_with_meta_device(lora)
for m in lora.modules():
Expand All @@ -257,7 +259,7 @@ def world_size(self) -> int:
return 2

@gpu_test(gpu_count=2)
def test_state_dict(self):
def test_lora_state_dict(self):
rank = self.rank
is_rank_zero = rank == 0
mlp_dim = 4
Expand Down Expand Up @@ -393,6 +395,45 @@ def test_state_dict(self):
for key, value in sharded_model_sd.items():
self.assertEqual(value, expected_sharded_model_sd[key])

@gpu_test(gpu_count=2)
def test_qlora_state_dict(self):
is_rank_zero = self.rank == 0
torch.manual_seed(42)
kwargs = {
"lora_attn_modules": ["q_proj", "v_proj", "k_proj", "output_proj"],
"apply_lora_to_mlp": True,
"apply_lora_to_output": False,
"vocab_size": 1024,
"num_layers": 3,
"num_heads": 4,
"num_kv_heads": 2,
"embed_dim": 1024,
"max_seq_len": 64,
"lora_rank": 4,
"lora_alpha": 1.0,
}
with torch.device("cuda"):
lora_kwargs = dict({"quantize_base": False}, **kwargs)
model_lora = lora_llama2(**lora_kwargs)
full_sd = model_lora.cpu().state_dict()
with torch.device("meta"):
qlora_kwargs = dict({"quantize_base": True}, **kwargs)
model_qlora = lora_llama2(**qlora_kwargs)
set_trainable_params(model_qlora, get_adapter_params(model_qlora))
for m in model_qlora.modules():
if isinstance(m, modules.TransformerDecoderLayer):
fully_shard(m)
fully_shard(model_qlora)
utils.load_from_full_model_state_dict(
model_qlora, full_sd, "cuda", is_rank_zero
)
# LoRALinear base weights should be DTensor(NF4)
for name, module in model_qlora.named_modules():
if isinstance(module, LoRALinear):
self.assertTrue(isinstance(module.weight, DTensor))
self.assertTrue(isinstance(module.weight._local_tensor, NF4Tensor))
self.assertEqual(module.weight.device.type, "cuda")

def _broadcast_full_state_dict(self, full_sd):
result = []
if torch.distributed.get_rank() == 0:
Expand Down
Loading

0 comments on commit f9cb9e6

Please sign in to comment.