Skip to content

Commit

Permalink
Update BERT example (#681)
Browse files Browse the repository at this point in the history
- Add TrainBertTest
- Add try-with-resources to model and trainer
- Add progress when creating batches
- Use arguments in BERT
- Add missing training listener notifications

Also updated:
- Make example train Arguments class inheritable to update both default values
and number of options. Before, defaults were not modifiable and it was more
difficult to use derivative arguments.

Change-Id: I085e1d42fc1849c1586ed0cd8272c44d9f4e9dce
  • Loading branch information
zachgk authored Feb 24, 2021
1 parent b7f72bd commit b4a93e4
Show file tree
Hide file tree
Showing 14 changed files with 126 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
*/
package ai.djl.examples.training;

import ai.djl.Device;
import ai.djl.Model;
import ai.djl.examples.training.util.Arguments;
import ai.djl.modality.nlp.preprocess.UnicodeNormalizer;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
Expand All @@ -25,6 +27,7 @@
import ai.djl.training.ParallelTrain;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
import ai.djl.training.TrainingResult;
import ai.djl.training.dataset.Batch;
import ai.djl.training.initializer.TruncatedNormalInitializer;
import ai.djl.training.listener.TrainingListener.Defaults;
Expand Down Expand Up @@ -78,6 +81,11 @@ public final class TrainBertOnCode {
private TrainBertOnCode() {}

public static void main(String[] args) {
TrainBertOnCode.runExample(args);
}

public static TrainingResult runExample(String[] args) {
BertArguments arguments = (BertArguments) new BertArguments().parseArgs(args);
Random rand = new Random(89724308);
// get all applicable files
List<Path> files = listSourceFiles(new File(".").toPath());
Expand All @@ -89,25 +97,35 @@ public static void main(String[] args) {
Dictionary dictionary = buildDictionary(countedTokens, 35000);

// Create model & trainer
Model model = createBertPretrainingModel(dictionary);
Trainer trainer = createBertPretrainingTrainer(model);

// Initialize training
Shape inputShape = new Shape(MAX_SEQUENCE_LENGTH, 512);
trainer.initialize(inputShape, inputShape, inputShape, inputShape);
ParallelTrain parallelTrain = new ParallelTrain(trainer.getDevices());

for (int epoch = 0; epoch < EPOCHS; ++epoch) {
List<MaskedInstance> maskedInstances = createEpochData(rand, dictionary, parsedFiles);
for (int idx = BATCH_SIZE; idx < maskedInstances.size(); ++idx) {
try (NDManager ndManager = trainer.getManager().newSubManager()) {
List<MaskedInstance> batchData = maskedInstances.subList(idx - BATCH_SIZE, idx);
Batch batch = createBatch(ndManager, batchData);
// the following uses the GPUs alternating
// EasyTrain.trainBatch(trainer, batch);
// this actually uses both GPUs at once
parallelTrain.trainBatch(trainer, batch);
try (Model model = createBertPretrainingModel(dictionary)) {

try (Trainer trainer = createBertPretrainingTrainer(model, arguments)) {

// Initialize training
Shape inputShape = new Shape(MAX_SEQUENCE_LENGTH, 512);
trainer.initialize(inputShape, inputShape, inputShape, inputShape);
ParallelTrain parallelTrain = new ParallelTrain(trainer.getDevices());

trainer.notifyListeners(listener -> listener.onTrainingBegin(trainer));
for (int epoch = 0; epoch < arguments.getEpoch(); ++epoch) {
List<MaskedInstance> maskedInstances =
createEpochData(rand, dictionary, parsedFiles, arguments);
for (int idx = BATCH_SIZE; idx < maskedInstances.size(); ++idx) {
try (NDManager ndManager = trainer.getManager().newSubManager()) {
List<MaskedInstance> batchData =
maskedInstances.subList(idx - BATCH_SIZE, idx);
Batch batch =
createBatch(ndManager, batchData, idx, maskedInstances.size());
// the following uses the GPUs alternating
// EasyTrain.trainBatch(trainer, batch);
// this actually uses both GPUs at once
parallelTrain.trainBatch(trainer, batch);
}
}
trainer.notifyListeners(listener -> listener.onEpoch(trainer));
}
trainer.notifyListeners(listener -> listener.onTrainingEnd(trainer));
return trainer.getTrainingResult();
}
}
}
Expand All @@ -121,7 +139,7 @@ private static Model createBertPretrainingModel(Dictionary dictionary) {
return model;
}

private static Trainer createBertPretrainingTrainer(Model model) {
private static Trainer createBertPretrainingTrainer(Model model, BertArguments arguments) {
Tracker learningRateTracker =
WarmUpTracker.builder()
.optWarmUpBeginValue(0f)
Expand All @@ -143,13 +161,16 @@ private static Trainer createBertPretrainingTrainer(Model model) {
TrainingConfig trainingConfig =
new DefaultTrainingConfig(new BertPretrainingLoss())
.optOptimizer(optimizer)
// TODO: why does this not log *anything*?
.optDevices(Device.getDevices(arguments.getMaxGpus()))
.addTrainingListeners(Defaults.logging());
return model.newTrainer(trainingConfig);
}

private static List<MaskedInstance> createEpochData(
Random rand, Dictionary dictionary, List<ParsedFile> parsedFiles) {
Random rand,
Dictionary dictionary,
List<ParsedFile> parsedFiles,
BertArguments arguments) {
// turn data into sentence pairs containing consecutive lines
List<SentencePair> sentencePairs = new ArrayList<>();
parsedFiles.forEach(parsedFile -> parsedFile.addToSentencePairs(sentencePairs));
Expand All @@ -161,6 +182,7 @@ private static List<MaskedInstance> createEpochData(
// Create masked instances for training
return sentencePairs
.stream()
.limit(arguments.getLimit())
.map(
sentencePair ->
new MaskedInstance(
Expand All @@ -172,7 +194,8 @@ private static List<MaskedInstance> createEpochData(
.collect(Collectors.toList());
}

private static Batch createBatch(NDManager ndManager, List<MaskedInstance> instances) {
private static Batch createBatch(
NDManager ndManager, List<MaskedInstance> instances, int idx, int dataSize) {
NDList inputs =
new NDList(
batchFromList(ndManager, instances, MaskedInstance::getTokenIds),
Expand All @@ -184,16 +207,15 @@ private static Batch createBatch(NDManager ndManager, List<MaskedInstance> insta
nextSentenceLabelsFromList(ndManager, instances),
batchFromList(ndManager, instances, MaskedInstance::getMaskedIds),
batchFromList(ndManager, instances, MaskedInstance::getLabelMask));
// TODO: Use batch progress
return new Batch(
ndManager,
inputs,
labels,
instances.size(),
Batchifier.STACK,
Batchifier.STACK,
0,
0);
idx,
dataSize);
}

private static NDArray batchFromList(NDManager ndManager, List<int[]> batchData) {
Expand Down Expand Up @@ -582,4 +604,14 @@ public String getRandomToken(Random rand) {
return tokens.get(rand.nextInt(tokens.size()));
}
}

private static class BertArguments extends Arguments {

@Override
protected void initialize() {
super.initialize();
epoch = EPOCHS;
batchSize = BATCH_SIZE;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public static void main(String[] args) throws IOException, TranslateException {
}

public static TrainingResult runExample(String[] args) throws IOException, TranslateException {
Arguments arguments = Arguments.parseArgs(args);
Arguments arguments = new Arguments().parseArgs(args);
if (arguments == null) {
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public static void main(String[] args) throws IOException, TranslateException {
}

public static TrainingResult runExample(String[] args) throws IOException, TranslateException {
Arguments arguments = Arguments.parseArgs(args);
Arguments arguments = new Arguments().parseArgs(args);
if (arguments == null) {
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public static void main(String[] args) throws IOException, TranslateException {
}

public static TrainingResult runExample(String[] args) throws IOException, TranslateException {
Arguments arguments = Arguments.parseArgs(args);
Arguments arguments = new Arguments().parseArgs(args);
if (arguments == null) {
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ public static void main(String[] args) throws IOException, TranslateException {
}

public static TrainingResult runExample(String[] args) throws IOException, TranslateException {
Arguments arguments = Arguments.parseArgs(args);
Arguments arguments = new Arguments().parseArgs(args);
if (arguments == null) {
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ public static void main(String[] args)
public static TrainingResult runExample(String[] args)
throws IOException, ModelNotFoundException, MalformedModelException,
TranslateException {
Arguments arguments = Arguments.parseArgs(args);
Arguments arguments = new Arguments().parseArgs(args);
if (arguments == null) {
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public static void main(String[] args) throws IOException, TranslateException {
}

public static TrainingResult runExample(String[] args) throws IOException, TranslateException {
Arguments arguments = Arguments.parseArgs(args);
Arguments arguments = new Arguments().parseArgs(args);
if (arguments == null) {
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public static void main(String[] args) throws IOException {
}

public static TrainingResult runExample(String[] args) throws IOException {
Arguments arguments = Arguments.parseArgs(args);
Arguments arguments = new Arguments().parseArgs(args);
if (arguments == null) {
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public static void main(String[] args) throws IOException, TranslateException {
}

public static TrainingResult runExample(String[] args) throws IOException, TranslateException {
Arguments arguments = Arguments.parseArgs(args);
Arguments arguments = new Arguments().parseArgs(args);
if (arguments == null) {
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
import java.util.Arrays;
import java.util.Map;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.DefaultParser;
import org.apache.commons.cli.Option;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
Expand All @@ -71,10 +70,8 @@ public static void main(String[] args)
public static TrainingResult runExample(String[] args)
throws IOException, ParseException, ModelNotFoundException, MalformedModelException,
TranslateException {
Options options = OptimizerArguments.getOptions();
DefaultParser parser = new DefaultParser();
CommandLine cmd = parser.parse(options, args, null, false);
OptimizerArguments arguments = new OptimizerArguments(cmd);
OptimizerArguments arguments =
(OptimizerArguments) new OptimizerArguments().parseArgs(args);

try (Model model = getModel(arguments)) {
// get training dataset
Expand Down Expand Up @@ -240,18 +237,21 @@ private static class OptimizerArguments extends Arguments {

private String optimizer;

public OptimizerArguments(CommandLine cmd) {
super(cmd);
public OptimizerArguments() {}

@Override
protected void setCmd(CommandLine cmd) {
super.setCmd(cmd);
if (cmd.hasOption("optimizer")) {
optimizer = cmd.getOptionValue("optimizer");
} else {
optimizer = "adam";
}
}

public static Options getOptions() {
Options options = Arguments.getOptions();
@Override
public Options getOptions() {
Options options = super.getOptions();
options.addOption(
Option.builder("z")
.longOpt("optimizer")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public static void main(String[] args)

public static TrainingResult runExample(String[] args)
throws IOException, TranslateException, ModelException, URISyntaxException {
Arguments arguments = Arguments.parseArgs(args);
Arguments arguments = new Arguments().parseArgs(args);
if (arguments == null) {
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ public static void main(String[] args) throws ModelException, IOException, Trans

public static TrainingResult runExample(String[] args)
throws IOException, ModelException, TranslateException {
Arguments arguments = Arguments.parseArgs(args);
Arguments arguments = new Arguments().parseArgs(args);
if (arguments == null) {
return null;
}
Expand Down
Loading

0 comments on commit b4a93e4

Please sign in to comment.