diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index c75bd2dd3c09..94ee9282e4fa 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1857,16 +1857,18 @@ def nms(self, inputs, input_types): scores = inputs[1] iou_threshold = inputs[2] + num_boxes = _op.shape_of(scores) + + # TVM NMS assumes score > 0 + scores = scores - _op.min(scores) + _op.const(1.0) # Generate data with shape (1, num_anchors, 5) scores = AttrCvt(op_name="expand_dims", extras={"axis": -1, "num_newaxis": 1})([scores], {}) - - # Prepare input data for get_valid_counts data = _op.concatenate([scores, boxes], -1) data = _op.expand_dims(data, 0, 1) - # Leverage get_valid_counts to sort the data and clear invalid boxes - ct, data, indices = get_relay_op("get_valid_counts")( - data, score_threshold=-1.0, id_index=-1, score_index=0 - ) + # PyTorch NMS doesn't have score_threshold, so no need to run get_valid_count + indices = _op.transform.arange(_op.squeeze(num_boxes), dtype="int32") + indices = _op.expand_dims(indices, 0, 1) + ct = num_boxes # Perform Non-Maximum Suppression, # PyTorch NMS doesn't have parameter top_k and max_output_size diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 74d9c78d0e3d..04f08b903bf1 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1675,10 +1675,10 @@ def _gen_rand_inputs(num_boxes): boxes = torch.rand(num_boxes, box_len, dtype=torch.float) * 0.5 boxes[:, 2] += boxes[:, 0] boxes[:, 3] += boxes[:, 1] - scores = torch.rand(num_boxes, dtype=torch.float) + scores = torch.from_numpy(np.random.uniform(-1, 1, size=(num_boxes,)).astype(np.float32)) return boxes, scores - targets = ["llvm"] # dynamic nms does not work on gpu + targets = ["llvm", "cuda"] for num_boxes, iou_thres in [(10, 0.3), (100, 0.5), (500, 0.9)]: in_boxes, in_scores = _gen_rand_inputs(num_boxes) diff --git a/tests/python/frontend/pytorch/test_object_detection.py b/tests/python/frontend/pytorch/test_object_detection.py index f5197494a345..e4545ec4ef5e 100644 --- a/tests/python/frontend/pytorch/test_object_detection.py +++ b/tests/python/frontend/pytorch/test_object_detection.py @@ -23,6 +23,7 @@ import tvm +import tvm.testing from tvm import relay from tvm.runtime.vm import VirtualMachine from tvm.contrib.download import download @@ -70,7 +71,7 @@ def generate_jit_model(index): ] model_func = model_funcs[index] - model = TraceWrapper(model_func(pretrained=True)) + model = TraceWrapper(model_func(pretrained=True, rpn_pre_nms_top_n_test=200)) model.eval() inp = torch.Tensor(np.random.uniform(0.0, 250.0, size=(1, 3, in_size, in_size))) @@ -94,46 +95,40 @@ def test_detection_models(): download(img_url, img) input_shape = (1, 3, in_size, in_size) - target = "llvm" + input_name = "input0" shape_list = [(input_name, input_shape)] - score_threshold = 0.9 scripted_model = generate_jit_model(1) mod, params = relay.frontend.from_pytorch(scripted_model, shape_list) - with tvm.transform.PassContext(opt_level=3, disabled_pass=["FoldScaleAxis"]): - vm_exec = relay.vm.compile(mod, target=target, params=params) - - ctx = tvm.cpu() - vm = VirtualMachine(vm_exec, ctx) data = process_image(img) - pt_res = scripted_model(data) - data = data.detach().numpy() - vm.set_input("main", **{input_name: data}) - tvm_res = vm.run() - - # Note: due to accumulated numerical error, we can't directly compare results - # with pytorch output. Some boxes might have a quite tiny difference in score - # and the order can become different. We just measure how many valid boxes - # there are for input image. - pt_scores = pt_res[1].detach().numpy().tolist() - tvm_scores = tvm_res[1].asnumpy().tolist() - num_pt_valid_scores = num_tvm_valid_scores = 0 - - for score in pt_scores: - if score >= score_threshold: - num_pt_valid_scores += 1 - else: - break - - for score in tvm_scores: - if score >= score_threshold: - num_tvm_valid_scores += 1 - else: - break - - assert num_pt_valid_scores == num_tvm_valid_scores, ( - "Output mismatch: Under score threshold {}, Pytorch has {} valid " - "boxes while TVM has {}.".format(score_threshold, num_pt_valid_scores, num_tvm_valid_scores) - ) + data_np = data.detach().numpy() + + with torch.no_grad(): + pt_res = scripted_model(data) + + for target in ["llvm", "cuda"]: + with tvm.transform.PassContext(opt_level=3): + vm_exec = relay.vm.compile(mod, target=target, params=params) + + ctx = tvm.context(target, 0) + vm = VirtualMachine(vm_exec, ctx) + + vm.set_input("main", **{input_name: data_np}) + tvm_res = vm.run() + + # Bounding boxes + tvm.testing.assert_allclose( + pt_res[0].cpu().numpy(), tvm_res[0].asnumpy(), rtol=1e-5, atol=1e-5 + ) + # Scores + tvm.testing.assert_allclose( + pt_res[1].cpu().numpy(), tvm_res[1].asnumpy(), rtol=1e-5, atol=1e-5 + ) + # Class ids + np.testing.assert_equal(pt_res[2].cpu().numpy(), tvm_res[2].asnumpy()) + + score_threshold = 0.9 + print("Num boxes:", pt_res[0].cpu().numpy().shape[0]) + print("Num valid boxes:", np.sum(pt_res[1].cpu().numpy() >= score_threshold))