Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Full finetune FSDP2 recipe #1287

Merged
merged 9 commits into from
Aug 11, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 102 additions & 110 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os
import sys
import time

Expand All @@ -15,25 +16,21 @@
from omegaconf import DictConfig, ListConfig

from torch import nn
from torch.distributed import init_process_group
from torch.distributed.fsdp import (
CPUOffload,
FullOptimStateDictConfig,
FullStateDictConfig,
FullyShardedDataParallel as FSDP,
StateDictType,
from torch.distributed import destroy_process_group, init_process_group
from torch.distributed._composable.fsdp import CPUOffloadPolicy, fully_shard
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be misremembering but there was a discussion at some point around depending on a public API instead of something under "_checkpoint". Do you remember what that was?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're thinking about the same thing, it may have been the suggestion to move to torch.utils.checkpoint instead of relying on distributed APIs for AC. If it's all the same to you I would punt on that here

CheckpointWrapper,
)

from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler

from torchtune import config, modules, utils
from torchtune.datasets import ConcatDataset
from torchtune.recipe_interfaces import FTRecipeInterface
from torchtune.utils.activations import apply_selective_activation_checkpointing

from tqdm import tqdm


log = utils.get_logger("DEBUG")


Expand All @@ -43,8 +40,8 @@ class FullFinetuneRecipeDistributed(FTRecipeInterface):
distributed training and can be run on a single node (1 to 8 GPUs).

Features:
- FSDP. Supported using PyTorch's FSDP APIs. DDP is currently not supported. Training on CPU
is not supported.
- FSDP. Supported using PyTorch's FSDP APIs. DDP is currently not supported. Training on CPU is not
supported.

- Activation Checkpointing. This can be controlled using the ``activation_checkpointing``
flag. Activation checkpointing helps reduce the memory footprint since we no longer keep
Expand Down Expand Up @@ -92,10 +89,10 @@ class FullFinetuneRecipeDistributed(FTRecipeInterface):

Raises:
ValueError: If ``dtype`` is set to fp16.
RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16.
"""

def __init__(self, cfg: DictConfig) -> None:

self._device = utils.get_device(device=cfg.device)
self._dtype = utils.get_dtype(cfg.dtype, device=self._device)

Expand Down Expand Up @@ -192,37 +189,31 @@ def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None:

def setup(self, cfg: DictConfig) -> None:
"""
Sets up the recipe state correctly. This includes setting recipe attributes based
on the ``resume_from_checkpoint`` flag.
Setup the recipe state. This includes recipe state (if resume_from_checkpoint is True),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sentence reads funny - just duplicating "recipe state" twice.

model, tokenizer, loss, optimizer, sampler, and dataloader.
"""
if self._is_rank_zero:
self._metric_logger = config.instantiate(cfg.metric_logger)

# log config with parameter override
self._metric_logger.log_config(cfg)

ckpt_dict = self.load_checkpoint(cfg.checkpointer)
checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer)
ebsmothers marked this conversation as resolved.
Show resolved Hide resolved

# ``_setup_model`` handles initialization and loading the state dict. This method
# should be called before ``_setup_optimizer`` since transforming the optimizer
# state dict requires the model
self._model = self._setup_model(
cfg_model=cfg.model,
enable_activation_checkpointing=cfg.enable_activation_checkpointing,
memory_efficient_fsdp_wrap=cfg.get("memory_efficient_fsdp_wrap", False),
fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False),
model_state_dict=ckpt_dict[utils.MODEL_KEY],
model_state_dict=checkpoint_dict[utils.MODEL_KEY],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Misplaced Comment] I never fully understood memory_efficient_fsdp_wrap, but do we still need that when we migrate to FSDP2?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think we do still need some version of it, otherwise our peak memory will be higher on Llama3 models

ac_mode=cfg.get("ac_mode", None),
ac_option=cfg.get("ac_option", None),
)

self._tokenizer = config.instantiate(cfg.tokenizer)

# _setup_optimizer should take in ckpt_dict only if training is resumed from
# checkpoint. Transforming the opt state dict is handled by this method
self._optimizer = self._setup_optimizer(
cfg_optimizer=cfg.optimizer,
opt_state_dict=ckpt_dict[utils.OPT_KEY]
opt_state_dict=checkpoint_dict[utils.OPT_KEY]
if self._resume_from_checkpoint
else None,
)
Expand Down Expand Up @@ -266,37 +257,21 @@ def _setup_model(
) -> nn.Module:
"""
Model initialization has some important considerations:
a. To minimize GPU peak memory, we load the model on CPU with the right
dtype. To ensure that we don't instantiate ``world_size`` number of models,
we initialize on meta_device for all ranks other than rank 0.
b. Rank 0 is also responsible for calling ``load_state_dict`` and loading the
model weights from checkpoint.
c. While wrapping the model with FSDP, we set ``sync_module_states``
to TRUE and broadcast module params and buffers from rank 0.
d. The ``device_id`` param ensures that the FSDP initialization happens on
the correct device.
a. To minimize GPU peak memory, we initialize the model on meta device with
the right dtype
b. All ranks calls ``load_state_dict`` without peaking CPU RAMs since
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do all ranks load state dict now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we shard as we load. Previously we only loaded on rank 0 because we relied on sync_module_states to broadcast to all the other ranks. I think this way should be more memory-efficient because we never need a full copy of the state dict on rank 0

