Skip to content

Commit

Permalink
Compile model+loss for LoRA single device recipe (#1296)
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst authored Aug 10, 2024
1 parent 2522c41 commit 00bbd53
Showing 1 changed file with 26 additions and 25 deletions.
51 changes: 26 additions & 25 deletions recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,9 @@ def _setup_model(
if compile_model:
log.info("Compiling model with torch.compile...")
backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor")
model.compile(backend=backend)
self._loss_step_original = self._loss_step
self._loss_step = torch.compile(self._loss_step, backend=backend)

if self._device.type == "cuda":
memory_stats = utils.get_memory_stats(device=self._device)
utils.log_memory_stats(memory_stats)
Expand Down Expand Up @@ -526,6 +528,26 @@ def save_checkpoint(self, epoch: int) -> None:
adapter_only=self._save_adapter_weights_only,
)

def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
# Both are shape [b, s]
tokens, labels = batch["tokens"], batch["labels"]
# Get the attention mask and position ids from the dataset if they
# exist. Currently, only sample packing in PackedDataset returns these
mask = batch.get("mask", None) # shape [b, s, s]
input_pos = batch.get("input_pos", None) # shape [b, s]

logits = self._model(tokens, mask=mask, input_pos=input_pos)
# Shift so that tokens < n predict n
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
logits = logits.transpose(1, 2)
# Compute loss
loss = self._loss_fn(logits, labels)
# free logits otherwise it peaks backward memory
del logits

return loss

def train(self) -> None:
"""
The core training loop.
Expand Down Expand Up @@ -557,31 +579,10 @@ def train(self) -> None:
):
break

# Both are shape [b, s]
tokens, labels = batch["tokens"], batch["labels"]
# Get the attention mask and position ids from the dataset if they
# exist. Currently, only sample packing in PackedDataset returns these
mask = batch.get("mask", None) # shape [b, s, s]
input_pos = batch.get("input_pos", None) # shape [b, s]

tokens = tokens.to(self._device)
num_tokens += tokens.numel()
labels = labels.to(self._device)
mask = mask.to(self._device) if mask is not None else None
input_pos = (
input_pos.to(self._device) if input_pos is not None else None
)

logits = self._model(tokens, mask=mask, input_pos=input_pos)
# Shift so that tokens < n predict n
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
logits = logits.transpose(1, 2)
# Compute loss
loss = self._loss_fn(logits, labels)
# free logits otherwise it peaks backward memory
del logits
batch = {k: v.to(self._device) for k, v in batch.items()}
num_tokens += batch["tokens"].numel()

loss = self._loss_step(batch)
loss = loss / self._gradient_accumulation_steps
running_loss += loss
loss.backward()
Expand Down

0 comments on commit 00bbd53

Please sign in to comment.