Skip to content

Commit

Permalink
[TensorFlow] Support NonMaxSuppressionV5 (apache#6933)
Browse files Browse the repository at this point in the history
  • Loading branch information
Trevor Morris committed Dec 2, 2020
1 parent 634b20c commit aaee1b1
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 3 deletions.
13 changes: 12 additions & 1 deletion python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,7 @@ def _impl(inputs, attr, params, mod):
return _impl


def _nms():
def _nms(return_scores=False):
def _impl(inputs, attr, params, mod):
# Get parameter values
try:
Expand Down Expand Up @@ -724,6 +724,16 @@ def _impl(inputs, attr, params, mod):
ret = get_relay_op("strided_slice")(
data_slice, begin=_expr.const([0]), end=size, slice_mode="size"
)

# NonMaxSuppressionV5 returns scores. pad_output is always False for NMSv5.
if return_scores:
if "soft_nms_sigma" in attr and attr["soft_nms_sigma"] != 0.0:
raise tvm.error.OpAttributeUnImplemented(
"soft_nms_sigma for NonMaxSuppressionV5 is not supported"
)
ret_scores = _op.take(inputs[1], ret, axis=0)
return _expr.TupleWrapper(_expr.Tuple([ret, ret_scores, size]), 3)

return ret

return _impl
Expand Down Expand Up @@ -2354,6 +2364,7 @@ def _impl(inputs, attr, params, mod):
"NonMaxSuppressionV2": _nms(),
"NonMaxSuppressionV3": _nms(),
"NonMaxSuppressionV4": _nms(),
"NonMaxSuppressionV5": _nms(True),
"NoOp": _no_op(),
"NotEqual": _broadcast("not_equal"),
"OneHot": _one_hot(),
Expand Down
30 changes: 28 additions & 2 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2662,9 +2662,35 @@ def _test_forward_nms_v4(
)


def _test_forward_nms_v5(
bx_shape, score_shape, iou_threshold, score_threshold, out_size, dtype="float32"
):
boxes = np.random.uniform(0, 10, size=bx_shape).astype(dtype)
scores = np.random.uniform(size=score_shape).astype(dtype)
max_output_size = np.int32(out_size)
tf.reset_default_graph()
in_data_1 = tf.placeholder(dtype, boxes.shape, name="in_data_1")
in_data_2 = tf.placeholder(dtype, scores.shape, name="in_data_2")
in_data_3 = tf.placeholder(tf.int32, name="in_data_3")
tf.image.non_max_suppression_with_scores(
boxes=in_data_1,
scores=in_data_2,
max_output_size=in_data_3,
iou_threshold=iou_threshold,
score_threshold=score_threshold,
name="nms",
)
compare_tf_with_tvm(
[boxes, scores, max_output_size],
["in_data_1:0", "in_data_2:0", "in_data_3:0"],
["nms/NonMaxSuppressionV5:0", "nms/NonMaxSuppressionV5:1"],
mode="vm",
)


def test_forward_nms():
""" NonMaxSuppressionV3,4 """
for _test_forward_nms in [_test_forward_nms_v3]:
""" NonMaxSuppressionV3,5 """
for _test_forward_nms in [_test_forward_nms_v3, _test_forward_nms_v5]:
_test_forward_nms((5, 4), (5,), 0.7, 0.5, 5)
_test_forward_nms((20, 4), (20,), 0.5, 0.6, 10)
_test_forward_nms((1000, 4), (1000,), 0.3, 0.7, 1000)
Expand Down

0 comments on commit aaee1b1

Please sign in to comment.