full state dicts are loaded with ``torch.load(mmap=True)``
c. We register (pre-)forward hooks with ``fully_shard`` instead of wrapping `nn.Module`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should say "instead of" this should just explain how we shard, not compare to FSDP1

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"We register (pre) forward hooks" -> why do we need this or what does this add to the docstring? I'd be willing to bet not many users know what this means

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah these are both fair points, lemme update

"""
if self._is_rank_zero:
log.info("FSDP is enabled. Instantiating Model on CPU for Rank 0 ...")
init_start = time.perf_counter()

with utils.set_default_dtype(self._dtype):
model = config.instantiate(cfg_model)

if self._is_rank_zero:
log.info(
f"Model instantiation took {time.perf_counter() - init_start:.2f} secs"
"FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ..."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is FSDP being enabled? Does _is_rank_zero only get set to true if inti_process_group is called? If so that's not very clear from the code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry not sure I fully understand this question. The _is_rank_zero here is just preventing log spew, and the init_process_group logic remains unchanged from the existing version of the recipe

)
init_start = time.perf_counter()

# Load both the model weights. This should happen only on Rank 0
model.load_state_dict(model_state_dict)

else:
# For non-zero ranks, load the model on meta device
with utils.set_default_dtype(self._dtype), torch.device("meta"):
model = config.instantiate(cfg_model)

if self._dtype == torch.bfloat16:
model = model.to(torch.bfloat16)
with utils.set_default_dtype(self._dtype), torch.device("meta"):
model = config.instantiate(cfg_model)

# We currently have two versions of activation checkpointing in this recipe
# for testing and BC purposes. ``enable_activation_checkpointing`` controls
Expand All @@ -314,41 +289,51 @@ def _setup_model(
ac_option,
)

# Wrap the model with FSDP. This will ensure that the model is sharded
# across all available GPUs.
model = FSDP(
module=model,
auto_wrap_policy=utils.get_full_finetune_fsdp_wrap_policy(
memory_efficient_fsdp_wrap=memory_efficient_fsdp_wrap,
modules_to_wrap={modules.TransformerDecoderLayer},
),
cpu_offload=CPUOffload(offload_params=fsdp_cpu_offload),
sharding_strategy=torch.distributed.fsdp.ShardingStrategy.FULL_SHARD,
device_id=self._device,
# this recipe does not currently support mixed precision training
mixed_precision=None,
# Ensure we broadcast params and buffers from rank 0
sync_module_states=True,
# Initialize empty modules on all non-zero ranks
param_init_fn=(
lambda module: module.to_empty(
device=torch.device("cuda"), recurse=False
)
if not self._is_rank_zero
else None
),
)

# Ensure no params and buffers are on meta device
utils.validate_no_params_on_meta_device(model)

# original activation checkpointing (full) - flip the condition above
if enable_activation_checkpointing and ac_mode is None:
utils.set_activation_checkpointing(
model, auto_wrap_policy={modules.TransformerDecoderLayer}
)

fsdp_kwargs = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Misplaced comment] Do we need to remove the logic related to enable_activation_checkpointing above? I don't think we have auto_wrap_policy anymore? Or am I misunderstanding?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah I can see how this is confusing. A bit awkward but auto_wrap_policy for AC != auto_wrap_policy from our dear old friend FSDP. You can see the usage here

if fsdp_cpu_offload:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies if I'm projecting, but I think this can use some comments around what exactly are we offloading to CPU.

fsdp_kwargs["offload_policy"] = CPUOffloadPolicy()
Copy link
Contributor

@felipemello1 felipemello1 Aug 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we expose the sharding strategy like in #1024

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can, but imo at this stage it's a nice-to-have rather than a must-have

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does feel wrong to remove a feature that was just recently contributed though. It feels the goal of moving this out of dev is to integrate it with everything we already have.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It hasn't landed yet though

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My vote would be to expose the sharding strategy - its a really easy lever to trade off memory for performance and it's unfortunate that we still haven't landed this for FSDP1.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To follow up on this, there is not a 1:1 mapping between FSDP1 sharding strategy and FSDP2 APIs. But

(1) for single-node training (all we currently support), the only relevant sharding strategies should be FULL_SHARD, SHARD_GRAD_OP, and NO_SHARD. HYBRID_SHARD and _HYBRID_SHARD_ZERO2 are only distinct from these in the case of multi-node training.
(2) FSDP2 does not support NO_SHARD, in that case they just redirect users to DDP

So then we are just down to FULL_SHARD and SHARD_GRAD_OP, which correspond to setting reshard_after_forward to True and False, respectively. So I can expose this flag in the recipe.

Thanks to @weifengpy for helping to clear this up!


# Shard the model with FSDP
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You already know this, but the block below is giving way too much away in terms of readability. Probably need some well designed utilities here. But apart from utilities, I think you can structure this in a way which is much easier to understand. Remember, torchtune is one of the only libraries with FSDP2 integration. We should showcase this well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is the top thing on my plate to refactor here

for m in reversed(list(model.modules())):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to have utils for these two things? Also, whether as utils or still here, it might be cleaner to follow if you looped through the modules list 2 times, once for ac and once for distributed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think we should have utils. And to be clear, we are not actually doing any AC wrapping here, that happens above. Here we are just checking which modules have been wrapped by AC already.

# TransformerDecoderLayer is wrapped by CheckpointWrapper
# when enable_activation_checkpointing
if enable_activation_checkpointing:
if isinstance(m, CheckpointWrapper):
fully_shard(m, **fsdp_kwargs)
# For large vocab size, we can additionally shard
# the token embeddings and output projections individually
if memory_efficient_fsdp_wrap:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just rename this to "shared_embedding" or something more descriptive?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's not what this is though. This is just for the case where we want to shard token embeddings and output projections in their own unit (instead of just in the top-level model sharding). We found this saves a fair bit of memory for Llama3. I do agree the naming is a bit awkward, definitely open to other suggestions here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're discussing this, can I say that having model specific functionality (this was included for llama3) infiltrating the recipe always made little sense to me. Branding it with something as generic as "memory_efficient_fsdp_wrap" makes this worse. Maybe leaking this into the recipe is unavoidable but I think we can abstract this functionality better.

One idea is to pass the modules we want to wrap separately as string to the utility whci does the wrapping. This would be very similar to how we decide the layers for LoRA. It also does away with the need to call this "memory_efficient"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So a couple thoughts here:

(1) This is an optimization that we only apply to Llama 3, but it should be generally useful for any model with a very large vocab size. Ofc I agree that as we've named it none of these points are obvious.. my initial intention here was to be as non-BC-breaking as possible, but if you guys want a rename of this I am happy to do it.

(2) On passing modules we want to wrap separately as a string: I think this is an interesting idea -- if we iterate over named modules we could probably just do a check on the FQN? The only concern I have is that it could be a bit easy for someone to shoot themselves in the foot here.. we are giving a lot more freedom over configurations that we have not really tried out ourselves.

Either way I do plan to move this into a separate utility. I'm still on the fence about how exactly to expose the individual wrapping of these layers, will try a couple approaches and see what makes sense here.

if isinstance(m, nn.Embedding):
fully_shard(m, **fsdp_kwargs)
if isinstance(m, modules.TransformerDecoder):
fully_shard(m.output, **fsdp_kwargs)
else:
if isinstance(m, modules.TransformerDecoderLayer):
ebsmothers marked this conversation as resolved.
Show resolved Hide resolved
fully_shard(m, **fsdp_kwargs)
fully_shard(model, **fsdp_kwargs)

with utils.set_default_dtype(self._dtype), self._device:
for m in model.modules():
# RoPE is not covered in state dict
if isinstance(m, modules.RotaryPositionalEmbeddings):
m.reset_parameters()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How should I conceptually parse this block? IIUC, at this stage all of the model params are init on meta device (load_from_full_model_state_dict is where we bring them over to GPU unless I'm misremembering). So at this point we're constructing the RoPE buffers on CUDA? Is there a reason this is happening before we load the state dict? Or doesn't matter?

Also, we called this reset_parameters because of an idiosyncrasy with FSDP1. Do we need to carry that over? Can we just make rope_init a public API? It's really hard to reason about reset_parameters if you don't know the details.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason this is happening before we load the state dict? Or doesn't matter?

It shouldn't matter because the RoPE buffers are not in the state dict anyways. This becomes more relevant for LoRA though, where we maybe have existing adapter weights to load. In that case, we either move from meta device on state dict load or on the usual param init, but not both. But even in this case since we only do one or the other I think the order shouldn't matter (lmk if this makes sense or not though)

Can we just make rope_init a public API?

Yeah I think we can do this


utils.load_from_full_model_state_dict(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a brief comment here saying what this is actually doing - helps make sure this function is self explanatory which is how we intended it to be

model, model_state_dict, self._device, self._is_rank_zero, strict=True
)

# Ensure no params and buffers are on meta device
utils.validate_no_params_on_meta_device(model)

if self._is_rank_zero:
log.info(
f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs"
)
memory_stats = utils.get_memory_stats(device=self._device)
utils.log_memory_stats(memory_stats)

Expand All @@ -360,17 +345,13 @@ def _setup_model(
def _setup_optimizer(
self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None
) -> Optimizer:
"""
Set up the optimizer. This method also handles transforing the state dict
for FSDP.
"""
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())

if opt_state_dict:
opt_state_dict = FSDP.optim_state_dict_to_load(
self._model, optimizer, opt_state_dict
utils.load_from_full_optimizer_state_dict(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@weifengpy IIRC there were some APIs related to optimizer and state_dict that were in the process of being moved over to distributed? Is that still the plan? If so, we can just import these from distributed instead of from utils?

Copy link
Contributor

@weifengpy weifengpy Aug 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I worked with @mori360 last month to upstream the implementation into PTD API. Numerically it's on-par now, but we need to verify memory behavior before formal migration pytorch/pytorch#128745

optimizer,
opt_state_dict,
self._device,
)
optimizer.load_state_dict(opt_state_dict)

if self._is_rank_zero:
log.info("Optimizer is initialized.")
Expand All @@ -391,21 +372,17 @@ def _setup_data(

if isinstance(cfg_dataset, ListConfig):
datasets = [
config.instantiate(single_cfg_dataset, self._tokenizer)
config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer)
for single_cfg_dataset in cfg_dataset
]
ds = ConcatDataset(datasets=datasets)
packed = False
else:
ds = config.instantiate(cfg_dataset, self._tokenizer)
ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer)
packed = cfg_dataset.get("packed", False)

sampler = DistributedSampler(
ds,
num_replicas=world_size,
rank=rank,
shuffle=shuffle,
seed=0,
ds, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=0
)
dataloader = DataLoader(
dataset=ds,
Expand All @@ -425,32 +402,48 @@ def _setup_data(

return sampler, dataloader

def save_checkpoint(self, epoch: int) -> None:
def save_checkpoint(
self,
epoch: int,
) -> None:
"""
Save state dict to file. The recipe save_checkpoint method is responsible for
correctly creating the checkpoint dict and passing to the checkpointer.
Checkpoint the state of the recipe. The constructed checkpoint state dict
contains the following information:
- Model weights with key MODEL_KEY
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- Model weights with key MODEL_KEY
- Model weights with key utils.MODEL_KEY

