Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch] Fix PyTorch NMS conversion for negative scores #7137

Merged
merged 6 commits into from
Dec 26, 2020

Conversation

masahi
Copy link
Member

@masahi masahi commented Dec 20, 2020

While investigating GPU NMS performance on MaskRCNN workload, I found that PyTorch NMS can have negative scores in its input. In that case, converted Relay model produces a wrong result for the two reasons:

  • GPU NMS IR treats negative scores as invalid.
  • Our frontend is using get_valid_counts even though PyTorch doesn't have score_threshold parameter. We are arbitrary using -1 as a score_threshold, which is not correct.

The issue is fixed by adding an offset to the scores appropriately and removing a call to get_valid_counts.

please review @yongwww @kevinthesun @zhiics @t-vi

Copy link
Contributor

@t-vi t-vi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great fix, thank you @masahi !

@masahi
Copy link
Member Author

masahi commented Dec 23, 2020

@kevinthesun @zhiics I asked torchvision people about boxes with negative scores in pytorch/vision#3198 and now I fully understand the issue. See in particular this great answer pytorch/vision#3198 (comment)

My conclusion is that TVM's conversion rule for PyTorch NMS is definitely wrong and needs fixing. Here is my take away from the above discussion:

  • PyTorch detection model has two use of NMS - one in ROIHead and another in RegionProposalNetwork.
  • NMS scores in ROIHead are probability. There, they do score thresholding with user-chosen threshold before NMS. This NMS doesn't send boxes with negative scores to TVM.
  • NMS scores in RegionProposalNetwork correspond to "objectness" and they don't apply softmax or sigmoid to the output from objectness network. The scores are estimate of logit function and negative logit totally makes sense - it just mean probability is < 0.5. So a totally reasonable box can end up having a negative score. We shouldn't arbitrarily cut negative boxes from RPN.

It's highly possible that one of the reasons you didn't get the same output from pytorch detection models after compiling to TVM is due to this incorrect assumption we've been making about negative boxes, because RPN output in PyTorch/TVM are totally different - we are only considering only about half of them.

The good news is, I found a way to reduce the number of boxes that are sent to NMS. See these parameters https:/pytorch/vision/blob/master/torchvision/models/detection/mask_rcnn.py#L68-L71. They are something like topk parameter for each classes separately. The default is 1000, and there are five classes/levels in RPN NMS. So that explain why we are getting about 4500 boxes to RPN NMS. If we set the rpn_pre_nms_top_n_train to 200, we will get at most 1000 boxes, 200 boxes for each level in the feature pyramid. That will significantly make detection model go faster and still consider a lot of boxes to keep accuracy high.

So please take a look at the above issue in torchvision and my comment carefully and let's go ahead and merge this.

@zhiics
Copy link
Member

zhiics commented Dec 23, 2020

@masahi I see the problem now. Thanks for following up with the PyTorch community and fixing out the root cause. Ping @larrywyang @yongwww, please take a look since you worked on NMS more than others before.

@masahi
Copy link
Member Author

masahi commented Dec 24, 2020

@zhiics @kevinthesun @yongwww

Great news!! I've just checked that with this fix, we can now get the same number of output boxes from pytorch and TVM, which enables meaningful comparison of two outputs. We can do assert_equal with high precision on bbox coordinates and scores. Class ids are identical. Tested on both LLVM and CUDA target.

See the updated test_object_detection.py. I also added an example of setting rpn_pre_nms_top_n_test explicitly. Compared to the default value of 1000, using 200 the runtime of GPU NMS drops from 2 seconds to 90 milliseconds.

@zhiics
Copy link
Member

zhiics commented Dec 24, 2020

BTW, could you please run the Mask R-CNN benchmark to see if there is any big performance difference? There is a tutorial for it.

@masahi
Copy link
Member Author

masahi commented Dec 24, 2020

@zhiics What before and after you want to know? With this change, we need to set rpn_pre_nms_top_n_test to lower values to get reasonable performance.

I don't think comparison to old code is meaningful because it is not doing the right thing. The old code is surely faster, but that's because we were cheating.

@masahi
Copy link
Member Author

masahi commented Dec 24, 2020

With the fix in this PR applied and using rpn_pre_nms_top_n_test = 200, on GTX 1070 ti

