Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable LoRA + FSDP2 #855

Merged
merged 58 commits into from
Jun 3, 2024
Merged

enable LoRA + FSDP2 #855

merged 58 commits into from
Jun 3, 2024

Conversation

weifengpy
Copy link
Contributor

@weifengpy weifengpy commented Apr 24, 2024

how to run it

  • 7B: tune download meta-llama/Llama-2-7b-hf --output-dir /tmp/Llama-2-7b-hf --hf-token <HF_TOKEN> && tune run --nproc_per_node 8 lora_finetune_fsdp2 --config llama2/7B_lora
  • 70B: tune download meta-llama/Llama-2-70b-hf --output-dir /tmp/Llama-2-70b-hf --hf-token <HF_TOKEN> && tune run --nproc_per_node 8 lora_finetune_fsdp2 --config llama2/70B_lora

recipe tests: pytest tests/recipes/test_lora_finetune_fsdp2.py -m integration_test
unit test: pytest tests/torchtune/utils/test_distributed.py -k test_state_dict

Highlights

  • +12% tokens per second (see wandb snapshot) with +0.8% memory reserved
  • 3.2 x speedup in model init: from 250s to 77s
Screenshot 2024-05-16 at 2 09 59 PM

FSDP2 changes

  • init models on device=meta instead of cpu
  • activation checkpointing happens before register FSDP hooks

checkpoint changes

  • loading: load HF checkpoint with mmap=True. convert plain tensor into DTensor
  • saving: gather full state dict on rank0 cpu. convert DTensor into plain tensor
  • resume finetuning from previous checkpoints

optional

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Copy link

pytorch-bot bot commented Apr 24, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/855

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 8fbbc4b with merge base 135cf2e (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 24, 2024
@weifengpy weifengpy changed the title enable LoRA + FSDP2 [WIP] enable LoRA + FSDP2 Apr 24, 2024
@weifengpy weifengpy marked this pull request as draft April 24, 2024 06:47
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Copy link

@awgu awgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think one main question for discussion is what the torchtune folks feel about defining an explicit initializer for the RoPE theta buffer, which would be required to do meta-device init (which should speed up initialization!).

recipes/lora_finetune_distributed.py Outdated Show resolved Hide resolved
recipes/lora_finetune_distributed.py Outdated Show resolved Hide resolved
torchtune/modules/peft/lora.py Outdated Show resolved Hide resolved
torchtune/modules/position_embeddings.py Outdated Show resolved Hide resolved
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@weifengpy weifengpy mentioned this pull request May 1, 2024
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
model, auto_wrap_policy={modules.TransformerDecoderLayer}
)

for m in reversed(list(model.modules())):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My noob understanding here is that you want to account for the Linear module before the TransformerDecoderLayer - is that right? If so, is there a better way to do this? It wasn't immediately obvious to me that reversing this list gets me that outcome - but maybe that's because I haven't played around with the modules function? Disregard if this is a pretty standard way to achieve this

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to quickly mention my opinion:

What we want is a post-order traversal of the modules so that we visit lower modules before higher ones. nn.Module.modules() always gives reverse post-order, so reversing it gives post-order.

FSDP1 hides all of this under the auto_wrap_policy argument, which does the post-order traversal for the user. Personally, I did not want to have an auto_wrap_policy argument for FSDP2 because the name is a misnomer. Auto wrapping is not automatically choosing which modules to wrap for the user -- rather, it just performing a post-order traversal and applying a user-defined policy to determine if a module is wrapped or not. What I am open to though is having some other higher-level utility (say, apply_fsdp()) that does the same thing as auto_wrap_policy but exists outside of the function fully_shard.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this thorough explanation, it helps to understand this code block a lot. I agree it'd be nice to have an apply_fsdp utility but not a major concern here. Can we add code comments here to get the point you described across? (Basically that we are iterating over lower-order modules to higher-level modules and wrapping individual transformer layers. And I assume the separate wrapping of trainable LoRA weights is more related to the point you mentioned today about lower memory, rather than the flat param rationale of grads being allocated per shard?) Users will take this recipe code as a starting point so the more explicit we are here the easier they'll find it to extend.

