-
Notifications
You must be signed in to change notification settings - Fork 404
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Full finetune FSDP2 recipe #1287
Changes from 3 commits
537f3f8
d99e761
721e81c
070ad1d
9a62f7d
ad68c62
0f77094
78e254c
ba1cb92
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -4,6 +4,7 @@ | |||||
# This source code is licensed under the BSD-style license found in the | ||||||
# LICENSE file in the root directory of this source tree. | ||||||
|
||||||
import os | ||||||
import sys | ||||||
import time | ||||||
|
||||||
|
@@ -15,25 +16,21 @@ | |||||
from omegaconf import DictConfig, ListConfig | ||||||
|
||||||
from torch import nn | ||||||
from torch.distributed import init_process_group | ||||||
from torch.distributed.fsdp import ( | ||||||
CPUOffload, | ||||||
FullOptimStateDictConfig, | ||||||
FullStateDictConfig, | ||||||
FullyShardedDataParallel as FSDP, | ||||||
StateDictType, | ||||||
from torch.distributed import destroy_process_group, init_process_group | ||||||
from torch.distributed._composable.fsdp import CPUOffloadPolicy, fully_shard | ||||||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( | ||||||
CheckpointWrapper, | ||||||
) | ||||||
|
||||||
from torch.optim import Optimizer | ||||||
from torch.utils.data import DataLoader, DistributedSampler | ||||||
|
||||||
from torchtune import config, modules, utils | ||||||
from torchtune.datasets import ConcatDataset | ||||||
from torchtune.recipe_interfaces import FTRecipeInterface | ||||||
from torchtune.utils.activations import apply_selective_activation_checkpointing | ||||||
|
||||||
from tqdm import tqdm | ||||||
|
||||||
|
||||||
log = utils.get_logger("DEBUG") | ||||||
|
||||||
|
||||||
|
@@ -43,8 +40,8 @@ class FullFinetuneRecipeDistributed(FTRecipeInterface): | |||||
distributed training and can be run on a single node (1 to 8 GPUs). | ||||||
|
||||||
Features: | ||||||
- FSDP. Supported using PyTorch's FSDP APIs. DDP is currently not supported. Training on CPU | ||||||
is not supported. | ||||||
- FSDP. Supported using PyTorch's FSDP APIs. DDP is currently not supported. Training on CPU is not | ||||||
supported. | ||||||
|
||||||
- Activation Checkpointing. This can be controlled using the ``activation_checkpointing`` | ||||||
flag. Activation checkpointing helps reduce the memory footprint since we no longer keep | ||||||
|
@@ -92,10 +89,10 @@ class FullFinetuneRecipeDistributed(FTRecipeInterface): | |||||
|
||||||
Raises: | ||||||
ValueError: If ``dtype`` is set to fp16. | ||||||
RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. | ||||||
""" | ||||||
|
||||||
def __init__(self, cfg: DictConfig) -> None: | ||||||
|
||||||
self._device = utils.get_device(device=cfg.device) | ||||||
self._dtype = utils.get_dtype(cfg.dtype, device=self._device) | ||||||
|
||||||
|
@@ -192,37 +189,31 @@ def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: | |||||
|
||||||
def setup(self, cfg: DictConfig) -> None: | ||||||
""" | ||||||
Sets up the recipe state correctly. This includes setting recipe attributes based | ||||||
on the ``resume_from_checkpoint`` flag. | ||||||
Setup the recipe state. This includes recipe state (if resume_from_checkpoint is True), | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This sentence reads funny - just duplicating "recipe state" twice. |
||||||
model, tokenizer, loss, optimizer, sampler, and dataloader. | ||||||
""" | ||||||
if self._is_rank_zero: | ||||||
self._metric_logger = config.instantiate(cfg.metric_logger) | ||||||
|
||||||
# log config with parameter override | ||||||
self._metric_logger.log_config(cfg) | ||||||
|
||||||
ckpt_dict = self.load_checkpoint(cfg.checkpointer) | ||||||
checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) | ||||||
ebsmothers marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
# ``_setup_model`` handles initialization and loading the state dict. This method | ||||||
# should be called before ``_setup_optimizer`` since transforming the optimizer | ||||||
# state dict requires the model | ||||||
self._model = self._setup_model( | ||||||
cfg_model=cfg.model, | ||||||
enable_activation_checkpointing=cfg.enable_activation_checkpointing, | ||||||
memory_efficient_fsdp_wrap=cfg.get("memory_efficient_fsdp_wrap", False), | ||||||
fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False), | ||||||
model_state_dict=ckpt_dict[utils.MODEL_KEY], | ||||||
model_state_dict=checkpoint_dict[utils.MODEL_KEY], | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [Misplaced Comment] I never fully understood There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes I think we do still need some version of it, otherwise our peak memory will be higher on Llama3 models |
||||||
ac_mode=cfg.get("ac_mode", None), | ||||||
ac_option=cfg.get("ac_option", None), | ||||||
) | ||||||
|
||||||
self._tokenizer = config.instantiate(cfg.tokenizer) | ||||||
|
||||||
# _setup_optimizer should take in ckpt_dict only if training is resumed from | ||||||
# checkpoint. Transforming the opt state dict is handled by this method | ||||||
self._optimizer = self._setup_optimizer( | ||||||
cfg_optimizer=cfg.optimizer, | ||||||
opt_state_dict=ckpt_dict[utils.OPT_KEY] | ||||||
opt_state_dict=checkpoint_dict[utils.OPT_KEY] | ||||||
if self._resume_from_checkpoint | ||||||
else None, | ||||||
) | ||||||
|
@@ -266,37 +257,21 @@ def _setup_model( | |||||
) -> nn.Module: | ||||||
""" | ||||||
Model initialization has some important considerations: | ||||||
a. To minimize GPU peak memory, we load the model on CPU with the right | ||||||
dtype. To ensure that we don't instantiate ``world_size`` number of models, | ||||||
we initialize on meta_device for all ranks other than rank 0. | ||||||
b. Rank 0 is also responsible for calling ``load_state_dict`` and loading the | ||||||
model weights from checkpoint. | ||||||
c. While wrapping the model with FSDP, we set ``sync_module_states`` | ||||||
to TRUE and broadcast module params and buffers from rank 0. | ||||||
d. The ``device_id`` param ensures that the FSDP initialization happens on | ||||||
the correct device. | ||||||
a. To minimize GPU peak memory, we initialize the model on meta device with | ||||||
the right dtype | ||||||
b. All ranks calls ``load_state_dict`` without peaking CPU RAMs since | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do all ranks load state dict now? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because we shard as we load. Previously we only loaded on rank 0 because we relied on sync_module_states to broadcast to all the other ranks. I think this way should be more memory-efficient because we never need a full copy of the state dict on rank 0 |
||||||
full state dicts are loaded with ``torch.load(mmap=True)`` | ||||||
c. We register (pre-)forward hooks with ``fully_shard`` instead of wrapping `nn.Module` | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we should say "instead of" this should just explain how we shard, not compare to FSDP1 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "We register (pre) forward hooks" -> why do we need this or what does this add to the docstring? I'd be willing to bet not many users know what this means There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah these are both fair points, lemme update |
||||||
""" | ||||||
if self._is_rank_zero: | ||||||
log.info("FSDP is enabled. Instantiating Model on CPU for Rank 0 ...") | ||||||
init_start = time.perf_counter() | ||||||
|
||||||
with utils.set_default_dtype(self._dtype): | ||||||
model = config.instantiate(cfg_model) | ||||||
|
||||||
if self._is_rank_zero: | ||||||
log.info( | ||||||
f"Model instantiation took {time.perf_counter() - init_start:.2f} secs" | ||||||
"FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ..." | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Where is FSDP being enabled? Does _is_rank_zero only get set to true if inti_process_group is called? If so that's not very clear from the code. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry not sure I fully understand this question. The |
||||||
) | ||||||
init_start = time.perf_counter() | ||||||
|
||||||
# Load both the model weights. This should happen only on Rank 0 | ||||||
model.load_state_dict(model_state_dict) | ||||||
|
||||||
else: | ||||||
# For non-zero ranks, load the model on meta device | ||||||
with utils.set_default_dtype(self._dtype), torch.device("meta"): | ||||||
model = config.instantiate(cfg_model) | ||||||
|
||||||
if self._dtype == torch.bfloat16: | ||||||
model = model.to(torch.bfloat16) | ||||||
with utils.set_default_dtype(self._dtype), torch.device("meta"): | ||||||
model = config.instantiate(cfg_model) | ||||||
|
||||||
# We currently have two versions of activation checkpointing in this recipe | ||||||
# for testing and BC purposes. ``enable_activation_checkpointing`` controls | ||||||
|
@@ -314,41 +289,51 @@ def _setup_model( | |||||
ac_option, | ||||||
) | ||||||
|
||||||
# Wrap the model with FSDP. This will ensure that the model is sharded | ||||||
# across all available GPUs. | ||||||
model = FSDP( | ||||||
module=model, | ||||||
auto_wrap_policy=utils.get_full_finetune_fsdp_wrap_policy( | ||||||
memory_efficient_fsdp_wrap=memory_efficient_fsdp_wrap, | ||||||
modules_to_wrap={modules.TransformerDecoderLayer}, | ||||||
), | ||||||
cpu_offload=CPUOffload(offload_params=fsdp_cpu_offload), | ||||||
sharding_strategy=torch.distributed.fsdp.ShardingStrategy.FULL_SHARD, | ||||||
device_id=self._device, | ||||||
# this recipe does not currently support mixed precision training | ||||||
mixed_precision=None, | ||||||
# Ensure we broadcast params and buffers from rank 0 | ||||||
sync_module_states=True, | ||||||
# Initialize empty modules on all non-zero ranks | ||||||
param_init_fn=( | ||||||
lambda module: module.to_empty( | ||||||
device=torch.device("cuda"), recurse=False | ||||||
) | ||||||
if not self._is_rank_zero | ||||||
else None | ||||||
), | ||||||
) | ||||||
|
||||||
# Ensure no params and buffers are on meta device | ||||||
utils.validate_no_params_on_meta_device(model) | ||||||
|
||||||
# original activation checkpointing (full) - flip the condition above | ||||||
if enable_activation_checkpointing and ac_mode is None: | ||||||
utils.set_activation_checkpointing( | ||||||
model, auto_wrap_policy={modules.TransformerDecoderLayer} | ||||||
) | ||||||
|
||||||
fsdp_kwargs = {} | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [Misplaced comment] Do we need to remove the logic related to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh yeah I can see how this is confusing. A bit awkward but auto_wrap_policy for AC != auto_wrap_policy from our dear old friend FSDP. You can see the usage here |
||||||
if fsdp_cpu_offload: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Apologies if I'm projecting, but I think this can use some comments around what exactly are we offloading to CPU. |
||||||
fsdp_kwargs["offload_policy"] = CPUOffloadPolicy() | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we expose the sharding strategy like in #1024 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can, but imo at this stage it's a nice-to-have rather than a must-have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It does feel wrong to remove a feature that was just recently contributed though. It feels the goal of moving this out of dev is to integrate it with everything we already have. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It hasn't landed yet though There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My vote would be to expose the sharding strategy - its a really easy lever to trade off memory for performance and it's unfortunate that we still haven't landed this for FSDP1. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To follow up on this, there is not a 1:1 mapping between FSDP1 sharding strategy and FSDP2 APIs. But (1) for single-node training (all we currently support), the only relevant sharding strategies should be FULL_SHARD, SHARD_GRAD_OP, and NO_SHARD. HYBRID_SHARD and _HYBRID_SHARD_ZERO2 are only distinct from these in the case of multi-node training. So then we are just down to FULL_SHARD and SHARD_GRAD_OP, which correspond to setting reshard_after_forward to True and False, respectively. So I can expose this flag in the recipe. Thanks to @weifengpy for helping to clear this up! |
||||||
|
||||||
# Shard the model with FSDP | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You already know this, but the block below is giving way too much away in terms of readability. Probably need some well designed utilities here. But apart from utilities, I think you can structure this in a way which is much easier to understand. Remember, torchtune is one of the only libraries with FSDP2 integration. We should showcase this well There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah this is the top thing on my plate to refactor here |
||||||
for m in reversed(list(model.modules())): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it make sense to have utils for these two things? Also, whether as utils or still here, it might be cleaner to follow if you looped through the modules list 2 times, once for ac and once for distributed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I think we should have utils. And to be clear, we are not actually doing any AC wrapping here, that happens above. Here we are just checking which modules have been wrapped by AC already. |
||||||
# TransformerDecoderLayer is wrapped by CheckpointWrapper | ||||||
# when enable_activation_checkpointing | ||||||
if enable_activation_checkpointing: | ||||||
if isinstance(m, CheckpointWrapper): | ||||||
fully_shard(m, **fsdp_kwargs) | ||||||
# For large vocab size, we can additionally shard | ||||||
# the token embeddings and output projections individually | ||||||
if memory_efficient_fsdp_wrap: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we just rename this to "shared_embedding" or something more descriptive? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's not what this is though. This is just for the case where we want to shard token embeddings and output projections in their own unit (instead of just in the top-level model sharding). We found this saves a fair bit of memory for Llama3. I do agree the naming is a bit awkward, definitely open to other suggestions here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we're discussing this, can I say that having model specific functionality (this was included for llama3) infiltrating the recipe always made little sense to me. Branding it with something as generic as "memory_efficient_fsdp_wrap" makes this worse. Maybe leaking this into the recipe is unavoidable but I think we can abstract this functionality better. One idea is to pass the modules we want to wrap separately as string to the utility whci does the wrapping. This would be very similar to how we decide the layers for LoRA. It also does away with the need to call this "memory_efficient" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So a couple thoughts here: (1) This is an optimization that we only apply to Llama 3, but it should be generally useful for any model with a very large vocab size. Ofc I agree that as we've named it none of these points are obvious.. my initial intention here was to be as non-BC-breaking as possible, but if you guys want a rename of this I am happy to do it. (2) On passing modules we want to wrap separately as a string: I think this is an interesting idea -- if we iterate over named modules we could probably just do a check on the FQN? The only concern I have is that it could be a bit easy for someone to shoot themselves in the foot here.. we are giving a lot more freedom over configurations that we have not really tried out ourselves. Either way I do plan to move this into a separate utility. I'm still on the fence about how exactly to expose the individual wrapping of these layers, will try a couple approaches and see what makes sense here. |
||||||
if isinstance(m, nn.Embedding): | ||||||
fully_shard(m, **fsdp_kwargs) | ||||||
if isinstance(m, modules.TransformerDecoder): | ||||||
fully_shard(m.output, **fsdp_kwargs) | ||||||
else: | ||||||
if isinstance(m, modules.TransformerDecoderLayer): | ||||||
ebsmothers marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
fully_shard(m, **fsdp_kwargs) | ||||||
fully_shard(model, **fsdp_kwargs) | ||||||
|
||||||
with utils.set_default_dtype(self._dtype), self._device: | ||||||
for m in model.modules(): | ||||||
# RoPE is not covered in state dict | ||||||
if isinstance(m, modules.RotaryPositionalEmbeddings): | ||||||
m.reset_parameters() | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How should I conceptually parse this block? IIUC, at this stage all of the model params are init on meta device ( Also, we called this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
It shouldn't matter because the RoPE buffers are not in the state dict anyways. This becomes more relevant for LoRA though, where we maybe have existing adapter weights to load. In that case, we either move from meta device on state dict load or on the usual param init, but not both. But even in this case since we only do one or the other I think the order shouldn't matter (lmk if this makes sense or not though)
Yeah I think we can do this |
||||||
|
||||||
utils.load_from_full_model_state_dict( | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a brief comment here saying what this is actually doing - helps make sure this function is self explanatory which is how we intended it to be |
||||||
model, model_state_dict, self._device, self._is_rank_zero, strict=True | ||||||
) | ||||||
|
||||||
# Ensure no params and buffers are on meta device | ||||||
utils.validate_no_params_on_meta_device(model) | ||||||
|
||||||
if self._is_rank_zero: | ||||||
log.info( | ||||||
f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs" | ||||||
) | ||||||
memory_stats = utils.get_memory_stats(device=self._device) | ||||||
utils.log_memory_stats(memory_stats) | ||||||
|
||||||
|
@@ -360,17 +345,13 @@ def _setup_model( | |||||
def _setup_optimizer( | ||||||
self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None | ||||||
) -> Optimizer: | ||||||
""" | ||||||
Set up the optimizer. This method also handles transforing the state dict | ||||||
for FSDP. | ||||||
""" | ||||||
optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) | ||||||
|
||||||
if opt_state_dict: | ||||||
opt_state_dict = FSDP.optim_state_dict_to_load( | ||||||
self._model, optimizer, opt_state_dict | ||||||
utils.load_from_full_optimizer_state_dict( | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @weifengpy IIRC there were some APIs related to optimizer and state_dict that were in the process of being moved over to distributed? Is that still the plan? If so, we can just import these from distributed instead of from utils? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I worked with @mori360 last month to upstream the implementation into PTD API. Numerically it's on-par now, but we need to verify memory behavior before formal migration pytorch/pytorch#128745 |
||||||
optimizer, | ||||||
opt_state_dict, | ||||||
self._device, | ||||||
) | ||||||
optimizer.load_state_dict(opt_state_dict) | ||||||
|
||||||
if self._is_rank_zero: | ||||||
log.info("Optimizer is initialized.") | ||||||
|
@@ -391,21 +372,17 @@ def _setup_data( | |||||
|
||||||
if isinstance(cfg_dataset, ListConfig): | ||||||
datasets = [ | ||||||
config.instantiate(single_cfg_dataset, self._tokenizer) | ||||||
config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer) | ||||||
for single_cfg_dataset in cfg_dataset | ||||||
] | ||||||
ds = ConcatDataset(datasets=datasets) | ||||||
packed = False | ||||||
else: | ||||||
ds = config.instantiate(cfg_dataset, self._tokenizer) | ||||||
ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer) | ||||||
packed = cfg_dataset.get("packed", False) | ||||||
|
||||||
sampler = DistributedSampler( | ||||||
ds, | ||||||
num_replicas=world_size, | ||||||
rank=rank, | ||||||
shuffle=shuffle, | ||||||
seed=0, | ||||||
ds, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=0 | ||||||
) | ||||||
dataloader = DataLoader( | ||||||
dataset=ds, | ||||||
|
@@ -425,32 +402,48 @@ def _setup_data( | |||||
|
||||||
return sampler, dataloader | ||||||
|
||||||
def save_checkpoint(self, epoch: int) -> None: | ||||||
def save_checkpoint( | ||||||
self, | ||||||
epoch: int, | ||||||
) -> None: | ||||||
""" | ||||||
Save state dict to file. The recipe save_checkpoint method is responsible for | ||||||
correctly creating the checkpoint dict and passing to the checkpointer. | ||||||
Checkpoint the state of the recipe. The constructed checkpoint state dict | ||||||
contains the following information: | ||||||
- Model weights with key MODEL_KEY | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
- Relevant recipe state if training is not complete | ||||||
|
||||||
Checkpointer will save the model weights and recipe state in | ||||||
different checkpoint files. To correctly resume training from an intermediate checkpoint, | ||||||
the model weights and recipe state must be provided. | ||||||
""" | ||||||
# final dict passed onto the checkpointer | ||||||
checkpoint_dict = {} | ||||||
|
||||||
intermediate_checkpoint = epoch + 1 < self.total_epochs | ||||||
# To prevent GPU memory from spiking during checkpoint save, | ||||||
# we consolidate the full model and optim state dicts on CPU for rank 0 | ||||||
with FSDP.state_dict_type( | ||||||
cpu_state_dict = utils.get_full_model_state_dict( | ||||||
self._model, | ||||||
StateDictType.FULL_STATE_DICT, | ||||||
FullStateDictConfig(offload_to_cpu=True, rank0_only=True), | ||||||
FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), | ||||||
): | ||||||
cpu_state_dict = self._model.state_dict() | ||||||
opt_state_dict = FSDP.optim_state_dict(self._model, self._optimizer) | ||||||
self._is_rank_zero, | ||||||
) | ||||||
|
||||||
if intermediate_checkpoint: | ||||||
opt_state_dict = utils.get_full_optimizer_state_dict( | ||||||
self._optimizer, | ||||||
self._is_rank_zero, | ||||||
) | ||||||
else: | ||||||
opt_state_dict = None | ||||||
|
||||||
# Now that we have the model and opt state dict, create the actual checkpoint dict | ||||||
# to be sent to the checkpointer and ultimately written to file | ||||||
if self._is_rank_zero: | ||||||
|
||||||
checkpoint_dict.update({utils.MODEL_KEY: cpu_state_dict}) | ||||||
|
||||||
# if training is in-progress, checkpoint the optimizer state as well | ||||||
if epoch + 1 < self.total_epochs: | ||||||
# if training is in-progress, checkpoint the optimizer state and recipe state | ||||||
# as well. | ||||||
if intermediate_checkpoint: | ||||||
checkpoint_dict.update( | ||||||
{ | ||||||
utils.OPT_KEY: opt_state_dict, | ||||||
|
@@ -464,13 +457,12 @@ def save_checkpoint(self, epoch: int) -> None: | |||||
self._checkpointer.save_checkpoint( | ||||||
checkpoint_dict, | ||||||
epoch=epoch, | ||||||
intermediate_checkpoint=(epoch + 1 < self.total_epochs), | ||||||
intermediate_checkpoint=intermediate_checkpoint, | ||||||
) | ||||||
|
||||||
def train(self) -> None: | ||||||
""" | ||||||
The core training loop. Supports training on subsets of the dataset using the | ||||||
``max_steps_per_epoch``. | ||||||
The core training loop. | ||||||
""" | ||||||
# clean up before training begins | ||||||
utils.cleanup_before_training() | ||||||
|
@@ -573,7 +565,7 @@ def train(self) -> None: | |||||
def cleanup(self) -> None: | ||||||
if self._is_rank_zero: | ||||||
self._metric_logger.close() | ||||||
torch.distributed.destroy_process_group() | ||||||
destroy_process_group() | ||||||
|
||||||
|
||||||
@config.parse | ||||||
|
@@ -590,12 +582,8 @@ def recipe_main(cfg: DictConfig) -> None: | |||||
"Distributed finetune recipe should be run via a distributed launcher." | ||||||
"If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]" | ||||||
) | ||||||
|
||||||
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you explain what this does? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh yeah this is a good question. This was originally used to save memory when we were using FSDP with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can remove setting TORCH_NCCL_AVOID_RECORD_STREAMS in FSDP2 because we get rid of recordstream. |
||||||
init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl") | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know this is old, but technically I think we only want nccl if device is cuda. The default for init_process_group just does both which works really well. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What other devices do we support right now though?
Not sure I understand this.. do you mean we should just be calling |
||||||
if cfg.get("fsdp_cpu_offload", False): | ||||||
# Utilize all available CPU cores for intra-op parallelism. This provides ~2x | ||||||
# speed up when benchmarking fused AdamW on CPU | ||||||
utils.set_torch_num_threads() | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. setting num of threads is a critical lesson we learned in cpu offloading. the perf indication is huge. curious why removing it ? do not need cpu offloading anymore? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah thanks, yes this should stay in. Will update |
||||||
|
||||||
config.log_config(recipe_name="FullFinetuneRecipeDistributed", cfg=cfg) | ||||||
|
||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I might be misremembering but there was a discussion at some point around depending on a public API instead of something under "_checkpoint". Do you remember what that was?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we're thinking about the same thing, it may have been the suggestion to move to torch.utils.checkpoint instead of relying on distributed APIs for AC. If it's all the same to you I would punt on that here