Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PaddlePaddle Hackathon 4】No.56 : add fp16 test and bf16 for bernoulli and trunc #51657

Closed
wants to merge 28 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
a1d0522
add fp16 and bf16 support for bernoulli
longranger2 Mar 14, 2023
f6455e7
add fp16 and bf16 support for trunc
longranger2 Mar 14, 2023
3279c68
Merge branch 'PaddlePaddle:develop' into fp16_56_2
longranger2 Mar 22, 2023
99f5854
fix bug
longranger2 Mar 22, 2023
9ee7d3a
Merge branch 'develop' into fp16_56_2
longranger2 Mar 25, 2023
dce1754
fix bug
longranger2 Apr 3, 2023
63c6f39
Merge branch 'PaddlePaddle:develop' into fp16_56_2
longranger2 Apr 3, 2023
b1771eb
fix bug
longranger2 Apr 22, 2023
528e5b8
fix PR-CI-Codestyle-Check
longranger2 Apr 22, 2023
2fc39e1
fix bug of trunc_kernel.cu
longranger2 Apr 22, 2023
8b8361d
fix bug of trunc_kernel.cu
longranger2 Apr 22, 2023
099d3bb
fix bug of trunc_kernel.cu
longranger2 Apr 22, 2023
22dbf8d
fix bug of trunc and bernoulli
longranger2 May 3, 2023
9db702f
fix bug
longranger2 May 9, 2023
38d7bc1
fix bug
longranger2 May 9, 2023
f4ce773
fix bug of MPType
longranger2 May 10, 2023
bd62029
fix check_variable_and_dtype
longranger2 May 10, 2023
3782bd1
fix bug of MPType
longranger2 May 10, 2023
b20ac1a
fix bug of undefined T
longranger2 May 10, 2023
7def562
fix bug
longranger2 May 11, 2023
3f44c3d
Merge branch 'PaddlePaddle:develop' into fp16_56_2
longranger2 May 12, 2023
3e9063a
Update test_bernoulli_op.py
longranger2 May 12, 2023
13a2c74
Update test_bernoulli_op.py
longranger2 May 15, 2023
3c4e333
Update test_bernoulli_op.py
longranger2 May 15, 2023
e7ad7f2
fix bug of import
longranger2 May 16, 2023
10336f8
Merge branch 'PaddlePaddle:develop' into fp16_56_2
longranger2 May 16, 2023
f922dd8
remove the trunc
longranger2 May 31, 2023
ea1d0ed
Merge branch 'PaddlePaddle:develop' into fp16_56_2
longranger2 May 31, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions paddle/phi/kernels/gpu/bernoulli_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h"
Expand All @@ -52,10 +53,12 @@ __global__ void bernoulli_cuda_kernel(
funcs::uniform_distribution<float> dist;
float4 rand = dist(&state);
#pragma unroll
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
for (size_t j = 0; j < 4; j++) {
size_t idx = i + j;
if (idx < size) {
out_data[idx] = static_cast<T>((&rand.x)[j] <= x_data[idx]);
out_data[idx] =
static_cast<T>((&rand.x)[j] <= static_cast<MPType>(x_data[idx]));
}
}
}
Expand Down Expand Up @@ -85,5 +88,11 @@ void BernoulliKernel(const Context& ctx,

} // namespace phi

PD_REGISTER_KERNEL(
bernoulli, GPU, ALL_LAYOUT, phi::BernoulliKernel, float, double) {}
PD_REGISTER_KERNEL(bernoulli,
GPU,
ALL_LAYOUT,
phi::BernoulliKernel,
phi::dtype::float16,
phi::dtype::bfloat16,
float,
double) {}
52 changes: 49 additions & 3 deletions python/paddle/fluid/tests/unittests/test_bernoulli_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,14 @@
import unittest

import numpy as np
from eager_op_test import OpTest
from eager_op_test import (
OpTest,
convert_float_to_uint16,
convert_uint16_to_float,
)

import paddle
from paddle.fluid import core


def output_hist(out):
Expand All @@ -31,9 +36,14 @@ def output_hist(out):
class TestBernoulliOp(OpTest):
def setUp(self):
self.op_type = "bernoulli"
self.inputs = {"X": np.random.uniform(size=(1000, 784))}
self.inputs = {
"X": np.random.uniform(size=(1000, 784)).astype(self.dtype)
}
self.attrs = {}
self.outputs = {"Out": np.zeros((1000, 784)).astype("float32")}
self.outputs = {"Out": np.zeros((1000, 784)).astype(self.dtype)}

def init_dtype(self):
self.dtype = np.float32

def test_check_output(self):
self.check_output_customized(self.verify_output)
Expand Down Expand Up @@ -98,5 +108,41 @@ def test_fixed_random_number(self):
paddle.enable_static()


class TestBernoulliFP16Op(TestBernoulliOp):
def init_dtype(self):
self.dtype = np.float16


@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestBernoulliBF16Op(OpTest):
def setUp(self):
self.python_api = paddle.bernoulli
self.op_type = "bernoulli"
self.dtype = np.uint16
self.init_test_case()

self.inputs = {'X': convert_float_to_uint16(self.x)}
self.attrs = {}
self.outputs = {'Out': convert_float_to_uint16(self.out)}

def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place_customized(self.verify_output, place)

def init_test_case(self):
self.x = np.random.uniform(size=(1000, 784)).astype("float32")
self.out = np.zeros((1000, 784)).astype("float32")
self.x = convert_uint16_to_float(convert_float_to_uint16(self.x))
self.out = convert_uint16_to_float(convert_float_to_uint16(self.out))

def verify_output(self, outs):
hist, prob = output_hist(np.array(outs[0]))
np.testing.assert_allclose(hist, prob)


if __name__ == "__main__":
unittest.main()
4 changes: 3 additions & 1 deletion python/paddle/tensor/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ def bernoulli(x, name=None):
if in_dynamic_mode():
return _C_ops.bernoulli(x)
else:
check_variable_and_dtype(x, "x", ["float32", "float64"], "bernoulli")
check_variable_and_dtype(
x, "x", ["float16", "float32", "float64"], "bernoulli"
)

helper = LayerHelper("randint", **locals())
out = helper.create_variable_for_type_inference(
Expand Down