forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* [TOP] complete level2 * [TOP] add split
- Loading branch information
Showing
9 changed files
with
756 additions
and
249 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,214 @@ | ||
/*! | ||
* Copyright (c) 2017 by Contributors | ||
* \file convolution.cc | ||
* \brief Convolution operators | ||
*/ | ||
#include <nnvm/op.h> | ||
#include <nnvm/node.h> | ||
#include <nnvm/op_attr_types.h> | ||
#include <nnvm/top/nn.h> | ||
#include "./nn_common.h" | ||
#include "../op_common.h" | ||
#include "../elemwise_op_common.h" | ||
|
||
namespace nnvm { | ||
namespace top { | ||
|
||
// conv2d | ||
DMLC_REGISTER_PARAMETER(ConvParam); | ||
|
||
inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs, | ||
std::vector<TShape>* in_shape, | ||
std::vector<TShape>* out_shape) { | ||
const ConvParam& param = nnvm::get<ConvParam>(attrs.parsed); | ||
if (param.use_bias) { | ||
CHECK_EQ(in_shape->size(), 3U) << "Input:[data, weight, bias]"; | ||
} else { | ||
CHECK_EQ(in_shape->size(), 2U) << "Input:[data, weight]"; | ||
} | ||
CHECK_EQ(out_shape->size(), 1U); | ||
|
||
TShape dshape = in_shape->at(0); | ||
if (dshape.ndim() == 0) return false; | ||
dshape = ConvertLayout(dshape, param.layout, kNCHW); | ||
|
||
CHECK_EQ(dshape.ndim(), 4U) << "Input data should be 4D"; | ||
CHECK_EQ(param.kernel_size.ndim(), 2U); | ||
CHECK_EQ(param.strides.ndim(), 2U) | ||
<< "incorrect stride size: " << param.strides; | ||
CHECK_EQ(param.dilation.ndim(), 2U) | ||
<< "incorrect dilate size: " << param.dilation; | ||
CHECK_EQ(dshape[1] % param.groups, 0U) | ||
<< "input channels must divide group size"; | ||
CHECK_EQ(param.channels % param.groups, 0U) | ||
<< "output channels must divide group size"; | ||
|
||
TShape wshape({param.channels / param.groups, | ||
dshape[1] / param.groups, | ||
param.kernel_size[0], | ||
param.kernel_size[1]}); | ||
|
||
wshape = ConvertLayout(wshape, kNCHW, param.layout); | ||
wshape[0] *= param.groups; | ||
|
||
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, ConvParam::kWeight, wshape); | ||
if (param.use_bias) { | ||
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, | ||
ConvParam::kBias, TShape({param.channels})); | ||
} | ||
// dilation | ||
dim_t dilated_ksize_y = 1 + (param.kernel_size[0] - 1) * param.dilation[0]; | ||
dim_t dilated_ksize_x = 1 + (param.kernel_size[1] - 1) * param.dilation[1]; | ||
TShape oshape({dshape[0], param.channels, 0, 0}); | ||
if (dshape[2] != 0) { | ||
oshape[2] = (dshape[2] + param.padding[0] * 2 - dilated_ksize_y) / param.strides[0] + 1; | ||
} | ||
if (dshape[3] != 0) { | ||
oshape[3] = (dshape[3] + param.padding[1] * 2 - dilated_ksize_x) / param.strides[1] + 1; | ||
} | ||
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, | ||
ConvertLayout(oshape, kNCHW, param.layout)); | ||
// Perform incomplete shape inference. Fill in the missing values in data shape. | ||
// 1) We can always fill in the batch_size. | ||
// 2) We can back-calculate the input height/width if the corresponding stride is 1. | ||
oshape = ConvertLayout((*out_shape)[0], param.layout, kNCHW); | ||
dshape[0] = oshape[0]; | ||
if (oshape[2] && param.strides[0] == 1) { | ||
dshape[2] = oshape[2] + dilated_ksize_y - 1 - 2 * param.padding[0]; | ||
} | ||
if (oshape[3] && param.strides[1] == 1) { | ||
dshape[3] = oshape[3] + dilated_ksize_x - 1 - 2 * param.padding[1]; | ||
} | ||
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, ConvParam::kData, | ||
ConvertLayout(dshape, kNCHW, param.layout)); | ||
// Check whether the kernel sizes are valid | ||
if (dshape[2] != 0) { | ||
CHECK_LE(dilated_ksize_y, dshape[2] + 2 * param.padding[0]) | ||
<< "kernel size exceed input"; | ||
} | ||
if (dshape[3] != 0) { | ||
CHECK_LE(dilated_ksize_x, dshape[3] + 2 * param.padding[1]) | ||
<< "kernel size exceed input"; | ||
} | ||
return true; | ||
} | ||
|
||
NNVM_REGISTER_OP(conv2d) | ||
.describe(R"code(2D convolution layer (e.g. spatial convolution over images). | ||
This layer creates a convolution kernel that is convolved | ||
with the layer input to produce a tensor of | ||
outputs. If `use_bias` is True, | ||
a bias vector is created and added to the outputs. | ||
- **data**: This depends on the `layout` parameter. Input is 4D array of shape | ||
(batch_size, in_channels, height, width) if `layout` is `NCHW`. | ||
- **weight**: (channels, in_channels, kernel_size[0], kernel_size[1]) | ||
- **bias**: (channels,) | ||
- **out**: This depends on the `layout` parameter. Output is 4D array of shape | ||
(batch_size, channels, out_height, out_width) if `layout` is `NCHW`. | ||
)code" NNVM_ADD_FILELINE) | ||
.add_argument("data", "4D Tensor", "Input data.") | ||
.add_argument("weight", "4D Tensor", "Weight matrix.") | ||
.add_argument("bias", "1D Tensor", "Bias parameter.") | ||
.add_arguments(ConvParam::__FIELDS__()) | ||
.set_attr_parser(ParamParser<ConvParam>) | ||
.set_num_outputs(1) | ||
.set_num_inputs(UseBiasNumInputs<ConvParam>) | ||
.set_attr<FListInputNames>("FListInputNames", UseBiasListInputNames<ConvParam>) | ||
.set_attr<FInferShape>("FInferShape", Conv2DInferShape) | ||
.set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>) | ||
.set_support_level(2); | ||
|
||
|
||
DMLC_REGISTER_PARAMETER(ConvTransposeParam); | ||
|
||
inline bool ConvTransposeInferShape(const nnvm::NodeAttrs& attrs, | ||
std::vector<TShape>* in_shape, | ||
std::vector<TShape>* out_shape) { | ||
const ConvTransposeParam& param = nnvm::get<ConvTransposeParam>(attrs.parsed); | ||
if (param.use_bias) { | ||
CHECK_EQ(in_shape->size(), 3U) << "Input:[data, weight, bias]"; | ||
} else { | ||
CHECK_EQ(in_shape->size(), 2U) << "Input:[data, weight]"; | ||
} | ||
CHECK_EQ(out_shape->size(), 1U); | ||
const TShape& dshape = (*in_shape)[ConvTransposeParam::kData]; | ||
if (dshape.ndim() == 0) return false; | ||
TShape dshape_nchw = ConvertLayout(dshape, param.layout, kNCHW); | ||
|
||
CHECK_EQ(dshape_nchw[1] % param.groups, 0U) | ||
<< "input num_filter must divide group size"; | ||
CHECK_EQ(param.channels % param.groups, 0U) | ||
<< "output num_filter must divide group size"; | ||
CHECK_EQ(param.kernel_size.ndim(), 2U) | ||
<< "incorrect kernel size: " << param.kernel_size; | ||
CHECK_EQ(param.strides.ndim(), 2U) | ||
<< "incorrect stride size: " << param.strides; | ||
CHECK_EQ(param.dilation.ndim(), 2U) | ||
<< "incorrect dilate size: " << param.dilation; | ||
|
||
TShape wshape({dshape_nchw[1], | ||
param.channels / param.groups, | ||
param.kernel_size[0], param.kernel_size[1]}); | ||
wshape = ConvertLayout(wshape, kNCHW, param.layout); | ||
|
||
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, ConvTransposeParam::kWeight, wshape); | ||
|
||
if (param.use_bias) { | ||
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, | ||
ConvTransposeParam::kBias, | ||
TShape({param.channels})); | ||
} | ||
// dilation | ||
dim_t dilated_ksize_y = 1 + (param.kernel_size[0] - 1) * param.dilation[0]; | ||
dim_t dilated_ksize_x = 1 + (param.kernel_size[1] - 1) * param.dilation[1]; | ||
// output shape. | ||
TShape oshape({dshape_nchw[0], param.channels, 0, 0}); | ||
oshape[2] = (param.strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y - | ||
2 * param.padding[0] + param.output_padding[0]); | ||
|
||
oshape[3] = (param.strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x - | ||
2 * param.padding[1] + param.output_padding[1]); | ||
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, | ||
ConvertLayout(oshape, kNCHW, param.layout)); | ||
return true; | ||
} | ||
|
||
NNVM_REGISTER_OP(conv2d_transpose) | ||
.describe(R"code(Transposed 2D convolution layer (sometimes called Deconvolution). | ||
The need for transposed convolutions generally arises | ||
from the desire to use a transformation going in the opposite direction | ||
of a normal convolution, i.e., from something that has the shape of the | ||
output of some convolution to something that has the shape of its input | ||
while maintaining a connectivity pattern that is compatible with | ||
said convolution. | ||
- **data**: This depends on the `layout` parameter. Input is 4D array of shape | ||
(batch_size, in_channels, height, width) if `layout` is `NCHW`. | ||
- **weight**: (channels, in_channels, kernel_size[0], kernel_size[1]) | ||
- **bias**: (channels,) | ||
- **out**: This depends on the `layout` parameter. Output is 4D array of shape | ||
(batch_size, channels, out_height, out_width) if `layout` is `NCHW`. | ||
out_height and out_width are calculated as:: | ||
out_height = (height-1)*strides[0]-2*padding[0]+kernel_size[0]+output_padding[0] | ||
out_width = (width-1)*strides[1]-2*padding[1]+kernel_size[1]+output_padding[1] | ||
)code" NNVM_ADD_FILELINE) | ||
.add_argument("data", "4D Tensor", "Input data.") | ||
.add_argument("weight", "4D Tensor", "Weight matrix.") | ||
.add_argument("bias", "1D Tensor", "Bias parameter.") | ||
.add_arguments(ConvTransposeParam::__FIELDS__()) | ||
.set_attr_parser(ParamParser<ConvTransposeParam>) | ||
.set_num_outputs(1) | ||
.set_num_inputs(UseBiasNumInputs<ConvTransposeParam>) | ||
.set_attr<FListInputNames>("FListInputNames", UseBiasListInputNames<ConvTransposeParam>) | ||
.set_attr<FInferShape>("FInferShape", ConvTransposeInferShape) | ||
.set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>) | ||
.set_support_level(2); | ||
|
||
} // namespace top | ||
} // namespace nnvm |
Oops, something went wrong.