From 79a6720b0c54e4dfed14e16ca8bb7f2c040c7191 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Tue, 22 Jun 2021 12:56:32 -0700 Subject: [PATCH] Fixes #1024, Add back string tensor support (#1040) Change-Id: I7b3f326669dec6739d24131d7507984390817226 --- .../ai/djl/tensorflow/engine/TfNDArray.java | 5 +- .../ai/djl/tensorflow/engine/TfNDManager.java | 8 ++- .../engine/javacpp/JavacppUtils.java | 70 +++++++++++++++---- 3 files changed, 65 insertions(+), 18 deletions(-) diff --git a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java index 60e1974fc71..eb222f190c8 100644 --- a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java +++ b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java @@ -26,6 +26,7 @@ import ai.djl.util.Preconditions; import java.nio.Buffer; import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -171,9 +172,7 @@ public NDArray stopGradient() { /** {@inheritDoc} */ @Override public String[] toStringArray() { - // TODO: Parse String Array from bytes[] - throw new UnsupportedOperationException( - "TensorFlow does not supporting printing String NDArray"); + return new String[] {JavacppUtils.getString(getHandle(), StandardCharsets.UTF_8)}; } /** {@inheritDoc} */ diff --git a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDManager.java b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDManager.java index a200673b8c8..bfe5588817c 100644 --- a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDManager.java +++ b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDManager.java @@ -62,7 +62,6 @@ public NDArray create(Shape shape, DataType dataType) { } /** {@inheritDoc} */ - @SuppressWarnings({"unchecked", "try"}) @Override public TfNDArray create(Buffer data, Shape shape, DataType dataType) { int size = data.remaining(); @@ -99,6 +98,13 @@ public TfNDArray create(Buffer data, Shape shape, DataType dataType) { return new TfNDArray(this, handle); } + /** {@inheritDoc} */ + @Override + public NDArray create(String data) { + TFE_TensorHandle handle = JavacppUtils.createStringTensor(data); + return new TfNDArray(this, handle); + } + /** {@inheritDoc} */ @Override public final Engine getEngine() { diff --git a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/javacpp/JavacppUtils.java b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/javacpp/JavacppUtils.java index 9ef587aecb8..1b7945291b7 100644 --- a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/javacpp/JavacppUtils.java +++ b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/javacpp/JavacppUtils.java @@ -23,10 +23,12 @@ import com.google.protobuf.InvalidProtocolBufferException; import java.nio.ByteBuffer; import java.nio.ByteOrder; +import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; import java.util.regex.Matcher; import java.util.regex.Pattern; import org.bytedeco.javacpp.BytePointer; +import org.bytedeco.javacpp.Loader; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.PointerPointer; import org.bytedeco.javacpp.PointerScope; @@ -45,6 +47,7 @@ import org.tensorflow.internal.c_api.TF_Session; import org.tensorflow.internal.c_api.TF_SessionOptions; import org.tensorflow.internal.c_api.TF_Status; +import org.tensorflow.internal.c_api.TF_TString; import org.tensorflow.internal.c_api.TF_Tensor; import org.tensorflow.internal.c_api.global.tensorflow; import org.tensorflow.proto.framework.ConfigProto; @@ -62,7 +65,7 @@ private JavacppUtils() {} @SuppressWarnings({"unchecked", "try"}) public static SavedModelBundle loadSavedModelBundle( String exportDir, String[] tags, ConfigProto config, RunOptions runOptions) { - try (PointerScope scope = new PointerScope()) { + try (PointerScope ignored = new PointerScope()) { TF_Status status = TF_Status.newStatus(); // allocate parameters for TF_LoadSessionFromSavedModel @@ -141,7 +144,7 @@ public static TF_Tensor[] runSession( int numInputs = inputTensorHandles.length; int numOutputs = outputOpHandles.length; int numTargets = targetOpHandles.length; - try (PointerScope scope = new PointerScope()) { + try (PointerScope ignored = new PointerScope()) { // TODO: check with sig-jvm if TF_Output here is freed TF_Output inputs = new TF_Output(numInputs); PointerPointer inputValues = new PointerPointer<>(numInputs); @@ -199,7 +202,7 @@ public static TF_Tensor[] runSession( @SuppressWarnings({"unchecked", "try"}) public static TFE_Context createEagerSession( boolean async, int devicePlacementPolicy, ConfigProto config) { - try (PointerScope scope = new PointerScope()) { + try (PointerScope ignored = new PointerScope()) { TFE_ContextOptions opts = TFE_ContextOptions.newContextOptions(); TF_Status status = TF_Status.newStatus(); if (config != null) { @@ -218,7 +221,7 @@ public static TFE_Context createEagerSession( @SuppressWarnings({"unchecked", "try"}) public static Device getDevice(TFE_TensorHandle handle) { - try (PointerScope scope = new PointerScope()) { + try (PointerScope ignored = new PointerScope()) { TF_Status status = TF_Status.newStatus(); BytePointer pointer = tensorflow.TFE_TensorHandleDeviceName(handle, status); String device = new String(pointer.getStringBytes(), StandardCharsets.UTF_8); @@ -232,7 +235,7 @@ public static DataType getDataType(TFE_TensorHandle handle) { @SuppressWarnings({"unchecked", "try"}) public static Shape getShape(TFE_TensorHandle handle) { - try (PointerScope scope = new PointerScope()) { + try (PointerScope ignored = new PointerScope()) { TF_Status status = TF_Status.newStatus(); int numDims = tensorflow.TFE_TensorHandleNumDims(handle, status); status.throwExceptionIfNotOK(); @@ -258,7 +261,7 @@ public static TF_Tensor createEmptyTFTensor(Shape shape, DataType dataType) { @SuppressWarnings({"unchecked", "try"}) public static TFE_TensorHandle createEmptyTFETensor(Shape shape, DataType dataType) { - try (PointerScope scope = new PointerScope()) { + try (PointerScope ignored = new PointerScope()) { TF_Tensor tensor = createEmptyTFTensor(shape, dataType); TF_Status status = TF_Status.newStatus(); TFE_TensorHandle handle = AbstractTFE_TensorHandle.newTensor(tensor, status); @@ -267,13 +270,36 @@ public static TFE_TensorHandle createEmptyTFETensor(Shape shape, DataType dataTy } } + @SuppressWarnings({"unchecked", "try"}) + public static TFE_TensorHandle createStringTensor(String src) { + int dType = TfDataType.toTf(DataType.STRING); + long[] dims = {}; + long numBytes = Loader.sizeof(TF_TString.class); + try (PointerScope ignored = new PointerScope()) { + TF_Tensor tensor = AbstractTF_Tensor.allocateTensor(dType, dims, numBytes); + Pointer pointer = tensorflow.TF_TensorData(tensor).capacity(numBytes); + TF_TString data = new TF_TString(pointer).capacity(pointer.position() + 1); + byte[] buf = src.getBytes(StandardCharsets.UTF_8); + tensorflow.TF_TString_Copy(data, new BytePointer(buf), buf.length); + TF_Status status = TF_Status.newStatus(); + TFE_TensorHandle handle = AbstractTFE_TensorHandle.newTensor(tensor, status); + status.throwExceptionIfNotOK(); + return handle.retainReference(); + } + } + @SuppressWarnings({"unchecked", "try"}) public static TFE_TensorHandle createTFETensorFromByteBuffer( ByteBuffer buf, Shape shape, DataType dataType) { int dType = TfDataType.toTf(dataType); long[] dims = shape.getShape(); - long numBytes = shape.size() * dataType.getNumOfBytes(); - try (PointerScope scope = new PointerScope()) { + long numBytes; + if (dataType == DataType.STRING) { + numBytes = buf.remaining() + 1; + } else { + numBytes = shape.size() * dataType.getNumOfBytes(); + } + try (PointerScope ignored = new PointerScope()) { TF_Tensor tensor = AbstractTF_Tensor.allocateTensor(dType, dims, numBytes); // get data pointer in native engine Pointer pointer = tensorflow.TF_TensorData(tensor).capacity(numBytes); @@ -287,7 +313,7 @@ public static TFE_TensorHandle createTFETensorFromByteBuffer( @SuppressWarnings({"unchecked", "try"}) public static TF_Tensor resolveTFETensor(TFE_TensorHandle handle) { - try (PointerScope scope = new PointerScope()) { + try (PointerScope ignored = new PointerScope()) { TF_Status status = TF_Status.newStatus(); TF_Tensor tensor = tensorflow.TFE_TensorHandleResolve(handle, status).withDeallocator(); status.throwExceptionIfNotOK(); @@ -297,7 +323,7 @@ public static TF_Tensor resolveTFETensor(TFE_TensorHandle handle) { @SuppressWarnings({"unchecked", "try"}) public static TFE_TensorHandle createTFETensor(TF_Tensor handle) { - try (PointerScope scope = new PointerScope()) { + try (PointerScope ignored = new PointerScope()) { TF_Status status = TF_Status.newStatus(); TFE_TensorHandle tensor = AbstractTFE_TensorHandle.newTensor(handle, status); status.throwExceptionIfNotOK(); @@ -305,9 +331,26 @@ public static TFE_TensorHandle createTFETensor(TF_Tensor handle) { } } + @SuppressWarnings({"unchecked", "try"}) + public static String getString(TFE_TensorHandle handle, Charset charset) { + try (PointerScope ignored = new PointerScope()) { + // convert to TF_Tensor + TF_Status status = TF_Status.newStatus(); + TF_Tensor tensor = tensorflow.TFE_TensorHandleResolve(handle, status).withDeallocator(); + status.throwExceptionIfNotOK(); + + Pointer pointer = + tensorflow.TF_TensorData(tensor).capacity(tensorflow.TF_TensorByteSize(tensor)); + + TF_TString data = new TF_TString(pointer).capacity(pointer.position() + 1); + BytePointer bp = tensorflow.TF_TString_GetDataPointer(data); + return bp.getString(charset); + } + } + @SuppressWarnings({"unchecked", "try"}) public static ByteBuffer getByteBuffer(TFE_TensorHandle handle) { - try (PointerScope scope = new PointerScope()) { + try (PointerScope ignored = new PointerScope()) { // convert to TF_Tensor TF_Status status = TF_Status.newStatus(); TF_Tensor tensor = tensorflow.TFE_TensorHandleResolve(handle, status).withDeallocator(); @@ -328,7 +371,7 @@ public static ByteBuffer getByteBuffer(TFE_TensorHandle handle) { @SuppressWarnings({"unchecked", "try"}) public static TFE_TensorHandle toDevice( TFE_TensorHandle handle, TFE_Context eagerSessionHandle, Device device) { - try (PointerScope scope = new PointerScope()) { + try (PointerScope ignored = new PointerScope()) { String deviceName = toTfDevice(device); TF_Status status = TF_Status.newStatus(); TFE_TensorHandle newHandle = @@ -372,8 +415,7 @@ public static String toTfDevice(Device device) { } else if (device.getDeviceType().equals(Device.Type.GPU)) { return "/device:GPU:" + device.getDeviceId(); } else { - throw new EngineException( - "Unknown device type to TensorFlow Engine: " + device.toString()); + throw new EngineException("Unknown device type to TensorFlow Engine: " + device); } } }