Torch GPU: 0.093 sec
TVM GPU: 0.39 sec (no tuning or cudnn)

@zhiics
Copy link
Member

zhiics commented Dec 24, 2020

@masahi Thanks. I was just trying to see how different it would be.

@masahi
Copy link
Member Author

masahi commented Dec 24, 2020

ok I can provide NMS performance comparison as a one data point. When I was investigating GPU NMS performance issue, the old code was taking 630 milliseconds while with this fix it was 2.1 seconds. But again, that's because the new code is dealing with far more boxes and our GPU NMS code is currently extremely slow due to the sequential loop.

According to the numbers I posted in #7154, on CPU NMS is fast: the old code was spending only 8 milliseconds. So I don't expect NMS on CPU would be a big issue.

After NMS, PyTorch detection model does post-NMS topk, which selects 1000 boxes for later processing. So the perf difference should only be in NMS.

@yongwww
Copy link
Member

yongwww commented Dec 24, 2020

@masahi thanks for the analysis and fix, the use of boxes with negative scores is interesting.

Previously, we used get_valid_counts before NMS to pre-process boxes, moving invalid boxes to the bottom of box tensor, which helps reduce the computation of NMS by skipping invalid boxes. If negative boxes are valid, then NMS will do computation for all boxes, which causes performance regression. Then we should consider improving the performance.

python/tvm/relay/frontend/pytorch.py Show resolved Hide resolved
@@ -70,7 +71,7 @@ def generate_jit_model(index):
]

model_func = model_funcs[index]
model = TraceWrapper(model_func(pretrained=True))
model = TraceWrapper(model_func(pretrained=True, rpn_pre_nms_top_n_test=200))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Glad to see rpn_pre_nms_top_n_test is able to limit the proposals before nms. I am not sure if the parameter is specified for real use cases, seems using default of the parameter to do benchamarking makes more sense to me

Copy link
Member Author

@masahi masahi Dec 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the default parameter 1000 they picked is fairly conservative. This means for each level in the feature pyramid, of which there is 5 if we use resnet 50 backbone, we get maximum of 1000 x 5 boxes as input to RPN. They have another parameter rpn_post_nms_top_n_test, which is like topk applied after NMS. This value is also by default 1000 and it is not per class unlike rpn_pre_nms_top_n_test. This means we always have 1000 boxes after NMS regardless of rpn_pre_nms_top_n_test.

@yongwww
Copy link
Member

yongwww commented Dec 24, 2020

With the fix in this PR applied and using rpn_pre_nms_top_n_test = 200, on GTX 1070 ti

Torch GPU: 0.093 sec
TVM GPU: 0.39 sec (no tuning or cudnn)

How about the perf numbers with default value of rpn_pre_nms_top_n_test?

@masahi
Copy link
Member Author

masahi commented Dec 24, 2020

How about the perf numbers with default value of rpn_pre_nms_top_n_test?

It is dominated by NMS which takes 2.1 seconds with default rpn_pre_nms_top_n_test.

But even in the default setting, PyTorch NMS is fairly fast and it is not a bottleneck at all. So the root issue is our terrible GPU performance. We need to fix our NMS and there is no excuse to cheat in the frontend.

@masahi
Copy link
Member Author

masahi commented Dec 24, 2020

Also I find that gluoncv MaskRCNN mostly follows the design of PyTorch MaskRCNN. But they apply sigmoid to objectness network outputs, so in their case there are no negative scores and all inputs to RPN NMS, which could be thousand of boxes, are valid, even if "by valid" we mean the previous wrong definition of having positive score.

So without this fix, if you compare MaskRCNN performance of PyTorch and GluonCV after we compile them to TVM, PyTorch model would run much faster because our cheat would work to ignore negative boxes, while there is no negative boxes in GluonCV MaskRCNN. So equivalent models give inconsistent result after they get compiled to TVM. This is non sense.

So we have to accept the fact that we need to deal with lots of boxes in MaskRCNN.

@masahi
Copy link
Member Author

masahi commented Dec 24, 2020

ok I ran CPU perf comparison with old code vs new code + default param. As expected, there is almost no perf regression.

Old code: 2.1364135856199997 sec
After the fix in this PR: 2.1756499899999997 sec

