Skip to content

Commit

Permalink
allow pytorch stream model loading
Browse files Browse the repository at this point in the history
  • Loading branch information
Qing Lan committed Mar 8, 2021
1 parent 347eb07 commit e110bd0
Show file tree
Hide file tree
Showing 14 changed files with 150 additions and 11 deletions.
9 changes: 9 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDArray.java
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,15 @@ default boolean[] toBooleanArray() {
return ret;
}

/**
* Converts this {@code NDArray} to a String array.
*
* <p>This method is only applicable to the String typed NDArray and not for printing purpose
*
* @return Array of Strings
*/
String[] toStringArray();

/**
* Converts this {@code NDArray} to a Number array based on its {@link DataType}.
*
Expand Down
6 changes: 6 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,12 @@ default NDArray stopGradient() {
throw new UnsupportedOperationException(UNSUPPORTED_MSG);
}

/** {@inheritDoc} */
@Override
default String[] toStringArray() {
throw new UnsupportedOperationException(UNSUPPORTED_MSG);
}

/** {@inheritDoc} */
@Override
default ByteBuffer toByteBuffer() {
Expand Down
5 changes: 5 additions & 0 deletions dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrEngine.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ static Engine newInstance() {
}

private Engine getAlternativeEngine() {
boolean disableAlternative =
Boolean.parseBoolean(System.getProperty("ai.djl.dlr.disable_alternative", "false"));
if (disableAlternative) {
return null;
}
if (alternativeEngine == null) {
Engine engine = Engine.getInstance();
if (engine.getRank() < getRank()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,11 +268,18 @@ public boolean hasGradient() {
return hasGradient;
}

/** {@inheritDoc} */
@Override
public NDArray stopGradient() {
return manager.invoke("stop_gradient", this, null);
}

/** {@inheritDoc} */
@Override
public String[] toStringArray() {
throw new UnsupportedOperationException("String NDArray is not supported!");
}

/** {@inheritDoc} */
@Override
public ByteBuffer toByteBuffer() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ public int getRank() {
}

private Engine getAlternativeEngine() {
boolean disableAlternative =
Boolean.parseBoolean(
System.getProperty("ai.djl.onnx.disable_alternative", "false"));
if (disableAlternative) {
return null;
}
if (alternativeEngine == null) {
Engine engine = Engine.getInstance();
if (engine.getRank() < getRank()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,16 @@
package ai.djl.onnxruntime.engine;

import ai.djl.Device;
import ai.djl.engine.EngineException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrayAdapter;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Arrays;
import java.util.UUID;

Expand Down Expand Up @@ -117,20 +121,35 @@ public void detach() {
manager = OrtNDManager.getSystemManager();
}

/** {@inheritDoc} */
@Override
public String[] toStringArray() {
try {
return (String[]) tensor.getValue();
} catch (OrtException e) {
throw new EngineException(e);
}
}

/** {@inheritDoc} */
@Override
public ByteBuffer toByteBuffer() {
return tensor.getByteBuffer().order(ByteOrder.nativeOrder());
}

/** {@inheritDoc} */
@Override
public String toString() {
if (isClosed) {
return "This array is already closed";
}
return "ND: "
+ getShape()
+ ' '
+ getDevice()
+ ' '
+ getDataType()
+ '\n'
+ Arrays.toString(toArray());
String arrStr;
if (getDataType() == DataType.STRING) {
arrStr = Arrays.toString(toStringArray());
} else {
arrStr = Arrays.toString(toArray());
}
return "ND: " + getShape() + ' ' + getDevice() + ' ' + getDataType() + '\n' + arrStr;
}

/** {@inheritDoc} */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ public static DataType toDataType(OnnxJavaType javaType) {
return DataType.BOOLEAN;
case UNKNOWN:
return DataType.UNKNOWN;
case STRING:
return DataType.STRING;
default:
throw new UnsupportedOperationException("type is not supported: " + javaType);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ public void testOrt() throws TranslateException, ModelException, IOException {
public void testStringTensor()
throws MalformedModelException, ModelNotFoundException, IOException,
TranslateException {
System.setProperty("ai.djl.onnx.disable_alternative", "true");
Criteria<NDList, NDList> criteria =
Criteria.builder()
.setTypes(NDList.class, NDList.class)
Expand All @@ -82,12 +83,15 @@ public void testStringTensor()
.build();
try (ZooModel<NDList, NDList> model = ModelZoo.loadModel(criteria);
Predictor<NDList, NDList> predictor = model.newPredictor()) {
OrtNDManager manager = (OrtNDManager) OrtNDManager.getSystemManager().newSubManager();
OrtNDManager manager = (OrtNDManager) model.getNDManager();
NDArray stringNd =
manager.create(
new String[] {" Re: Jack can't hide from keith@cco.", " I like dogs"},
new Shape(1, 2));
predictor.predict(new NDList(stringNd));
NDList result = predictor.predict(new NDList(stringNd));
Assert.assertEquals(result.size(), 2);
Assert.assertEquals(result.get(0).toLongArray(), new long[] {1});
}
System.clearProperty("ai.djl.onnx.disable_alternative");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ public int getRank() {
}

Engine getAlternativeEngine() {
boolean disableAlternative =
Boolean.parseBoolean(
System.getProperty("ai.djl.paddlepaddle.disable_alternative", "false"));
if (disableAlternative) {
return null;
}
if (alternativeEngine == null) {
Engine engine = Engine.getInstance();
if (engine.getRank() < getRank()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import ai.djl.util.PairList;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
Expand Down Expand Up @@ -101,6 +102,18 @@ public void load(Path modelPath, String prefix, Map<String, ?> options)
}
}

/**
* Load PyTorch model from {@link InputStream}.
*
* <p>Currently, only TorchScript file are supported
*
* @param modelStream the stream of the model file
* @throws IOException model loading error
*/
public void load(InputStream modelStream) throws IOException {
block = JniUtils.loadModule((PtNDManager) manager, modelStream, manager.getDevice(), false);
}

private Path findModelFile(String prefix) {
if (Files.isRegularFile(modelDir)) {
Path file = modelDir;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,12 @@ public ByteBuffer toByteBuffer() {
return JniUtils.getByteBuffer(this);
}

/** {@inheritDoc} */
@Override
public String[] toStringArray() {
throw new UnsupportedOperationException("String NDArray is not supported!");
}

/** {@inheritDoc} */
@Override
public void set(Buffer data) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.pytorch.integration;

import ai.djl.Model;
import ai.djl.inference.Predictor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.Shape;
import ai.djl.pytorch.engine.PtModel;
import ai.djl.translate.NoopTranslator;
import ai.djl.translate.TranslateException;
import java.io.IOException;
import java.net.URL;
import org.testng.Assert;
import org.testng.annotations.Test;

public class PtModelTest {

@Test
public void testLoadFromStream() throws IOException, TranslateException {
URL url =
new URL("https://djl-ai.s3.amazonaws.com/resources/test-models/traced_resnet18.pt");
try (PtModel model = (PtModel) Model.newInstance("test model")) {
model.load(url.openStream());
try (Predictor<NDList, NDList> predictor = model.newPredictor(new NoopTranslator())) {
NDArray array = model.getNDManager().ones(new Shape(1, 3, 224, 224));
NDArray result = predictor.predict(new NDList(array)).singletonOrThrow();
Assert.assertEquals(result.getShape(), new Shape(1, 1000));
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,13 @@ public boolean[] toBooleanArray() {
return result;
}

@Override
public String[] toStringArray() {
// TODO: Parse String Array from bytes[]
throw new UnsupportedOperationException(
"TensorFlow does not supporting printing String NDArray");
}

/** {@inheritDoc} */
@Override
public ByteBuffer toByteBuffer() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ public int getRank() {
}

private Engine getAlternativeEngine() {
boolean disableAlternative =
Boolean.parseBoolean(
System.getProperty("ai.djl.tflite.disable_alternative", "false"));
if (disableAlternative) {
return null;
}
if (alternativeEngine == null) {
Engine engine = Engine.getInstance();
if (engine.getRank() < getRank()) {
Expand All @@ -67,7 +73,7 @@ private Engine getAlternativeEngine() {
/** {@inheritDoc} */
@Override
public String getVersion() {
return "1.4.0";
return "2.4.1";
}

/** {@inheritDoc} */
Expand Down

0 comments on commit e110bd0

Please sign in to comment.