Skip to content

Commit

Permalink
Update QAT: add grad clipping, torch.compile, collate fn
Browse files Browse the repository at this point in the history
**Summary:** Update the qat_distributed recipe to match the
full_finetune_distributed recipe. This commit adds features to
QAT like gradient clipping, torch.compile, and user configurable
collate function for data pre-processing.

**Test Plan:** TBD
  • Loading branch information
andrewor14 committed Oct 16, 2024
1 parent 7d29c21 commit 8b55ce7
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 45 deletions.
4 changes: 2 additions & 2 deletions recipes/configs/llama3/8B_qat_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
memory_efficient_fsdp_wrap: True
custom_sharded_layers: ['tok_embeddings', 'output']

# Reduced precision
dtype: bf16
Expand All @@ -72,6 +72,6 @@ dtype: bf16
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}
output_dir: /tmp/alpaca-llama3-finetune
output_dir: /tmp/full-llama3-finetune
log_every_n_steps: 1
log_peak_memory_stats: False
97 changes: 54 additions & 43 deletions recipes/qat_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os
import sys
import time

Expand All @@ -21,7 +20,8 @@
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import config, modules, training, utils
from torchtune.data import padded_collate_packed, padded_collate_sft
from torchtune.config._utils import _get_component_from_path
from torchtune.data import padded_collate_packed
from torchtune.datasets import ConcatDataset
from torchtune.recipe_interfaces import FTRecipeInterface
from torchtune.training import DummyProfiler, PROFILER_KEY
Expand Down Expand Up @@ -50,7 +50,7 @@ class QATRecipeDistributed(FTRecipeInterface):
to improved quantized accuracy. This can be specified through ``fake_quant_after_n_steps``.
- FSDP. Supported using PyTorch's FSDP APIs. CPU offload of parameters, gradients, and optimizer states
is supported via the ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is
is supported via ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is
done by default (corresponding to FULL_SHARD sharding strategy), but can be disabled by setting the config
``fsdp_reshard_after_forward`` to False (this corresponds to SHARD_GRAD_OP sharding strategy).
DDP is currently not supported. Training on CPU is not supported.
Expand Down Expand Up @@ -93,6 +93,10 @@ class QATRecipeDistributed(FTRecipeInterface):
- Logging. Terminal, Disk, WandB and TensorBoard are all supported.
- Gradient Clipping. Gradient clipping is supported using the ``clip_grad_norm`` flag. By default,
``clip_grad_norm`` is set to ``None``. If you only want to log the grad norm, you can set
``clip_grad_norm='inf'``.
For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config
has example commands for how to kick-off training.
Expand All @@ -102,6 +106,7 @@ class QATRecipeDistributed(FTRecipeInterface):
Raises:
ValueError: If ``dtype`` is set to fp16.
RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16.
RuntimeError: If ``left_pad_sequence`` is set as the data collator.
"""

def __init__(self, cfg: DictConfig) -> None:
Expand Down Expand Up @@ -135,9 +140,6 @@ def __init__(self, cfg: DictConfig) -> None:
# Training cfg
self._resume_from_checkpoint = cfg.resume_from_checkpoint
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps
self._fsdp_sharding_strategy = torch.distributed.fsdp.ShardingStrategy[
cfg.get("fsdp_sharding_strategy", "FULL_SHARD")
]
self._fake_quant_after_n_steps = cfg.get("fake_quant_after_n_steps", None)
self._quantizer_mode = None

Expand All @@ -148,6 +150,7 @@ def __init__(self, cfg: DictConfig) -> None:
self.total_epochs = cfg.epochs
self.max_steps_per_epoch = cfg.max_steps_per_epoch
self.global_step = 0
self._clip_grad_norm = cfg.get("clip_grad_norm", None)

def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -217,7 +220,7 @@ def setup(self, cfg: DictConfig) -> None:

checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer)

self._model_compile = cfg.get("compile", False)
self._compile = cfg.get("compile", False)
self._model = self._setup_model(
cfg_model=cfg.model,
enable_activation_checkpointing=cfg.enable_activation_checkpointing,
Expand All @@ -240,30 +243,25 @@ def setup(self, cfg: DictConfig) -> None:

# initialize loss
self._loss_fn = config.instantiate(cfg.loss)
backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor")

if self._compile:
training.compile_loss(self._loss_fn, verbose=self._is_rank_zero)

if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss":
# set num_output_chunks for model
self._model.set_num_output_chunks(self._loss_fn.num_output_chunks)
if self._model_compile:
log.info("Compiling loss with torch.compile...")
# For CEWithChunkedOutputLoss, if we compile the entire class
# we lose the benefits from the chunked loss.
# Therefore, we only compile the cross entropy function + upcasting
self._loss_fn.compute_cross_entropy = torch.compile(
self._loss_fn.compute_cross_entropy, backend=backend
)
else:
if self._model_compile:
log.info("Compiling loss with torch.compile...")
self._loss_fn = torch.compile(self._loss_fn, backend=backend)
log.info("Loss is initialized.")

if self._is_rank_zero:
log.info("Loss is initialized.")

# sampler and dataloader depend on the tokenizer and loss_fn and should be
# setup after both of these are initialized
collate_name = cfg.get("collate_fn", "torchtune.data.padded_collate_sft")
self._sampler, self._dataloader = self._setup_data(
cfg_dataset=cfg.dataset,
shuffle=cfg.shuffle,
batch_size=cfg.batch_size,
collate_fn=collate_name,
)

# Finally update the recipe state which can only be correctly set after all of the
Expand Down Expand Up @@ -388,6 +386,9 @@ def _setup_model(
with training.set_default_dtype(self._dtype), torch.device("meta"):
model = config.instantiate(cfg_model)

if self._compile:
training.compile_model(model, verbose=self._is_rank_zero)

# We currently have two versions of activation checkpointing in this recipe
# for testing and BC purposes. ``enable_activation_checkpointing`` controls
# the older version of AC and this behavior is unchanged
Expand Down Expand Up @@ -459,7 +460,12 @@ def _is_layer_fqn(s: str) -> bool:
# This method will convert the full model state dict into a sharded state
# dict and load into the model
training.load_from_full_model_state_dict(
model, model_state_dict, self._device, self._is_rank_zero, strict=True
model,
model_state_dict,
self._device,
self._is_rank_zero,
strict=True,
cpu_offload=fsdp_cpu_offload,
)

# Ensure no params and buffers are on meta device
Expand Down Expand Up @@ -497,6 +503,7 @@ def _setup_data(
cfg_dataset: DictConfig,
shuffle: bool,
batch_size: int,
collate_fn: str,
) -> Tuple[DistributedSampler, DataLoader]:
"""
All data related setup happens here. Currently this recipe only supports the
Expand All @@ -507,15 +514,20 @@ def _setup_data(

if isinstance(cfg_dataset, ListConfig):
datasets = [
config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer)
config.instantiate(single_cfg_dataset, self._tokenizer)
for single_cfg_dataset in cfg_dataset
]
ds = ConcatDataset(datasets=datasets)
packed = False
else:
ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer)
ds = config.instantiate(cfg_dataset, self._tokenizer)
packed = cfg_dataset.get("packed", False)

