diff --git a/python/tvm/contrib/cudnn.py b/python/tvm/contrib/cudnn.py index 0e22e0c09274..57184ccb8b77 100644 --- a/python/tvm/contrib/cudnn.py +++ b/python/tvm/contrib/cudnn.py @@ -209,6 +209,101 @@ def conv_output_shape( oshape: list output shape """ + + assert len(x_shape) == len(w_shape) + assert len(x_shape) in (4, 5) + + if tensor_format == 0: + n_output = x_shape[0] + c_output = w_shape[0] + x_chan = x_shape[1] + w_chan_input = w_shape[1] + x_shape = x_shape[2:] + w_shape = w_shape[2:] + + elif tensor_format == 1: + n_output = x_shape[0] + c_output = w_shape[0] + x_chan = x_shape[-1] + w_chan_input = w_shape[-1] + assert len(x_shape) == 4, "CuDNN layout NHWC is only well-defined for 4d tensors" + x_shape = x_shape[1:-1] + w_shape = w_shape[1:-1] + + elif tensor_format == 2: + n_output = x_shape[0] + c_output = w_shape[0] + x_chan = x_shape[1] + w_chan_input = w_shape[1] + w_lanes = tvm.runtime.DataType(conv_dtype).lanes + assert w_lanes == 1 + x_shape = x_shape[2:] + w_shape = w_shape[2:] + + else: + raise ValueError("Unknown CuDNN tensor format: '{}'".format(tensor_format)) + + x_lanes = tvm.runtime.DataType(data_dtype).lanes + assert x_chan * x_lanes == w_chan_input * groups, ( + "Mismatched dimensions, data has {} channels/group " + "(dimension {} with {} lanes/value, {} groups), " + "but weights require {} input channels/group" + ).format(x_chan // groups, x_chan, x_lanes, groups, w_chan_input) + + output_dims = [] + for x_shape_i, w_shape_i, pad_i, stride_i, dilation_i in zip( + x_shape, w_shape, pad, stride, dilation + ): + output_dim = 1 + (x_shape_i + 2 * pad_i - (((w_shape_i - 1) * dilation_i) + 1)) // stride_i + output_dims.append(output_dim) + + if tensor_format in [0, 2]: + output = [n_output, c_output, *output_dims] + elif tensor_format == 1: + output = [n_output, *output_dims, c_output] + else: + raise ValueError("Unknown CuDNN tensor format: '{}'".format(tensor_format)) + + return output + + +def _conv_output_shape_from_cudnn( + tensor_format, pad, stride, dilation, x_shape, w_shape, data_dtype, conv_dtype, groups=1 +): + """Get output shape of 2D or 3D convolution. The output of this + function should be identical to that of conv_output_shape, but + requires a GPU with CuDNN to be present. This is maintained for + testing purposes to validate the output of conv_output_shape. + + Paramters + --------- + tensor_format: int + 0: CUDNN_TENSOR_NCHW + 1: CUDNN_TENSOR_NHWC + 2: CUDNN_TENSOR_NCHW_VECT_C + pad: int or list + padding + stride: int or list + stride + dilation: int or list + dilation + x_shape: list + input shape + w_shape: list + weight shape + data_dtype: str + data type + conv_dtype: str + convolution type + groups: int + number of groups + + Returns + ------- + oshape: list + output shape + + """ dims = len(x_shape) assert dims in (4, 5) @@ -217,7 +312,7 @@ def conv_output_shape( ) oshape = np.zeros((dims), dtype=np.int32) - func = tvm._ffi.get_global_func("tvm.contrib.cudnn.conv.output_shape") + func = tvm._ffi.get_global_func("tvm.contrib.cudnn.conv.output_shape_from_cudnn") func( tensor_format, dims - 2, diff --git a/src/runtime/contrib/cudnn/conv_forward.cc b/src/runtime/contrib/cudnn/conv_forward.cc index ad3b959338bb..2d7f82694929 100644 --- a/src/runtime/contrib/cudnn/conv_forward.cc +++ b/src/runtime/contrib/cudnn/conv_forward.cc @@ -156,7 +156,9 @@ void OutputShape(int format, int dims, int groups, const int pad[], const int st dilation, CUDNN_CROSS_CORRELATION, entry_ptr->conv_entry.data_type)); - if (dims == 2 && entry_ptr->conv_entry.tensor_format == CUDNN_TENSOR_NHWC) { + if (entry_ptr->conv_entry.tensor_format == CUDNN_TENSOR_NHWC) { + ICHECK_EQ(full_dims, 4) << "Use of layout CUDNN_TENSOR_NHWC is only defined for 4d tensors"; + // Set Input CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.tensor_format, data_type, x_dim[0], @@ -295,7 +297,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv3d.forward") conv_dtype); }); -TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.output_shape") +TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.output_shape_from_cudnn") .set_body([](TVMArgs args, TVMRetValue* ret) { int format = args[0]; int dims = args[1]; diff --git a/tests/python/contrib/test_cudnn.py b/tests/python/contrib/test_cudnn.py index 8a929f550a4f..7651bdea36a6 100644 --- a/tests/python/contrib/test_cudnn.py +++ b/tests/python/contrib/test_cudnn.py @@ -14,6 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +import sys + +import pytest + import tvm from tvm import te from tvm.contrib import cudnn @@ -23,6 +28,12 @@ import tvm.testing +requires_cudnn = pytest.mark.skipif( + tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape_from_cudnn", True) is None, + reason="CuDNN is not enabled", +) + + def verify_conv2d(data_dtype, conv_dtype, tensor_format=0, groups=1): in_channel = 4 out_channel = 16 @@ -38,9 +49,6 @@ def verify_conv2d(data_dtype, conv_dtype, tensor_format=0, groups=1): height = 32 width = 32 - if not tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape", True): - print("skip because cudnn is not enabled...") - return if data_dtype == "float16" and not have_fp16(tvm.cuda(0).compute_version): print("Skip because gpu does not have fp16 support") return @@ -123,10 +131,6 @@ def verify_conv3d(data_dtype, conv_dtype, tensor_format=0, groups=1): height = 32 width = 32 - if not tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape", True): - print("skip because cudnn is not enabled...") - return - # schedule xshape = [batch, in_channel, depth, height, width] wshape = [out_channel, in_channel // groups, filter_d, filter_h, filter_w] @@ -205,11 +209,8 @@ def verify_softmax_4d(shape, dtype="float32"): @tvm.testing.requires_gpu +@requires_cudnn def test_softmax(): - if not tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape", True): - print("skip because cudnn is not enabled...") - return - verify_softmax((32, 10), -1) verify_softmax((3, 4), -1) verify_softmax((1, 5), -1, "float64") @@ -217,7 +218,84 @@ def test_softmax(): verify_softmax_4d((1, 16, 256, 256), "float64") +test_kwargs_default_2d = { + "tensor_format": 0, + "pad": [1, 1], + "stride": [1, 1], + "dilation": [1, 1], + "x_shape": [16, 4, 32, 32], + "w_shape": [8, 4, 3, 3], + "groups": 1, + "conv_dtype": "float32", + "data_dtype": "float32", +} +test_kwargs_default_3d = { + "tensor_format": 0, + "pad": [1, 1, 1], + "stride": [1, 1, 1], + "dilation": [1, 1, 1], + "x_shape": [16, 4, 32, 32, 32], + "w_shape": [8, 4, 3, 3, 3], + "groups": 1, + "conv_dtype": "float32", + "data_dtype": "float32", +} +conv_output_shape_conditions = { + "2d_small": test_kwargs_default_2d, + "2d_large": { + **test_kwargs_default_2d, + "x_shape": [16, 32, 512, 1024], + "w_shape": [8, 32, 5, 5], + }, + "2d_pad": {**test_kwargs_default_2d, "pad": [2, 3]}, + "2d_stride": {**test_kwargs_default_2d, "stride": [2, 3]}, + "2d_dilation": {**test_kwargs_default_2d, "dilation": [2, 3]}, + "2d_groups": {**test_kwargs_default_2d, "groups": 4, "w_shape": [8, 1, 3, 3]}, + "2d_NHWC": { + **test_kwargs_default_2d, + "tensor_format": 1, + "x_shape": [16, 32, 32, 4], + "w_shape": [8, 3, 3, 4], + }, + "2d_NCHW_VECT_C": { + **test_kwargs_default_2d, + "tensor_format": 2, + "w_shape": [8, 16, 3, 3], + "data_dtype": "int8x4", + }, + "3d_small": test_kwargs_default_3d, + "3d_large": { + **test_kwargs_default_3d, + "x_shape": [16, 32, 64, 128, 256], + "w_shape": [8, 32, 5, 5, 5], + }, + "3d_pad": {**test_kwargs_default_3d, "pad": [2, 3, 4]}, + "3d_stride": {**test_kwargs_default_3d, "stride": [2, 3, 4]}, + "3d_dilation": {**test_kwargs_default_3d, "dilation": [2, 3, 4]}, + "3d_groups": {**test_kwargs_default_3d, "groups": 4, "w_shape": [8, 1, 3, 3, 3]}, + "3d_NCHW_VECT_C": { + **test_kwargs_default_3d, + "tensor_format": 2, + "w_shape": [8, 16, 3, 3, 3], + "data_dtype": "int8x4", + }, +} + + +@pytest.fixture( + params=[pytest.param(kwargs, id=name) for name, kwargs in conv_output_shape_conditions.items()] +) +def conv_output_shape_kwargs(request): + return request.param + + +@tvm.testing.requires_gpu +@requires_cudnn +def test_conv_output_shape(conv_output_shape_kwargs): + shape_from_cudnn = cudnn._conv_output_shape_from_cudnn(**conv_output_shape_kwargs) + shape_from_python = cudnn.conv_output_shape(**conv_output_shape_kwargs) + assert shape_from_cudnn == shape_from_python + + if __name__ == "__main__": - test_conv2d() - test_conv3d() - test_softmax() + sys.exit(pytest.main(sys.argv)) diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 57f07b3f00e5..13f5525bfee8 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -508,7 +508,7 @@ def verify_any_conv2d( kernel_np = np.random.uniform(size=kernel_shape).astype(dtype) targets = None - if use_cudnn and tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape", True): + if use_cudnn and tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape_from_cudnn", True): targets = [("cuda -libs=cudnn", tvm.cuda(0))] check_result([data_np, kernel_np], mod, ref_out_shape, assert_shape=True, targets=targets)