Skip to content

Commit

Permalink
[rust] Fix NDArrayTests failure on cuda (#3319)
Browse files Browse the repository at this point in the history
* [rust] Fix NDArrayTests failure on cuda

* Update
  • Loading branch information
xyang16 authored Jul 10, 2024
1 parent 9bcf3b9 commit 97b9a93
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.testing.Assertions;
import ai.djl.testing.TestRequirements;

import org.testng.Assert;
import org.testng.annotations.Test;
Expand Down Expand Up @@ -106,8 +107,7 @@ public void testToDataType() {
NDArray int8 = uint8.toType(DataType.INT8, false);
NDArray int32 = int8.toType(DataType.INT32, false);
NDArray uint32 = int32.toType(DataType.UINT32, false);
NDArray int64 = uint32.toType(DataType.INT64, false);
NDArray f16 = int64.toType(DataType.FLOAT16, false);
NDArray f16 = uint32.toType(DataType.FLOAT16, false);
NDArray bf16 = f16.toType(DataType.BFLOAT16, false);
NDArray f32 = bf16.toType(DataType.FLOAT32, false);
NDArray f64 = f32.toType(DataType.FLOAT64, false);
Expand All @@ -116,6 +116,19 @@ public void testToDataType() {
}
}

@Test
public void testI64toF16() {
TestRequirements.notGpu();
try (NDManager manager = NDManager.newBaseManager("Rust")) {
NDArray array = manager.create(2);
Assert.assertEquals(array.getDataType(), DataType.INT32);
NDArray int64 = array.toType(DataType.INT64, false);
NDArray f16 = int64.toType(DataType.FLOAT16, false);
NDArray bool = f16.toType(DataType.BOOLEAN, false);
Assert.assertTrue(bool.getBoolean());
}
}

@Test
public void testComparisonOp() {
try (NDManager manager = NDManager.newBaseManager("Rust")) {
Expand Down
7 changes: 7 additions & 0 deletions testing/src/main/java/ai/djl/testing/TestRequirements.java
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@ public static void gpu() {
}
}

/** Requires the test runs on non-GPU. */
public static void notGpu() {
if (Engine.getInstance().getGpuCount() > 0) {
throw new SkipException("This test requires non-GPU to run");
}
}

/** Requires that the test runs on macOS M1. */
public static void macosM1() {
if (!System.getProperty("os.name").toLowerCase().startsWith("mac")
Expand Down

0 comments on commit 97b9a93

Please sign in to comment.