-
Notifications
You must be signed in to change notification settings - Fork 34
/
train.py
193 lines (155 loc) · 6.64 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
"""Train one of the available models."""
# =============================================================================
# Copyright 2021 Henrique Morimitsu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
from argparse import ArgumentParser, Namespace
from pathlib import Path
import sys
import lightning
import lightning.pytorch as pl
import torch
from lightning.pytorch.strategies import DDPStrategy
from ptlflow import get_model, get_model_reference
from ptlflow.utils.callbacks.logger import LoggerCallback
from ptlflow.utils.utils import (
add_datasets_to_parser,
get_list_of_available_models_list,
)
def _init_parser() -> ArgumentParser:
parser = ArgumentParser()
parser.add_argument(
"model",
type=str,
choices=get_list_of_available_models_list(),
help="Name of the model to use.",
)
parser.add_argument(
"--random_seed",
type=int,
default=1234,
help="A number to seed the pseudo-random generators.",
)
parser.add_argument(
"--clear_train_state",
action="store_true",
help=(
"Only used if --resume_from_checkpoint is not None. If set, only the weights are loaded from the checkpoint "
"and the training state is ignored. Set it when you want to finetune the model from a previous checkpoint."
),
)
parser.add_argument(
"--log_dir",
type=str,
default="ptlflow_logs",
help="The path to the directory where the logs will be saved.",
)
parser.add_argument("--find_unused_parameters", action="store_true")
return parser
def train(args: Namespace) -> None:
"""Run the training.
Parameters
----------
args : Namespace
Arguments to configure the training.
"""
_print_untested_warning()
lightning.fabric.utilities.seed.seed_everything(args.random_seed)
if args.train_transform_cuda:
from torch.multiprocessing import set_start_method
set_start_method("spawn")
if args.train_dataset is None:
args.train_dataset = "chairs-train"
print('INFO: --train_dataset is not set. It will be set to "chairs-train"')
log_model_name = f"{args.model}-{_gen_dataset_id(args.train_dataset)}"
model = get_model(args.model, args.pretrained_ckpt, args)
if args.resume_from_checkpoint is not None and args.clear_train_state:
# Restore model weights, but not the train state
pl_ckpt = torch.load(args.resume_from_checkpoint)
model.load_state_dict(pl_ckpt["state_dict"])
args.resume_from_checkpoint = None
# Setup loggers and callbacks
callbacks = []
lr_logger = pl.callbacks.LearningRateMonitor(logging_interval="step")
callbacks.append(lr_logger)
log_model_dir = str(Path(args.log_dir) / log_model_name)
tb_logger = pl.loggers.TensorBoardLogger(log_model_dir)
model.val_dataloader() # Called just to populate model.val_dataloader_names
model_ckpt_last = pl.callbacks.model_checkpoint.ModelCheckpoint(
filename=args.model + "_last_{epoch}_{step}", save_weights_only=True, mode="max"
)
callbacks.append(model_ckpt_last)
model_ckpt_train = pl.callbacks.model_checkpoint.ModelCheckpoint(
filename=args.model + "_train_{epoch}_{step}"
)
callbacks.append(model_ckpt_train)
if len(model.val_dataloader_names) > 0:
model_ckpt_best = pl.callbacks.model_checkpoint.ModelCheckpoint(
filename=args.model
+ "_best_{"
+ model.val_dataloader_names[0]
+ ":.2f}_{epoch}_{step}",
save_weights_only=True,
save_top_k=1,
monitor=model.val_dataloader_names[0],
)
callbacks.append(model_ckpt_best)
callbacks.append(LoggerCallback())
trainer = pl.Trainer.from_argparse_args(
args,
logger=tb_logger,
callbacks=callbacks,
strategy=DDPStrategy(find_unused_parameters=args.find_unused_parameters),
)
trainer.tune(model)
trainer.fit(model)
def _gen_dataset_id(dataset_string: str) -> str:
sep_datasets = dataset_string.split("+")
names_list = []
for dataset in sep_datasets:
if "*" in dataset:
tokens = dataset.split("*")
try:
_, dataset_params = int(tokens[0]), tokens[1]
except ValueError: # the multiplier is at the end
dataset_params = tokens[0]
else:
dataset_params = dataset
dataset_name = dataset_params.split("-")[0]
names_list.append(dataset_name)
dataset_id = "_".join(names_list)
return dataset_id
def _print_untested_warning():
print("###########################################################################")
print("# WARNING, please read! #")
print("# #")
print("# This training script has not been tested! #")
print("# Therefore, there is no guarantee that a model trained with this script #")
print("# will produce good results after the training! #")
print("# #")
print("# You can find more information at #")
print("# https://ptlflow.readthedocs.io/en/latest/starting/training.html #")
print("###########################################################################")
if __name__ == "__main__":
parser = _init_parser()
# TODO: It is ugly that the model has to be gotten from the argv rather than the argparser.
# However, I do not see another way, since the argparser requires the model to load some of the args.
FlowModel = None
if len(sys.argv) > 1 and sys.argv[1] != "-h" and sys.argv[1] != "--help":
FlowModel = get_model_reference(sys.argv[1])
parser = FlowModel.add_model_specific_args(parser)
add_datasets_to_parser(parser, "datasets.yml")
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
train(args)