Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stopGradient on PyTorch #708

Merged
merged 1 commit into from
Mar 1, 2021
Merged

stopGradient on PyTorch #708

merged 1 commit into from
Mar 1, 2021

Conversation

enpasos
Copy link
Contributor

@enpasos enpasos commented Mar 1, 2021

One of the enhancements in 0.10.0 was "Adds the NDArray stopGradient and scaleGradient functions (#548)".
It works great on MXNet. Thanks for the feature.

To benefit from one of the 0.10.0 key features "Upgrades PyTorch to 1.7.1" I wired up the stopGradient for PyTorch which makes up the line of change in this pull request.

It works for me ... I have tested it with

    @Test
    void testStopGradient() {

        try (NDManager manager = NDManager.newBaseManager()) {

            NDArray x = manager.create(new float[] {1.0f}, new Shape(1));
            x.attachGradient();
            try (GradientCollector gc = Engine.getInstance().newGradientCollector()) {
                NDArray y = x.mul(x);
                gc.backward(y);
                NDArray grad = x.getGradient();
                System.out.println("grad: " + grad);
                assertEquals(2f, grad.getFloat(0));
            }

            x = manager.create(new float[] {1.0f}, new Shape(1));
            x.attachGradient();
            try (GradientCollector gc = Engine.getInstance().newGradientCollector()) {
                NDArray z = x.mul(x.stopGradient());
                gc.backward(z);
                NDArray grad = x.getGradient();
                System.out.println("grad: " + grad);
                assertEquals(1f, grad.getFloat(0));
            }

        }
    }

@zachgk
Copy link
Contributor

zachgk commented Mar 1, 2021

Thanks for your contribution @enpasos! Since you wrote the test for it, do you want to add that to our test suite as well? You can add it in NDArrayOtherOpTest

@stu1130 stu1130 merged commit f81bef5 into deepjavalibrary:master Mar 1, 2021
@stu1130
Copy link
Contributor

stu1130 commented Mar 1, 2021

Thanks for your contribution @enpasos! Since you wrote the test for it, do you want to add that to our test suite as well? You can add it in NDArrayOtherOpTest

Sorry I didn't see your comment. @enpasos can you raise another PR. Thanks

@enpasos enpasos mentioned this pull request Mar 1, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants