Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 committed Jul 10, 2024
1 parent 129c55e commit 3bf3339
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.testing.Assertions;
import ai.djl.testing.TestRequirements;

import org.testng.Assert;
import org.testng.annotations.Test;
Expand Down Expand Up @@ -106,14 +107,24 @@ public void testToDataType() {
NDArray int8 = uint8.toType(DataType.INT8, false);
NDArray int32 = int8.toType(DataType.INT32, false);
NDArray uint32 = int32.toType(DataType.UINT32, false);
NDArray int64 = uint32.toType(DataType.INT64, false);
NDArray bool = int64.toType(DataType.BOOLEAN, false);
Assert.assertTrue(bool.getBoolean());
NDArray f16 = int32.toType(DataType.FLOAT16, false);
NDArray f16 = uint32.toType(DataType.FLOAT16, false);
NDArray bf16 = f16.toType(DataType.BFLOAT16, false);
NDArray f32 = bf16.toType(DataType.FLOAT32, false);
NDArray f64 = f32.toType(DataType.FLOAT64, false);
bool = f64.toType(DataType.BOOLEAN, false);
NDArray bool = f64.toType(DataType.BOOLEAN, false);
Assert.assertTrue(bool.getBoolean());
}
}

@Test
public void testInt64toF16() {
TestRequirements.notGpu();
try (NDManager manager = NDManager.newBaseManager("Rust")) {
NDArray array = manager.create(2);
Assert.assertEquals(array.getDataType(), DataType.INT32);
NDArray int64 = array.toType(DataType.INT64, false);
NDArray f16 = int64.toType(DataType.FLOAT16, false);
NDArray bool = f16.toType(DataType.BOOLEAN, false);
Assert.assertTrue(bool.getBoolean());
}
}
Expand Down Expand Up @@ -221,12 +232,10 @@ public void testExpandDim() {
expected = manager.create(new float[] {4f});
Assert.assertEquals(array.expandDims(0), expected);

// TODO: Add zero-dim test back once the bug is fixed in candle
// https:/huggingface/candle/issues/2327
// zero-dim
// array = manager.create(new Shape(2, 1, 0));
// expected = manager.create(new Shape(2, 1, 1, 0));
// Assert.assertEquals(array.expandDims(2), expected);
array = manager.create(new Shape(2, 1, 0));
expected = manager.create(new Shape(2, 1, 1, 0));
Assert.assertEquals(array.expandDims(2), expected);
}
}

Expand Down
7 changes: 7 additions & 0 deletions testing/src/main/java/ai/djl/testing/TestRequirements.java
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@ public static void gpu() {
}
}

/** Requires the test runs on non-GPU. */
public static void notGpu() {
if (Engine.getInstance().getGpuCount() > 0) {
throw new SkipException("This test requires non-GPU to run");
}
}

/** Requires that the test runs on macOS M1. */
public static void macosM1() {
if (!System.getProperty("os.name").toLowerCase().startsWith("mac")
Expand Down

0 comments on commit 3bf3339

Please sign in to comment.