Skip to content

Commit

Permalink
Fixes #1024, Add back string tensor support (#1040)
Browse files Browse the repository at this point in the history
Change-Id: I7b3f326669dec6739d24131d7507984390817226
  • Loading branch information
frankfliu authored Jun 22, 2021
1 parent 08fdab4 commit 79a6720
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import ai.djl.util.Preconditions;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
Expand Down Expand Up @@ -171,9 +172,7 @@ public NDArray stopGradient() {
/** {@inheritDoc} */
@Override
public String[] toStringArray() {
// TODO: Parse String Array from bytes[]
throw new UnsupportedOperationException(
"TensorFlow does not supporting printing String NDArray");
return new String[] {JavacppUtils.getString(getHandle(), StandardCharsets.UTF_8)};
}

/** {@inheritDoc} */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ public NDArray create(Shape shape, DataType dataType) {
}

/** {@inheritDoc} */
@SuppressWarnings({"unchecked", "try"})
@Override
public TfNDArray create(Buffer data, Shape shape, DataType dataType) {
int size = data.remaining();
Expand Down Expand Up @@ -99,6 +98,13 @@ public TfNDArray create(Buffer data, Shape shape, DataType dataType) {
return new TfNDArray(this, handle);
}

/** {@inheritDoc} */
@Override
public NDArray create(String data) {
TFE_TensorHandle handle = JavacppUtils.createStringTensor(data);
return new TfNDArray(this, handle);
}

/** {@inheritDoc} */
@Override
public final Engine getEngine() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@
import com.google.protobuf.InvalidProtocolBufferException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.Loader;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.PointerPointer;
import org.bytedeco.javacpp.PointerScope;
Expand All @@ -45,6 +47,7 @@
import org.tensorflow.internal.c_api.TF_Session;
import org.tensorflow.internal.c_api.TF_SessionOptions;
import org.tensorflow.internal.c_api.TF_Status;
import org.tensorflow.internal.c_api.TF_TString;
import org.tensorflow.internal.c_api.TF_Tensor;
import org.tensorflow.internal.c_api.global.tensorflow;
import org.tensorflow.proto.framework.ConfigProto;
Expand All @@ -62,7 +65,7 @@ private JavacppUtils() {}
@SuppressWarnings({"unchecked", "try"})
public static SavedModelBundle loadSavedModelBundle(
String exportDir, String[] tags, ConfigProto config, RunOptions runOptions) {
try (PointerScope scope = new PointerScope()) {
try (PointerScope ignored = new PointerScope()) {
TF_Status status = TF_Status.newStatus();

// allocate parameters for TF_LoadSessionFromSavedModel
Expand Down Expand Up @@ -141,7 +144,7 @@ public static TF_Tensor[] runSession(
int numInputs = inputTensorHandles.length;
int numOutputs = outputOpHandles.length;
int numTargets = targetOpHandles.length;
try (PointerScope scope = new PointerScope()) {
try (PointerScope ignored = new PointerScope()) {
// TODO: check with sig-jvm if TF_Output here is freed
TF_Output inputs = new TF_Output(numInputs);
PointerPointer<TF_Tensor> inputValues = new PointerPointer<>(numInputs);
Expand Down Expand Up @@ -199,7 +202,7 @@ public static TF_Tensor[] runSession(
@SuppressWarnings({"unchecked", "try"})
public static TFE_Context createEagerSession(
boolean async, int devicePlacementPolicy, ConfigProto config) {
try (PointerScope scope = new PointerScope()) {
try (PointerScope ignored = new PointerScope()) {
TFE_ContextOptions opts = TFE_ContextOptions.newContextOptions();
TF_Status status = TF_Status.newStatus();
if (config != null) {
Expand All @@ -218,7 +221,7 @@ public static TFE_Context createEagerSession(

@SuppressWarnings({"unchecked", "try"})
public static Device getDevice(TFE_TensorHandle handle) {
try (PointerScope scope = new PointerScope()) {
try (PointerScope ignored = new PointerScope()) {
TF_Status status = TF_Status.newStatus();
BytePointer pointer = tensorflow.TFE_TensorHandleDeviceName(handle, status);
String device = new String(pointer.getStringBytes(), StandardCharsets.UTF_8);
Expand All @@ -232,7 +235,7 @@ public static DataType getDataType(TFE_TensorHandle handle) {

@SuppressWarnings({"unchecked", "try"})
public static Shape getShape(TFE_TensorHandle handle) {
try (PointerScope scope = new PointerScope()) {
try (PointerScope ignored = new PointerScope()) {
TF_Status status = TF_Status.newStatus();
int numDims = tensorflow.TFE_TensorHandleNumDims(handle, status);
status.throwExceptionIfNotOK();
Expand All @@ -258,7 +261,7 @@ public static TF_Tensor createEmptyTFTensor(Shape shape, DataType dataType) {

@SuppressWarnings({"unchecked", "try"})
public static TFE_TensorHandle createEmptyTFETensor(Shape shape, DataType dataType) {
try (PointerScope scope = new PointerScope()) {
try (PointerScope ignored = new PointerScope()) {
TF_Tensor tensor = createEmptyTFTensor(shape, dataType);
TF_Status status = TF_Status.newStatus();
TFE_TensorHandle handle = AbstractTFE_TensorHandle.newTensor(tensor, status);
Expand All @@ -267,13 +270,36 @@ public static TFE_TensorHandle createEmptyTFETensor(Shape shape, DataType dataTy
}
}

@SuppressWarnings({"unchecked", "try"})
public static TFE_TensorHandle createStringTensor(String src) {
int dType = TfDataType.toTf(DataType.STRING);
long[] dims = {};
long numBytes = Loader.sizeof(TF_TString.class);
try (PointerScope ignored = new PointerScope()) {
TF_Tensor tensor = AbstractTF_Tensor.allocateTensor(dType, dims, numBytes);
Pointer pointer = tensorflow.TF_TensorData(tensor).capacity(numBytes);
TF_TString data = new TF_TString(pointer).capacity(pointer.position() + 1);
byte[] buf = src.getBytes(StandardCharsets.UTF_8);
tensorflow.TF_TString_Copy(data, new BytePointer(buf), buf.length);
TF_Status status = TF_Status.newStatus();
TFE_TensorHandle handle = AbstractTFE_TensorHandle.newTensor(tensor, status);
status.throwExceptionIfNotOK();
return handle.retainReference();
}
}

@SuppressWarnings({"unchecked", "try"})
public static TFE_TensorHandle createTFETensorFromByteBuffer(
ByteBuffer buf, Shape shape, DataType dataType) {
int dType = TfDataType.toTf(dataType);
long[] dims = shape.getShape();
long numBytes = shape.size() * dataType.getNumOfBytes();
try (PointerScope scope = new PointerScope()) {
long numBytes;
if (dataType == DataType.STRING) {
numBytes = buf.remaining() + 1;
} else {
numBytes = shape.size() * dataType.getNumOfBytes();
}
try (PointerScope ignored = new PointerScope()) {
TF_Tensor tensor = AbstractTF_Tensor.allocateTensor(dType, dims, numBytes);
// get data pointer in native engine
Pointer pointer = tensorflow.TF_TensorData(tensor).capacity(numBytes);
Expand All @@ -287,7 +313,7 @@ public static TFE_TensorHandle createTFETensorFromByteBuffer(

@SuppressWarnings({"unchecked", "try"})
public static TF_Tensor resolveTFETensor(TFE_TensorHandle handle) {
try (PointerScope scope = new PointerScope()) {
try (PointerScope ignored = new PointerScope()) {
TF_Status status = TF_Status.newStatus();
TF_Tensor tensor = tensorflow.TFE_TensorHandleResolve(handle, status).withDeallocator();
status.throwExceptionIfNotOK();
Expand All @@ -297,17 +323,34 @@ public static TF_Tensor resolveTFETensor(TFE_TensorHandle handle) {

@SuppressWarnings({"unchecked", "try"})
public static TFE_TensorHandle createTFETensor(TF_Tensor handle) {
try (PointerScope scope = new PointerScope()) {
try (PointerScope ignored = new PointerScope()) {
TF_Status status = TF_Status.newStatus();
TFE_TensorHandle tensor = AbstractTFE_TensorHandle.newTensor(handle, status);
status.throwExceptionIfNotOK();
return tensor.retainReference();
}
}

@SuppressWarnings({"unchecked", "try"})
public static String getString(TFE_TensorHandle handle, Charset charset) {
try (PointerScope ignored = new PointerScope()) {
// convert to TF_Tensor
TF_Status status = TF_Status.newStatus();
TF_Tensor tensor = tensorflow.TFE_TensorHandleResolve(handle, status).withDeallocator();
status.throwExceptionIfNotOK();

Pointer pointer =
tensorflow.TF_TensorData(tensor).capacity(tensorflow.TF_TensorByteSize(tensor));

TF_TString data = new TF_TString(pointer).capacity(pointer.position() + 1);
BytePointer bp = tensorflow.TF_TString_GetDataPointer(data);
return bp.getString(charset);
}
}

@SuppressWarnings({"unchecked", "try"})
public static ByteBuffer getByteBuffer(TFE_TensorHandle handle) {
try (PointerScope scope = new PointerScope()) {
try (PointerScope ignored = new PointerScope()) {
// convert to TF_Tensor
TF_Status status = TF_Status.newStatus();
TF_Tensor tensor = tensorflow.TFE_TensorHandleResolve(handle, status).withDeallocator();
Expand All @@ -328,7 +371,7 @@ public static ByteBuffer getByteBuffer(TFE_TensorHandle handle) {
@SuppressWarnings({"unchecked", "try"})
public static TFE_TensorHandle toDevice(
TFE_TensorHandle handle, TFE_Context eagerSessionHandle, Device device) {
try (PointerScope scope = new PointerScope()) {
try (PointerScope ignored = new PointerScope()) {
String deviceName = toTfDevice(device);
TF_Status status = TF_Status.newStatus();
TFE_TensorHandle newHandle =
Expand Down Expand Up @@ -372,8 +415,7 @@ public static String toTfDevice(Device device) {
} else if (device.getDeviceType().equals(Device.Type.GPU)) {
return "/device:GPU:" + device.getDeviceId();
} else {
throw new EngineException(
"Unknown device type to TensorFlow Engine: " + device.toString());
throw new EngineException("Unknown device type to TensorFlow Engine: " + device);
}
}
}

0 comments on commit 79a6720

Please sign in to comment.