Skip to content

Commit

Permalink
【Hackathon No57】add_bf16_fp16 unittest for conv3d & conv3d_transpose (P…
Browse files Browse the repository at this point in the history
…addlePaddle#52195)

* add test+conv3d_transpose_part2

* fix some merge error

* fix codestyle

* fix typo

* fix codestyle

* fix some error

* add redef float2uint

* fix conv3d and conv3d_transpose
  • Loading branch information
Difers authored and ZzSean committed May 5, 2023
1 parent 85f599e commit 656a5a2
Show file tree
Hide file tree
Showing 3 changed files with 277 additions and 20 deletions.
128 changes: 114 additions & 14 deletions python/paddle/fluid/tests/unittests/test_conv3d_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,16 @@
import unittest

import numpy as np
from eager_op_test import OpTest, paddle_static_guard
from eager_op_test import (
OpTest,
convert_float_to_uint16,
get_numeric_gradient,
paddle_static_guard,
)

import paddle
from paddle.fluid import core
from paddle.fluid.tests.unittests.testsuite import create_op


def conv3d_forward_naive(
Expand Down Expand Up @@ -179,6 +185,77 @@ def init_kernel_type(self):
globals()[cls_name] = TestCUDNNCase


def create_test_cudnn_bf16_class(parent):
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA and do not support bfloat16",
)
class TestConv3DCUDNNBF16(parent):
def get_numeric_grad(self, place, check_name):
scope = core.Scope()
self._check_grad_helper()
op = create_op(
scope, self.op_type, self.inputs, self.outputs, self.attrs
)
return get_numeric_gradient(
place, scope, op, self.inputs_fp32, check_name, ['Output']
)

def init_kernel_type(self):
self.use_cudnn = True
self.dtype = np.uint16

def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(
place, check_dygraph=(not self.use_mkldnn)
)

def test_check_grad_no_filter(self):
place = core.CUDAPlace(0)
numeric_grads = self.get_numeric_grad(place, 'Input')

self.check_grad_with_place(
place,
['Input'],
'Output',
no_grad_set={'Filter'},
check_dygraph=(not self.use_mkldnn),
user_defined_grads=[numeric_grads],
)

def test_check_grad_no_input(self):
place = core.CUDAPlace(0)
numeric_grads = self.get_numeric_grad(place, 'Filter')

self.check_grad_with_place(
place,
['Filter'],
'Output',
no_grad_set={'Input'},
check_dygraph=(not self.use_mkldnn),
user_defined_grads=[numeric_grads],
)

def test_check_grad(self):
place = core.CUDAPlace(0)
numeric_input_grads = self.get_numeric_grad(place, 'Input')
numeric_fliter_grads = self.get_numeric_grad(place, 'Filter')

self.check_grad_with_place(
place,
{'Input', 'Filter'},
'Output',
user_defined_grads=[numeric_input_grads, numeric_fliter_grads],
check_dygraph=(not self.use_mkldnn),
)

cls_name = "{}_{}".format(parent.__name__, "CUDNNBF16OP")
TestConv3DCUDNNBF16.__name__ = cls_name
globals()[cls_name] = TestConv3DCUDNNBF16


def create_test_padding_SAME_class(parent):
class TestPaddingSMAECase(parent):
def init_paddings(self):
Expand Down Expand Up @@ -323,19 +400,37 @@ def setUp(self):
'dilations': self.dilations,
}

input = np.random.random(self.input_size).astype(self.dtype)
filter = np.random.random(self.filter_size).astype(self.dtype)
if self.is_bfloat16_op():
input = np.random.random(self.input_size).astype(np.float32)
filter = np.random.random(self.filter_size).astype(np.float32)
else:
input = np.random.random(self.input_size).astype(self.dtype)
filter = np.random.random(self.filter_size).astype(self.dtype)

output = conv3d_forward_naive(
input,
filter,
self.groups,
conv3d_param,
).astype(self.dtype)
)

