Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
laiyin.lyc committed May 18, 2021
1 parent b29648c commit ab98b22
Showing 1 changed file with 23 additions and 30 deletions.
53 changes: 23 additions & 30 deletions tests/python/topi/python/test_topi_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@
"x86": (topi.nn.sparse_dense, topi.x86.schedule_sparse_dense),
}

_sparse_conv2d_implement = {"generic": (topi.nn.sparse_conv2d, topi.generic.schedule_sparse_conv2d)}


def verify_dynamic_csrmv(batch, in_dim, out_dim, use_bias=True):
nr, nc, n = te.var("nr"), te.var("nc"), te.var("n")
Expand Down Expand Up @@ -554,52 +552,47 @@ def test_sparse_add_csr():
tvm.testing.assert_allclose(Z_tvm.asnumpy(), Z_np, atol=1e-4, rtol=1e-4)


def verify_sparse_conv2d_bsr(M, H, W, N, K, BS_R, BS_C, density, use_relu, device, target, layout):
def test_sparse_conv2d_bsr(M, H, W, N, K, BS_R, BS_C, density, layout):
if layout == "NHWC":
X_np = np.random.randn(M, H, W, K).astype("float32")
elif layout == "NCHW":
X_np = np.random.randn(M, K, H, W).astype("float32")
W_sp_np = random_bsr_matrix(N, K, BS_R, BS_C, density=density, dtype="float32")
W_np = W_sp_np.todense()
if layout == "NHWC":
Y_np = tvm.topi.testing.conv2d_nhwc_python(X_np, W_np.reshape(1, 1, K, N), 1, 0)
Y_np = tvm.topi.testing.conv2d_nhwc_python(X_np, np.array(W_np).T.reshape(1, 1, K, N), 1, 0)
elif layout == "NCHW":
Y_np = tvm.topi.testing.conv2d_nchw_python(X_np, W_np.reshape(N, K, 1, 1), 1, 0)
if use_relu:
Y_np = np.maximum(Y_np, 0.0)
Y_np = tvm.topi.testing.conv2d_nchw_python(X_np, np.array(W_np).reshape(N, K, 1, 1), 1, 0)

if BS_C == 1:
W_data = te.placeholder(shape=W_sp_np.data.shape[:-1], dtype=str(W_sp_np.data.dtype))
W_sp_np_data = W_sp_np.data.reshape(W_sp_np.data.shape[0], BS_R)
else:
W_data = te.placeholder(shape=W_sp_np.data.shape, dtype=str(W_sp_np.data.dtype))
W_sp_np_data = W_sp_np.data
W_indices = te.placeholder(shape=W_sp_np.indices.shape, dtype=str(W_sp_np.indices.dtype))
W_indptr = te.placeholder(shape=W_sp_np.indptr.shape, dtype=str(W_sp_np.indptr.dtype))
X = te.placeholder(shape=X_np.shape, dtype=str(X_np.dtype))

fcompute, fschedule = tvm.topi.testing.dispatch(target, _sparse_conv2d_implement)
with tvm.target.Target(target):
Y = fcompute(X, W_data, W_indices, W_indptr, layout)
if use_relu:
Y = topi.nn.relu(Y)
s = fschedule([Y])
func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y])
Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype), device=device)
func(
tvm.nd.array(X_np, device=device),
tvm.nd.array(W_sp_np.data, device=device),
tvm.nd.array(W_sp_np.indices, device=device),
tvm.nd.array(W_sp_np.indptr, device=device),
Y_tvm,
)
tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-4, rtol=1e-4)
Y = topi.nn.sparse_conv2d(X, W_data, W_indices, W_indptr, layout)
s = te.create_schedule(Y.op)
func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y])
Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype="float32"))
func(
tvm.nd.array(X_np),
tvm.nd.array(W_sp_np_data),
tvm.nd.array(W_sp_np.indices),
tvm.nd.array(W_sp_np.indptr),
Y_tvm,
)
tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np.astype("float32"), atol=1e-4, rtol=1e-4)


@tvm.testing.parametrize_targets("llvm", "cuda")
def test_sparse_conv2d_bsr_relu(dev, target):
M, H, W, N, K, BS_R, BS_C, density = 1, 32, 32, 64, 128, 8, 16, 0.9
verify_sparse_conv2d_bsr(M, H, W, N, K, BS_R, BS_C, density, True, dev, target, "NHWC")
verify_sparse_conv2d_bsr(M, H, W, N, K, BS_R, BS_C, density, False, dev, target, "NCHW")
verify_sparse_conv2d_bsr(M, H, W, N, K, BS_R, 1, density, False, dev, target, "NHWC")
def test_sparse_conv2d():
M, H, W, N, K, BS_R, BS_C, density = 1, 32, 32, 128, 64, 8, 16, 0.9
test_sparse_conv2d_bsr(M, H, W, N, K, BS_R, BS_C, density, "NHWC")
test_sparse_conv2d_bsr(M, H, W, N, K, BS_R, BS_C, density, "NCHW")
test_sparse_conv2d_bsr(M, H, W, N, K, BS_R, 1, density, "NHWC")


if __name__ == "__main__":
Expand All @@ -614,4 +607,4 @@ def test_sparse_conv2d_bsr_relu(dev, target):
test_sparse_dense_csr_reverse()
test_sparse_dense_bsr_reverse()
test_sparse_add_csr()
test_sparse_conv2d_bsr_relu()
test_sparse_conv2d()

0 comments on commit ab98b22

Please sign in to comment.