Skip to content

Commit

Permalink
add user_defined_grads
Browse files Browse the repository at this point in the history
  • Loading branch information
longranger2 committed May 20, 2023
1 parent 1f74545 commit 6348897
Showing 1 changed file with 10 additions and 12 deletions.
22 changes: 10 additions & 12 deletions python/paddle/fluid/tests/unittests/test_lerp_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,7 @@
import unittest

import numpy as np
from eager_op_test import (
OpTest,
convert_float_to_uint16,
convert_uint16_to_float,
)
from eager_op_test import OpTest, convert_float_to_uint16

import paddle
from paddle.fluid import core
Expand Down Expand Up @@ -237,12 +233,9 @@ def setUp(self):
self.init_shape()
self.init_xyshape()
self.init_wshape()
x = np.arange(1.0, 101.0).astype(np.float32).reshape(self.xshape)
y = np.full(100, 10.0).astype(np.float32).reshape(self.yshape)
w = np.random.random(self.wshape).astype(np.float32)
x = convert_uint16_to_float(convert_float_to_uint16(x))
y = convert_uint16_to_float(convert_float_to_uint16(y))
w = convert_uint16_to_float(convert_float_to_uint16(w))
x = np.arange(1.0, 101.0).astype("float32").reshape(self.xshape)
y = np.full(100, 10.0).astype("float32").reshape(self.yshape)
w = np.random.random(self.wshape).astype("float32")
self.inputs = {
'X': convert_float_to_uint16(x),
'Y': convert_float_to_uint16(y),
Expand All @@ -266,7 +259,12 @@ def test_check_output(self):

def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X', 'Y'], 'Out')
self.check_grad_with_place(
place,
['X', 'Y'],
'Out',
user_defined_grads=[np.zeros(self.xshape), np.ones(self.yshape)],
)


if __name__ == "__main__":
Expand Down

0 comments on commit 6348897

Please sign in to comment.