From 887c2636047920ad8d2bad38b60f4db8441c44d9 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Tue, 13 Apr 2021 10:34:49 -0700 Subject: [PATCH] Remove 2nd engine dependency from IrisTranslator Change-Id: I55755277e82bb160afdc56bbe283d8f3ec68c732 --- onnxruntime/onnxruntime-engine/build.gradle | 4 ++-- .../softmax_regression/IrisClassificationModelLoader.java | 8 +++++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/onnxruntime/onnxruntime-engine/build.gradle b/onnxruntime/onnxruntime-engine/build.gradle index 759c60e3f92..8b1adcbbc60 100644 --- a/onnxruntime/onnxruntime-engine/build.gradle +++ b/onnxruntime/onnxruntime-engine/build.gradle @@ -8,8 +8,8 @@ dependencies { exclude group: "junit", module: "junit" } - testRuntimeOnly project(":pytorch:pytorch-engine") - testRuntimeOnly "ai.djl.pytorch:pytorch-native-auto:${pytorch_version}-SNAPSHOT" + // testRuntimeOnly project(":pytorch:pytorch-engine") + // testRuntimeOnly "ai.djl.pytorch:pytorch-native-auto:${pytorch_version}-SNAPSHOT" testRuntimeOnly "org.slf4j:slf4j-simple:${slf4j_version}" } diff --git a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/tabular/softmax_regression/IrisClassificationModelLoader.java b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/tabular/softmax_regression/IrisClassificationModelLoader.java index 8b48fbb393b..8c514175ac9 100644 --- a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/tabular/softmax_regression/IrisClassificationModelLoader.java +++ b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/tabular/softmax_regression/IrisClassificationModelLoader.java @@ -33,6 +33,7 @@ import ai.djl.translate.TranslatorFactory; import ai.djl.util.Pair; import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Map; @@ -110,7 +111,12 @@ public NDList processInput(TranslatorContext ctx, IrisFlower input) { /** {@inheritDoc} */ @Override public Classifications processOutput(TranslatorContext ctx, NDList list) { - return new Classifications(synset, list.get(1)); + float[] data = list.get(1).toFloatArray(); + List probabilities = new ArrayList<>(data.length); + for (float f : data) { + probabilities.add((double) f); + } + return new Classifications(synset, probabilities); } /** {@inheritDoc} */