Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streaming offloading in (q)lora single device #1443

Merged
merged 11 commits into from
Sep 16, 2024

Conversation

janeyx99
Copy link
Contributor

@janeyx99 janeyx99 commented Aug 28, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Please link to any issues this PR addresses.

Changelog

What are the changes made in this PR?

Add a streaming activations offloading API to the lora finetune single device script and relevant recipes. All changes are in the newly added context manager, which streams activations to CPU and back with heuristical synchronization. The one thing that may be weird: for peak performance, we don't want to offload the last output computation of TransformerDecoder, because it does very little compute on super large memory. So offloading is actually detrimental, due to the time to offload being much longer than the compute for overlapping. So in this PR, I disable offloading by wrapping a no-op context manager around the code. This shouldn't have an impact on any existing recipe even if no saved tensor hooks were applied (as it still does AC), but I can gate the logic on only when activation_offloading is enabled.

This change depends on pytorch/pytorch#134728 to make sure the post node hooks actually run.
For offloading to work with NF4Tensor, we also need to modify AO to give NF4 more overloads. There will be a PR coming up for that.

Test plan

Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.)

I have run this on llama3 with seq len 8k, bs2 to observe memory savings with <1% perf hit per step. The command:

tune run lora_finetune_single_device --config llama3/8B_qlora_single_device \
optimizer_in_bwd=False \
enable_activation_checkpointing=True \
optimizer._component_=bitsandbytes.optim.PagedAdamW8bit \
compile=False \
model.lora_attn_modules="[q_proj,v_proj]" \
model.apply_lora_to_mlp=False \
model.apply_lora_to_output=False \
model.lora_rank=8 \
model.lora_alpha=16 \
dataset.source=Yukang/LongAlpaca-12k \
dataset.packed=False \
dataset.split=train[:10%] \
dataset.train_on_input=True \
tokenizer.max_seq_len=8192 \
metric_logger=torchtune.utils.metric_logging.StdoutLogger \
metric_logger.project=recipe_profiling \
log_every_n_steps=1 \
log_peak_memory_stats=True \
gradient_accumulation_steps=1 \
max_steps_per_epoch=4 \
epochs=1 \
batch_size=2 \
metric_logger.name=llama3__qlora__seqlen_8192__act_ckpt_True__act_off_True__bs2 \
profiler.enabled=True \
profiler.profile_memory=True \
profiler.with_stack=True \
profiler.wait_steps=0 \
profiler.warmup_steps=2 \
profiler.active_steps=2 \
profiler.num_cycles=1
  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Example of docstring:


Example in our docs: https://pytorch.org/torchtune/main/tutorials/qat_finetune.html#applying-qat-to-llama3-models

  • I did not change any public API;
  • I have added an example to docs or docstrings;

With activations offloading, we use only 10.1GiB peak memory for batchsize 2 seqlen 8k compared to 13.2GiB for batchsize 1.

image

Runtime, with offloading compared to without was only slower 20ms for a step of 23s, meaning <1% slowdown for llama3-8B. This is because almost all the comm is overlapped.
[INTERNAL ONLY] https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/janeyx_731fbe41-9354-4f14-b308-83166146c376_qlora_AC_stream_AO_noop_bs2_llama3_4steps.json.gz

Copy link

