Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CuDNN] Remove GPU dependency from tvm.contrib.cudnn.conv_output_shape #8275

Merged
merged 1 commit into from
Jun 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 96 additions & 1 deletion python/tvm/contrib/cudnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,101 @@ def conv_output_shape(
oshape: list
output shape
"""

assert len(x_shape) == len(w_shape)
assert len(x_shape) in (4, 5)

if tensor_format == 0:
n_output = x_shape[0]
c_output = w_shape[0]
x_chan = x_shape[1]
w_chan_input = w_shape[1]
x_shape = x_shape[2:]
w_shape = w_shape[2:]

elif tensor_format == 1:
n_output = x_shape[0]
c_output = w_shape[0]
x_chan = x_shape[-1]
w_chan_input = w_shape[-1]
assert len(x_shape) == 4, "CuDNN layout NHWC is only well-defined for 4d tensors"
x_shape = x_shape[1:-1]
w_shape = w_shape[1:-1]

elif tensor_format == 2:
n_output = x_shape[0]
c_output = w_shape[0]
x_chan = x_shape[1]
w_chan_input = w_shape[1]
w_lanes = tvm.runtime.DataType(conv_dtype).lanes
assert w_lanes == 1
x_shape = x_shape[2:]
w_shape = w_shape[2:]

else:
raise ValueError("Unknown CuDNN tensor format: '{}'".format(tensor_format))

x_lanes = tvm.runtime.DataType(data_dtype).lanes
assert x_chan * x_lanes == w_chan_input * groups, (
"Mismatched dimensions, data has {} channels/group "
"(dimension {} with {} lanes/value, {} groups), "
"but weights require {} input channels/group"
).format(x_chan // groups, x_chan, x_lanes, groups, w_chan_input)

output_dims = []
for x_shape_i, w_shape_i, pad_i, stride_i, dilation_i in zip(
x_shape, w_shape, pad, stride, dilation
):
output_dim = 1 + (x_shape_i + 2 * pad_i - (((w_shape_i - 1) * dilation_i) + 1)) // stride_i
output_dims.append(output_dim)

if tensor_format in [0, 2]:
output = [n_output, c_output, *output_dims]
elif tensor_format == 1:
output = [n_output, *output_dims, c_output]
else:
raise ValueError("Unknown CuDNN tensor format: '{}'".format(tensor_format))

return output


def _conv_output_shape_from_cudnn(
tensor_format, pad, stride, dilation, x_shape, w_shape, data_dtype, conv_dtype, groups=1
):
"""Get output shape of 2D or 3D convolution. The output of this
function should be identical to that of conv_output_shape, but
requires a GPU with CuDNN to be present. This is maintained for
testing purposes to validate the output of conv_output_shape.

Paramters
---------
tensor_format: int
0: CUDNN_TENSOR_NCHW
1: CUDNN_TENSOR_NHWC
2: CUDNN_TENSOR_NCHW_VECT_C
pad: int or list
padding
stride: int or list
stride
dilation: int or list
dilation
x_shape: list
input shape
w_shape: list
weight shape
data_dtype: str
data type
conv_dtype: str
convolution type
groups: int
number of groups

Returns
-------
oshape: list
output shape

"""
dims = len(x_shape)
assert dims in (4, 5)

Expand All @@ -217,7 +312,7 @@ def conv_output_shape(
)
oshape = np.zeros((dims), dtype=np.int32)

func = tvm._ffi.get_global_func("tvm.contrib.cudnn.conv.output_shape")
func = tvm._ffi.get_global_func("tvm.contrib.cudnn.conv.output_shape_from_cudnn")
func(
tensor_format,
dims - 2,
Expand Down
6 changes: 4 additions & 2 deletions src/runtime/contrib/cudnn/conv_forward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,9 @@ void OutputShape(int format, int dims, int groups, const int pad[], const int st
dilation, CUDNN_CROSS_CORRELATION,
entry_ptr->conv_entry.data_type));

if (dims == 2 && entry_ptr->conv_entry.tensor_format == CUDNN_TENSOR_NHWC) {
if (entry_ptr->conv_entry.tensor_format == CUDNN_TENSOR_NHWC) {
ICHECK_EQ(full_dims, 4) << "Use of layout CUDNN_TENSOR_NHWC is only defined for 4d tensors";

// Set Input
CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.input_desc,
entry_ptr->conv_entry.tensor_format, data_type, x_dim[0],
Expand Down Expand Up @@ -295,7 +297,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv3d.forward")
conv_dtype);
});

TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.output_shape")
TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.output_shape_from_cudnn")
.set_body([](TVMArgs args, TVMRetValue* ret) {
int format = args[0];
int dims = args[1];
Expand Down
106 changes: 92 additions & 14 deletions tests/python/contrib/test_cudnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import sys

import pytest

import tvm
from tvm import te
from tvm.contrib import cudnn
Expand All @@ -23,6 +28,12 @@
import tvm.testing


requires_cudnn = pytest.mark.skipif(
tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape_from_cudnn", True) is None,
reason="CuDNN is not enabled",
)


def verify_conv2d(data_dtype, conv_dtype, tensor_format=0, groups=1):
in_channel = 4
out_channel = 16
Expand All @@ -38,9 +49,6 @@ def verify_conv2d(data_dtype, conv_dtype, tensor_format=0, groups=1):
height = 32
width = 32

if not tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape", True):
print("skip because cudnn is not enabled...")
return
if data_dtype == "float16" and not have_fp16(tvm.cuda(0).compute_version):
print("Skip because gpu does not have fp16 support")
return
Expand Down Expand Up @@ -123,10 +131,6 @@ def verify_conv3d(data_dtype, conv_dtype, tensor_format=0, groups=1):
height = 32
width = 32

if not tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape", True):
print("skip because cudnn is not enabled...")
return

# schedule
xshape = [batch, in_channel, depth, height, width]
wshape = [out_channel, in_channel // groups, filter_d, filter_h, filter_w]
Expand Down Expand Up @@ -205,19 +209,93 @@ def verify_softmax_4d(shape, dtype="float32"):


@tvm.testing.requires_gpu
@requires_cudnn
def test_softmax():
if not tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape", True):
print("skip because cudnn is not enabled...")
return

verify_softmax((32, 10), -1)
verify_softmax((3, 4), -1)
verify_softmax((1, 5), -1, "float64")
verify_softmax_4d((1, 16, 256, 256))
verify_softmax_4d((1, 16, 256, 256), "float64")


test_kwargs_default_2d = {
"tensor_format": 0,
"pad": [1, 1],
"stride": [1, 1],
"dilation": [1, 1],
"x_shape": [16, 4, 32, 32],
"w_shape": [8, 4, 3, 3],
"groups": 1,
"conv_dtype": "float32",
"data_dtype": "float32",
}
test_kwargs_default_3d = {
"tensor_format": 0,
"pad": [1, 1, 1],
"stride": [1, 1, 1],
"dilation": [1, 1, 1],
"x_shape": [16, 4, 32, 32, 32],
"w_shape": [8, 4, 3, 3, 3],
"groups": 1,
"conv_dtype": "float32",
"data_dtype": "float32",
}
conv_output_shape_conditions = {
"2d_small": test_kwargs_default_2d,
"2d_large": {
**test_kwargs_default_2d,
"x_shape": [16, 32, 512, 1024],
"w_shape": [8, 32, 5, 5],
},
"2d_pad": {**test_kwargs_default_2d, "pad": [2, 3]},
"2d_stride": {**test_kwargs_default_2d, "stride": [2, 3]},
"2d_dilation": {**test_kwargs_default_2d, "dilation": [2, 3]},
"2d_groups": {**test_kwargs_default_2d, "groups": 4, "w_shape": [8, 1, 3, 3]},
"2d_NHWC": {
**test_kwargs_default_2d,
"tensor_format": 1,
"x_shape": [16, 32, 32, 4],
"w_shape": [8, 3, 3, 4],
},
"2d_NCHW_VECT_C": {
**test_kwargs_default_2d,
"tensor_format": 2,
"w_shape": [8, 16, 3, 3],
"data_dtype": "int8x4",
},
"3d_small": test_kwargs_default_3d,
"3d_large": {
**test_kwargs_default_3d,
"x_shape": [16, 32, 64, 128, 256],
"w_shape": [8, 32, 5, 5, 5],
},
"3d_pad": {**test_kwargs_default_3d, "pad": [2, 3, 4]},
"3d_stride": {**test_kwargs_default_3d, "stride": [2, 3, 4]},
"3d_dilation": {**test_kwargs_default_3d, "dilation": [2, 3, 4]},
"3d_groups": {**test_kwargs_default_3d, "groups": 4, "w_shape": [8, 1, 3, 3, 3]},
"3d_NCHW_VECT_C": {
**test_kwargs_default_3d,
"tensor_format": 2,
"w_shape": [8, 16, 3, 3, 3],
"data_dtype": "int8x4",
},
}


@pytest.fixture(
params=[pytest.param(kwargs, id=name) for name, kwargs in conv_output_shape_conditions.items()]
)
def conv_output_shape_kwargs(request):
return request.param


@tvm.testing.requires_gpu
@requires_cudnn
def test_conv_output_shape(conv_output_shape_kwargs):
shape_from_cudnn = cudnn._conv_output_shape_from_cudnn(**conv_output_shape_kwargs)
shape_from_python = cudnn.conv_output_shape(**conv_output_shape_kwargs)
assert shape_from_cudnn == shape_from_python


if __name__ == "__main__":
test_conv2d()
test_conv3d()
test_softmax()
sys.exit(pytest.main(sys.argv))
2 changes: 1 addition & 1 deletion tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ def verify_any_conv2d(
kernel_np = np.random.uniform(size=kernel_shape).astype(dtype)

targets = None
if use_cudnn and tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape", True):
if use_cudnn and tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape_from_cudnn", True):
targets = [("cuda -libs=cudnn", tvm.cuda(0))]

check_result([data_np, kernel_np], mod, ref_out_shape, assert_shape=True, targets=targets)
Expand Down