diff --git a/examples/pytorch/mvgrl/graph/main.py b/examples/pytorch/mvgrl/graph/main.py index bebadf319928..38f5aaf8f7b1 100644 --- a/examples/pytorch/mvgrl/graph/main.py +++ b/examples/pytorch/mvgrl/graph/main.py @@ -130,8 +130,8 @@ def collate(samples): print("Epoch {}, Loss {:.4f}".format(epoch, loss_all)) - if loss < best: - best = loss + if loss_all < best: + best = loss_all best_t = epoch cnt_wait = 0 th.save(model.state_dict(), f"{args.dataname}.pkl")