Skip to content

Commit

Permalink
[TFLite] added scalar axis value handling in reduce
Browse files Browse the repository at this point in the history
Axis value in reduce can now be specified as scalar
  • Loading branch information
d-smirnov committed Nov 24, 2020
1 parent 448278d commit 9a263a8
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 8 deletions.
3 changes: 2 additions & 1 deletion python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -1638,7 +1638,8 @@ def _convert_reduce(self, relay_op, op):
in_expr = self.get_expr(input_tensor.tensor_idx)

# axis
axis = tuple(self.get_tensor_value(input_tensors[1]))
axis_value = self.get_tensor_value(input_tensors[1])
axis = tuple(axis_value) if len(axis_value.shape) > 0 else tuple((axis_value.item(),))

# Options - keep_dims (bool)
assert op.BuiltinOptionsType() == BuiltinOptions.ReducerOptions
Expand Down
18 changes: 11 additions & 7 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2143,18 +2143,22 @@ def _test_forward_reduce(testop, dtype="float32"):
if dtype == "bool":
data0 = [np.random.choice(a=[False, True], size=(16, 16, 16, 16)).astype(dtype), None]
data1 = [
np.random.choice(a=[False, True], size=(16, 16, 16, 16)).astype(dtype),
np.array(1, dtype=np.int32),
]
data2 = [
np.random.choice(a=[False, True], size=(16, 16, 16, 16)).astype(dtype),
np.array([1, 2], dtype=np.int32),
]
else:
data0 = [np.random.rand(16, 16, 16, 16).astype(dtype), None]
data1 = [np.random.rand(16, 16, 16, 16).astype(dtype), np.array([1, 2], dtype=np.int32)]
testop(data0)
testop(data0, keep_dims=False)
testop(data0, keep_dims=True)
testop(data1)
testop(data1, keep_dims=False)
testop(data1, keep_dims=True)
data1 = [np.random.rand(16, 16, 16, 16).astype(dtype), np.array(1, dtype=np.int32)]
data2 = [np.random.rand(16, 16, 16, 16).astype(dtype), np.array([1, 2], dtype=np.int32)]

for data in [data0, data1, data2]:
testop(data)
testop(data, keep_dims=False)
testop(data, keep_dims=True)


def _test_forward_reduce_quantized(testop):
Expand Down

0 comments on commit 9a263a8

Please sign in to comment.