-
Notifications
You must be signed in to change notification settings - Fork 404
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
When profile_memory, also export and save the snapshot.pickle for lora_finetune_single_device.py #1382
When profile_memory, also export and save the snapshot.pickle for lora_finetune_single_device.py #1382
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -328,6 +328,11 @@ def _setup_profiler( | |||||||||
|
||||||||||
log.info(f" Profiler config after instantiation: {profiler_cfg}") | ||||||||||
|
||||||||||
self.profiler_wait_steps = profiler_cfg["wait_steps"] | ||||||||||
self.profiler_warmup_steps = profiler_cfg["warmup_steps"] | ||||||||||
self.profiler_active_steps = profiler_cfg["active_steps"] | ||||||||||
self.profiler_profile_memory = profiler_cfg["profile_memory"] | ||||||||||
|
||||||||||
return profiler | ||||||||||
|
||||||||||
def _setup_model( | ||||||||||
|
@@ -579,6 +584,9 @@ def train(self) -> None: | |||||||||
): | ||||||||||
break | ||||||||||
|
||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
if curr_epoch == 0 and self.profiler_profile_memory and idx == self.profiler_wait_steps + self.profiler_warmup_steps: | ||||||||||
torch.cuda.memory._record_memory_history() | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The issue is if in the future we support other backend, like intel. It makes me think that this should be a function, instead of being hardcoded as "torch.cuda" cc: @ebsmothers There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I think it's OK, idk that memory snapshot would be supported on XPU anyways. We can do a check on device + profile_memory combo just to be safe though |
||||||||||
|
||||||||||
batch = {k: v.to(self._device) for k, v in batch.items()} | ||||||||||
num_tokens += batch["tokens"].numel() | ||||||||||
|
||||||||||
|
@@ -626,6 +634,9 @@ def train(self) -> None: | |||||||||
num_tokens = 0 | ||||||||||
t0 = time.perf_counter() | ||||||||||
|
||||||||||
if curr_epoch == 0 and self.profiler_profile_memory and idx == self.profiler_wait_steps + self.profiler_warmup_steps + self.profiler_active_steps: | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
torch.cuda.memory._record_memory_history(enabled=None) | ||||||||||
|
||||||||||
# Step the profiler | ||||||||||
# Note we are stepping each batch, which might not include optimizer step in the trace | ||||||||||
# if the schedule cycle doesn't align with gradient accumulation. | ||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems to be throwing an error in our recipe tests, can you just quickly inspect
profiler_cfg
after running e.g.I thought these fields would be defined based on this, but seems like something weird is happening here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, it's because profiling isn't enabled for tests, so it returns here
torchtune/torchtune/utils/_profiler.py
Line 262 in 2f8ed7a