From 8a6a15196a9b821169e7da7eee617563705beef9 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Wed, 24 Aug 2016 23:26:45 -0700 Subject: [PATCH] update (#26) * updates (#1) * add scalars * change format * change inferattr interface * remove scalar * remove warning --- nnvm/example/src/operator.cc | 30 +++++++++++++------------- nnvm/include/nnvm/op.h | 32 +++++++++++++++++++++++++++ nnvm/include/nnvm/op_attr_types.h | 4 ++-- nnvm/src/c_api/c_api_symbolic.cc | 17 ++++++++++++++- nnvm/src/core/symbolic.cc | 5 ++++- nnvm/src/pass/infer_shape_type.cc | 36 +++++++++++++++++++------------ nnvm/tests/python/test_symbol.py | 1 - 7 files changed, 91 insertions(+), 34 deletions(-) diff --git a/nnvm/example/src/operator.cc b/nnvm/example/src/operator.cc index 4fac82d243423..c0c729cdb4dd7 100644 --- a/nnvm/example/src/operator.cc +++ b/nnvm/example/src/operator.cc @@ -21,14 +21,14 @@ using nnvm::array_view; // simply return the shape as same inline bool SameShape(const NodeAttrs& attrs, - array_view ishape, - array_view oshape) { - if (ishape.size() == 0 || ishape[0]->ndim() == 0) return false; - for (TShape* pshape : oshape) { - *pshape = *ishape[0]; + std::vector *ishape, + std::vector *oshape) { + if (ishape->size() == 0 || (*ishape)[0].ndim() == 0) return false; + for (TShape& pshape : *oshape) { + pshape = (*ishape)[0]; } - for (TShape* pshape : ishape) { - *pshape = *ishape[0]; + for (TShape& pshape : *ishape) { + pshape = (*ishape)[0]; } return true; } @@ -51,13 +51,13 @@ NNVM_REGISTER_OP(reshape) }) .attr( "FInferShape", [] (const NodeAttrs& attrs, - array_view ishape, - array_view oshape) { + std::vector *ishape, + std::vector *oshape) { // get parsed attribute const TShape& target = nnvm::get(attrs.parsed); - *oshape[0] = target; - if (ishape[0]->ndim() == 0) return false; - CHECK_EQ(ishape[0]->Size(), target.Size()) + (*oshape)[0] = target; + if ((*ishape)[0].ndim() == 0) return false; + CHECK_EQ((*ishape)[0].Size(), target.Size()) << "Reshape op: source target shape mismatch"; return true; }) @@ -78,9 +78,9 @@ NNVM_REGISTER_OP(cast) .attr("FInferShape", SameShape) .attr( "FInferType", [](const NodeAttrs& attrs, - array_view itype, - array_view otype) { - *otype[0] = nnvm::get(attrs.parsed); + std::vector *itype, + std::vector *otype) { + (*otype)[0] = nnvm::get(attrs.parsed); return true; }); diff --git a/nnvm/include/nnvm/op.h b/nnvm/include/nnvm/op.h index d072b9f7b4200..721e8e736e09b 100644 --- a/nnvm/include/nnvm/op.h +++ b/nnvm/include/nnvm/op.h @@ -6,6 +6,7 @@ #ifndef NNVM_OP_H_ #define NNVM_OP_H_ +#include #include #include #include @@ -22,6 +23,7 @@ struct NodeAttrs; template class OpMap; class OpRegistryEntry; +using dmlc::ParamFieldInfo; /*! \brief constant to indicate it take any length of positional inputs */ static const uint32_t kVarg = std::numeric_limits::max(); @@ -80,6 +82,8 @@ class Op { * This can be used to generate docstring automatically for the operator. */ std::string description; + /* \brief description of inputs and keyword arguments*/ + std::vector arguments; /*! * \brief number of inputs to the operator, * -1 means it is variable length @@ -149,6 +153,22 @@ class Op { * \return reference to self. */ inline Op& describe(const std::string& descr); // NOLINT(*) + /*! + * \brief Add argument information to the function. + * \param name Name of the argument. + * \param type Type of the argument. + * \param description Description of the argument. + * \return reference to self. + */ + inline Op& add_argument(const std::string &name, + const std::string &type, + const std::string &description); + /*! + * \brief Append list if arguments to the end. + * \param args Additional list of arguments. + * \return reference to self. + */ + inline Op& add_arguments(const std::vector &args); /*! * \brief Set the num_inputs * \param n The number of inputs to be set. @@ -340,6 +360,18 @@ inline Op& Op::describe(const std::string& descr) { // NOLINT(*) return *this; } +inline Op& Op::add_argument(const std::string &name, + const std::string &type, + const std::string &description) { + arguments.push_back({name, type, type, description}); + return *this; +} + +inline Op& Op::add_arguments(const std::vector &args) { + this->arguments.insert(arguments.end(), args.begin(), args.end()); + return *this; +} + inline Op& Op::set_num_inputs(uint32_t n) { // NOLINT(*) this->num_inputs = n; return *this; diff --git a/nnvm/include/nnvm/op_attr_types.h b/nnvm/include/nnvm/op_attr_types.h index 2cb59ea9495dd..675b93a6c9d2f 100644 --- a/nnvm/include/nnvm/op_attr_types.h +++ b/nnvm/include/nnvm/op_attr_types.h @@ -57,8 +57,8 @@ using FMutateInputs = std::function (const NodeAttrs& attr */ template using FInferNodeEntryAttr = std::function in_attrs, - array_view out_attrs)>; + std::vector *in_attrs, + std::vector *out_attrs)>; /*! * \brief Shape inference function. * Update the shapes given the input shape information. diff --git a/nnvm/src/c_api/c_api_symbolic.cc b/nnvm/src/c_api/c_api_symbolic.cc index 3dbb816d17296..dcdce820c7e60 100644 --- a/nnvm/src/c_api/c_api_symbolic.cc +++ b/nnvm/src/c_api/c_api_symbolic.cc @@ -28,11 +28,26 @@ int NNSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator, const char ***arg_descriptions, const char **return_type) { const Op *op = static_cast(creator); + NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); API_BEGIN(); *name = op->name.c_str(); *description = op->description.c_str(); - *num_doc_args = 0; + *num_doc_args = static_cast(op->arguments.size()); + if (return_type) *return_type = nullptr; + ret->ret_vec_charp.clear(); + for (size_t i = 0; i < op->arguments.size(); ++i) { + ret->ret_vec_charp.push_back(op->arguments[i].name.c_str()); + } + for (size_t i = 0; i < op->arguments.size(); ++i) { + ret->ret_vec_charp.push_back(op->arguments[i].type_info_str.c_str()); + } + for (size_t i = 0; i < op->arguments.size(); ++i) { + ret->ret_vec_charp.push_back(op->arguments[i].description.c_str()); + } + *arg_names = dmlc::BeginPtr(ret->ret_vec_charp); + *arg_type_infos = dmlc::BeginPtr(ret->ret_vec_charp) + op->arguments.size(); + *arg_descriptions = dmlc::BeginPtr(ret->ret_vec_charp) + (op->arguments.size() * 2); API_END(); } diff --git a/nnvm/src/core/symbolic.cc b/nnvm/src/core/symbolic.cc index 09e6c63134338..d595880aed1ce 100644 --- a/nnvm/src/core/symbolic.cc +++ b/nnvm/src/core/symbolic.cc @@ -151,7 +151,10 @@ void Symbol::Print(std::ostream &os) const { } if (!node->attrs.dict.empty()) { os << "Attrs:\n"; - for (auto &kv : node->attrs.dict) { + // make an ordered copy because unordered_map doesn't guarantee order. + std::map sorted_dict( + node->attrs.dict.begin(), node->attrs.dict.end()); + for (auto &kv : sorted_dict) { os << '\t' << kv.first << '=' << kv.second << '\n'; } } diff --git a/nnvm/src/pass/infer_shape_type.cc b/nnvm/src/pass/infer_shape_type.cc index 265e576f9eb27..5978ecdb79f2e 100644 --- a/nnvm/src/pass/infer_shape_type.cc +++ b/nnvm/src/pass/infer_shape_type.cc @@ -47,44 +47,52 @@ Graph InferAttr(Graph &&ret, } // temp space for shape inference. - std::vector ishape, oshape; + std::vector ishape, oshape; // number of completed nodes size_t num_unknown = 0; for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { const auto& inode = idx[nid]; + uint32_t num_inputs = inode.inputs.size(); + uint32_t num_outputs = inode.source->num_outputs(); if (inode.source->is_variable()) { if (shape_attr_key.length() != 0 && fis_none(rshape[idx.entry_id(nid, 0)])) { auto it = inode.source->attrs.dict.find(shape_attr_key); if (it != inode.source->attrs.dict.end()) { - CHECK_EQ(inode.source->num_outputs(), 1); + CHECK_EQ(num_outputs, 1); std::istringstream is(it->second); CHECK(is >> rshape[idx.entry_id(nid, 0)]) << "Invalid attribute"; } } continue; } - ishape.resize(inode.inputs.size()); - for (uint32_t i = 0; i < ishape.size(); ++i) { - ishape[i] = &rshape[idx.entry_id(inode.inputs[i])]; - } - oshape.resize(inode.source->num_outputs()); - for (uint32_t i = 0; i < oshape.size(); ++i) { - oshape[i] = &rshape[idx.entry_id(nid, i)]; - } if (finfer_shape.count(inode.source->op)) { + ishape.resize(num_inputs, def_value); + for (uint32_t i = 0; i < ishape.size(); ++i) { + ishape[i] = rshape[idx.entry_id(inode.inputs[i])]; + } + oshape.resize(num_outputs, def_value); + for (uint32_t i = 0; i < oshape.size(); ++i) { + oshape[i] = rshape[idx.entry_id(nid, i)]; + } num_unknown += - !(finfer_shape[inode.source->op](inode.source->attrs, ishape, oshape)); + !(finfer_shape[inode.source->op](inode.source->attrs, &ishape, &oshape)); + for (uint32_t i = 0; i < num_inputs; ++i) { + rshape[idx.entry_id(inode.inputs[i])] = ishape[i]; + } + for (uint32_t i = 0; i < num_outputs; ++i) { + rshape[idx.entry_id(nid, i)] = oshape[i]; + } } else if (is_backward.get(inode.source->op, false)) { // backward operator inference. CHECK_GE(inode.control_deps.size(), 1) << "BackwardOp need to have control_deps to its forward op"; const auto& fnode = idx[inode.control_deps[0]]; - CHECK_EQ(fnode.inputs.size(), inode.source->num_outputs()) + CHECK_EQ(fnode.inputs.size(), num_outputs) << "BackwardOp need to correspond to the forward node"; bool known = true; for (size_t i = 0; i < fnode.inputs.size(); ++i) { - *oshape[i] = rshape[idx.entry_id(fnode.inputs[i])]; - if (fis_none(*oshape[i])) known = false; + rshape[idx.entry_id(nid, i)] = rshape[idx.entry_id(fnode.inputs[i])]; + if (fis_none(rshape[idx.entry_id(nid, i)])) known = false; } num_unknown += !known; } diff --git a/nnvm/tests/python/test_symbol.py b/nnvm/tests/python/test_symbol.py index a754f0f60fde1..994f0d14a6057 100644 --- a/nnvm/tests/python/test_symbol.py +++ b/nnvm/tests/python/test_symbol.py @@ -41,7 +41,6 @@ def test_copy(): z = sym.Variable('z') y = sym.exp(sym.add(x, x, name='add', gpu=2), name='exp', gpu=1, attr={"kk": "1"}) - assert y.__copy__().debug_str() == y.debug_str() if __name__ == "__main__":