From 5c1a8cab849180570b5f615bdb8183035061709f Mon Sep 17 00:00:00 2001 From: enpasos Date: Mon, 1 Mar 2021 11:16:29 +0100 Subject: [PATCH 1/2] stopGradient on PyTorch --- .../src/main/java/ai/djl/pytorch/engine/PtNDArray.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java index cdf9c8193f0..7dfb50b33f4 100644 --- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java +++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java @@ -202,7 +202,7 @@ public boolean hasGradient() { /** {@inheritDoc} */ @Override public NDArray stopGradient() { - throw new UnsupportedOperationException("Not supported"); + return JniUtils.detachGradient(this); } /** {@inheritDoc} */ From c4be5f7672de3e6e81466d1470d13e9ef887a940 Mon Sep 17 00:00:00 2001 From: "Matthias.Unverzagt" Date: Mon, 1 Mar 2021 22:21:48 +0100 Subject: [PATCH 2/2] stopGradient test added --- .../tests/ndarray/NDArrayOtherOpTest.java | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayOtherOpTest.java b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayOtherOpTest.java index 5f120e544bd..3dee160c16a 100644 --- a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayOtherOpTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayOtherOpTest.java @@ -12,6 +12,7 @@ */ package ai.djl.integration.tests.ndarray; +import ai.djl.engine.Engine; import ai.djl.engine.EngineException; import ai.djl.ndarray.LazyNDArray; import ai.djl.ndarray.NDArray; @@ -21,6 +22,7 @@ import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; import ai.djl.testing.Assertions; +import ai.djl.training.GradientCollector; import ai.djl.util.Hex; import java.nio.FloatBuffer; import org.testng.Assert; @@ -888,4 +890,28 @@ public void testOneHot() { Assert.assertEquals(array.oneHot(3), expected); } } + + @Test + public void testStopGradient() { + try (NDManager manager = NDManager.newBaseManager()) { + // normal gradient + NDArray x = manager.create(new float[] {1.0f}, new Shape(1)); + x.attachGradient(); + try (GradientCollector gc = Engine.getInstance().newGradientCollector()) { + NDArray y = x.mul(x); + gc.backward(y); + NDArray grad = x.getGradient(); + Assert.assertEquals(2f, grad.getFloat(0)); + } + // stop gradient + x = manager.create(new float[] {1.0f}, new Shape(1)); + x.attachGradient(); + try (GradientCollector gc = Engine.getInstance().newGradientCollector()) { + NDArray z = x.mul(x.stopGradient()); + gc.backward(z); + NDArray grad = x.getGradient(); + Assert.assertEquals(1f, grad.getFloat(0)); + } + } + } }