Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch] Fix PyTorch NMS conversion for negative scores #7137

Merged
merged 6 commits into from
Dec 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1857,16 +1857,18 @@ def nms(self, inputs, input_types):
scores = inputs[1]
iou_threshold = inputs[2]

num_boxes = _op.shape_of(scores)

# TVM NMS assumes score > 0
scores = scores - _op.min(scores) + _op.const(1.0)
# Generate data with shape (1, num_anchors, 5)
scores = AttrCvt(op_name="expand_dims", extras={"axis": -1, "num_newaxis": 1})([scores], {})

# Prepare input data for get_valid_counts
data = _op.concatenate([scores, boxes], -1)
data = _op.expand_dims(data, 0, 1)
# Leverage get_valid_counts to sort the data and clear invalid boxes
ct, data, indices = get_relay_op("get_valid_counts")(
data, score_threshold=-1.0, id_index=-1, score_index=0
)
# PyTorch NMS doesn't have score_threshold, so no need to run get_valid_count
masahi marked this conversation as resolved.
Show resolved Hide resolved
indices = _op.transform.arange(_op.squeeze(num_boxes), dtype="int32")
indices = _op.expand_dims(indices, 0, 1)
ct = num_boxes

# Perform Non-Maximum Suppression,
# PyTorch NMS doesn't have parameter top_k and max_output_size
Expand Down
4 changes: 2 additions & 2 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1675,10 +1675,10 @@ def _gen_rand_inputs(num_boxes):
boxes = torch.rand(num_boxes, box_len, dtype=torch.float) * 0.5
boxes[:, 2] += boxes[:, 0]
boxes[:, 3] += boxes[:, 1]
scores = torch.rand(num_boxes, dtype=torch.float)
scores = torch.from_numpy(np.random.uniform(-1, 1, size=(num_boxes,)).astype(np.float32))
return boxes, scores

targets = ["llvm"] # dynamic nms does not work on gpu
targets = ["llvm", "cuda"]

for num_boxes, iou_thres in [(10, 0.3), (100, 0.5), (500, 0.9)]:
in_boxes, in_scores = _gen_rand_inputs(num_boxes)
Expand Down
69 changes: 32 additions & 37 deletions tests/python/frontend/pytorch/test_object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import tvm

import tvm.testing
from tvm import relay
from tvm.runtime.vm import VirtualMachine
from tvm.contrib.download import download
Expand Down Expand Up @@ -70,7 +71,7 @@ def generate_jit_model(index):
]

model_func = model_funcs[index]
model = TraceWrapper(model_func(pretrained=True))
model = TraceWrapper(model_func(pretrained=True, rpn_pre_nms_top_n_test=200))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Glad to see rpn_pre_nms_top_n_test is able to limit the proposals before nms. I am not sure if the parameter is specified for real use cases, seems using default of the parameter to do benchamarking makes more sense to me

Copy link
Member Author

@masahi masahi Dec 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the default parameter 1000 they picked is fairly conservative. This means for each level in the feature pyramid, of which there is 5 if we use resnet 50 backbone, we get maximum of 1000 x 5 boxes as input to RPN. They have another parameter rpn_post_nms_top_n_test, which is like topk applied after NMS. This value is also by default 1000 and it is not per class unlike rpn_pre_nms_top_n_test. This means we always have 1000 boxes after NMS regardless of rpn_pre_nms_top_n_test.


model.eval()
inp = torch.Tensor(np.random.uniform(0.0, 250.0, size=(1, 3, in_size, in_size)))
Expand All @@ -94,46 +95,40 @@ def test_detection_models():
download(img_url, img)

input_shape = (1, 3, in_size, in_size)
target = "llvm"

input_name = "input0"
shape_list = [(input_name, input_shape)]
score_threshold = 0.9

scripted_model = generate_jit_model(1)
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)

with tvm.transform.PassContext(opt_level=3, disabled_pass=["FoldScaleAxis"]):
vm_exec = relay.vm.compile(mod, target=target, params=params)

ctx = tvm.cpu()
vm = VirtualMachine(vm_exec, ctx)
data = process_image(img)
pt_res = scripted_model(data)
data = data.detach().numpy()
vm.set_input("main", **{input_name: data})
tvm_res = vm.run()

# Note: due to accumulated numerical error, we can't directly compare results
# with pytorch output. Some boxes might have a quite tiny difference in score
# and the order can become different. We just measure how many valid boxes
# there are for input image.
pt_scores = pt_res[1].detach().numpy().tolist()
tvm_scores = tvm_res[1].asnumpy().tolist()
num_pt_valid_scores = num_tvm_valid_scores = 0

for score in pt_scores:
if score >= score_threshold:
num_pt_valid_scores += 1
else:
break

for score in tvm_scores:
if score >= score_threshold:
num_tvm_valid_scores += 1
else:
break

assert num_pt_valid_scores == num_tvm_valid_scores, (
"Output mismatch: Under score threshold {}, Pytorch has {} valid "
"boxes while TVM has {}.".format(score_threshold, num_pt_valid_scores, num_tvm_valid_scores)
)
data_np = data.detach().numpy()

with torch.no_grad():
pt_res = scripted_model(data)

for target in ["llvm", "cuda"]:
with tvm.transform.PassContext(opt_level=3):
vm_exec = relay.vm.compile(mod, target=target, params=params)

ctx = tvm.context(target, 0)
vm = VirtualMachine(vm_exec, ctx)

vm.set_input("main", **{input_name: data_np})
tvm_res = vm.run()

# Bounding boxes
tvm.testing.assert_allclose(
pt_res[0].cpu().numpy(), tvm_res[0].asnumpy(), rtol=1e-5, atol=1e-5
)
# Scores
tvm.testing.assert_allclose(
pt_res[1].cpu().numpy(), tvm_res[1].asnumpy(), rtol=1e-5, atol=1e-5
)
# Class ids
np.testing.assert_equal(pt_res[2].cpu().numpy(), tvm_res[2].asnumpy())

score_threshold = 0.9
print("Num boxes:", pt_res[0].cpu().numpy().shape[0])
print("Num valid boxes:", np.sum(pt_res[1].cpu().numpy() >= score_threshold))