if self.is_bfloat16_op():
output = convert_float_to_uint16(output)
self.inputs = {
'Input': convert_float_to_uint16(input),
'Filter': convert_float_to_uint16(filter),
}
self.inputs_fp32 = {
'Input': OpTest.np_dtype_to_fluid_dtype(input),
'Filter': OpTest.np_dtype_to_fluid_dtype(filter),
}
else:
output = output.astype(self.dtype)
self.inputs = {
'Input': OpTest.np_dtype_to_fluid_dtype(input),
'Filter': OpTest.np_dtype_to_fluid_dtype(filter),
}

self.inputs = {
'Input': OpTest.np_dtype_to_fluid_dtype(input),
'Filter': OpTest.np_dtype_to_fluid_dtype(filter),
}
self.attrs = {
'strides': self.stride,
'paddings': self.pad,
Expand All @@ -358,8 +453,6 @@ def test_check_output(self):
)

def test_check_grad(self):
if self.dtype == np.float16:
return
place = core.CUDAPlace(0) if self.has_cudnn() else core.CPUPlace()
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_grad_with_place(
Expand All @@ -371,8 +464,7 @@ def test_check_grad(self):
)

def test_check_grad_no_filter(self):
if self.dtype == np.float16:
return

place = core.CUDAPlace(0) if self.has_cudnn() else core.CPUPlace()
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_grad_with_place(
Expand All @@ -385,8 +477,7 @@ def test_check_grad_no_filter(self):
)

def test_check_grad_no_input(self):
if self.dtype == np.float16:
return

place = core.CUDAPlace(0) if self.has_cudnn() else core.CPUPlace()
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_grad_with_place(
Expand Down Expand Up @@ -617,6 +708,14 @@ def init_kernel_type(self):
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64


# ----------------Conv3DCUDNN bf16----------------
create_test_cudnn_bf16_class(TestConv3DOp)
create_test_cudnn_bf16_class(TestWithGroup1)
create_test_cudnn_bf16_class(TestWithGroup2)
create_test_cudnn_bf16_class(TestWith1x1)
create_test_cudnn_bf16_class(TestWithInput1x1Filter1x1)


# ---- test asymmetric padding ----


Expand Down Expand Up @@ -1114,4 +1213,5 @@ def run_8():


if __name__ == '__main__':

unittest.main()
147 changes: 142 additions & 5 deletions python/paddle/fluid/tests/unittests/test_conv3d_transpose_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,25 @@
import paddle

paddle.enable_static()
from eager_op_test import OpTest
from eager_op_test import OpTest, copy_bits_from_float_to_uint16

from paddle.fluid import core


def convert_float_to_uint16(float_list, data_format="NCHW"):
if data_format == "NHWC":
float_list = np.transpose(float_list, [0, 4, 1, 2, 3])

new_output = []
for x in np.nditer(float_list):
new_output.append(np.uint16(copy_bits_from_float_to_uint16(x)))
new_output = np.reshape(new_output, float_list.shape).view(np.uint16)

if data_format == "NHWC":
new_output = np.transpose(new_output, [0, 2, 3, 4, 1])
return new_output


def conv3dtranspose_forward_naive(input_, filter_, attrs):
padding_algorithm = attrs['padding_algorithm']
if padding_algorithm not in ["SAME", "VALID", "EXPLICIT"]:
Expand Down Expand Up @@ -134,6 +148,86 @@ def _get_padding_with_SAME(input_shape, kernel_size, kernel_stride):
return out


def create_test_cudnn_fp16_class(parent, grad_check=True):
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestConv3DTransposeCUDNNFP16(parent):
def init_kernel_type(self):
self.use_cudnn = True
self.dtype = np.float16

def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=2e-2)

def test_check_grad_no_filter(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place) and grad_check:
self.check_grad_with_place(
place, ['Input'], 'Output', no_grad_set={'Filter'}
)

def test_check_grad_no_input(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place) and grad_check:
self.check_grad_with_place(
place, ['Filter'], 'Output', no_grad_set={'Input'}
)

cls_name = "{}_{}".format(parent.__name__, "CUDNNFP16OP")
TestConv3DTransposeCUDNNFP16.__name__ = cls_name
globals()[cls_name] = TestConv3DTransposeCUDNNFP16


