Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Topi][CPU] Dense with weight transform #7404

Merged
merged 15 commits into from
Feb 9, 2021
15 changes: 15 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -939,6 +939,21 @@ struct DenseAttrs : public tvm::AttrsNode<DenseAttrs> {
}
};

/*! \brief Attributes for dense_weight_transform operator */
struct DenseWeightTransformAttrs : public tvm::AttrsNode<DenseWeightTransformAttrs> {
tvm::String weight_layout;
DataType out_dtype;

TVM_DECLARE_ATTRS(DenseWeightTransformAttrs, "relay.attrs.DenseAttrs") {
TVM_ATTR_FIELD(weight_layout)
.set_default("NKn")
.describe("Dimension ordering of weight. Can be 'NKn', 'NK16n', etc.");
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");
}
};

/*! \brief Attributes for batch matmul operator */
struct BatchMatmulAttrs : public tvm::AttrsNode<BatchMatmulAttrs> {
tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite
Expand Down
11 changes: 11 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,17 @@ def legalize_dense(attrs, inputs, types):
reg.register_pattern("nn.dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)


@reg.register_alter_op_layout("nn.dense")
def alter_op_layout_dense(attrs, inputs, tinfos, out_type):
"""Alternate the layout of dense"""
return topi.nn.dense_alter_layout(attrs, inputs, tinfos, out_type)


# dense_weight_transform
reg.register_strategy("nn.contrib_dense_weight_transform", strategy.dense_weight_transform_strategy)
reg.register_pattern("nn.contrib_dense_weight_transform", reg.OpPattern.OUT_ELEMWISE_FUSABLE)


# fifo_buffer
@reg.register_compute("nn.fifo_buffer")
def compute_fifo_buffer(attrs, inputs, out_type):
Expand Down
33 changes: 33 additions & 0 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1435,6 +1435,39 @@ def dense(data, weight, units=None, out_dtype=""):
return _make.dense(data, weight, units, out_dtype)


def contrib_dense_weight_transform(data, weight, weight_layout="NKn", out_dtype=""):
"""Dense operator.
Applies a linear transformation

.. math::

`Y = X * W^T`

Parameters
----------
data : tvm.relay.Expr
The input data to the operator,
of shape `(d_1, d_2, ..., d_n, units_in)`.

weight : tvm.relay.Expr
The transformed weight expressions, 3-D matrix,
of shape `(units // pack_weight_tile, units_in, pack_weight_tile)`.

weight_layout : str, optional
Layout of the weight.

out_dtype : str, optional
Specifies the output data type for mixed precision dense,
of shape `(d_1, d_2, ..., d_n, units)`.

Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return _make.contrib_dense_weight_transform(data, weight, weight_layout, out_dtype)


def fifo_buffer(data, buffer, axis):
"""FIFO buffer to enable computation reuse in CNNs with sliding indow input

Expand Down
13 changes: 13 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,19 @@ def dense_strategy(attrs, inputs, out_type, target):
return strategy


@override_native_generic_func("dense_weight_transform_strategy")
def dense_weight_transform_strategy(attrs, inputs, out_type, target):
"""dense_weight_transform generic strategy"""
logger.warning("dense_weight_transform is not optimized for this platform.")
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_dense(topi.nn.dense_weight_transform),
wrap_topi_schedule(topi.generic.schedule_dense),
name="dense_weight_transform.generic",
)
return strategy


# batch_matmul
def wrap_compute_batch_matmul(topi_compute, need_auto_scheduler_layout=False):
"""wrap batch_matmul topi compute"""
Expand Down
27 changes: 15 additions & 12 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,14 +364,13 @@ def conv1d_strategy_cpu(attrs, inputs, out_type, target):
def dense_strategy_cpu(attrs, inputs, out_type, target):
"""dense x86 strategy"""
strategy = _op.OpStrategy()
m, _ = inputs[0].shape
same_type = inputs[0].dtype == inputs[1].dtype == out_type.dtype
dtype = inputs[0].dtype
u8s8s32 = dtype == "uint8" and inputs[1].dtype == "int8" and out_type.dtype == "int32"
strategy.add_implementation(
wrap_compute_dense(topi.x86.dense_nopack),
wrap_topi_schedule(topi.x86.schedule_dense_nopack),
name="dense_nopack.x86",
wrap_compute_dense(topi.x86.dense_pack),
wrap_topi_schedule(topi.x86.schedule_dense_pack),
name="dense_pack.x86",
kevinthesun marked this conversation as resolved.
Show resolved Hide resolved
plevel=10,
)

Expand Down Expand Up @@ -407,14 +406,18 @@ def dense_strategy_cpu(attrs, inputs, out_type, target):
name="dense_mkldnn.x86",
plevel=15,
)
with SpecializedCondition(m >= 16):
# this implementation may not be well-optimized, so use plevel=5 for now.
strategy.add_implementation(
wrap_compute_dense(topi.x86.dense_pack),
wrap_topi_schedule(topi.x86.schedule_dense_pack),
name="dense_pack.x86",
plevel=5,
)
return strategy


