Skip to content

Commit

Permalink
Use PyTorch index over engine agnostic solutioin to improve the perfo…
Browse files Browse the repository at this point in the history
…rmance (#638)
  • Loading branch information
stu1130 authored Feb 11, 2021
1 parent 0beb220 commit 3bed9fa
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,12 @@ public void set(Buffer data) {
JniUtils.set(this, buf);
}

/** {@inheritDoc} */
@Override
public NDArray get(long... indices) {
return JniUtils.getItem(this, indices);
}

/** {@inheritDoc} */
@Override
public void copyTo(NDArray array) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,20 @@ public static void booleanMaskSet(PtNDArray ndArray, PtNDArray value, PtNDArray
ndArray.getHandle(), value.getHandle(), indicesNd.getHandle());
}

public static PtNDArray getItem(PtNDArray ndArray, long[] indices) {
// use a specialized API here
// due to significant performance gain
// for commonly used data loading call
if (indices.length == 1) {
return new PtNDArray(
ndArray.getManager(),
PyTorchLibrary.LIB.torchGetItem(ndArray.getHandle(), indices[0]));
}
return new PtNDArray(
ndArray.getManager(),
PyTorchLibrary.LIB.torchGetItem(ndArray.getHandle(), indices));
}

public static PtNDArray clone(PtNDArray ndArray) {
return new PtNDArray(
ndArray.getManager(), PyTorchLibrary.LIB.tensorClone(ndArray.getHandle()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ private PyTorchLibrary() {}

native long torchTo(long handle, int dType, int[] device, boolean copy);

native long torchGetItem(long handle, long index);

native long torchGetItem(long handle, long[] indices);

native long torchToSparse(long handle);

native long torchToDense(long handle);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,30 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchTo(
API_END_RETURN()
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchGetItem__JJ(
JNIEnv* env, jobject jthis, jlong jhandle, jlong jindex) {
API_BEGIN()
const auto* tensor_ptr = reinterpret_cast<torch::Tensor*>(jhandle);
const auto* result_ptr = new torch::Tensor(tensor_ptr->index({static_cast<int64_t>(jindex)}));
return reinterpret_cast<uintptr_t>(result_ptr);
API_END_RETURN()
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchGetItem__J_3J(
JNIEnv* env, jobject jthis, jlong jhandle, jlongArray jindices) {
API_BEGIN()
const auto* tensor_ptr = reinterpret_cast<torch::Tensor*>(jhandle);
std::vector<int64_t> vec = djl::utils::jni::GetVecFromJLongArray(env, jindices);
std::vector<torch::indexing::TensorIndex> indices;
indices.reserve(vec.size());
std::transform(vec.begin(), vec.end(), std::back_inserter(indices), [](int64_t index) {
return torch::indexing::TensorIndex{index};
});
const auto* result_ptr = new torch::Tensor(tensor_ptr->index(indices));
return reinterpret_cast<uintptr_t>(result_ptr);
API_END_RETURN()
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_tensorClone(JNIEnv* env, jobject jthis, jlong jhandle) {
API_BEGIN()
torch::NoGradGuard NoGradGuard;
Expand Down

0 comments on commit 3bed9fa

Please sign in to comment.