Skip to content

Commit

Permalink
[Relay][Topi][CPU] Dense with weight transform (apache#7404)
Browse files Browse the repository at this point in the history
* Add CPU dense weight transform

* Fix format

* Fix python format

* Fix pylint

* Minor fix

* Add test

* Do not need to infer layout for dense

* Fix test

* Rename dense_pack

* Fix test

* Fix lint

* Fix dynamic shape dense

* Fix lint

* Fix autotvm task extraction test

* Disable AlterOpLayout in micro_tflite.py tutorial
  • Loading branch information
kevinthesun authored and Lokiiiiii committed Mar 1, 2021
1 parent 600e159 commit 0fd600f
Show file tree
Hide file tree
Showing 13 changed files with 413 additions and 47 deletions.
30 changes: 30 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,17 @@ def legalize_dense(attrs, inputs, types):
reg.register_pattern("nn.dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)


@reg.register_alter_op_layout("nn.dense")
def alter_op_layout_dense(attrs, inputs, tinfos, out_type):
"""Alternate the layout of dense"""
return topi.nn.dense_alter_layout(attrs, inputs, tinfos, out_type)


# dense_pack
reg.register_strategy("nn.contrib_dense_pack", strategy.dense_pack_strategy)
reg.register_pattern("nn.contrib_dense_pack", reg.OpPattern.OUT_ELEMWISE_FUSABLE)


# fifo_buffer
@reg.register_compute("nn.fifo_buffer")
def compute_fifo_buffer(attrs, inputs, out_type):
Expand Down Expand Up @@ -1130,6 +1141,25 @@ def dense_shape_func(attrs, inputs, _):
return ret


@script
def _dense_pack_shape_func(data_shape, weight_shape):
out = output_tensor((data_shape.shape[0],), "int64")
for i in const_range(out.shape[0] - 1):
out[i] = data_shape[i]
out[out.shape[0] - 1] = weight_shape[0] * weight_shape[2]

return out


@reg.register_shape_func("nn.contrib_dense_pack", False)
def dense_pack_shape_func(attrs, inputs, _):
"""
Shape function for dense_pack op.
"""
ret = [_dense_pack_shape_func(inputs[0], inputs[1])]
return ret


@script
def _batch_matmul_shape_func(data_shape, weight_shape):
out = output_tensor((data_shape.shape[0],), "int64")
Expand Down
33 changes: 33 additions & 0 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1435,6 +1435,39 @@ def dense(data, weight, units=None, out_dtype=""):
return _make.dense(data, weight, units, out_dtype)


def contrib_dense_pack(data, weight, units=None, out_dtype=""):
"""Dense operator.
Applies a linear transformation
.. math::
`Y = X * W^T`
Parameters
----------
data : tvm.relay.Expr
The input data to the operator,
of shape `(d_1, d_2, ..., d_n, units_in)`.
weight : tvm.relay.Expr
The transformed weight expressions, 3-D matrix,
of shape `(units // pack_weight_tile, units_in, pack_weight_tile)`.
units : int, optional
Number of hidden units of the dense transformation.
out_dtype : str, optional
Specifies the output data type for mixed precision dense,
of shape `(d_1, d_2, ..., d_n, units)`.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return _make.contrib_dense_pack(data, weight, units, out_dtype)


def fifo_buffer(data, buffer, axis):
"""FIFO buffer to enable computation reuse in CNNs with sliding indow input
Expand Down
13 changes: 13 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,19 @@ def dense_strategy(attrs, inputs, out_type, target):
return strategy


@override_native_generic_func("dense_pack_strategy")
def dense_pack_strategy(attrs, inputs, out_type, target):
"""dense_pack generic strategy"""
logger.warning("dense_pack is not optimized for this platform.")
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_dense(topi.nn.dense_pack),
wrap_topi_schedule(topi.generic.schedule_dense),
name="dense_pack.generic",
)
return strategy


# batch_matmul
def wrap_compute_batch_matmul(topi_compute, need_auto_scheduler_layout=False):
"""wrap batch_matmul topi compute"""
Expand Down
28 changes: 19 additions & 9 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,14 +364,20 @@ def conv1d_strategy_cpu(attrs, inputs, out_type, target):
def dense_strategy_cpu(attrs, inputs, out_type, target):
"""dense x86 strategy"""
strategy = _op.OpStrategy()
m, _ = inputs[0].shape
same_type = inputs[0].dtype == inputs[1].dtype == out_type.dtype
dtype = inputs[0].dtype
u8s8s32 = dtype == "uint8" and inputs[1].dtype == "int8" and out_type.dtype == "int32"
strategy.add_implementation(
wrap_compute_dense(topi.x86.dense_nopack),
wrap_topi_schedule(topi.x86.schedule_dense_nopack),
name="dense_nopack.x86",
plevel=5,
)

strategy.add_implementation(
wrap_compute_dense(topi.x86.dense_pack),
wrap_topi_schedule(topi.x86.schedule_dense_pack),
name="dense_pack.x86",
plevel=10,
)

Expand Down Expand Up @@ -407,14 +413,18 @@ def dense_strategy_cpu(attrs, inputs, out_type, target):
name="dense_mkldnn.x86",
plevel=15,
)
with SpecializedCondition(m >= 16):
# this implementation may not be well-optimized, so use plevel=5 for now.
strategy.add_implementation(
wrap_compute_dense(topi.x86.dense_pack),
wrap_topi_schedule(topi.x86.schedule_dense_pack),
name="dense_pack.x86",
plevel=5,
)
return strategy