- Relevant recipe state if training is not complete

Checkpointer will save the model weights and recipe state in
different checkpoint files. To correctly resume training from an intermediate checkpoint,
the model weights and recipe state must be provided.
"""
# final dict passed onto the checkpointer
checkpoint_dict = {}

intermediate_checkpoint = epoch + 1 < self.total_epochs
# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
with FSDP.state_dict_type(
cpu_state_dict = utils.get_full_model_state_dict(
self._model,
StateDictType.FULL_STATE_DICT,
FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True),
):
cpu_state_dict = self._model.state_dict()
opt_state_dict = FSDP.optim_state_dict(self._model, self._optimizer)
self._is_rank_zero,
)

if intermediate_checkpoint:
opt_state_dict = utils.get_full_optimizer_state_dict(
self._optimizer,
self._is_rank_zero,
)
else:
opt_state_dict = None

# Now that we have the model and opt state dict, create the actual checkpoint dict
# to be sent to the checkpointer and ultimately written to file
if self._is_rank_zero:

checkpoint_dict.update({utils.MODEL_KEY: cpu_state_dict})

# if training is in-progress, checkpoint the optimizer state as well
if epoch + 1 < self.total_epochs:
# if training is in-progress, checkpoint the optimizer state and recipe state
# as well.
if intermediate_checkpoint:
checkpoint_dict.update(
{
utils.OPT_KEY: opt_state_dict,
Expand All @@ -464,13 +457,12 @@ def save_checkpoint(self, epoch: int) -> None:
self._checkpointer.save_checkpoint(
checkpoint_dict,
epoch=epoch,
intermediate_checkpoint=(epoch + 1 < self.total_epochs),
intermediate_checkpoint=intermediate_checkpoint,
)

def train(self) -> None:
"""
The core training loop. Supports training on subsets of the dataset using the
``max_steps_per_epoch``.
The core training loop.
"""
# clean up before training begins
utils.cleanup_before_training()
Expand Down Expand Up @@ -573,7 +565,7 @@ def train(self) -> None:
def cleanup(self) -> None:
if self._is_rank_zero:
self._metric_logger.close()
torch.distributed.destroy_process_group()
destroy_process_group()


@config.parse
Expand All @@ -590,7 +582,7 @@ def recipe_main(cfg: DictConfig) -> None:
"Distributed finetune recipe should be run via a distributed launcher."
"If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]"
)

os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain what this does?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah this is a good question. This was originally used to save memory when we were using FSDP with sync_module_states=True (see #843). I suspect it shouldn't be needed with our current FSDP2 usage but can also run a quick sanity check here. @weifengpy can you confirm whether we hit dist._broadcast_coalesced in any part of fully_shard?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove setting TORCH_NCCL_AVOID_RECORD_STREAMS in FSDP2 because we get rid of recordstream. dist._broadcast_coalesced was from state dict loading for FSDP1 and it's gone as well

init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this is old, but technically I think we only want nccl if device is cuda. The default for init_process_group just does both which works really well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What other devices do we support right now though?

The default for init_process_group just does both which works really well

Not sure I understand this.. do you mean we should just be calling init_process_group()?

if cfg.get("fsdp_cpu_offload", False):
# Utilize all available CPU cores for intra-op parallelism. This provides ~2x
Expand Down
3 changes: 2 additions & 1 deletion torchtune/utils/_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ def load_from_full_model_state_dict(
full_sd: Dict[str, Any],
device: torch.device,
is_rank_zero: bool,
strict: bool = False,
):
"""
Converting full state dict into a sharded state dict
Expand Down Expand Up @@ -355,7 +356,7 @@ def load_from_full_model_state_dict(
)
sharded_sd[param_name] = nn.Parameter(sharded_tensor)
# choose `assign=True` since we cannot call `copy_` on meta tensor
return model.load_state_dict(sharded_sd, strict=False, assign=True)
return model.load_state_dict(sharded_sd, strict=strict, assign=True)


def get_full_model_state_dict(
Expand Down
Loading