From 8b55ce71fb23abfdbdcf5b215abf328fffc4d022 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Wed, 16 Oct 2024 12:50:53 -0700 Subject: [PATCH] Update QAT: add grad clipping, torch.compile, collate fn **Summary:** Update the qat_distributed recipe to match the full_finetune_distributed recipe. This commit adds features to QAT like gradient clipping, torch.compile, and user configurable collate function for data pre-processing. **Test Plan:** TBD --- recipes/configs/llama3/8B_qat_full.yaml | 4 +- recipes/qat_distributed.py | 97 ++++++++++++++----------- 2 files changed, 56 insertions(+), 45 deletions(-) diff --git a/recipes/configs/llama3/8B_qat_full.yaml b/recipes/configs/llama3/8B_qat_full.yaml index ff4d9c319..8cb568967 100644 --- a/recipes/configs/llama3/8B_qat_full.yaml +++ b/recipes/configs/llama3/8B_qat_full.yaml @@ -63,7 +63,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -memory_efficient_fsdp_wrap: True +custom_sharded_layers: ['tok_embeddings', 'output'] # Reduced precision dtype: bf16 @@ -72,6 +72,6 @@ dtype: bf16 metric_logger: _component_: torchtune.training.metric_logging.DiskLogger log_dir: ${output_dir} -output_dir: /tmp/alpaca-llama3-finetune +output_dir: /tmp/full-llama3-finetune log_every_n_steps: 1 log_peak_memory_stats: False diff --git a/recipes/qat_distributed.py b/recipes/qat_distributed.py index eb2e44fae..95b84cf76 100644 --- a/recipes/qat_distributed.py +++ b/recipes/qat_distributed.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import os import sys import time @@ -21,7 +20,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler from torchtune import config, modules, training, utils -from torchtune.data import padded_collate_packed, padded_collate_sft +from torchtune.config._utils import _get_component_from_path +from torchtune.data import padded_collate_packed from torchtune.datasets import ConcatDataset from torchtune.recipe_interfaces import FTRecipeInterface from torchtune.training import DummyProfiler, PROFILER_KEY @@ -50,7 +50,7 @@ class QATRecipeDistributed(FTRecipeInterface): to improved quantized accuracy. This can be specified through ``fake_quant_after_n_steps``. - FSDP. Supported using PyTorch's FSDP APIs. CPU offload of parameters, gradients, and optimizer states - is supported via the ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is + is supported via ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is done by default (corresponding to FULL_SHARD sharding strategy), but can be disabled by setting the config ``fsdp_reshard_after_forward`` to False (this corresponds to SHARD_GRAD_OP sharding strategy). DDP is currently not supported. Training on CPU is not supported. @@ -93,6 +93,10 @@ class QATRecipeDistributed(FTRecipeInterface): - Logging. Terminal, Disk, WandB and TensorBoard are all supported. + - Gradient Clipping. Gradient clipping is supported using the ``clip_grad_norm`` flag. By default, + ``clip_grad_norm`` is set to ``None``. If you only want to log the grad norm, you can set + ``clip_grad_norm='inf'``. + For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config has example commands for how to kick-off training. @@ -102,6 +106,7 @@ class QATRecipeDistributed(FTRecipeInterface): Raises: ValueError: If ``dtype`` is set to fp16. RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. + RuntimeError: If ``left_pad_sequence`` is set as the data collator. """ def __init__(self, cfg: DictConfig) -> None: @@ -135,9 +140,6 @@ def __init__(self, cfg: DictConfig) -> None: # Training cfg self._resume_from_checkpoint = cfg.resume_from_checkpoint self._gradient_accumulation_steps = cfg.gradient_accumulation_steps - self._fsdp_sharding_strategy = torch.distributed.fsdp.ShardingStrategy[ - cfg.get("fsdp_sharding_strategy", "FULL_SHARD") - ] self._fake_quant_after_n_steps = cfg.get("fake_quant_after_n_steps", None) self._quantizer_mode = None @@ -148,6 +150,7 @@ def __init__(self, cfg: DictConfig) -> None: self.total_epochs = cfg.epochs self.max_steps_per_epoch = cfg.max_steps_per_epoch self.global_step = 0 + self._clip_grad_norm = cfg.get("clip_grad_norm", None) def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: """ @@ -217,7 +220,7 @@ def setup(self, cfg: DictConfig) -> None: checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) - self._model_compile = cfg.get("compile", False) + self._compile = cfg.get("compile", False) self._model = self._setup_model( cfg_model=cfg.model, enable_activation_checkpointing=cfg.enable_activation_checkpointing, @@ -240,30 +243,25 @@ def setup(self, cfg: DictConfig) -> None: # initialize loss self._loss_fn = config.instantiate(cfg.loss) - backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor") + + if self._compile: + training.compile_loss(self._loss_fn, verbose=self._is_rank_zero) + if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss": # set num_output_chunks for model self._model.set_num_output_chunks(self._loss_fn.num_output_chunks) - if self._model_compile: - log.info("Compiling loss with torch.compile...") - # For CEWithChunkedOutputLoss, if we compile the entire class - # we lose the benefits from the chunked loss. - # Therefore, we only compile the cross entropy function + upcasting - self._loss_fn.compute_cross_entropy = torch.compile( - self._loss_fn.compute_cross_entropy, backend=backend - ) - else: - if self._model_compile: - log.info("Compiling loss with torch.compile...") - self._loss_fn = torch.compile(self._loss_fn, backend=backend) - log.info("Loss is initialized.") + + if self._is_rank_zero: + log.info("Loss is initialized.") # sampler and dataloader depend on the tokenizer and loss_fn and should be # setup after both of these are initialized + collate_name = cfg.get("collate_fn", "torchtune.data.padded_collate_sft") self._sampler, self._dataloader = self._setup_data( cfg_dataset=cfg.dataset, shuffle=cfg.shuffle, batch_size=cfg.batch_size, + collate_fn=collate_name, ) # Finally update the recipe state which can only be correctly set after all of the @@ -388,6 +386,9 @@ def _setup_model( with training.set_default_dtype(self._dtype), torch.device("meta"): model = config.instantiate(cfg_model) + if self._compile: + training.compile_model(model, verbose=self._is_rank_zero) + # We currently have two versions of activation checkpointing in this recipe # for testing and BC purposes. ``enable_activation_checkpointing`` controls # the older version of AC and this behavior is unchanged @@ -459,7 +460,12 @@ def _is_layer_fqn(s: str) -> bool: # This method will convert the full model state dict into a sharded state # dict and load into the model training.load_from_full_model_state_dict( - model, model_state_dict, self._device, self._is_rank_zero, strict=True + model, + model_state_dict, + self._device, + self._is_rank_zero, + strict=True, + cpu_offload=fsdp_cpu_offload, ) # Ensure no params and buffers are on meta device @@ -497,6 +503,7 @@ def _setup_data( cfg_dataset: DictConfig, shuffle: bool, batch_size: int, + collate_fn: str, ) -> Tuple[DistributedSampler, DataLoader]: """ All data related setup happens here. Currently this recipe only supports the @@ -507,15 +514,20 @@ def _setup_data( if isinstance(cfg_dataset, ListConfig): datasets = [ - config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer) + config.instantiate(single_cfg_dataset, self._tokenizer) for single_cfg_dataset in cfg_dataset ] ds = ConcatDataset(datasets=datasets) packed = False else: - ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer) + ds = config.instantiate(cfg_dataset, self._tokenizer) packed = cfg_dataset.get("packed", False) + # Instantiate collate_fn + if "left_pad_sequence" in collate_fn: + raise RuntimeError("left_pad_sequence collator is only for inference.") + collate_fn = _get_component_from_path(collate_fn) + sampler = DistributedSampler( ds, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=0 ) @@ -526,14 +538,12 @@ def _setup_data( # dropping last avoids shape issues with compile + flex attention drop_last=True, collate_fn=partial( - padded_collate_sft, + collate_fn, padding_idx=self._tokenizer.pad_id, ignore_idx=self._loss_fn.ignore_index, ) if not packed - else partial( - padded_collate_packed, - ), + else padded_collate_packed, ) if self._is_rank_zero: @@ -564,12 +574,14 @@ def save_checkpoint( cpu_state_dict = training.get_full_model_state_dict( self._model, self._is_rank_zero, + device=self._device, ) if intermediate_checkpoint: opt_state_dict = training.get_full_optimizer_state_dict( self._optimizer, self._is_rank_zero, + device=self._device, ) else: opt_state_dict = None @@ -642,13 +654,6 @@ def train(self) -> None: ): torch.cuda.memory._record_memory_history() - # Both are shape [b, s] - tokens, labels = batch["tokens"], batch["labels"] - # Get the attention mask and position ids from the dataset if they - # exist. Currently, only sample packing in PackedDataset returns these - mask = batch.get("mask", None) # shape [b, s, s] - input_pos = batch.get("input_pos", None) # shape [b, s] - # Optionally wait N steps before enabling fake quant if self._fake_quant_after_n_steps is not None: if self.global_step == 0: @@ -670,15 +675,13 @@ def train(self) -> None: ) self._model.apply(enable_fq) - tokens = tokens.to(self._device) - num_tokens += tokens.numel() - labels = labels.to(self._device) - mask = mask.to(self._device) if mask is not None else None - input_pos = ( - input_pos.to(self._device) if input_pos is not None else None - ) + utils.batch_to_device(batch, self._device) + num_tokens += batch["tokens"].numel() + + # Shape [b, s], needed for the loss not the model + labels = batch.pop("labels") - logits = self._model(tokens, mask=mask, input_pos=input_pos) + logits = self._model(**batch) # Shift labels to compute loss # equivalent to doing labels[..., 1:] and logits[..., :-1, :] @@ -692,6 +695,7 @@ def train(self) -> None: # Compute loss loss = self._loss_fn(logits, labels) + # free logits otherwise it peaks backward memory del logits @@ -701,6 +705,11 @@ def train(self) -> None: # Step with optimizer if (idx + 1) % self._gradient_accumulation_steps == 0: + if self._clip_grad_norm is not None: + grad_norm = torch.nn.utils.clip_grad_norm_( + self._model.parameters(), + max_norm=float(self._clip_grad_norm), + ) self._optimizer.step() self._optimizer.zero_grad(set_to_none=True) @@ -728,6 +737,8 @@ def train(self) -> None: log_dict.update( training.get_memory_stats(device=self._device) ) + if self._clip_grad_norm is not None: + log_dict.update({"grad_norm": grad_norm}) self._metric_logger.log_dict( log_dict, step=self.global_step,