From a4ee800f65db1eb7a6863ddc8b082507dc8398d6 Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Wed, 3 Feb 2021 23:15:28 +0000 Subject: [PATCH 01/15] Add CPU dense weight transform --- include/tvm/relay/attrs/nn.h | 15 ++++ python/tvm/relay/op/nn/_nn.py | 9 +++ python/tvm/relay/op/nn/nn.py | 33 ++++++++ python/tvm/relay/op/strategy/generic.py | 13 +++ python/tvm/relay/op/strategy/x86.py | 27 ++++--- python/tvm/topi/nn/dense.py | 69 ++++++++++++++++ python/tvm/topi/x86/__init__.py | 1 + python/tvm/topi/x86/dense.py | 103 +++++++++++++++++------- python/tvm/topi/x86/dense_alter_op.py | 68 ++++++++++++++++ src/relay/op/nn/nn.cc | 33 +++++++- src/relay/op/nn/nn.h | 46 +++++++++++ 11 files changed, 376 insertions(+), 41 deletions(-) create mode 100644 python/tvm/topi/x86/dense_alter_op.py diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index c3c58e54517c..05a5ca426bb6 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -939,6 +939,21 @@ struct DenseAttrs : public tvm::AttrsNode { } }; +/*! \brief Attributes for dense_weight_transform operator */ +struct DenseWeightTransformAttrs : public tvm::AttrsNode { + tvm::String weight_layout; + DataType out_dtype; + + TVM_DECLARE_ATTRS(DenseWeightTransformAttrs, "relay.attrs.DenseAttrs") { + TVM_ATTR_FIELD(weight_layout) + .set_default("NKn") + .describe("Dimension ordering of weight. Can be 'NKn', 'NK16n', etc."); + TVM_ATTR_FIELD(out_dtype) + .set_default(NullValue()) + .describe("Output data type, set to explicit type under mixed precision setting"); + } +}; + /*! \brief Attributes for batch matmul operator */ struct BatchMatmulAttrs : public tvm::AttrsNode { tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 37ee6b6e929f..85d92a267136 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -77,6 +77,15 @@ def legalize_dense(attrs, inputs, types): reg.register_strategy("nn.dense", strategy.dense_strategy) reg.register_pattern("nn.dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE) +@reg.register_alter_op_layout("nn.dense") +def alter_op_layout_dense(attrs, inputs, tinfos, out_type): + """Alternate the layout of dense""" + return topi.nn.dense_alter_layout(attrs, inputs, tinfos, out_type) + +# dense_weight_transform +reg.register_strategy("nn.contrib_dense_weight_transform", strategy.dense_weight_transform_strategy) +reg.register_pattern("nn.contrib_dense_weight_transform", reg.OpPattern.OUT_ELEMWISE_FUSABLE) + # fifo_buffer @reg.register_compute("nn.fifo_buffer") diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 562cee5f53bb..5727888e7847 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -1435,6 +1435,39 @@ def dense(data, weight, units=None, out_dtype=""): return _make.dense(data, weight, units, out_dtype) +def contrib_dense_weight_transform(data, weight, weight_layout="NKn", out_dtype=""): + """Dense operator. + Applies a linear transformation + + .. math:: + + `Y = X * W^T` + + Parameters + ---------- + data : tvm.relay.Expr + The input data to the operator, + of shape `(d_1, d_2, ..., d_n, units_in)`. + + weight : tvm.relay.Expr + The transformed weight expressions, 3-D matrix, + of shape `(units // pack_weight_tile, units_in, pack_weight_tile)`. + + weight_layout : str, optional + Layout of the weight. + + out_dtype : str, optional + Specifies the output data type for mixed precision dense, + of shape `(d_1, d_2, ..., d_n, units)`. + + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + return _make.contrib_dense_weight_transform(data, weight, weight_layout, out_dtype) + + def fifo_buffer(data, buffer, axis): """FIFO buffer to enable computation reuse in CNNs with sliding indow input diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 3ad75faf4bc1..5c14d1bdefe8 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -731,6 +731,19 @@ def dense_strategy(attrs, inputs, out_type, target): return strategy +@override_native_generic_func("dense_weight_transform_strategy") +def dense_weight_transform_strategy(attrs, inputs, out_type, target): + """dense_weight_transform generic strategy""" + logger.warning("dense_weight_transform is not optimized for this platform.") + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_dense(topi.nn.dense_weight_transform), + wrap_topi_schedule(topi.generic.schedule_dense), + name="dense_weight_transform.generic", + ) + return strategy + + # batch_matmul def wrap_compute_batch_matmul(topi_compute, need_auto_scheduler_layout=False): """wrap batch_matmul topi compute""" diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index edfaaeefc5df..c29cbe511e35 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -364,14 +364,13 @@ def conv1d_strategy_cpu(attrs, inputs, out_type, target): def dense_strategy_cpu(attrs, inputs, out_type, target): """dense x86 strategy""" strategy = _op.OpStrategy() - m, _ = inputs[0].shape same_type = inputs[0].dtype == inputs[1].dtype == out_type.dtype dtype = inputs[0].dtype u8s8s32 = dtype == "uint8" and inputs[1].dtype == "int8" and out_type.dtype == "int32" strategy.add_implementation( - wrap_compute_dense(topi.x86.dense_nopack), - wrap_topi_schedule(topi.x86.schedule_dense_nopack), - name="dense_nopack.x86", + wrap_compute_dense(topi.x86.dense_pack), + wrap_topi_schedule(topi.x86.schedule_dense_pack), + name="dense_pack.x86", plevel=10, ) @@ -407,14 +406,18 @@ def dense_strategy_cpu(attrs, inputs, out_type, target): name="dense_mkldnn.x86", plevel=15, ) - with SpecializedCondition(m >= 16): - # this implementation may not be well-optimized, so use plevel=5 for now. - strategy.add_implementation( - wrap_compute_dense(topi.x86.dense_pack), - wrap_topi_schedule(topi.x86.schedule_dense_pack), - name="dense_pack.x86", - plevel=5, - ) + return strategy + + +@dense_weight_transform_strategy.register("cpu") +def dense_weight_transform_strategy_cpu(attrs, inputs, out_type, target): + """dense_weight_transform x86 strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_dense(topi.x86.dense_weight_transform), + wrap_topi_schedule(topi.x86.schedule_dense_weight_transform), + name="dense_weight_transform.x86", + ) return strategy diff --git a/python/tvm/topi/nn/dense.py b/python/tvm/topi/nn/dense.py index bb6ea90c3fcd..ceca961004af 100644 --- a/python/tvm/topi/nn/dense.py +++ b/python/tvm/topi/nn/dense.py @@ -104,3 +104,72 @@ def dense_legalize(attrs, inputs, types): # not to change by default # pylint: disable=unused-argument return None + + +def dense_weight_transform(data, weight, bias=None, out_dtype=None): + """The default implementation of dense_weight_transform in topi. + + Parameters + ---------- + data : tvm.te.Tensor + 2-D with shape [batch, in_dim] + + weight : tvm.te.Tensor + 2-D with shape [out_dim, in_dim] + + bias : Optional[tvm.te.Tensor] + 1-D with shape [out_dim] + + out_dtype : Optional[str] + The output type. This is used for mixed precision. + + Returns + ------- + output : tvm.te.Tensor + 2-D with shape [batch, out_dim] + """ + if out_dtype is None: + out_dtype = data.dtype + M, K = get_const_tuple(data.shape) # batch, in_dim + N, _, packw_bn = get_const_tuple(weight.shape) # out_dim + N = N * packw_bn + + idxdiv = tvm.tir.indexdiv + idxmod = tvm.tir.indexmod + k = te.reduce_axis((0, K), name="k") + C = te.compute( + (M, N), + lambda y, x: te.sum( + data[y, k].astype(out_dtype) + * weight[idxdiv(x, packw_bn), k, idxmod(x, packw_bn)].astype(out_dtype), + axis=k, + ), + name="T_dense_weight_transform", + tag="dense_weight_transform", + ) + if bias is not None: + C = te.compute((M, N), lambda i, j: C[i, j] + bias[j].astype(out_dtype), tag=tag.BROADCAST) + return C + + +@tvm.target.generic_func +def dense_alter_layout(attrs, inputs, tinfos, out_type): + """Change dense layout. + + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current convolution + inputs : tvm.relay.Expr + Grouped input symbols + tinfos : list + Input shape and dtype + out_type: type + The output type + + Note + ---- + Unlike other TOPI functions, this function operates on both graph level and operator level. + """ + # not to change by default + return None diff --git a/python/tvm/topi/x86/__init__.py b/python/tvm/topi/x86/__init__.py index 154511010a1c..bb6a7cdd4122 100644 --- a/python/tvm/topi/x86/__init__.py +++ b/python/tvm/topi/x86/__init__.py @@ -39,4 +39,5 @@ from .conv3d_transpose import * from .sparse import * from .conv2d_alter_op import * +from .dense_alter_op import * from .scatter import * diff --git a/python/tvm/topi/x86/dense.py b/python/tvm/topi/x86/dense.py index 15d7a1a310d6..47a06edb9ace 100644 --- a/python/tvm/topi/x86/dense.py +++ b/python/tvm/topi/x86/dense.py @@ -26,11 +26,12 @@ from tvm.contrib import mkldnn from .utils import get_fp32_len +from .injective import schedule_injective, schedule_injective_from_existing from .. import generic, tag from ..utils import traverse_inline, get_const_tuple -def _schedule_dense_pack_template(cfg, s, C): +def _schedule_dense_weight_transform_template(cfg, s, C, O): A, packedB = s[C].op.input_tensors CC = s.cache_write(C, "global") @@ -39,9 +40,10 @@ def _schedule_dense_pack_template(cfg, s, C): yt, yo, yi = cfg["tile_y"].apply(s, C, y) xt, xo, xi = cfg["tile_x"].apply(s, C, x) - s[C].reorder(yt, xt, yo, xo, yi, xi) - xyt = s[C].fuse(yt, xt) - s[C].parallel(xyt) + s[C].reorder(xt, yt, yo, xo, yi, xi) + xyt = s[C].fuse(xt, yt) + if C == O: + s[C].parallel(xyt) xyo = s[C].fuse(yo, xo) s[C].unroll(yi) s[C].vectorize(xi) @@ -51,12 +53,27 @@ def _schedule_dense_pack_template(cfg, s, C): ko, ki = cfg["tile_k"].apply(s, CC, k) s[CC].reorder(ko, ki, y, x) s[CC].vectorize(x) - s[CC].unroll(y) - s[CC].unroll(ki) - z, y, x = s[packedB].op.axis - s[packedB].reorder(z, x, y) - s[packedB].parallel(z) + tile_inner = cfg["tile_inner"].size[-1] + if tile_inner > 1: + yo, yi = s[CC].split(y, tile_inner) + s[CC].reorder(ko, yo, ki, yi, x) + s[CC].unroll(yo) + s[CC].unroll(ki) + s[CC].unroll(yi) + else: + s[CC].unroll(ki) + s[CC].unroll(y) + + if C != O: + y, x = s[O].op.axis + yt, yo, yi = cfg["tile_y"].apply(s, O, y) + xt, xo, xi = cfg["tile_x"].apply(s, O, x) + s[O].reorder(xt, yt, yo, xo, yi, xi) + xyt = s[O].fuse(xt, yt) + s[C].compute_at(s[O], xyt) + s[O].vectorize(xi) + s[O].parallel(xyt) return s @@ -81,7 +98,7 @@ def _schedule_dense_nopack_template(cfg, s, C): return s -def _default_dense_pack_config(cfg, M, N, K): +def _default_dense_weight_transform_config(cfg, M, N, K): # Generate default schedule for dynamic shape. if isinstance(M, tvm.tir.Var): M = 16 @@ -116,6 +133,7 @@ def _default_dense_pack_config(cfg, M, N, K): cfg["tile_y"] = SplitEntity([MM // tiley_oi, tiley_oi, tiley_ii]) cfg["tile_x"] = SplitEntity([NN // tilex_oi, tilex_oi, tilex_ii]) cfg["tile_k"] = SplitEntity([K, 1]) + cfg["tile_inner"] = SplitEntity([M // tiley_ii, tiley_ii]) def _default_dense_nopack_config(cfg, M, N, K): @@ -181,26 +199,43 @@ def _callback(op): traverse_inline(s, outs[0].op, _callback) return s +def dense_pack(data, weight, bias=None, out_dtype=None): + return dense_weight_transform(data, weight, bias, out_dtype) + +def schedule_dense_pack(outs): + return schedule_dense_weight_transform(outs) -@autotvm.register_topi_compute("dense_pack.x86") -def dense_pack(cfg, data, weight, bias=None, out_dtype=None): - """Compute dense with packing""" +@autotvm.register_topi_compute("dense_weight_transform.x86") +def dense_weight_transform(cfg, data, weight, bias=None, out_dtype=None): + """Compute dense with transformed weight.""" if out_dtype is None: out_dtype = data.dtype M, K = get_const_tuple(data.shape) # batch, in_dim - N, _ = get_const_tuple(weight.shape) # out_dim + if len(weight.shape) == 3: + N, _, packw_bn = get_const_tuple(weight.shape) # out_dim + N = N * packw_bn + else: + N, _ = get_const_tuple(weight.shape) # out_dim # create tuning space cfg.define_split("tile_y", M, num_outputs=3) cfg.define_split("tile_x", N, num_outputs=3) cfg.define_split("tile_k", K, num_outputs=2) + cfg.define_split("tile_inner", M, num_outputs=2, filter=lambda y: y.size[-1] <= 16) if cfg.is_fallback: - _default_dense_pack_config(cfg, M, N, K) - - packw_bn = cfg["tile_x"].size[-1] - packw_shape = (N // packw_bn, K, packw_bn) - packw = te.compute( - packw_shape, lambda z, y, x: weight[z * packw_bn + x, y], name="packed_weight" - ) + _default_dense_weight_transform_config(cfg, M, N, K) + + if len(weight.shape) == 2: + packw_bn = cfg["tile_x"].size[-1] + packw_shape = (N // packw_bn, K, packw_bn) + if autotvm.GLOBAL_SCOPE.in_tuning: + # Directly use modified data layout placeholder. + packw = tvm.te.placeholder(packw_shape, weight.dtype, name="packed_weight") + else: + packw = te.compute( + packw_shape, lambda z, y, x: weight[z * packw_bn + x, y], name="packed_weight" + ) + else: + packw = weight idxdiv = tvm.tir.indexdiv idxmod = tvm.tir.indexmod @@ -212,21 +247,21 @@ def dense_pack(cfg, data, weight, bias=None, out_dtype=None): * packw[idxdiv(x, packw_bn), k, idxmod(x, packw_bn)].astype(out_dtype), axis=k, ), - tag="dense_pack", + tag="dense_weight_transform", ) if bias is not None: C = te.compute((M, N), lambda i, j: C[i, j] + bias[j].astype(out_dtype), tag=tag.BROADCAST) return C -@autotvm.register_topi_schedule("dense_pack.x86") -def schedule_dense_pack(cfg, outs): - """Create the schedule for dense_pack""" +@autotvm.register_topi_schedule("dense_weight_transform.x86") +def schedule_dense_weight_transform(cfg, outs): + """Create the schedule for dense_weight_transform""" s = te.create_schedule([x.op for x in outs]) def _callback(op): - if "dense_pack" in op.tag: - _schedule_dense_pack_template(cfg, s, op.output(0)) + if "dense_weight_transform" in op.tag: + _schedule_dense_weight_transform_template(cfg, s, op.output(0), outs[0]) traverse_inline(s, outs[0].op, _callback) return s @@ -276,7 +311,19 @@ def dense_mkl(cfg, data, weight, bias=None, out_dtype=None): @autotvm.register_topi_schedule("dense_mkl.x86") def schedule_dense_mkl(_, outs): """Create schedule for dense_mkl""" - return generic.schedule_extern(outs) + #return generic.schedule_extern(outs) + s = te.create_schedule([x.op for x in outs]) + te.schedule.AutoInlineInjective(s) + + def _callback(op): + if "broadcast" in op.tag or "injective" in op.tag or "elemwise" in op.tag: + schedule_injective_from_existing(s, op.output(0)) + + #traverse_inline(s, outs[0].op, _callback) + for out in outs: + if "dense" not in out.op.name: + schedule_injective_from_existing(s, out) + return s @autotvm.register_topi_compute("dense_mkldnn.x86") diff --git a/python/tvm/topi/x86/dense_alter_op.py b/python/tvm/topi/x86/dense_alter_op.py new file mode 100644 index 000000000000..8183baf234da --- /dev/null +++ b/python/tvm/topi/x86/dense_alter_op.py @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-variable,unused-argument,no-member +"""Dense alter op functions for x86""" + +import tvm +from tvm import te +from tvm import relay +from tvm import autotvm +from .dense import _default_dense_weight_transform_config +from ..utils import get_const_tuple +from ..nn import dense_alter_layout + + +@dense_alter_layout.register("cpu") +def _alter_dense_layout(attrs, inputs, tinfos, out_type): + target = tvm.target.Target.current(allow_none=False) + dispatch_ctx = autotvm.task.DispatchContext.current + new_attrs = {} + data_tensor, weight_tensor = tinfos + out_dtype = out_type.dtype + M, K = get_const_tuple(data_tensor.shape) + N, _ = get_const_tuple(weight_tensor.shape) + + impl, outs = relay.backend.compile_engine.select_implementation( + relay.op.get("nn.dense"), attrs, tinfos, out_type, target + ) + workload = autotvm.task.get_workload(outs) + if workload: + cfg = dispatch_ctx.query(target, workload) + topi_impl = workload[0] + if topi_impl == "dense_weight_transform.x86": + if cfg.is_fallback: + _default_dense_weight_transform_config(cfg, M, N, K) + packw_bn = cfg["tile_x"].size[-1] + new_attrs["weight_layout"] = "NK%dn" % packw_bn + new_attrs["out_dtype"] = out_dtype + new_weight = te.placeholder( + (N // packw_bn, K, packw_bn), dtype=weight_tensor.dtype, + ) + # Relay dense doesn't have bias. + new_workload = autotvm.task.args_to_workload( + [ + data_tensor, + new_weight, + None, + out_dtype, + ], + topi_impl, + ) + dispatch_ctx.update(target, new_workload, cfg) + return relay.nn.contrib_dense_weight_transform(*inputs, **new_attrs) + + return None diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 8ace82be9ff8..e6a4ae677a8d 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -184,7 +184,38 @@ RELAY_REGISTER_OP("nn.dense") .add_argument("data", "nD Tensor", "Input data.") .add_argument("weight", "2D Tensor", "Weight matrix.") .set_support_level(1) - .add_type_rel("Dense", DenseRel); + .add_type_rel("Dense", DenseRel) + .set_attr("FInferCorrectLayout", DenseInferCorrectLayout); + +// relay.nn.contrib_dense_weight_transform +TVM_REGISTER_NODE_TYPE(DenseWeightTransformAttrs); + +// Positional relay function to create dense_weight_transform operator used by frontend FFI. +Expr MakeDenseWeightTransform(Expr data, Expr weight, String weight_layout, DataType out_dtype) { + auto attrs = make_object(); + attrs->weight_layout = weight_layout; + attrs->out_dtype = out_dtype; + static const Op& op = Op::Get("nn.contrib_dense_weight_transform"); + return Call(op, {data, weight}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_dense_weight_transform").set_body_typed(MakeDenseWeightTransform); + +RELAY_REGISTER_OP("nn.contrib_dense_weight_transform") + .describe(R"code(Applies a linear transformation: :math:`Y = XW^T`. + +- **data**: `(x1, x2, ..., xn, input_dim)` +- **weight**: `(units // pack_weight_tile, input_dim, pack_weight_tile)` +- **out**: `(x1, x2, ..., xn, units)`. + +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "nD Tensor", "Input data.") + .add_argument("weight", "3D Tensor", "Packed weight matrix.") + .set_support_level(10) + .add_type_rel("DenseWeightTransform", DenseWeightTransformRel) + .set_attr("FInferCorrectLayout", DensePackedInferCorrectLayout); // relay.leaky_relu TVM_REGISTER_NODE_TYPE(LeakyReluAttrs); diff --git a/src/relay/op/nn/nn.h b/src/relay/op/nn/nn.h index 9b9cff2dba81..14f8249b160f 100644 --- a/src/relay/op/nn/nn.h +++ b/src/relay/op/nn/nn.h @@ -31,6 +31,8 @@ #include +#include "../op_common.h" + namespace tvm { namespace relay { @@ -88,6 +90,50 @@ bool DenseRel(const Array& types, int num_inputs, const Attrs& attrs, return true; } +template +bool DenseWeightTransformRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + ICHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + const auto* weight = types[1].as(); + if (data == nullptr || weight == nullptr) return false; + + const AttrType* param = attrs.as(); + ICHECK(param != nullptr); + + Array oshape = data->shape; + oshape.Set((oshape.size() - 1), weight->shape[0] * weight->shape[2]); + + DataType out_dtype = param->out_dtype; + if (out_dtype.bits() == 0) { + out_dtype = data->dtype; + } + // assign output type + reporter->Assign(types[2], TensorType(oshape, out_dtype)); + return true; +} + +template +Array > DenseInferCorrectLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { + return Array >{ + {"MK", "NK"}, + {"MK"}}; +} + +template +Array > DensePackedInferCorrectLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { + const T* params = attrs.as(); + return Array >{ + {"MK", params->weight_layout}, + {"MK"}}; +} + } // namespace relay } // namespace tvm #endif // TVM_RELAY_OP_NN_NN_H_ From fae00504a7d8fb3a0e899ccb241421c6a2cd2527 Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Wed, 3 Feb 2021 23:36:22 +0000 Subject: [PATCH 02/15] Fix format --- include/tvm/relay/attrs/nn.h | 4 ++-- python/tvm/relay/op/nn/_nn.py | 8 +++----- python/tvm/relay/op/strategy/generic.py | 4 +--- python/tvm/relay/op/strategy/x86.py | 4 +--- python/tvm/topi/x86/dense.py | 7 +++++-- python/tvm/topi/x86/dense_alter_op.py | 12 ++---------- src/relay/op/nn/nn.cc | 6 ++++-- src/relay/op/nn/nn.h | 8 ++------ 8 files changed, 20 insertions(+), 33 deletions(-) diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 05a5ca426bb6..0b33ffb57665 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -946,8 +946,8 @@ struct DenseWeightTransformAttrs : public tvm::AttrsNode()) .describe("Output data type, set to explicit type under mixed precision setting"); diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 85d92a267136..05863b3cadb7 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -77,11 +77,13 @@ def legalize_dense(attrs, inputs, types): reg.register_strategy("nn.dense", strategy.dense_strategy) reg.register_pattern("nn.dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE) + @reg.register_alter_op_layout("nn.dense") def alter_op_layout_dense(attrs, inputs, tinfos, out_type): """Alternate the layout of dense""" return topi.nn.dense_alter_layout(attrs, inputs, tinfos, out_type) + # dense_weight_transform reg.register_strategy("nn.contrib_dense_weight_transform", strategy.dense_weight_transform_strategy) reg.register_pattern("nn.contrib_dense_weight_transform", reg.OpPattern.OUT_ELEMWISE_FUSABLE) @@ -920,11 +922,7 @@ def conv_shape_func(attrs, inputs, _): return [ _conv_shape_func( - inputs[0], - inputs[1], - convert(strides), - convert(padding), - convert(dilation), + inputs[0], inputs[1], convert(strides), convert(padding), convert(dilation), ) ] diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 5c14d1bdefe8..f32ac9c3c093 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -164,9 +164,7 @@ def fast_softmax_strategy(attrs, inputs, out_type, target): # so it should only be used together with auto-scheduler. strategy = _op.OpStrategy() strategy.add_implementation( - wrap_compute_softmax(topi.nn.fast_softmax), - naive_schedule, - name="fast_softmax.generic", + wrap_compute_softmax(topi.nn.fast_softmax), naive_schedule, name="fast_softmax.generic", ) return strategy diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index c29cbe511e35..df91c5c50d0c 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -304,9 +304,7 @@ def conv3d_strategy_cpu(attrs, inputs, out_type, target): # or packed layouts. if layout == "NCDHW": strategy.add_implementation( - wrap_compute_conv3d(topi.nn.conv3d_ncdhw), - naive_schedule, - name="conv3d_ncdhw.x86", + wrap_compute_conv3d(topi.nn.conv3d_ncdhw), naive_schedule, name="conv3d_ncdhw.x86", ) elif layout == "NDHWC": strategy.add_implementation( diff --git a/python/tvm/topi/x86/dense.py b/python/tvm/topi/x86/dense.py index 47a06edb9ace..42989f5d9534 100644 --- a/python/tvm/topi/x86/dense.py +++ b/python/tvm/topi/x86/dense.py @@ -199,12 +199,15 @@ def _callback(op): traverse_inline(s, outs[0].op, _callback) return s + def dense_pack(data, weight, bias=None, out_dtype=None): return dense_weight_transform(data, weight, bias, out_dtype) + def schedule_dense_pack(outs): return schedule_dense_weight_transform(outs) + @autotvm.register_topi_compute("dense_weight_transform.x86") def dense_weight_transform(cfg, data, weight, bias=None, out_dtype=None): """Compute dense with transformed weight.""" @@ -311,7 +314,7 @@ def dense_mkl(cfg, data, weight, bias=None, out_dtype=None): @autotvm.register_topi_schedule("dense_mkl.x86") def schedule_dense_mkl(_, outs): """Create schedule for dense_mkl""" - #return generic.schedule_extern(outs) + # return generic.schedule_extern(outs) s = te.create_schedule([x.op for x in outs]) te.schedule.AutoInlineInjective(s) @@ -319,7 +322,7 @@ def _callback(op): if "broadcast" in op.tag or "injective" in op.tag or "elemwise" in op.tag: schedule_injective_from_existing(s, op.output(0)) - #traverse_inline(s, outs[0].op, _callback) + # traverse_inline(s, outs[0].op, _callback) for out in outs: if "dense" not in out.op.name: schedule_injective_from_existing(s, out) diff --git a/python/tvm/topi/x86/dense_alter_op.py b/python/tvm/topi/x86/dense_alter_op.py index 8183baf234da..71aac60f0802 100644 --- a/python/tvm/topi/x86/dense_alter_op.py +++ b/python/tvm/topi/x86/dense_alter_op.py @@ -49,18 +49,10 @@ def _alter_dense_layout(attrs, inputs, tinfos, out_type): packw_bn = cfg["tile_x"].size[-1] new_attrs["weight_layout"] = "NK%dn" % packw_bn new_attrs["out_dtype"] = out_dtype - new_weight = te.placeholder( - (N // packw_bn, K, packw_bn), dtype=weight_tensor.dtype, - ) + new_weight = te.placeholder((N // packw_bn, K, packw_bn), dtype=weight_tensor.dtype,) # Relay dense doesn't have bias. new_workload = autotvm.task.args_to_workload( - [ - data_tensor, - new_weight, - None, - out_dtype, - ], - topi_impl, + [data_tensor, new_weight, None, out_dtype,], topi_impl, ) dispatch_ctx.update(target, new_workload, cfg) return relay.nn.contrib_dense_weight_transform(*inputs, **new_attrs) diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index e6a4ae677a8d..911908d72a4c 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -199,7 +199,8 @@ Expr MakeDenseWeightTransform(Expr data, Expr weight, String weight_layout, Data return Call(op, {data, weight}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_dense_weight_transform").set_body_typed(MakeDenseWeightTransform); +TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_dense_weight_transform") + .set_body_typed(MakeDenseWeightTransform); RELAY_REGISTER_OP("nn.contrib_dense_weight_transform") .describe(R"code(Applies a linear transformation: :math:`Y = XW^T`. @@ -215,7 +216,8 @@ RELAY_REGISTER_OP("nn.contrib_dense_weight_transform") .add_argument("weight", "3D Tensor", "Packed weight matrix.") .set_support_level(10) .add_type_rel("DenseWeightTransform", DenseWeightTransformRel) - .set_attr("FInferCorrectLayout", DensePackedInferCorrectLayout); + .set_attr("FInferCorrectLayout", + DensePackedInferCorrectLayout); // relay.leaky_relu TVM_REGISTER_NODE_TYPE(LeakyReluAttrs); diff --git a/src/relay/op/nn/nn.h b/src/relay/op/nn/nn.h index 14f8249b160f..31ba3a7d166c 100644 --- a/src/relay/op/nn/nn.h +++ b/src/relay/op/nn/nn.h @@ -118,9 +118,7 @@ Array > DenseInferCorrectLayout(const Attrs& attrs, const Array& new_in_layouts, const Array& old_in_layouts, const Array& old_in_types) { - return Array >{ - {"MK", "NK"}, - {"MK"}}; + return Array >{{"MK", "NK"}, {"MK"}}; } template @@ -129,9 +127,7 @@ Array > DensePackedInferCorrectLayout(const Attrs& attrs, const Array& old_in_layouts, const Array& old_in_types) { const T* params = attrs.as(); - return Array >{ - {"MK", params->weight_layout}, - {"MK"}}; + return Array >{{"MK", params->weight_layout}, {"MK"}}; } } // namespace relay From fbb1e420afee3fdc854633760a3c07a4f262ae94 Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Wed, 3 Feb 2021 23:46:39 +0000 Subject: [PATCH 03/15] Fix python format --- python/tvm/relay/op/nn/_nn.py | 6 +++++- python/tvm/relay/op/strategy/generic.py | 4 +++- python/tvm/relay/op/strategy/x86.py | 4 +++- python/tvm/topi/x86/dense_alter_op.py | 13 +++++++++++-- 4 files changed, 22 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 05863b3cadb7..59e428fdb6b5 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -922,7 +922,11 @@ def conv_shape_func(attrs, inputs, _): return [ _conv_shape_func( - inputs[0], inputs[1], convert(strides), convert(padding), convert(dilation), + inputs[0], + inputs[1], + convert(strides), + convert(padding), + convert(dilation), ) ] diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index f32ac9c3c093..5c14d1bdefe8 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -164,7 +164,9 @@ def fast_softmax_strategy(attrs, inputs, out_type, target): # so it should only be used together with auto-scheduler. strategy = _op.OpStrategy() strategy.add_implementation( - wrap_compute_softmax(topi.nn.fast_softmax), naive_schedule, name="fast_softmax.generic", + wrap_compute_softmax(topi.nn.fast_softmax), + naive_schedule, + name="fast_softmax.generic", ) return strategy diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index df91c5c50d0c..c29cbe511e35 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -304,7 +304,9 @@ def conv3d_strategy_cpu(attrs, inputs, out_type, target): # or packed layouts. if layout == "NCDHW": strategy.add_implementation( - wrap_compute_conv3d(topi.nn.conv3d_ncdhw), naive_schedule, name="conv3d_ncdhw.x86", + wrap_compute_conv3d(topi.nn.conv3d_ncdhw), + naive_schedule, + name="conv3d_ncdhw.x86", ) elif layout == "NDHWC": strategy.add_implementation( diff --git a/python/tvm/topi/x86/dense_alter_op.py b/python/tvm/topi/x86/dense_alter_op.py index 71aac60f0802..17e17f131930 100644 --- a/python/tvm/topi/x86/dense_alter_op.py +++ b/python/tvm/topi/x86/dense_alter_op.py @@ -49,10 +49,19 @@ def _alter_dense_layout(attrs, inputs, tinfos, out_type): packw_bn = cfg["tile_x"].size[-1] new_attrs["weight_layout"] = "NK%dn" % packw_bn new_attrs["out_dtype"] = out_dtype - new_weight = te.placeholder((N // packw_bn, K, packw_bn), dtype=weight_tensor.dtype,) + new_weight = te.placeholder( + (N // packw_bn, K, packw_bn), + dtype=weight_tensor.dtype, + ) # Relay dense doesn't have bias. new_workload = autotvm.task.args_to_workload( - [data_tensor, new_weight, None, out_dtype,], topi_impl, + [ + data_tensor, + new_weight, + None, + out_dtype, + ], + topi_impl, ) dispatch_ctx.update(target, new_workload, cfg) return relay.nn.contrib_dense_weight_transform(*inputs, **new_attrs) From b794b9d98dea5d4259ecc92e7c97697799d4d691 Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Wed, 3 Feb 2021 23:59:55 +0000 Subject: [PATCH 04/15] Fix pylint --- python/tvm/topi/nn/dense.py | 1 + python/tvm/topi/x86/dense.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/python/tvm/topi/nn/dense.py b/python/tvm/topi/nn/dense.py index ceca961004af..a055aecf3998 100644 --- a/python/tvm/topi/nn/dense.py +++ b/python/tvm/topi/nn/dense.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=invalid-name,unused-variable """TVM operator fully connected compute.""" import tvm from tvm import te, auto_scheduler diff --git a/python/tvm/topi/x86/dense.py b/python/tvm/topi/x86/dense.py index 42989f5d9534..0e78735787f7 100644 --- a/python/tvm/topi/x86/dense.py +++ b/python/tvm/topi/x86/dense.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name,too-many-locals,unused-variable +# pylint: disable=no-value-for-parameter """x86 dense operators""" from __future__ import absolute_import as _abs import tvm @@ -26,7 +27,7 @@ from tvm.contrib import mkldnn from .utils import get_fp32_len -from .injective import schedule_injective, schedule_injective_from_existing +from .injective import schedule_injective_from_existing from .. import generic, tag from ..utils import traverse_inline, get_const_tuple From 6eb5772c638e7425393c3280b338cac430c00697 Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Thu, 4 Feb 2021 00:02:47 +0000 Subject: [PATCH 05/15] Minor fix --- python/tvm/topi/nn/dense.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/topi/nn/dense.py b/python/tvm/topi/nn/dense.py index a055aecf3998..8ab7ecaf1281 100644 --- a/python/tvm/topi/nn/dense.py +++ b/python/tvm/topi/nn/dense.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name,unused-variable +# pylint: disable=invalid-name,unused-argument """TVM operator fully connected compute.""" import tvm from tvm import te, auto_scheduler From cf48da32d55837a59fcf216b6469bb0aedc1f541 Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Thu, 4 Feb 2021 00:58:38 +0000 Subject: [PATCH 06/15] Add test --- python/tvm/topi/x86/dense_alter_op.py | 2 +- .../python/relay/test_pass_alter_op_layout.py | 33 ++++++++++++++++++- 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/x86/dense_alter_op.py b/python/tvm/topi/x86/dense_alter_op.py index 17e17f131930..34904b3d71cb 100644 --- a/python/tvm/topi/x86/dense_alter_op.py +++ b/python/tvm/topi/x86/dense_alter_op.py @@ -26,7 +26,7 @@ from ..nn import dense_alter_layout -@dense_alter_layout.register("cpu") +@dense_alter_layout.register(["cpu", "arm_cpu"]) def _alter_dense_layout(attrs, inputs, tinfos, out_type): target = tvm.target.Target.current(allow_none=False) dispatch_ctx = autotvm.task.DispatchContext.current diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index 58c279d750ec..53ca98b7aebf 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -18,7 +18,7 @@ import pytest import tvm -from tvm import relay +from tvm import relay, topi from tvm.relay import transform, analysis from tvm.relay.testing.temp_op_attr import TempOpAttr from tvm.relay.testing import run_infer_type @@ -1248,6 +1248,36 @@ def expected(): assert tvm.ir.structural_equal(a, b, map_free_vars=True), "Actual = \n" + str(a) +def test_alter_op_dense(): + def before(): + x = relay.var("x", shape=(32, 64)) + weight = relay.var("weight", shape=(48, 64)) + y = relay.nn.dense(x, weight) + y = relay.Function(analysis.free_vars(y), y) + return y + + def expected(): + x = relay.var("x", shape=(32, 64)) + weight = relay.var("weight", shape=(48, 64)) + target_layout = "NK16n" + weight_transform = relay.layout_transform(weight, "NK", target_layout) + y = relay.nn.contrib_dense_weight_transform( + x, weight_transform, weight_layout=target_layout, out_dtype="float32" + ) + y = relay.Function(analysis.free_vars(y), y) + return y + + for target, _ in tvm.testing.enabled_targets(): + with tvm.target.Target(target): + with TempOpAttr( + "nn.dense", "FTVMAlterOpLayout", topi.x86.dense_alter_op._alter_dense_layout + ): + a = before() + a = run_opt_pass(a, transform.AlterOpLayout()) + b = run_opt_pass(expected(), transform.InferType()) + assert tvm.ir.structural_equal(a, b) + + if __name__ == "__main__": test_alter_op() test_alter_return_none() @@ -1269,3 +1299,4 @@ def expected(): test_alter_layout_nhwc_arm() test_alter_layout_nhwc_int8_aarch64() test_alter_op_with_global_var() + test_alter_op_dense() From d6cc506e0ec50c4df588f36d0e5083992055727d Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Thu, 4 Feb 2021 22:40:36 +0000 Subject: [PATCH 07/15] Do not need to infer layout for dense --- include/tvm/relay/attrs/nn.h | 15 --------------- python/tvm/relay/op/nn/nn.py | 8 ++++---- python/tvm/topi/x86/dense_alter_op.py | 9 +++++---- src/relay/op/nn/nn.cc | 17 ++++++----------- src/relay/op/nn/nn.h | 17 ----------------- 5 files changed, 15 insertions(+), 51 deletions(-) diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 0b33ffb57665..c3c58e54517c 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -939,21 +939,6 @@ struct DenseAttrs : public tvm::AttrsNode { } }; -/*! \brief Attributes for dense_weight_transform operator */ -struct DenseWeightTransformAttrs : public tvm::AttrsNode { - tvm::String weight_layout; - DataType out_dtype; - - TVM_DECLARE_ATTRS(DenseWeightTransformAttrs, "relay.attrs.DenseAttrs") { - TVM_ATTR_FIELD(weight_layout) - .set_default("NKn") - .describe("Dimension ordering of weight. Can be 'NKn', 'NK16n', etc."); - TVM_ATTR_FIELD(out_dtype) - .set_default(NullValue()) - .describe("Output data type, set to explicit type under mixed precision setting"); - } -}; - /*! \brief Attributes for batch matmul operator */ struct BatchMatmulAttrs : public tvm::AttrsNode { tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 5727888e7847..0d3026518091 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -1435,7 +1435,7 @@ def dense(data, weight, units=None, out_dtype=""): return _make.dense(data, weight, units, out_dtype) -def contrib_dense_weight_transform(data, weight, weight_layout="NKn", out_dtype=""): +def contrib_dense_weight_transform(data, weight, units=None, out_dtype=""): """Dense operator. Applies a linear transformation @@ -1453,8 +1453,8 @@ def contrib_dense_weight_transform(data, weight, weight_layout="NKn", out_dtype= The transformed weight expressions, 3-D matrix, of shape `(units // pack_weight_tile, units_in, pack_weight_tile)`. - weight_layout : str, optional - Layout of the weight. + units : int, optional + Number of hidden units of the dense transformation. out_dtype : str, optional Specifies the output data type for mixed precision dense, @@ -1465,7 +1465,7 @@ def contrib_dense_weight_transform(data, weight, weight_layout="NKn", out_dtype= result : tvm.relay.Expr The computed result. """ - return _make.contrib_dense_weight_transform(data, weight, weight_layout, out_dtype) + return _make.contrib_dense_weight_transform(data, weight, units, out_dtype) def fifo_buffer(data, buffer, axis): diff --git a/python/tvm/topi/x86/dense_alter_op.py b/python/tvm/topi/x86/dense_alter_op.py index 34904b3d71cb..14c5ca574d3c 100644 --- a/python/tvm/topi/x86/dense_alter_op.py +++ b/python/tvm/topi/x86/dense_alter_op.py @@ -30,7 +30,6 @@ def _alter_dense_layout(attrs, inputs, tinfos, out_type): target = tvm.target.Target.current(allow_none=False) dispatch_ctx = autotvm.task.DispatchContext.current - new_attrs = {} data_tensor, weight_tensor = tinfos out_dtype = out_type.dtype M, K = get_const_tuple(data_tensor.shape) @@ -47,8 +46,7 @@ def _alter_dense_layout(attrs, inputs, tinfos, out_type): if cfg.is_fallback: _default_dense_weight_transform_config(cfg, M, N, K) packw_bn = cfg["tile_x"].size[-1] - new_attrs["weight_layout"] = "NK%dn" % packw_bn - new_attrs["out_dtype"] = out_dtype + weight_layout = "NK%dn" % packw_bn new_weight = te.placeholder( (N // packw_bn, K, packw_bn), dtype=weight_tensor.dtype, @@ -64,6 +62,9 @@ def _alter_dense_layout(attrs, inputs, tinfos, out_type): topi_impl, ) dispatch_ctx.update(target, new_workload, cfg) - return relay.nn.contrib_dense_weight_transform(*inputs, **new_attrs) + weight_transform = relay.layout_transform(inputs[1], "NK", weight_layout) + return relay.nn.contrib_dense_weight_transform( + inputs[0], weight_transform, None, out_dtype + ) return None diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 911908d72a4c..2f7ae2f871dc 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -184,16 +184,13 @@ RELAY_REGISTER_OP("nn.dense") .add_argument("data", "nD Tensor", "Input data.") .add_argument("weight", "2D Tensor", "Weight matrix.") .set_support_level(1) - .add_type_rel("Dense", DenseRel) - .set_attr("FInferCorrectLayout", DenseInferCorrectLayout); + .add_type_rel("Dense", DenseRel); // relay.nn.contrib_dense_weight_transform -TVM_REGISTER_NODE_TYPE(DenseWeightTransformAttrs); - // Positional relay function to create dense_weight_transform operator used by frontend FFI. -Expr MakeDenseWeightTransform(Expr data, Expr weight, String weight_layout, DataType out_dtype) { - auto attrs = make_object(); - attrs->weight_layout = weight_layout; +Expr MakeDenseWeightTransform(Expr data, Expr weight, IndexExpr units, DataType out_dtype) { + auto attrs = make_object(); + attrs->units = units; attrs->out_dtype = out_dtype; static const Op& op = Op::Get("nn.contrib_dense_weight_transform"); return Call(op, {data, weight}, Attrs(attrs), {}); @@ -210,14 +207,12 @@ RELAY_REGISTER_OP("nn.contrib_dense_weight_transform") - **out**: `(x1, x2, ..., xn, units)`. )code" TVM_ADD_FILELINE) - .set_attrs_type() + .set_attrs_type() .set_num_inputs(2) .add_argument("data", "nD Tensor", "Input data.") .add_argument("weight", "3D Tensor", "Packed weight matrix.") .set_support_level(10) - .add_type_rel("DenseWeightTransform", DenseWeightTransformRel) - .set_attr("FInferCorrectLayout", - DensePackedInferCorrectLayout); + .add_type_rel("DenseWeightTransform", DenseWeightTransformRel); // relay.leaky_relu TVM_REGISTER_NODE_TYPE(LeakyReluAttrs); diff --git a/src/relay/op/nn/nn.h b/src/relay/op/nn/nn.h index 31ba3a7d166c..75dab129899f 100644 --- a/src/relay/op/nn/nn.h +++ b/src/relay/op/nn/nn.h @@ -113,23 +113,6 @@ bool DenseWeightTransformRel(const Array& types, int num_inputs, const Att return true; } -template -Array > DenseInferCorrectLayout(const Attrs& attrs, - const Array& new_in_layouts, - const Array& old_in_layouts, - const Array& old_in_types) { - return Array >{{"MK", "NK"}, {"MK"}}; -} - -template -Array > DensePackedInferCorrectLayout(const Attrs& attrs, - const Array& new_in_layouts, - const Array& old_in_layouts, - const Array& old_in_types) { - const T* params = attrs.as(); - return Array >{{"MK", params->weight_layout}, {"MK"}}; -} - } // namespace relay } // namespace tvm #endif // TVM_RELAY_OP_NN_NN_H_ From bf4d859b4c1882310942d1c4fb6e8321faf5b491 Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Thu, 4 Feb 2021 22:45:03 +0000 Subject: [PATCH 08/15] Fix test --- tests/python/relay/test_pass_alter_op_layout.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index 53ca98b7aebf..59ff5847744b 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -1262,7 +1262,7 @@ def expected(): target_layout = "NK16n" weight_transform = relay.layout_transform(weight, "NK", target_layout) y = relay.nn.contrib_dense_weight_transform( - x, weight_transform, weight_layout=target_layout, out_dtype="float32" + x, weight_transform, units=None, out_dtype="float32" ) y = relay.Function(analysis.free_vars(y), y) return y @@ -1279,6 +1279,7 @@ def expected(): if __name__ == "__main__": + """ test_alter_op() test_alter_return_none() test_alter_layout() @@ -1299,4 +1300,5 @@ def expected(): test_alter_layout_nhwc_arm() test_alter_layout_nhwc_int8_aarch64() test_alter_op_with_global_var() + """ test_alter_op_dense() From e02458e9edae2a876d82978fe0a9fbc2d0b2c897 Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Fri, 5 Feb 2021 00:19:45 +0000 Subject: [PATCH 09/15] Rename dense_pack --- python/tvm/relay/op/nn/_nn.py | 6 ++--- python/tvm/relay/op/nn/nn.py | 4 ++-- python/tvm/relay/op/strategy/generic.py | 12 +++++----- python/tvm/relay/op/strategy/x86.py | 19 +++++++++++----- python/tvm/topi/nn/dense.py | 8 +++---- python/tvm/topi/x86/dense.py | 30 +++++++++---------------- python/tvm/topi/x86/dense_alter_op.py | 10 ++++----- src/relay/op/nn/nn.cc | 15 ++++++------- src/relay/op/nn/nn.h | 4 ++-- 9 files changed, 52 insertions(+), 56 deletions(-) diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 59e428fdb6b5..3a4414e57108 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -84,9 +84,9 @@ def alter_op_layout_dense(attrs, inputs, tinfos, out_type): return topi.nn.dense_alter_layout(attrs, inputs, tinfos, out_type) -# dense_weight_transform -reg.register_strategy("nn.contrib_dense_weight_transform", strategy.dense_weight_transform_strategy) -reg.register_pattern("nn.contrib_dense_weight_transform", reg.OpPattern.OUT_ELEMWISE_FUSABLE) +# dense_pack +reg.register_strategy("nn.contrib_dense_pack", strategy.dense_pack_strategy) +reg.register_pattern("nn.contrib_dense_pack", reg.OpPattern.OUT_ELEMWISE_FUSABLE) # fifo_buffer diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 0d3026518091..0c233a6e3b53 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -1435,7 +1435,7 @@ def dense(data, weight, units=None, out_dtype=""): return _make.dense(data, weight, units, out_dtype) -def contrib_dense_weight_transform(data, weight, units=None, out_dtype=""): +def contrib_dense_pack(data, weight, units=None, out_dtype=""): """Dense operator. Applies a linear transformation @@ -1465,7 +1465,7 @@ def contrib_dense_weight_transform(data, weight, units=None, out_dtype=""): result : tvm.relay.Expr The computed result. """ - return _make.contrib_dense_weight_transform(data, weight, units, out_dtype) + return _make.contrib_dense_pack(data, weight, units, out_dtype) def fifo_buffer(data, buffer, axis): diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 5c14d1bdefe8..f35303895ddc 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -731,15 +731,15 @@ def dense_strategy(attrs, inputs, out_type, target): return strategy -@override_native_generic_func("dense_weight_transform_strategy") -def dense_weight_transform_strategy(attrs, inputs, out_type, target): - """dense_weight_transform generic strategy""" - logger.warning("dense_weight_transform is not optimized for this platform.") +@override_native_generic_func("dense_pack_strategy") +def dense_pack_strategy(attrs, inputs, out_type, target): + """dense_pack generic strategy""" + logger.warning("dense_pack is not optimized for this platform.") strategy = _op.OpStrategy() strategy.add_implementation( - wrap_compute_dense(topi.nn.dense_weight_transform), + wrap_compute_dense(topi.nn.dense_pack), wrap_topi_schedule(topi.generic.schedule_dense), - name="dense_weight_transform.generic", + name="dense_pack.generic", ) return strategy diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index c29cbe511e35..f33c45b248d6 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -367,6 +367,13 @@ def dense_strategy_cpu(attrs, inputs, out_type, target): same_type = inputs[0].dtype == inputs[1].dtype == out_type.dtype dtype = inputs[0].dtype u8s8s32 = dtype == "uint8" and inputs[1].dtype == "int8" and out_type.dtype == "int32" + strategy.add_implementation( + wrap_compute_dense(topi.x86.dense_nopack), + wrap_topi_schedule(topi.x86.schedule_dense_nopack), + name="dense_nopack.x86", + plevel=5, + ) + strategy.add_implementation( wrap_compute_dense(topi.x86.dense_pack), wrap_topi_schedule(topi.x86.schedule_dense_pack), @@ -409,14 +416,14 @@ def dense_strategy_cpu(attrs, inputs, out_type, target): return strategy -@dense_weight_transform_strategy.register("cpu") -def dense_weight_transform_strategy_cpu(attrs, inputs, out_type, target): - """dense_weight_transform x86 strategy""" +@dense_pack_strategy.register("cpu") +def dense_pack_strategy_cpu(attrs, inputs, out_type, target): + """dense_pack x86 strategy""" strategy = _op.OpStrategy() strategy.add_implementation( - wrap_compute_dense(topi.x86.dense_weight_transform), - wrap_topi_schedule(topi.x86.schedule_dense_weight_transform), - name="dense_weight_transform.x86", + wrap_compute_dense(topi.x86.dense_pack), + wrap_topi_schedule(topi.x86.schedule_dense_pack), + name="dense_pack.x86", ) return strategy diff --git a/python/tvm/topi/nn/dense.py b/python/tvm/topi/nn/dense.py index 8ab7ecaf1281..e8ec476b86a5 100644 --- a/python/tvm/topi/nn/dense.py +++ b/python/tvm/topi/nn/dense.py @@ -107,8 +107,8 @@ def dense_legalize(attrs, inputs, types): return None -def dense_weight_transform(data, weight, bias=None, out_dtype=None): - """The default implementation of dense_weight_transform in topi. +def dense_pack(data, weight, bias=None, out_dtype=None): + """The default implementation of dense_pack in topi. Parameters ---------- @@ -145,8 +145,8 @@ def dense_weight_transform(data, weight, bias=None, out_dtype=None): * weight[idxdiv(x, packw_bn), k, idxmod(x, packw_bn)].astype(out_dtype), axis=k, ), - name="T_dense_weight_transform", - tag="dense_weight_transform", + name="T_dense_pack", + tag="dense_pack", ) if bias is not None: C = te.compute((M, N), lambda i, j: C[i, j] + bias[j].astype(out_dtype), tag=tag.BROADCAST) diff --git a/python/tvm/topi/x86/dense.py b/python/tvm/topi/x86/dense.py index 0e78735787f7..848ba72b442b 100644 --- a/python/tvm/topi/x86/dense.py +++ b/python/tvm/topi/x86/dense.py @@ -32,7 +32,7 @@ from ..utils import traverse_inline, get_const_tuple -def _schedule_dense_weight_transform_template(cfg, s, C, O): +def _schedule_dense_pack_template(cfg, s, C, O): A, packedB = s[C].op.input_tensors CC = s.cache_write(C, "global") @@ -99,7 +99,7 @@ def _schedule_dense_nopack_template(cfg, s, C): return s -def _default_dense_weight_transform_config(cfg, M, N, K): +def _default_dense_pack_config(cfg, M, N, K): # Generate default schedule for dynamic shape. if isinstance(M, tvm.tir.Var): M = 16 @@ -201,16 +201,8 @@ def _callback(op): return s -def dense_pack(data, weight, bias=None, out_dtype=None): - return dense_weight_transform(data, weight, bias, out_dtype) - - -def schedule_dense_pack(outs): - return schedule_dense_weight_transform(outs) - - -@autotvm.register_topi_compute("dense_weight_transform.x86") -def dense_weight_transform(cfg, data, weight, bias=None, out_dtype=None): +@autotvm.register_topi_compute("dense_pack.x86") +def dense_pack(cfg, data, weight, bias=None, out_dtype=None): """Compute dense with transformed weight.""" if out_dtype is None: out_dtype = data.dtype @@ -226,7 +218,7 @@ def dense_weight_transform(cfg, data, weight, bias=None, out_dtype=None): cfg.define_split("tile_k", K, num_outputs=2) cfg.define_split("tile_inner", M, num_outputs=2, filter=lambda y: y.size[-1] <= 16) if cfg.is_fallback: - _default_dense_weight_transform_config(cfg, M, N, K) + _default_dense_pack_config(cfg, M, N, K) if len(weight.shape) == 2: packw_bn = cfg["tile_x"].size[-1] @@ -251,21 +243,21 @@ def dense_weight_transform(cfg, data, weight, bias=None, out_dtype=None): * packw[idxdiv(x, packw_bn), k, idxmod(x, packw_bn)].astype(out_dtype), axis=k, ), - tag="dense_weight_transform", + tag="dense_pack", ) if bias is not None: C = te.compute((M, N), lambda i, j: C[i, j] + bias[j].astype(out_dtype), tag=tag.BROADCAST) return C -@autotvm.register_topi_schedule("dense_weight_transform.x86") -def schedule_dense_weight_transform(cfg, outs): - """Create the schedule for dense_weight_transform""" +@autotvm.register_topi_schedule("dense_pack.x86") +def schedule_dense_pack(cfg, outs): + """Create the schedule for dense_pack""" s = te.create_schedule([x.op for x in outs]) def _callback(op): - if "dense_weight_transform" in op.tag: - _schedule_dense_weight_transform_template(cfg, s, op.output(0), outs[0]) + if "dense_pack" in op.tag: + _schedule_dense_pack_template(cfg, s, op.output(0), outs[0]) traverse_inline(s, outs[0].op, _callback) return s diff --git a/python/tvm/topi/x86/dense_alter_op.py b/python/tvm/topi/x86/dense_alter_op.py index 14c5ca574d3c..5e15c8bf5368 100644 --- a/python/tvm/topi/x86/dense_alter_op.py +++ b/python/tvm/topi/x86/dense_alter_op.py @@ -21,7 +21,7 @@ from tvm import te from tvm import relay from tvm import autotvm -from .dense import _default_dense_weight_transform_config +from .dense import _default_dense_pack_config from ..utils import get_const_tuple from ..nn import dense_alter_layout @@ -42,9 +42,9 @@ def _alter_dense_layout(attrs, inputs, tinfos, out_type): if workload: cfg = dispatch_ctx.query(target, workload) topi_impl = workload[0] - if topi_impl == "dense_weight_transform.x86": + if topi_impl == "dense_pack.x86": if cfg.is_fallback: - _default_dense_weight_transform_config(cfg, M, N, K) + _default_dense_pack_config(cfg, M, N, K) packw_bn = cfg["tile_x"].size[-1] weight_layout = "NK%dn" % packw_bn new_weight = te.placeholder( @@ -63,8 +63,6 @@ def _alter_dense_layout(attrs, inputs, tinfos, out_type): ) dispatch_ctx.update(target, new_workload, cfg) weight_transform = relay.layout_transform(inputs[1], "NK", weight_layout) - return relay.nn.contrib_dense_weight_transform( - inputs[0], weight_transform, None, out_dtype - ) + return relay.nn.contrib_dense_pack(inputs[0], weight_transform, None, out_dtype) return None diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 2f7ae2f871dc..3e3d94c614c3 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -186,20 +186,19 @@ RELAY_REGISTER_OP("nn.dense") .set_support_level(1) .add_type_rel("Dense", DenseRel); -// relay.nn.contrib_dense_weight_transform -// Positional relay function to create dense_weight_transform operator used by frontend FFI. -Expr MakeDenseWeightTransform(Expr data, Expr weight, IndexExpr units, DataType out_dtype) { +// relay.nn.contrib_dense_pack +// Positional relay function to create dense_pack operator used by frontend FFI. +Expr MakeDensePack(Expr data, Expr weight, IndexExpr units, DataType out_dtype) { auto attrs = make_object(); attrs->units = units; attrs->out_dtype = out_dtype; - static const Op& op = Op::Get("nn.contrib_dense_weight_transform"); + static const Op& op = Op::Get("nn.contrib_dense_pack"); return Call(op, {data, weight}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_dense_weight_transform") - .set_body_typed(MakeDenseWeightTransform); +TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_dense_pack").set_body_typed(MakeDensePack); -RELAY_REGISTER_OP("nn.contrib_dense_weight_transform") +RELAY_REGISTER_OP("nn.contrib_dense_pack") .describe(R"code(Applies a linear transformation: :math:`Y = XW^T`. - **data**: `(x1, x2, ..., xn, input_dim)` @@ -212,7 +211,7 @@ RELAY_REGISTER_OP("nn.contrib_dense_weight_transform") .add_argument("data", "nD Tensor", "Input data.") .add_argument("weight", "3D Tensor", "Packed weight matrix.") .set_support_level(10) - .add_type_rel("DenseWeightTransform", DenseWeightTransformRel); + .add_type_rel("DensePack", DensePackRel); // relay.leaky_relu TVM_REGISTER_NODE_TYPE(LeakyReluAttrs); diff --git a/src/relay/op/nn/nn.h b/src/relay/op/nn/nn.h index 75dab129899f..c00e2e02b369 100644 --- a/src/relay/op/nn/nn.h +++ b/src/relay/op/nn/nn.h @@ -91,8 +91,8 @@ bool DenseRel(const Array& types, int num_inputs, const Attrs& attrs, } template -bool DenseWeightTransformRel(const Array& types, int num_inputs, const Attrs& attrs, - const TypeReporter& reporter) { +bool DensePackRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { ICHECK_EQ(types.size(), 3); const auto* data = types[0].as(); const auto* weight = types[1].as(); From cbd2d860354ad82cd3bf8b21eea380586e17e409 Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Fri, 5 Feb 2021 00:24:22 +0000 Subject: [PATCH 10/15] Fix test --- tests/python/relay/test_pass_alter_op_layout.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index 59ff5847744b..fecec4113da2 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -1261,7 +1261,7 @@ def expected(): weight = relay.var("weight", shape=(48, 64)) target_layout = "NK16n" weight_transform = relay.layout_transform(weight, "NK", target_layout) - y = relay.nn.contrib_dense_weight_transform( + y = relay.nn.contrib_dense_pack( x, weight_transform, units=None, out_dtype="float32" ) y = relay.Function(analysis.free_vars(y), y) @@ -1279,7 +1279,6 @@ def expected(): if __name__ == "__main__": - """ test_alter_op() test_alter_return_none() test_alter_layout() @@ -1300,5 +1299,4 @@ def expected(): test_alter_layout_nhwc_arm() test_alter_layout_nhwc_int8_aarch64() test_alter_op_with_global_var() - """ test_alter_op_dense() From f1f8d4be25a0e724a4bab278672e93a513dfa903 Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Fri, 5 Feb 2021 01:29:51 +0000 Subject: [PATCH 11/15] Fix lint --- tests/python/relay/test_pass_alter_op_layout.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index fecec4113da2..41186884bdb2 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -1261,9 +1261,7 @@ def expected(): weight = relay.var("weight", shape=(48, 64)) target_layout = "NK16n" weight_transform = relay.layout_transform(weight, "NK", target_layout) - y = relay.nn.contrib_dense_pack( - x, weight_transform, units=None, out_dtype="float32" - ) + y = relay.nn.contrib_dense_pack(x, weight_transform, units=None, out_dtype="float32") y = relay.Function(analysis.free_vars(y), y) return y From e3749bf88cf15d5dea0a5b5ae8bf122c45c0fd3d Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Fri, 5 Feb 2021 19:13:24 +0000 Subject: [PATCH 12/15] Fix dynamic shape dense --- python/tvm/relay/op/nn/_nn.py | 19 ++++++++++++++++ python/tvm/topi/x86/dense.py | 43 ++++++++++++++++++++++++----------- 2 files changed, 49 insertions(+), 13 deletions(-) diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 3a4414e57108..cb08649029d6 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -1141,6 +1141,25 @@ def dense_shape_func(attrs, inputs, _): return ret +@script +def _dense_pack_shape_func(data_shape, weight_shape): + out = output_tensor((data_shape.shape[0],), "int64") + for i in const_range(out.shape[0] - 1): + out[i] = data_shape[i] + out[out.shape[0] - 1] = weight_shape[0] * weight_shape[2] + + return out + + +@reg.register_shape_func("nn.contrib_dense_pack", False) +def dense_shape_func(attrs, inputs, _): + """ + Shape function for dense_pack op. + """ + ret = [_dense_pack_shape_func(inputs[0], inputs[1])] + return ret + + @script def _batch_matmul_shape_func(data_shape, weight_shape): out = output_tensor((data_shape.shape[0],), "int64") diff --git a/python/tvm/topi/x86/dense.py b/python/tvm/topi/x86/dense.py index 848ba72b442b..6011f01c2cb0 100644 --- a/python/tvm/topi/x86/dense.py +++ b/python/tvm/topi/x86/dense.py @@ -101,11 +101,11 @@ def _schedule_dense_nopack_template(cfg, s, C): def _default_dense_pack_config(cfg, M, N, K): # Generate default schedule for dynamic shape. - if isinstance(M, tvm.tir.Var): + if isinstance(M, (tvm.tir.Var, tvm.tir.Any)): M = 16 - if isinstance(N, tvm.tir.Var): + if isinstance(N, (tvm.tir.Var, tvm.tir.Any)): N = 16 - if isinstance(K, tvm.tir.Var): + if isinstance(K, (tvm.tir.Var, tvm.tir.Any)): K = 16 vec_width = get_fp32_len() @@ -139,11 +139,11 @@ def _default_dense_pack_config(cfg, M, N, K): def _default_dense_nopack_config(cfg, M, N, K): # Generate default schedule for dynamic shape. - if isinstance(M, tvm.tir.Var): + if isinstance(M, (tvm.tir.Var, tvm.tir.Any)): M = 16 - if isinstance(N, tvm.tir.Var): + if isinstance(N, (tvm.tir.Var, tvm.tir.Any)): N = 16 - if isinstance(K, tvm.tir.Var): + if isinstance(K, (tvm.tir.Var, tvm.tir.Any)): K = 16 vec_width = get_fp32_len() @@ -165,9 +165,15 @@ def dense_nopack(cfg, data, weight, bias=None, out_dtype=None): M, K = get_const_tuple(data.shape) N, _ = get_const_tuple(weight.shape) # create tuning space - cfg.define_split("tile_y", 32 if isinstance(M, tvm.tir.Var) else M, num_outputs=2) - cfg.define_split("tile_x", 32 if isinstance(N, tvm.tir.Var) else N, num_outputs=2) - cfg.define_split("tile_k", 32 if isinstance(K, tvm.tir.Var) else K, num_outputs=2) + cfg.define_split( + "tile_y", 32 if isinstance(M, (tvm.tir.Var, tvm.tir.Any)) else M, num_outputs=2 + ) + cfg.define_split( + "tile_x", 32 if isinstance(N, (tvm.tir.Var, tvm.tir.Any)) else N, num_outputs=2 + ) + cfg.define_split( + "tile_k", 32 if isinstance(K, (tvm.tir.Var, tvm.tir.Any)) else K, num_outputs=2 + ) if cfg.is_fallback: _default_dense_nopack_config(cfg, M, N, K) @@ -213,10 +219,21 @@ def dense_pack(cfg, data, weight, bias=None, out_dtype=None): else: N, _ = get_const_tuple(weight.shape) # out_dim # create tuning space - cfg.define_split("tile_y", M, num_outputs=3) - cfg.define_split("tile_x", N, num_outputs=3) - cfg.define_split("tile_k", K, num_outputs=2) - cfg.define_split("tile_inner", M, num_outputs=2, filter=lambda y: y.size[-1] <= 16) + cfg.define_split( + "tile_y", 32 if isinstance(M, (tvm.tir.Var, tvm.tir.Any)) else M, num_outputs=3 + ) + cfg.define_split( + "tile_x", 32 if isinstance(N, (tvm.tir.Var, tvm.tir.Any)) else N, num_outputs=3 + ) + cfg.define_split( + "tile_k", 32 if isinstance(K, (tvm.tir.Var, tvm.tir.Any)) else K, num_outputs=2 + ) + cfg.define_split( + "tile_inner", + 32 if isinstance(M, (tvm.tir.Var, tvm.tir.Any)) else M, + num_outputs=2, + filter=lambda y: y.size[-1] <= 16, + ) if cfg.is_fallback: _default_dense_pack_config(cfg, M, N, K) From 271f344506f0335c9098f6a8816b5873b5a13647 Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Fri, 5 Feb 2021 21:33:01 +0000 Subject: [PATCH 13/15] Fix lint --- python/tvm/relay/op/nn/_nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index cb08649029d6..6ae86c0786e5 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -1152,7 +1152,7 @@ def _dense_pack_shape_func(data_shape, weight_shape): @reg.register_shape_func("nn.contrib_dense_pack", False) -def dense_shape_func(attrs, inputs, _): +def dense_pack_shape_func(attrs, inputs, _): """ Shape function for dense_pack op. """ From c256475c2fe94cb9ba478245b452b1be47a9b88b Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Sat, 6 Feb 2021 02:15:58 +0000 Subject: [PATCH 14/15] Fix autotvm task extraction test --- tests/python/relay/test_autotvm_task_extraction.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/python/relay/test_autotvm_task_extraction.py b/tests/python/relay/test_autotvm_task_extraction.py index da71ac37f695..d6bfd8d0ec11 100644 --- a/tests/python/relay/test_autotvm_task_extraction.py +++ b/tests/python/relay/test_autotvm_task_extraction.py @@ -60,9 +60,9 @@ def test_task_extraction(): tasks = autotvm.task.extract_from_program( mod["main"], target=target, params=params, ops=(dense,) ) - assert len(tasks) == 1 + assert len(tasks) == 2 tasks = autotvm.task.extract_from_program(mod, target=target, params=params, ops=(dense,)) - assert len(tasks) == 1 + assert len(tasks) == 2 mod, params, _ = get_network("resnet-18", batch_size=1) mod_list.append(mod) @@ -70,13 +70,13 @@ def test_task_extraction(): tasks = autotvm.task.extract_from_program( mod["main"], target=target, params=params, ops=(conv2d, dense) ) - assert len(tasks) == 13 + assert len(tasks) == 14 tasks = autotvm.task.extract_from_program( mod, target=target, params=params, ops=(conv2d, dense) ) - assert len(tasks) == 13 + assert len(tasks) == 14 tasks = autotvm.task.extract_from_program(mod, target=target, params=params) - assert len(tasks) == 13 + assert len(tasks) == 14 mod, params, _ = get_network("resnet3d-18", batch_size=1) tasks = autotvm.task.extract_from_program(mod, target=target, params=params, ops=(conv3d,)) @@ -88,7 +88,7 @@ def test_task_extraction(): tasks = autotvm.task.extract_from_program( mod, target=target, params=params, ops=(conv2d, dense) ) - assert len(tasks) == 20 + assert len(tasks) == 21 mod, params, _ = get_network("dcgan", batch_size=1) tasks = autotvm.task.extract_from_program( From 6b5abd16cd973e08e8f3492e08acea6a536095d9 Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Sun, 7 Feb 2021 03:49:04 +0000 Subject: [PATCH 15/15] Disable AlterOpLayout in micro_tflite.py tutorial --- tutorials/micro/micro_tflite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tutorials/micro/micro_tflite.py b/tutorials/micro/micro_tflite.py index c28918380265..c979216d0c6b 100644 --- a/tutorials/micro/micro_tflite.py +++ b/tutorials/micro/micro_tflite.py @@ -195,7 +195,7 @@ # Now, compile the model for the target: with tvm.transform.PassContext( - opt_level=3, config={"tir.disable_vectorize": True}, disabled_pass=["FuseOps"] + opt_level=3, config={"tir.disable_vectorize": True}, disabled_pass=["FuseOps", "AlterOpLayout"] ): graph, c_mod, c_params = relay.build(mod, target=TARGET, params=params)