Skip to content

Commit

Permalink
[Frontend][Tensorflow] SelectV2 and BroadcastArgs op support for tf2 …
Browse files Browse the repository at this point in the history
…models (apache#7901)
  • Loading branch information
srinidhigoud authored and Trevor Morris committed May 6, 2021
1 parent a36234e commit e0e1010
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 4 deletions.
43 changes: 40 additions & 3 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"""TF: Tensorflow frontend."""
import warnings
from collections import defaultdict
from collections import deque

# Numpy support
import numpy as np
Expand Down Expand Up @@ -1770,6 +1771,43 @@ def _impl(inputs, attr, params, mod):
return _impl


def _broadcast_args():
def _impl(inputs, attr, params, mod):
if isinstance(inputs[0], _expr.Var):
s0 = params[inputs[0].name_hint]
else:
s0 = _infer_value(inputs[0], params, mod)
if isinstance(inputs[1], _expr.Var):
s1 = params[inputs[1].name_hint]
else:
s1 = _infer_value(inputs[1], params, mod)
s0 = list(s0.asnumpy().reshape([-1]))
s1 = list(s1.asnumpy().reshape([-1]))
s0_size, s1_size = len(s0), len(s1)

out = deque([])
for i in range(1, min(s0_size, s1_size) + 1):
if s0[s0_size - i] == s1[s1_size - i]:
out.appendleft(s0[s0_size - i])
elif s0[s0_size - i] == 1:
out.appendleft(s1[s1_size - i])
else:
assert s1[s1_size - i] == 1, "Incompatible broadcast type %s and %s" % (
s0[s0_size - i],
s1[s1_size - i],
)
out.appendleft(s0[s0_size - i])
if s0_size < s1_size:
for i in range(s0_size + 1, s1_size + 1):
out.appendleft(s1[s1_size - i])
if s1_size < s0_size:
for i in range(s1_size + 1, s0_size + 1):
out.appendleft(s0[s0_size - i])
return _expr.const(list(out), attr["T"].name)

return _impl


def _broadcast_to():
def _impl(inputs, attr, params, mod):
if isinstance(inputs[1], _expr.Var):
Expand Down Expand Up @@ -2745,6 +2783,7 @@ def _impl(inputs, attr, params, mod):
"BatchToSpaceND": _batch_to_space_nd(),
"BiasAdd": _bias_add(),
"BroadcastTo": _broadcast_to(),
"BroadcastArgs": _broadcast_args(),
"Cast": _cast(),
"Ceil": AttrCvt("ceil"),
"CheckNumerics": _check_numerics(),
Expand Down Expand Up @@ -2838,6 +2877,7 @@ def _impl(inputs, attr, params, mod):
"Round": AttrCvt("round"),
"Rsqrt": _rsqrt(),
"Select": _where(),
"SelectV2": _where(),
"Selu": _selu(),
"Shape": _shape(),
"Sigmoid": AttrCvt("sigmoid"),
Expand Down Expand Up @@ -3941,7 +3981,6 @@ def _backtrack_construct(self, node_name):
raise ImportError("Unable to import tensorflow which is required {}".format(e))

input_op_name = node_name.split(":")[0].split("^")[-1]

if input_op_name not in self._nodes:
node = self._tf_node_map[input_op_name]
attr = self._parse_attr(node.attr)
Expand Down Expand Up @@ -4002,7 +4041,6 @@ def _backtrack_construct(self, node_name):
inputs[i] = actual_input

op = self._convert_operator(node.op, node.name, inputs, attr)

if isinstance(op, np.ndarray):
self._params[node.name] = tvm.nd.array(op)
op = [
Expand All @@ -4024,7 +4062,6 @@ def _backtrack_construct(self, node_name):
tn = node_name.split(":")
tensor_slot = int(tn[1]) if len(tn) > 1 else 0
return out[tensor_slot]

return out[0]


Expand Down
29 changes: 28 additions & 1 deletion tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3093,6 +3093,33 @@ def test_forward_resize():
_test_resize_nearest_neighbor_dynamic_shape((1, 16, 16, 3), scale=[2, 2])


#######################################################################
# BroadcastArgs
# -----------


def _test_broadcast_args(in_shape_1, in_shape_2):
""" One iteration of broadcast_args"""

shape_1 = np.array(in_shape_1).astype("int32")
shape_2 = np.array(in_shape_2).astype("int32")

with tf.Graph().as_default():
shape_1 = constant_op.constant(shape_1, shape=shape_1.shape, dtype=shape_1.dtype)
shape_2 = constant_op.constant(shape_2, shape=shape_2.shape, dtype=shape_2.dtype)
tf.raw_ops.BroadcastArgs(s0=shape_1, s1=shape_2)

compare_tf_with_tvm(None, "", "BroadcastArgs:0", opt_level=0)


def test_forward_broadcast_args():
""" Resize Bilinear """

_test_broadcast_args((4, 1, 32, 32), [4, 8, 32, 32])
_test_broadcast_args((6, 32, 32, 1), [6, 32, 32, 16])
_test_broadcast_args((32, 32, 16), [6, 32, 32, 16])


#######################################################################
# BroadcastTo
# -----------
Expand Down Expand Up @@ -3636,7 +3663,7 @@ def test_forward_logical():


#######################################################################
# Where, Select
# Where, Select, SelectV2
# -------------
def test_forward_where():
""" Where: return elements depending on conditions"""
Expand Down

0 comments on commit e0e1010

Please sign in to comment.