@dense_weight_transform_strategy.register("cpu")
def dense_weight_transform_strategy_cpu(attrs, inputs, out_type, target):
"""dense_weight_transform x86 strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_dense(topi.x86.dense_weight_transform),
wrap_topi_schedule(topi.x86.schedule_dense_weight_transform),
name="dense_weight_transform.x86",
)
return strategy


Expand Down
70 changes: 70 additions & 0 deletions python/tvm/topi/nn/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,unused-argument
"""TVM operator fully connected compute."""
import tvm
from tvm import te, auto_scheduler
Expand Down Expand Up @@ -104,3 +105,72 @@ def dense_legalize(attrs, inputs, types):
# not to change by default
# pylint: disable=unused-argument
return None


def dense_weight_transform(data, weight, bias=None, out_dtype=None):
"""The default implementation of dense_weight_transform in topi.

Parameters
----------
data : tvm.te.Tensor
2-D with shape [batch, in_dim]

weight : tvm.te.Tensor
2-D with shape [out_dim, in_dim]

bias : Optional[tvm.te.Tensor]
1-D with shape [out_dim]

out_dtype : Optional[str]
The output type. This is used for mixed precision.

Returns
-------
output : tvm.te.Tensor
2-D with shape [batch, out_dim]
"""
if out_dtype is None:
out_dtype = data.dtype
M, K = get_const_tuple(data.shape) # batch, in_dim
N, _, packw_bn = get_const_tuple(weight.shape) # out_dim
N = N * packw_bn

idxdiv = tvm.tir.indexdiv
idxmod = tvm.tir.indexmod
k = te.reduce_axis((0, K), name="k")
C = te.compute(
(M, N),
lambda y, x: te.sum(
data[y, k].astype(out_dtype)
* weight[idxdiv(x, packw_bn), k, idxmod(x, packw_bn)].astype(out_dtype),
axis=k,
),
name="T_dense_weight_transform",
tag="dense_weight_transform",
)
if bias is not None:
C = te.compute((M, N), lambda i, j: C[i, j] + bias[j].astype(out_dtype), tag=tag.BROADCAST)
return C


@tvm.target.generic_func
def dense_alter_layout(attrs, inputs, tinfos, out_type):
"""Change dense layout.

Parameters
----------
attrs : tvm.ir.Attrs
Attributes of current convolution
inputs : tvm.relay.Expr
Grouped input symbols
tinfos : list
Input shape and dtype
out_type: type
The output type

Note
----
Unlike other TOPI functions, this function operates on both graph level and operator level.
"""
# not to change by default
return None
1 change: 1 addition & 0 deletions python/tvm/topi/x86/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,5 @@
from .conv3d_transpose import *
from .sparse import *
from .conv2d_alter_op import *
from .dense_alter_op import *
from .scatter import *
Loading