def create_test_cudnn_bf16_class(parent):
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA and do not support bfloat16",
)
class TestConv3DTransposeCUDNNBF16(parent):
def init_kernel_type(self):
self.use_cudnn = True
self.dtype = np.uint16

def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)

def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place,
{'Input', 'Filter'},
'Output',
)

def test_check_grad_no_filter(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place,
['Input'],
'Output',
no_grad_set={'Filter'},
)

def test_check_grad_no_input(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place,
['Filter'],
'Output',
no_grad_set={'Input'},
)

cls_name = "{}_{}".format(parent.__name__, "CUDNNBF16OP")
TestConv3DTransposeCUDNNBF16.__name__ = cls_name
globals()[cls_name] = TestConv3DTransposeCUDNNBF16


def conv3d_transpose_wrapper(
x,
weight,
Expand Down Expand Up @@ -172,12 +266,16 @@ def setUp(self):
self.pad = [0, 0, 0]
self.padding_algorithm = "EXPLICIT"
self.init_op_type()
self.init_kernel_type()
self.init_test_case()

input_ = np.random.random(self.input_size).astype("float32")
filter_ = np.random.random(self.filter_size).astype("float32")
if self.is_bfloat16_op():
input = np.random.random(self.input_size).astype(np.float32)
filter = np.random.random(self.filter_size).astype(np.float32)
else:
input = np.random.random(self.input_size).astype(self.dtype)
filter = np.random.random(self.filter_size).astype(self.dtype)

self.inputs = {'Input': input_, 'Filter': filter_}
self.attrs = {
'strides': self.stride,
'paddings': self.pad,
Expand All @@ -189,9 +287,21 @@ def setUp(self):
}

output = conv3dtranspose_forward_naive(
input_, filter_, self.attrs
input, filter, self.attrs
).astype("float32")

if self.is_bfloat16_op():
self.inputs = {
'Input': convert_float_to_uint16(input),
'Filter': convert_float_to_uint16(filter),
}
else:
self.inputs = {
'Input': input,
'Filter': filter,
}
output = output.astype(self.dtype)

self.outputs = {'Output': output}

def test_check_output(self):
Expand Down Expand Up @@ -264,6 +374,9 @@ def init_op_type(self):
self.op_type = "conv3d_transpose"
self.python_api = conv3d_transpose_wrapper

def init_kernel_type(self):
self.dtype = np.float32


class TestWithSymmetricPad(TestConv3DTransposeOp):
def init_test_case(self):
Expand Down Expand Up @@ -596,6 +709,30 @@ def init_op_type(self):
self.python_api = conv3d_transpose_wrapper


# ----------------Conv3DTransposeCUDNN fp16----------------
create_test_cudnn_fp16_class(TestConv3DTransposeOp)
create_test_cudnn_fp16_class(TestWithSymmetricPad)
create_test_cudnn_fp16_class(TestWithAsymmetricPad)
create_test_cudnn_fp16_class(TestWithSAMEPad)
create_test_cudnn_fp16_class(TestWithVALIDPad)
create_test_cudnn_fp16_class(TestWithStride)
create_test_cudnn_fp16_class(TestWithGroups)
create_test_cudnn_fp16_class(TestWithDilation)
create_test_cudnn_fp16_class(Test_NHWC)


# ----------------Conv3DTransposeCUDNN bf16----------------
create_test_cudnn_bf16_class(TestConv3DTransposeOp)
create_test_cudnn_bf16_class(TestWithSymmetricPad)
create_test_cudnn_bf16_class(TestWithAsymmetricPad)
create_test_cudnn_bf16_class(TestWithSAMEPad)
create_test_cudnn_bf16_class(TestWithVALIDPad)
create_test_cudnn_bf16_class(TestWithStride)
create_test_cudnn_bf16_class(TestWithGroups)
create_test_cudnn_bf16_class(TestWithDilation)
create_test_cudnn_bf16_class(Test_NHWC)


class TestConv3dTranspose(unittest.TestCase):
def error_weight_input(self):
array = np.array([1], dtype=np.float32)
Expand Down
Loading

0 comments on commit 656a5a2

Please sign in to comment.