Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update no_trainer scripts with new Accelerate functionalities #16617

Merged
merged 8 commits into from
Apr 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 80 additions & 9 deletions examples/pytorch/language-modeling/run_clm_no_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,23 @@ def parse_args():
"--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`."
)
parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
parser.add_argument(
"--checkpointing_steps",
type=str,
default=None,
help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help="If the training should continue from a checkpoint folder.",
)
parser.add_argument(
"--with_tracking",
required=False,
help="Whether to load in all available experiment trackers from the environment and use them for logging.",
)
args = parser.parse_args()

# Sanity checks
Expand All @@ -208,7 +225,8 @@ def main():
args = parse_args()

# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
accelerator = Accelerator()
# If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
accelerator = Accelerator(log_with="all") if args.with_tracking else Accelerator()
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
Expand Down Expand Up @@ -427,18 +445,10 @@ def group_texts(examples):
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)

# Prepare everything with our `accelerator`.
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
model, optimizer, train_dataloader, eval_dataloader
)

# On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
if accelerator.distributed_type == DistributedType.TPU:
model.tie_weights()

# Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be
# shorter in multiprocess)

# Scheduler and math around the number of training steps.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
Expand All @@ -453,6 +463,23 @@ def group_texts(examples):
num_training_steps=args.max_train_steps,
)

# Prepare everything with our `accelerator`.
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
)

# Figure out how many steps we should save the Accelerator states
if hasattr(args.checkpointing_steps, "isdigit"):
checkpointing_steps = args.checkpointing_steps
if args.checkpointing_steps.isdigit():
checkpointing_steps = int(args.checkpointing_steps)
else:
checkpointing_steps = None

# We need to initialize the trackers we use, and also store our configuration
if args.with_tracking:
accelerator.init_trackers("clm_no_trainer", args)

# Train!
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

Expand All @@ -467,11 +494,38 @@ def group_texts(examples):
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0

# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
accelerator.load_state(args.resume_from_checkpoint)
resume_step = None
path = args.resume_from_checkpoint
else:
# Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
if "epoch" in path:
args.num_train_epochs -= int(path.replace("epoch_", ""))
else:
resume_step = int(path.replace("step_", ""))
args.num_train_epochs -= resume_step // len(train_dataloader)
resume_step = (args.num_train_epochs * len(train_dataloader)) - resume_step

for epoch in range(args.num_train_epochs):
model.train()
if args.with_tracking:
total_loss = 0
for step, batch in enumerate(train_dataloader):
# We need to skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == 0 and step < resume_step:
continue
outputs = model(**batch)
loss = outputs.loss
# We keep track of the loss at each epoch
if args.with_tracking:
total_loss += loss.detach().float()
loss = loss / args.gradient_accumulation_steps
accelerator.backward(loss)
if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
Expand All @@ -481,6 +535,10 @@ def group_texts(examples):
progress_bar.update(1)
completed_steps += 1

if isinstance(checkpointing_steps, int):
if completed_steps % checkpointing_steps == 0:
accelerator.save_state(f"step_{completed_steps}")

if completed_steps >= args.max_train_steps:
break

Expand All @@ -502,6 +560,16 @@ def group_texts(examples):

logger.info(f"epoch {epoch}: perplexity: {perplexity}")

if args.with_tracking:
accelerator.log(
{
"perplexity": perplexity,
"train_loss": total_loss,
"epoch": epoch,
},
step=completed_steps,
)

if args.push_to_hub and epoch < args.num_train_epochs - 1:
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
Expand All @@ -512,6 +580,9 @@ def group_texts(examples):
commit_message=f"Training in progress epoch {epoch}", blocking=False, auto_lfs_prune=True
)

if args.checkpointing_steps == "epoch":
accelerator.save_state(f"epoch_{epoch}")

if args.output_dir is not None:
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
Expand Down
86 changes: 80 additions & 6 deletions examples/pytorch/language-modeling/run_mlm_no_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,23 @@ def parse_args():
"--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`."
)
parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
parser.add_argument(
"--checkpointing_steps",
type=str,
default=None,
help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help="If the training should continue from a checkpoint folder.",
)
parser.add_argument(
"--with_tracking",
required=False,
help="Whether to load in all available experiment trackers from the environment and use them for logging.",
)
args = parser.parse_args()

# Sanity checks
Expand All @@ -219,7 +236,8 @@ def main():
args = parse_args()

# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
accelerator = Accelerator()
# If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
accelerator = Accelerator(log_with="all") if args.with_tracking else Accelerator()
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
Expand Down Expand Up @@ -468,11 +486,6 @@ def group_texts(examples):
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)

# Prepare everything with our `accelerator`.
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
model, optimizer, train_dataloader, eval_dataloader
)

# On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
if accelerator.distributed_type == DistributedType.TPU:
model.tie_weights()
Expand All @@ -494,6 +507,23 @@ def group_texts(examples):
num_training_steps=args.max_train_steps,
)

# Prepare everything with our `accelerator`.
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
)

# Figure out how many steps we should save the Accelerator states
if hasattr(args.checkpointing_steps, "isdigit"):
checkpointing_steps = args.checkpointing_steps
if args.checkpointing_steps.isdigit():
checkpointing_steps = int(args.checkpointing_steps)
else:
checkpointing_steps = None

# We need to initialize the trackers we use, and also store our configuration
if args.with_tracking:
accelerator.init_trackers("clm_no_trainer", args)

# Train!
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

Expand All @@ -508,11 +538,38 @@ def group_texts(examples):
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0

# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
accelerator.load_state(args.resume_from_checkpoint)
resume_step = None
path = args.resume_from_checkpoint
else:
# Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
if "epoch" in path:
args.num_train_epochs -= int(path.replace("epoch_", ""))
else:
resume_step = int(path.replace("step_", ""))
args.num_train_epochs -= resume_step // len(train_dataloader)
resume_step = (args.num_train_epochs * len(train_dataloader)) - resume_step

for epoch in range(args.num_train_epochs):
model.train()
if args.with_tracking:
total_loss = 0
for step, batch in enumerate(train_dataloader):
# We need to skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == 0 and step < resume_step:
continue
outputs = model(**batch)
loss = outputs.loss
# We keep track of the loss at each epoch
if args.with_tracking:
total_loss += loss.detach().float()
loss = loss / args.gradient_accumulation_steps
accelerator.backward(loss)
if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
Expand All @@ -522,6 +579,10 @@ def group_texts(examples):
progress_bar.update(1)
completed_steps += 1

if isinstance(checkpointing_steps, int):
if completed_steps % checkpointing_steps == 0:
accelerator.save_state(f"step_{completed_steps}")

if completed_steps >= args.max_train_steps:
break

Expand All @@ -543,6 +604,16 @@ def group_texts(examples):

logger.info(f"epoch {epoch}: perplexity: {perplexity}")

if args.with_tracking:
accelerator.log(
{
"perplexity": perplexity,
"train_loss": total_loss,
"epoch": epoch,
},
step=completed_steps,
)

if args.push_to_hub and epoch < args.num_train_epochs - 1:
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
Expand All @@ -553,6 +624,9 @@ def group_texts(examples):
commit_message=f"Training in progress epoch {epoch}", blocking=False, auto_lfs_prune=True
)

if args.checkpointing_steps == "epoch":
accelerator.save_state(f"epoch_{epoch}")

if args.output_dir is not None:
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
Expand Down
Loading