From 22a08776cdc6267e5389816c64b61d0e27b96c01 Mon Sep 17 00:00:00 2001 From: Ritwik Das Date: Wed, 2 Dec 2020 19:57:13 -0800 Subject: [PATCH] Fix trt Test (#7016) * Fix trt Test * Fixed stuff * Done * fix 0 * Trigger Build Co-authored-by: Ubuntu --- tests/python/contrib/test_tensorrt.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/tests/python/contrib/test_tensorrt.py b/tests/python/contrib/test_tensorrt.py index de9822289528..aadfa1303655 100644 --- a/tests/python/contrib/test_tensorrt.py +++ b/tests/python/contrib/test_tensorrt.py @@ -1050,7 +1050,7 @@ def test_tensorrt_dynamic_batch(): batches_to_test = [1, 1, 0, 2, 3, 0, 1, 3, 2] x_shape = (relay.Any(), 1, 8, 8) x_data = np.ones([max(batches_to_test)] + list(x_shape)[1:]).astype("float32") - result_dict = {} + result_arr = [{} for _ in range(len(batches_to_test))] for use_trt in [True, False]: x = relay.var("x", shape=x_shape, dtype="float32") out = relay.nn.relu(x) @@ -1058,18 +1058,18 @@ def test_tensorrt_dynamic_batch(): mod = tvm.IRModule() mod["main"] = f if use_trt: - mod = relay.tensorrt.EnableTrt(mod) + mod, _ = tensorrt.partition_for_tensorrt(mod) if not skip_runtime_test(): with relay.build_config(opt_level=3): relay_exec = relay.create_executor("vm", mod=mod, ctx=tvm.cpu(0), target="llvm") for i, batch_size in enumerate(batches_to_test): - result_dict[(i, use_trt)] = relay_exec.evaluate()(x_data[:batch_size, ...]) + result_arr[i][use_trt] = relay_exec.evaluate()(x_data[:batch_size, ...]) if not skip_runtime_test(): for i in range(len(batches_to_test)): - assert_result_matches(result_dict[(i, True)], result_dict[(i, False)]) + assert_result_dict_holds(result_arr[i]) def test_tensorrt_dynamic_batch_conv(): @@ -1080,7 +1080,7 @@ def test_tensorrt_dynamic_batch_conv(): x_data = np.ones([max(batches_to_test)] + list(x_shape)[1:]).astype("float32") k_shape = (16, 32, 3, 3) params = {"kernel": np.random.uniform(-1, 1, k_shape).astype("float32")} - result_dict = {} + result_arr = [{} for _ in range(len(batches_to_test))] for use_trt in [True, False]: x = relay.var("x", shape=x_shape, dtype="float32") kernel = relay.var("kernel", shape=k_shape, dtype="float32") @@ -1089,20 +1089,18 @@ def test_tensorrt_dynamic_batch_conv(): mod = tvm.IRModule() mod["main"] = f if use_trt: - mod = tensorrt.partition_for_tensorrt(mod, params) + mod, _ = tensorrt.partition_for_tensorrt(mod, params) if not skip_runtime_test(): with relay.build_config(opt_level=3): relay_exec = relay.create_executor("vm", mod=mod, ctx=tvm.cpu(0), target="llvm") for i, batch_size in enumerate(batches_to_test): - result_dict[(i, use_trt)] = relay_exec.evaluate()( - x=x_data[:batch_size, ...], **params - ) + result_arr[i][use_trt] = relay_exec.evaluate()(x_data[:batch_size, ...], **params) if not skip_runtime_test(): for i in range(len(batches_to_test)): - assert_result_matches(result_dict[(i, True)], result_dict[(i, False)]) + assert_result_dict_holds(result_arr[i]) def test_maskrcnn_resnet50() -> None: