From fcc9be50f020027cc9947f93d507b9e03c34e089 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Tue, 6 Oct 2020 14:07:26 -0700 Subject: [PATCH] [Topi] Allow batch_matmul to broadcast along batch dimension. (#6616) * Allow batch_matmul to broadcast along batch dimension. * Added typerel checking. * Fix style issue and respond to feedback. * Fix style. * More formatting issues :( * Fix issues after merge. * Comment update. * Small tweak. --- include/tvm/topi/nn/batch_matmul.h | 67 ------------------- python/tvm/relay/frontend/onnx.py | 9 --- python/tvm/topi/nn/batch_matmul.py | 28 +++++--- python/tvm/topi/testing/batch_matmul.py | 7 +- python/tvm/topi/x86/batch_matmul.py | 10 +-- src/relay/op/nn/nn.cc | 6 +- src/topi/nn.cc | 6 -- tests/python/frontend/onnx/test_forward.py | 1 - .../topi/python/test_topi_batch_matmul.py | 21 +++--- 9 files changed, 43 insertions(+), 112 deletions(-) delete mode 100644 include/tvm/topi/nn/batch_matmul.h diff --git a/include/tvm/topi/nn/batch_matmul.h b/include/tvm/topi/nn/batch_matmul.h deleted file mode 100644 index bffddca8010f..000000000000 --- a/include/tvm/topi/nn/batch_matmul.h +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief Batch matmul op constructions - * \file nn/batch_matmul.h - */ -#ifndef TVM_TOPI_NN_BATCH_MATMUL_H_ -#define TVM_TOPI_NN_BATCH_MATMUL_H_ - -#include -#include - -#include - -namespace tvm { -namespace topi { -namespace nn { - -using namespace tvm::te; - -/*! - * \brief Creates an operation that calculates matrix multiplication in batch. - * - * \param x Tensor with shape [batch, M, K] - * \param y Tensor with shape [batch, N, K] - * - * \return Tensor with shape [batch, M, N] - */ -inline tvm::te::Tensor batch_matmul(const tvm::te::Tensor& x, const tvm::te::Tensor& y) { - CHECK_EQ(x->shape.size(), 3) << "batch_matmul requires 3-D data"; - CHECK_EQ(y->shape.size(), 3) << "batch_matmul requires 3-D data"; - - auto batch = x->shape[0]; - auto M = x->shape[1]; - auto K = x->shape[2]; - auto N = y->shape[1]; - - auto k = tvm::te::reduce_axis(Range(0, K), "k"); - auto result = tvm::te::compute( - {batch, M, N}, [&](Var b, Var i, Var j) { return tvm::sum(x(b, i, k) * y(b, j, k), {k}); }, - "tensor", "batch_matmul"); - - return result; -} - -} // namespace nn -} // namespace topi -} // namespace tvm - -#endif // TVM_TOPI_NN_BATCH_MATMUL_H_ diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 59fdb32d1a16..f4cf9572b93f 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -539,15 +539,6 @@ def flatten_to_3d(x, x_shape): # Convert a and b into 3 dimensional tensors. a = flatten_to_3d(inputs[0], a_shape) b = flatten_to_3d(inputs[1], b_shape) - # Broadcast b to match batch size of a - new_b_shape = _op.concatenate( - [ - _op.strided_slice(_op.shape_of(a), [0], [1]), - _op.strided_slice(_op.shape_of(b), [1], [3]), - ], - 0, - ) - b = _op.broadcast_to(b, new_b_shape) # Transpose matrix dimensions of b. b = _op.transpose(b, [0, 2, 1]) # Perform a batch matmul. diff --git a/python/tvm/topi/nn/batch_matmul.py b/python/tvm/topi/nn/batch_matmul.py index 34a8c6dafc87..9b926a1182d8 100644 --- a/python/tvm/topi/nn/batch_matmul.py +++ b/python/tvm/topi/nn/batch_matmul.py @@ -22,7 +22,7 @@ def batch_matmul(x, y, oshape=None): """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are - data in batch. + data in batch. Supports broadcasting for batch dimension. Parameters ---------- @@ -32,24 +32,30 @@ def batch_matmul(x, y, oshape=None): y : tvm.te.Tensor 3-D with shape [batch, N, K] + oshape : List[Optional] + Explicit intended output shape of the computation. Can be useful in cases + with dynamic input shapes. + Returns ------- output : tvm.te.Tensor 3-D with shape [batch, M, N] """ + assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim batch_matmul" + x_shape = get_const_tuple(x.shape) + y_shape = get_const_tuple(y.shape) + XB = x_shape[0] + YB = y_shape[0] + _, M, K = x.shape + k = te.reduce_axis((0, K), name="k") if oshape is None: - assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim batch_matmul" - x_shape = get_const_tuple(x.shape) - y_shape = get_const_tuple(y.shape) - assert x_shape[0] == y_shape[0], "batch dimension doesn't match" + assert XB == YB or XB == 1 or YB == 1, "batch dimension doesn't match" assert x_shape[2] == y_shape[2], "shapes of x and y is inconsistant" - batch, M, K = x.shape + batch = max(XB, YB) N = y.shape[1] - k = te.reduce_axis((0, K), name="k") oshape = (batch, M, N) - else: - _, _, K = x.shape - k = te.reduce_axis((0, K), name="k") return te.compute( - oshape, lambda b, i, j: te.sum(x[b, i, k] * y[b, j, k], axis=k), tag="batch_matmul" + oshape, + lambda b, i, j: te.sum(x[b if XB != 1 else 0, i, k] * y[b if YB != 1 else 0, j, k], axis=k), + tag="batch_matmul", ) diff --git a/python/tvm/topi/testing/batch_matmul.py b/python/tvm/topi/testing/batch_matmul.py index 0a991f633c3c..a48c92967c77 100644 --- a/python/tvm/topi/testing/batch_matmul.py +++ b/python/tvm/topi/testing/batch_matmul.py @@ -35,9 +35,10 @@ def batch_matmul(x, y): out : numpy.ndarray 3-D with shape [batch, M, N] """ - batch, M, _ = x.shape - N = y.shape[1] + XB, M, _ = x.shape + YB, N, _ = y.shape + batch = max(XB, YB) out = np.zeros((batch, M, N)).astype(x.dtype) for i in range(batch): - out[i] = np.dot(x[i], y[i].T) + out[i] = np.dot(x[i if XB != 1 else 0], y[i if YB != 1 else 0].T) return out diff --git a/python/tvm/topi/x86/batch_matmul.py b/python/tvm/topi/x86/batch_matmul.py index c095dcb0b6bb..e3f08160509e 100644 --- a/python/tvm/topi/x86/batch_matmul.py +++ b/python/tvm/topi/x86/batch_matmul.py @@ -27,7 +27,7 @@ @autotvm.register_topi_compute("batch_matmul.x86") def batch_matmul(cfg, x, y, out_shape=None): """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are - data in batch. + data in batch. Supports broadcasting in batch dimension. Parameters ---------- @@ -45,9 +45,9 @@ def batch_matmul(cfg, x, y, out_shape=None): assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim batch_matmul" XB, M, XK = get_const_tuple(x.shape) YB, N, YK = get_const_tuple(y.shape) - assert XB == YB, "batch dimension doesn't match" + assert (XB == YB) or (YB == 1) or (XB == 1), "batch dimension doesn't match" assert XK == YK, "shapes of x and y is inconsistant" - B = XB + B = max(XB, YB) K = XK if out_shape is not None: assert out_shape[0] == B, "got invalid output shape" @@ -58,7 +58,9 @@ def batch_matmul(cfg, x, y, out_shape=None): k = te.reduce_axis((0, K), name="k") C = te.compute( - (B, M, N), lambda b, i, j: te.sum(x[b, i, k] * y[b, j, k], axis=k), tag="batch_matmul" + (B, M, N), + lambda b, i, j: te.sum(x[b if XB != 1 else 0, i, k] * y[b if YB != 1 else 0, j, k], axis=k), + tag="batch_matmul", ) return C diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 38ebe421d38d..1de7ca003772 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -33,6 +33,7 @@ #include #include +#include #include #include @@ -862,8 +863,9 @@ bool BatchMatmulRel(const Array& types, int num_inputs, const Attrs& attrs } } if (!is_dyn) { - CHECK(reporter->AssertEQ(x->shape[0], y->shape[0])) - << "BatchDot: batch dimension doesn't match, " + CHECK(reporter->AssertEQ(x->shape[0], y->shape[0]) || reporter->AssertEQ(x->shape[0], 1) || + reporter->AssertEQ(y->shape[0], 1)) + << "BatchDot: batch dimensions don't match, " << " x shape=" << x->shape << ", y shape=" << y->shape; CHECK(reporter->AssertEQ(x->shape[2], y->shape[2])) << "BatchDot: shapes of x and y is inconsistent, " diff --git a/src/topi/nn.cc b/src/topi/nn.cc index c03d1b056d35..2c9546507de6 100644 --- a/src/topi/nn.cc +++ b/src/topi/nn.cc @@ -24,7 +24,6 @@ #include #include #include -#include #include #include #include @@ -68,11 +67,6 @@ TVM_REGISTER_GLOBAL("topi.nn.bias_add").set_body([](TVMArgs args, TVMRetValue* r *rv = nn::bias_add(args[0], args[1], args[2]); }); -/* Ops from nn/batch_matmul.h */ -TVM_REGISTER_GLOBAL("topi.nn.batch_matmul").set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = nn::batch_matmul(args[0], args[1]); -}); - /* Ops from nn/dilate.h */ TVM_REGISTER_GLOBAL("topi.nn.dilate").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = nn::dilate(args[0], args[1], args[2]); diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 1aeb430de52f..da8629dfcd2b 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -3628,7 +3628,6 @@ def verify_roi_align( test_clip_min_max_as_inputs() test_onehot() test_matmul() - test_batch_matmul() test_gather() test_gatherelements() test_gather_nd() diff --git a/tests/python/topi/python/test_topi_batch_matmul.py b/tests/python/topi/python/test_topi_batch_matmul.py index 769822e98488..0d82ee69fa26 100644 --- a/tests/python/topi/python/test_topi_batch_matmul.py +++ b/tests/python/topi/python/test_topi_batch_matmul.py @@ -32,16 +32,16 @@ } -def verify_batch_matmul(batch, M, N, K): - x = te.placeholder((batch, M, K), name="x") - y = te.placeholder((batch, N, K), name="y") +def verify_batch_matmul(x_batch, y_batch, M, N, K): + x = te.placeholder((x_batch, M, K), name="x") + y = te.placeholder((y_batch, N, K), name="y") dtype = x.dtype # use memoize to pickle the test data for next time use @memoize("topi.tests.test_topi_batch_matmul") def get_ref_data(): - a_np = np.random.uniform(size=(batch, M, K)).astype(dtype) - b_np = np.random.uniform(size=(batch, N, K)).astype(dtype) + a_np = np.random.uniform(size=(x_batch, M, K)).astype(dtype) + b_np = np.random.uniform(size=(y_batch, N, K)).astype(dtype) c_np = tvm.topi.testing.batch_matmul(a_np, b_np) return (a_np, b_np, c_np) @@ -67,10 +67,13 @@ def check_device(device, ctx): @tvm.testing.uses_gpu def test_batch_matmul(): - verify_batch_matmul(1, 16, 16, 32) - verify_batch_matmul(5, 16, 16, 32) - verify_batch_matmul(5, 16, 20, 32) - verify_batch_matmul(30, 16, 20, 32) + verify_batch_matmul(1, 1, 16, 16, 32) + verify_batch_matmul(5, 5, 16, 16, 32) + verify_batch_matmul(5, 5, 16, 20, 32) + verify_batch_matmul(30, 30, 16, 20, 32) + # Test batch broadcasting. + verify_batch_matmul(1, 5, 16, 16, 32) + verify_batch_matmul(5, 1, 16, 16, 32) if __name__ == "__main__":