diff --git a/paddle/phi/kernels/gpu/bernoulli_kernel.cu b/paddle/phi/kernels/gpu/bernoulli_kernel.cu index edcf29e2d88d3..be41dcb524947 100644 --- a/paddle/phi/kernels/gpu/bernoulli_kernel.cu +++ b/paddle/phi/kernels/gpu/bernoulli_kernel.cu @@ -26,6 +26,7 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/distribution_helper.h" @@ -52,10 +53,12 @@ __global__ void bernoulli_cuda_kernel( funcs::uniform_distribution dist; float4 rand = dist(&state); #pragma unroll + using MPType = typename phi::dtype::MPTypeTrait::Type; for (size_t j = 0; j < 4; j++) { size_t idx = i + j; if (idx < size) { - out_data[idx] = static_cast((&rand.x)[j] <= x_data[idx]); + out_data[idx] = + static_cast((&rand.x)[j] <= static_cast(x_data[idx])); } } } @@ -85,5 +88,11 @@ void BernoulliKernel(const Context& ctx, } // namespace phi -PD_REGISTER_KERNEL( - bernoulli, GPU, ALL_LAYOUT, phi::BernoulliKernel, float, double) {} +PD_REGISTER_KERNEL(bernoulli, + GPU, + ALL_LAYOUT, + phi::BernoulliKernel, + phi::dtype::float16, + phi::dtype::bfloat16, + float, + double) {} diff --git a/python/paddle/fluid/tests/unittests/test_bernoulli_op.py b/python/paddle/fluid/tests/unittests/test_bernoulli_op.py index af08b07237ff4..b4ae430d5ead6 100644 --- a/python/paddle/fluid/tests/unittests/test_bernoulli_op.py +++ b/python/paddle/fluid/tests/unittests/test_bernoulli_op.py @@ -15,9 +15,14 @@ import unittest import numpy as np -from eager_op_test import OpTest +from eager_op_test import ( + OpTest, + convert_float_to_uint16, + convert_uint16_to_float, +) import paddle +from paddle.fluid import core def output_hist(out): @@ -31,9 +36,14 @@ def output_hist(out): class TestBernoulliOp(OpTest): def setUp(self): self.op_type = "bernoulli" - self.inputs = {"X": np.random.uniform(size=(1000, 784))} + self.inputs = { + "X": np.random.uniform(size=(1000, 784)).astype(self.dtype) + } self.attrs = {} - self.outputs = {"Out": np.zeros((1000, 784)).astype("float32")} + self.outputs = {"Out": np.zeros((1000, 784)).astype(self.dtype)} + + def init_dtype(self): + self.dtype = np.float32 def test_check_output(self): self.check_output_customized(self.verify_output) @@ -98,5 +108,41 @@ def test_fixed_random_number(self): paddle.enable_static() +class TestBernoulliFP16Op(TestBernoulliOp): + def init_dtype(self): + self.dtype = np.float16 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not complied with CUDA and not support the bfloat16", +) +class TestBernoulliBF16Op(OpTest): + def setUp(self): + self.python_api = paddle.bernoulli + self.op_type = "bernoulli" + self.dtype = np.uint16 + self.init_test_case() + + self.inputs = {'X': convert_float_to_uint16(self.x)} + self.attrs = {} + self.outputs = {'Out': convert_float_to_uint16(self.out)} + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place_customized(self.verify_output, place) + + def init_test_case(self): + self.x = np.random.uniform(size=(1000, 784)).astype("float32") + self.out = np.zeros((1000, 784)).astype("float32") + self.x = convert_uint16_to_float(convert_float_to_uint16(self.x)) + self.out = convert_uint16_to_float(convert_float_to_uint16(self.out)) + + def verify_output(self, outs): + hist, prob = output_hist(np.array(outs[0])) + np.testing.assert_allclose(hist, prob) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/tensor/random.py b/python/paddle/tensor/random.py index a8206ff95bf59..37584fb2efefd 100644 --- a/python/paddle/tensor/random.py +++ b/python/paddle/tensor/random.py @@ -77,7 +77,9 @@ def bernoulli(x, name=None): if in_dynamic_mode(): return _C_ops.bernoulli(x) else: - check_variable_and_dtype(x, "x", ["float32", "float64"], "bernoulli") + check_variable_and_dtype( + x, "x", ["float16", "float32", "float64"], "bernoulli" + ) helper = LayerHelper("randint", **locals()) out = helper.create_variable_for_type_inference(