Skip to content

Commit

Permalink
Merge pull request #522 from vue1999/stage-two-weights
Browse files Browse the repository at this point in the history
Stage two weights
  • Loading branch information
ilyes319 authored Jul 23, 2024
2 parents 0d2c32d + d8f6ae1 commit 475a0ce
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 26 deletions.
18 changes: 9 additions & 9 deletions mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,19 +585,19 @@ def run(args: argparse.Namespace) -> None:
swa: Optional[tools.SWAContainer] = None
swas = [False]
if args.swa:
assert dipole_only is False, "swa for dipole fitting not implemented"
assert dipole_only is False, "Stage Two for dipole fitting not implemented"
swas.append(True)
if args.start_swa is None:
args.start_swa = max(1, args.max_num_epochs // 4 * 3)
else:
if args.start_swa > args.max_num_epochs:
logging.info(
f"Start swa must be less than max_num_epochs, got {args.start_swa} > {args.max_num_epochs}"
f"Start Stage Two must be less than max_num_epochs, got {args.start_swa} > {args.max_num_epochs}"
)
args.start_swa = max(1, args.max_num_epochs // 4 * 3)
logging.info(f"Setting start swa to {args.start_swa}")
logging.info(f"Setting start Stage Two to {args.start_swa}")
if args.loss == "forces_only":
raise ValueError("Can not select swa with forces only loss.")
raise ValueError("Can not select Stage Two with forces only loss.")
if args.loss == "virials":
loss_fn_energy = modules.WeightedEnergyForcesVirialsLoss(
energy_weight=args.swa_energy_weight,
Expand All @@ -617,15 +617,15 @@ def run(args: argparse.Namespace) -> None:
dipole_weight=args.swa_dipole_weight,
)
logging.info(
f"Using stochastic weight averaging (after {args.start_swa} epochs) with energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight}, dipole weight : {args.swa_dipole_weight} and learning rate : {args.swa_lr}"
f"Stage Two (after {args.start_swa} epochs) with energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight}, dipole weight : {args.swa_dipole_weight} and learning rate : {args.swa_lr}"
)
else:
loss_fn_energy = modules.WeightedEnergyForcesLoss(
energy_weight=args.swa_energy_weight,
forces_weight=args.swa_forces_weight,
)
logging.info(
f"Using stochastic weight averaging (after {args.start_swa} epochs) with energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight} and learning rate : {args.swa_lr}"
f"Stage Two (after {args.start_swa} epochs) with energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight} and learning rate : {args.swa_lr}"
)
swa = tools.SWAContainer(
model=AveragedModel(model),
Expand Down Expand Up @@ -807,7 +807,7 @@ def run(args: argparse.Namespace) -> None:
if rank == 0:
# Save entire model
if swa_eval:
model_path = Path(args.checkpoints_dir) / (tag + "_swa.model")
model_path = Path(args.checkpoints_dir) / (tag + "_stagetwo.model")
else:
model_path = Path(args.checkpoints_dir) / (tag + ".model")
logging.info(f"Saving model to {model_path}")
Expand All @@ -821,10 +821,10 @@ def run(args: argparse.Namespace) -> None:
),
}
if swa_eval:
torch.save(model, Path(args.model_dir) / (args.name + "_swa.model"))
torch.save(model, Path(args.model_dir) / (args.name + "_stagetwo.model"))
try:
path_complied = Path(args.model_dir) / (
args.name + "_swa_compiled.model"
args.name + "_stagetwo_compiled.model"
)
logging.info(f"Compiling model, saving metadata {path_complied}")
model_compiled = jit.compile(deepcopy(model))
Expand Down
37 changes: 22 additions & 15 deletions mace/tools/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,46 +388,51 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
"--forces_weight", help="weight of forces loss", type=float, default=100.0
)
parser.add_argument(
"--swa_forces_weight",
help="weight of forces loss after starting swa",
"--swa_forces_weight","--stage_two_forces_weight",
help="weight of forces loss after starting Stage Two (previously called swa)",
type=float,
default=100.0,
dest="swa_forces_weight",
)
parser.add_argument(
"--energy_weight", help="weight of energy loss", type=float, default=1.0
)
parser.add_argument(
"--swa_energy_weight",
help="weight of energy loss after starting swa",
"--swa_energy_weight","--stage_two_energy_weight",
help="weight of energy loss after starting Stage Two (previously called swa)",
type=float,
default=1000.0,
dest="swa_energy_weight",
)
parser.add_argument(
"--virials_weight", help="weight of virials loss", type=float, default=1.0
)
parser.add_argument(
"--swa_virials_weight",
help="weight of virials loss after starting swa",
"--swa_virials_weight", "--stage_two_virials_weight",
help="weight of virials loss after starting Stage Two (previously called swa)",
type=float,
default=10.0,
dest="swa_virials_weight",
)
parser.add_argument(
"--stress_weight", help="weight of virials loss", type=float, default=1.0
)
parser.add_argument(
"--swa_stress_weight",
help="weight of stress loss after starting swa",
"--swa_stress_weight", "--stage_two_stress_weight",
help="weight of stress loss after starting Stage Two (previously called swa)",
type=float,
default=10.0,
dest="swa_stress_weight",
)
parser.add_argument(
"--dipole_weight", help="weight of dipoles loss", type=float, default=1.0
)
parser.add_argument(
"--swa_dipole_weight",
help="weight of dipoles after starting swa",
"--swa_dipole_weight","--stage_two_dipole_weight",
help="weight of dipoles after starting Stage Two (previously called swa)",
type=float,
default=1.0,
dest="swa_dipole_weight",
)
parser.add_argument(
"--config_type_weights",
Expand Down Expand Up @@ -462,7 +467,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
"--lr", help="Learning rate of optimizer", type=float, default=0.01
)
parser.add_argument(
"--swa_lr", help="Learning rate of optimizer in swa", type=float, default=1e-3
"--swa_lr", "--stage_two_lr", help="Learning rate of optimizer in Stage Two (previously called swa)", type=float, default=1e-3, dest="swa_lr"
)
parser.add_argument(
"--weight_decay", help="weight decay (L2 penalty)", type=float, default=5e-7
Expand All @@ -489,16 +494,18 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
default=0.9993,
)
parser.add_argument(
"--swa",
help="use Stochastic Weight Averaging, which decreases the learning rate and increases the energy weight at the end of the training to help converge them",
"--swa", "--stage_two",
help="use Stage Two loss weight, which decreases the learning rate and increases the energy weight at the end of the training to help converge them",
action="store_true",
default=False,
dest="swa",
)
parser.add_argument(
"--start_swa",
help="Number of epochs before switching to swa",
"--start_swa","--start_stage_two",
help="Number of epochs before changing to Stage Two loss weights",
type=int,
default=None,
dest="start_swa",
)
parser.add_argument(
"--ema",
Expand Down
4 changes: 2 additions & 2 deletions mace/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def train(
) # Can break if exponential LR, TODO fix that!
else:
if swa_start:
logging.info("Changing loss based on SWA")
logging.info("Changing loss based on Stage Two Weights")
lowest_loss = np.inf
swa_start = False
keep_last = True
Expand Down Expand Up @@ -233,7 +233,7 @@ def train(
patience_counter += 1
if patience_counter >= patience and epoch < swa.start:
logging.info(
f"Stopping optimization after {patience_counter} epochs without improvement and starting swa"
f"Stopping optimization after {patience_counter} epochs without improvement and starting Stage Two"
)
epoch = swa.start
elif patience_counter >= patience and epoch >= swa.start:
Expand Down

0 comments on commit 475a0ce

Please sign in to comment.