Skip to content

Commit

Permalink
Fix bug introduced by multi-engine feature
Browse files Browse the repository at this point in the history
Change-Id: I17d3db91e878adbdaad193b1e435850329e0f930
  • Loading branch information
frankfliu committed May 20, 2020
1 parent c43faed commit a8c4bfc
Show file tree
Hide file tree
Showing 9 changed files with 46 additions and 68 deletions.
11 changes: 11 additions & 0 deletions api/src/main/java/ai/djl/engine/Engine.java
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,17 @@ public static Engine getInstance() {
return getEngine(System.getProperty("ai.djl.default_engine", DEFAULT_ENGINE));
}

/**
* Returns if the specified engine is available.
*
* @param engineName the name of Engine to check
* @return {@code true} if the specified engine is available
* @see EngineProvider
*/
public static boolean hasEngine(String engineName) {
return ALL_ENGINES.containsKey(engineName);
}

/**
* Returns the {@code Engine} with the given name.
*
Expand Down
60 changes: 20 additions & 40 deletions api/src/main/java/ai/djl/repository/zoo/BaseModelLoader.java
Original file line number Diff line number Diff line change
Expand Up @@ -104,47 +104,32 @@ public <S, T> ZooModel<S, T> loadModel(Criteria<S, T> criteria)
Path modelPath = repository.getResourceDirectory(artifact);

// Check if the engine is specified in Criteria, use it if it is.
// Otherwise check the modelzoo supported engine and grab the first engine in the list.
// Otherwise check the modelzoo supported engine and grab a random engine in the list.
// Otherwise if none of them is specified or model zoo is null, go to default engine.

String engine = criteria.getEngine();
if (engine == null || engine.isEmpty()) {
if (modelZoo != null) {
engine = Engine.getInstance().getEngineName();
if (!modelZoo.getSupportedEngines().contains(engine)) {
engine = modelZoo.getSupportedEngines().iterator().next();
if (engine == null && modelZoo != null) {
String defaultEngine = Engine.getInstance().getEngineName();
for (String supportedEngine : modelZoo.getSupportedEngines()) {
if (supportedEngine.equals(defaultEngine)) {
engine = supportedEngine;
break;
} else if (Engine.hasEngine(supportedEngine)) {
engine = supportedEngine;
}
}
}

try {
Model model = createModel(Device.defaultDevice(), artifact, arguments, engine);
model.load(modelPath, artifact.getName(), criteria.getOptions());
return new ZooModel<>(model, translator);
} catch (IllegalArgumentException e) {

if (e.getMessage().contains("Deep learning engine not found:")) {
StringBuilder errorMsg = new StringBuilder(200);

errorMsg.append("Your Criteria Filters: ");
errorMsg.append(criteria.getFilters().toString());

errorMsg.append(", Under Model Zoo: ");
errorMsg.append((modelZoo == null) ? "null" : modelZoo.getClass().toString());

errorMsg.append(", Is using Engine: ");
errorMsg.append(
(engine == null || engine.isEmpty())
? Engine.getInstance().getEngineName()
: engine);
errorMsg.append(
", But the engine could not be found, "
+ "please try adding the dependency engine to the gradle file");

throw new UnsupportedOperationException(errorMsg.toString(), e);
if (engine == null) {
throw new ModelNotFoundException(
"No supported engine available for model zoo: "
+ modelZoo.getGroupId());
}
throw e;
}
if (engine != null && !Engine.hasEngine(engine)) {
throw new ModelNotFoundException(engine + " is not supported.");
}

Model model = createModel(Device.defaultDevice(), artifact, arguments, engine);
model.load(modelPath, artifact.getName(), criteria.getOptions());
return new ZooModel<>(model, translator);
} finally {
if (progress != null) {
progress.end();
Expand All @@ -167,11 +152,6 @@ protected Model createModel(
return Model.newInstance(device, engine);
}

protected Model createModel(Device device, Artifact artifact, Map<String, Object> arguments)
throws IOException {
return Model.newInstance(device);
}

/**
* Returns the first artifact that matches a given criteria.
*
Expand Down
3 changes: 1 addition & 2 deletions api/src/main/java/ai/djl/repository/zoo/ModelZoo.java
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,7 @@ static <I, O> ZooModel<I, O> loadModel(Criteria<I, O> criteria)

Set<String> supportedEngine = zoo.getSupportedEngines();
String engine = criteria.getEngine();

if (engine != null && !engine.isEmpty() && !supportedEngine.contains(engine)) {
if (engine != null && !supportedEngine.contains(engine)) {
continue;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import ai.djl.basicmodelzoo.BasicModelZoo;
import ai.djl.basicmodelzoo.cv.classification.ResNetV1;
import ai.djl.inference.Predictor;
import ai.djl.integration.util.TestUtils;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.ndarray.NDArray;
Expand Down Expand Up @@ -45,6 +46,7 @@
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import org.testng.SkipException;
import org.testng.annotations.Test;

public class ResnetTest {
Expand Down Expand Up @@ -143,6 +145,9 @@ public void testLoadTrain()

private ZooModel<Image, Classifications> getModel()
throws IOException, ModelNotFoundException, MalformedModelException {
if (!TestUtils.isMxnet()) {
throw new SkipException("Resnet50-cifar10 model only available in MXNet");
}

Criteria<Image, Classifications> criteria =
Criteria.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import ai.djl.basicdataset.PikachuDetection;
import ai.djl.basicmodelzoo.BasicModelZoo;
import ai.djl.inference.Predictor;
import ai.djl.integration.util.TestUtils;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.MultiBoxDetection;
Expand Down Expand Up @@ -49,6 +50,7 @@
import java.nio.file.Path;
import java.nio.file.Paths;
import org.testng.Assert;
import org.testng.SkipException;
import org.testng.annotations.Test;

public class SingleShotDetectionTest {
Expand Down Expand Up @@ -129,6 +131,9 @@ private TrainingConfig setupTrainingConfig() {

private ZooModel<Image, DetectedObjects> getModel()
throws IOException, ModelNotFoundException, MalformedModelException {
if (!TestUtils.isMxnet()) {
throw new SkipException("SSD-pikachu model only available in MXNet");
}

Criteria<Image, DetectedObjects> criteria =
Criteria.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ public ZooModel<Image, Classifications> loadModel(

/** {@inheritDoc} */
@Override
protected Model createModel(Device device, Artifact artifact, Map<String, Object> arguments) {
protected Model createModel(
Device device, Artifact artifact, Map<String, Object> arguments, String engine) {
int width = ((Double) arguments.getOrDefault("width", 28d)).intValue();
int height = ((Double) arguments.getOrDefault("height", 28d)).intValue();
int input = width * height;
Expand All @@ -119,7 +120,7 @@ protected Model createModel(Device device, Artifact artifact, Map<String, Object
.mapToInt(Double::intValue)
.toArray();

Model model = Model.newInstance(device);
Model model = Model.newInstance(device, engine);
model.setBlock(new Mlp(input, output, hidden));
return model;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ private Block resnetBlock(Map<String, Object> arguments) {
return blockBuilder.build();
}

/** {@inheritDoc} */
@Override
protected Model createModel(
Device device, Artifact artifact, Map<String, Object> arguments, String engine) {
Expand All @@ -139,14 +140,6 @@ protected Model createModel(
return model;
}

/** {@inheritDoc} */
@Override
protected Model createModel(Device device, Artifact artifact, Map<String, Object> arguments) {
Model model = Model.newInstance(device);
model.setBlock(resnetBlock(arguments));
return model;
}

private static final class FactoryImpl implements TranslatorFactory<Image, Classifications> {

/** {@inheritDoc} */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,6 @@ protected Model createModel(
return model;
}

/** {@inheritDoc} */
@Override
protected Model createModel(Device device, Artifact artifact, Map<String, Object> arguments) {
Model model = Model.newInstance(device);
model.setBlock(customSSDBlock(arguments));
return model;
}

private static final class FactoryImpl implements TranslatorFactory<Image, DetectedObjects> {

/** {@inheritDoc} */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,6 @@ protected Model createModel(
return customGloveBlock(model, artifact, arguments);
}

/** {@inheritDoc} */
@Override
protected Model createModel(Device device, Artifact artifact, Map<String, Object> arguments)
throws IOException {
Model model = Model.newInstance(device);
return customGloveBlock(model, artifact, arguments);
}

/** {@inheritDoc} */
@Override
public ZooModel<NDList, NDList> loadModel(
Expand Down

0 comments on commit a8c4bfc

Please sign in to comment.