pytorch-bot bot commented Aug 28, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1443

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 8c91a32 with merge base df29d8a (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 28, 2024
@janeyx99 janeyx99 marked this pull request as draft August 28, 2024 22:45
@janeyx99 janeyx99 changed the title [DRAFT] non-streaming offloading in single device [DRAFT][WIP] streaming offloading in single device Aug 28, 2024
@janeyx99 janeyx99 force-pushed the activations-offloading-streams branch from 8b3aeca to f88c44a Compare August 30, 2024 16:33
@janeyx99 janeyx99 changed the title [DRAFT][WIP] streaming offloading in single device [DRAFT] Streaming offloading in single device Sep 1, 2024
@janeyx99 janeyx99 changed the title [DRAFT] Streaming offloading in single device [DRAFT] Streaming offloading in (q)lora single device Sep 1, 2024
@janeyx99 janeyx99 marked this pull request as ready for review September 1, 2024 07:23
@janeyx99 janeyx99 changed the title [DRAFT] Streaming offloading in (q)lora single device Streaming offloading in (q)lora single device Sep 1, 2024
@codecov-commenter
Copy link

codecov-commenter commented Sep 1, 2024

Codecov Report

Attention: Patch coverage is 0% with 144 lines in your changes missing coverage. Please review.

Project coverage is 26.84%. Comparing base (68d4f3e) to head (65936bc).
Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
torchtune/training/_activation_offloading.py 0.00% 134 Missing ⚠️
recipes/lora_finetune_single_device.py 0.00% 10 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (68d4f3e) and HEAD (65936bc). Click for more details.

HEAD has 4 uploads less than BASE
Flag BASE (68d4f3e) HEAD (65936bc)
5 1
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1443       +/-   ##
===========================================
- Coverage   73.08%   26.84%   -46.24%     
===========================================
  Files         286      287        +1     
  Lines       13828    13953      +125     
===========================================
- Hits        10106     3746     -6360     
- Misses       3722    10207     +6485     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@felipemello1 felipemello1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Very exciting results!!

I did a first pass, but I dont think that you should make changes just yet, until we are sure that streaming is the way to go. I can do some testing on my end this week, but we need to check the trade-off vs compile. If compile + no offloading > no compile + offloading, we may need to go back to the drawing board.

However, since the PR is up, maybe this would make it easier for folks from pytorch to investigate streaming support? cc: @weifengpy

On a second note, if we do proceed with streaming, I would like to brainstorm a bit about how we can avoid touching the transformer class.

thanks again!

# shape: [b, seq_len, out_dim]
output = self.output(h).float()
# Disable activations offloading locally
from torchtune.utils._activations_offloading.offload_ctx_mgr import NoOpManager
Copy link
Contributor

@felipemello1 felipemello1 Sep 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:O

is there anyway that we can do it outside of the transformer?

for example:
NoOpWrapper(model.output)

I think that our transformer class it becoming a bit too crazy haha

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or in the context manager ignore a module if module in list_ignore_module

torchtune/utils/_activations_offloading/offload_ctx_mgr.py Outdated Show resolved Hide resolved
self.curr_autograd_node = None

# platform util functions
def verify_sufficient_virtual_memory():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a motivation to define so many functions within init? i.e. why not do it as a method of the class or as another function outside of the class?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's cuz everything gets accessed from the init, so everything got defined here for linearity. I don't have any real preference though. Is there a reason to move them out? If so, where would be best? (In the class but as a class method or completely out of the class.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh actually, I know why everything's in the init lol. It's cuz pack can't take in any variables other than the tensor --> so we can't make that a class method that takes in self. And technically we could move it to the global space and have OffloadActivations write a bunch of metadata to global, but that doesn't seem better.

# First, sync back and dereference previously offloaded tensors
# as the offloading should be done sufficiently long ago.
for k in [x for x in self.fwd_stash.keys()]:
if k <= tensor_id - self.max_fwd_stash_size:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as a rule of thumb, if we can replace "x", "k" and "ev" with meaningful var names, i think it is worth to do so

recipes/lora_finetune_single_device.py Outdated Show resolved Hide resolved
recipes/lora_finetune_single_device.py Outdated Show resolved Hide resolved
recipes/lora_finetune_single_device.py Outdated Show resolved Hide resolved
@janeyx99
Copy link
Contributor Author

janeyx99 commented Sep 3, 2024

@felipemello1 thanks for the review.

If compile + no offloading > no compile + offloading, we may need to go back to the drawing board.

What does > mean here? Are we talking runtime? memory usage? both?

@felipemello1
Copy link
Contributor

If compile + no offloading > no compile + offloading, we may need to go back to the drawing board.

What does > mean here? Are we talking runtime? memory usage? both?

memory wise, i dont think that compile can be better. So probably the correct comparison should be:

"if compile + some offloading method that works with compile > no compile + stream offloading, we may need to go back to the drawing board."

What do you think? E.g. if using tensor hooks is 15% slower than stream, but compile makes it 2x faster, and saves more memory on top of that, then it may be better to use tensor hooks.

@weifengpy
Copy link
Contributor

However, since the PR is up, maybe this would make it easier for folks from pytorch to investigate streaming support? cc: @weifengpy

Hey Felipe, Do you mean torch.compile + streaming? that would be another Will Feng

@felipemello1
Copy link
Contributor

Sorry! I meant @yf225

@yf225
Copy link
Contributor

yf225 commented Sep 3, 2024

I believe with per-layer compile, compile + offloading should functionally work (since we are doing offloading outside of the compiled graph). But yes performance testing is still needed.

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is awesome! Mostly just have a bunch of noob questions 😃

# managing cpu memory
self.use_pin_memory: bool = use_pin_memory
self.virtual_memory_safe_pct = (
0.60 # we should not exceed this percentage of memory
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is this value determined? Also I guess we don't actually do anything with it, just warn?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, this is arbitrary + was Less's idea. I guess it would print fair warnings until someone's RAM becomes busted lol

# -------- core pack / unpack work --------
def pack_tensor(activation: torch.Tensor) -> int:
# activations are passed in during forward pass - from here we take over and return a unique id
if self.is_first_forward_call:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noob q: what is the "first" here referring to?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So a super contrived fwd could be thought of as a collection of ops, aka a collection of nodes:

def fwd(z):
    a = z * z
    b = sin(a)
    c = b + 1

In every step, you can imagine the fwd gets called once. So the first forward pass is the first step overall in training. And the first forward call is the first operation in each fwd, so here, it would be the a = z * z.

torchtune/utils/_activations_offloading/offload_ctx_mgr.py Outdated Show resolved Hide resolved
Comment on lines 593 to 600
maybe_offload_ctx = OffloadActivations(use_streams=True)
noop_ctx = NoOpManager()
self._model.output.register_forward_pre_hook(
lambda *args: noop_ctx.__enter__()
)
self._model.output.register_forward_hook(
lambda *args: noop_ctx.__exit__(), always_call=True
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just so I understand, this is our hack to ensure that output activations do not get offloaded, right? (Nice hack though.) Is this just cause those are actually too hefty, or are there other reasons for this (e.g. interaction with our chunked output layer)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The compute for them is super small compared to offloading them (cuz they're chunky), and they're used pretty much right away so we'd waste time moving them to CPU and waiting for them to come back. I'll add a comment for this


# If we're on a new node, mark prev node's tensors to be freed later
if (
graph_id == self.curr_graph_id
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this is just cause I don't understand what _current_graph_task_id represents, but when would equality not hold here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Different backwards could maybe be occurring at the same time--this makes sure that we're tracking for the right graph. I don't know many details on this and followed @albanD's advice here.

torchtune/utils/_activations_offloading/offload_ctx_mgr.py Outdated Show resolved Hide resolved
torchtune/utils/_activations_offloading/offload_ctx_mgr.py Outdated Show resolved Hide resolved
torchtune/utils/_activations_offloading/offload_ctx_mgr.py Outdated Show resolved Hide resolved
Comment on lines 218 to 220
if unpack_tensor_id in self.fwd_stash:
maybe_gpu_tensor = self.fwd_stash[unpack_tensor_id][0]
brought_back_from_cpu = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I don't fully follow this; do we need to keep on CPU in this case cause it's one of the last max_fwd_stash_size tensors?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to keep it on CPU, but we don't have preemption to know whether it's one of the last in the fwd, so we'd already have kicked off the moving to CPU. Here, we're just taking advantage of the fact that the fwd_stash must be 1 large, and we're telling the bwd to just pull from the stash if there's still something in it instead of waiting for the tensor to be pulled from CPU to GPU.

@felipemello1
Copy link
Contributor

I see some conflicts. Please ping me after they are solved and I will review the PR again, so we can land it.

@janeyx99 janeyx99 force-pushed the activations-offloading-streams branch from 3dc8147 to 9a91b9b Compare September 6, 2024 14:52
@janeyx99 janeyx99 force-pushed the activations-offloading-streams branch 4 times, most recently from 38a3098 to 65936bc Compare September 9, 2024 19:27
Copy link
Contributor

@felipemello1 felipemello1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

first pass. I will read more carefully later

in order to overlap the extra communication with the computation to hide the extra runtime. As the communication
workload is variable depending on the number and size of tensors being offloaded, it is common to not offload every
single activation. In fact, once can use offloading in conjunction with activations checkpointing, where all
activations will either be recomputed later in the backward or brought back from the CPU.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I remember, this only works if activation_checkpointing is True. Is that still right? If so, we should probably update this doc and add to the recipes to raise and error or set AC=True automatically.

Another option, which i would prefer, is to investigate allowing offloading without AC, since streaming seems promising

nit: I believe you meant "one can use offloading" instead of "once can use offloading

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean by "this" in the first sentence? Activations offloading works when AC is false as well, it's just super slow.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i remember trying to use only offloading, with AC=False, and it broke. Maybe its not the case anymore?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea it shouldn't break! It should just be reaaaaally slow

*Sounds great! How do I use it?*

To enable activation offloading, use the ``enable_activation_offloading`` config entry or flag
in our lora finetuning single device recipe, e.g. ``enable_activation_offloading=True``.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i dont think that this will stay exclusive to lora finetuning. Maybe you are suggesting that this should be updated when it becomes available in full_single_device?

PS: do you plan to do it? (its fine if the answer is no, just checking :) )

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also add a note about offloading in single device vs FSDP, since in distributed it is another command?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea I was thinking we should mention lora finetuning single device for now in particular as it's only enabled for those. It would be important to widen this when more support is added --> I'm happy to do it.

In distributed, it should be the same command, but I haven't tested at all, so did not want to promise anything for that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In distributed, it should be the same command

I thought that for FSDP we use fsdp_cpu_offload and let fsdp do the offloading. Unless you think it makes sense to have two different offloading strategies

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FSDP's CPU offload will just offload parameters, gradients, and optimizer states though, right? Not activations. Do we expect that the two will work together? I can't immediately think of any reason why they wouldn't but maybe there's something obvious I'm missing

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, i missed that FSDP doesnt offload activations. Thats nice! I thought it was a solved issue for FSDP and we were just improving single device

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They theoretically should work together but who knows man, we can't say so til we test! I'm suspicious that there will be unforeseen results with intranode comms, so the scope of this PR is just single device.

recipes/lora_finetune_single_device.py Show resolved Hide resolved
noop_ctx = NoOpManager()
model.output.register_forward_pre_hook(lambda *args: noop_ctx.__enter__())
model.output.register_forward_hook(
lambda *args: noop_ctx.__exit__(), always_call=True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to test it with qwen or gemma, that have tied embeddings. They dont have model.output, since they use TiedTransformerDecoder, so this would fail. We would need to add something like "if hasattr(model.output)"

The issue is that they use the output = F.Linear(h, self.tok_embeddings.weight).

in this PR (#1527), i created a regular python class (no nn.Module) to replace the F.Linear logic, so we can get rid of the TiedTransformerDecoder

class TiedLinear:
    def __init__(self, tied_module: nn.Module):
        self.tied_module = tied_module

    def __call__(self, x: torch.tensor) -> torch.tensor:
        return F.linear(x, self.tied_module.weight)

model=TransformerDecoder(output_projection=TiedLinear(embeddings.weight))

I imagine that hooks only work with nn.module. Is that true? I didnt try to make the TiedLinear an nn.Module because it doesnt have its own weights, and i didnt want FSDP and other wrappers to interact with it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should probably add this NoOp logic to some utils, but no strong opinion

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yea, I agree we need to test these. The hooks do only work with module (since they're module hooks) so that would require some design...e.g., we enable adding hooks to this class you added, or we make it an nn.Module still but exclude it from FSDP/other wrappers. Or something else 😛

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah this tied embedding business is actually a pretty important point. As a hack we can maybe do something like if hasattr(model, "output") and isinstance(model.output, nn.Module) to gate the registration of these hooks?

Would have to think more about whether we can get away with making TiedLinear an nn.Module though.. half the reason we're doing things this way is to avoid the breaking of references for tied weights when FSDP calls to_empty. Is there an easy way to exclude an nn.Module from all FSDP's usual hook registration etc?

Anyways before I get too far down that rabbit hole, I'd propose just going with the hack (provided it works). It's dumb, simple, and explicit, which I like.

but is a limited resource. Default: True.

use_streams (bool): Whether or not to use streams for performance optimization where
the communications get overlapped with the computation. Default: True.
Copy link
Contributor

@felipemello1 felipemello1 Sep 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should probably add some safe guards here for torch 2.4.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

100% I think it'd have to be 2.5 too, since the change needed is in nightlies and not even 2.4.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah let's definitely add a version check then. What will the behavior be if someone tries to run on 2.4? Will it fully error out? And is it only a concern when use_streams=True?

Copy link
Contributor

@felipemello1 felipemello1 Sep 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my understanding is that use_streams=False works in 2.4, so i dont see why we should raise an error. A warning plus defaulting to use_streams=False sounds reasonable to me

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, use_streams=False will work with 2.4. It's only with use_streams=True where we should error. Is there already another feature gated by torch version that I should copy?

And yes, I will default to False for use_streams for this PR

@janeyx99 janeyx99 force-pushed the activations-offloading-streams branch from 65936bc to f1178c7 Compare September 11, 2024 17:16
docs/Makefile Outdated Show resolved Hide resolved
@@ -25,6 +25,7 @@ dependencies = [
"numpy<=1.26.4", # Pin here until https:/tensorflow/tensorboard/issues/6869 is addressed
"tqdm",
"omegaconf",
"psutil",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ebsmothers is this okay? This is a new requirement as we use psutil to check cpu RAM usage and warn on too much usage.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI: psutil does not pull in other deps!

@@ -222,6 +239,8 @@ def setup(self, cfg: DictConfig) -> None:
self._model = self._setup_model(
cfg_model=cfg.model,
enable_activation_checkpointing=cfg.enable_activation_checkpointing,
enable_activation_offloading=cfg.get("enable_activation_offloading", False),
offload_with_streams=cfg.get("offload_with_streams", False),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that this is the last thing to approve the PR: I dont like having two arguments for offloading. What do you think about making enable_activation_offloading have 3 options?

False / "with_stream" / "without_stream"

or something like that?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One q here: is there any case that I wouldn't wanna use streams provided I'm on a sufficiently recent PyTorch version? If there are no feature gaps I'm inclined to just have a single enable_activation_offloading flag that will run with streams if possible and otherwise run the vanilla version.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We haven't tested streams use on other hardware or with distributed or with other shape models, so I was hesitant to make it the default. Since we're only landing the lora finetuning recipe on single device with this PR, it is fine to just not include that flag option for now.

I've removed it but it would be p easy to add it back if required in the future.

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the only remaining thing here is sorting out versioning/how we expose streaming vs. non-streaming. But that's a minor point (and seems like we're in agreement on it anyways). Other than that this looks great. Can't wait to have this available for people to try out!

self.s1.wait_stream(self.s0)

stream = self.s1 if self.use_streams else self.s0
with torch.cuda.stream(stream):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the following code is also new @ebsmothers @felipemello1

We add a similar version check for ao. This one is less "essential" than the torch check because someone will only run into this if they play around with AC and start offloading NF4Tensors. I've landed the torchao counterpart here pytorch/ao#881.

@ebsmothers ebsmothers merged commit 7045e96 into pytorch:main Sep 16, 2024
17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants