[THRUST] Faster multi dimensional argsort by segmented sort #7195
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Current implementation of thrust argsort, when given multi dimensional inputs to sort along the inner most axis, is very inefficient: it does
n_iter
calls to thrust sort. Seetvm/src/runtime/contrib/thrust/thrust.cu
Lines 50 to 65 in bad149e
When the outer dimension is large, the performance of thrust argsort is far from optimal. In particular, the thrust numbers shown in the TIR mergesort PR #7099 do not reflect the true performance thrust can achieve.
This PR replaces
n_iter
calls to thrust argsort with one segmented sort by key. Since thrust doesn't provide API to do segmented sort, I used a neat back-to-back stable-sort-by-key trick explained in https://groups.google.com/forum/#!topic/thrust-users/BoLsxO6b4FY. My implementation is a bit more complicated because we need to do segmented sort by key, not just segmented sort.Here are the numbers I get using the same benchmark script used in #7099, measured on GTX 1070 ti. When the outer dimension is small (like 2, 2, 2000 case), my change makes it slower due to the overhead from two calls to
stable_sort_by_key
. But other than that, it is much faster than one we have now.Also, I removed
tvm.contrib.thrust.sort_nms
andargsort_nms_thrust
, since they are not used anymore.please review @kazum @Laurawly
(cc @mbrookhart when you are back, this should be exciting for you!)