@dense_pack_strategy.register("cpu")
def dense_pack_strategy_cpu(attrs, inputs, out_type, target):
"""dense_pack x86 strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_dense(topi.x86.dense_pack),
wrap_topi_schedule(topi.x86.schedule_dense_pack),
name="dense_pack.x86",
)
return strategy


Expand Down
70 changes: 70 additions & 0 deletions python/tvm/topi/nn/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,unused-argument
"""TVM operator fully connected compute."""
import tvm
from tvm import te, auto_scheduler
Expand Down Expand Up @@ -104,3 +105,72 @@ def dense_legalize(attrs, inputs, types):
# not to change by default
# pylint: disable=unused-argument
return None


def dense_pack(data, weight, bias=None, out_dtype=None):
"""The default implementation of dense_pack in topi.
Parameters
----------
data : tvm.te.Tensor
2-D with shape [batch, in_dim]
weight : tvm.te.Tensor
2-D with shape [out_dim, in_dim]
bias : Optional[tvm.te.Tensor]
1-D with shape [out_dim]
out_dtype : Optional[str]
The output type. This is used for mixed precision.
Returns
-------
output : tvm.te.Tensor
2-D with shape [batch, out_dim]
"""
if out_dtype is None:
out_dtype = data.dtype
M, K = get_const_tuple(data.shape) # batch, in_dim
N, _, packw_bn = get_const_tuple(weight.shape) # out_dim
N = N * packw_bn

idxdiv = tvm.tir.indexdiv
idxmod = tvm.tir.indexmod
k = te.reduce_axis((0, K), name="k")
C = te.compute(
(M, N),
lambda y, x: te.sum(
data[y, k].astype(out_dtype)
* weight[idxdiv(x, packw_bn), k, idxmod(x, packw_bn)].astype(out_dtype),
axis=k,
),
name="T_dense_pack",
tag="dense_pack",
)
if bias is not None:
C = te.compute((M, N), lambda i, j: C[i, j] + bias[j].astype(out_dtype), tag=tag.BROADCAST)
return C


@tvm.target.generic_func
def dense_alter_layout(attrs, inputs, tinfos, out_type):
"""Change dense layout.
Parameters
----------
attrs : tvm.ir.Attrs
Attributes of current convolution
inputs : tvm.relay.Expr
Grouped input symbols
tinfos : list
Input shape and dtype
out_type: type
The output type
Note
----
Unlike other TOPI functions, this function operates on both graph level and operator level.
"""
# not to change by default
return None
1 change: 1 addition & 0 deletions python/tvm/topi/x86/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,5 @@
from .conv3d_transpose import *
from .sparse import *
from .conv2d_alter_op import *
from .dense_alter_op import *
from .scatter import *
Loading

0 comments on commit 0fd600f

Please sign in to comment.