-
Notifications
You must be signed in to change notification settings - Fork 404
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Streaming offloading in (q)lora single device #1443
Streaming offloading in (q)lora single device #1443
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1443
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 8c91a32 with merge base df29d8a (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
8b3aeca
to
f88c44a
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1443 +/- ##
===========================================
- Coverage 73.08% 26.84% -46.24%
===========================================
Files 286 287 +1
Lines 13828 13953 +125
===========================================
- Hits 10106 3746 -6360
- Misses 3722 10207 +6485 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! Very exciting results!!
I did a first pass, but I dont think that you should make changes just yet, until we are sure that streaming is the way to go. I can do some testing on my end this week, but we need to check the trade-off vs compile. If compile + no offloading > no compile + offloading, we may need to go back to the drawing board.
However, since the PR is up, maybe this would make it easier for folks from pytorch to investigate streaming support? cc: @weifengpy
On a second note, if we do proceed with streaming, I would like to brainstorm a bit about how we can avoid touching the transformer class.
thanks again!
torchtune/modules/transformer.py
Outdated
# shape: [b, seq_len, out_dim] | ||
output = self.output(h).float() | ||
# Disable activations offloading locally | ||
from torchtune.utils._activations_offloading.offload_ctx_mgr import NoOpManager |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
:O
is there anyway that we can do it outside of the transformer?
for example:
NoOpWrapper(model.output)
I think that our transformer class it becoming a bit too crazy haha
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or in the context manager ignore a module if module in list_ignore_module
self.curr_autograd_node = None | ||
|
||
# platform util functions | ||
def verify_sufficient_virtual_memory(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a motivation to define so many functions within init? i.e. why not do it as a method of the class or as another function outside of the class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's cuz everything gets accessed from the init, so everything got defined here for linearity. I don't have any real preference though. Is there a reason to move them out? If so, where would be best? (In the class but as a class method or completely out of the class.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh actually, I know why everything's in the init lol. It's cuz pack
can't take in any variables other than the tensor --> so we can't make that a class method that takes in self
. And technically we could move it to the global space and have OffloadActivations
write a bunch of metadata to global, but that doesn't seem better.
# First, sync back and dereference previously offloaded tensors | ||
# as the offloading should be done sufficiently long ago. | ||
for k in [x for x in self.fwd_stash.keys()]: | ||
if k <= tensor_id - self.max_fwd_stash_size: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as a rule of thumb, if we can replace "x", "k" and "ev" with meaningful var names, i think it is worth to do so
@felipemello1 thanks for the review.
What does > mean here? Are we talking runtime? memory usage? both? |
memory wise, i dont think that compile can be better. So probably the correct comparison should be: "if compile + some offloading method that works with compile > no compile + stream offloading, we may need to go back to the drawing board." What do you think? E.g. if using tensor hooks is 15% slower than stream, but compile makes it 2x faster, and saves more memory on top of that, then it may be better to use tensor hooks. |
Hey Felipe, Do you mean torch.compile + streaming? that would be another Will Feng |
Sorry! I meant @yf225 |
I believe with per-layer compile, compile + offloading should functionally work (since we are doing offloading outside of the compiled graph). But yes performance testing is still needed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is awesome! Mostly just have a bunch of noob questions 😃
# managing cpu memory | ||
self.use_pin_memory: bool = use_pin_memory | ||
self.virtual_memory_safe_pct = ( | ||
0.60 # we should not exceed this percentage of memory |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How is this value determined? Also I guess we don't actually do anything with it, just warn?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea, this is arbitrary + was Less's idea. I guess it would print fair warnings until someone's RAM becomes busted lol
# -------- core pack / unpack work -------- | ||
def pack_tensor(activation: torch.Tensor) -> int: | ||
# activations are passed in during forward pass - from here we take over and return a unique id | ||
if self.is_first_forward_call: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Noob q: what is the "first" here referring to?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So a super contrived fwd could be thought of as a collection of ops, aka a collection of nodes:
def fwd(z):
a = z * z
b = sin(a)
c = b + 1
In every step, you can imagine the fwd gets called once. So the first forward pass is the first step overall in training. And the first forward call is the first operation in each fwd, so here, it would be the a = z * z
.
maybe_offload_ctx = OffloadActivations(use_streams=True) | ||
noop_ctx = NoOpManager() | ||
self._model.output.register_forward_pre_hook( | ||
lambda *args: noop_ctx.__enter__() | ||
) | ||
self._model.output.register_forward_hook( | ||
lambda *args: noop_ctx.__exit__(), always_call=True | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just so I understand, this is our hack to ensure that output activations do not get offloaded, right? (Nice hack though.) Is this just cause those are actually too hefty, or are there other reasons for this (e.g. interaction with our chunked output layer)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The compute for them is super small compared to offloading them (cuz they're chunky), and they're used pretty much right away so we'd waste time moving them to CPU and waiting for them to come back. I'll add a comment for this
|
||
# If we're on a new node, mark prev node's tensors to be freed later | ||
if ( | ||
graph_id == self.curr_graph_id |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe this is just cause I don't understand what _current_graph_task_id
represents, but when would equality not hold here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Different backwards could maybe be occurring at the same time--this makes sure that we're tracking for the right graph. I don't know many details on this and followed @albanD's advice here.
if unpack_tensor_id in self.fwd_stash: | ||
maybe_gpu_tensor = self.fwd_stash[unpack_tensor_id][0] | ||
brought_back_from_cpu = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I don't fully follow this; do we need to keep on CPU in this case cause it's one of the last max_fwd_stash_size
tensors?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need to keep it on CPU, but we don't have preemption to know whether it's one of the last in the fwd, so we'd already have kicked off the moving to CPU. Here, we're just taking advantage of the fact that the fwd_stash must be 1 large, and we're telling the bwd to just pull from the stash if there's still something in it instead of waiting for the tensor to be pulled from CPU to GPU.
I see some conflicts. Please ping me after they are solved and I will review the PR again, so we can land it. |
3dc8147
to
9a91b9b
Compare
38a3098
to
65936bc
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
first pass. I will read more carefully later
in order to overlap the extra communication with the computation to hide the extra runtime. As the communication | ||
workload is variable depending on the number and size of tensors being offloaded, it is common to not offload every | ||
single activation. In fact, once can use offloading in conjunction with activations checkpointing, where all | ||
activations will either be recomputed later in the backward or brought back from the CPU. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as I remember, this only works if activation_checkpointing is True. Is that still right? If so, we should probably update this doc and add to the recipes to raise and error or set AC=True automatically.
Another option, which i would prefer, is to investigate allowing offloading without AC, since streaming seems promising
nit: I believe you meant "one can use offloading" instead of "once can use offloading
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean by "this" in the first sentence? Activations offloading works when AC is false as well, it's just super slow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i remember trying to use only offloading, with AC=False, and it broke. Maybe its not the case anymore?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea it shouldn't break! It should just be reaaaaally slow
*Sounds great! How do I use it?* | ||
|
||
To enable activation offloading, use the ``enable_activation_offloading`` config entry or flag | ||
in our lora finetuning single device recipe, e.g. ``enable_activation_offloading=True``. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i dont think that this will stay exclusive to lora finetuning. Maybe you are suggesting that this should be updated when it becomes available in full_single_device?
PS: do you plan to do it? (its fine if the answer is no, just checking :) )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we also add a note about offloading in single device vs FSDP, since in distributed it is another command?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea I was thinking we should mention lora finetuning single device for now in particular as it's only enabled for those. It would be important to widen this when more support is added --> I'm happy to do it.
In distributed, it should be the same command, but I haven't tested at all, so did not want to promise anything for that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In distributed, it should be the same command
I thought that for FSDP we use fsdp_cpu_offload and let fsdp do the offloading. Unless you think it makes sense to have two different offloading strategies
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FSDP's CPU offload will just offload parameters, gradients, and optimizer states though, right? Not activations. Do we expect that the two will work together? I can't immediately think of any reason why they wouldn't but maybe there's something obvious I'm missing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, i missed that FSDP doesnt offload activations. Thats nice! I thought it was a solved issue for FSDP and we were just improving single device
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They theoretically should work together but who knows man, we can't say so til we test! I'm suspicious that there will be unforeseen results with intranode comms, so the scope of this PR is just single device.
noop_ctx = NoOpManager() | ||
model.output.register_forward_pre_hook(lambda *args: noop_ctx.__enter__()) | ||
model.output.register_forward_hook( | ||
lambda *args: noop_ctx.__exit__(), always_call=True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we need to test it with qwen or gemma, that have tied embeddings. They dont have model.output, since they use TiedTransformerDecoder, so this would fail. We would need to add something like "if hasattr(model.output)"
The issue is that they use the output = F.Linear(h, self.tok_embeddings.weight).
in this PR (#1527), i created a regular python class (no nn.Module) to replace the F.Linear logic, so we can get rid of the TiedTransformerDecoder
class TiedLinear:
def __init__(self, tied_module: nn.Module):
self.tied_module = tied_module
def __call__(self, x: torch.tensor) -> torch.tensor:
return F.linear(x, self.tied_module.weight)
model=TransformerDecoder(output_projection=TiedLinear(embeddings.weight))
I imagine that hooks only work with nn.module. Is that true? I didnt try to make the TiedLinear an nn.Module because it doesnt have its own weights, and i didnt want FSDP and other wrappers to interact with it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should probably add this NoOp logic to some utils, but no strong opinion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yea, I agree we need to test these. The hooks do only work with module (since they're module hooks) so that would require some design...e.g., we enable adding hooks to this class you added, or we make it an nn.Module still but exclude it from FSDP/other wrappers. Or something else 😛
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yeah this tied embedding business is actually a pretty important point. As a hack we can maybe do something like if hasattr(model, "output") and isinstance(model.output, nn.Module)
to gate the registration of these hooks?
Would have to think more about whether we can get away with making TiedLinear
an nn.Module though.. half the reason we're doing things this way is to avoid the breaking of references for tied weights when FSDP calls to_empty
. Is there an easy way to exclude an nn.Module from all FSDP's usual hook registration etc?
Anyways before I get too far down that rabbit hole, I'd propose just going with the hack (provided it works). It's dumb, simple, and explicit, which I like.
but is a limited resource. Default: True. | ||
|
||
use_streams (bool): Whether or not to use streams for performance optimization where | ||
the communications get overlapped with the computation. Default: True. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should probably add some safe guards here for torch 2.4.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
100% I think it'd have to be 2.5 too, since the change needed is in nightlies and not even 2.4.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yeah let's definitely add a version check then. What will the behavior be if someone tries to run on 2.4? Will it fully error out? And is it only a concern when use_streams=True
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my understanding is that use_streams=False works in 2.4, so i dont see why we should raise an error. A warning plus defaulting to use_streams=False sounds reasonable to me
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea, use_streams=False will work with 2.4. It's only with use_streams=True where we should error. Is there already another feature gated by torch version that I should copy?
And yes, I will default to False for use_streams for this PR
65936bc
to
f1178c7
Compare
@@ -25,6 +25,7 @@ dependencies = [ | |||
"numpy<=1.26.4", # Pin here until https:/tensorflow/tensorboard/issues/6869 is addressed | |||
"tqdm", | |||
"omegaconf", | |||
"psutil", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ebsmothers is this okay? This is a new requirement as we use psutil to check cpu RAM usage and warn on too much usage.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI: psutil does not pull in other deps!
@@ -222,6 +239,8 @@ def setup(self, cfg: DictConfig) -> None: | |||
self._model = self._setup_model( | |||
cfg_model=cfg.model, | |||
enable_activation_checkpointing=cfg.enable_activation_checkpointing, | |||
enable_activation_offloading=cfg.get("enable_activation_offloading", False), | |||
offload_with_streams=cfg.get("offload_with_streams", False), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that this is the last thing to approve the PR: I dont like having two arguments for offloading. What do you think about making enable_activation_offloading have 3 options?
False / "with_stream" / "without_stream"
or something like that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One q here: is there any case that I wouldn't wanna use streams provided I'm on a sufficiently recent PyTorch version? If there are no feature gaps I'm inclined to just have a single enable_activation_offloading
flag that will run with streams if possible and otherwise run the vanilla version.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We haven't tested streams use on other hardware or with distributed or with other shape models, so I was hesitant to make it the default. Since we're only landing the lora finetuning recipe on single device with this PR, it is fine to just not include that flag option for now.
I've removed it but it would be p easy to add it back if required in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the only remaining thing here is sorting out versioning/how we expose streaming vs. non-streaming. But that's a minor point (and seems like we're in agreement on it anyways). Other than that this looks great. Can't wait to have this available for people to try out!
self.s1.wait_stream(self.s0) | ||
|
||
stream = self.s1 if self.use_streams else self.s0 | ||
with torch.cuda.stream(stream): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the following code is also new @ebsmothers @felipemello1
We add a similar version check for ao. This one is less "essential" than the torch check because someone will only run into this if they play around with AC and start offloading NF4Tensors. I've landed the torchao counterpart here pytorch/ao#881.
Context
What is the purpose of this PR? Is it to
Please link to any issues this PR addresses.
Changelog
What are the changes made in this PR?
Add a streaming activations offloading API to the lora finetune single device script and relevant recipes. All changes are in the newly added context manager, which streams activations to CPU and back with heuristical synchronization. The one thing that may be weird: for peak performance, we don't want to offload the last output computation of TransformerDecoder, because it does very little compute on super large memory. So offloading is actually detrimental, due to the time to offload being much longer than the compute for overlapping. So in this PR, I disable offloading by wrapping a no-op context manager around the code. This shouldn't have an impact on any existing recipe even if no saved tensor hooks were applied (as it still does AC), but I can gate the logic on only when activation_offloading is enabled.
This change depends on pytorch/pytorch#134728 to make sure the post node hooks actually run.
For offloading to work with NF4Tensor, we also need to modify AO to give NF4 more overloads. There will be a PR coming up for that.
Test plan
Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.)
I have run this on llama3 with seq len 8k, bs2 to observe memory savings with <1% perf hit per step. The command:
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Example of docstring:
torchtune/torchtune/modules/vision_transformer.py
Line 285 in 6a7951f
Example in our docs: https://pytorch.org/torchtune/main/tutorials/qat_finetune.html#applying-qat-to-llama3-models
With activations offloading, we use only 10.1GiB peak memory for batchsize 2 seqlen 8k compared to 13.2GiB for batchsize 1.
Runtime, with offloading compared to without was only slower 20ms for a step of 23s, meaning <1% slowdown for llama3-8B. This is because almost all the comm is overlapped.
[INTERNAL ONLY] https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/janeyx_731fbe41-9354-4f14-b308-83166146c376_qlora_AC_stream_AO_noop_bs2_llama3_4steps.json.gz