Skip to content

Commit

Permalink
Introduce NodePtr (apache#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed May 26, 2018
1 parent bc82192 commit 512bac9
Show file tree
Hide file tree
Showing 6 changed files with 6 additions and 8 deletions.
1 change: 1 addition & 0 deletions nnvm/include/nnvm/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <dmlc/base.h>
#include <dmlc/any.h>
#include <dmlc/memory.h>
#include <dmlc/logging.h>
#include <dmlc/registry.h>
#include <dmlc/array_view.h>
Expand Down
2 changes: 1 addition & 1 deletion nnvm/include/nnvm/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ void PostOrderDFSVisit(const std::vector<GNode>& heads,
template<typename FVisit>
inline void DFSVisit(const std::vector<NodeEntry>& heads,
FVisit fvisit) {
typedef const std::shared_ptr<Node>* GNode;
typedef const NodePtr* GNode;
std::vector<GNode> head_nodes(heads.size());
std::transform(heads.begin(), heads.end(), head_nodes.begin(),
[](const NodeEntry& e)->GNode {
Expand Down
3 changes: 1 addition & 2 deletions nnvm/include/nnvm/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ class Node;

/*!
* \brief we always used NodePtr for a reference pointer
* to the node, so this alias can be changed in case we need
* even faster graph composition than 3M ops/sec.
* to the node, so this alias can be changed in case.
*
* By default, NodePtr is a std::shared_ptr of node
*/
Expand Down
2 changes: 1 addition & 1 deletion nnvm/src/core/graph_attr_types.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ IndexedGraph::IndexedGraph(const Graph &g) {
std::vector<size_t> inputs_rptr{0}, control_rptr{0};

DFSVisit(g.outputs, [this, &inputs_rptr, &control_rptr]
(const std::shared_ptr<nnvm::Node>& n) {
(const NodePtr& n) {
CHECK_LT(nodes_.size(), std::numeric_limits<uint32_t>::max());
uint32_t nid = static_cast<uint32_t>(nodes_.size());
// nodes_
Expand Down
4 changes: 1 addition & 3 deletions nnvm/src/core/node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Node::~Node() {
// explicit deletion via DFS
// this is used to avoid stackoverflow caused by chain of deletions
std::vector<Node*> stack{this};
std::vector<std::shared_ptr<Node> > to_delete;
std::vector<NodePtr> to_delete;
while (!stack.empty()) {
Node* n = stack.back();
stack.pop_back();
Expand All @@ -37,8 +37,6 @@ Node::~Node() {
}

NodePtr Node::Create() {
// NOTE: possible change to thread local memory pool
// via std::allocate_shared instead for faster allocation.
return std::make_shared<Node>();
}

Expand Down
2 changes: 1 addition & 1 deletion nnvm/src/test_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ void test_node_speed() {
auto add = nnvm::Op::Get("add");
double tstart = dmlc::GetTime();
size_t rep = 1000;
size_t n = 100;
size_t n = 1000;
for (size_t t = 0; t < rep; ++t) {
nnvm::Symbol s = nnvm::Symbol::CreateVariable("x");
for (size_t i = 0; i < n; ++i) {
Expand Down

0 comments on commit 512bac9

Please sign in to comment.