From 16063d8e8b253c6fa27ea4518b405081fa0c939c Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sat, 23 Sep 2017 20:52:04 -0700 Subject: [PATCH] [COMPILER] GraphHash based cache system, allow dump and query duplicated functions. (#30) --- nnvm/include/nnvm/graph.h | 4 +- nnvm/include/nnvm/pass_functions.h | 11 + nnvm/python/nnvm/compiler/__init__.py | 2 + nnvm/python/nnvm/compiler/build_module.py | 2 +- nnvm/python/nnvm/compiler/compile_engine.py | 99 +++++++ nnvm/python/nnvm/testing/__init__.py | 2 + nnvm/python/nnvm/testing/config.py | 2 +- nnvm/src/compiler/compile_engine.cc | 272 ++++++++++++++++++ nnvm/src/compiler/compile_engine.h | 96 +++++++ nnvm/src/compiler/graph_fuse.cc | 238 ++++++++------- .../{graph_deep_compare.cc => graph_hash.cc} | 108 ++++++- nnvm/src/compiler/graph_hash.h | 82 ++++++ nnvm/src/core/graph.cc | 2 +- nnvm/src/pass/print_graph_ir.cc | 4 +- .../python/compiler/test_compiler_cache.py | 42 +++ nnvm/tests/python/compiler/test_op_fusion.py | 8 +- nnvm/tests/python/compiler/test_top_level1.py | 18 +- nnvm/tests/python/compiler/test_top_level2.py | 6 +- nnvm/tests/python/compiler/test_top_level4.py | 7 +- 19 files changed, 856 insertions(+), 149 deletions(-) create mode 100644 nnvm/python/nnvm/compiler/compile_engine.py create mode 100644 nnvm/src/compiler/compile_engine.cc create mode 100644 nnvm/src/compiler/compile_engine.h rename nnvm/src/compiler/{graph_deep_compare.cc => graph_hash.cc} (52%) create mode 100644 nnvm/src/compiler/graph_hash.h create mode 100644 nnvm/tests/python/compiler/test_compiler_cache.py diff --git a/nnvm/include/nnvm/graph.h b/nnvm/include/nnvm/graph.h index 514705f65acd..a555ecd68a65 100644 --- a/nnvm/include/nnvm/graph.h +++ b/nnvm/include/nnvm/graph.h @@ -63,11 +63,11 @@ class Graph { * \return The indexed graph. * \sa IndexedGraph */ - const IndexedGraph& indexed_graph(); + const IndexedGraph& indexed_graph() const; private: // internal structure of indexed graph - std::shared_ptr indexed_graph_; + mutable std::shared_ptr indexed_graph_; }; /*! diff --git a/nnvm/include/nnvm/pass_functions.h b/nnvm/include/nnvm/pass_functions.h index 56ac11bab8e6..e4884a196e79 100644 --- a/nnvm/include/nnvm/pass_functions.h +++ b/nnvm/include/nnvm/pass_functions.h @@ -41,6 +41,17 @@ inline std::string SaveJSON(Graph graph) { return ret.GetAttr("json"); } + +/*! + * \brief Print graph ir + * \param graph The graph to be printed + * \return The graph ir string. + */ +inline std::string PrintGraphIR(Graph graph) { + Graph ret = ApplyPass(std::move(graph), "PrintGraphIR"); + return ret.GetAttr("graphir"); +} + /*! * \brief Add control flow dependencies between nodes. * diff --git a/nnvm/python/nnvm/compiler/__init__.py b/nnvm/python/nnvm/compiler/__init__.py index 1d8b9219b4f7..08cdb1850a39 100644 --- a/nnvm/python/nnvm/compiler/__init__.py +++ b/nnvm/python/nnvm/compiler/__init__.py @@ -5,6 +5,7 @@ from . import build_module from . build_module import build, optimize, build_config +from . compile_engine import engine, graph_key from .. import symbol as _symbol from .. import graph as _graph @@ -14,5 +15,6 @@ from .. import top as _top + tvm.register_extension(_symbol.Symbol, _symbol.Symbol) tvm.register_extension(_graph.Graph, _graph.Graph) diff --git a/nnvm/python/nnvm/compiler/build_module.py b/nnvm/python/nnvm/compiler/build_module.py index 27d86e9a6967..18b2a163c692 100644 --- a/nnvm/python/nnvm/compiler/build_module.py +++ b/nnvm/python/nnvm/compiler/build_module.py @@ -184,7 +184,7 @@ def build(graph, target, shape, dtype="float32", params=None): graph._set_json_attr("target", target, "str") graph._set_json_attr("opt_level", cfg.opt_level, "int") graph = graph.apply("InferShape").apply("InferType") - graph = graph.apply("GraphFusePartition").apply("GraphFuse") + graph = graph.apply("GraphFusePartition").apply("GraphFuseCompile") libmod = graph_attr._move_out_module(graph, "module") return graph, libmod, params diff --git a/nnvm/python/nnvm/compiler/compile_engine.py b/nnvm/python/nnvm/compiler/compile_engine.py new file mode 100644 index 000000000000..426c56d2d9c3 --- /dev/null +++ b/nnvm/python/nnvm/compiler/compile_engine.py @@ -0,0 +1,99 @@ +# pylint: disable=invalid-name +"""Compiler engine interface to internal engine""" +import tvm + +_list_cache_items = tvm.get_global_func("nnvm.compiler.ListCacheItems") +_clear_cache = tvm.get_global_func("nnvm.compiler.ClearCache") +_get_cache_item = tvm.get_global_func("nnvm.compiler.GetCacheItem") +_set_cache_item = tvm.get_global_func("nnvm.compiler.SetCacheItem") +_graph_key_get_graph = tvm.get_global_func("nnvm.compiler.GraphKeyGetGraph") +_make_graph_key = tvm.get_global_func("nnvm.compiler.MakeGraphKey") + +@tvm.register_node +class GraphKey(tvm.node.NodeBase): + """Key of a graph compilation context""" + @property + def graph(self): + return _graph_key_get_graph(self) + + +@tvm.register_node +class GraphCacheEntry(tvm.node.NodeBase): + """CacheEntry of compilation into a TVM Function""" + pass + + +@tvm.register_node +class GraphFunc(tvm.node.NodeBase): + """Compiled result of a graph into a TVM Function""" + pass + + +class Engine(object): + """Global singleton compilation engine.""" + def items(self): + """List the available cache key value pairs. + + Returns + ------- + item_list : list of (GraphKey, GraphCacheEntry) + The existing cache items + """ + res = _list_cache_items() + assert len(res) % 2 == 0 + return [(res[2*i], res[2*i+1]) for i in range(len(res)/2)] + + def clear_cache(self): + """Clear the existing cached functions.""" + _clear_cache() + + def __setitem__(self, key, value): + """Clear the existing cached functions.""" + if isinstance(value, GraphCacheEntry): + _set_cache_item(key, value.graph_func) + else: + _set_cache_item(key, value) + + def __getitem__(self, key): + """Clear the existing cached functions.""" + return _get_cache_item(key) + + def dump(self): + """Return a string representation of engine dump + + Returns + ------- + dump : str + The dumped string representation + """ + items = self.items() + res = "====================================\n" + res += "CompilerEngine dump, %d items cached\n" % len(items) + for key, value in items: + res += "------------------------------------\n" + res += "target={}\n".format(key.target) + res += "inputs={}\n".format(key.inputs) + res += "use_count={}\n".format(value.use_count) + res += "func_name={}\n".format(value.graph_func.func_name) + res += key.graph.ir() + "\n" + res += "===================================\n" + return res + +engine = Engine() + + +def graph_key(graph, inputs, target): + """Construct a new graph key. + + Parameters + ---------- + graph : Graph + The computation graph structure + + inputs : list of Tensor(placeholder) + The input requirement to the graph. + + target : str + The target of compilation. + """ + return _make_graph_key(graph, inputs, target) diff --git a/nnvm/python/nnvm/testing/__init__.py b/nnvm/python/nnvm/testing/__init__.py index 1241e403b1a0..6dd015d872ea 100644 --- a/nnvm/python/nnvm/testing/__init__.py +++ b/nnvm/python/nnvm/testing/__init__.py @@ -1 +1,3 @@ """Utilities for testcase""" + +from .config import ctx_list diff --git a/nnvm/python/nnvm/testing/config.py b/nnvm/python/nnvm/testing/config.py index 26d1d41014cf..a96e4b4ea8e1 100644 --- a/nnvm/python/nnvm/testing/config.py +++ b/nnvm/python/nnvm/testing/config.py @@ -2,7 +2,7 @@ import os import tvm -def test_ctx_list(): +def ctx_list(): """Get context list for testcases""" device_list = os.environ.get("NNVM_TEST_TARGETS", "") device_list = (device_list.split(",") if device_list diff --git a/nnvm/src/compiler/compile_engine.cc b/nnvm/src/compiler/compile_engine.cc new file mode 100644 index 000000000000..d31612f5a826 --- /dev/null +++ b/nnvm/src/compiler/compile_engine.cc @@ -0,0 +1,272 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file compile_engine.cc + * \brief The compile engine. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include "./graph_hash.h" +#include "./compile_engine.h" + +namespace nnvm { +namespace compiler { + +using namespace tvm; + +/*! + * \brief Get type flag from TVM Type + * + * \param type the tvm type. + * \return corresponding DLDataType + */ +int GetTypeFlag(tvm::Type type) { + if (type == tvm::Float(32)) return 0; + LOG(FATAL) << "cannot convert " << type; + return 0; +} +// convert from type flag to tvm type. +Type GetTVMType(int type_flag) { + if (type_flag == 0) return tvm::Float(32); + LOG(FATAL) << "unknown type_flag=" << type_flag; + return Float(32); +} + +// internal compile engine +class CompileEngine { + public: + static CompileEngine* Global() { + static CompileEngine inst; + return &inst; + } + // lower graph possible get back an cached op. + GraphFunc Lower(Graph graph, + const Array& inputs, + const std::string& target, + const Op* schedule_op_key, + const NodeAttrs& schedule_op_attr) { + GraphKey key = GraphKeyNode::make(graph, inputs, target); + std::lock_guard lock(mutex_); + auto it = cache_.find(key); + if (it != cache_.end()) { + ++(it->second->use_count); + return it->second->graph_func; + } + GraphFunc f = DoLower(key->graph, key->inputs, key->target, + schedule_op_key, schedule_op_attr); + std::shared_ptr n = std::make_shared(); + n->graph_func = f; + n->use_count = 1; + cache_[key] = GraphCacheEntry(n); + return f; + } + // List all items in the cache. + Array ListCacheItems() { + std::lock_guard lock(mutex_); + Array items; + for (auto& kv : cache_) { + items.push_back(kv.first); + std::shared_ptr n = + std::make_shared(*(kv.second.operator->())); + items.push_back(GraphCacheEntry(n)); + } + return items; + } + // Find the function given graph key. + GraphCacheEntry Find(const GraphKey& key) { + std::lock_guard lock(mutex_); + auto it = cache_.find(key); + if (it != cache_.end()) { + return it->second; + } else { + return GraphCacheEntry(); + } + } + // Find the function given graph key. + void Set(const GraphKey& key, GraphFunc func) { + std::lock_guard lock(mutex_); + std::shared_ptr n = std::make_shared(); + n->graph_func = func; + n->use_count = 1; + cache_[key] = GraphCacheEntry(n); + } + // Find the function given graph key. + void Clear() { + std::lock_guard lock(mutex_); + cache_.clear(); + } + // run the actual lowering process + GraphFunc DoLower(Graph graph, + const Array& inputs, + const std::string& target, + const Op* schedule_op_key, + const NodeAttrs& schedule_op_attr) { + // shape, type + static auto& fcompute = + nnvm::Op::GetAttr("FTVMCompute"); + static auto& fschedule = + nnvm::Op::GetAttr("FTVMSchedule"); + + std::vector ishape; + std::vector idtype; + + for (const tvm::Tensor t : inputs) { + std::vector shape; + for (Expr v : t->shape) { + CHECK(v.as()); + shape.push_back(v.as()->value); + } + ishape.emplace_back(TShape(shape.begin(), shape.end())); + idtype.emplace_back(GetTypeFlag(t->dtype)); + } + graph = pass::InferShape(graph, ishape); + graph = pass::InferType(graph, idtype); + + const ShapeVector& shape_vec = graph.GetAttr("shape"); + const DTypeVector& dtype_vec = graph.GetAttr("dtype"); + const IndexedGraph& idx = graph.indexed_graph(); + CHECK_EQ(inputs.size(), idx.input_nodes().size()); + + std::vector tensor_vec(idx.num_node_entries()); + for (size_t i = 0; i < idx.input_nodes().size(); ++i) { + uint32_t nid = idx.input_nodes()[i]; + tensor_vec[idx.entry_id(nid, 0)] = inputs[i]; + } + + std::ostringstream readable_name; + readable_name << "fuse"; + for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { + const auto& inode = idx[nid]; + if (inode.source->is_variable()) continue; + Array inputs, out_info; + readable_name << "_" << inode.source->op()->name; + // input array + for (const IndexedGraph::NodeEntry& e : inode.inputs) { + const tvm::Tensor& t = tensor_vec[idx.entry_id(e)]; + CHECK(t.defined()); + inputs.push_back(t); + } + // output hint + for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) { + Array shape; + for (int64_t x : shape_vec[idx.entry_id(nid, i)]) { + CHECK_LE(x, static_cast(std::numeric_limits::max())); + shape.push_back(make_const(Int(32), x)); + } + out_info.push_back( + placeholder(shape, + GetTVMType(dtype_vec[idx.entry_id(nid, i)]))); + } + // get default + Array out = fcompute[inode.source->op()]( + inode.source->attrs, inputs, out_info); + CHECK_EQ(out.size(), inode.source->num_outputs()); + // schedule on root node, and use master's schedule + for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) { + uint32_t eid = idx.entry_id(nid, index); + tensor_vec[eid] = out[index]; + } + } + // Schedule on final output. + Array outputs; + Array all_args = inputs; + for (const IndexedGraph::NodeEntry& e : idx.outputs()) { + const tvm::Tensor& t = tensor_vec[idx.entry_id(e)]; + CHECK(t.defined()); + outputs.push_back(t); + all_args.push_back(t); + } + Schedule sch = fschedule[schedule_op_key]( + schedule_op_attr, outputs, target); + std::shared_ptr gf = std::make_shared(); + gf->target = target; + gf->func_name = GetUniqeName(readable_name.str()); + gf->inputs = inputs; + gf->outputs = outputs; + static const PackedFunc& flower = GetPackedFunc("nnvm.compiler.lower"); + gf->funcs = flower(sch, all_args, gf->func_name); + return GraphFunc(gf); + } + + private: + // Get unique name + std::string GetUniqeName(std::string name) { + while (true) { + auto it = name_map_.find(name); + if (it == name_map_.end()) { + name_map_[name] = 1; + return name; + } else { + std::ostringstream os; + os << name << "_" << it->second; + ++(it->second); + name = os.str(); + } + } + return name; + } + + // global mutex + std::mutex mutex_; + // the name map + std::unordered_map name_map_; + // the compiler cache + std::unordered_map cache_; +}; + +GraphFunc GraphLower(Graph graph, + const Array& inputs, + const std::string& target, + const Op* schedule_op_key, + const NodeAttrs& schedule_op_attr) { + return CompileEngine::Global()->Lower( + graph, inputs, target, schedule_op_key, schedule_op_attr); +} + +// Expose cache to front end +TVM_REGISTER_GLOBAL("nnvm.compiler.ListCacheItems") +.set_body([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue *rv) { + *rv = CompileEngine::Global()->ListCacheItems(); + }); + +TVM_REGISTER_GLOBAL("nnvm.compiler.ClearCache") +.set_body([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue *rv) { + CompileEngine::Global()->Clear(); + }); + +// NOTE: this involves graph lookup and can be slow +TVM_REGISTER_GLOBAL("nnvm.compiler.GetCacheItem") +.set_body([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue *rv) { + *rv = CompileEngine::Global()->Find(args[0]); + }); + +TVM_REGISTER_GLOBAL("nnvm.compiler.SetCacheItem") +.set_body([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue *rv) { + CompileEngine::Global()->Set(args[0], args[1]); + }); + +TVM_REGISTER_GLOBAL("nnvm.compiler.GraphKeyGetGraph") +.set_body([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue *rv) { + *rv = args[0].operator GraphKey()->graph; + }); + +TVM_REGISTER_GLOBAL("nnvm.compiler.MakeGraphKey") +.set_body([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue *rv) { + *rv = GraphKeyNode::make(args[0], args[1], args[2]); + }); + + +TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) +.set_dispatch([](const GraphFuncNode *op, IRPrinter *p) { + p->stream << "GraphFunc(name=" << op->func_name + << ", addr=" << op << ")"; +}); + +} // namespace compiler +} // namespace nnvm diff --git a/nnvm/src/compiler/compile_engine.h b/nnvm/src/compiler/compile_engine.h new file mode 100644 index 000000000000..91ab2588a593 --- /dev/null +++ b/nnvm/src/compiler/compile_engine.h @@ -0,0 +1,96 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file compile_engine.h + * \brief Internal engine to compile a subgraph fragment and cache compilation. + */ +#ifndef NNVM_COMPILER_COMPILE_ENGINE_H_ +#define NNVM_COMPILER_COMPILE_ENGINE_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "./graph_hash.h" + +namespace nnvm { +namespace compiler { + +/*! \brief A TVM Node to represent compiled graph function */ +struct GraphFuncNode : public tvm::Node { + /* \brief compiled target */ + std::string target; + /*! \brief Function name */ + std::string func_name; + /* \brief The inputs to the function */ + tvm::Array inputs; + /* \brief The outputs to the function */ + tvm::Array outputs; + /*! \brief The lowered functions */ + tvm::Array funcs; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("target", &target); + v->Visit("func_name", &func_name); + v->Visit("inputs", &inputs); + v->Visit("outputs", &outputs); + v->Visit("funcs", &funcs); + } + + static constexpr const char* _type_key = "GraphFunc"; + TVM_DECLARE_NODE_TYPE_INFO(GraphFuncNode, tvm::Node); +}; + +TVM_DEFINE_NODE_REF(GraphFunc, GraphFuncNode); + +/*! \brief Cache Entry in the graph */ +struct GraphCacheEntryNode : public tvm::Node { + /*! \brief The graph function */ + GraphFunc graph_func; + /*! \brief Usage statistics */ + int use_count{0}; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("graph_func", &graph_func); + v->Visit("use_count", &use_count); + } + static constexpr const char* _type_key = "GraphCacheEntry"; + TVM_DECLARE_NODE_TYPE_INFO(GraphCacheEntryNode, tvm::Node); +}; + +class GraphCacheEntry : public ::tvm::NodeRef { + public: + GraphCacheEntry() {} + explicit GraphCacheEntry(std::shared_ptr<::tvm::Node> n) : NodeRef(n) {} + GraphCacheEntryNode* operator->() { + return static_cast(node_.get()); + } + using ContainerType = GraphCacheEntryNode; +}; + +/*! + * \brief Call compile engine to lower a graph with given inputs. + * + * \param graph The graph to be compiled + * \param inputs The input specification. + * \param schedule_op_key The hint key for the schedule. + * \param schedule_op_attr The hint attribute for the schedule. + * + * \return func A lowered tvm function. + */ +GraphFunc GraphLower(Graph graph, + const Array& inputs, + const std::string& target, + const Op* schedule_op_key, + const NodeAttrs& schedule_op_attr); + +} // namespace compiler +} // namespace nnvm + +#endif // NNVM_COMPILER_COMPILE_ENGINE_H_ diff --git a/nnvm/src/compiler/graph_fuse.cc b/nnvm/src/compiler/graph_fuse.cc index e3935ed95ff1..bd7cdbae4e98 100644 --- a/nnvm/src/compiler/graph_fuse.cc +++ b/nnvm/src/compiler/graph_fuse.cc @@ -4,15 +4,16 @@ * \brief Fuse the operators together. */ #include +#include #include #include #include #include -#include +#include #include #include -#include #include +#include "./compile_engine.h" #include "../runtime/graph_executor.h" namespace nnvm { @@ -20,8 +21,6 @@ namespace compiler { using namespace tvm; -using DLTypeVector = std::vector; - // The single fuse rule. enum class FuseRule { kUknown, @@ -29,8 +28,14 @@ enum class FuseRule { kRealize }; +/*! + * \brief Get DLDataType from dtype flag. + * + * \param type_flag The data type flag + * \return corresponding DLDataType + */ DLDataType GetDLType(int type_flag) { - if (type_flag == 0) return Type2TVMType(Float(32)); + if (type_flag == 0) return tvm::Type2TVMType(tvm::Float(32)); LOG(FATAL) << "unknown type_flag=" << type_flag; return Type2TVMType(Float(32)); } @@ -48,13 +53,6 @@ nnvm::Graph GraphFusePartition(nnvm::Graph g) { // Get attributes from the graph const ShapeVector& shape_vec = g.GetAttr("shape"); - const DTypeVector& dtype_vec = g.GetAttr("dtype"); - // Transform to dltype - // In future, directly fo type inference in dltype. - DLTypeVector dltype_vec = DLTypeVector(dtype_vec.size()); - for (size_t i = 0; i < dtype_vec.size(); ++i) { - dltype_vec[i] = GetDLType(dtype_vec[i]); - } // Reference counter of each op node // For now, always store result when an op is referred more than once. @@ -174,7 +172,6 @@ nnvm::Graph GraphFusePartition(nnvm::Graph g) { g.attrs["group_root"] = std::make_shared(std::move(group_vec)); g.attrs["group_master"] = std::make_shared(std::move(master_vec)); g.attrs["pattern"] = std::make_shared(std::move(pattern_vec)); - g.attrs["dltype"] = std::make_shared(std::move(dltype_vec)); return g; } @@ -182,16 +179,15 @@ nnvm::Graph GraphFusePartition(nnvm::Graph g) { NNVM_REGISTER_PASS(GraphFusePartition) .set_body(GraphFusePartition) .depend_graph_attr("shape") -.depend_graph_attr("dtype") -.provide_graph_attr("dltype"); +.depend_graph_attr("dtype"); -struct NodeEntryHash { +struct INodeEntryHash { size_t operator()(const IndexedGraph::NodeEntry& e) const { return e.node_id; } }; -struct NodeEntryEqual { +struct INodeEntryEqual { size_t operator()(const IndexedGraph::NodeEntry& a, const IndexedGraph::NodeEntry& b) const { return a.node_id == b.node_id && a.index == b.index; @@ -200,30 +196,29 @@ struct NodeEntryEqual { // Auxiliary data structure for representing fused op. struct FuseEntry { - // The inputs - std::vector inputs; + // subgraph of the fragement + Graph subgraph; // The input map - std::unordered_map imap; - // Output tensors - Array outputs; - // Placeholder for inputs - Array placeholder; - // Computing schedule - Schedule schedule; - // Function name - std::string func_name; + std::unordered_map imap; + // reverse map to the old input entry + std::unordered_map reverse_imap; + // TVM Placeholder for inputs + std::unordered_map input_info; + // Whether we can flatten data + bool flatten_data; + // The corresponding function. + GraphFunc compiled_func; }; // Fuse the partitioned graph into segments. // Create a new graph with fused noded. // Also inheritate attribute shape, dltype from previous graph. -nnvm::Graph GraphFuse(nnvm::Graph g) { +nnvm::Graph GraphFuseCompile(nnvm::Graph g) { // setup ref counter const IndexedGraph& idx = g.indexed_graph(); // Get attributes from the graph const ShapeVector& shape_vec = g.GetAttr("shape"); - const DLTypeVector& dltype_vec = g.GetAttr("dltype"); const DTypeVector& dtype_vec = g.GetAttr("dtype"); const std::vector& group_vec = g.GetAttr >("group_root"); const std::vector& master_vec = g.GetAttr >("group_master"); @@ -238,11 +233,11 @@ nnvm::Graph GraphFuse(nnvm::Graph g) { CHECK_GE(group_vec[nid], 0); int root_id = group_vec[nid]; FuseEntry& fe = fuse_vec[root_id]; - TOpPattern pt = pattern_vec[root_id]; + fe.flatten_data = (pattern_vec[root_id] == kElemWise); for (const auto& e : inode.inputs) { if (group_vec[e.node_id] != root_id && fe.imap.count(e) == 0) { Array shape; - if (pt == kElemWise) { + if (fe.flatten_data) { // elementwise support flatten int64_t prod = 1; for (int64_t x : shape_vec[idx.entry_id(e)]) { @@ -257,93 +252,85 @@ nnvm::Graph GraphFuse(nnvm::Graph g) { } } std::ostringstream os_name; - os_name << "input" << fe.inputs.size(); + os_name << "input" << fe.imap.size(); Tensor data = placeholder( - shape, TVMType2Type(dltype_vec[idx.entry_id(e)]), + shape, TVMType2Type(GetDLType(dtype_vec[idx.entry_id(e)])), os_name.str()); - fe.imap[e] = data; - fe.inputs.push_back(e); - fe.placeholder.push_back(data); + NodeEntry garg = Symbol::CreateVariable(os_name.str()).outputs[0]; + fe.imap[e] = garg; + fe.reverse_imap[garg.node.get()] = e; + fe.input_info[garg.node.get()] = std::move(data); } } } - // Setup the Tensor - std::vector tensor_vec(idx.num_node_entries()); - static auto& fcompute = - nnvm::Op::GetAttr("FTVMCompute"); - static auto& fschedule = - nnvm::Op::GetAttr("FTVMSchedule"); + // Setup the Subgraph + std::vector subgraph_vec(idx.num_node_entries()); for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { const auto& inode = idx[nid]; if (inode.source->is_variable()) continue; int root_id = group_vec[nid]; FuseEntry& fe = fuse_vec[root_id]; - Array inputs, out_info; + // copy and create subgraph node. + NodePtr gnode = Node::Create(); + gnode->attrs = inode.source->attrs; // input loading for (const auto& e : inode.inputs) { if (group_vec[e.node_id] != root_id) { auto it = fe.imap.find(e); CHECK(it != fe.imap.end()); - inputs.push_back(it->second); + gnode->inputs.push_back(it->second); } else { - Tensor t = tensor_vec[idx.entry_id(e)]; - CHECK(t.defined()); - inputs.push_back(t); + const NodeEntry& ne = subgraph_vec[idx.entry_id(e)]; + CHECK(!idx[e.node_id].source->is_variable()); + CHECK(ne.node != nullptr); + gnode->inputs.push_back(ne); } } - // output hint - for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) { - Array shape; - for (int64_t x : shape_vec[idx.entry_id(nid, i)]) { - CHECK_LE(x, static_cast(std::numeric_limits::max())); - shape.push_back(make_const(Int(32), x)); - } - out_info.push_back( - placeholder(shape, - TVMType2Type(dltype_vec[idx.entry_id(nid, i)]))); - } - // get default - Array out = fcompute[inode.source->op()]( - inode.source->attrs, inputs, out_info); - CHECK_EQ(out.size(), inode.source->num_outputs()); // schedule on root node, and use master's schedule if (nid != root_id) { for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) { uint32_t eid = idx.entry_id(nid, index); - tensor_vec[eid] = out[index]; + subgraph_vec[eid] = NodeEntry{gnode, index, 0}; } } else { - fe.outputs = out; - int master = master_vec[root_id]; - CHECK_GE(master, 0); - fe.schedule = fschedule[idx[master].source->op()]( - idx[master].source->attrs, fe.outputs, target); - std::ostringstream os; - os << idx[master].source->attrs.name + "_id" << nid; - fe.func_name = os.str(); + for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) { + fe.subgraph.outputs.push_back(NodeEntry{gnode, index, 0}); + } } } - static const PackedFunc& flower = GetPackedFunc("nnvm.compiler.lower"); - static const PackedFunc& fbuild = GetPackedFunc("nnvm.compiler.build_target"); + // Start lowering + Array func_list; + std::unordered_set func_set; - Array funcs; - for (const FuseEntry& fe : fuse_vec) { - if (fe.schedule.defined()) { - Array args = fe.placeholder; - for (tvm::Tensor x : fe.outputs) { - args.push_back(x); - } - Array ret = flower(fe.schedule, args, fe.func_name); - for (LoweredFunc x : ret) { - funcs.push_back(x); + for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { + const auto& inode = idx[nid]; + if (inode.source->is_variable()) continue; + int root_id = group_vec[nid]; + if (nid != root_id) continue; + int master = master_vec[root_id]; + FuseEntry& fe = fuse_vec[root_id]; + + const IndexedGraph& subidx = fe.subgraph.indexed_graph(); + CHECK_EQ(subidx.input_nodes().size(), fe.imap.size()); + CHECK_EQ(subidx.input_nodes().size(), fe.input_info.size()); + + Array inputs; + for (uint32_t sub_input_id : subidx.input_nodes()) { + auto it = fe.input_info.find(subidx[sub_input_id].source); + inputs.push_back(it->second); + } + fe.compiled_func = GraphLower(fe.subgraph, inputs, target, + idx[master].source->op(), + idx[master].source->attrs); + for (LoweredFunc f : fe.compiled_func->funcs) { + if (!func_set.count(f.get())) { + func_set.insert(f.get()); + func_list.push_back(f); } } } - - tvm::runtime::Module module = fbuild(funcs, target); - // Final step: Remap the node, with given attribute + // Rebuild the fused graph const nnvm::Op* tvm_op = nnvm::Op::Get("tvm_op"); - std::unordered_map old_new; for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { const auto& inode = idx[nid]; @@ -351,36 +338,41 @@ nnvm::Graph GraphFuse(nnvm::Graph g) { nnvm::NodePtr np = nnvm::Node::Create(); np->attrs = inode.source->attrs; old_new[nid] = np; - } else { - int root_id = group_vec[nid]; - if (nid != root_id) continue; - FuseEntry& fe = fuse_vec[root_id]; - nnvm::NodePtr np = nnvm::Node::Create(); - np->attrs.op = tvm_op; - np->attrs.name = inode.source->attrs.name; - runtime::TVMOpParam param; - param.func_name = fuse_vec[nid].func_name; - param.num_inputs = static_cast(fe.inputs.size()); - param.num_outputs = static_cast(fe.outputs.size()); - param.flatten_data = pattern_vec[nid] == kElemWise; - param.UpdateDict(&(np->attrs.dict)); - np->attrs.parsed = std::move(param); - for (const auto& e : fe.inputs) { - auto it = old_new.find(e.node_id); - CHECK(it != old_new.end()) - << "cannot find node_id=" << e.node_id; - np->inputs.emplace_back( - nnvm::NodeEntry{it->second, e.index, e.version}); - } - for (const uint32_t node_id : inode.control_deps) { - auto it = old_new.find(node_id); - CHECK(it != old_new.end()); - np->control_deps.emplace_back(it->second); - } - old_new[nid] = np; + continue; } - } + int root_id = group_vec[nid]; + if (nid != root_id) continue; + FuseEntry& fe = fuse_vec[root_id]; + const IndexedGraph& subidx = fe.subgraph.indexed_graph(); + nnvm::NodePtr np = nnvm::Node::Create(); + np->attrs.op = tvm_op; + np->attrs.name = inode.source->attrs.name; + runtime::TVMOpParam param; + param.func_name = fe.compiled_func->func_name; + param.num_inputs = static_cast(fe.imap.size()); + param.num_outputs = static_cast(fe.subgraph.outputs.size()); + param.flatten_data = fe.flatten_data; + param.UpdateDict(&(np->attrs.dict)); + np->attrs.parsed = std::move(param); + for (uint32_t sub_input_id : subidx.input_nodes()) { + // Need to make sure subgraph input order meets order of the graph input + auto rit = fe.reverse_imap.find(subidx[sub_input_id].source); + CHECK(rit != fe.reverse_imap.end()); + const IndexedGraph::NodeEntry& e = rit->second; + auto it = old_new.find(e.node_id); + CHECK(it != old_new.end()) + << "cannot find node_id=" << e.node_id; + np->inputs.emplace_back( + nnvm::NodeEntry{it->second, e.index, e.version}); + } + for (const uint32_t node_id : inode.control_deps) { + auto it = old_new.find(node_id); + CHECK(it != old_new.end()); + np->control_deps.emplace_back(it->second); + } + old_new[nid] = np; + } nnvm::Graph ret; for (const auto& e : idx.outputs()) { auto it = old_new.find(group_vec[e.node_id]); @@ -389,6 +381,7 @@ nnvm::Graph GraphFuse(nnvm::Graph g) { ret.outputs.emplace_back( nnvm::NodeEntry{it->second, e.index, e.version}); } + const IndexedGraph& new_idx = ret.indexed_graph(); ShapeVector new_shape_vec = ShapeVector(new_idx.num_node_entries(), TShape()); DTypeVector new_dtype_vec = DTypeVector(new_idx.num_node_entries()); @@ -401,18 +394,23 @@ nnvm::Graph GraphFuse(nnvm::Graph g) { uint32_t old_eid = idx.entry_id(nid, i); new_shape_vec[new_eid] = shape_vec[old_eid]; new_dtype_vec[new_eid] = dtype_vec[old_eid]; - new_dltype_vec[new_eid] = tvm::runtime::TVMType2String(dltype_vec[old_eid]); + new_dltype_vec[new_eid] = tvm::runtime::TVMType2String( + GetDLType(dtype_vec[old_eid])); } } ret.attrs["shape"] = std::make_shared(std::move(new_shape_vec)); ret.attrs["dtype"] = std::make_shared(std::move(new_dtype_vec)); ret.attrs["dltype"] = std::make_shared(std::move(new_dltype_vec)); + // Setup module + static const PackedFunc& fbuild = GetPackedFunc("nnvm.compiler.build_target"); + tvm::runtime::Module module = fbuild(func_list, target); ret.attrs["module"] = std::make_shared(std::move(module)); ret = nnvm::ApplyPass(ret, "PlanMemory"); return ret; } -NNVM_REGISTER_PASS(GraphFuse) -.set_body(GraphFuse); +NNVM_REGISTER_PASS(GraphFuseCompile) +.set_body(GraphFuseCompile); + } // namespace compiler } // namespace nnvm diff --git a/nnvm/src/compiler/graph_deep_compare.cc b/nnvm/src/compiler/graph_hash.cc similarity index 52% rename from nnvm/src/compiler/graph_deep_compare.cc rename to nnvm/src/compiler/graph_hash.cc index df578165ed6b..a38b6e135d43 100644 --- a/nnvm/src/compiler/graph_deep_compare.cc +++ b/nnvm/src/compiler/graph_hash.cc @@ -3,22 +3,124 @@ * \file graph_deep_compare.cc * \brief Deep compare two graph structure */ +#include #include #include #include +#include #include +#include #include "./node_attr.h" +#include "./graph_hash.h" namespace nnvm { namespace compiler { +using namespace tvm; +using tvm::ir::IntImm; + +size_t HashPlaceHolder(const Tensor& t) { + size_t key = t->shape.size(); + key = dmlc::HashCombine(key, (t->dtype.code() << 8) | t->dtype.bits()); + for (Expr s : t->shape) { + if (const IntImm* op = s.as()) { + key = dmlc::HashCombine(key, op->value); + } + } + return key; +} + +bool PlaceHolderEqual(const Tensor& a, const Tensor& b) { + if (a->shape.size() != b->shape.size()) return false; + if (a->dtype != b->dtype) return false; + for (size_t i = 0; i < a->shape.size(); ++i) { + const IntImm* a_value = a->shape[i].as(); + const IntImm* b_value = b->shape[i].as(); + if (a_value && b_value == nullptr) return false; + if (b_value && a_value == nullptr) return false; + if (a_value == nullptr && b_value == nullptr) { + continue; + } + if (a_value->value != b_value->value) return false; + } + return true; +} + +size_t GraphKeyHash::Hash(const GraphKey& gkey) { + if (gkey->cache_hash_key_ != 0) return gkey->cache_hash_key_; + size_t key = dmlc::HashCombine(GraphHash(gkey->graph), gkey->target); + key = dmlc::HashCombine(key, gkey->inputs.size()); + for (size_t i = 0; i < gkey->inputs.size(); ++i) { + key = dmlc::HashCombine(key, HashPlaceHolder(gkey->inputs[i])); + } + if (key == 0) key = 1; + gkey->cache_hash_key_ = key; + return key; +} + +bool GraphKeyEqual::Equal(const GraphKey& a, + const GraphKey& b) { + if (a->target != b->target) return false; + if (a->inputs.size() != b->inputs.size()) return false; + for (size_t i = 0; i < a->inputs.size(); ++i) { + if (!PlaceHolderEqual(a->inputs[i], b->inputs[i])) return false; + } + if (GraphDeepCompare(a->graph, b->graph, false).length() != 0) return false; + return true; +} + +GraphKey GraphKeyNode::make(Graph graph, + tvm::Array inputs, + std::string target) { + std::shared_ptr n + = std::make_shared(); + n->graph = std::move(graph); + n->inputs = inputs; + n->target = std::move(target); + return GraphKey(n); +} + +TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) +.set_dispatch([](const GraphKeyNode *op, IRPrinter *p) { + p->stream << "GraphKeyNode("<< op << ")"; +}); + + +// Run graph hash +size_t GraphHash(const Graph& graph) { + const IndexedGraph& idx = graph.indexed_graph(); + size_t key = 0; + // Combine a linearized sequence of ops in subgraph + key = dmlc::HashCombine(key, idx.num_nodes()); + std::hash str_hash; + std::vector hash_temp; + for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { + const IndexedGraph::Node& inode = idx[nid]; + // Use name instad op address so it is deterministic across runs + key = dmlc::HashCombine(key, inode.source->op()p); + if (inode.source->is_variable()) continue; + hash_temp.clear(); + for (const auto& kv : GetAttrDict(inode.source->attrs)) { + hash_temp.push_back(dmlc::HashCombine(str_hash(kv.first), kv.second)); + } + // to make sure it is deterministic + // since unordered_map is not deterministic + std::sort(hash_temp.begin(), hash_temp.end()); + for (size_t value : hash_temp) { + key = dmlc::HashCombine(key, value); + } + } + return key; +} + // deep compare the graph structure // not considering the graph attributes // return non-empty error message if the graph mismatch. // the comparator won't match name of intermediate node. // compare_var_attr -std::string DeepCompare(Graph a, Graph b, - bool compare_variable_attr) { +std::string GraphDeepCompare(const Graph& a, + const Graph& b, + bool compare_variable_attr) { const IndexedGraph& idxa = a.indexed_graph(); const IndexedGraph& idxb = b.indexed_graph(); std::ostringstream err; @@ -113,7 +215,7 @@ std::string DeepCompare(Graph a, Graph b, TVM_REGISTER_GLOBAL("nnvm.graph.DeepCompare") .set_body([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue *rv) { - *rv = DeepCompare(args[0], args[1], args[2]); + *rv = GraphDeepCompare(args[0], args[1], args[2]); }); } // namespace compiler } // namespace nnvm diff --git a/nnvm/src/compiler/graph_hash.h b/nnvm/src/compiler/graph_hash.h new file mode 100644 index 000000000000..f6f93a9d7e95 --- /dev/null +++ b/nnvm/src/compiler/graph_hash.h @@ -0,0 +1,82 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file graph_hash.h + * \brief The graph hashing function. + */ +#ifndef NNVM_COMPILER_GRAPH_HASH_H_ +#define NNVM_COMPILER_GRAPH_HASH_H_ + +#include +#include +#include +#include + +namespace nnvm { +namespace compiler { + +class GraphKey; + +/*! \brief Key to a graph compiler cache */ +struct GraphKeyNode : public tvm::Node { + /*! \brief The graph structure */ + Graph graph; + /* \brief The inputs to the function */ + tvm::Array inputs; + /*! \brief The target */ + std::string target; + // Cached internal hash key, invisible to the user. + // The graph hash key is ensured always not to be 0 + mutable size_t cache_hash_key_{0}; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("inputs", &inputs); + v->Visit("target", &target); + } + + static GraphKey make(Graph graph, + tvm::Array inputs, + std::string target); + static constexpr const char* _type_key = "GraphKey"; + TVM_DECLARE_NODE_TYPE_INFO(GraphKeyNode, tvm::Node); +}; + +TVM_DEFINE_NODE_REF(GraphKey, GraphKeyNode); + +/*! \brief Hashing function for graph key */ +struct GraphKeyHash { + size_t operator()(const GraphKey& gkey) const { + return Hash(gkey); + } + static size_t Hash(const GraphKey& gkey); +}; + +/*! \brief function for graph key */ +struct GraphKeyEqual { + bool operator()(const GraphKey& a, + const GraphKey& b) const { + return Equal(a, b); + } + static bool Equal(const GraphKey& a, const GraphKey& b); +}; + +/*! + * \brief Create a hash code for a given graph. + * \return The hash code of the graph. + */ +size_t GraphHash(const Graph& graph); + +/*! + * \brief Compare two graphs + * return empty string if they are equal + * otherwise return error message + * \param a The first graph. + * \param b The second graph. + * \return empty string if they are equal, otherwise return error message. + */ +std::string GraphDeepCompare(const Graph& a, + const Graph& b, + bool compare_variable_attr); +} // namespace compiler +} // namespace nnvm + +#endif // NNVM_COMPILER_GRAPH_HASH_H_ diff --git a/nnvm/src/core/graph.cc b/nnvm/src/core/graph.cc index 474f3104f145..d1b6efb66dbb 100644 --- a/nnvm/src/core/graph.cc +++ b/nnvm/src/core/graph.cc @@ -9,7 +9,7 @@ namespace nnvm { -const IndexedGraph& Graph::indexed_graph() { +const IndexedGraph& Graph::indexed_graph() const { if (indexed_graph_ == nullptr) { indexed_graph_.reset(new IndexedGraph(*this)); } diff --git a/nnvm/src/pass/print_graph_ir.cc b/nnvm/src/pass/print_graph_ir.cc index 52298c0a77a5..80b6df2006ba 100644 --- a/nnvm/src/pass/print_graph_ir.cc +++ b/nnvm/src/pass/print_graph_ir.cc @@ -180,7 +180,7 @@ void PrintGraphIR_(Graph src, } // save a graph to json -Graph PrintGraphIR(Graph src) { +Graph PrintGraphIRPass(Graph src) { std::ostringstream os; std::vector join_entry_attrs, join_node_attrs; if (src.attrs.count("join_entry_attrs") != 0) { @@ -200,7 +200,7 @@ Graph PrintGraphIR(Graph src) { // register pass NNVM_REGISTER_PASS(PrintGraphIR) .describe("Return a empty Graph, save ir to ret.attrs[\"graphir\"]") -.set_body(PrintGraphIR); +.set_body(PrintGraphIRPass); } // namespace pass } // namespace nnvm diff --git a/nnvm/tests/python/compiler/test_compiler_cache.py b/nnvm/tests/python/compiler/test_compiler_cache.py new file mode 100644 index 000000000000..f7666b39f005 --- /dev/null +++ b/nnvm/tests/python/compiler/test_compiler_cache.py @@ -0,0 +1,42 @@ +import numpy as np +import tvm +import nnvm.symbol as sym +import nnvm.compiler +import nnvm.runtime + +def test_compile_cache(): + x = sym.Variable("x") + y = sym.Variable("y") + z = sym.exp(y + x) + shape = (10, 1) + dtype = tvm.float32 + shape_dict = {"x": shape, "y": shape} + def verify(graph, lib): + m = nnvm.runtime.create(graph, lib, tvm.cpu(0)) + # get member functions + na = tvm.nd.array(np.random.uniform(size=shape).astype(dtype)) + nb = tvm.nd.array(np.random.uniform(size=shape).astype(dtype)) + m.run(x=na, y=nb) + # get outputs + out = m.get_output(0, tvm.nd.empty(shape, dtype)) + np.testing.assert_allclose( + out.asnumpy(), np.exp(na.asnumpy() + nb.asnumpy())) + + engine = nnvm.compiler.engine + graph, lib, _ = nnvm.compiler.build(z, "llvm", shape_dict) + inputs = [tvm.placeholder((10,)), tvm.placeholder((10,))] + + gkey = nnvm.compiler.graph_key(nnvm.graph.create(z), inputs, "llvm") + gkey2 = nnvm.compiler.graph_key(nnvm.graph.create(z), inputs + inputs, "llvm") + gf = engine[gkey] + assert gf is not None + assert engine[gkey2] is None + graph, lib, _ = nnvm.compiler.build(z, "llvm", shape_dict) + assert graph.index.num_nodes == 3 + verify(graph, lib) + # Test various set external cache + engine.clear_cache() + engine[gkey] = gf + +if __name__ == "__main__": + test_compile_cache() diff --git a/nnvm/tests/python/compiler/test_op_fusion.py b/nnvm/tests/python/compiler/test_op_fusion.py index ca82da2023c0..a7c6ca4b288d 100644 --- a/nnvm/tests/python/compiler/test_op_fusion.py +++ b/nnvm/tests/python/compiler/test_op_fusion.py @@ -4,7 +4,7 @@ import topi from nnvm import symbol as sym from nnvm.compiler import graph_util, graph_attr -from nnvm.testing.config import test_ctx_list +from nnvm.testing import ctx_list def test_ewise_injective(): x = sym.Variable("x") @@ -14,7 +14,7 @@ def test_ewise_injective(): shape_dict = {"x": dshape} dtype = "float32" target = "llvm" - for target, ctx in test_ctx_list(): + for target, ctx in ctx_list(): graph, lib, _ = nnvm.compiler.build(y, target, shape_dict) assert graph.index.num_nodes == 2 m = nnvm.runtime.create(graph, lib, ctx) @@ -37,7 +37,7 @@ def test_conv_ewise_injective(): oshape = (1, 32* 18 * 18) shape_dict = {"x": dshape} - for target, ctx in test_ctx_list(): + for target, ctx in ctx_list(): graph, lib, _ = nnvm.compiler.build(y, target, shape_dict) m = nnvm.runtime.create(graph, lib, ctx) # print(graph.ir(join_entry_attrs=["shape"])) @@ -64,7 +64,7 @@ def test_injective_reduce_injective(): dshape = (32, 1, 18, 18) shape_dict = {"x": dshape} - for target, ctx in test_ctx_list(): + for target, ctx in ctx_list(): graph, lib, _ = nnvm.compiler.build(y, target, shape_dict) m = nnvm.runtime.create(graph, lib, ctx) assert graph.index.num_nodes == 2 diff --git a/nnvm/tests/python/compiler/test_top_level1.py b/nnvm/tests/python/compiler/test_top_level1.py index d03782bb4cb5..baeaaf86040a 100644 --- a/nnvm/tests/python/compiler/test_top_level1.py +++ b/nnvm/tests/python/compiler/test_top_level1.py @@ -4,7 +4,7 @@ import nnvm.symbol as sym import nnvm.compiler import nnvm.runtime -from nnvm.testing.config import test_ctx_list +from nnvm.testing.config import ctx_list def test_relu(): x = sym.Variable("x") @@ -13,7 +13,7 @@ def test_relu(): dtype = "float32" dshape = (1, 3, 32, 32) oshape = dshape - for target, ctx in test_ctx_list(): + for target, ctx in ctx_list(): graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape}) m = nnvm.runtime.create(graph, lib, ctx) # get member functions @@ -31,7 +31,7 @@ def test_exp(): dtype = "float32" dshape = (1, 3, 32, 32) oshape = dshape - for target, ctx in test_ctx_list(): + for target, ctx in ctx_list(): graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape}) m = nnvm.runtime.create(graph, lib, ctx) # get member functions @@ -54,7 +54,7 @@ def test_log(): dtype = "float32" dshape = (1, 3, 32, 32) oshape = dshape - for target, ctx in test_ctx_list(): + for target, ctx in ctx_list(): with nnvm.compiler.build_config(opt_level=1): graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape}) m = nnvm.runtime.create(graph, lib, ctx) @@ -78,7 +78,7 @@ def test_tanh(): dtype = "float32" dshape = (1, 3, 32, 32) oshape = dshape - for target, ctx in test_ctx_list(): + for target, ctx in ctx_list(): with nnvm.compiler.build_config(opt_level=1): graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape}) m = nnvm.runtime.create(graph, lib, ctx) @@ -102,7 +102,7 @@ def test_sigmoid(): dtype = "float32" dshape = (1, 3, 32, 32) oshape = dshape - for target, ctx in test_ctx_list(): + for target, ctx in ctx_list(): graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape}) m = nnvm.runtime.create(graph, lib, ctx) # get member functions @@ -125,7 +125,7 @@ def test_softmax(): dtype = "float32" dshape = (10, 1000) oshape = dshape - for target, ctx in test_ctx_list(): + for target, ctx in ctx_list(): with nnvm.compiler.build_config(opt_level=1): graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape}) m = nnvm.runtime.create(graph, lib, ctx) @@ -153,7 +153,7 @@ def test_dense(): "dense_weight" : (3, 100), "dense_bias" : (3,), } - for target, ctx in test_ctx_list(): + for target, ctx in ctx_list(): graph, lib, _ = nnvm.compiler.build(y, target, shape) m = nnvm.runtime.create(graph, lib, ctx) x_np = np.random.uniform(size=shape["x"]).astype(dtype) @@ -179,7 +179,7 @@ def test_batchnorm(): y = sym.batch_norm( x, gamma, beta, moving_mean, moving_var, epsilon=eps) - for target, ctx in test_ctx_list(): + for target, ctx in ctx_list(): graph, lib, _ = nnvm.compiler.build(y, "llvm", {"x": shape}) m = nnvm.runtime.create(graph, lib, tvm.cpu(0)) x_np = np.random.uniform(size=shape).astype(dtype) diff --git a/nnvm/tests/python/compiler/test_top_level2.py b/nnvm/tests/python/compiler/test_top_level2.py index 793d6d3e955f..79f29f40ce5c 100644 --- a/nnvm/tests/python/compiler/test_top_level2.py +++ b/nnvm/tests/python/compiler/test_top_level2.py @@ -5,7 +5,7 @@ import nnvm.symbol as sym import nnvm.compiler import nnvm.runtime -from nnvm.testing.config import test_ctx_list +from nnvm.testing.config import ctx_list def test_conv2d(): @@ -17,7 +17,7 @@ def test_conv2d(): kshape = (10, 3, 3, 3) oshape = (1, 10, 18, 18) shape_dict = {"x": dshape} - for target, ctx in test_ctx_list(): + for target, ctx in ctx_list(): graph, lib, _ = nnvm.compiler.build(y, target, shape_dict) m = nnvm.runtime.create(graph, lib, ctx) # get member functions @@ -46,7 +46,7 @@ def test_grouped_conv2d(): kshape = (32, 1, 3, 3) oshape = (1, 32, 18, 18) shape_dict = {"x": dshape} - for target, ctx in test_ctx_list(): + for target, ctx in ctx_list(): graph, lib, _ = nnvm.compiler.build(y, target, shape_dict) m = nnvm.runtime.create(graph, lib, ctx) # set input diff --git a/nnvm/tests/python/compiler/test_top_level4.py b/nnvm/tests/python/compiler/test_top_level4.py index 851fc8b7d46c..eac3178e3e45 100644 --- a/nnvm/tests/python/compiler/test_top_level4.py +++ b/nnvm/tests/python/compiler/test_top_level4.py @@ -4,7 +4,7 @@ import nnvm.symbol as sym import nnvm.compiler import nnvm.runtime -from nnvm.testing.config import test_ctx_list +from nnvm.testing.config import ctx_list def verify_transpose(dshape, axes): x = sym.Variable("x") @@ -14,7 +14,7 @@ def verify_transpose(dshape, axes): y = sym.transpose(x) y = y + 1 dtype = "float32" - for target, ctx in test_ctx_list(): + for target, ctx in ctx_list(): graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape}) m = nnvm.runtime.create(graph, lib, ctx) # set input @@ -29,7 +29,7 @@ def verify_reduce(dshape, fnp, fsym, **kwargs): x = sym.Variable("x") y = fsym(x + 1, **kwargs) dtype = "float32" - for target, ctx in test_ctx_list(): + for target, ctx in ctx_list(): graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape}) m = nnvm.runtime.create(graph, lib, ctx) # set input @@ -54,3 +54,4 @@ def test_reduce(): if __name__ == "__main__": test_reduce() test_tranpose() + print(nnvm.compiler.engine.dump())