The old code spends 5 milliseconds on NMS. Here is the stat from profile.
https:/masahi/torchscript-to-tvm/blob/master/maskrcnn/maskrcnn_cpu_vm_profile_old.txt#L51

The new code gets far more boxes, but NMS spends only 20 milliseconds on it.
https:/masahi/torchscript-to-tvm/blob/master/maskrcnn/maskrcnn_cpu_vm_profile_1000.txt#L24

The reason the new code doesn't make CPU perf worse is simple. Even though now NMS gets more boxes, PyTorch detection models does post NMS topk after RPN, see https:/pytorch/vision/blob/90645ccd0e774ad76200245e32222a23d09f2312/torchvision/models/detection/rpn.py#L261-L263

So no matter how many boxes NMS gets, the number of boxes after NMS + post NMS topk doesn't change and bounded by rpn_pre_nms_top_n_test parameter whose default is 1000. The CPU perf is always dominated by the large dense layer.

Moreover, I'd like to reiterate that the new code is more correct and gives essentially the same output as PyTorch. At this point, if you are still not happy with the this fix, I'm genuinely curious why.

@kevinthesun
Copy link
Contributor

@masahi Thanks for this investigating and improvement. Indeed this change won't affect cpu perf much even when MKL is enabled and large dynamic shape dense becomes faster. One interesting thing about output result is: for pytorch 1.7, we can exact match the results of tvm vs pt with this change, but for pytorch 1.4 there is still mismatch which won't affect final accuracy. I'm fine with this change now.

BTW, enabling MKL on my Intel Xeon Platinum machine with 18 cores can reduce the latency of pt maskrcnn from 1000 ms to 600 ms. Those large dynamic shape dense layers do contribute a lot to the latency.

Copy link
Contributor

@kevinthesun kevinthesun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. @zhiics @yongwww PTAL.

@masahi
Copy link
Member Author

masahi commented Dec 25, 2020

One interesting thing about output result is: for pytorch 1.7, we can exact match the results of tvm vs pt with this change, but for pytorch 1.4 there is still mismatch which won't affect final accuracy

Interesting, I didn't test on 1.4. If even the number of output box are different, then it must be a NMS issue, since NMS is the only thing that could change the number of box. I'm also fine if we can get the same output as the latest pytorch.

Thanks for validating my fix. Luckily, I also found a way to parallelize the inner loop of GPU NMS which should give a massive speedup. The change is this one masahi@c75c6ef but since I'm away from my GPU at the moment, I haven't tested yet. It also reduces the number of IOU tests from O(N ** 2) to O (# selected boxes * N).

Hopefully next week I can send a PR with good perf improvement on GPU NMS and hence PyTorch MaskRCNN performance on GPU.

@zhiics zhiics merged commit 4c13ae9 into apache:main Dec 26, 2020
@zhiics
Copy link
Member

zhiics commented Dec 26, 2020

Thanks @masahi @kevinthesun @yongwww

tkonolige pushed a commit to tkonolige/incubator-tvm that referenced this pull request Jan 11, 2021
* Fix pytorch nms conversion for negative scores

* updated mask rcnn test to verify outputs and also run cuda target

* set rpn_post_nms_top_n_test to 200

* fix parameter name

* dump output box information

* simplifying
TusharKanekiDey pushed a commit to TusharKanekiDey/tvm that referenced this pull request Jan 20, 2021
* Fix pytorch nms conversion for negative scores

* updated mask rcnn test to verify outputs and also run cuda target

* set rpn_post_nms_top_n_test to 200

* fix parameter name

* dump output box information

* simplifying
trevor-m pushed a commit to neo-ai/tvm that referenced this pull request Jan 21, 2021
* Fix pytorch nms conversion for negative scores

* updated mask rcnn test to verify outputs and also run cuda target

* set rpn_post_nms_top_n_test to 200

* fix parameter name

* dump output box information

* simplifying
electriclilies pushed a commit to electriclilies/tvm that referenced this pull request Feb 18, 2021
* Fix pytorch nms conversion for negative scores

* updated mask rcnn test to verify outputs and also run cuda target

* set rpn_post_nms_top_n_test to 200

* fix parameter name

* dump output box information

* simplifying
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants