diff --git a/examples/src/main/java/ai/djl/examples/training/TrainBertOnCode.java b/examples/src/main/java/ai/djl/examples/training/TrainBertOnCode.java index 2fa22dfd2c4..53eca2e14b4 100644 --- a/examples/src/main/java/ai/djl/examples/training/TrainBertOnCode.java +++ b/examples/src/main/java/ai/djl/examples/training/TrainBertOnCode.java @@ -12,7 +12,9 @@ */ package ai.djl.examples.training; +import ai.djl.Device; import ai.djl.Model; +import ai.djl.examples.training.util.Arguments; import ai.djl.modality.nlp.preprocess.UnicodeNormalizer; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; @@ -25,6 +27,7 @@ import ai.djl.training.ParallelTrain; import ai.djl.training.Trainer; import ai.djl.training.TrainingConfig; +import ai.djl.training.TrainingResult; import ai.djl.training.dataset.Batch; import ai.djl.training.initializer.TruncatedNormalInitializer; import ai.djl.training.listener.TrainingListener.Defaults; @@ -78,6 +81,11 @@ public final class TrainBertOnCode { private TrainBertOnCode() {} public static void main(String[] args) { + TrainBertOnCode.runExample(args); + } + + public static TrainingResult runExample(String[] args) { + BertArguments arguments = (BertArguments) new BertArguments().parseArgs(args); Random rand = new Random(89724308); // get all applicable files List files = listSourceFiles(new File(".").toPath()); @@ -89,25 +97,35 @@ public static void main(String[] args) { Dictionary dictionary = buildDictionary(countedTokens, 35000); // Create model & trainer - Model model = createBertPretrainingModel(dictionary); - Trainer trainer = createBertPretrainingTrainer(model); - - // Initialize training - Shape inputShape = new Shape(MAX_SEQUENCE_LENGTH, 512); - trainer.initialize(inputShape, inputShape, inputShape, inputShape); - ParallelTrain parallelTrain = new ParallelTrain(trainer.getDevices()); - - for (int epoch = 0; epoch < EPOCHS; ++epoch) { - List maskedInstances = createEpochData(rand, dictionary, parsedFiles); - for (int idx = BATCH_SIZE; idx < maskedInstances.size(); ++idx) { - try (NDManager ndManager = trainer.getManager().newSubManager()) { - List batchData = maskedInstances.subList(idx - BATCH_SIZE, idx); - Batch batch = createBatch(ndManager, batchData); - // the following uses the GPUs alternating - // EasyTrain.trainBatch(trainer, batch); - // this actually uses both GPUs at once - parallelTrain.trainBatch(trainer, batch); + try (Model model = createBertPretrainingModel(dictionary)) { + + try (Trainer trainer = createBertPretrainingTrainer(model, arguments)) { + + // Initialize training + Shape inputShape = new Shape(MAX_SEQUENCE_LENGTH, 512); + trainer.initialize(inputShape, inputShape, inputShape, inputShape); + ParallelTrain parallelTrain = new ParallelTrain(trainer.getDevices()); + + trainer.notifyListeners(listener -> listener.onTrainingBegin(trainer)); + for (int epoch = 0; epoch < arguments.getEpoch(); ++epoch) { + List maskedInstances = + createEpochData(rand, dictionary, parsedFiles, arguments); + for (int idx = BATCH_SIZE; idx < maskedInstances.size(); ++idx) { + try (NDManager ndManager = trainer.getManager().newSubManager()) { + List batchData = + maskedInstances.subList(idx - BATCH_SIZE, idx); + Batch batch = + createBatch(ndManager, batchData, idx, maskedInstances.size()); + // the following uses the GPUs alternating + // EasyTrain.trainBatch(trainer, batch); + // this actually uses both GPUs at once + parallelTrain.trainBatch(trainer, batch); + } + } + trainer.notifyListeners(listener -> listener.onEpoch(trainer)); } + trainer.notifyListeners(listener -> listener.onTrainingEnd(trainer)); + return trainer.getTrainingResult(); } } } @@ -121,7 +139,7 @@ private static Model createBertPretrainingModel(Dictionary dictionary) { return model; } - private static Trainer createBertPretrainingTrainer(Model model) { + private static Trainer createBertPretrainingTrainer(Model model, BertArguments arguments) { Tracker learningRateTracker = WarmUpTracker.builder() .optWarmUpBeginValue(0f) @@ -143,13 +161,16 @@ private static Trainer createBertPretrainingTrainer(Model model) { TrainingConfig trainingConfig = new DefaultTrainingConfig(new BertPretrainingLoss()) .optOptimizer(optimizer) - // TODO: why does this not log *anything*? + .optDevices(Device.getDevices(arguments.getMaxGpus())) .addTrainingListeners(Defaults.logging()); return model.newTrainer(trainingConfig); } private static List createEpochData( - Random rand, Dictionary dictionary, List parsedFiles) { + Random rand, + Dictionary dictionary, + List parsedFiles, + BertArguments arguments) { // turn data into sentence pairs containing consecutive lines List sentencePairs = new ArrayList<>(); parsedFiles.forEach(parsedFile -> parsedFile.addToSentencePairs(sentencePairs)); @@ -161,6 +182,7 @@ private static List createEpochData( // Create masked instances for training return sentencePairs .stream() + .limit(arguments.getLimit()) .map( sentencePair -> new MaskedInstance( @@ -172,7 +194,8 @@ private static List createEpochData( .collect(Collectors.toList()); } - private static Batch createBatch(NDManager ndManager, List instances) { + private static Batch createBatch( + NDManager ndManager, List instances, int idx, int dataSize) { NDList inputs = new NDList( batchFromList(ndManager, instances, MaskedInstance::getTokenIds), @@ -184,7 +207,6 @@ private static Batch createBatch(NDManager ndManager, List insta nextSentenceLabelsFromList(ndManager, instances), batchFromList(ndManager, instances, MaskedInstance::getMaskedIds), batchFromList(ndManager, instances, MaskedInstance::getLabelMask)); - // TODO: Use batch progress return new Batch( ndManager, inputs, @@ -192,8 +214,8 @@ private static Batch createBatch(NDManager ndManager, List insta instances.size(), Batchifier.STACK, Batchifier.STACK, - 0, - 0); + idx, + dataSize); } private static NDArray batchFromList(NDManager ndManager, List batchData) { @@ -582,4 +604,14 @@ public String getRandomToken(Random rand) { return tokens.get(rand.nextInt(tokens.size())); } } + + private static class BertArguments extends Arguments { + + @Override + protected void initialize() { + super.initialize(); + epoch = EPOCHS; + batchSize = BATCH_SIZE; + } + } } diff --git a/examples/src/main/java/ai/djl/examples/training/TrainCaptcha.java b/examples/src/main/java/ai/djl/examples/training/TrainCaptcha.java index c02cb61bfc9..54942f6f917 100644 --- a/examples/src/main/java/ai/djl/examples/training/TrainCaptcha.java +++ b/examples/src/main/java/ai/djl/examples/training/TrainCaptcha.java @@ -55,7 +55,7 @@ public static void main(String[] args) throws IOException, TranslateException { } public static TrainingResult runExample(String[] args) throws IOException, TranslateException { - Arguments arguments = Arguments.parseArgs(args); + Arguments arguments = new Arguments().parseArgs(args); if (arguments == null) { return null; } diff --git a/examples/src/main/java/ai/djl/examples/training/TrainMnist.java b/examples/src/main/java/ai/djl/examples/training/TrainMnist.java index cdb4b76f455..5759857c8dc 100644 --- a/examples/src/main/java/ai/djl/examples/training/TrainMnist.java +++ b/examples/src/main/java/ai/djl/examples/training/TrainMnist.java @@ -50,7 +50,7 @@ public static void main(String[] args) throws IOException, TranslateException { } public static TrainingResult runExample(String[] args) throws IOException, TranslateException { - Arguments arguments = Arguments.parseArgs(args); + Arguments arguments = new Arguments().parseArgs(args); if (arguments == null) { return null; } diff --git a/examples/src/main/java/ai/djl/examples/training/TrainMnistWithLSTM.java b/examples/src/main/java/ai/djl/examples/training/TrainMnistWithLSTM.java index 894fc818a0d..13ce0f2f742 100644 --- a/examples/src/main/java/ai/djl/examples/training/TrainMnistWithLSTM.java +++ b/examples/src/main/java/ai/djl/examples/training/TrainMnistWithLSTM.java @@ -48,7 +48,7 @@ public static void main(String[] args) throws IOException, TranslateException { } public static TrainingResult runExample(String[] args) throws IOException, TranslateException { - Arguments arguments = Arguments.parseArgs(args); + Arguments arguments = new Arguments().parseArgs(args); if (arguments == null) { return null; } diff --git a/examples/src/main/java/ai/djl/examples/training/TrainPikachu.java b/examples/src/main/java/ai/djl/examples/training/TrainPikachu.java index f15130f27e1..c0a23f9a8a0 100644 --- a/examples/src/main/java/ai/djl/examples/training/TrainPikachu.java +++ b/examples/src/main/java/ai/djl/examples/training/TrainPikachu.java @@ -71,7 +71,7 @@ public static void main(String[] args) throws IOException, TranslateException { } public static TrainingResult runExample(String[] args) throws IOException, TranslateException { - Arguments arguments = Arguments.parseArgs(args); + Arguments arguments = new Arguments().parseArgs(args); if (arguments == null) { return null; } diff --git a/examples/src/main/java/ai/djl/examples/training/TrainSentimentAnalysis.java b/examples/src/main/java/ai/djl/examples/training/TrainSentimentAnalysis.java index 60926df8364..6c760d18841 100644 --- a/examples/src/main/java/ai/djl/examples/training/TrainSentimentAnalysis.java +++ b/examples/src/main/java/ai/djl/examples/training/TrainSentimentAnalysis.java @@ -86,7 +86,7 @@ public static void main(String[] args) public static TrainingResult runExample(String[] args) throws IOException, ModelNotFoundException, MalformedModelException, TranslateException { - Arguments arguments = Arguments.parseArgs(args); + Arguments arguments = new Arguments().parseArgs(args); if (arguments == null) { return null; } diff --git a/examples/src/main/java/ai/djl/examples/training/TrainSeq2Seq.java b/examples/src/main/java/ai/djl/examples/training/TrainSeq2Seq.java index ddece45b821..eaada6e0088 100644 --- a/examples/src/main/java/ai/djl/examples/training/TrainSeq2Seq.java +++ b/examples/src/main/java/ai/djl/examples/training/TrainSeq2Seq.java @@ -64,7 +64,7 @@ public static void main(String[] args) throws IOException, TranslateException { } public static TrainingResult runExample(String[] args) throws IOException, TranslateException { - Arguments arguments = Arguments.parseArgs(args); + Arguments arguments = new Arguments().parseArgs(args); if (arguments == null) { return null; } diff --git a/examples/src/main/java/ai/djl/examples/training/TrainTicTacToe.java b/examples/src/main/java/ai/djl/examples/training/TrainTicTacToe.java index 2d8dcaa16f2..063e2c0f54b 100644 --- a/examples/src/main/java/ai/djl/examples/training/TrainTicTacToe.java +++ b/examples/src/main/java/ai/djl/examples/training/TrainTicTacToe.java @@ -57,7 +57,7 @@ public static void main(String[] args) throws IOException { } public static TrainingResult runExample(String[] args) throws IOException { - Arguments arguments = Arguments.parseArgs(args); + Arguments arguments = new Arguments().parseArgs(args); if (arguments == null) { return null; } diff --git a/examples/src/main/java/ai/djl/examples/training/TrainWithHpo.java b/examples/src/main/java/ai/djl/examples/training/TrainWithHpo.java index 1a681d9e60a..d2b3512a8f8 100644 --- a/examples/src/main/java/ai/djl/examples/training/TrainWithHpo.java +++ b/examples/src/main/java/ai/djl/examples/training/TrainWithHpo.java @@ -50,7 +50,7 @@ public static void main(String[] args) throws IOException, TranslateException { } public static TrainingResult runExample(String[] args) throws IOException, TranslateException { - Arguments arguments = Arguments.parseArgs(args); + Arguments arguments = new Arguments().parseArgs(args); if (arguments == null) { return null; } diff --git a/examples/src/main/java/ai/djl/examples/training/TrainWithOptimizers.java b/examples/src/main/java/ai/djl/examples/training/TrainWithOptimizers.java index d8e74e1d620..be21aef2a1c 100644 --- a/examples/src/main/java/ai/djl/examples/training/TrainWithOptimizers.java +++ b/examples/src/main/java/ai/djl/examples/training/TrainWithOptimizers.java @@ -52,7 +52,6 @@ import java.util.Arrays; import java.util.Map; import org.apache.commons.cli.CommandLine; -import org.apache.commons.cli.DefaultParser; import org.apache.commons.cli.Option; import org.apache.commons.cli.Options; import org.apache.commons.cli.ParseException; @@ -71,10 +70,8 @@ public static void main(String[] args) public static TrainingResult runExample(String[] args) throws IOException, ParseException, ModelNotFoundException, MalformedModelException, TranslateException { - Options options = OptimizerArguments.getOptions(); - DefaultParser parser = new DefaultParser(); - CommandLine cmd = parser.parse(options, args, null, false); - OptimizerArguments arguments = new OptimizerArguments(cmd); + OptimizerArguments arguments = + (OptimizerArguments) new OptimizerArguments().parseArgs(args); try (Model model = getModel(arguments)) { // get training dataset @@ -240,9 +237,11 @@ private static class OptimizerArguments extends Arguments { private String optimizer; - public OptimizerArguments(CommandLine cmd) { - super(cmd); + public OptimizerArguments() {} + @Override + protected void setCmd(CommandLine cmd) { + super.setCmd(cmd); if (cmd.hasOption("optimizer")) { optimizer = cmd.getOptionValue("optimizer"); } else { @@ -250,8 +249,9 @@ public OptimizerArguments(CommandLine cmd) { } } - public static Options getOptions() { - Options options = Arguments.getOptions(); + @Override + public Options getOptions() { + Options options = super.getOptions(); options.addOption( Option.builder("z") .longOpt("optimizer") diff --git a/examples/src/main/java/ai/djl/examples/training/transferlearning/TrainAmazonReviewRanking.java b/examples/src/main/java/ai/djl/examples/training/transferlearning/TrainAmazonReviewRanking.java index 59601ea7128..0ccddeb4336 100644 --- a/examples/src/main/java/ai/djl/examples/training/transferlearning/TrainAmazonReviewRanking.java +++ b/examples/src/main/java/ai/djl/examples/training/transferlearning/TrainAmazonReviewRanking.java @@ -66,7 +66,7 @@ public static void main(String[] args) public static TrainingResult runExample(String[] args) throws IOException, TranslateException, ModelException, URISyntaxException { - Arguments arguments = Arguments.parseArgs(args); + Arguments arguments = new Arguments().parseArgs(args); if (arguments == null) { return null; } diff --git a/examples/src/main/java/ai/djl/examples/training/transferlearning/TrainResnetWithCifar10.java b/examples/src/main/java/ai/djl/examples/training/transferlearning/TrainResnetWithCifar10.java index ae291aa2592..455a445700a 100644 --- a/examples/src/main/java/ai/djl/examples/training/transferlearning/TrainResnetWithCifar10.java +++ b/examples/src/main/java/ai/djl/examples/training/transferlearning/TrainResnetWithCifar10.java @@ -77,7 +77,7 @@ public static void main(String[] args) throws ModelException, IOException, Trans public static TrainingResult runExample(String[] args) throws IOException, ModelException, TranslateException { - Arguments arguments = Arguments.parseArgs(args); + Arguments arguments = new Arguments().parseArgs(args); if (arguments == null) { return null; } diff --git a/examples/src/main/java/ai/djl/examples/training/util/Arguments.java b/examples/src/main/java/ai/djl/examples/training/util/Arguments.java index c8f4216053a..4da3cf42503 100644 --- a/examples/src/main/java/ai/djl/examples/training/util/Arguments.java +++ b/examples/src/main/java/ai/djl/examples/training/util/Arguments.java @@ -26,23 +26,28 @@ public class Arguments { - private int epoch; - private int batchSize; - private int maxGpus; - private boolean isSymbolic; - private boolean preTrained; - private String outputDir; - private long limit; - private String modelDir; - private Map criteria; - - public Arguments(CommandLine cmd) { + protected int epoch; + protected int batchSize; + protected int maxGpus; + protected boolean isSymbolic; + protected boolean preTrained; + protected String outputDir; + protected long limit; + protected String modelDir; + protected Map criteria; + + protected void initialize() { + epoch = 2; + maxGpus = Device.getGpuCount(); + outputDir = "build/model"; + limit = Long.MAX_VALUE; + modelDir = null; + } + + protected void setCmd(CommandLine cmd) { if (cmd.hasOption("epoch")) { epoch = Integer.parseInt(cmd.getOptionValue("epoch")); - } else { - epoch = 2; } - maxGpus = Device.getGpuCount(); if (cmd.hasOption("max-gpus")) { maxGpus = Math.min(Integer.parseInt(cmd.getOptionValue("max-gpus")), maxGpus); } @@ -56,18 +61,12 @@ public Arguments(CommandLine cmd) { if (cmd.hasOption("output-dir")) { outputDir = cmd.getOptionValue("output-dir"); - } else { - outputDir = "build/model"; } if (cmd.hasOption("max-batches")) { limit = Long.parseLong(cmd.getOptionValue("max-batches")) * batchSize; - } else { - limit = Long.MAX_VALUE; } if (cmd.hasOption("model-dir")) { modelDir = cmd.getOptionValue("model-dir"); - } else { - modelDir = null; } if (cmd.hasOption("criteria")) { Type type = new TypeToken>() {}.getType(); @@ -75,8 +74,9 @@ public Arguments(CommandLine cmd) { } } - public static Arguments parseArgs(String[] args) { - Options options = Arguments.getOptions(); + public Arguments parseArgs(String[] args) { + initialize(); + Options options = getOptions(); try { DefaultParser parser = new DefaultParser(); CommandLine cmd = parser.parse(options, args, null, false); @@ -84,14 +84,15 @@ public static Arguments parseArgs(String[] args) { printHelp("./gradlew run --args='[OPTIONS]'", options); return null; } - return new Arguments(cmd); + setCmd(cmd); + return this; } catch (ParseException e) { printHelp("./gradlew run --args='[OPTIONS]'", options); } return null; } - public static Options getOptions() { + public Options getOptions() { Options options = new Options(); options.addOption( Option.builder("h").longOpt("help").hasArg(false).desc("Print this help.").build()); @@ -196,7 +197,7 @@ public Map getCriteria() { return criteria; } - private static void printHelp(String msg, Options options) { + private void printHelp(String msg, Options options) { HelpFormatter formatter = new HelpFormatter(); formatter.setLeftPadding(1); formatter.setWidth(120); diff --git a/examples/src/test/java/ai/djl/examples/training/TrainBertTest.java b/examples/src/test/java/ai/djl/examples/training/TrainBertTest.java new file mode 100644 index 00000000000..a28d67c2040 --- /dev/null +++ b/examples/src/test/java/ai/djl/examples/training/TrainBertTest.java @@ -0,0 +1,24 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.examples.training; + +import org.testng.annotations.Test; + +public class TrainBertTest { + + @Test + public void testTrainBert() { + String[] args = new String[] {"-g", "1", "-m", "1", "-e", "1"}; + TrainBertOnCode.runExample(args); + } +}