diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 6dd164c6e35e..50221c7baf28 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -19,6 +19,7 @@ """TF: Tensorflow frontend.""" import warnings from collections import defaultdict +from collections import deque # Numpy support import numpy as np @@ -1765,6 +1766,43 @@ def _impl(inputs, attr, params, mod): return _impl +def _broadcast_args(): + def _impl(inputs, attr, params, mod): + if isinstance(inputs[0], _expr.Var): + s0 = params[inputs[0].name_hint] + else: + s0 = _infer_value(inputs[0], params, mod) + if isinstance(inputs[1], _expr.Var): + s1 = params[inputs[1].name_hint] + else: + s1 = _infer_value(inputs[1], params, mod) + s0 = list(s0.asnumpy().reshape([-1])) + s1 = list(s1.asnumpy().reshape([-1])) + s0_size, s1_size = len(s0), len(s1) + + out = deque([]) + for i in range(1, min(s0_size, s1_size) + 1): + if s0[s0_size - i] == s1[s1_size - i]: + out.appendleft(s0[s0_size - i]) + elif s0[s0_size - i] == 1: + out.appendleft(s1[s1_size - i]) + else: + assert s1[s1_size - i] == 1, "Incompatible broadcast type %s and %s" % ( + s0[s0_size - i], + s1[s1_size - i], + ) + out.appendleft(s0[s0_size - i]) + if s0_size < s1_size: + for i in range(s0_size + 1, s1_size + 1): + out.appendleft(s1[s1_size - i]) + if s1_size < s0_size: + for i in range(s1_size + 1, s0_size + 1): + out.appendleft(s0[s0_size - i]) + return _expr.const(list(out), attr["T"].name) + + return _impl + + def _broadcast_to(): def _impl(inputs, attr, params, mod): if isinstance(inputs[1], _expr.Var): @@ -2740,6 +2778,7 @@ def _impl(inputs, attr, params, mod): "BatchToSpaceND": _batch_to_space_nd(), "BiasAdd": _bias_add(), "BroadcastTo": _broadcast_to(), + "BroadcastArgs": _broadcast_args(), "Cast": _cast(), "Ceil": AttrCvt("ceil"), "CheckNumerics": _check_numerics(), @@ -2833,6 +2872,7 @@ def _impl(inputs, attr, params, mod): "Round": AttrCvt("round"), "Rsqrt": _rsqrt(), "Select": _where(), + "SelectV2": _where(), "Selu": _selu(), "Shape": _shape(), "Sigmoid": AttrCvt("sigmoid"), @@ -3936,7 +3976,6 @@ def _backtrack_construct(self, node_name): raise ImportError("Unable to import tensorflow which is required {}".format(e)) input_op_name = node_name.split(":")[0].split("^")[-1] - if input_op_name not in self._nodes: node = self._tf_node_map[input_op_name] attr = self._parse_attr(node.attr) @@ -3997,7 +4036,6 @@ def _backtrack_construct(self, node_name): inputs[i] = actual_input op = self._convert_operator(node.op, node.name, inputs, attr) - if isinstance(op, np.ndarray): self._params[node.name] = tvm.nd.array(op) op = [ @@ -4019,7 +4057,6 @@ def _backtrack_construct(self, node_name): tn = node_name.split(":") tensor_slot = int(tn[1]) if len(tn) > 1 else 0 return out[tensor_slot] - return out[0] diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 8446ef3d590b..f8d3e2386070 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -3050,6 +3050,33 @@ def test_forward_resize(): _test_resize_nearest_neighbor_dynamic_shape((1, 16, 16, 3), scale=[2, 2]) +####################################################################### +# BroadcastArgs +# ----------- + + +def _test_broadcast_args(in_shape_1, in_shape_2): + """ One iteration of broadcast_args""" + + shape_1 = np.array(in_shape_1).astype("int32") + shape_2 = np.array(in_shape_2).astype("int32") + + with tf.Graph().as_default(): + shape_1 = constant_op.constant(shape_1, shape=shape_1.shape, dtype=shape_1.dtype) + shape_2 = constant_op.constant(shape_2, shape=shape_2.shape, dtype=shape_2.dtype) + tf.raw_ops.BroadcastArgs(s0=shape_1, s1=shape_2) + + compare_tf_with_tvm(None, "", "BroadcastArgs:0", opt_level=0) + + +def test_forward_broadcast_args(): + """ Resize Bilinear """ + + _test_broadcast_args((4, 1, 32, 32), [4, 8, 32, 32]) + _test_broadcast_args((6, 32, 32, 1), [6, 32, 32, 16]) + _test_broadcast_args((32, 32, 16), [6, 32, 32, 16]) + + ####################################################################### # BroadcastTo # ----------- @@ -3593,7 +3620,7 @@ def test_forward_logical(): ####################################################################### -# Where, Select +# Where, Select, SelectV2 # ------------- def test_forward_where(): """ Where: return elements depending on conditions"""