utils.load_from_full_model_state_dict(
model, lora_weights_state_dict, self._device, self._is_rank_zero
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got you, thanks so much for the explanation! I think something that would be super helpful would be document here in the form of comments the relationship between:

  • the modules on which we call fully_shard
  • init on meta device
  • calling initialize_parameters and reset_parameters

Also I think there was a technical reason with FSDP1 to call the function reset_parameters. Is that still true? Or can we standardize this with initialize_parameters in the modules code? Happy to chat about this offline!


for m in reversed(list(model.modules())):
if (
isinstance(m, nn.Linear)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be LoRALinear? Or am I misunderstanding?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nn.Linear with required_grad=True are referring to lora_a and lora_b

Trunk FSDP1 is wrapping lora_a and lora_b separately thus we are do the same wrapping for parity. But as you mentioned, we can wrap LoRALinear instead so lora_a and lora_b are communicated together

m.lora_a.to_empty(device=self._device)
m.lora_b.to_empty(device=self._device)
m.initialize_parameters()
if isinstance(m, modules.RotaryPositionalEmbeddings):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to clarify, we special handle RoPE because the buffer is not being loaded from a state dict, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's correct

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar comment here, let's document what's happening so that users can easily understand why we initialize these modules separately.

@@ -242,6 +250,79 @@ def lora_wrap_fsdp(module: nn.Module, recurse: bool, **kwargs):
return lora_wrap_fsdp


def load_from_full_model_state_dict(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit unfortunate but we have two different meanings of "full model" IIUC:

  • One is related to FSDP i.e. full vs sharded - is that right?
  • Other is LoRA i.e. full model vs LoRA adapters

The current function is a bit confusing since we pass in the adapter state_dict and this is referred to as full_sd in the function itself.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch. you are right "full" is opposing to "sharded" here. What do you think if rename to "local"? load_from_local_model_state_dict ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add docstrings for the functions added in this file?

for param_name, full_tensor in full_sd.items():
sharded_meta_param = meta_sharded_sd.get(param_name)
sharded_tensor = distribute_tensor(
full_tensor, sharded_meta_param.device_mesh, sharded_meta_param.placements
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is device_mesh and placements information coming from?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's from fully_shard(model, mesh). we are using default mesh where every rank serves FSDP, since there is no 2D/3D parallasim involved

after fully_shard(model, mesh), model.parameters() is converted from plain tensor to DTensor with mesh

@weifengpy
Copy link
Contributor Author

synced with @kartikayk on action items

  • add docstring for meta init, state dict, migration to DCP
  • place FSDP2 in receipes/dev/
  • attach wandb loss curve
  • pending feedbacks on unit tests

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking great! Most of my comments are minor and around documentation.

I saw there was also some discussion around the location for this recipe. I am inclined to agree with the point you and @kartikayk discussed around putting it in recipes/dev. I don't want that to be a long-term home for it, but also want to make sure we don't break folks who are using the current recipe and aren't on the latest version of PyTorch yet.

model, auto_wrap_policy={modules.TransformerDecoderLayer}
)

for m in reversed(list(model.modules())):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this thorough explanation, it helps to understand this code block a lot. I agree it'd be nice to have an apply_fsdp utility but not a major concern here. Can we add code comments here to get the point you described across? (Basically that we are iterating over lower-order modules to higher-level modules and wrapping individual transformer layers. And I assume the separate wrapping of trainable LoRA weights is more related to the point you mentioned today about lower memory, rather than the flat param rationale of grads being allocated per shard?) Users will take this recipe code as a starting point so the more explicit we are here the easier they'll find it to extend.

m.lora_a.to_empty(device=self._device)
m.lora_b.to_empty(device=self._device)
m.initialize_parameters()
if isinstance(m, modules.RotaryPositionalEmbeddings):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar comment here, let's document what's happening so that users can easily understand why we initialize these modules separately.

Comment on lines 349 to 351
# LoRA hyper-params needed for merging weights while saving checkpoints
self._lora_rank = cfg_model.lora_rank
self._lora_alpha = cfg_model.lora_alpha
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just fyi might need a merge, I think these are actually now defined in the version of recipe we have in main

Comment on lines 54 to 56
@pytest.mark.skipif(
version.parse(torch.__version__).base_version < "2.4.0", reason=""
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make sure I understand, our distributed LoRA recipe will now only work on torch >= 2.4.0, is that correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's correct. FSDP2 will be released in 2.4.0, although it has been in nightly for a while

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for clarifying! In that case I think it makes sense to create dev/recipes/lora_finetune_fsdp2.py (or something like that) in the short-term, then migrate to replace recipes/lora_finetune_distributed.py once we've socialized and gotten enough users onto >= 2.4.0. Let me know how that sounds to you.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

proposing a middle ground: type hinting with "FSDPModule" without try...except

def load_from_full_model_state_dict(
    model: "FSDPModule",

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do also have gpu_test in torchtune, can we use that here for the sake of consistency?

@@ -242,6 +250,79 @@ def lora_wrap_fsdp(module: nn.Module, recurse: bool, **kwargs):
return lora_wrap_fsdp


def load_from_full_model_state_dict(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add docstrings for the functions added in this file?

torchtune/utils/_distributed.py Show resolved Hide resolved
)
sharded_sd[param_name] = nn.Parameter(sharded_tensor)
# choose `assign=True` since we cannot call `copy_` on meta tensor
model.load_state_dict(sharded_sd, strict=False, assign=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we catch missing and unexpected keys from load_state_dict with strict=False what format will the keys be in? Previously with FSDP1 the keys contained all the info about FSDP wrapping. E.g.model.layers.0._fsdp_flat_param.attn.q_proj.weight (probably not exactly right but something like that). Will that still be the case here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for FSDP2, it's clean FQNs without FSDP prefix. For example, layers.0.attn.q_proj.lora_a.weight

FSDP2 is clean because 1) fully_shard register hooks instead wrap nn.Module, 2) fully_shard set module.__class__ = type(f"FSDP{cls.__name__}", (FSDPModule, cls), dct) https://fburl.com/i20yr3s2

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great! I think this means we can actually do validation of LoRA state dict load more cleanly (note that we actually have two separate utilities for this for the single-device vs distributed case because of this FSDP prefix issue). Not a concern for this PR but this will allow us to clean up our code a bit

torchtune/utils/_distributed.py Show resolved Hide resolved
Comment on lines 307 to 321
for pid, full_pid in zip(param_group[PARAMS], full_param_group[PARAMS]):
if pid not in state:
continue
param_state = state[pid]
full_param_state = full_state[full_pid]
for attr, full_tensor in full_param_state.items():
sharded_tensor = param_state[attr]
if isinstance(sharded_tensor, DTensor):
param_state[attr] = distribute_tensor(
full_tensor,
sharded_tensor.device_mesh,
sharded_tensor.placements,
)
else:
param_state[attr] = full_tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be useful to add some comments for this code block

@weifengpy weifengpy marked this pull request as draft May 31, 2024 20:51
weifengpy and others added 5 commits May 31, 2024 16:59
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@weifengpy weifengpy marked this pull request as ready for review June 1, 2024 01:13
@weifengpy weifengpy requested a review from ebsmothers June 1, 2024 02:09
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple small comments on the configs but otherwise no major concerns. Really excited to see this in our library!

Comment on lines 9 to 14
# tune run --nnodes 1 --nproc_per_node 4 lora_finetune_distributed --config llama2/13B_lora
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run --nnodes 1 --nproc_per_node 4 lora_finetune_distributed --config llama2/13B_lora checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably need to do a find and replace of lora_finetune_distributed -> lora_finetune_fsdp2 in all three config files

Comment on lines 1 to 2
# Config for multi-device LoRA in lora_finetune_distributed.py
# using a Llama2 13B model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can update this header to mention that this config is for the recipe using FSDP2 (I know the config file is the same, but nice visibility to just explicitly call it out at the top of the file)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch! updated to lora_finetune_fsdp2 and mentioned FSDP2

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@weifengpy weifengpy merged commit 71741df into pytorch:main Jun 3, 2024
29 checks passed
weifengpy added a commit to weifengpy/torchtune that referenced this pull request Jun 4, 2024
Comment on lines +320 to +322
# iterating from lowerer modules to higher
# eg grouping lora adapters before transformer block
for m in reversed(list(model.modules())):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way, another option here is to just make two passes if that is clearer.

for module in model.modules():
    if <LoRA adapter>:
        fully_shard(module)
for module in model.modules():
    if <transformer block>:
        fully_shard(module)

maximegmd pushed a commit to maximegmd/torchtune that referenced this pull request Jul 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants