-
Notifications
You must be signed in to change notification settings - Fork 404
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Full finetune FSDP2 recipe #1287
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1287
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit ba1cb92 with merge base 00bbd53 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
recipes/full_finetune_distributed.py
Outdated
if cfg.get("fsdp_cpu_offload", False): | ||
# Utilize all available CPU cores for intra-op parallelism. This provides ~2x | ||
# speed up when benchmarking fused AdamW on CPU | ||
utils.set_torch_num_threads() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
setting num of threads is a critical lesson we learned in cpu offloading. the perf indication is huge. curious why removing it ? do not need cpu offloading anymore?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah thanks, yes this should stay in. Will update
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for driving this! The code looks good. Happy to see the similarity between lora and full finetuning and many of the code can be reused
I did notice the perf difference in snapshot (wandb). would love to learn more from profilers traces and compare FSDP1 and FSDP2 side by side
I did not stamp only because this PR seems to for discussion purpose
Thanks @weifengpy for the review! We intend to land this (hopefully soon), mainly I left the PR in draft status because there is some additional testing I'd like to do. Re the perf difference, this run is for Llama3 which does some slightly custom sharding. Let me re-run on Llama2 and see if the results persist. The suggestion to profile to understand the difference makes sense to me -- while I'm at it I may even integrate the profiler from our LoRA recipes into full finetune (since this is something folks have been asking about anyways) |
recipes/full_finetune_distributed.py
Outdated
correctly creating the checkpoint dict and passing to the checkpointer. | ||
Checkpoint the state of the recipe. The constructed checkpoint state dict | ||
contains the following information: | ||
- Model weights with key MODEL_KEY |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Model weights with key MODEL_KEY | |
- Model weights with key utils.MODEL_KEY |
recipes/full_finetune_distributed.py
Outdated
if enable_activation_checkpointing and ac_mode is None: | ||
utils.set_activation_checkpointing( | ||
model, auto_wrap_policy={modules.TransformerDecoderLayer} | ||
) | ||
|
||
fsdp_kwargs = {} | ||
if fsdp_cpu_offload: | ||
fsdp_kwargs["offload_policy"] = CPUOffloadPolicy() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we expose the sharding strategy like in #1024
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can, but imo at this stage it's a nice-to-have rather than a must-have
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does feel wrong to remove a feature that was just recently contributed though. It feels the goal of moving this out of dev is to integrate it with everything we already have.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It hasn't landed yet though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My vote would be to expose the sharding strategy - its a really easy lever to trade off memory for performance and it's unfortunate that we still haven't landed this for FSDP1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To follow up on this, there is not a 1:1 mapping between FSDP1 sharding strategy and FSDP2 APIs. But
(1) for single-node training (all we currently support), the only relevant sharding strategies should be FULL_SHARD, SHARD_GRAD_OP, and NO_SHARD. HYBRID_SHARD and _HYBRID_SHARD_ZERO2 are only distinct from these in the case of multi-node training.
(2) FSDP2 does not support NO_SHARD, in that case they just redirect users to DDP
So then we are just down to FULL_SHARD and SHARD_GRAD_OP, which correspond to setting reshard_after_forward to True and False, respectively. So I can expose this flag in the recipe.
Thanks to @weifengpy for helping to clear this up!
sounds good. feel free to request another review when it's ready
great! I am happy to take a deeper look at the trace once we have it. hopefully to gain new knowledge to make FSDP2 better |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this update! I just added some comments and general questions as this is my first time really looking at FSDP2.
recipes/full_finetune_distributed.py
Outdated
the right dtype | ||
b. All ranks calls ``load_state_dict`` without peaking CPU RAMs since | ||
full state dicts are loaded with ``torch.load(mmap=True)`` | ||
c. We register (pre-)forward hooks with ``fully_shard`` instead of wrapping `nn.Module` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we should say "instead of" this should just explain how we shard, not compare to FSDP1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"We register (pre) forward hooks" -> why do we need this or what does this add to the docstring? I'd be willing to bet not many users know what this means
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah these are both fair points, lemme update
the correct device. | ||
a. To minimize GPU peak memory, we initialize the model on meta device with | ||
the right dtype | ||
b. All ranks calls ``load_state_dict`` without peaking CPU RAMs since |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do all ranks load state dict now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because we shard as we load. Previously we only loaded on rank 0 because we relied on sync_module_states to broadcast to all the other ranks. I think this way should be more memory-efficient because we never need a full copy of the state dict on rank 0
log.info( | ||
f"Model instantiation took {time.perf_counter() - init_start:.2f} secs" | ||
"FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ..." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is FSDP being enabled? Does _is_rank_zero only get set to true if inti_process_group is called? If so that's not very clear from the code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry not sure I fully understand this question. The _is_rank_zero
here is just preventing log spew, and the init_process_group
logic remains unchanged from the existing version of the recipe
recipes/full_finetune_distributed.py
Outdated
if enable_activation_checkpointing and ac_mode is None: | ||
utils.set_activation_checkpointing( | ||
model, auto_wrap_policy={modules.TransformerDecoderLayer} | ||
) | ||
|
||
fsdp_kwargs = {} | ||
if fsdp_cpu_offload: | ||
fsdp_kwargs["offload_policy"] = CPUOffloadPolicy() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does feel wrong to remove a feature that was just recently contributed though. It feels the goal of moving this out of dev is to integrate it with everything we already have.
recipes/full_finetune_distributed.py
Outdated
for m in reversed(list(model.modules())): | ||
# For large vocab size, we can additionally shard | ||
# the token embeddings and output projections individually | ||
if memory_efficient_fsdp_wrap: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we just rename this to "shared_embedding" or something more descriptive?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's not what this is though. This is just for the case where we want to shard token embeddings and output projections in their own unit (instead of just in the top-level model sharding). We found this saves a fair bit of memory for Llama3. I do agree the naming is a bit awkward, definitely open to other suggestions here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we're discussing this, can I say that having model specific functionality (this was included for llama3) infiltrating the recipe always made little sense to me. Branding it with something as generic as "memory_efficient_fsdp_wrap" makes this worse. Maybe leaking this into the recipe is unavoidable but I think we can abstract this functionality better.
One idea is to pass the modules we want to wrap separately as string to the utility whci does the wrapping. This would be very similar to how we decide the layers for LoRA. It also does away with the need to call this "memory_efficient"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So a couple thoughts here:
(1) This is an optimization that we only apply to Llama 3, but it should be generally useful for any model with a very large vocab size. Ofc I agree that as we've named it none of these points are obvious.. my initial intention here was to be as non-BC-breaking as possible, but if you guys want a rename of this I am happy to do it.
(2) On passing modules we want to wrap separately as a string: I think this is an interesting idea -- if we iterate over named modules we could probably just do a check on the FQN? The only concern I have is that it could be a bit easy for someone to shoot themselves in the foot here.. we are giving a lot more freedom over configurations that we have not really tried out ourselves.
Either way I do plan to move this into a separate utility. I'm still on the fence about how exactly to expose the individual wrapping of these layers, will try a couple approaches and see what makes sense here.
recipes/full_finetune_distributed.py
Outdated
fsdp_kwargs["offload_policy"] = CPUOffloadPolicy() | ||
|
||
# Shard the model with FSDP | ||
for m in reversed(list(model.modules())): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make sense to have utils for these two things? Also, whether as utils or still here, it might be cleaner to follow if you looped through the modules list 2 times, once for ac and once for distributed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I think we should have utils. And to be clear, we are not actually doing any AC wrapping here, that happens above. Here we are just checking which modules have been wrapped by AC already.
@@ -590,7 +582,7 @@ def recipe_main(cfg: DictConfig) -> None: | |||
"Distributed finetune recipe should be run via a distributed launcher." | |||
"If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]" | |||
) | |||
|
|||
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" | |||
init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know this is old, but technically I think we only want nccl if device is cuda. The default for init_process_group just does both which works really well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What other devices do we support right now though?
The default for init_process_group just does both which works really well
Not sure I understand this.. do you mean we should just be calling init_process_group()
?
recipes/full_finetune_distributed.py
Outdated
@@ -590,7 +582,7 @@ def recipe_main(cfg: DictConfig) -> None: | |||
"Distributed finetune recipe should be run via a distributed launcher." | |||
"If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]" | |||
) | |||
|
|||
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain what this does?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yeah this is a good question. This was originally used to save memory when we were using FSDP with sync_module_states=True
(see #843). I suspect it shouldn't be needed with our current FSDP2 usage but can also run a quick sanity check here. @weifengpy can you confirm whether we hit dist._broadcast_coalesced
in any part of fully_shard
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can remove setting TORCH_NCCL_AVOID_RECORD_STREAMS in FSDP2 because we get rid of recordstream. dist._broadcast_coalesced
was from state dict loading for FSDP1 and it's gone as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do the recipe tests need any updates? I'd guess not? Can you also test checkpoint reads and writes and make sure memory looks fine?
recipes/full_finetune_distributed.py
Outdated
StateDictType, | ||
from torch.distributed import destroy_process_group, init_process_group | ||
from torch.distributed._composable.fsdp import CPUOffloadPolicy, fully_shard | ||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I might be misremembering but there was a discussion at some point around depending on a public API instead of something under "_checkpoint". Do you remember what that was?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we're thinking about the same thing, it may have been the suggestion to move to torch.utils.checkpoint instead of relying on distributed APIs for AC. If it's all the same to you I would punt on that here
recipes/full_finetune_distributed.py
Outdated
@@ -192,37 +189,31 @@ def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: | |||
|
|||
def setup(self, cfg: DictConfig) -> None: | |||
""" | |||
Sets up the recipe state correctly. This includes setting recipe attributes based | |||
on the ``resume_from_checkpoint`` flag. | |||
Setup the recipe state. This includes recipe state (if resume_from_checkpoint is True), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This sentence reads funny - just duplicating "recipe state" twice.
self._model = self._setup_model( | ||
cfg_model=cfg.model, | ||
enable_activation_checkpointing=cfg.enable_activation_checkpointing, | ||
memory_efficient_fsdp_wrap=cfg.get("memory_efficient_fsdp_wrap", False), | ||
fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False), | ||
model_state_dict=ckpt_dict[utils.MODEL_KEY], | ||
model_state_dict=checkpoint_dict[utils.MODEL_KEY], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Misplaced Comment] I never fully understood memory_efficient_fsdp_wrap
, but do we still need that when we migrate to FSDP2?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I think we do still need some version of it, otherwise our peak memory will be higher on Llama3 models
recipes/full_finetune_distributed.py
Outdated
the right dtype | ||
b. All ranks calls ``load_state_dict`` without peaking CPU RAMs since | ||
full state dicts are loaded with ``torch.load(mmap=True)`` | ||
c. We register (pre-)forward hooks with ``fully_shard`` instead of wrapping `nn.Module` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"We register (pre) forward hooks" -> why do we need this or what does this add to the docstring? I'd be willing to bet not many users know what this means
recipes/full_finetune_distributed.py
Outdated
for m in reversed(list(model.modules())): | ||
# For large vocab size, we can additionally shard | ||
# the token embeddings and output projections individually | ||
if memory_efficient_fsdp_wrap: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we're discussing this, can I say that having model specific functionality (this was included for llama3) infiltrating the recipe always made little sense to me. Branding it with something as generic as "memory_efficient_fsdp_wrap" makes this worse. Maybe leaking this into the recipe is unavoidable but I think we can abstract this functionality better.
One idea is to pass the modules we want to wrap separately as string to the utility whci does the wrapping. This would be very similar to how we decide the layers for LoRA. It also does away with the need to call this "memory_efficient"
recipes/full_finetune_distributed.py
Outdated
if isinstance(m, modules.TransformerDecoder): | ||
fully_shard(m.output, **fsdp_kwargs) | ||
# TransformerDecoderLayer is wrapped by CheckpointWrapper | ||
# when enable_activation_checkpointing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# when enable_activation_checkpointing | |
# when enable_activation_checkpointing is set to True |
recipes/full_finetune_distributed.py
Outdated
with utils.set_default_dtype(self._dtype), self._device: | ||
for m in model.modules(): | ||
# RoPE is not covered in state dict | ||
if isinstance(m, modules.RotaryPositionalEmbeddings): | ||
m.reset_parameters() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How should I conceptually parse this block? IIUC, at this stage all of the model params are init on meta device (load_from_full_model_state_dict
is where we bring them over to GPU unless I'm misremembering). So at this point we're constructing the RoPE buffers on CUDA? Is there a reason this is happening before we load the state dict? Or doesn't matter?
Also, we called this reset_parameters
because of an idiosyncrasy with FSDP1. Do we need to carry that over? Can we just make rope_init
a public API? It's really hard to reason about reset_parameters
if you don't know the details.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason this is happening before we load the state dict? Or doesn't matter?
It shouldn't matter because the RoPE buffers are not in the state dict anyways. This becomes more relevant for LoRA though, where we maybe have existing adapter weights to load. In that case, we either move from meta device on state dict load or on the usual param init, but not both. But even in this case since we only do one or the other I think the order shouldn't matter (lmk if this makes sense or not though)
Can we just make rope_init a public API?
Yeah I think we can do this
if isinstance(m, modules.RotaryPositionalEmbeddings): | ||
m.reset_parameters() | ||
|
||
utils.load_from_full_model_state_dict( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a brief comment here saying what this is actually doing - helps make sure this function is self explanatory which is how we intended it to be
if opt_state_dict: | ||
opt_state_dict = FSDP.optim_state_dict_to_load( | ||
self._model, optimizer, opt_state_dict | ||
utils.load_from_full_optimizer_state_dict( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@weifengpy IIRC there were some APIs related to optimizer and state_dict that were in the process of being moved over to distributed? Is that still the plan? If so, we can just import these from distributed instead of from utils?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I worked with @mori360 last month to upstream the implementation into PTD API. Numerically it's on-par now, but we need to verify memory behavior before formal migration pytorch/pytorch#128745
Yeah recipe tests are all passing. Will update with tests for checkpoints and memory results |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1287 +/- ##
===========================================
- Coverage 68.31% 27.10% -41.22%
===========================================
Files 255 258 +3
Lines 11796 11915 +119
===========================================
- Hits 8059 3229 -4830
- Misses 3737 8686 +4949 ☔ View full report in Codecov by Sentry. |
This PR migrates our full finetune recipe onto FSDP2.
Main changes
shard_model
for sharding our models with FSDP2. The recommended way to shard FSDP2 models is by iterating over modules in a bottom-up fashion. As a result, this utility takes in a list ofshard_conditions
. Each shard condition is a callable taking a module name and the module itself (i.e. a pair frommodel.named_modules()
) and returning True if the module should be sharded (and False otherwise). We use a list of shard conditions to more easily decouple different things we might want to check (e.g. whether a module is using activation checkpointing, whether we are doing any custom sharding, or whether LoRA is applied (coming in a subsequent PR)). If any ofshard_conditions
return True, the module will be sharded.shard_model
will also take any FSDP-related fields likecpu_offload
andreshard_after_forward
(~fkasharding_strategy
)memory_efficient_fsdp_wrap
flag. Instead, we can provide an optional configcustom_sharded_layers
which is more flexible. For our existing Llama3 8B sharding, we can pass['tok_embeddings', 'output']
to tell FSDP to shard these modules separatelyOther smaller changes:
reset_parameters()
and ontorope_init()
load_model_from_full_state_dict
that's no longer neededTest plan
Apart from green CI,
Test Llama3 8B perf and memory
On both this PR and main, run
Results are as follows (note that tok/sec plots are smoothed in all figures for clarity)
Other model tests
Note: the losses curves for Gemma 7B differ from those with FSDP1. However, Gemma 7B loss curves appear to be non-reproducible with FSDP1, see #1303.
Phi3
Command:
Gemma 2B
Qwen2 0.5B
Test checkpoint resume
First run
Then run
Functionality tests
Custom wrapping
Tried out with Llama2 7B, can see some reduction in reserved memory with no perf impact
Selective AC
(Not testing memory or perf here, just making sure the run succeeds, since this is still an experimental feature)