# Instantiate collate_fn
if "left_pad_sequence" in collate_fn:
raise RuntimeError("left_pad_sequence collator is only for inference.")
collate_fn = _get_component_from_path(collate_fn)

sampler = DistributedSampler(
ds, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=0
)
Expand All @@ -526,14 +538,12 @@ def _setup_data(
# dropping last avoids shape issues with compile + flex attention
drop_last=True,
collate_fn=partial(
padded_collate_sft,
collate_fn,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else partial(
padded_collate_packed,
),
else padded_collate_packed,
)

if self._is_rank_zero:
Expand Down Expand Up @@ -564,12 +574,14 @@ def save_checkpoint(
cpu_state_dict = training.get_full_model_state_dict(
self._model,
self._is_rank_zero,
device=self._device,
)

if intermediate_checkpoint:
opt_state_dict = training.get_full_optimizer_state_dict(
self._optimizer,
self._is_rank_zero,
device=self._device,
)
else:
opt_state_dict = None
Expand Down Expand Up @@ -642,13 +654,6 @@ def train(self) -> None:
):
torch.cuda.memory._record_memory_history()

# Both are shape [b, s]
tokens, labels = batch["tokens"], batch["labels"]
# Get the attention mask and position ids from the dataset if they
# exist. Currently, only sample packing in PackedDataset returns these
mask = batch.get("mask", None) # shape [b, s, s]
input_pos = batch.get("input_pos", None) # shape [b, s]

# Optionally wait N steps before enabling fake quant
if self._fake_quant_after_n_steps is not None:
if self.global_step == 0:
Expand All @@ -670,15 +675,13 @@ def train(self) -> None:
)
self._model.apply(enable_fq)

tokens = tokens.to(self._device)
num_tokens += tokens.numel()
labels = labels.to(self._device)
mask = mask.to(self._device) if mask is not None else None
input_pos = (
input_pos.to(self._device) if input_pos is not None else None
)
utils.batch_to_device(batch, self._device)
num_tokens += batch["tokens"].numel()

# Shape [b, s], needed for the loss not the model
labels = batch.pop("labels")

logits = self._model(tokens, mask=mask, input_pos=input_pos)
logits = self._model(**batch)

# Shift labels to compute loss
# equivalent to doing labels[..., 1:] and logits[..., :-1, :]
Expand All @@ -692,6 +695,7 @@ def train(self) -> None:

# Compute loss
loss = self._loss_fn(logits, labels)

# free logits otherwise it peaks backward memory
del logits

Expand All @@ -701,6 +705,11 @@ def train(self) -> None:

# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
max_norm=float(self._clip_grad_norm),
)
self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)

Expand Down Expand Up @@ -728,6 +737,8 @@ def train(self) -> None:
log_dict.update(
training.get_memory_stats(device=self._device)
)
if self._clip_grad_norm is not None:
log_dict.update({"grad_norm": grad_norm})
self._metric_logger.log_dict(
log_dict,
step=self.global_step,
Expand Down

0 comments on commit 8b55ce7

Please sign in to comment.