Skip to content

Commit

Permalink
[example] Enable PyTorch for some training example (#3398)
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu authored Aug 9, 2024
1 parent 11ff0c1 commit 7ac48dd
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,22 @@ public class TrainCaptchaTest {
public void testTrainCaptcha() throws IOException, TranslateException {
TestRequirements.linux();

// TODO: PyTorch -- cuDNN error: CUDNN_STATUS_VERSION_MISMATCH
// TODO: PyTorch
/*
ai.djl.engine.EngineException: index 11 is out of bounds for dimension 1 with size 11
at app//ai.djl.pytorch.jni.PyTorchLibrary.torchGather(Native Method)
at app//ai.djl.pytorch.jni.JniUtils.pick(JniUtils.java:581)
at app//ai.djl.pytorch.jni.JniUtils.indexAdv(JniUtils.java:417)
at app//ai.djl.pytorch.engine.PtNDArrayIndexer.get(PtNDArrayIndexer.java:74)
at app//ai.djl.ndarray.NDArray.get(NDArray.java:614)
at app//ai.djl.ndarray.NDArray.get(NDArray.java:603)
at app//ai.djl.training.loss.SoftmaxCrossEntropyLoss.evaluate(SoftmaxCrossEntropyLoss.java:86)
at app//ai.djl.training.loss.IndexLoss.evaluate(IndexLoss.java:55)
at app//ai.djl.training.loss.AbstractCompositeLoss.evaluate(AbstractCompositeLoss.java:68)
at app//ai.djl.training.EasyTrain.trainSplit(EasyTrain.java:124)
at app//ai.djl.training.EasyTrain.trainBatch(EasyTrain.java:110)
at app//ai.djl.training.EasyTrain.fit(EasyTrain.java:58)
*/
String[] args = new String[] {"-g", "1", "-e", "1", "-m", "2", "--engine", "MXNet"};
TrainingResult result = TrainCaptcha.runExample(args);
Assert.assertNotNull(result);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
*/
package ai.djl.examples.training;

import ai.djl.engine.Engine;
import ai.djl.training.TrainingResult;
import ai.djl.translate.TranslateException;

Expand All @@ -25,14 +24,7 @@ public class TrainMnistWithLSTMTest {

@Test
public void testTrainMnistWithLSTM() throws IOException, TranslateException {
String[] args;
Engine engine = Engine.getEngine("PyTorch");
if (engine.getGpuCount() > 0) {
// TODO: PyTorch -- cuDNN error: CUDNN_STATUS_VERSION_MISMATCH
args = new String[] {"-g", "1", "-e", "1", "-m", "2", "--engine", "MXNet"};
} else {
args = new String[] {"-g", "1", "-e", "1", "-m", "2"};
}
String[] args = {"-g", "1", "-e", "1", "-m", "2"};
TrainingResult result = TrainMnistWithLSTM.runExample(args);
Assert.assertNotNull(result);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ public void testTrainResNet() throws ModelException, IOException, TranslateExcep

// Limit max 4 gpu for cifar10 training to make it converge faster.
// and only train 10 batch for unit test.
// only MXNet support symbolic model
String[] args = {"-e", "2", "-g", "4", "-m", "10", "-p"};
TrainingResult result = TrainResnetWithCifar10.runExample(args);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ public void testTrainSentimentAnalysis()
TestRequirements.nightly();
TestRequirements.gpu("MXNet", 1);

// TODO: Add a PyTorch Glove model to model zoo
String[] args = {"-e", "1", "-g", "1", "--engine", "MXNet"};
TrainSentimentAnalysis.runExample(args);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

package ai.djl.examples.training;

import ai.djl.engine.Engine;
import ai.djl.training.TrainingResult;
import ai.djl.translate.TranslateException;

Expand All @@ -26,15 +25,7 @@ public class TrainTimeSeriesTest {

@Test
public void testTrainTimeSeries() throws TranslateException, IOException {
String[] args;
Engine engine = Engine.getEngine("PyTorch");
if (engine.getGpuCount() > 0) {
// TODO: PyTorch -- cuDNN error: CUDNN_STATUS_VERSION_MISMATCH
args = new String[] {"-g", "1", "-e", "5", "-b", "32", "--engine", "MXNet"};
} else {
args = new String[] {"-g", "1", "-e", "5", "-b", "32"};
}

String[] args = {"-g", "1", "-e", "5", "-b", "32"};
TrainingResult result = TrainTimeSeries.runExample(args);
Assert.assertNotNull(result);
float loss = result.getTrainLoss();
Expand Down

0 comments on commit 7ac48dd

Please sign in to comment.